Source code for maxent_grpo.grpo
"""Baseline GRPO training entrypoint.
Provides a thin wrapper around the training pipeline that either parses TRL
arguments from the CLI or delegates to the Hydra-based CLI when explicit args
are not provided. Exposed for ``python -m maxent_grpo.grpo`` and for
programmatic invocation inside orchestration code.
"""
from __future__ import annotations
import importlib
import sys
from pathlib import Path
# Allow running this file directly (e.g., accelerate launch src/maxent_grpo/grpo.py)
# by ensuring the package root is on sys.path.
if __package__ is None or __package__ == "":
project_src = Path(__file__).resolve().parents[1]
project_src_str = str(project_src)
if project_src_str in sys.path:
sys.path.remove(project_src_str)
sys.path.insert(0, project_src_str)
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
from maxent_grpo.cli._test_hooks import ensure_usercustomize_loaded
from maxent_grpo.config import GRPOConfig, GRPOScriptArguments
if TYPE_CHECKING:
from trl import ModelConfig
try: # Best-effort to expose CLI helpers when available.
from maxent_grpo.cli import parse_grpo_args
except (
ImportError,
ModuleNotFoundError,
AttributeError,
): # pragma: no cover - optional deps may be absent
parse_grpo_args = None
__all__ = ["cli", "main"]
def _resolve_cli_attr(attr_name: str) -> Any:
"""Best-effort import helper for optional CLI attributes."""
try:
cli_mod = importlib.import_module("maxent_grpo.cli")
except (ImportError, ModuleNotFoundError, AttributeError):
cli_mod = None
if cli_mod is not None:
attr = getattr(cli_mod, attr_name, None)
if attr is not None:
return attr
try:
pkg = importlib.import_module("maxent_grpo")
except (ImportError, ModuleNotFoundError, AttributeError):
return None
cli_pkg = getattr(pkg, "cli", None)
if cli_pkg is None:
return None
return getattr(cli_pkg, attr_name, None)
[docs]
def main(
script_args: Optional[GRPOScriptArguments] = None,
training_args: Optional[GRPOConfig] = None,
model_args: "Optional[ModelConfig]" = None,
) -> Any:
"""Run the baseline GRPO trainer or delegate to Hydra.
:param script_args: Dataset/reward script arguments parsed via TRL or provided directly.
:param training_args: GRPO training configuration produced by TRL.
:param model_args: Model configuration passed to TRL/transformers trainers.
:returns: Training result from :func:`maxent_grpo.training.baseline.run_baseline_training`,
or the Hydra CLI invocation result when no args are supplied.
:raises RuntimeError: If no CLI parser or Hydra entrypoint is available.
:raises Exception: Propagates parser or training pipeline exceptions.
"""
ensure_usercustomize_loaded()
if script_args is None or training_args is None or model_args is None:
# Prefer monkeypatched attributes (used in tests). Do not fall back to Hydra.
_parse_grpo_args = parse_grpo_args
if not callable(_parse_grpo_args):
parsed = _resolve_cli_attr("parse_grpo_args")
_parse_grpo_args = parsed if callable(parsed) else None
if callable(_parse_grpo_args):
parser = cast(
Callable[
[],
tuple[GRPOScriptArguments, GRPOConfig, "ModelConfig"],
],
_parse_grpo_args,
)
script_args, training_args, model_args = parser()
else:
raise RuntimeError(
"No CLI parser available. Ensure TRL is installed and "
"maxent_grpo.cli.parse_grpo_args is importable."
)
baseline_mod = sys.modules.get("maxent_grpo.training.baseline")
if baseline_mod and hasattr(baseline_mod, "run_baseline_training"):
run_baseline_training = baseline_mod.run_baseline_training
else:
from maxent_grpo.training.baseline import run_baseline_training
return run_baseline_training(script_args, training_args, model_args)
[docs]
def cli() -> None:
"""Invoke the baseline entrypoint (CLI style).
:returns: ``None``. Side effects include running training or delegating to Hydra.
"""
main()
if __name__ == "__main__": # pragma: no cover - CLI entrypoint
cli()