"""
Minimal GRPO training entrypoint built on TRL.
This script wires up a standard ``trl.GRPOTrainer`` with:
* Dataset loading via ``core.data.get_dataset``.
* Simple chat‑templated prompts built from a dataset column.
* A small registry of reward functions from ``maxent_grpo.rewards.basic``.
It aims to be a clean baseline without experimental features (e.g., replay
buffers, schedulers, or custom trainers). Use together with
``maxent_grpo.config.ScriptArguments``/``maxent_grpo.config.GRPOConfig`` and TRL's ``TrlParser``.
Key functions
* ``_to_prompt``: Convert a dataset row to a chat prompt + gold answer.
* ``main``: Load data/model, construct ``GRPOTrainer``, train/eval, and handle
Hub push and model card creation.
License
Copyright 2025 Liv d'Aliberti
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# The module is import‑light: heavy libs are imported lazily inside functions.
from __future__ import annotations
# pylint: disable=broad-exception-caught
import atexit
from contextlib import contextmanager, nullcontext
from collections.abc import MutableMapping as MutableMappingABC
from importlib import import_module
import json
import logging
import os
import sys
import threading
import time
from types import SimpleNamespace
from urllib.parse import urlparse
from typing import (
Dict,
Optional,
Any,
List,
MutableMapping,
Union,
Callable,
Iterator,
Protocol,
cast,
runtime_checkable,
TYPE_CHECKING,
)
from maxent_grpo.config import GRPOConfig, GRPOScriptArguments
from maxent_grpo.prompt_templates import (
normalize_prompt_template,
resolve_generation_stop_settings,
)
from maxent_grpo.training.rewards import load_reward_functions
from maxent_grpo.training.data import resolve_dataloader_kwargs
from maxent_grpo.rewards.basic import get_reward_funcs as _compat_get_reward_funcs
from maxent_grpo.core.data import get_dataset, load_dataset_split
from maxent_grpo.core.hub import ensure_hf_repo_ready
from maxent_grpo.core.model import get_model, get_tokenizer
from maxent_grpo.training.runtime import log_run_header, require_torch
from maxent_grpo.training.seed_paper_eval_callback import SeedPaperEvalCallback
from maxent_grpo.training.runtime.prompts import (
PROMPT_CHAR_LIMIT,
_prompt_char_limit_from_tokens,
_to_prompt,
)
from maxent_grpo.training.scoring_common import (
_coerce_optional_int,
_get_embedding_vocab_size,
)
from maxent_grpo.training.trl_trainer import (
build_custom_grpo_trainer,
wrap_trl_trainer,
)
from maxent_grpo.utils.deps_guard import ensure_real_dependencies
if TYPE_CHECKING:
from trl import ModelConfig
class _LazyModuleProxy:
"""Proxy that lazily imports a module on first attribute access."""
def __init__(self, module_name: str) -> None:
self._module_name = module_name
self._module: Any | None = None
def _load(self) -> Any:
if self._module is None:
self._module = import_module(self._module_name)
return self._module
def __getattr__(self, name: str) -> Any:
if name in self.__dict__:
return self.__dict__[name]
module = self._load()
value = getattr(module, name)
setattr(self, name, value)
return value
transformers = _LazyModuleProxy("transformers")
def _maybe_align_model_tokenizer_vocab(model: Any, tokenizer: Any) -> None:
"""Resize model embeddings when tokenizer exposes additional addressable ids."""
try:
tokenizer_size = _coerce_optional_int(len(tokenizer))
except Exception:
tokenizer_size = None
if not isinstance(tokenizer_size, int) or tokenizer_size <= 0:
return
config = getattr(model, "config", None)
model_vocab_size = _get_embedding_vocab_size(model, config)
if isinstance(model_vocab_size, int) and model_vocab_size >= tokenizer_size:
return
resize_fn = getattr(model, "resize_token_embeddings", None)
if not callable(resize_fn):
return
LOG.info(
"Resizing model token embeddings from %s to tokenizer size %s to align special tokens.",
model_vocab_size,
tokenizer_size,
)
resize_fn(int(tokenizer_size))
def _guided_decoding_kwargs(guided_decoding: Any) -> Dict[str, Any]:
"""Extract vLLM guided-decoding fields across version variants."""
kwargs = dict(getattr(guided_decoding, "kwargs", {}) or {})
for name in (
"json",
"regex",
"choice",
"grammar",
"json_object",
"disable_fallback",
"disable_any_whitespace",
"disable_additional_properties",
"whitespace_pattern",
"structural_tag",
):
if name not in kwargs:
value = getattr(guided_decoding, name, None)
if value is not None:
kwargs[name] = value
return kwargs
def _patch_vllm_guided_decoding_compat() -> None:
"""Bridge TRL 0.18 guided decoding onto vLLM 0.16 structured outputs."""
try:
sampling_params_mod = import_module("vllm.sampling_params")
except Exception:
return
structured_outputs_cls = getattr(
sampling_params_mod, "StructuredOutputsParams", None
)
if structured_outputs_cls is None:
return
guided_cls = getattr(sampling_params_mod, "GuidedDecodingParams", None)
if guided_cls is None:
class GuidedDecodingParams:
def __init__(self, backend: Optional[str] = None, **kwargs: Any) -> None:
self.backend = backend
self.kwargs = dict(kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
guided_cls = GuidedDecodingParams
setattr(sampling_params_mod, "GuidedDecodingParams", guided_cls)
vllm_mod = import_module("vllm")
original_sampling_params = getattr(vllm_mod, "SamplingParams", None)
if original_sampling_params is None:
return
def _guided_to_structured_outputs(guided_decoding: Any) -> Any:
if guided_decoding is None:
return None
if isinstance(guided_decoding, structured_outputs_cls):
return guided_decoding
structured = structured_outputs_cls(
**_guided_decoding_kwargs(guided_decoding)
)
backend = getattr(guided_decoding, "backend", None)
if backend is not None:
try:
setattr(structured, "_backend", backend)
except Exception:
pass
return structured
def _compat_sampling_params(*args: Any, **kwargs: Any) -> Any:
if "guided_decoding" in kwargs and "structured_outputs" not in kwargs:
kwargs["structured_outputs"] = _guided_to_structured_outputs(
kwargs.pop("guided_decoding")
)
else:
kwargs.pop("guided_decoding", None)
return original_sampling_params(*args, **kwargs)
for module_name in ("trl.trainer.grpo_trainer", "trl.scripts.vllm_serve"):
module = sys.modules.get(module_name)
if module is None or getattr(module, "_maxent_guided_decoding_patch", False):
continue
setattr(module, "GuidedDecodingParams", guided_cls)
setattr(module, "SamplingParams", _compat_sampling_params)
setattr(module, "_maxent_guided_decoding_patch", True)
def _main_process_first(training_args: Any, desc: str) -> Any:
"""Return a process-ordering context when TrainingArguments provides one."""
main_process_first = getattr(training_args, "main_process_first", None)
if not callable(main_process_first):
return nullcontext()
try:
return main_process_first(local=True, desc=desc)
except TypeError:
try:
return main_process_first(desc=desc)
except TypeError:
return main_process_first()
@contextmanager
def _force_vllm_dtype(
training_args: GRPOConfig,
tokenizer: Optional[Any] = None,
) -> Iterator[None]:
"""Ensure TRL vLLM init respects local dtype and colocate engine overrides."""
dtype_override = None
if getattr(training_args, "fp16", False):
dtype_override = "float16"
elif getattr(training_args, "bf16", False):
dtype_override = "bfloat16"
if not (dtype_override and getattr(training_args, "use_vllm", False)):
yield
return
try:
import trl.trainer.grpo_trainer as grpo_mod
from vllm import LLM as _LLM
except (ImportError, AttributeError, RuntimeError):
# If vLLM/TRL isn't available, fall through without patching.
yield
return
orig_llm = getattr(grpo_mod, "LLM", None)
use_colocate_wrapper = (
bool(getattr(training_args, "use_vllm", False))
and str(getattr(training_args, "vllm_mode", "server") or "server").strip().lower()
== "colocate"
)
class _TRLColocateContextProxy:
"""Small adapter exposing the attrs expected by the local colocate engine."""
def __init__(
self,
args_obj: GRPOConfig,
tokenizer_obj: Optional[Any],
model_id: Optional[str],
) -> None:
self.training_args = args_obj
self.tokenizer = tokenizer_obj
self.generation_stats: Dict[str, Any] = {}
# The TRL constructor path does not have access to the local generator fallback.
self.vllm_disable_local_fallback = True
self.prompt_char_limit = PROMPT_CHAR_LIMIT
self.max_prompt_len = int(getattr(args_obj, "max_prompt_length", 0) or 0)
if isinstance(model_id, str) and model_id:
self.model_name_or_path = model_id
self.model_id = model_id
self.vllm_model_id = model_id
def __getattr__(self, name: str) -> Any:
return getattr(self.training_args, name)
@contextmanager
def _temporary_attr_overrides(target: Any, overrides: Dict[str, Any]) -> Iterator[None]:
previous: Dict[str, Any] = {}
missing: List[str] = []
for key, value in overrides.items():
if hasattr(target, key):
previous[key] = getattr(target, key)
else:
missing.append(key)
setattr(target, key, value)
try:
yield
finally:
for key in overrides:
if key in previous:
setattr(target, key, previous[key])
else:
try:
delattr(target, key)
except AttributeError:
if key not in missing:
raise
class _TRLColocateSyncModel:
"""Expose TRL's expected ``load_weights`` surface via the local sync client."""
def __init__(self, sync_client: Any) -> None:
self._sync_client = sync_client
def load_weights(self, named_params: List[tuple[str, Any]]) -> None:
self._sync_client.ensure_ready()
for name, param in named_params:
self._sync_client.update_named_param(name, param)
class _TRLColocateLLMWrapper:
"""Thin adapter so TRL can drive the local subprocess-isolated colocate engine."""
def __init__(self, ctx: Any, tokenizer_obj: Optional[Any]) -> None:
from maxent_grpo.training.rollout.vllm_colocate import ColocateVLLMEngine
self._ctx = ctx
self._tokenizer = tokenizer_obj
self._engine = ColocateVLLMEngine(
ctx,
lambda *_args, **_kwargs: (_raise_no_local_fallback(), None),
)
self._sync_client = self._engine.sync_client()
sync_model = _TRLColocateSyncModel(self._sync_client)
self.llm_engine = SimpleNamespace(
model_executor=SimpleNamespace(
driver_worker=SimpleNamespace(
model_runner=SimpleNamespace(model=sync_model)
)
)
)
def _encode_completion(self, text: str) -> List[int]:
if self._tokenizer is None:
raise RuntimeError("Tokenizer unavailable for colocated vLLM completion encoding.")
encoded = self._tokenizer(
text,
add_special_tokens=False,
return_attention_mask=False,
)
if isinstance(encoded, MutableMappingABC):
token_ids = encoded.get("input_ids")
else:
token_ids = getattr(encoded, "input_ids", None)
if token_ids is None:
raise RuntimeError("Tokenizer did not return input_ids for colocated vLLM output.")
return [int(token_id) for token_id in token_ids]
def _token_ids_from_meta(self, entry: Any) -> Optional[List[int]]:
raw_output = None
token_ids = None
if isinstance(entry, MutableMappingABC):
token_ids = entry.get("token_ids")
raw_output = entry.get("raw_output")
else:
token_ids = getattr(entry, "token_ids", None)
raw_output = getattr(entry, "raw_output", None)
if token_ids is None and isinstance(raw_output, MutableMappingABC):
token_ids = raw_output.get("token_ids")
if token_ids is None:
return None
try:
return [int(token_id) for token_id in token_ids]
except (TypeError, ValueError):
return None
def generate(self, prompts: List[str], sampling_params: Any = None, use_tqdm: bool = False) -> List[Any]:
del use_tqdm
request_count = max(1, int(getattr(sampling_params, "n", 1) or 1))
guided = getattr(sampling_params, "guided_decoding", None)
guided_regex = getattr(guided, "regex", None) if guided is not None else None
stop_sequences = getattr(sampling_params, "stop", None)
overrides = {
"gen_temperature": float(
getattr(sampling_params, "temperature", getattr(self._ctx, "gen_temperature", 1.0))
),
"gen_top_p": float(
getattr(sampling_params, "top_p", getattr(self._ctx, "gen_top_p", 1.0))
),
"gen_top_k": int(
getattr(sampling_params, "top_k", getattr(self._ctx, "gen_top_k", -1))
),
"gen_min_p": float(
getattr(sampling_params, "min_p", getattr(self._ctx, "gen_min_p", 0.0))
),
"gen_repetition_penalty": float(
getattr(
sampling_params,
"repetition_penalty",
getattr(self._ctx, "gen_repetition_penalty", 1.0),
)
),
"gen_frequency_penalty": float(
getattr(
sampling_params,
"frequency_penalty",
getattr(self._ctx, "gen_frequency_penalty", 0.0),
)
),
"gen_presence_penalty": float(
getattr(
sampling_params,
"presence_penalty",
getattr(self._ctx, "gen_presence_penalty", 0.0),
)
),
"max_completion_len": int(
getattr(
sampling_params,
"max_tokens",
getattr(self._ctx, "max_completion_len", 0),
)
),
"gen_stop_sequences": stop_sequences,
"vllm_guided_regex": guided_regex,
}
with _temporary_attr_overrides(self._ctx, overrides):
grouped, grouped_meta = self._engine.request_batch(
list(prompts), request_count
)
if grouped is None:
raise RuntimeError("Colocated vLLM returned no outputs.")
meta_groups = list(grouped_meta or [])
if len(meta_groups) < len(grouped):
meta_groups.extend([] for _ in range(len(grouped) - len(meta_groups)))
return [
SimpleNamespace(
outputs=[
SimpleNamespace(token_ids=token_ids)
for idx, text in enumerate(prompt_outputs)
for token_ids in [
(
self._token_ids_from_meta(meta_group[idx])
if idx < len(meta_group)
else None
)
or self._encode_completion(text)
]
]
)
for prompt_outputs, meta_group in zip(grouped, meta_groups)
]
def reset_prefix_cache(self) -> None:
self._sync_client.reset_prefix_cache()
def _raise_no_local_fallback() -> None:
raise RuntimeError("Local fallback generation is unavailable in TRL colocate mode.")
def _patched_llm(*args: Any, **kwargs: Any) -> Any:
kwargs.setdefault("dtype", dtype_override)
if use_colocate_wrapper:
model_id = kwargs.get("model")
if not isinstance(model_id, str) and args:
first_arg = args[0]
if isinstance(first_arg, str):
model_id = first_arg
ctx = _TRLColocateContextProxy(training_args, tokenizer, model_id)
return _TRLColocateLLMWrapper(ctx, tokenizer)
return _LLM(*args, **kwargs)
if orig_llm is not None:
grpo_mod.LLM = _patched_llm
try:
yield
finally:
if orig_llm is not None:
grpo_mod.LLM = orig_llm
LOG = logging.getLogger(__name__)
_VLLM_BATCH_UPDATE_PREFIX = "__maxent_vllm_batch__:"
GRPOTrainerOverride: Optional[type] = None
get_peft_config_override: Optional[Any] = (
None # Callable but kept lax to avoid importing typing.Callable
)
__all__ = [
"GRPOTrainerOverride",
"get_peft_config_override",
"get_reward_funcs",
"run_baseline_training",
"_to_prompt",
"PROMPT_CHAR_LIMIT",
]
# Backward compatibility hook for tests/legacy callers that monkeypatch reward resolution.
get_reward_funcs = _compat_get_reward_funcs
_EVAL_DATASET_PRESETS: Dict[str, Dict[str, Optional[str]]] = {
"math_500": {
"dataset_name": "HuggingFaceH4/MATH-500",
"dataset_config": "default",
"split": "test",
"prompt_column": "problem",
"solution_column": "answer",
},
"aime24": {
"dataset_name": "HuggingFaceH4/aime_2024",
"dataset_config": "default",
"split": "train",
"prompt_column": "problem",
"solution_column": "answer",
},
"aime25": {
"dataset_name": "yentinglin/aime_2025",
"dataset_config": "default",
"split": "train",
"prompt_column": "problem",
"solution_column": "answer",
},
"amc": {
"dataset_name": "AI-MO/aimo-validation-amc",
"dataset_config": "default",
"split": "train",
"prompt_column": "problem",
"solution_column": "answer",
},
"minerva": {
"dataset_name": "math-ai/minervamath",
"dataset_config": "default",
"split": "test",
"prompt_column": "question",
"solution_column": "answer",
},
"olympiad_bench": {
"dataset_name": "knoveleng/OlympiadBench",
"dataset_config": "default",
"split": "train",
"prompt_column": "question",
"solution_column": "answer",
},
}
_EVAL_DATASET_ALIASES = {
"math": "math_500",
"aime_24": "aime24",
"aime_2024": "aime24",
"aime_25": "aime25",
"aime_2025": "aime25",
"olympiadbench": "olympiad_bench",
"olympiad": "olympiad_bench",
"oly": "olympiad_bench",
}
def _resolve_eval_dataset_preset(spec: str) -> Optional[Dict[str, Optional[str]]]:
"""Resolve built-in benchmark aliases used by training eval configs."""
normalized = spec.strip().lower().replace("-", "_")
normalized = _EVAL_DATASET_ALIASES.get(normalized, normalized)
preset = _EVAL_DATASET_PRESETS.get(normalized)
if preset is None:
return None
return dict(preset)
@runtime_checkable
class ChatTemplate(Protocol):
"""Protocol for objects with chat templating capabilities."""
def apply_chat_template(
self,
conversation: List[Dict[str, str]],
tokenize: bool = True,
add_generation_prompt: bool = True,
) -> Union[str, List[int]]:
"""Render a chat conversation according to an internal template.
:param conversation: Ordered list of chat messages.
:type conversation: list[dict[str, str]]
:param tokenize: Whether to return token IDs instead of text.
:type tokenize: bool
:param add_generation_prompt: Append assistant prefix at the end.
:type add_generation_prompt: bool
:returns: The templated conversation as text or token IDs.
:rtype: str | list[int]
"""
raise NotImplementedError
def _collect_dataset_columns(dataset: Any) -> Dict[str, List[str]]:
"""Return per-split column names when discoverable."""
col_map: Dict[str, List[str]] = {}
cols = getattr(dataset, "column_names", None)
if isinstance(cols, dict):
for split, names in cols.items():
if isinstance(names, (list, tuple)) and names:
col_map[str(split)] = list(names)
return col_map
if isinstance(cols, (list, tuple)) and cols:
return {"all": list(cols)}
if isinstance(dataset, dict):
for split, split_ds in dataset.items():
split_cols = getattr(split_ds, "column_names", None)
if isinstance(split_cols, (list, tuple)) and split_cols:
col_map[str(split)] = list(split_cols)
continue
if isinstance(split_ds, list) and split_ds:
first = split_ds[0]
if isinstance(first, dict):
col_map[str(split)] = list(first.keys())
return col_map
def _get_column_names(dataset: Any) -> List[str]:
"""Return a best-effort list of column names for a dataset split."""
cols = getattr(dataset, "column_names", None)
if isinstance(cols, (list, tuple)):
return list(cols)
return []
def _validate_dataset_columns(
dataset: Any,
*,
prompt_column: str,
solution_column: str,
label: str,
) -> None:
"""Fail fast if required dataset columns are missing."""
col_map = _collect_dataset_columns(dataset)
if not col_map:
LOG.debug("Unable to infer columns for %s; skipping early validation.", label)
return
message_only = {"messages", "message"}
if all(cols and set(cols).issubset(message_only) for cols in col_map.values()):
LOG.debug(
"Detected message-only dataset columns for %s; skipping early validation.",
label,
)
return
missing_by_split: Dict[str, List[str]] = {}
for split, cols in col_map.items():
if (
"messages" in cols
and prompt_column not in cols
and solution_column not in cols
):
continue
missing = [
name for name in (prompt_column, solution_column) if name not in cols
]
if missing:
missing_by_split[split] = missing
if missing_by_split:
if all(
set(missing) == {solution_column} for missing in missing_by_split.values()
):
LOG.info(
"%s is missing '%s'; continuing with empty answers.",
label,
solution_column,
)
return
missing_desc = "; ".join(
f"{split} missing {', '.join(cols)}"
for split, cols in missing_by_split.items()
)
available_desc = "; ".join(
f"{split}={sorted(cols)}" for split, cols in col_map.items()
)
raise ValueError(
f"{label} is missing required columns: {missing_desc}. "
f"Available columns: {available_desc}"
)
def _resolve_prompt_column(dataset: Any, prompt_column: str) -> str:
"""Return an inferred prompt column when the default is missing."""
if prompt_column != "problem":
return prompt_column
col_map = _collect_dataset_columns(dataset)
if not col_map:
return prompt_column
if all("problem" in cols for cols in col_map.values()):
return prompt_column
if all("prompt" in cols for cols in col_map.values()):
LOG.info("Prompt column '%s' missing; falling back to 'prompt'.", prompt_column)
return "prompt"
return prompt_column
def _split_eval_dataset_specs(raw_name: Any) -> List[str]:
"""Return normalized evaluation dataset spec entries from config."""
if raw_name is None:
return []
if isinstance(raw_name, (list, tuple)):
specs: List[str] = []
for item in raw_name:
item_text = str(item).strip()
if item_text:
specs.extend(
part.strip() for part in item_text.split(",") if part.strip()
)
return specs
text = str(raw_name).strip()
if not text:
return []
return [part.strip() for part in text.split(",") if part.strip()]
def _canonical_eval_benchmark_label(spec: str) -> str:
"""Return stable benchmark labels used in eval metric suffixes."""
normalized = spec.strip().lower().replace("-", "_")
aliases = {
"math": "MATH",
"math500": "MATH",
"math_500": "MATH",
"aime24": "AIME24",
"aime_24": "AIME24",
"aime_2024": "AIME24",
"amc": "AMC",
}
if normalized in aliases:
return aliases[normalized]
cleaned = "".join(ch if ch.isalnum() else "_" for ch in spec.strip())
cleaned = cleaned.strip("_").upper()
return cleaned or "EVAL"
def _resolve_eval_dataset_spec(
spec: str,
*,
default_dataset_config: Optional[str],
default_split: str,
default_prompt_column: str,
default_solution_column: str,
) -> tuple[str, Optional[str], str, str, str, str]:
"""Resolve one evaluation dataset spec (preset alias or HF dataset id)."""
dataset_name = spec
dataset_config = default_dataset_config
dataset_split = default_split
prompt_column = default_prompt_column
solution_column = default_solution_column
preset = _resolve_eval_dataset_preset(spec)
if preset is not None:
dataset_name = str(preset.get("dataset_name") or dataset_name)
raw_dataset_config = preset.get("dataset_config")
dataset_config = (
str(raw_dataset_config) if raw_dataset_config is not None else None
)
dataset_split = str(preset.get("split") or default_split)
prompt_column = str(preset.get("prompt_column") or default_prompt_column)
solution_column = str(preset.get("solution_column") or default_solution_column)
benchmark_label = _canonical_eval_benchmark_label(spec)
return (
dataset_name,
dataset_config,
dataset_split,
prompt_column,
solution_column,
benchmark_label,
)
def _ensure_split_mapping(dataset: Any) -> MutableMapping[str, Any]:
"""Coerce dataset-like objects into a split->dataset mapping."""
if isinstance(dataset, MutableMappingABC):
return cast(MutableMapping[str, Any], dataset)
if hasattr(dataset, "keys") and hasattr(dataset, "__getitem__"):
return cast(MutableMapping[str, Any], dataset)
return {"train": dataset}
def _resolve_vllm_group_port() -> Optional[int]:
"""Resolve the vLLM communicator port from launcher environment."""
for key in ("VLLM_GROUP_PORT", "PORT_FOR_COMMUNICATION"):
raw = str(os.getenv(key, "")).strip()
if not raw:
continue
try:
port = int(raw)
except ValueError:
LOG.warning("Ignoring invalid %s=%r (expected integer port).", key, raw)
continue
if 1 <= port <= 65535:
return port
LOG.warning("Ignoring out-of-range %s=%r (expected 1..65535).", key, raw)
return None
@contextmanager
def _temporary_env(overrides: Dict[str, str]) -> Iterator[None]:
"""Temporarily set environment variables while preserving prior values."""
if not overrides:
yield
return
previous: Dict[str, Optional[str]] = {}
for key, value in overrides.items():
previous[key] = os.environ.get(key)
os.environ[key] = value
try:
yield
finally:
for key, prior in previous.items():
if prior is None:
os.environ.pop(key, None)
else:
os.environ[key] = prior
def _loopback_host(base_url: str) -> bool:
try:
parsed = urlparse(base_url)
host = parsed.hostname or ""
except Exception:
host = ""
if not host:
host = base_url
host = host.strip().lower()
return host in {"localhost", "127.0.0.1", "::1"}
def _vllm_client_nccl_overrides(base_url: str) -> Dict[str, str]:
"""Return conservative NCCL settings for loopback vLLM sync."""
overrides: Dict[str, str] = {}
enable_overrides = str(
os.getenv("MAXENT_VLLM_CLIENT_NCCL_OVERRIDES", "0")
).strip().lower() in {"1", "true", "yes", "on"}
if not enable_overrides:
return overrides
if not _loopback_host(base_url):
explicit = os.getenv("MAXENT_VLLM_CLIENT_NCCL_SOCKET_IFNAME")
if explicit and "NCCL_SOCKET_IFNAME" not in os.environ:
overrides["NCCL_SOCKET_IFNAME"] = explicit
explicit = os.getenv("MAXENT_VLLM_CLIENT_NCCL_P2P_DISABLE")
if explicit and "NCCL_P2P_DISABLE" not in os.environ:
overrides["NCCL_P2P_DISABLE"] = explicit
explicit = os.getenv("MAXENT_VLLM_CLIENT_NCCL_IB_DISABLE")
if explicit and "NCCL_IB_DISABLE" not in os.environ:
overrides["NCCL_IB_DISABLE"] = explicit
return overrides
if "NCCL_SOCKET_IFNAME" not in os.environ:
overrides["NCCL_SOCKET_IFNAME"] = os.getenv(
"MAXENT_VLLM_CLIENT_NCCL_SOCKET_IFNAME", "lo"
)
if "NCCL_P2P_DISABLE" not in os.environ:
overrides["NCCL_P2P_DISABLE"] = os.getenv(
"MAXENT_VLLM_CLIENT_NCCL_P2P_DISABLE", "1"
)
if "NCCL_IB_DISABLE" not in os.environ:
overrides["NCCL_IB_DISABLE"] = os.getenv(
"MAXENT_VLLM_CLIENT_NCCL_IB_DISABLE", "1"
)
return overrides
def _vllm_sync_chunk_bytes() -> int:
"""Return the max weight-sync batch size for server-mode vLLM updates."""
raw = os.getenv("MAXENT_VLLM_SYNC_CHUNK_MB", "64")
try:
mb = int(raw)
except (TypeError, ValueError):
mb = 64
if mb <= 0:
mb = 64
return mb * 1024 * 1024
def _encode_vllm_batched_update(
names: list[str],
dtypes: list[str],
shapes: list[list[int]],
) -> dict[str, Any]:
"""Encode batched vLLM weight metadata through TRL's legacy request model."""
return {
"name": _VLLM_BATCH_UPDATE_PREFIX
+ json.dumps(
{"names": names, "dtypes": dtypes, "shapes": shapes},
separators=(",", ":"),
),
"dtype": dtypes[0] if dtypes else "float16",
"shape": shapes[0] if shapes else [0],
}
def _tensor_nbytes(tensor: Any) -> int:
"""Best-effort tensor size in bytes for batching decisions."""
try:
return int(tensor.numel()) * int(tensor.element_size())
except Exception:
return 0
def _import_builtin_vllm_weight_transfer() -> Optional[type]:
"""Return vLLM's built-in NCCL transfer engine when available."""
try:
nccl_engine_mod = import_module(
"vllm.distributed.weight_transfer.nccl_engine"
)
gpu_worker_mod = import_module("vllm.v1.worker.gpu_worker")
except Exception:
return None
engine_cls = getattr(nccl_engine_mod, "NCCLWeightTransferEngine", None)
if engine_cls is None:
return None
if not (
callable(getattr(engine_cls, "trainer_init", None))
or callable(getattr(engine_cls, "init_process_group", None))
):
return None
if not callable(getattr(engine_cls, "trainer_send_weights", None)):
return None
worker_cls = getattr(gpu_worker_mod, "GPUWorker", None)
if worker_cls is not None:
if not callable(getattr(worker_cls, "init_weight_transfer_engine", None)):
return None
if not callable(getattr(worker_cls, "update_weights", None)):
return None
return engine_cls
def _builtin_weight_transfer_trainer_init(
engine_cls: type,
init_info: dict[str, Any],
) -> Any:
"""Initialize trainer-side vLLM weight transfer across version variants."""
trainer_init = getattr(engine_cls, "trainer_init", None)
if callable(trainer_init):
return trainer_init(init_info)
legacy_init = getattr(engine_cls, "init_process_group", None)
if callable(legacy_init):
return legacy_init(init_info)
raise RuntimeError("Built-in vLLM weight transfer lacks trainer init entrypoint")
def _clear_vllm_client_buffer(client: Any) -> None:
"""Reset any buffered trainer-side weight updates."""
setattr(client, "_maxent_weight_buffer", [])
setattr(client, "_maxent_weight_buffer_bytes", 0)
def _resolve_vllm_client_generate_boundary(client: Any) -> Dict[str, Any]:
"""Resolve tokenizer/model boundary metadata for live server-mode rollouts."""
cached = getattr(client, "_maxent_generate_boundary", None)
if isinstance(cached, dict):
return cached
model_id = str(os.getenv("MAXENT_VLLM_SERVER_MODEL_NAME", "") or "").strip()
if not model_id:
boundary = {
"model_id": None,
"tokenizer_limit": None,
"model_limit": None,
"blocked_token_ids": [],
}
setattr(client, "_maxent_generate_boundary", boundary)
return boundary
tokenizer_limit_env = _coerce_optional_int(
os.getenv("MAXENT_VLLM_SERVER_TOKENIZER_VOCAB_LIMIT")
)
model_limit_env = _coerce_optional_int(
os.getenv("MAXENT_VLLM_SERVER_MODEL_VOCAB_LIMIT")
)
tokenizer_limit = (
int(tokenizer_limit_env)
if isinstance(tokenizer_limit_env, int) and tokenizer_limit_env > 0
else None
)
model_limit = (
int(model_limit_env)
if isinstance(model_limit_env, int) and model_limit_env > 0
else None
)
if tokenizer_limit is None or model_limit is None:
try:
transformers_mod = import_module("transformers")
auto_tokenizer = getattr(transformers_mod, "AutoTokenizer")
auto_config = getattr(transformers_mod, "AutoConfig")
tokenizer = auto_tokenizer.from_pretrained(model_id, trust_remote_code=True)
config = auto_config.from_pretrained(model_id, trust_remote_code=True)
except Exception as exc:
raise RuntimeError(
"Failed to resolve vLLM token boundary for live server-mode rollouts "
f"(model={model_id}): {exc}"
) from exc
if tokenizer_limit is None:
tokenizer_limit = max(
int(getattr(tokenizer, "vocab_size", 0) or 0),
int(len(tokenizer)),
)
if model_limit is None:
model_limit = int(getattr(config, "vocab_size", 0) or 0)
if tokenizer_limit <= 0 or model_limit <= 0:
raise RuntimeError(
"Resolved invalid vLLM token boundary values "
f"(model={model_id}, tokenizer_limit={tokenizer_limit}, model_limit={model_limit})"
)
blocked_token_ids: List[int] = []
if model_limit > tokenizer_limit:
blocked_token_ids = list(range(int(tokenizer_limit), int(model_limit)))
boundary = {
"model_id": model_id,
"tokenizer_limit": int(tokenizer_limit),
"model_limit": int(model_limit),
"blocked_token_ids": blocked_token_ids,
}
setattr(client, "_maxent_generate_boundary", boundary)
if not bool(getattr(client, "_maxent_generate_boundary_logged", False)):
LOG.warning(
"Patched TRL VLLMClient.generate boundary | model=%s tokenizer_limit=%d model_limit=%d blocked_tail=%d",
model_id,
int(tokenizer_limit),
int(model_limit),
len(blocked_token_ids),
)
setattr(client, "_maxent_generate_boundary_logged", True)
return boundary
def _normalize_vllm_generate_url(base_url: str) -> str:
"""Return the canonical /generate endpoint for a vLLM server base URL."""
base = str(base_url or "").strip()
if not base:
raise RuntimeError("vLLM client base_url is unavailable")
if base.endswith("/generate/"):
return base
if base.endswith("/generate"):
return f"{base}/"
return f"{base.rstrip('/')}/generate/"
def _validate_vllm_completion_ids(
completion_ids: List[List[int]],
*,
tokenizer_limit: Optional[int],
model_id: Optional[str],
) -> None:
"""Fail fast when live rollouts contain tokenizer-inaccessible token IDs."""
if not isinstance(tokenizer_limit, int) or tokenizer_limit <= 0:
return
invalid_tokens = [
int(token_id)
for sequence in completion_ids
for token_id in sequence
if int(token_id) < 0 or int(token_id) >= int(tokenizer_limit)
]
if not invalid_tokens:
return
sample = invalid_tokens[:16]
raise RuntimeError(
"Detected completion token ids outside the tokenizer-addressable range "
f"(model={model_id or 'unknown'}, tokenizer_limit={int(tokenizer_limit)}, sample={sample})"
)
def _patch_trl_vllm_client_init() -> None:
"""Patch TRL VLLMClient init handshake to avoid POST-first deadlocks."""
try:
import trl.extras.vllm_client as trl_vllm_client_mod
except Exception as exc: # pragma: no cover - optional dependency path
LOG.debug("Skipping vLLM client patch; trl.extras import failed: %s", exc)
return
client_cls = getattr(trl_vllm_client_mod, "VLLMClient", None)
if client_cls is None:
return
if getattr(client_cls, "_maxent_async_init_patch", False):
return
try:
from maxent_grpo.training.generation.vllm_utils import (
init_vllm_client_communicator as _init_vllm_client_communicator,
)
except Exception as exc: # pragma: no cover - defensive
LOG.warning("Failed to import async vLLM init helper: %s", exc)
return
original_ctor = getattr(client_cls, "__init__", None)
original_init_communicator = getattr(client_cls, "init_communicator", None)
original_update_named_param = getattr(client_cls, "update_named_param", None)
original_update_model_params = getattr(client_cls, "update_model_params", None)
original_reset_prefix_cache = getattr(client_cls, "reset_prefix_cache", None)
original_close_communicator = getattr(client_cls, "close_communicator", None)
original_generate = getattr(client_cls, "generate", None)
if (
not callable(original_ctor)
or not callable(original_init_communicator)
or not callable(original_update_named_param)
or not callable(original_generate)
):
return
builtin_weight_transfer = _import_builtin_vllm_weight_transfer()
def _patched_ctor(self: Any, *args: Any, **kwargs: Any) -> None:
if "group_port" not in kwargs or kwargs.get("group_port") in (None, 0):
resolved_group_port = _resolve_vllm_group_port()
if resolved_group_port is not None:
kwargs["group_port"] = resolved_group_port
original_ctor(self, *args, **kwargs)
_clear_vllm_client_buffer(self)
setattr(
self,
"_maxent_weight_chunk_bytes",
0 if builtin_weight_transfer is not None else _vllm_sync_chunk_bytes(),
)
setattr(self, "_maxent_builtin_weight_transfer", builtin_weight_transfer is not None)
setattr(self, "_maxent_generate_boundary", None)
def _patched_generate(
self: Any,
prompts: List[str],
n: int = 1,
repetition_penalty: float = 1.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
max_tokens: int = 16,
guided_decoding_regex: Optional[str] = None,
) -> List[List[int]]:
boundary = _resolve_vllm_client_generate_boundary(self)
blocked_token_ids = list(boundary.get("blocked_token_ids") or [])
url = _normalize_vllm_generate_url(getattr(self, "base_url", ""))
payload: Dict[str, Any] = {
"prompts": prompts,
"n": int(n),
"repetition_penalty": float(repetition_penalty),
"temperature": float(temperature),
"top_p": float(top_p),
"top_k": int(top_k),
"min_p": float(min_p),
"max_tokens": int(max_tokens),
"guided_decoding_regex": guided_decoding_regex,
}
if blocked_token_ids:
payload["blocked_token_ids"] = blocked_token_ids
response = self.session.post(url, json=payload)
if response.status_code != 200:
raise RuntimeError(
f"Request failed: {response.status_code}, {response.text}"
)
response_payload = response.json()
raw_completion_ids = response_payload.get("completion_ids")
if not isinstance(raw_completion_ids, list):
raise RuntimeError("vLLM generate response missing completion_ids")
completion_ids: List[List[int]] = []
for idx, item in enumerate(raw_completion_ids):
if not isinstance(item, list):
raise RuntimeError(f"completion_ids[{idx}] is not a list")
completion_ids.append([int(token_id) for token_id in item])
_validate_vllm_completion_ids(
completion_ids,
tokenizer_limit=cast(Optional[int], boundary.get("tokenizer_limit")),
model_id=cast(Optional[str], boundary.get("model_id")),
)
return completion_ids
def _patched_init_communicator(self: Any) -> None:
base_url = str(getattr(self, "base_url", ""))
overrides = _vllm_client_nccl_overrides(base_url)
if overrides:
LOG.info(
"vLLM client NCCL overrides applied | %s",
", ".join(f"{k}={v}" for k, v in overrides.items()),
)
with _temporary_env(overrides):
if builtin_weight_transfer is None:
bound_original = cast(
Callable[[], None],
original_init_communicator.__get__(self, type(self)),
)
_init_vllm_client_communicator(
self,
log=LOG.info,
init_fn=bound_original,
)
return
timeout = float(os.getenv("MAXENT_VLLM_INIT_TIMEOUT_S", "60"))
retries_raw = os.getenv("MAXENT_VLLM_INIT_RETRIES", "2")
backoff_raw = os.getenv("MAXENT_VLLM_INIT_RETRY_BACKOFF_S", "2.0")
try:
retries = max(1, int(retries_raw))
except (TypeError, ValueError):
retries = 2
try:
backoff_s = max(0.0, float(backoff_raw))
except (TypeError, ValueError):
backoff_s = 2.0
host = str(getattr(self, "host", "") or "").strip()
if not host:
raise RuntimeError("vLLM client host is unavailable for weight sync")
group_port = getattr(self, "group_port", None)
if group_port in (None, 0):
raise RuntimeError("vLLM group_port is unavailable for weight sync")
def _close_local_group() -> None:
if getattr(self, "pynccl_comm", None) is not None:
try:
delattr(self, "pynccl_comm")
except Exception:
setattr(self, "pynccl_comm", None)
last_error: Optional[BaseException] = None
for attempt in range(1, retries + 1):
_close_local_group()
try:
response = self.session.get(
f"{self.base_url}/get_world_size/",
timeout=timeout,
)
response.raise_for_status()
vllm_world_size = int(response.json()["world_size"])
world_size = vllm_world_size + 1
init_url = f"{self.base_url}/init_communicator/"
payload = {
"host": host,
"port": int(group_port),
"world_size": world_size,
}
post_resp = self.session.post(
init_url,
json=payload,
timeout=timeout,
)
if post_resp.status_code != 200:
raise RuntimeError(
"vLLM init_communicator POST failed: "
f"{post_resp.status_code} {getattr(post_resp, 'text', '')}"
)
# Match TRL's original init ordering: let the server accept
# the init request first, then join the NCCL group locally.
time.sleep(0.1)
comm = _builtin_weight_transfer_trainer_init(
builtin_weight_transfer,
{
"master_address": host,
"master_port": int(group_port),
"rank_offset": 1,
"world_size": world_size,
}
)
self.rank = 0
self.pynccl_comm = comm
if self.pynccl_comm is None:
raise RuntimeError(
"vLLM trainer weight-transfer init produced no communicator"
)
atexit.register(self.close_communicator)
return
except Exception as exc:
last_error = exc
LOG.info(
"vLLM init_communicator failed (attempt %d): %s",
attempt,
exc,
)
_close_local_group()
if attempt >= retries:
break
time.sleep(backoff_s)
if last_error is not None:
raise RuntimeError(str(last_error)) from last_error
def _flush_weight_buffer(self: Any) -> None:
buffer = list(getattr(self, "_maxent_weight_buffer", []) or [])
if not buffer:
return
if builtin_weight_transfer is None:
raise RuntimeError("Built-in vLLM weight transfer is unavailable")
if getattr(self, "pynccl_comm", None) is None:
raise RuntimeError(
"Communicator not initialized. Call `init_communicator` first."
)
names = [str(name) for name, _ in buffer]
dtypes = [str(weight.dtype).split(".")[-1] for _, weight in buffer]
shapes = [list(tuple(weight.shape)) for _, weight in buffer]
url = f"{self.base_url}/update_named_param/"
response_holder: Dict[str, Any] = {}
def _post_update() -> None:
try:
response_holder["resp"] = self.session.post(
url,
json=_encode_vllm_batched_update(names, dtypes, shapes),
)
except Exception as exc:
response_holder["error"] = exc
post_thread = threading.Thread(target=_post_update, daemon=True)
post_thread.start()
builtin_weight_transfer.trainer_send_weights(
iter(buffer),
self.pynccl_comm,
src=0,
packed=True,
)
post_thread.join()
post_error = response_holder.get("error")
if post_error is not None:
raise RuntimeError(f"vLLM update_named_param POST failed: {post_error}")
response = response_holder.get("resp")
if response is None:
raise RuntimeError("vLLM update_named_param POST returned no response")
if response.status_code != 200:
raise RuntimeError(f"Request failed: {response.status_code}, {response.text}")
_clear_vllm_client_buffer(self)
def _patched_update_named_param(self: Any, name: str, weights: Any) -> None:
if builtin_weight_transfer is None:
dtype, shape = str(weights.dtype), tuple(weights.shape)
url = f"{self.base_url}/update_named_param/"
response = self.session.post(
url,
json={"name": name, "dtype": dtype, "shape": shape},
)
if response.status_code != 200:
raise RuntimeError(
f"Request failed: {response.status_code}, {response.text}"
)
# vLLM's NCCL broadcast is launched asynchronously on the current
# CUDA stream. Replacing the broken store-backed barrier with a
# stream sync keeps the source buffer valid until the transfer is
# complete.
self.pynccl_comm.broadcast(weights, src=self.rank)
require_torch("baseline_vllm_weight_sync").cuda.current_stream(
device=weights.device
).synchronize()
return
if weights is None:
return
tensor = getattr(weights, "detach", None)
tensor = tensor() if callable(tensor) else weights
weight_buffer = list(getattr(self, "_maxent_weight_buffer", []) or [])
weight_buffer.append((str(name), tensor))
setattr(self, "_maxent_weight_buffer", weight_buffer)
total_bytes = int(getattr(self, "_maxent_weight_buffer_bytes", 0) or 0)
total_bytes += _tensor_nbytes(tensor)
setattr(self, "_maxent_weight_buffer_bytes", total_bytes)
chunk_bytes = int(
getattr(self, "_maxent_weight_chunk_bytes", _vllm_sync_chunk_bytes()) or 0
)
if chunk_bytes > 0 and total_bytes >= chunk_bytes:
_flush_weight_buffer(self)
def _patched_update_model_params(self: Any, model: Any) -> None:
if builtin_weight_transfer is None or not callable(original_update_model_params):
if callable(original_update_model_params):
original_update_model_params(self, model)
return
original_update_model_params(self, model)
_flush_weight_buffer(self)
def _patched_reset_prefix_cache(self: Any) -> Any:
if builtin_weight_transfer is not None:
_flush_weight_buffer(self)
if callable(original_reset_prefix_cache):
return original_reset_prefix_cache(self)
return None
def _patched_close_communicator(self: Any) -> Any:
if builtin_weight_transfer is None:
if callable(original_close_communicator):
return original_close_communicator(self)
return None
try:
_flush_weight_buffer(self)
except Exception:
LOG.debug("Failed to flush pending vLLM weights during shutdown.")
_clear_vllm_client_buffer(self)
if getattr(self, "pynccl_comm", None) is not None:
try:
delattr(self, "pynccl_comm")
except Exception:
setattr(self, "pynccl_comm", None)
session = getattr(self, "session", None)
close = getattr(session, "close", None)
if callable(close):
try:
close()
except Exception:
LOG.debug("Failed to close vLLM client session cleanly.")
return None
setattr(client_cls, "__init__", _patched_ctor)
setattr(client_cls, "generate", _patched_generate)
setattr(client_cls, "init_communicator", _patched_init_communicator)
setattr(client_cls, "update_named_param", _patched_update_named_param)
setattr(client_cls, "flush", _flush_weight_buffer)
if callable(original_update_model_params):
setattr(client_cls, "update_model_params", _patched_update_model_params)
if callable(original_reset_prefix_cache):
setattr(client_cls, "reset_prefix_cache", _patched_reset_prefix_cache)
if callable(original_close_communicator):
setattr(client_cls, "close_communicator", _patched_close_communicator)
setattr(client_cls, "_maxent_async_init_patch", True)
try: # Keep GRPOTrainer's module-local alias in sync if it was imported earlier.
import trl.trainer.grpo_trainer as trl_grpo_mod
if getattr(trl_grpo_mod, "VLLMClient", None) is not client_cls:
setattr(trl_grpo_mod, "VLLMClient", client_cls)
except Exception:
pass
LOG.info("Applied async vLLM communicator patch to TRL VLLMClient.")
[docs]
def run_baseline_training(
script_args: GRPOScriptArguments,
training_args: GRPOConfig,
model_args: "ModelConfig",
) -> None:
"""Entrypoint that loads data/model, builds trainer, and runs GRPO.
The function also performs a small eval subsample for speed if
``training_args.do_eval`` is enabled and an eval split exists.
:param script_args: Script configuration including dataset and rewards.
:type script_args: GRPOScriptArguments
:param training_args: GRPO trainer arguments from TRL.
:type training_args: GRPOConfig
:param model_args: Model configuration for TRL/transformers.
:type model_args: ``trl.ModelConfig``
:returns: ``None``. Side effects include training, evaluation, and checkpointing.
:rtype: None
"""
# Ensure logs directory exists for any file redirections by launchers
os.makedirs(os.environ.get("LOG_DIR", "var/artifacts/logs"), exist_ok=True)
ensure_real_dependencies(context="baseline GRPO training")
ensure_hf_repo_ready(training_args)
if getattr(training_args, "controller_meta_enabled", False):
LOG.info(
"controller_meta_enabled is set; CustomGRPOTrainer will handle controller/meta updates."
)
# Import selected pieces lazily to keep module import light-weight
if bool(getattr(training_args, "use_vllm", False)):
_patch_vllm_guided_decoding_compat()
from transformers.trainer_utils import get_last_checkpoint
from trl import (
GRPOTrainer as _GRPOTrainer,
get_peft_config as _get_peft_config,
)
from trl.data_utils import maybe_apply_chat_template
if bool(getattr(training_args, "use_vllm", False)):
_patch_vllm_guided_decoding_compat()
override = getattr(sys.modules[__name__], "GRPOTrainerOverride", None)
if override is not None:
trainer_cls = wrap_trl_trainer(override)
else:
trainer_cls = build_custom_grpo_trainer(_GRPOTrainer)
# Avoid leaking overrides across calls/tests.
setattr(sys.modules[__name__], "GRPOTrainerOverride", None)
peft_factory = get_peft_config_override or _get_peft_config
# Keep custom communicator patch opt-in to preserve open-r1 parity by default.
if bool(getattr(training_args, "use_vllm", False)):
patch_vllm_client = str(
os.getenv("MAXENT_TRL_VLLM_CLIENT_PATCH", "0")
).strip().lower() in {"1", "true", "yes", "on"}
if patch_vllm_client:
_patch_trl_vllm_client_init()
else:
LOG.info(
"Skipping custom TRL vLLM communicator patch "
"(MAXENT_TRL_VLLM_CLIENT_PATCH=0)."
)
transformers_mod = transformers
set_seed_fn = getattr(transformers_mod, "set_seed", None)
if callable(set_seed_fn):
set_seed_fn(training_args.seed)
if not getattr(training_args, "return_reward", False):
setattr(training_args, "return_reward", True)
active_prompt_template = normalize_prompt_template(
getattr(training_args, "prompt_template", None),
default=None,
)
template_stop_sequences, template_include_stop = resolve_generation_stop_settings(
active_prompt_template
)
if (
getattr(training_args, "vllm_stop_sequences", None) in (None, [], "")
and template_stop_sequences
):
setattr(training_args, "vllm_stop_sequences", list(template_stop_sequences))
if (
getattr(training_args, "vllm_include_stop_str_in_output", None) is None
and template_include_stop
):
setattr(training_args, "vllm_include_stop_str_in_output", True)
# Keep stop sequences aligned across train/eval and vLLM/HF generation.
vllm_stops = getattr(training_args, "vllm_stop_sequences", None)
if getattr(training_args, "gen_stop_sequences", None) in (None, []):
setattr(training_args, "gen_stop_sequences", vllm_stops)
if getattr(training_args, "eval_stop_sequences", None) in (None, []):
setattr(training_args, "eval_stop_sequences", vllm_stops)
# Logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logging.getLogger(__name__).setLevel(log_level)
log_run_header(training_args)
dl_kwargs = resolve_dataloader_kwargs(training_args)
if dl_kwargs:
# Normalize dataloader settings onto training_args for TRL/Trainer usage.
try:
training_args.dataloader_num_workers = int(
dl_kwargs.get(
"num_workers", getattr(training_args, "dataloader_num_workers", 0)
)
)
except (AttributeError, TypeError, ValueError) as exc:
LOG.debug("Failed to set dataloader_num_workers: %s", exc)
if "pin_memory" in dl_kwargs:
try:
training_args.dataloader_pin_memory = bool(dl_kwargs["pin_memory"])
except (AttributeError, TypeError, ValueError) as exc:
LOG.debug("Failed to set dataloader_pin_memory: %s", exc)
if getattr(training_args, "dataloader_num_workers", 0) > 0:
if "prefetch_factor" in dl_kwargs:
try:
training_args.dataloader_prefetch_factor = int(
dl_kwargs["prefetch_factor"]
)
except (AttributeError, TypeError, ValueError) as exc:
LOG.debug("Failed to set dataloader_prefetch_factor: %s", exc)
if "persistent_workers" in dl_kwargs:
try:
training_args.dataloader_persistent_workers = bool(
dl_kwargs["persistent_workers"]
)
except (AttributeError, TypeError, ValueError) as exc:
LOG.debug("Failed to set dataloader_persistent_workers: %s", exc)
else:
# Avoid invalid prefetch/persistent settings when workers are disabled.
try:
training_args.dataloader_prefetch_factor = None
except (AttributeError, TypeError, ValueError) as exc:
LOG.debug("Failed to clear dataloader_prefetch_factor: %s", exc)
try:
training_args.dataloader_persistent_workers = None
except (AttributeError, TypeError, ValueError) as exc:
LOG.debug("Failed to clear dataloader_persistent_workers: %s", exc)
LOG.info(
"Baseline dataloader settings | num_workers=%s | pin_memory=%s | prefetch_factor=%s | persistent_workers=%s",
getattr(training_args, "dataloader_num_workers", None),
getattr(training_args, "dataloader_pin_memory", None),
getattr(training_args, "dataloader_prefetch_factor", None),
getattr(training_args, "dataloader_persistent_workers", None),
)
# Optional: datasets logging if available
try: # pragma: no cover - environment dependent
import datasets as _hf_datasets
datasets_utils = getattr(_hf_datasets, "utils", None)
datasets_logging = getattr(datasets_utils, "logging", None)
set_verbosity = getattr(datasets_logging, "set_verbosity", None)
if callable(set_verbosity):
set_verbosity(log_level)
except (
ImportError,
ModuleNotFoundError,
AttributeError,
OSError,
RuntimeError,
ValueError,
) as exc:
LOG.debug("Skipping datasets logging setup: %s", exc)
tf_logging_module = getattr(
getattr(transformers_mod, "utils", None), "logging", None
)
if tf_logging_module is not None:
set_verbosity = getattr(tf_logging_module, "set_verbosity", None)
if callable(set_verbosity):
set_verbosity(log_level)
enable_default_handler = getattr(
tf_logging_module, "enable_default_handler", None
)
if callable(enable_default_handler):
enable_default_handler()
enable_explicit_format = getattr(
tf_logging_module, "enable_explicit_format", None
)
if callable(enable_explicit_format):
enable_explicit_format()
# Data / model
raw_ds = get_dataset(script_args)
pc = getattr(script_args, "dataset_prompt_column", "problem")
pc = _resolve_prompt_column(raw_ds, pc)
sc = getattr(script_args, "dataset_solution_column", "answer")
dataset_label = getattr(script_args, "dataset_name", None) or getattr(
script_args, "dataset_mixture", None
)
_validate_dataset_columns(
raw_ds,
prompt_column=pc,
solution_column=sc,
label=f"training dataset {dataset_label or ''}".strip(),
)
tokenizer = get_tokenizer(model_args, training_args)
model = get_model(model_args, training_args)
ensure_real_dependencies(
context="baseline GRPO training",
require_torch=False,
require_transformers=False,
require_trl=False,
require_datasets=False,
model=model,
tokenizer=tokenizer,
)
# Ensure PAD token exists (left padding recommended for causal LMs)
if tokenizer.pad_token_id is None:
if tokenizer.eos_token_id is not None:
eos_token = tokenizer.eos_token
if isinstance(eos_token, list):
eos_token = eos_token[0] if eos_token else None
if isinstance(eos_token, str):
tokenizer.pad_token = eos_token
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
resize_fn = getattr(model, "resize_token_embeddings", None)
if callable(resize_fn):
resize_fn(len(tokenizer))
_maybe_align_model_tokenizer_vocab(model, tokenizer)
config = getattr(model, "config", None)
if config is not None and getattr(config, "pad_token_id", None) is None:
setattr(config, "pad_token_id", tokenizer.pad_token_id)
try:
tokenizer.padding_side = "left"
except AttributeError as exc:
LOG.debug("Unable to set tokenizer.padding_side: %s", exc)
# Map dataset → prompt text + gold answer
char_limit = _prompt_char_limit_from_tokens(
getattr(training_args, "max_prompt_length", 0)
)
# Keep prompt mapping identical for GRPO and MaxEnt so startup prompt
# preprocessing and prompt-format behavior stay aligned.
use_prompt_messages = active_prompt_template is None
def _make_conversation(ex: Dict[str, Any]) -> Dict[str, Any]:
if pc not in ex:
raise ValueError(f"Dataset Question Field Error: {pc} is not supported.")
prompt: List[Dict[str, str]] = []
if training_args.system_prompt is not None:
prompt.append({"role": "system", "content": training_args.system_prompt})
prompt.append({"role": "user", "content": str(ex[pc])})
return {"prompt": prompt, "answer": str(ex.get(sc, ex.get("solution", "")))}
def _map_fn(ex: Dict[str, Any]) -> Dict[str, Any]:
"""Map a training split example to prompt/answer text.
:param ex: Dataset row containing prompt/answer fields.
:type ex: dict[str, Any]
:returns: Mapping with ``prompt``/``answer`` keys for training.
:rtype: dict[str, Any]
"""
if use_prompt_messages:
return _make_conversation(ex)
prompt_col = pc
if prompt_col not in ex and prompt_col == "problem" and "prompt" in ex:
prompt_col = "prompt"
out = _to_prompt(
ex,
cast(Any, tokenizer),
prompt_col,
training_args.system_prompt,
char_limit=char_limit,
return_messages=use_prompt_messages,
prompt_template=active_prompt_template,
)
out["answer"] = str(ex.get(sc, out.get("answer", "")))
return out
dataset: MutableMapping[str, Any]
map_fn = getattr(raw_ds, "map", None)
dataset: MutableMapping[str, Any]
if callable(map_fn):
with _main_process_first(training_args, "dataset prompt mapping"):
dataset = _ensure_split_mapping(map_fn(_map_fn))
else:
class _Split:
def __init__(self, rows: List[Any]) -> None:
self._rows = rows
@property
def column_names(self) -> List[str]:
return []
def remove_columns(self, *_cols: Any) -> "_Split":
return self
def shuffle(self, seed: Any = None) -> "_Split":
_ = seed
return self
def select(self, _indices: Any) -> "_Split":
return self
def __len__(self) -> int:
return len(self._rows)
class _DictDataset(dict):
def map(self, fn: Callable[[Any], Any]) -> "_DictDataset":
return _DictDataset(
{k: _Split([fn(ex) for ex in v]) for k, v in self.items()}
)
raw_splits = raw_ds if isinstance(raw_ds, dict) else {"train": raw_ds}
dataset = _ensure_split_mapping(_DictDataset(raw_splits).map(_map_fn))
for split in list(dataset):
split_ds = dataset[split]
if "messages" in _get_column_names(split_ds):
remove_columns = getattr(split_ds, "remove_columns", None)
if callable(remove_columns):
dataset[split] = remove_columns("messages")
try:
rank = int(getattr(training_args, "local_rank", -1) or -1)
except (TypeError, ValueError):
rank = -1
if rank in (-1, 0):
try:
sample = dataset[getattr(script_args, "dataset_train_split", "train")][0]
sample_prompt = sample.get("prompt") if isinstance(sample, dict) else None
if isinstance(sample_prompt, str):
rendered = sample_prompt
else:
rendered = maybe_apply_chat_template(sample, cast(Any, tokenizer)).get(
"prompt"
)
if isinstance(rendered, str):
preview = rendered[:400].replace("\n", "\\n")
LOG.info(
"Prompt preview (chat template applied): %s%s",
preview,
"..." if len(rendered) > 400 else "",
)
except Exception as exc:
LOG.debug("Failed to render prompt preview: %s", exc)
# Resolve splits
train_split = getattr(script_args, "dataset_train_split", "train")
test_split = getattr(script_args, "dataset_test_split", None)
if test_split is None:
# prefer 'validation' then 'test' if present
if "validation" in dataset:
test_split = "validation"
elif "test" in dataset:
test_split = "test"
train_ds = dataset[train_split]
eval_ds = None
eval_benchmark_name_to_id: Dict[str, int] = {}
eval_benchmark_id_to_name: Dict[int, str] = {}
eval_dataset_name = getattr(script_args, "eval_dataset_name", None)
eval_prompt_col = getattr(script_args, "eval_dataset_prompt_column", None) or pc
eval_solution_col = getattr(script_args, "eval_dataset_solution_column", None) or sc
if training_args.do_eval:
if eval_dataset_name:
eval_split = getattr(script_args, "eval_dataset_split", "validation")
eval_specs = _split_eval_dataset_specs(eval_dataset_name)
if not eval_specs:
eval_specs = [str(eval_dataset_name)]
eval_dataset_parts: List[Any] = []
for spec in eval_specs:
(
spec_dataset_name,
spec_dataset_config,
spec_split,
spec_prompt_col,
spec_solution_col,
spec_benchmark,
) = _resolve_eval_dataset_spec(
spec,
default_dataset_config=getattr(
script_args, "eval_dataset_config", None
),
default_split=eval_split,
default_prompt_column=eval_prompt_col,
default_solution_column=eval_solution_col,
)
eval_ds_raw = load_dataset_split(
spec_dataset_name,
spec_dataset_config,
spec_split,
)
spec_prompt_col = _resolve_prompt_column(eval_ds_raw, spec_prompt_col)
_validate_dataset_columns(
eval_ds_raw,
prompt_column=spec_prompt_col,
solution_column=spec_solution_col,
label=f"eval dataset {spec_dataset_name}:{spec_split}",
)
benchmark_id = eval_benchmark_name_to_id.setdefault(
spec_benchmark, len(eval_benchmark_name_to_id)
)
eval_benchmark_id_to_name.setdefault(benchmark_id, spec_benchmark)
def _map_eval_fn(
ex: Dict[str, Any],
*,
prompt_col: str = spec_prompt_col,
solution_col: str = spec_solution_col,
benchmark_label: str = spec_benchmark,
benchmark_idx: int = benchmark_id,
) -> Dict[str, Any]:
"""Convert evaluation dataset rows into prompt/answer pairs."""
if use_prompt_messages:
if prompt_col not in ex:
raise ValueError(
f"Dataset Question Field Error: {prompt_col} is not supported."
)
prompt: List[Dict[str, str]] = []
if training_args.system_prompt is not None:
prompt.append(
{
"role": "system",
"content": training_args.system_prompt,
}
)
prompt.append({"role": "user", "content": str(ex[prompt_col])})
return {
"prompt": prompt,
"answer": str(ex.get(solution_col, ex.get("solution", ""))),
"eval_benchmark": benchmark_label,
"eval_benchmark_id": int(benchmark_idx),
}
resolved_prompt_col = prompt_col
if (
resolved_prompt_col not in ex
and resolved_prompt_col == "problem"
and "prompt" in ex
):
resolved_prompt_col = "prompt"
out = _to_prompt(
ex,
cast(Any, tokenizer),
resolved_prompt_col,
training_args.system_prompt,
char_limit=char_limit,
return_messages=use_prompt_messages,
prompt_template=active_prompt_template,
)
out["answer"] = str(ex.get(solution_col, out.get("answer", "")))
out["eval_benchmark"] = benchmark_label
out["eval_benchmark_id"] = int(benchmark_idx)
return out
with _main_process_first(training_args, "eval dataset prompt mapping"):
mapped_eval = eval_ds_raw.map(_map_eval_fn)
if "messages" in _get_column_names(mapped_eval):
remove_columns = getattr(mapped_eval, "remove_columns", None)
if callable(remove_columns):
mapped_eval = remove_columns("messages")
eval_dataset_parts.append(mapped_eval)
if len(eval_dataset_parts) == 1:
eval_ds = eval_dataset_parts[0]
elif eval_dataset_parts:
try:
from datasets import concatenate_datasets as _hf_concat
eval_ds = _hf_concat(eval_dataset_parts)
except Exception:
merged_rows: List[Any] = []
for part in eval_dataset_parts:
merged_rows.extend(list(part))
try:
from datasets import Dataset as _HFDataset
eval_ds = _HFDataset.from_list(merged_rows)
except Exception:
eval_ds = merged_rows
if eval_benchmark_id_to_name:
LOG.info(
"Configured eval benchmarks: %s",
", ".join(
f"{idx}:{name}"
for idx, name in sorted(eval_benchmark_id_to_name.items())
),
)
elif test_split is not None and test_split in dataset:
full_eval = dataset[test_split]
eval_ds = full_eval
# Rewards
reward_funcs, reward_weights = load_reward_functions(
script_args, tokenizer, training_args
)
# Keep TRL args aligned with the resolved reward spec so GRPOTrainer's
# validation (length match) succeeds even when recipes store rewards on
# script_args only.
try:
setattr(training_args, "reward_weights", reward_weights)
except (AttributeError, TypeError) as exc:
LOG.debug("Failed to attach reward_weights to training_args: %s", exc)
# Trainer
with _force_vllm_dtype(training_args, tokenizer):
trainer = trainer_cls(
model=model,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
peft_config=peft_factory(model_args),
processing_class=tokenizer,
)
# Expose trainer kwargs for tests that introspect trainer construction.
setattr(
trainer,
"_init_kwargs",
dict(
model=model,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
peft_config=peft_factory(model_args),
processing_class=tokenizer,
),
)
if eval_benchmark_id_to_name:
setattr(
trainer,
"eval_benchmark_id_to_name",
dict(eval_benchmark_id_to_name),
)
setattr(
trainer,
"eval_benchmark_name_to_id",
dict(eval_benchmark_name_to_id),
)
if bool(getattr(training_args, "seed_paper_eval_enabled", False)) and hasattr(
trainer, "add_callback"
):
trainer.add_callback(SeedPaperEvalCallback(training_args))
# Train
logger = logging.getLogger(__name__)
resume_request = getattr(training_args, "resume_from_checkpoint", None)
last_ckpt: Optional[str] = None
if isinstance(resume_request, str) and resume_request:
if os.path.isdir(resume_request):
last_ckpt = resume_request
else:
logger.warning(
"resume_from_checkpoint=%s was provided but the path does not exist; "
"starting from scratch.",
resume_request,
)
elif resume_request is None:
# Backward compatible behavior: if the output directory already contains a
# checkpoint and the user did not explicitly opt out of resuming, prefer
# picking up from the latest checkpoint.
output_dir = getattr(training_args, "output_dir", None)
if output_dir and os.path.isdir(output_dir):
last_ckpt = get_last_checkpoint(output_dir)
elif resume_request:
output_dir = getattr(training_args, "output_dir", None)
if output_dir and os.path.isdir(output_dir):
last_ckpt = get_last_checkpoint(output_dir)
if last_ckpt is None:
logger.warning(
"resume_from_checkpoint was requested but no checkpoint was found under %s; "
"starting from scratch.",
output_dir or "<unspecified>",
)
else:
last_ckpt = None
if last_ckpt is not None:
training_args.resume_from_checkpoint = last_ckpt
else:
training_args.resume_from_checkpoint = None
if bool(getattr(training_args, "seed_paper_eval_enabled", False)):
eval_strategy = str(getattr(training_args, "eval_strategy", "") or "").strip().lower()
built_in_eval_enabled = bool(getattr(training_args, "do_eval", False)) and eval_strategy not in {
"",
"no",
"none",
}
if not built_in_eval_enabled and hasattr(training_args, "eval_on_start"):
setattr(
training_args,
"seed_paper_eval_on_start",
bool(getattr(training_args, "eval_on_start", False)),
)
setattr(training_args, "eval_on_start", False)
train_result = trainer.train(resume_from_checkpoint=last_ckpt)
if hasattr(trainer, "log_metrics"):
trainer.log_metrics("train", train_result.metrics)
if hasattr(trainer, "save_metrics"):
trainer.save_metrics("train", train_result.metrics)
if hasattr(trainer, "save_state"):
trainer.save_state()
# Save
if bool(getattr(training_args, "final_model_save_enabled", True)):
try:
trainer.save_model(training_args.output_dir)
except TypeError:
trainer.save_model()
if getattr(trainer, "accelerator", None) is not None and getattr(
trainer.accelerator, "is_main_process", False
):
if hasattr(trainer, "create_model_card"):
trainer.create_model_card(
dataset_name=script_args.dataset_name, tags=["open-r1"]
)
if hasattr(trainer, "model") and hasattr(trainer.model, "config"):
trainer.model.config.use_cache = True
if hasattr(trainer.model.config, "save_pretrained"):
trainer.model.config.save_pretrained(training_args.output_dir)
# Eval
if training_args.do_eval and eval_ds is not None:
if hasattr(trainer, "evaluate"):
metrics = trainer.evaluate()
if hasattr(trainer, "log_metrics"):
trainer.log_metrics("eval", metrics)
if hasattr(trainer, "save_metrics"):
trainer.save_metrics("eval", metrics)
# Hub
if getattr(training_args, "push_to_hub", False):
trainer.push_to_hub(dataset_name=script_args.dataset_name, tags=["open-r1"])