Source code for maxent_grpo.core.data

"""
Dataset loading utilities with support for mixtures.

This module wraps Hugging Face ``datasets.load_dataset`` to handle either a
single dataset (``dataset_name``) or a declarative mixture with optional column
selection, subsampling via weights, shuffling, and an optional train/test split.
It returns a mapping compatible with downstream training/evaluation code (a
``datasets.DatasetDict`` when the library is installed, or a lightweight stub
during tests).

The import of ``datasets`` is guarded so this module can be imported in
environments where the library is unavailable; tests monkey-patch the missing
symbols when needed.

License
Copyright 2025 Liv d'Aliberti

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the specific language governing permissions and
limitations under the License.
"""

from __future__ import annotations

import logging
import os
import random
import time
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Optional, cast
from collections.abc import Mapping

if TYPE_CHECKING:  # pragma: no cover - type hints only
    from collections.abc import Sequence
    from typing import Callable

    import datasets
    from datasets import Dataset, DatasetDict
    from maxent_grpo.config import ScriptArguments

    concatenate_datasets: Callable[[Sequence[Dataset]], Dataset]
else:
    try:
        import datasets
        from datasets import Dataset, DatasetDict

        _concatenate = getattr(datasets, "concatenate_datasets", None)
        if _concatenate is None:
            raise AttributeError(
                "The 'datasets' package is missing concatenate_datasets."
            )
        concatenate_datasets = _concatenate
    except (
        ModuleNotFoundError,
        AttributeError,
        ImportError,
        OSError,
        RuntimeError,
        ValueError,
    ):  # pragma: no cover - optional dependency

        class _DatasetsStub:
            """Lightweight stub so imports succeed when ``datasets`` is absent."""

            def load_dataset(self, *_args: Any, **_kwargs: Any) -> Any:
                raise ModuleNotFoundError(
                    "The 'datasets' package is required for dataset loading. "
                    "Install with `pip install datasets`."
                )

            def __getattr__(self, _name: str) -> Any:
                raise ModuleNotFoundError(
                    "The 'datasets' package is required for dataset loading. "
                    "Install with `pip install datasets`."
                )

        datasets = _DatasetsStub()
        Dataset = Any
        DatasetDict = dict

        def concatenate_datasets(*_datasets: Any, **_kwargs: Any) -> Any:
            raise ModuleNotFoundError(
                "The 'datasets' package is required for dataset concatenation. "
                "Install with `pip install datasets`."
            )


try:  # pragma: no cover - convenience re-export for callers/tests
    from maxent_grpo.config.dataset import ScriptArguments
except (
    ImportError,
    ModuleNotFoundError,
    AttributeError,
    OSError,
    RuntimeError,
    ValueError,
):  # pragma: no cover - optional dependency
    ScriptArguments = Any  # type: ignore[assignment]

try:  # pragma: no cover - optional pyarrow exception for from_list conversions
    from pyarrow.lib import ArrowInvalid as _ArrowInvalid
except (ImportError, ModuleNotFoundError, AttributeError, OSError, RuntimeError):
    _ArrowInvalid = None

_FROM_LIST_EXCEPTIONS = (TypeError, ValueError, RuntimeError)
if _ArrowInvalid is not None:
    _FROM_LIST_EXCEPTIONS = _FROM_LIST_EXCEPTIONS + (_ArrowInvalid,)


logger = logging.getLogger(__name__)

_DEFAULT_HF_RETRIES = 6
_DEFAULT_HF_RETRY_SLEEP = 2.0
_DEFAULT_HF_RETRY_MAX_SLEEP = 60.0


def _is_saved_hf_dataset_dir(candidate: Any) -> bool:
    """Return True when ``candidate`` looks like ``datasets.save_to_disk`` output."""

    try:
        path = Path(candidate)
    except TypeError:
        return False
    return path.is_dir() and (
        (path / "dataset_dict.json").is_file() or (path / "state.json").is_file()
    )


def _dataset_load_retry_settings() -> tuple[int, float, float]:
    def _read_int(name: str, default: int) -> int:
        raw = os.environ.get(name)
        if raw is None:
            return default
        try:
            return int(raw)
        except (TypeError, ValueError):
            return default

    def _read_float(name: str, default: float) -> float:
        raw = os.environ.get(name)
        if raw is None:
            return default
        try:
            return float(raw)
        except (TypeError, ValueError):
            return default

    retries = max(0, _read_int("MAXENT_HF_DATASET_RETRIES", _DEFAULT_HF_RETRIES))
    sleep_s = max(
        0.0, _read_float("MAXENT_HF_DATASET_RETRY_SLEEP", _DEFAULT_HF_RETRY_SLEEP)
    )
    max_sleep_s = max(
        sleep_s,
        _read_float("MAXENT_HF_DATASET_RETRY_MAX_SLEEP", _DEFAULT_HF_RETRY_MAX_SLEEP),
    )
    return retries, sleep_s, max_sleep_s


def _should_retry_dataset_load(exc: BaseException) -> bool:
    status = getattr(getattr(exc, "response", None), "status_code", None)
    if isinstance(status, int) and (status == 429 or 500 <= status <= 599):
        return True
    message = str(exc)
    for token in (" 502 ", " 503 ", " 504 ", " 500 ", " 429 "):
        if token in f" {message} ":
            return True
    return False


def _load_dataset_with_retries(*args: Any, **kwargs: Any) -> Any:
    if args and _is_saved_hf_dataset_dir(args[0]):
        split = kwargs.pop("split", None)
        dataset = datasets.load_from_disk(str(args[0]))  # type: ignore[attr-defined]
        if split is None:
            return dataset
        return dataset[split]
    retries, sleep_s, max_sleep_s = _dataset_load_retry_settings()
    last_exc: Optional[BaseException] = None
    for attempt in range(retries + 1):
        try:
            return datasets.load_dataset(*args, **kwargs)  # nosec B615
        except (
            ConnectionError,
            OSError,
            RuntimeError,
            ValueError,
        ) as exc:  # pragma: no cover - network failures are environment dependent
            last_exc = exc
            if attempt >= retries or not _should_retry_dataset_load(exc):
                raise
            delay = min(max_sleep_s, sleep_s * (2**attempt))
            # Small jitter to avoid all ranks retrying in lockstep if this runs multi-process.
            delay = delay * (0.85 + 0.3 * random.random())
            logger.warning(
                "datasets.load_dataset failed (attempt %d/%d); retrying in %.1fs | error=%s",
                attempt + 1,
                retries + 1,
                delay,
                exc,
            )
            time.sleep(delay)
    if last_exc is not None:
        raise last_exc
    raise RuntimeError("datasets.load_dataset failed unexpectedly without an exception")


def _to_dataset_dict(payload: Any) -> DatasetDict:
    def _maybe_to_hf_dataset(value: Any) -> Any:
        from_list = getattr(Dataset, "from_list", None)
        if not callable(from_list):
            return value
        try:
            if isinstance(value, Dataset):
                return value
        except TypeError:
            return value
        if isinstance(value, list):
            if value and not all(isinstance(item, Mapping) for item in value):
                return value
            try:
                return from_list(value)
            except _FROM_LIST_EXCEPTIONS:
                return value
        return value

    if isinstance(payload, DatasetDict):
        return payload
    if isinstance(payload, dict):
        converted = {key: _maybe_to_hf_dataset(val) for key, val in payload.items()}
        return DatasetDict(converted)
    return DatasetDict({"train": _maybe_to_hf_dataset(payload)})


[docs] def get_dataset(args: ScriptArguments) -> DatasetDict: """Load a dataset or a weighted mixture and return a dictionary. The function dispatches to ``datasets.load_dataset`` for simple cases or combines multiple datasets when ``args.dataset_mixture`` is provided. Each dataset in a mixture can specify a subset of columns, a fractional weight to subsample with deterministic shuffling, and an optional global test split on the concatenated result. :param args: Parsed script arguments that describe either a single dataset (``dataset_name`` / ``dataset_config``) or a declarative mixture (``dataset_mixture``). :type args: maxent_grpo.config.ScriptArguments :returns: Mapping with at least a ``train`` split, and possibly ``test`` if a split size was requested. :rtype: datasets.DatasetDict :raises ValueError: If neither a dataset name nor mixture is supplied, or when a mixture resolves to zero loaded datasets. """ inline_ds = getattr(args, "dataset", None) if inline_ds is not None: return _to_dataset_dict(inline_ds) if args.dataset_name and not args.dataset_mixture: logger.info("Loading dataset: %s", args.dataset_name) return cast( DatasetDict, _load_dataset_with_retries(args.dataset_name, args.dataset_config), ) elif args.dataset_mixture: logger.info( "Creating dataset mixture with %d datasets", len(args.dataset_mixture.datasets), ) seed: int = args.dataset_mixture.seed datasets_list: List[Dataset] = [] for dataset_config in args.dataset_mixture.datasets: logger.info( "Loading dataset for mixture: %s (config: %s)", dataset_config.id, dataset_config.config, ) ds = _load_dataset_with_retries( dataset_config.id, dataset_config.config, split=dataset_config.split, ) if dataset_config.columns is not None: ds = ds.select_columns(dataset_config.columns) if dataset_config.weight is not None: ds = ds.shuffle(seed=seed).select( range(int(len(ds) * dataset_config.weight)) ) logger.info( "Subsampled dataset '%s' (config: %s) with weight=%s to %d examples", dataset_config.id, dataset_config.config, dataset_config.weight, len(ds), ) datasets_list.append(ds) if datasets_list: combined_dataset = concatenate_datasets(datasets_list) combined_dataset = combined_dataset.shuffle(seed=seed) logger.info( "Created dataset mixture with %d examples", len(combined_dataset) ) if args.dataset_mixture.test_split_size is not None: combined_dataset = combined_dataset.train_test_split( test_size=args.dataset_mixture.test_split_size, seed=seed ) logger.info( "Split dataset into train and test sets with test size: %s", args.dataset_mixture.test_split_size, ) return combined_dataset else: return DatasetDict({"train": combined_dataset}) else: raise ValueError("No datasets were loaded from the mixture configuration") else: raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided")
[docs] def load_dataset_split( dataset_name: str, dataset_config: Optional[str] = None, split: str = "validation", ) -> Dataset: """Load a single split from a dataset independent of ``ScriptArguments``. This helper is used by evaluation code that cannot rely on the full CLI argument object but still needs consistent column filtering and error handling. :param dataset_name: Dataset repository ID on the Hugging Face Hub. :type dataset_name: str :param dataset_config: Optional dataset configuration name to disambiguate multiple configurations. :type dataset_config: str | None :param split: Split to load (for example ``\"train\"``, ``\"validation\"``, or ``\"test\"``). :type split: str :returns: The requested dataset split as returned by ``datasets.load_dataset``. :rtype: datasets.Dataset :raises ValueError: If ``split`` is falsy, as evaluation requires an explicit split to avoid loading entire datasets inadvertently. """ if not split: raise ValueError("split must be provided when loading an eval dataset") dataset = _load_dataset_with_retries(dataset_name, dataset_config, split=split) return cast(Dataset, dataset)