Source code for maxent_grpo.grpo

"""Baseline GRPO training entrypoint.

Provides a thin wrapper around the training pipeline that either parses TRL
arguments from the CLI or delegates to the Hydra-based CLI when explicit args
are not provided. Exposed for ``python -m maxent_grpo.grpo`` and for
programmatic invocation inside orchestration code.
"""

from __future__ import annotations

import importlib
import sys
from pathlib import Path

# Allow running this file directly (e.g., accelerate launch src/maxent_grpo/grpo.py)
# by ensuring the package root is on sys.path.
if __package__ is None or __package__ == "":
    project_src = Path(__file__).resolve().parents[1]
    project_src_str = str(project_src)
    if project_src_str in sys.path:
        sys.path.remove(project_src_str)
    sys.path.insert(0, project_src_str)

from typing import TYPE_CHECKING, Any, Callable, Optional, cast

from maxent_grpo.cli._test_hooks import ensure_usercustomize_loaded

from maxent_grpo.config import GRPOConfig, GRPOScriptArguments

if TYPE_CHECKING:
    from trl import ModelConfig

try:  # Best-effort to expose CLI helpers when available.
    from maxent_grpo.cli import parse_grpo_args
except (
    ImportError,
    ModuleNotFoundError,
    AttributeError,
):  # pragma: no cover - optional deps may be absent
    parse_grpo_args = None

__all__ = ["cli", "main"]


def _resolve_cli_attr(attr_name: str) -> Any:
    """Best-effort import helper for optional CLI attributes."""

    try:
        cli_mod = importlib.import_module("maxent_grpo.cli")
    except (ImportError, ModuleNotFoundError, AttributeError):
        cli_mod = None
    if cli_mod is not None:
        attr = getattr(cli_mod, attr_name, None)
        if attr is not None:
            return attr
    try:
        pkg = importlib.import_module("maxent_grpo")
    except (ImportError, ModuleNotFoundError, AttributeError):
        return None
    cli_pkg = getattr(pkg, "cli", None)
    if cli_pkg is None:
        return None
    return getattr(cli_pkg, attr_name, None)


[docs] def main( script_args: Optional[GRPOScriptArguments] = None, training_args: Optional[GRPOConfig] = None, model_args: "Optional[ModelConfig]" = None, ) -> Any: """Run the baseline GRPO trainer or delegate to Hydra. :param script_args: Dataset/reward script arguments parsed via TRL or provided directly. :param training_args: GRPO training configuration produced by TRL. :param model_args: Model configuration passed to TRL/transformers trainers. :returns: Training result from :func:`maxent_grpo.training.baseline.run_baseline_training`, or the Hydra CLI invocation result when no args are supplied. :raises RuntimeError: If no CLI parser or Hydra entrypoint is available. :raises Exception: Propagates parser or training pipeline exceptions. """ ensure_usercustomize_loaded() if script_args is None or training_args is None or model_args is None: # Prefer monkeypatched attributes (used in tests). Do not fall back to Hydra. _parse_grpo_args = parse_grpo_args if not callable(_parse_grpo_args): parsed = _resolve_cli_attr("parse_grpo_args") _parse_grpo_args = parsed if callable(parsed) else None if callable(_parse_grpo_args): parser = cast( Callable[ [], tuple[GRPOScriptArguments, GRPOConfig, "ModelConfig"], ], _parse_grpo_args, ) script_args, training_args, model_args = parser() else: raise RuntimeError( "No CLI parser available. Ensure TRL is installed and " "maxent_grpo.cli.parse_grpo_args is importable." ) baseline_mod = sys.modules.get("maxent_grpo.training.baseline") if baseline_mod and hasattr(baseline_mod, "run_baseline_training"): run_baseline_training = baseline_mod.run_baseline_training else: from maxent_grpo.training.baseline import run_baseline_training return run_baseline_training(script_args, training_args, model_args)
[docs] def cli() -> None: """Invoke the baseline entrypoint (CLI style). :returns: ``None``. Side effects include running training or delegating to Hydra. """ main()
if __name__ == "__main__": # pragma: no cover - CLI entrypoint cli()