Source code for maxent_grpo.training.data

"""Dataset loading helpers for the training pipeline."""

from __future__ import annotations

import hashlib
import json
import os
import random
import shutil
import logging
from typing import Any, Callable, List, Mapping, Optional, Tuple, cast

try:  # pragma: no cover - optional dependency guard for stripped test envs
    from datasets import load_from_disk as _hf_load_from_disk
except (
    ImportError,
    ModuleNotFoundError,
    AttributeError,
    OSError,
    RuntimeError,
    ValueError,
):  # pragma: no cover - fallback when datasets is unavailable
    _hf_load_from_disk = None

from maxent_grpo.core.data import get_dataset, load_dataset_split
from maxent_grpo.training.runtime.prompts import (
    _prompt_char_limit_from_tokens,
    _to_prompt,
)

LOG = logging.getLogger(__name__)


def _stable_hash(value: Any) -> str:
    """Return a short, stable hash for arbitrarily typed values."""

    blob = json.dumps(value, sort_keys=True, default=str).encode("utf-8")
    return hashlib.sha256(blob).hexdigest()[:16]


def _dataset_cache_path(
    script_args: Any,
    training_args: Any,
    *,
    prompt_column: str,
    solution_column: str,
    train_split: str,
    char_limit: int,
) -> str:
    """Resolve an on-disk cache directory for the processed dataset."""

    base_dir = (
        getattr(training_args, "dataset_cache_dir", None)
        or getattr(script_args, "dataset_cache_dir", None)
        or os.environ.get("MAXENT_DATASET_CACHE_DIR")
    )
    if not base_dir:
        output_dir = getattr(training_args, "output_dir", None)
        if not output_dir:
            output_dir = os.getcwd()
        base_dir = os.path.join(output_dir, "dataset_cache")
    dataset_id = {
        "name": getattr(script_args, "dataset_name", None),
        "mixture": getattr(script_args, "dataset_mixture", None),
        "revision": getattr(script_args, "dataset_revision", None),
        "train_split": train_split,
        "prompt_column": prompt_column,
        "solution_column": solution_column,
        "char_limit": int(char_limit or 0),
        "prompt_template": getattr(training_args, "prompt_template", None),
        "system_prompt_hash": _stable_hash(
            getattr(training_args, "system_prompt", "") or ""
        ),
        "max_prompt_length": int(getattr(training_args, "max_prompt_length", 0) or 0),
    }
    cache_key = _stable_hash(dataset_id)
    return os.path.join(base_dir, cache_key)


def _format_eval_row(
    example: Mapping[str, Any],
    *,
    prompt_column: str,
    solution_column: str,
    tokenizer: Any,
    prompt_template: Optional[str],
    system_prompt: Optional[str],
    char_limit: int,
) -> dict:
    example_map = dict(example)
    prompt_col = prompt_column
    if (
        prompt_col not in example_map
        and prompt_col == "problem"
        and "prompt" in example_map
    ):
        prompt_col = "prompt"
    out = _to_prompt(
        example_map,
        tokenizer,
        prompt_col,
        system_prompt,
        char_limit=char_limit,
        prompt_template=prompt_template,
    )
    answer = example_map.get(solution_column, out.get("answer", ""))
    out["answer"] = "" if answer is None else str(answer)
    return out


def _normalize_eval_rows(rows: Any) -> Optional[List[dict]]:
    if rows is None:
        return None
    if isinstance(rows, list):
        return [dict(row) for row in rows]
    normalized: List[dict] = []
    try:
        iterator = iter(rows)
    except TypeError:
        return normalized
    for row in iterator:
        normalized.append(dict(row))
    return normalized


def _ensure_split_mapping(dataset: Any) -> Mapping[str, Any]:
    """Coerce a dataset-like object into a split->dataset mapping."""

    if isinstance(dataset, dict):
        return dataset
    if hasattr(dataset, "keys") and hasattr(dataset, "__getitem__"):
        return cast(Mapping[str, Any], dataset)
    return {"train": dataset}


def _sample_eval_rows(rows: List[dict], keep: int, seed: int) -> List[dict]:
    if keep <= 0 or keep >= len(rows):
        return rows
    indices = list(range(len(rows)))
    random.Random(int(seed or 0)).shuffle(indices)
    return [rows[idx] for idx in indices[:keep]]


[docs] def load_datasets( script_args: Any, training_args: Any, tokenizer: Any, *, accelerator: Any | None = None, ) -> Tuple[Any, list]: """Load train/eval datasets and return ``(train_dataset, eval_rows)``. The helper handles prompt/answer column normalization, optional dataset caching, and prompt truncation. Evaluation rows are normalized into a list of dictionaries with ``prompt``/``answer`` keys. :param script_args: Script arguments describing dataset identifiers and prompt/answer columns. :param training_args: Training configuration providing prompt limits and cache settings. :param tokenizer: Tokenizer used to format prompts. :param accelerator: Optional Accelerator used for process synchronization. :returns: Tuple of the processed training dataset and a list of evaluation rows (possibly empty when eval is disabled). :rtype: tuple[Any, list] :raises ValueError: If required dataset columns are missing. """ pc = getattr(script_args, "dataset_prompt_column", "problem") sc = getattr(script_args, "dataset_solution_column", "answer") char_limit = _prompt_char_limit_from_tokens( getattr(training_args, "max_prompt_length", 0) ) prompt_template = getattr(training_args, "prompt_template", None) def _map_fn(ex: Mapping[str, Any]) -> dict: ex_map = dict(ex) prompt_col = pc if prompt_col not in ex_map and prompt_col == "problem" and "prompt" in ex_map: prompt_col = "prompt" out = _to_prompt( ex_map, tokenizer, prompt_col, getattr(training_args, "system_prompt", None), char_limit=char_limit, prompt_template=prompt_template, ) answer = ex_map.get(sc, out.get("answer", "")) out["answer"] = "" if answer is None else str(answer) return out def _is_valid(row: dict) -> bool: """Drop rows missing a prompt or answer to keep collate happy.""" prompt = row.get("prompt") answer = row.get("answer") return ( isinstance(prompt, str) and isinstance(answer, str) and prompt.strip() != "" and answer.strip() != "" ) train_split = getattr(script_args, "dataset_train_split", "train") test_split = getattr(script_args, "dataset_test_split", None) dataset = None cache_path = None is_main_process = bool(getattr(accelerator, "is_main_process", True)) def _wait_for_everyone() -> None: if accelerator is None: return wait_for_all: Optional[Callable[[], None]] = getattr( accelerator, "wait_for_everyone", None ) if wait_for_all is not None: wait_for_all() cache_path = _dataset_cache_path( script_args, training_args, prompt_column=pc, solution_column=sc, train_split=train_split, char_limit=char_limit, ) cache_enabled = bool(cache_path and _hf_load_from_disk) if cache_enabled and os.path.isdir(cache_path) and _hf_load_from_disk: dataset = _hf_load_from_disk(cache_path) def _build_hf_dataset() -> Any: raw_ds = cast(Any, get_dataset(script_args)) # Remove all original columns so the DataLoader only sees prompt/answer. remove_cols = getattr(raw_ds, "column_names", None) if isinstance(remove_cols, dict): # DatasetDict: merge column names across splits. merged = set() for cols in remove_cols.values(): merged.update(cols) remove_cols = list(merged) if isinstance(remove_cols, list): remove_cols = [c for c in remove_cols if c not in {"prompt", "answer"}] else: remove_cols = None mapped = raw_ds.map(_map_fn, remove_columns=remove_cols, desc="Map") if hasattr(mapped, "filter"): mapped = mapped.filter(_is_valid, desc="Filter") return mapped if ( dataset is None and accelerator is not None and not is_main_process and cache_enabled ): _wait_for_everyone() if os.path.isdir(cache_path) and _hf_load_from_disk: dataset = _hf_load_from_disk(cache_path) else: # pragma: no cover - indicates main process failed before caching raise RuntimeError( f"Expected dataset cache at {cache_path} but it was not created." ) if dataset is None: raw_ds = cast(Any, get_dataset(script_args)) if hasattr(raw_ds, "map"): if accelerator is None or is_main_process: dataset = _build_hf_dataset() if cache_enabled and hasattr(dataset, "save_to_disk"): os.makedirs(os.path.dirname(cache_path), exist_ok=True) tmp_dir = f"{cache_path}.tmp" if os.path.isdir(tmp_dir): shutil.rmtree(tmp_dir) dataset.save_to_disk(tmp_dir) os.replace(tmp_dir, cache_path) _wait_for_everyone() if dataset is None: if cache_enabled and os.path.isdir(cache_path) and _hf_load_from_disk: dataset = _hf_load_from_disk(cache_path) else: dataset = _build_hf_dataset() elif isinstance(raw_ds, dict): dataset = {} for split, rows in raw_ds.items(): mapped = [_map_fn(cast(Mapping[str, Any], row)) for row in rows] dataset[split] = [row for row in mapped if _is_valid(row)] else: mapped_rows = [_map_fn(cast(Mapping[str, Any], row)) for row in raw_ds] dataset = {"train": [row for row in mapped_rows if _is_valid(row)]} dataset_map = _ensure_split_mapping(dataset) if test_split is None: test_split = "validation" if "validation" in dataset_map else None if test_split is None and "test" in dataset_map: test_split = "test" if train_split not in dataset_map: train_split = "train" if "train" in dataset_map else list(dataset_map.keys())[0] train_ds = dataset_map[train_split] eval_rows = _normalize_eval_rows(getattr(script_args, "eval_rows", None)) if eval_rows is None and getattr(training_args, "do_eval", False): eval_dataset_name = getattr(script_args, "eval_dataset_name", None) eval_prompt_col = getattr(script_args, "eval_dataset_prompt_column", None) or pc eval_solution_col = ( getattr(script_args, "eval_dataset_solution_column", None) or sc ) if eval_dataset_name: eval_split = getattr(script_args, "eval_dataset_split", "validation") eval_ds_raw = load_dataset_split( eval_dataset_name, getattr(script_args, "eval_dataset_config", None), eval_split, ) eval_rows = [ _format_eval_row( cast(Mapping[str, Any], row), prompt_column=eval_prompt_col, solution_column=eval_solution_col, tokenizer=tokenizer, prompt_template=prompt_template, system_prompt=getattr(training_args, "system_prompt", None), char_limit=char_limit, ) for row in eval_ds_raw ] elif test_split is not None and test_split in dataset_map: full_eval = dataset_map[test_split] try: n_total = len(full_eval) except (TypeError, AttributeError): n_total = 0 n_keep = min(1000, max(1, int(0.1 * n_total))) if n_total > 0 else 0 shuffle_fn = getattr(full_eval, "shuffle", None) if callable(shuffle_fn) and n_keep > 0: try: shuffled = shuffle_fn(seed=training_args.seed) select_fn = getattr(shuffled, "select", None) if callable(select_fn): subset = select_fn(range(n_keep)) eval_rows = _normalize_eval_rows(subset) else: eval_rows = _normalize_eval_rows(shuffled) except (AttributeError, RuntimeError, TypeError, ValueError): eval_rows = None if eval_rows is None: eval_rows = _normalize_eval_rows(full_eval) if eval_rows and n_keep > 0: eval_rows = _sample_eval_rows( eval_rows, n_keep, getattr(training_args, "seed", 0) ) if eval_rows is None: eval_rows = [] return train_ds, eval_rows
[docs] def resolve_dataloader_kwargs(training_args: Any) -> dict: """Return ``torch.utils.data.DataLoader`` kwargs derived from training_args. :param training_args: Training config or namespace containing DataLoader knobs such as ``dataloader_num_workers`` and ``dataloader_pin_memory``. :returns: Dictionary of keyword arguments suitable for ``DataLoader``. :rtype: dict """ kwargs: dict[str, Any] = {} num_workers = int(getattr(training_args, "dataloader_num_workers", 0) or 0) if num_workers < 0: num_workers = 0 kwargs["num_workers"] = num_workers pin_memory = getattr(training_args, "dataloader_pin_memory", None) if pin_memory is not None: kwargs["pin_memory"] = bool(pin_memory) prefetch = getattr(training_args, "dataloader_prefetch_factor", None) if prefetch is not None: if num_workers > 0: try: kwargs["prefetch_factor"] = int(prefetch) except (TypeError, ValueError): LOG.warning( "Invalid dataloader_prefetch_factor=%s; ignoring.", prefetch ) else: LOG.debug("Ignoring dataloader_prefetch_factor because num_workers=0.") persistent = getattr(training_args, "dataloader_persistent_workers", None) if persistent is not None: if num_workers > 0: kwargs["persistent_workers"] = bool(persistent) else: LOG.debug("Ignoring dataloader_persistent_workers because num_workers=0.") return kwargs
__all__ = ["load_datasets", "resolve_dataloader_kwargs"]