"""Custom TRL GRPOTrainer wrapper used by the MaxEnt-GRPO pipelines.
This module is the single place where GRPO-vs-MaxEnt objective behavior should
diverge at runtime. The surrounding training pipeline (dataset mapping, reward
loading, model/tokenizer setup, trainer wiring, launch entrypoints) is kept
shared so objective comparisons stay fair and easy to audit.
"""
from __future__ import annotations
# pylint: disable=broad-exception-caught
import logging
import math
import os
import json
from contextlib import contextmanager, nullcontext
from collections.abc import Mapping
from functools import partial
import inspect
import re
from types import SimpleNamespace
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Type, cast
import torch
import torch.nn.functional as F
AutoModelForCausalLM = None # type: ignore[assignment]
try:
from accelerate.utils import gather
except (ImportError, ModuleNotFoundError): # pragma: no cover - test fallback
def gather(value: Any) -> Any:
return value
try:
from trl.data_utils import (
apply_chat_template,
is_conversational,
maybe_apply_chat_template,
)
except (ImportError, ModuleNotFoundError): # pragma: no cover - test fallback
def apply_chat_template(example: Any, _tokenizer: Any) -> Dict[str, str]:
return {"text": str(example)}
def is_conversational(example: Any) -> bool:
return isinstance(example, list)
def maybe_apply_chat_template(example: Any, _tokenizer: Any) -> Dict[str, str]:
return {"prompt": str(example)}
_trl_create_reference_model = None
_trl_prepare_deepspeed = None
_trl_prepare_fsdp = None
from maxent_grpo.rewards.basic import (
pure_accuracy_math_correctness,
truncate_after_first_boxed_answer,
uses_pure_accuracy_math_reward,
)
from maxent_grpo.methods import resolve_method_spec_from_args
from maxent_grpo.objectives import resolve_objective_routing
from maxent_grpo.training.rewards import _compute_seed_grpo_statistics
from maxent_grpo.training.controller_objective import (
ControllerMetaContext,
build_controller_objective,
)
from maxent_grpo.training.telemetry.trl_logging import ensure_weighting_logging
from maxent_grpo.training.weighting import (
apply_meta_controller_update,
collect_weight_entropy,
maybe_update_beta,
maybe_update_tau,
weight_matrix_from_q,
)
from maxent_grpo.training.weighting.logic import build_weighting_settings
from maxent_grpo.training.scoring_common import (
_coerce_optional_int,
_get_config_value,
_get_embedding_vocab_size,
)
LOG = logging.getLogger(__name__)
_PASS_METRIC_SUCCESS_REWARD = 1.0
_PASS_METRIC_EPS = 1e-6
@contextmanager
def _adapter_disabled_context(model: Any):
"""Disable adapters when the model exposes a supported API.
This trainer runs both plain Transformers models and PEFT-enabled models.
Older PEFT integrations expose ``disable_adapter()`` as a context manager,
while newer Transformers PEFT shims expose ``disable_adapters()`` /
``enable_adapters()`` as imperative methods. Plain base models expose
neither and should be treated as a no-op.
"""
disable_adapter = getattr(model, "disable_adapter", None)
if callable(disable_adapter):
with disable_adapter():
yield
return
disable_adapters = getattr(model, "disable_adapters", None)
enable_adapters = getattr(model, "enable_adapters", None)
if callable(disable_adapters) and callable(enable_adapters):
try:
disable_adapters()
except ValueError as exc:
if "PEFT is not installed" in str(exc):
with nullcontext():
yield
return
raise
try:
yield
finally:
try:
enable_adapters()
except ValueError as exc:
if "PEFT is not installed" not in str(exc):
raise
return
with nullcontext():
yield
_LOG_DELTA_CLAMP = 5.0
_BENCHMARK_SUFFIX_SANITIZER = re.compile(r"[^A-Za-z0-9]+")
_EMA_PARAM_NAME_PREFIXES: Tuple[str, ...] = (
"_fsdp_wrapped_module.",
"_checkpoint_wrapped_module.",
"base_model.model.",
"module.",
"model.",
)
def _mean(values: List[float]) -> float:
"""Return the arithmetic mean for a non-empty list, else 0.0."""
return float(sum(values)) / float(len(values)) if values else 0.0
def _weighted_mean(values: List[float], weights: List[float]) -> float:
"""Return the weighted mean or 0.0 when weights are empty."""
if not values or not weights:
return 0.0
total_weight = float(sum(weights))
if total_weight <= 0.0:
return 0.0
return float(sum(v * w for v, w in zip(values, weights))) / total_weight
def _nanmin_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Return the min value while ignoring NaNs."""
finite = tensor[~torch.isnan(tensor)]
if finite.numel() == 0:
return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
return finite.min()
def _nanmax_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Return the max value while ignoring NaNs."""
finite = tensor[~torch.isnan(tensor)]
if finite.numel() == 0:
return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
return finite.max()
def _clamp_log_delta(delta: torch.Tensor) -> torch.Tensor:
"""Clamp log-probability deltas before exponentiating."""
return delta.float().clamp(min=-_LOG_DELTA_CLAMP, max=_LOG_DELTA_CLAMP)
def _resolve_vocab_size_limit(model: Any) -> Optional[int]:
"""Return the smallest positive vocab-size limit exposed by the model."""
config = getattr(model, "config", None)
embedding_vocab_size = _get_embedding_vocab_size(model, config)
config_vocab_size = _coerce_optional_int(_get_config_value(config, "vocab_size", None))
attr_vocab_size = _coerce_optional_int(getattr(model, "vocab_size", None))
candidates = [
int(value)
for value in (embedding_vocab_size, config_vocab_size, attr_vocab_size)
if isinstance(value, int) and int(value) > 0
]
if not candidates:
return None
return max(candidates)
def _resolve_tokenizer_vocab_limit(tokenizer: Any) -> Optional[int]:
"""Return the full positive vocab-size limit exposed by the tokenizer."""
candidates: List[int] = []
vocab_size = _coerce_optional_int(getattr(tokenizer, "vocab_size", None))
if isinstance(vocab_size, int) and vocab_size > 0:
candidates.append(int(vocab_size))
try:
tokenizer_len = _coerce_optional_int(len(tokenizer))
except Exception:
tokenizer_len = None
if isinstance(tokenizer_len, int) and tokenizer_len > 0:
candidates.append(int(tokenizer_len))
if not candidates:
return None
# `tokenizer.vocab_size` often excludes added special tokens while
# `len(tokenizer)` includes them; use the larger addressable range.
return max(candidates)
def _resolve_token_id_upper_bound(model: Any, tokenizer: Any = None) -> Optional[int]:
"""Return a conservative upper bound for valid token IDs."""
candidates: List[int] = []
model_limit = _resolve_vocab_size_limit(model)
if isinstance(model_limit, int) and model_limit > 0:
candidates.append(int(model_limit))
tokenizer_limit = _resolve_tokenizer_vocab_limit(tokenizer)
if isinstance(tokenizer_limit, int) and tokenizer_limit > 0:
candidates.append(int(tokenizer_limit))
if not candidates:
return None
return min(candidates)
def _mask_invalid_logit_columns(
logits: torch.Tensor,
*,
valid_vocab_size: Optional[int],
) -> torch.Tensor:
"""Mask logit columns that correspond to tokenizer-inaccessible token IDs.
Some Qwen checkpoints expose larger output embeddings than the tokenizer can
address. Leaving those extra columns active lets entropy-regularized losses
push probability mass into dead token rows, which later surface as sampled
token IDs outside the tokenizer range.
"""
if not isinstance(valid_vocab_size, int) or valid_vocab_size <= 0:
return logits
if logits.ndim < 1:
return logits
last_dim = int(logits.size(-1))
if last_dim <= valid_vocab_size:
return logits
masked = logits.clone()
mask_value = torch.finfo(masked.dtype).min
masked[..., valid_vocab_size:] = mask_value
return masked
def _entropy_normalization_scale(valid_vocab_size: Optional[int]) -> float:
"""Return the log-vocab normalization constant for exact entropy metrics."""
if not isinstance(valid_vocab_size, int) or valid_vocab_size <= 1:
return 1.0
try:
scale = float(math.log(float(valid_vocab_size)))
except (TypeError, ValueError, OverflowError):
return 1.0
if not math.isfinite(scale) or scale <= 0.0:
return 1.0
return scale
def _tokenize_for_diversity(text: str, tokenizer: Any = None) -> List[Any]:
"""Tokenize a completion for diversity metrics."""
if not text:
return []
if tokenizer is not None:
try:
encode = getattr(tokenizer, "encode", None)
if callable(encode):
return list(encode(text, add_special_tokens=False))
if callable(tokenizer):
tokenized = tokenizer(text, add_special_tokens=False)
if isinstance(tokenized, dict) and "input_ids" in tokenized:
return list(tokenized["input_ids"])
if isinstance(tokenized, (list, tuple)):
return list(tokenized)
except Exception:
pass
return [tok for tok in text.strip().split() if tok]
def _completion_diversity_metrics(
grouped_completions: List[List[str]],
*,
tokenizer: Any = None,
accelerator: Any = None,
) -> Dict[str, float]:
"""Return coarse diversity metrics for grouped completions."""
if not grouped_completions:
return {}
def _distinct_n(tokens: List[Any], n: int) -> float:
if n <= 0 or len(tokens) < n:
return 0.0
total = len(tokens) - n + 1
if total <= 0:
return 0.0
ngrams = {tuple(tokens[i : i + n]) for i in range(total)}
return float(len(ngrams)) / float(total)
def _jaccard_distance(sets: List[set[Any]]) -> float:
if len(sets) < 2:
return 0.0
total_dist = 0.0
pairs = 0
for i in range(len(sets)):
for j in range(i + 1, len(sets)):
a = sets[i]
b = sets[j]
union = a | b
if not union:
dist = 0.0
else:
dist = 1.0 - (len(a & b) / float(len(union)))
total_dist += dist
pairs += 1
return total_dist / float(pairs) if pairs > 0 else 0.0
group_metrics: List[Dict[str, float]] = []
for group in grouped_completions:
if not group:
continue
normalized = [comp.strip() for comp in group if comp is not None]
group_size = len(normalized)
if group_size <= 0:
continue
all_tokens: List[Any] = []
token_sets: List[set[Any]] = []
for comp in normalized:
tokens = _tokenize_for_diversity(comp, tokenizer)
if tokens:
all_tokens.extend(tokens)
token_sets.append(set(tokens))
else:
token_sets.append(set())
group_metrics.append(
{
"group_size": float(group_size),
"distinct_1": _distinct_n(all_tokens, 1),
"distinct_2": _distinct_n(all_tokens, 2),
"jaccard": _jaccard_distance(token_sets),
}
)
if not group_metrics:
return {}
if accelerator is not None and getattr(accelerator, "num_processes", 1) > 1:
gather_fn = getattr(accelerator, "gather_object", None)
if callable(gather_fn):
try:
gather_fn_typed = cast(Callable[[Any], Any], gather_fn)
gathered = gather_fn_typed(group_metrics) # pylint: disable=not-callable
if isinstance(gathered, list):
merged: List[Dict[str, float]] = []
for item in gathered:
if isinstance(item, list):
merged.extend([m for m in item if isinstance(m, dict)])
elif isinstance(item, dict):
merged.append(item)
if merged:
group_metrics = merged
except Exception:
pass
else:
dist = getattr(torch, "distributed", None)
if (
dist is not None
and callable(getattr(dist, "is_available", None))
and callable(getattr(dist, "is_initialized", None))
and dist.is_available()
and dist.is_initialized()
):
try:
world = int(getattr(dist, "get_world_size")())
except (TypeError, ValueError, RuntimeError):
world = 0
if world > 1:
try:
gathered = [None for _ in range(world)]
gather_obj = getattr(dist, "all_gather_object", None)
if callable(gather_obj):
gather_obj(gathered, group_metrics)
merged: List[Dict[str, float]] = []
for item in gathered:
if isinstance(item, list):
merged.extend(
[m for m in item if isinstance(m, dict)]
)
elif isinstance(item, dict):
merged.append(item)
if merged:
group_metrics = merged
except (RuntimeError, ValueError, TypeError):
pass
distinct1_vals = [m["distinct_1"] for m in group_metrics if "distinct_1" in m]
distinct2_vals = [m["distinct_2"] for m in group_metrics if "distinct_2" in m]
jaccard_vals = [m["jaccard"] for m in group_metrics if "jaccard" in m]
weights = [m.get("group_size", 0.0) for m in group_metrics]
return {
"distinct_1": _mean(distinct1_vals),
"distinct_2": _mean(distinct2_vals),
"jaccard": _mean(jaccard_vals),
"distinct_1_micro": _weighted_mean(distinct1_vals, weights),
"distinct_2_micro": _weighted_mean(distinct2_vals, weights),
"jaccard_micro": _weighted_mean(jaccard_vals, weights),
}
def _is_main_process(trainer: Any) -> bool:
"""Return whether the active trainer rank should emit shared metrics."""
accelerator = getattr(trainer, "accelerator", None)
return bool(getattr(accelerator, "is_main_process", True))
def _use_lightweight_greedy_eval(trainer: Any, mode: str) -> bool:
"""Return whether training-time eval is using the lightweight greedy path."""
if mode != "eval":
return False
args = getattr(trainer, "args", None)
return bool(getattr(args, "eval_greedy_only_enabled", False))
def _use_sharded_prompt_major_greedy_eval(trainer: Any, mode: str) -> bool:
"""Return whether greedy-only eval should shard prompt-major batches across ranks."""
if not _use_lightweight_greedy_eval(trainer, mode):
return False
args = getattr(trainer, "args", None)
if not bool(getattr(args, "disable_distributed_sampler", False)):
return False
accelerator = getattr(trainer, "accelerator", None)
try:
num_processes = int(getattr(accelerator, "num_processes", 1) or 1)
except (TypeError, ValueError):
num_processes = 1
return num_processes > 1
def _use_local_only_lightweight_eval_metrics(trainer: Any, mode: str) -> bool:
"""Return whether greedy-only eval should stay main-rank-only for metrics."""
return _use_lightweight_greedy_eval(
trainer, mode
) and not _use_sharded_prompt_major_greedy_eval(trainer, mode)
def _use_local_only_eval_diversity_metrics(trainer: Any, mode: str) -> bool:
"""Return whether eval diversity logging should stay local to main rank.
Full eval still runs through the standard Trainer loop, but completion
diversity logging is auxiliary and uses Python-object gathers that have been
the concrete failure mode under DDP. When eval inputs are replicated across
ranks (the stable math configs set ``disable_distributed_sampler=True``), we
can compute those diversity summaries on the main rank only without any
cross-rank synchronization.
"""
if mode != "eval":
return False
args = getattr(trainer, "args", None)
if not bool(getattr(args, "disable_distributed_sampler", False)):
return False
accelerator = getattr(trainer, "accelerator", None)
try:
num_processes = int(getattr(accelerator, "num_processes", 1) or 1)
except (TypeError, ValueError):
num_processes = 1
return num_processes > 1
def _metric_tensor_for_logging(
trainer: Any,
value: Any,
*,
mode: str,
) -> Optional[torch.Tensor]:
"""Return a metric tensor for logging, avoiding DDP gathers in local-only eval."""
if not isinstance(value, torch.Tensor):
return None
if _use_sharded_prompt_major_greedy_eval(trainer, mode):
gathered = gather(value)
if not isinstance(gathered, torch.Tensor):
return None
return gathered
if _use_local_only_lightweight_eval_metrics(trainer, mode):
if not _is_main_process(trainer):
return None
return value
gathered = gather(value)
return gathered if isinstance(gathered, torch.Tensor) else None
def _local_metric_tensor(value: Any) -> Optional[torch.Tensor]:
"""Return a detached local metric tensor without any distributed gather."""
if not isinstance(value, torch.Tensor):
return None
if value.numel() <= 0:
return None
return value.detach()
def _apply_eos_completion_mask(
completion_ids: torch.Tensor,
eos_token_id: Optional[int],
completion_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Mask completion tokens after the first EOS token (TRL-style)."""
if eos_token_id is None:
if completion_mask is not None:
return completion_mask
return torch.ones_like(completion_ids, dtype=getattr(torch, "long", None))
try:
is_eos = completion_ids == eos_token_id
batch = int(is_eos.size(0))
seq_len = int(is_eos.size(1))
eos_idx = torch.full(
(batch,),
seq_len,
dtype=getattr(torch, "long", None),
device=getattr(completion_ids, "device", None),
)
any_eos = is_eos.any(dim=1)
if bool(any_eos.any()):
eos_pos = is_eos.int().argmax(dim=1)
eos_idx = eos_idx.clone()
eos_idx[any_eos] = eos_pos[any_eos]
seq_idx = torch.arange(
seq_len, device=getattr(completion_ids, "device", None)
).unsqueeze(0)
seq_idx = seq_idx.expand(batch, -1)
mask = seq_idx <= eos_idx.unsqueeze(1)
to_fn = getattr(mask, "to", None)
if callable(to_fn):
mask = to_fn(dtype=getattr(torch, "long", None))
return cast(torch.Tensor, mask)
except Exception:
# Defensive fallback used for test doubles that only implement a subset
# of tensor ops; preserve prior behavior by returning an all-ones mask.
return torch.ones_like(completion_ids, dtype=getattr(torch, "long", None))
def _normalize_text_for_prefix_match(text: str) -> str:
"""Normalize text for lightweight decode-prefix comparisons."""
return " ".join(str(text).split()).strip()
def _build_prompt_text(example: Dict[str, Any], tokenizer: Any) -> str:
"""Render one trainer example into the exact text sent to generation."""
if not isinstance(example, dict):
return str(example)
prompt = example.get("prompt", "")
if isinstance(prompt, list) and prompt:
first_message = prompt[0]
last_message = prompt[-1]
if (
isinstance(first_message, dict)
and isinstance(last_message, dict)
and "role" in first_message
and "content" in first_message
and "role" in last_message
):
apply_chat_template = getattr(tokenizer, "apply_chat_template", None)
if callable(apply_chat_template):
try:
last_role = str(last_message.get("role", ""))
add_generation_prompt = last_role == "user"
continue_final_message = last_role == "assistant"
return str(
apply_chat_template(
prompt,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
)
)
except Exception:
pass
conversational_example: Dict[str, Any] | None = None
if "prompt" in example:
conversational_example = {"prompt": prompt}
elif "messages" in example:
conversational_example = {"messages": example.get("messages")}
if (
isinstance(conversational_example, dict)
and is_conversational(conversational_example)
):
try:
rendered = maybe_apply_chat_template(conversational_example, tokenizer)
except Exception:
rendered = {"prompt": str(prompt)}
if isinstance(rendered, dict):
prompt_text = rendered.get("prompt")
if prompt_text is None:
prompt_text = rendered.get("text")
if prompt_text is not None:
return str(prompt_text)
return str(rendered)
return str(prompt)
def _normalize_group_mass_proxy(values: Sequence[float]) -> List[float]:
"""Convert a per-group signal into a non-negative mass proxy."""
cleaned: List[float] = []
for value in values:
try:
cleaned.append(float(value))
except (TypeError, ValueError):
cleaned.append(float("nan"))
if not cleaned:
return []
if all(math.isfinite(val) and val >= 0.0 for val in cleaned):
total = sum(cleaned)
if total > 0.0:
return [val / total for val in cleaned]
positives = [max(val, 0.0) if math.isfinite(val) else 0.0 for val in cleaned]
pos_total = sum(positives)
if pos_total > 0.0:
return [val / pos_total for val in positives]
return [float("nan")] * len(cleaned)
def _build_rich_rollout_rows(
*,
step: int,
group_size: int,
prompt_texts: Sequence[str],
completion_texts: Sequence[str],
rewards: Sequence[float],
advantages: Sequence[float],
q_values: Optional[Sequence[float]] = None,
) -> tuple[list[str], list[list[Any]]]:
"""Build prompt-major rollout rows for within-group distribution analysis."""
total_rows = min(
len(prompt_texts),
len(completion_texts),
len(rewards),
len(advantages),
)
if total_rows <= 0:
return [], []
q_flat = list(q_values or [])
columns = [
"step",
"prompt_index",
"completion_index",
"group_size",
"reward_rank_desc",
"prompt",
"completion",
"reward_total",
"advantage",
"q_mass",
"update_weight_raw",
"update_mass_proxy",
]
rows: List[List[Any]] = []
effective_group = max(int(group_size), 1)
for start in range(0, total_rows, effective_group):
stop = min(start + effective_group, total_rows)
local_rewards = [float(rewards[idx]) for idx in range(start, stop)]
local_advantages = [float(advantages[idx]) for idx in range(start, stop)]
local_q = (
[float(q_flat[idx]) for idx in range(start, stop)]
if len(q_flat) >= stop
else [float("nan")] * (stop - start)
)
reward_order = sorted(
range(stop - start),
key=lambda idx: (-local_rewards[idx], idx),
)
reward_rank = {local_idx: rank + 1 for rank, local_idx in enumerate(reward_order)}
use_q_mass = all(math.isfinite(val) for val in local_q)
local_proxy = (
_normalize_group_mass_proxy(local_q)
if use_q_mass
else _normalize_group_mass_proxy(local_advantages)
)
prompt_index = start // effective_group
for local_idx, row_idx in enumerate(range(start, stop)):
q_mass = local_q[local_idx] if use_q_mass else float("nan")
update_weight_raw = q_mass if use_q_mass else local_advantages[local_idx]
update_mass_proxy = (
local_proxy[local_idx]
if local_idx < len(local_proxy)
else float("nan")
)
rows.append(
[
int(step),
int(prompt_index),
int(local_idx),
int(stop - start),
int(reward_rank.get(local_idx, local_idx + 1)),
str(prompt_texts[row_idx]),
str(completion_texts[row_idx]),
float(local_rewards[local_idx]),
float(local_advantages[local_idx]),
float(q_mass),
float(update_weight_raw),
float(update_mass_proxy),
]
)
return columns, rows
def _write_rich_rollout_sidecar(
*,
output_dir: str,
table_key: str,
step: int,
columns: Sequence[str],
rows: Sequence[Sequence[Any]],
) -> Optional[str]:
"""Persist prompt-major rollout rows for downstream figure generation."""
if not output_dir:
return None
try:
sidecar_dir = os.path.join(output_dir, "rich_completions")
os.makedirs(sidecar_dir, exist_ok=True)
path = os.path.join(sidecar_dir, f"{table_key}_step_{int(step):06d}.json")
with open(path, "w", encoding="utf-8") as handle:
json.dump(
{"columns": list(columns), "data": [list(row) for row in rows]},
handle,
)
return path
except OSError:
return None
def _token_prefix_search_order(target_len: int, max_len: int) -> List[int]:
"""Return a small symmetric search window around a candidate prefix length."""
if max_len <= 0:
return []
bounded = max(1, min(target_len, max_len))
order = [bounded]
radius = 1
while radius <= 8:
lower = bounded - radius
upper = bounded + radius
if lower >= 1:
order.append(lower)
if upper <= max_len:
order.append(upper)
radius += 1
if max_len not in order:
order.append(max_len)
return order
def _find_token_prefix_len_for_text(
tokenizer: Any,
token_ids: List[int],
target_text: str,
) -> Optional[int]:
"""Best-effort map a decoded text prefix back onto token prefix length."""
if not token_ids:
return None
normalized_target = _normalize_text_for_prefix_match(target_text)
if not normalized_target:
return None
encode = getattr(tokenizer, "encode", None)
decode = getattr(tokenizer, "decode", None)
if not callable(encode) or not callable(decode):
return None
try:
encoded = list(encode(target_text, add_special_tokens=False))
except Exception:
encoded = []
search_order = _token_prefix_search_order(len(encoded), len(token_ids))
if not search_order:
search_order = list(range(1, len(token_ids) + 1))
for prefix_len in search_order:
try:
decoded = decode(token_ids[:prefix_len], skip_special_tokens=True)
except Exception:
continue
if _normalize_text_for_prefix_match(decoded) == normalized_target:
return prefix_len
return None
def _pad_completion_rows(
rows: List[List[int]],
*,
pad_token_id: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pad variable-length completion token rows and return ids + mask tensors."""
if not rows:
empty = torch.empty((0, 0), dtype=torch.long, device=device)
return empty, empty
max_len = max(len(row) for row in rows)
if max_len <= 0:
empty = torch.empty((len(rows), 0), dtype=torch.long, device=device)
mask = torch.empty((len(rows), 0), dtype=torch.long, device=device)
return empty, mask
completion_ids = torch.full(
(len(rows), max_len),
int(pad_token_id),
dtype=torch.long,
device=device,
)
completion_mask = torch.zeros(
(len(rows), max_len),
dtype=torch.long,
device=device,
)
for idx, row in enumerate(rows):
if not row:
continue
width = len(row)
completion_ids[idx, :width] = torch.tensor(
row,
dtype=torch.long,
device=device,
)
completion_mask[idx, :width] = 1
return completion_ids, completion_mask
def _pad_logprob_rows(
rows: List[torch.Tensor],
*,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Pad per-token log-prob rows with zeros to a dense tensor."""
if not rows:
return torch.empty((0, 0), dtype=dtype, device=device)
max_len = max(int(row.numel()) for row in rows)
if max_len <= 0:
return torch.empty((len(rows), 0), dtype=dtype, device=device)
padded = torch.zeros((len(rows), max_len), dtype=dtype, device=device)
for idx, row in enumerate(rows):
width = int(row.numel())
if width <= 0:
continue
padded[idx, :width] = row.to(device=device, dtype=dtype)
return padded
def _metric_suffix_from_benchmark(name: Any) -> str:
"""Return a metric-safe benchmark suffix (e.g., ``AIME24``)."""
text = str(name).strip()
if not text:
return "UNKNOWN"
cleaned = _BENCHMARK_SUFFIX_SANITIZER.sub("_", text).strip("_").upper()
return cleaned or "UNKNOWN"
def _gather_eval_benchmark_ids_for_prompts(
trainer: Any,
prompt_inputs: List[Dict[str, Any]],
*,
device: torch.device,
local_only: bool = False,
) -> Optional[torch.Tensor]:
"""Return gathered prompt-major benchmark ids when present."""
if not prompt_inputs:
return None
keys = ("eval_benchmark_id", "benchmark_id")
raw_vals: Optional[List[Any]] = None
for key in keys:
candidate = [example.get(key) for example in prompt_inputs]
if candidate and any(val is not None for val in candidate):
raw_vals = candidate
break
if not raw_vals:
return None
ids: List[int] = []
for val in raw_vals:
try:
ids.append(int(val) if val is not None else -1)
except (TypeError, ValueError):
ids.append(-1)
ids_tensor = torch.tensor(ids, dtype=torch.long, device=device)
if local_only:
if not _is_main_process(trainer):
return None
return ids_tensor
gathered = gather(ids_tensor)
if not isinstance(gathered, torch.Tensor) or gathered.numel() <= 0:
return None
return gathered.to(torch.long)
def _empty_dataset_like(dataset: Any) -> Any:
"""Return an empty dataset preserving the input dataset type when possible."""
if dataset is None:
return []
select_fn = getattr(dataset, "select", None)
if callable(select_fn):
try:
return select_fn([])
except Exception:
pass
if isinstance(dataset, list):
return []
if isinstance(dataset, tuple):
return tuple()
try:
return dataset[:0]
except Exception:
return []
def _build_seed_worker(num_workers: int, rank: int):
"""Return a worker_init_fn compatible with the active transformers seed_worker signature."""
try:
from transformers.trainer_utils import seed_worker as hf_seed_worker
except Exception: # pragma: no cover - transformers is required for training
return None
try:
params = list(inspect.signature(hf_seed_worker).parameters)
except (TypeError, ValueError): # pragma: no cover - defensive fallback
return hf_seed_worker
if len(params) <= 1:
return hf_seed_worker
return partial(hf_seed_worker, num_workers=num_workers, rank=rank)
def _numeric_or_none(value: Any) -> Optional[float]:
"""Best-effort numeric conversion used for logging filters."""
if isinstance(value, bool):
return None
try:
return float(value)
except (TypeError, ValueError):
item_fn = getattr(value, "item", None)
if callable(item_fn):
try:
return float(item_fn())
except (TypeError, ValueError):
return None
return None
_BOOL_TRUE = {"1", "true", "t", "yes", "y", "on"}
_BOOL_FALSE = {"0", "false", "f", "no", "n", "off", ""}
def _coerce_bool(value: Any, *, default: bool) -> bool:
"""Convert flexible config values to bool without surprising string truthiness."""
if isinstance(value, bool):
return value
if value is None:
return default
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str):
lowered = value.strip().lower()
if lowered in _BOOL_TRUE:
return True
if lowered in _BOOL_FALSE:
return False
return default
try:
return bool(value)
except Exception:
return default
def _coerce_non_negative_float(value: Any, *, default: float = 0.0) -> float:
"""Convert config values to a finite non-negative float."""
numeric = _numeric_or_none(value)
if numeric is None or not math.isfinite(numeric):
return default
return max(float(numeric), 0.0)
def _reshape_prompt_major_tensor(
tensor: torch.Tensor,
group_size: int,
) -> Optional[torch.Tensor]:
"""Reshape prompt-major flat rollouts into ``[prompts, generations, ...]``."""
if group_size <= 0:
return None
total_rows = int(tensor.size(0))
if total_rows <= 0 or total_rows % group_size != 0:
return None
num_prompts = total_rows // group_size
if num_prompts <= 0:
return None
shape = (num_prompts, group_size) + tuple(tensor.shape[1:])
return tensor.reshape(shape).contiguous()
def _flatten_prompt_major_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Convert a prompt-major ``[prompts, generations, ...]`` tensor to flat order."""
if tensor.dim() < 2:
return tensor.reshape(-1)
shape = (-1,) + tuple(tensor.shape[2:])
return tensor.reshape(shape).contiguous()
def _resolve_prompt_group_sizes(
tensor_dict: Dict[str, Optional[torch.Tensor]],
group_size: int,
) -> Tuple[int, int]:
"""Infer flat row count and prompt count for listwise prompt groups."""
if group_size <= 0:
raise ValueError("group_size must be positive")
for key in (
"completion_ids",
"completion_mask",
"advantages",
"old_per_token_logps",
"prompt_ids",
"prompt_mask",
):
tensor = tensor_dict.get(key)
if isinstance(tensor, torch.Tensor):
total_rows = int(tensor.size(0))
break
else:
raise ValueError("Listwise prompt grouping requires flat rollout tensors.")
usable = (total_rows // group_size) * group_size
if usable <= 0:
raise ValueError("Listwise prompt grouping requires at least one full prompt group.")
if usable != total_rows:
raise ValueError(
"Listwise prompt grouping requires the flat batch size to be divisible by num_generations."
)
return total_rows, total_rows // group_size
def _shuffle_listwise_tensor_dict(
tensor_dict: Dict[str, Optional[torch.Tensor]],
group_size: int,
) -> Dict[str, Optional[torch.Tensor]]:
"""Shuffle prompt groups while preserving candidate order within each group."""
total_rows, num_prompts = _resolve_prompt_group_sizes(tensor_dict, group_size)
permutation_device: Optional[torch.device] = None
for tensor in tensor_dict.values():
if isinstance(tensor, torch.Tensor):
permutation_device = tensor.device
break
permutation = torch.randperm(num_prompts, device=permutation_device)
shuffled: Dict[str, Optional[torch.Tensor]] = {}
for key, tensor in tensor_dict.items():
if tensor is None:
shuffled[key] = None
elif int(tensor.size(0)) == total_rows:
grouped = _reshape_prompt_major_tensor(tensor, group_size)
if grouped is None:
raise ValueError(f"Could not reshape listwise tensor {key!r} for shuffling.")
shuffled[key] = _flatten_prompt_major_tensor(grouped[permutation])
elif int(tensor.size(0)) == num_prompts:
shuffled[key] = tensor[permutation]
else:
shuffled[key] = tensor
return shuffled
def _normalize_listwise_q_targets(
q_grouped: torch.Tensor,
*,
num_prompts: int,
group_size: int,
context: str,
) -> torch.Tensor:
"""Validate listwise q targets and project them onto the simplex."""
if q_grouped.dim() != 2:
raise ValueError(f"{context} requires rank-2 listwise q targets.")
expected_shape = (num_prompts, group_size)
actual_shape = (int(q_grouped.size(0)), int(q_grouped.size(1)))
if actual_shape != expected_shape:
raise ValueError(
f"{context} requires listwise q targets with shape {expected_shape}, "
f"got {actual_shape}."
)
if not torch.isfinite(q_grouped).all():
raise ValueError(f"{context} requires finite listwise q targets.")
if (q_grouped < 0).any():
raise ValueError(f"{context} requires non-negative listwise q targets.")
row_sums = q_grouped.sum(dim=1, keepdim=True)
if (row_sums <= 0).any():
raise ValueError(f"{context} requires listwise q targets with positive mass.")
return q_grouped / row_sums
def _split_listwise_tensor_dict(
tensor_dict: Dict[str, Optional[torch.Tensor]],
num_chunks: int,
group_size: int,
) -> List[Dict[str, Optional[torch.Tensor]]]:
"""Split buffered listwise tensors by whole prompt groups."""
if num_chunks <= 0:
raise ValueError("num_chunks must be positive")
total_rows, num_prompts = _resolve_prompt_group_sizes(tensor_dict, group_size)
if num_prompts % num_chunks != 0:
# When the local rollout only contains too few whole prompt groups to
# split across microsteps, reuse the full prompt-group batch for each
# microstep and attenuate each reuse so one full reuse cycle matches
# the intended total loss contribution of a normally split rollout.
scale = torch.tensor(1.0 / float(num_chunks), dtype=torch.float32)
for tensor in tensor_dict.values():
if isinstance(tensor, torch.Tensor):
scale = scale.to(device=tensor.device)
break
chunks: List[Dict[str, Optional[torch.Tensor]]] = []
for _ in range(num_chunks):
chunk = dict(tensor_dict)
chunk["maxent_listwise_loss_scale"] = scale
chunks.append(chunk)
return chunks
prompts_per_chunk = num_prompts // num_chunks
rows_per_chunk = prompts_per_chunk * group_size
chunks: List[Dict[str, Optional[torch.Tensor]]] = []
for chunk_idx in range(num_chunks):
row_start = chunk_idx * rows_per_chunk
row_end = (chunk_idx + 1) * rows_per_chunk
prompt_start = chunk_idx * prompts_per_chunk
prompt_end = (chunk_idx + 1) * prompts_per_chunk
chunk: Dict[str, Optional[torch.Tensor]] = {}
for key, tensor in tensor_dict.items():
if tensor is None:
chunk[key] = None
elif int(tensor.size(0)) == total_rows:
chunk[key] = tensor[row_start:row_end]
elif int(tensor.size(0)) == num_prompts:
chunk[key] = tensor[prompt_start:prompt_end]
else:
chunk[key] = tensor
chunks.append(chunk)
return chunks
def _strip_ema_param_prefixes(name: str) -> Tuple[str, int]:
"""Remove known wrapper prefixes used in policy/reference param names."""
clean = str(name)
stripped = 0
while clean:
matched = False
for prefix in _EMA_PARAM_NAME_PREFIXES:
if clean.startswith(prefix):
clean = clean[len(prefix) :]
stripped += 1
matched = True
break
if not matched:
break
return clean if clean else str(name), stripped
def _build_ema_alias_index(
params: Dict[str, torch.Tensor],
) -> Dict[str, List[Tuple[str, torch.Tensor, int]]]:
"""Index tensors by canonicalized names for alias-aware EMA matching."""
by_canonical: Dict[str, List[Tuple[str, torch.Tensor, int]]] = {}
for name, param in params.items():
canonical, stripped = _strip_ema_param_prefixes(name)
by_canonical.setdefault(canonical, []).append((name, param, stripped))
for candidates in by_canonical.values():
candidates.sort(key=lambda item: (item[2], len(item[0]), item[0]))
return by_canonical
def _selected_logps_and_entropy(
logits: torch.Tensor,
token_ids: torch.Tensor,
*,
entropy_mode: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return selected token log-probs and a differentiable entropy term."""
log_probs = F.log_softmax(logits, dim=-1)
selected_logps = torch.gather(
log_probs, dim=-1, index=token_ids.unsqueeze(-1)
).squeeze(-1)
if entropy_mode == "sample":
entropy = -selected_logps
else:
probs = log_probs.exp()
entropy = -(probs * log_probs).sum(dim=-1)
return selected_logps, entropy
def _resolve_ema_source_param(
ref_name: str,
ref_param: torch.Tensor,
policy_params: Dict[str, torch.Tensor],
policy_alias_index: Dict[str, List[Tuple[str, torch.Tensor, int]]],
) -> Tuple[Optional[torch.Tensor], bool]:
"""Return matching policy tensor for ``ref_name`` and whether aliasing was used."""
direct = policy_params.get(ref_name)
if isinstance(direct, torch.Tensor) and direct.shape == ref_param.shape:
return direct, False
canonical, _ = _strip_ema_param_prefixes(ref_name)
for candidate_name, candidate, _ in policy_alias_index.get(canonical, ()):
if isinstance(candidate, torch.Tensor) and candidate.shape == ref_param.shape:
return candidate, candidate_name != ref_name
return None, False
def _strip_mode_prefix(key: str, mode: str) -> str:
"""Remove a train/eval prefix from metric keys when applicable."""
if mode == "train" and key.startswith("train/"):
return key[len("train/") :]
if mode == "eval" and key.startswith("eval/"):
return key[len("eval/") :]
return key
_CANONICAL_METRIC_KEYS: Dict[str, str] = {
"completions/mean_length": "completions/mean_length_sampled",
"completions/min_length": "completions/min_length_sampled",
"completions/max_length": "completions/max_length_sampled",
"completions/clipped_ratio": "completions/clipped_frac",
"completions/mean_terminated_length": "completions/mean_length_terminated",
"completions/min_terminated_length": "completions/min_length_terminated",
"completions/max_terminated_length": "completions/max_length_terminated",
}
_LEGACY_METRIC_ALIASES: Dict[str, Tuple[str, ...]] = {
"completions/mean_length_sampled": ("completions/mean_length",),
"completions/min_length_sampled": ("completions/min_length",),
"completions/max_length_sampled": ("completions/max_length",),
"completions/clipped_frac": ("completions/clipped_ratio",),
"completions/mean_length_terminated": ("completions/mean_terminated_length",),
"completions/min_length_terminated": ("completions/min_terminated_length",),
"completions/max_length_terminated": ("completions/max_terminated_length",),
}
def _canonical_metric_key(key: str) -> str:
"""Normalize metric aliases to one canonical key namespace."""
if key.startswith("diversity/"):
return f"completions/{key}"
return _CANONICAL_METRIC_KEYS.get(key, key)
def _legacy_metric_aliases(key: str) -> Tuple[str, ...]:
"""Return compatibility aliases for a canonical metric key."""
aliases: List[str] = list(_LEGACY_METRIC_ALIASES.get(key, ()))
if key.startswith("completions/diversity/"):
aliases.append(key[len("completions/") :])
if not aliases:
return ()
return tuple(dict.fromkeys(aliases))
def _supports_adapter_disabled_reference(model: Any) -> bool:
"""Return whether the model exposes an adapter-disable reference path."""
return callable(getattr(model, "disable_adapter", None)) or (
callable(getattr(model, "disable_adapters", None))
and callable(getattr(model, "enable_adapters", None))
)
[docs]
def build_custom_grpo_trainer(parent_cls: Type[Any]) -> Type[Any]:
"""Return a GRPOTrainer subclass with MaxEnt hooks enabled.
:param parent_cls: Base TRL GRPOTrainer class.
:returns: Wrapped GRPOTrainer subclass.
"""
if getattr(parent_cls, "_MAXENT_CUSTOM_TRAINER", False):
return parent_cls
class CustomGRPOTrainer(parent_cls):
"""Thin GRPOTrainer subclass used as a future extension point."""
_MAXENT_CUSTOM_TRAINER = True
@staticmethod
def _resolve_parent_training_args(
init_args: Tuple[Any, ...],
init_kwargs: Dict[str, Any],
) -> Any:
"""Best-effort retrieval of TRL trainer args from constructor inputs."""
if "args" in init_kwargs:
return init_kwargs.get("args")
if len(init_args) >= 3:
return init_args[2]
return None
def __init__(self, *args: Any, **kwargs: Any) -> None:
parent_args = self._resolve_parent_training_args(args, kwargs)
parent_routing = resolve_objective_routing(
objective=getattr(parent_args, "objective", None),
train_grpo_objective=getattr(
parent_args, "train_grpo_objective", True
),
maxent_objective_variant=getattr(
parent_args, "maxent_objective_variant", None
),
maxent_alpha=getattr(parent_args, "maxent_alpha", None),
policy_entropy_bonus_coef=getattr(
parent_args, "policy_entropy_bonus_coef", 0.0
),
)
maxent_requested = parent_routing.maxent_requested
parent_alpha_default = 1.0 if maxent_requested else 0.0
parent_maxent_alpha = _coerce_non_negative_float(
(
getattr(parent_args, "maxent_alpha", parent_alpha_default)
if parent_args is not None
else parent_alpha_default
),
default=parent_alpha_default,
)
super().__init__(*args, **kwargs)
self.objective_routing = resolve_objective_routing(
objective=getattr(getattr(self, "args", None), "objective", None),
train_grpo_objective=getattr(
getattr(self, "args", None), "train_grpo_objective", True
),
maxent_objective_variant=getattr(
getattr(self, "args", None), "maxent_objective_variant", None
),
maxent_alpha=getattr(
getattr(self, "args", None), "maxent_alpha", parent_maxent_alpha
),
policy_entropy_bonus_coef=getattr(
getattr(self, "args", None), "policy_entropy_bonus_coef", 0.0
),
)
self.method_spec = resolve_method_spec_from_args(
getattr(self, "args", None)
)
self.maxent_enabled = self.objective_routing.maxent_requested
self.maxent_objective_variant = (
self.objective_routing.maxent_objective_variant
)
self.maxent_alpha = self.objective_routing.maxent_alpha
self._maybe_initialize_reference_model_for_maxent()
controller_meta_requested = bool(
getattr(getattr(self, "args", None), "controller_meta_enabled", False)
)
self._controller_meta_requested = controller_meta_requested
if self.objective_routing.uses_listwise_loss:
configured_tau = _coerce_non_negative_float(
getattr(getattr(self, "args", None), "maxent_tau", 0.0),
default=0.0,
)
if configured_tau <= 0.0:
raise ValueError("Listwise MaxEnt requires maxent_tau > 0.")
self._maxent_weighting = (
build_weighting_settings(getattr(self, "args", None))
if (
(self.maxent_enabled or controller_meta_requested)
and getattr(self, "args", None) is not None
)
else None
)
self._maxent_controller_objective = (
build_controller_objective(
getattr(self, "args", None), self._maxent_weighting
)
if self._maxent_weighting is not None
else None
)
self._sync_weighting_scalars()
route_mode = self.objective_routing.route_mode
LOG.info(
"Objective routing selected | mode=%s | objective=%s | "
"maxent_variant=%s | maxent_alpha=%s",
route_mode,
getattr(getattr(self, "args", None), "objective", None),
self.maxent_objective_variant,
self.maxent_alpha,
)
if self.method_spec is not None:
LOG.info(
"Resolved training method | name=%s | family=%s | backend=%s | "
"objective=%s | seed_grpo=%s | slug=%s",
self.method_spec.canonical_name,
self.method_spec.family,
self.method_spec.loss_backend,
self.method_spec.objective,
self.method_spec.seed_grpo_enabled,
self.method_spec.slug,
)
def evaluation_loop( # type: ignore[override]
self,
dataloader: Any,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Any:
if _use_lightweight_greedy_eval(self, "eval"):
original_include = getattr(self.args, "include_for_metrics", None)
filtered_include: Any = original_include
if original_include is not None:
filtered_include = tuple(
item for item in original_include if item != "inputs"
)
try:
if original_include is not None:
self.args.include_for_metrics = filtered_include
return super().evaluation_loop(
dataloader,
description,
prediction_loss_only=True,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
finally:
if original_include is not None:
self.args.include_for_metrics = original_include
return super().evaluation_loop(
dataloader,
description,
prediction_loss_only=prediction_loss_only,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
def prediction_step( # type: ignore[override]
self,
model: Any,
inputs: Dict[str, Any],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
if _use_lightweight_greedy_eval(self, "eval") and not bool(
getattr(model, "training", False)
):
# Greedy-only eval should do the minimum work required to log
# pass@1 metrics: prepare one greedy completion per prompt and
# skip the extra eval loss computation that the Trainer would
# otherwise perform for every batch.
self._prepare_inputs(inputs)
return None, None, None
return super().prediction_step(
model,
inputs,
prediction_loss_only,
ignore_keys=ignore_keys,
)
if controller_meta_requested and not self.objective_routing.uses_listwise_loss:
LOG.info(
"Controller meta enabled for objective=%s; beta updates stay active, "
"but tau only affects listwise MaxEnt.",
route_mode,
)
if self.objective_routing.uses_listwise_loss:
if self.maxent_alpha > 0.0:
LOG.info(
"Listwise MaxEnt selected; maxent_alpha=%.4f is inactive in this objective.",
self.maxent_alpha,
)
elif self.maxent_enabled and self.maxent_objective_variant == "entropy":
if float(getattr(getattr(self, "args", None), "maxent_tau", 0.0) or 0.0) > 0.0:
LOG.info(
"Entropy-regularized MaxEnt selected; listwise tau/q weighting knobs stay inactive."
)
self._step = 0
self._buffered_inputs: Optional[List[Dict[str, Optional[torch.Tensor]]]] = None
self._last_train_kl_for_alpha: Optional[float] = None
self._last_grpo_debug_step: Optional[int] = None
self._last_reference_ema_step: Optional[int] = None
def _entropy_alpha_kl_control_requested(self) -> bool:
"""Return whether entropy-MaxEnt KL-based alpha control is active."""
args = getattr(self, "args", None)
return bool(
getattr(args, "maxent_alpha_raise_on_low_kl", False)
or getattr(args, "maxent_alpha_lower_on_high_kl", False)
or getattr(args, "maxent_alpha_disable_outside_trust_zone", False)
)
def _dr_grpo_denominator_mode(self) -> str:
"""Return the normalized Dr.GRPO denominator mode."""
args = getattr(self, "args", None)
mode = str(
getattr(args, "dr_grpo_denominator_mode", "fixed_max") or "fixed_max"
).strip().lower()
return "active_tokens" if mode == "active_tokens" else "fixed_max"
def _dr_grpo_loss_denominator(
self,
completion_mask: torch.Tensor,
*,
loss_tensor: torch.Tensor,
mode: str,
) -> torch.Tensor:
"""Return the denominator used by the Dr.GRPO loss."""
denominator_mode = self._dr_grpo_denominator_mode()
if denominator_mode == "active_tokens":
denominator = completion_mask.sum().clamp(min=1).to(loss_tensor.dtype)
else:
max_completion_length = int(
getattr(self, "max_completion_length", 0)
or getattr(getattr(self, "args", None), "max_completion_length", 0)
or completion_mask.size(1)
or 1
)
denominator = loss_tensor.new_tensor(
float(max(loss_tensor.size(0) * max(max_completion_length, 1), 1))
)
self._append_metric_value(
mode,
"loss/dr_grpo_denominator",
float(denominator.detach().item()),
include_legacy_aliases=False,
)
self._append_metric_value(
mode,
"loss/dr_grpo_denominator_active_tokens",
1.0 if denominator_mode == "active_tokens" else 0.0,
include_legacy_aliases=False,
)
return denominator
def _should_force_reference_model_for_maxent(self) -> bool:
"""Return whether MaxEnt should materialize a frozen reference model."""
if not bool(getattr(self, "maxent_enabled", False)):
return False
args = getattr(self, "args", None)
if bool(getattr(args, "maxent_share_reference_model", False)):
return False
if getattr(self, "ref_model", None) is not None:
return False
unwrapped_model = getattr(self, "model", None)
unwrap_fn = getattr(getattr(self, "accelerator", None), "unwrap_model", None)
if callable(unwrap_fn):
try:
unwrapped_model = unwrap_fn(unwrapped_model)
except Exception:
unwrapped_model = getattr(self, "model", None)
if _supports_adapter_disabled_reference(unwrapped_model):
return False
ref_source = str(
getattr(args, "maxent_reference_logprobs_source", "auto") or "auto"
).strip().lower()
force_model_reference = bool(
getattr(args, "maxent_trl_reference_scoring", False)
) or ref_source in {
"model",
"reference",
"reference_model",
"ref_model",
}
needs_entropy_kl_measure = bool(
self.objective_routing.uses_entropy_regularized_loss
and (
float(getattr(self, "beta", 0.0) or 0.0) != 0.0
or self._entropy_alpha_kl_control_requested()
)
)
return bool(force_model_reference or needs_entropy_kl_measure)
def _maybe_initialize_reference_model_for_maxent(self) -> None:
"""Materialize a frozen reference model when MaxEnt needs one and TRL skipped it."""
if not self._should_force_reference_model_for_maxent():
return
global AutoModelForCausalLM
global _trl_create_reference_model
global _trl_prepare_deepspeed
global _trl_prepare_fsdp
if _trl_create_reference_model is None:
try:
from trl.models import create_reference_model as _create_reference_model
from trl.models import prepare_deepspeed as _prepare_deepspeed
from trl.models import prepare_fsdp as _prepare_fsdp
_trl_create_reference_model = _create_reference_model
_trl_prepare_deepspeed = _prepare_deepspeed
_trl_prepare_fsdp = _prepare_fsdp
except Exception:
_trl_create_reference_model = None
_trl_prepare_deepspeed = None
_trl_prepare_fsdp = None
if AutoModelForCausalLM is None:
try:
from transformers import AutoModelForCausalLM as _AutoModelForCausalLM
AutoModelForCausalLM = _AutoModelForCausalLM # type: ignore[assignment]
except Exception:
AutoModelForCausalLM = None # type: ignore[assignment]
policy_model = getattr(self, "model", None)
unwrap_fn = getattr(getattr(self, "accelerator", None), "unwrap_model", None)
unwrapped_policy = policy_model
if callable(unwrap_fn):
try:
unwrapped_policy = unwrap_fn(policy_model)
except Exception:
unwrapped_policy = policy_model
ref_model: Any = None
if (
not bool(getattr(self, "is_deepspeed_enabled", False))
and not bool(getattr(self, "is_fsdp_enabled", False))
and callable(_trl_create_reference_model)
):
try:
ref_model = _trl_create_reference_model(unwrapped_policy)
except Exception as exc:
LOG.warning(
"Failed to clone a frozen reference model for MaxEnt KL measurement; "
"retrying from pretrained weights: %s",
exc,
)
if ref_model is None and AutoModelForCausalLM is not None:
args = getattr(self, "args", None)
model_init_kwargs_raw = getattr(args, "model_init_kwargs", None)
model_init_kwargs = (
dict(model_init_kwargs_raw)
if isinstance(model_init_kwargs_raw, Mapping)
else {}
)
ref_revision = getattr(args, "reference_model_revision", None)
if ref_revision:
model_init_kwargs["revision"] = ref_revision
model_id = (
getattr(args, "reference_model_name_or_path", None)
or getattr(getattr(unwrapped_policy, "config", None), "_name_or_path", None)
or getattr(unwrapped_policy, "name_or_path", None)
)
if model_id:
try:
ref_model = AutoModelForCausalLM.from_pretrained(
model_id,
**model_init_kwargs,
)
except Exception as exc:
LOG.warning(
"Failed to load a frozen reference model for MaxEnt KL measurement "
"from %s: %s",
model_id,
exc,
)
if ref_model is None:
if not bool(getattr(self, "_maxent_missing_ref_model_warned", False)):
LOG.warning(
"MaxEnt requested model-based KL measurement, but no frozen reference "
"model could be initialized. KL metrics may collapse to rollout behavior."
)
setattr(self, "_maxent_missing_ref_model_warned", True)
return
for param in getattr(ref_model, "parameters", lambda: [])():
try:
param.requires_grad = False
except (AttributeError, RuntimeError):
continue
eval_fn = getattr(ref_model, "eval", None)
if callable(eval_fn):
eval_fn()
if bool(getattr(self, "is_deepspeed_enabled", False)) and callable(
_trl_prepare_deepspeed
):
ref_model = _trl_prepare_deepspeed(ref_model, self.accelerator)
elif bool(getattr(self, "is_fsdp_enabled", False)) and callable(
_trl_prepare_fsdp
):
ref_model = _trl_prepare_fsdp(ref_model, self.accelerator)
else:
prepare_model = getattr(self.accelerator, "prepare_model", None)
if callable(prepare_model):
try:
ref_model = prepare_model(
ref_model,
evaluation_mode=True,
)
except TypeError:
ref_model = prepare_model(ref_model)
else:
prepare = getattr(self.accelerator, "prepare", None)
if callable(prepare):
ref_model = prepare(ref_model)
self.ref_model = ref_model # pylint: disable=attribute-defined-outside-init
LOG.info(
"Materialized a frozen reference model for MaxEnt KL measurement despite beta=0."
)
def _should_use_model_reference_logprobs(
self,
*,
default_to_model_reference: bool,
) -> bool:
"""Return whether the current loss should score against a model-based reference."""
args = getattr(self, "args", None)
ref_source = str(
getattr(args, "maxent_reference_logprobs_source", "auto") or "auto"
).strip().lower()
if ref_source == "none":
ref_source = "policy"
if bool(getattr(args, "maxent_trl_reference_scoring", False)):
return True
if ref_source in {"model", "reference", "reference_model", "ref_model"}:
return True
if ref_source == "policy":
return False
if ref_source == "auto":
if getattr(self, "ref_model", None) is not None:
return True
unwrap_fn = getattr(getattr(self, "accelerator", None), "unwrap_model", None)
unwrapped_model = getattr(self, "model", None)
if callable(unwrap_fn):
try:
unwrapped_model = unwrap_fn(unwrapped_model)
except Exception:
unwrapped_model = getattr(self, "model", None)
if _supports_adapter_disabled_reference(unwrapped_model):
return True
return default_to_model_reference
return default_to_model_reference
def _get_reference_per_token_logps(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
logits_to_keep: int,
*,
batch_size: int,
) -> Optional[torch.Tensor]:
"""Return per-token log-probs from the frozen/model-based reference path."""
if getattr(self, "ref_model", None) is not None:
return self._get_per_token_logps(
self.ref_model,
input_ids,
attention_mask,
logits_to_keep,
batch_size=batch_size,
)
unwrap_fn = getattr(getattr(self, "accelerator", None), "unwrap_model", None)
unwrapped_model = getattr(self, "model", None)
if callable(unwrap_fn):
try:
unwrapped_model = unwrap_fn(unwrapped_model)
except Exception:
unwrapped_model = getattr(self, "model", None)
if not _supports_adapter_disabled_reference(unwrapped_model):
return None
with _adapter_disabled_context(unwrapped_model):
return self._get_per_token_logps(
self.model,
input_ids,
attention_mask,
logits_to_keep,
batch_size=batch_size,
)
def _sync_weighting_scalars(self) -> None:
"""Expose controller scalars on the trainer for logging helpers."""
weighting = getattr(self, "_maxent_weighting", None)
if weighting is None:
return
tau_val = float(getattr(weighting, "tau", 0.0) or 0.0)
beta_val = float(getattr(weighting, "beta", 0.0) or 0.0)
denom_val = float(getattr(weighting, "denom", 1.0) or 1.0)
self.tau = tau_val # pylint: disable=attribute-defined-outside-init
self.maxent_tau = tau_val # pylint: disable=attribute-defined-outside-init
self.beta = beta_val # pylint: disable=attribute-defined-outside-init
self.weight_norm_denom = (
denom_val # pylint: disable=attribute-defined-outside-init
)
def _maybe_apply_controller_meta(
self,
*,
mode: str,
kl_value: Optional[float],
weight_entropy: Optional[float] = None,
total_loss: Optional[float] = None,
) -> bool:
"""Apply one meta-controller update when the active route enables it."""
if mode != "train":
return False
weighting = getattr(self, "_maxent_weighting", None)
meta_objective = getattr(self, "_maxent_controller_objective", None)
if weighting is None or meta_objective is None:
return False
update_interval = max(
1,
int(
getattr(
getattr(weighting, "controller_meta", None),
"update_interval",
1,
)
or 1
),
)
global_step = int(getattr(self.state, "global_step", 0) or 0)
if global_step % update_interval != 0:
return False
weight_stats = SimpleNamespace()
if isinstance(weight_entropy, (int, float)) and math.isfinite(weight_entropy):
weight_stats = SimpleNamespace(weight_entropy=float(weight_entropy))
loss_outputs = SimpleNamespace(
kl_loss_scalar=(
float(kl_value)
if isinstance(kl_value, (int, float)) and math.isfinite(kl_value)
else None
),
total_loss_scalar=(
float(total_loss)
if isinstance(total_loss, (int, float)) and math.isfinite(total_loss)
else None
),
)
grads = meta_objective.compute(
ControllerMetaContext(
weighting=weighting,
weight_stats=weight_stats,
loss_outputs=loss_outputs,
global_step=global_step,
kl_value=kl_value,
)
)
if grads is None:
return False
updated = apply_meta_controller_update(
weighting,
tau_grad=grads.tau_grad,
beta_grad=grads.beta_grad,
)
if not updated:
return False
if isinstance(getattr(grads, "tau_grad", None), (int, float)):
self._append_metric_value(
mode,
"meta/tau_grad",
float(getattr(grads, "tau_grad", 0.0) or 0.0),
)
if isinstance(getattr(grads, "beta_grad", None), (int, float)):
self._append_metric_value(
mode,
"meta/beta_grad",
float(getattr(grads, "beta_grad", 0.0) or 0.0),
)
return True
def get_train_dataloader(self): # type: ignore[override]
# Preserve native TRL batching/sampling behavior for GRPO/MaxEnt while
# adapting worker_init_fn to the active transformers seed_worker signature.
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
try:
from transformers.utils import is_datasets_available
except (
Exception
): # pragma: no cover - transformers is required for training
def is_datasets_available() -> bool:
return False
if is_datasets_available():
try:
import datasets
except Exception:
datasets = None # type: ignore
if datasets is not None and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
dataloader_params = {
"batch_size": self._train_batch_size * self.args.steps_per_generation,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = (
self.args.dataloader_prefetch_factor
)
worker_init_fn = _build_seed_worker(
self.args.dataloader_num_workers, self.args.process_index
)
if worker_init_fn is not None:
dataloader_params["worker_init_fn"] = worker_init_fn
return self.accelerator.prepare(
torch.utils.data.DataLoader(train_dataset, **dataloader_params)
)
def get_eval_dataloader(self, eval_dataset: Optional[Any] = None): # type: ignore[override]
"""Use a prompt-major loader for greedy-only eval, sharded across ranks."""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
lightweight_eval = _use_lightweight_greedy_eval(self, "eval")
if not lightweight_eval:
setattr(self, "_local_only_eval_prompt_major_loader_active", False)
setattr(self, "_sharded_eval_prompt_major_loader_active", False)
return super().get_eval_dataloader(eval_dataset)
if eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
sharded_eval = _use_sharded_prompt_major_greedy_eval(self, "eval")
setattr(self, "_local_only_eval_prompt_major_loader_active", True)
setattr(self, "_sharded_eval_prompt_major_loader_active", bool(sharded_eval))
prompt_major_dataset = eval_dataset
data_collator = self.data_collator
try:
from transformers.utils import is_datasets_available
except Exception: # pragma: no cover - transformers is required for training
def is_datasets_available() -> bool:
return False
if is_datasets_available():
try:
import datasets
except Exception:
datasets = None # type: ignore
if datasets is not None and isinstance(
prompt_major_dataset, datasets.Dataset
):
prompt_major_dataset = self._remove_unused_columns(
prompt_major_dataset,
description="evaluation",
)
else:
trim_collator = getattr(
self,
"_get_collator_with_removed_columns",
None,
)
if callable(trim_collator):
data_collator = trim_collator(
data_collator,
description="evaluation",
)
else:
trim_collator = getattr(
self,
"_get_collator_with_removed_columns",
None,
)
if callable(trim_collator):
data_collator = trim_collator(
data_collator,
description="evaluation",
)
dataloader_params = {
"batch_size": int(
getattr(self.args, "per_device_eval_batch_size", 0)
or getattr(self.args, "eval_batch_size", 0)
or 1
),
"collate_fn": data_collator,
"num_workers": int(getattr(self.args, "dataloader_num_workers", 0) or 0),
"pin_memory": bool(
getattr(self.args, "dataloader_pin_memory", False)
),
"persistent_workers": bool(
getattr(self.args, "dataloader_persistent_workers", False)
),
}
if not isinstance(prompt_major_dataset, torch.utils.data.IterableDataset):
if sharded_eval:
accelerator = getattr(self, "accelerator", None)
try:
num_processes = int(
getattr(accelerator, "num_processes", 1) or 1
)
except (TypeError, ValueError):
num_processes = 1
try:
process_index = int(
getattr(accelerator, "process_index", 0) or 0
)
except (TypeError, ValueError):
process_index = 0
dataloader_params["sampler"] = (
torch.utils.data.distributed.DistributedSampler(
prompt_major_dataset,
num_replicas=max(num_processes, 1),
rank=max(process_index, 0),
shuffle=False,
drop_last=False,
)
)
else:
dataloader_params["sampler"] = torch.utils.data.SequentialSampler(
prompt_major_dataset
)
dataloader_params["drop_last"] = False
if dataloader_params["num_workers"] > 0:
prefetch = getattr(self.args, "dataloader_prefetch_factor", None)
if prefetch is not None:
dataloader_params["prefetch_factor"] = int(prefetch)
worker_init_fn = _build_seed_worker(
dataloader_params["num_workers"],
int(getattr(self.args, "process_index", 0) or 0),
)
if worker_init_fn is not None:
dataloader_params["worker_init_fn"] = worker_init_fn
# Keep this dataloader fully local. ``accelerator.prepare`` would shard or
# wrap it again, which defeats the goal of rank-0-only prompt-major eval.
return torch.utils.data.DataLoader(prompt_major_dataset, **dataloader_params)
def _prepare_inputs(self, generation_batch: Any) -> Any: # type: ignore[override]
if not self.objective_routing.uses_listwise_loss:
return super()._prepare_inputs(generation_batch)
mode = "train" if self.model.training else "eval"
if mode == "train":
generate_every = self.args.steps_per_generation * self.num_iterations
if self._step % generate_every == 0 or self._buffered_inputs is None:
generated = self._generate_and_score_completions(generation_batch)
group_size = max(int(getattr(self, "num_generations", 1) or 1), 1)
steps_per_generation = int(
getattr(self.args, "steps_per_generation", 1) or 1
)
_, num_prompts = _resolve_prompt_group_sizes(generated, group_size)
q_targets = generated.get("maxent_listwise_q")
if not isinstance(q_targets, torch.Tensor):
raise ValueError(
"Listwise MaxEnt rollout generation must provide maxent_listwise_q targets."
)
generated["maxent_listwise_q"] = _normalize_listwise_q_targets(
q_targets,
num_prompts=num_prompts,
group_size=group_size,
context="Listwise MaxEnt rollout generation",
)
generated = _shuffle_listwise_tensor_dict(
generated,
group_size,
)
if (
num_prompts % steps_per_generation != 0
and not bool(
getattr(self, "_listwise_batch_reuse_warned", False)
)
):
LOG.warning(
"Listwise MaxEnt local rollout prompt groups (%d) do not divide "
"steps_per_generation (%d); reusing the full local listwise batch "
"across microsteps with a per-microstep loss scale of 1/%d. "
"Increase the local prompt-group count to avoid this fallback.",
num_prompts,
steps_per_generation,
steps_per_generation,
)
setattr(self, "_listwise_batch_reuse_warned", True)
self._buffered_inputs = _split_listwise_tensor_dict(
generated,
steps_per_generation,
group_size,
)
inputs = self._buffered_inputs[
self._step % int(getattr(self.args, "steps_per_generation", 1) or 1)
]
self._step += 1
return inputs
return self._generate_and_score_completions(generation_batch)
def _append_metric_value(
self,
mode: str,
key: str,
value: Any,
*,
include_legacy_aliases: bool = True,
) -> None:
numeric = _numeric_or_none(value)
if numeric is None:
return
normalized = _strip_mode_prefix(str(key), mode)
canonical = _canonical_metric_key(normalized)
store = self._metrics[mode]
if canonical == "num_tokens":
store[canonical] = [numeric]
else:
store.setdefault(canonical, []).append(numeric)
if mode == "train" and canonical == "kl":
setattr(self, "_last_train_kl_for_alpha", float(numeric))
if not include_legacy_aliases:
return
for alias in _legacy_metric_aliases(canonical):
if alias == canonical:
continue
if canonical == "num_tokens":
store[alias] = [numeric]
else:
store.setdefault(alias, []).append(numeric)
def _set_latest_metric_value(
self,
mode: str,
key: str,
value: Any,
*,
include_legacy_aliases: bool = True,
) -> None:
"""Replace the most recent metric sample for a key, appending if absent."""
numeric = _numeric_or_none(value)
if numeric is None:
return
normalized = _strip_mode_prefix(str(key), mode)
canonical = _canonical_metric_key(normalized)
store = self._metrics[mode]
if canonical == "num_tokens":
store[canonical] = [numeric]
else:
bucket = store.setdefault(canonical, [])
if bucket:
bucket[-1] = numeric
else:
bucket.append(numeric)
if mode == "train" and canonical == "kl":
setattr(self, "_last_train_kl_for_alpha", float(numeric))
if not include_legacy_aliases:
return
for alias in _legacy_metric_aliases(canonical):
if alias == canonical:
continue
if canonical == "num_tokens":
store[alias] = [numeric]
continue
alias_bucket = store.setdefault(alias, [])
if alias_bucket:
alias_bucket[-1] = numeric
else:
alias_bucket.append(numeric)
def _recompute_completion_metrics(
self,
outputs: Dict[str, Any],
*,
mode: str,
) -> None:
"""Overwrite TRL completion metrics with correctly gathered values."""
completion_ids = outputs.get("completion_ids")
completion_mask = outputs.get("completion_mask")
if not isinstance(completion_ids, torch.Tensor) or not isinstance(
completion_mask, torch.Tensor
):
return
try:
completion_lengths = completion_mask.sum(dim=1).to(torch.float32)
except Exception:
return
gathered_lengths = _metric_tensor_for_logging(
self,
completion_lengths,
mode=mode,
)
if not isinstance(gathered_lengths, torch.Tensor) or gathered_lengths.numel() <= 0:
return
self._set_latest_metric_value(
mode,
"completions/mean_length",
float(gathered_lengths.mean().item()),
)
self._set_latest_metric_value(
mode,
"completions/min_length",
float(gathered_lengths.min().item()),
)
self._set_latest_metric_value(
mode,
"completions/max_length",
float(gathered_lengths.max().item()),
)
terminated_mask: Optional[torch.Tensor] = None
eos_token_id = _coerce_optional_int(
getattr(getattr(self, "processing_class", None), "eos_token_id", None)
)
if eos_token_id is not None:
try:
active_eos = (completion_ids == int(eos_token_id)) & completion_mask.to(
dtype=torch.bool
)
terminated_mask = active_eos.any(dim=1)
except Exception:
terminated_mask = None
max_completion_length = int(
getattr(self, "max_completion_length", 0)
or getattr(getattr(self, "args", None), "max_completion_length", 0)
or 0
)
if max_completion_length > 0:
try:
shorter_than_cap = completion_lengths.to(torch.long) < int(
max_completion_length
)
terminated_mask = (
shorter_than_cap
if terminated_mask is None
else (terminated_mask | shorter_than_cap)
)
except Exception:
pass
if not isinstance(terminated_mask, torch.Tensor):
return
gathered_terminated = _metric_tensor_for_logging(
self,
terminated_mask.to(torch.bool),
mode=mode,
)
if (
not isinstance(gathered_terminated, torch.Tensor)
or gathered_terminated.numel() <= 0
):
return
total = int(min(gathered_lengths.numel(), gathered_terminated.numel()))
if total <= 0:
return
gathered_lengths = gathered_lengths[:total]
gathered_terminated = gathered_terminated[:total].to(torch.bool)
term_completion_lengths = gathered_lengths[gathered_terminated]
clipped_ratio = 1.0 - (
float(term_completion_lengths.numel()) / float(max(total, 1))
)
self._set_latest_metric_value(
mode,
"completions/clipped_ratio",
clipped_ratio,
)
if term_completion_lengths.numel() <= 0:
zero_val = 0.0
self._set_latest_metric_value(
mode,
"completions/mean_terminated_length",
zero_val,
)
self._set_latest_metric_value(
mode,
"completions/min_terminated_length",
zero_val,
)
self._set_latest_metric_value(
mode,
"completions/max_terminated_length",
zero_val,
)
return
self._set_latest_metric_value(
mode,
"completions/mean_terminated_length",
float(term_completion_lengths.mean().item()),
)
self._set_latest_metric_value(
mode,
"completions/min_terminated_length",
float(term_completion_lengths.min().item()),
)
self._set_latest_metric_value(
mode,
"completions/max_terminated_length",
float(term_completion_lengths.max().item()),
)
def _log_grpo_debug(
self,
inputs: List[Dict[str, Any]],
outputs: Dict[str, Any],
*,
mode: str,
) -> None:
if mode != "train":
return
step = int(getattr(self.state, "global_step", 0))
last_logged_step = getattr(self, "_last_grpo_debug_step", None)
if last_logged_step == step:
return
self._last_grpo_debug_step = step
completion_mask = outputs.get("completion_mask")
advantages = outputs.get("advantages")
completion_ids = outputs.get("completion_ids")
token_mask_sum = None
completion_length_mean = None
if isinstance(completion_mask, torch.Tensor):
try:
completion_lengths = completion_mask.sum(1)
agg_lengths = self.accelerator.gather(completion_lengths)
completion_length_mean = float(agg_lengths.float().mean().item())
token_mask_sum = float(agg_lengths.sum().item())
except Exception:
completion_length_mean = None
token_mask_sum = None
advantages_std = None
if isinstance(advantages, torch.Tensor):
try:
agg_adv = self.accelerator.gather(advantages)
advantages_std = float(agg_adv.float().std().item())
except Exception:
advantages_std = None
reward_std = None
try:
reward_history = self._metrics.get(mode, {}).get("reward_std")
if reward_history:
reward_std = float(reward_history[-1])
except Exception:
reward_std = None
local_expected = len(inputs)
local_actual = (
int(completion_ids.shape[0])
if isinstance(completion_ids, torch.Tensor)
else local_expected
)
try:
counts = torch.tensor(
[local_expected, local_actual],
device=self.accelerator.device,
dtype=torch.long,
)
agg_counts = self.accelerator.gather(counts)
expected_total = int(agg_counts[0::2].sum().item())
actual_total = int(agg_counts[1::2].sum().item())
except Exception:
expected_total = local_expected
actual_total = local_actual
dropped_total = max(expected_total - actual_total, 0)
if self.accelerator.is_main_process:
LOG.info(
"GRPO debug | step=%d | token_mask_sum=%s | completion_length_mean=%s | "
"advantages_std=%s | reward_std=%s | num_sequences=%d | dropped_groups=%d",
step,
token_mask_sum,
completion_length_mean,
advantages_std,
reward_std,
expected_total,
dropped_total,
)
self._maybe_update_grpo_beta(mode)
def _maybe_update_grpo_beta(self, mode: str) -> None:
if self.maxent_enabled:
return
if getattr(self, "_maxent_controller_objective", None) is not None:
return
args = getattr(self, "args", None)
if args is None:
return
if not bool(getattr(args, "grpo_beta_controller_enabled", False)):
return
kl_target = float(getattr(args, "kl_target", 0.0) or 0.0)
kl_horizon = int(getattr(args, "kl_horizon", 0) or 0)
kl_ctl_step_size = float(getattr(args, "kl_ctl_step_size", 0.0) or 0.0)
if kl_target <= 0.0 or kl_horizon <= 0 or kl_ctl_step_size <= 0.0:
return
kl_history = self._metrics.get(mode, {}).get("kl")
if not kl_history:
return
try:
measured_kl = float(kl_history[-1])
except (TypeError, ValueError):
return
if not math.isfinite(measured_kl):
return
current_beta = float(getattr(self, "beta", 0.0) or 0.0)
if current_beta <= 0.0:
return
ratio = measured_kl / max(kl_target, 1e-8)
error = ratio - 1.0
if abs(error) < 1e-8:
return
limit = kl_ctl_step_size
clipped_error = max(min(error, limit), -limit)
horizon = max(1, kl_horizon)
scale = 1.0 + clipped_error / float(horizon)
if scale <= 0.0:
scale = 1e-6
new_beta = max(0.0, current_beta * scale)
self.beta = new_beta # pylint: disable=attribute-defined-outside-init
def _maybe_update_reference_model_ema(self) -> None:
"""Soft-update frozen reference weights from the current policy weights."""
args = getattr(self, "args", None)
if args is None:
return
if not bool(getattr(self.model, "training", False)):
return
if not bool(getattr(args, "maxent_reference_ema_enabled", False)):
return
if bool(getattr(args, "maxent_share_reference_model", False)):
if not bool(getattr(self, "_maxent_ref_ema_share_warned", False)):
LOG.warning(
"Reference EMA requested but maxent_share_reference_model=true; skipping EMA updates."
)
setattr(self, "_maxent_ref_ema_share_warned", True)
return
ref_model = getattr(self, "ref_model", None)
if ref_model is None:
if not bool(getattr(self, "_maxent_ref_ema_missing_warned", False)):
LOG.warning(
"Reference EMA requested but no frozen reference model is available; skipping EMA updates."
)
setattr(self, "_maxent_ref_ema_missing_warned", True)
return
step = int(getattr(self.state, "global_step", 0) or 0)
if step <= 0:
return
if self._last_reference_ema_step == step:
return
warmup_raw = getattr(args, "maxent_reference_ema_warmup_steps", 0)
interval_raw = getattr(args, "maxent_reference_ema_update_interval", 1)
beta_raw = getattr(args, "maxent_reference_ema_beta", 0.995)
try:
warmup_steps = int(warmup_raw)
except (TypeError, ValueError):
warmup_steps = 0
if warmup_steps < 0:
warmup_steps = 0
if step < warmup_steps:
return
try:
update_interval = int(interval_raw)
except (TypeError, ValueError):
update_interval = 1
if update_interval < 1:
update_interval = 1
if (step - warmup_steps) % update_interval != 0:
return
beta = _numeric_or_none(beta_raw)
if beta is None or not math.isfinite(beta):
beta = 0.995
beta = min(max(float(beta), 0.0), 1.0)
alpha = 1.0 - beta
if alpha <= 0.0:
return
unwrap_fn = getattr(self.accelerator, "unwrap_model", None)
policy_model = self.model
if callable(unwrap_fn):
try:
policy_model = unwrap_fn(policy_model)
except Exception:
policy_model = self.model
try:
ref_model = unwrap_fn(ref_model)
except Exception:
ref_model = getattr(self, "ref_model", None)
if ref_model is None:
return
if ref_model is policy_model:
if not bool(getattr(self, "_maxent_ref_ema_alias_warned", False)):
LOG.warning(
"Reference EMA requested but reference model aliases the policy model; skipping EMA updates."
)
setattr(self, "_maxent_ref_ema_alias_warned", True)
return
policy_named = getattr(policy_model, "named_parameters", None)
ref_named = getattr(ref_model, "named_parameters", None)
if not callable(policy_named) or not callable(ref_named):
return
try:
policy_named_fn = cast(
Callable[[], Iterable[tuple[str, Any]]], policy_named
)
ref_named_fn = cast(Callable[[], Iterable[tuple[str, Any]]], ref_named)
policy_params = {
str(name): param
for name, param in policy_named_fn()
if isinstance(param, torch.Tensor)
}
ref_params = {
str(name): param
for name, param in ref_named_fn() # pylint: disable=not-callable
if isinstance(param, torch.Tensor)
}
if not policy_params or not ref_params:
return
policy_alias_index = _build_ema_alias_index(policy_params)
total_ref = len(ref_params)
updated = 0
mismatched = 0
alias_hits = 0
mismatch_examples: List[str] = []
with torch.no_grad():
for name, ref_param in ref_params.items():
src_param, alias_used = _resolve_ema_source_param(
name,
ref_param,
policy_params,
policy_alias_index,
)
if not isinstance(src_param, torch.Tensor):
mismatched += 1
if len(mismatch_examples) < 5:
mismatch_examples.append(name)
continue
src_tensor = src_param.detach().to(
device=ref_param.device, dtype=ref_param.dtype
)
ref_param.data.mul_(beta).add_(src_tensor, alpha=alpha)
updated += 1
if alias_used:
alias_hits += 1
except Exception as exc:
if not bool(getattr(self, "_maxent_ref_ema_error_warned", False)):
LOG.warning(
"Reference EMA update failed once and will be retried on later steps: %s",
exc,
)
setattr(self, "_maxent_ref_ema_error_warned", True)
return
if updated <= 0:
if not bool(getattr(self, "_maxent_ref_ema_no_params_warned", False)):
LOG.warning(
"Reference EMA enabled but no compatible parameters were updated."
)
setattr(self, "_maxent_ref_ema_no_params_warned", True)
return
self._last_reference_ema_step = step
self._append_metric_value("train", "maxent/ref_ema_applied", 1.0)
self._append_metric_value("train", "maxent/ref_ema_beta", beta)
self._append_metric_value(
"train",
"maxent/ref_ema_updated_frac",
float(updated) / float(max(total_ref, 1)),
)
self._append_metric_value(
"train",
"maxent/ref_ema_alias_hit_frac",
float(alias_hits) / float(max(total_ref, 1)),
)
if mismatched > 0 and not bool(
getattr(self, "_maxent_ref_ema_mismatch_warned", False)
):
LOG.warning(
"Reference EMA skipped %d/%d reference parameters due to missing/mismatched policy counterparts. "
"sample_missing=%s",
mismatched,
total_ref,
mismatch_examples,
)
setattr(self, "_maxent_ref_ema_mismatch_warned", True)
def _resolve_effective_maxent_alpha(
self,
mode: str,
*,
measured_kl_override: Optional[float] = None,
) -> Tuple[float, float, Optional[float], float, bool, float, float, float, bool]:
"""Return effective MaxEnt alpha with optional KL-based up/down scaling.
Returns ``(effective_alpha, multiplier, measured_kl, kl_threshold,
kl_control_enabled, direction, min_multiplier, max_multiplier,
trust_zone_blocked)`` where direction is ``+1`` (raised), ``-1``
(lowered/blocked), or ``0`` (unchanged).
"""
del mode
base_alpha = float(getattr(self, "maxent_alpha", 0.0) or 0.0)
if base_alpha <= 0.0:
return 0.0, 1.0, None, 0.0, False, 0.0, 1.0, 1.0, False
args = getattr(self, "args", None)
raise_on_low_kl = bool(getattr(args, "maxent_alpha_raise_on_low_kl", False))
lower_on_high_kl = bool(
getattr(args, "maxent_alpha_lower_on_high_kl", False)
)
trust_zone_gate_enabled = bool(
getattr(args, "maxent_alpha_disable_outside_trust_zone", False)
)
enabled = raise_on_low_kl or lower_on_high_kl or trust_zone_gate_enabled
threshold_raw = getattr(args, "maxent_alpha_kl_threshold", 0.04)
try:
threshold = float(threshold_raw)
except (TypeError, ValueError):
threshold = 0.04
max_mult_raw = getattr(args, "maxent_alpha_kl_max_multiplier", 2.0)
try:
max_multiplier = float(max_mult_raw)
except (TypeError, ValueError):
max_multiplier = 2.0
if not math.isfinite(max_multiplier) or max_multiplier < 1.0:
max_multiplier = 1.0
min_mult_raw = getattr(args, "maxent_alpha_kl_min_multiplier", 0.5)
try:
min_multiplier = float(min_mult_raw)
except (TypeError, ValueError):
min_multiplier = 0.5
if not math.isfinite(min_multiplier) or min_multiplier <= 0.0:
min_multiplier = 0.5
min_multiplier = min(max(min_multiplier, 1e-8), 1.0)
if not math.isfinite(threshold) or threshold <= 0.0:
return (
base_alpha,
1.0,
None,
threshold,
enabled,
0.0,
min_multiplier,
max_multiplier,
False,
)
if not enabled:
return (
base_alpha,
1.0,
None,
threshold,
False,
0.0,
min_multiplier,
max_multiplier,
False,
)
measured_kl: Optional[float] = None
if isinstance(measured_kl_override, (int, float)):
measured_kl = float(measured_kl_override)
else:
cached_kl = getattr(self, "_last_train_kl_for_alpha", None)
if isinstance(cached_kl, (int, float)):
measured_kl = float(cached_kl)
else:
kl_history = self._metrics.get("train", {}).get("kl")
if kl_history:
try:
measured_kl = float(kl_history[-1])
except (TypeError, ValueError):
measured_kl = None
if measured_kl is None:
return (
base_alpha,
1.0,
None,
threshold,
True,
0.0,
min_multiplier,
max_multiplier,
False,
)
if not math.isfinite(measured_kl):
if trust_zone_gate_enabled:
return (
0.0,
0.0,
measured_kl,
threshold,
True,
-1.0,
min_multiplier,
max_multiplier,
True,
)
if lower_on_high_kl:
return (
base_alpha * min_multiplier,
min_multiplier,
measured_kl,
threshold,
True,
-1.0,
min_multiplier,
max_multiplier,
False,
)
return (
base_alpha,
1.0,
measured_kl,
threshold,
True,
0.0,
min_multiplier,
max_multiplier,
False,
)
gain_raw = getattr(args, "maxent_alpha_kl_gain", 1.0)
try:
gain = float(gain_raw)
except (TypeError, ValueError):
gain = 1.0
if not math.isfinite(gain) or gain < 0.0:
gain = 0.0
direction = 0.0
multiplier = 1.0
trust_zone_blocked = False
if measured_kl < threshold and raise_on_low_kl:
low_kl_frac = max(threshold - measured_kl, 0.0) / max(threshold, 1e-8)
multiplier = 1.0 + gain * low_kl_frac
direction = 1.0
elif measured_kl > threshold and trust_zone_gate_enabled:
multiplier = 0.0
direction = -1.0
trust_zone_blocked = True
elif measured_kl > threshold and lower_on_high_kl:
high_kl_frac = max(measured_kl - threshold, 0.0) / max(threshold, 1e-8)
multiplier = 1.0 / (1.0 + gain * high_kl_frac)
direction = -1.0
if not math.isfinite(multiplier):
multiplier = 1.0
direction = 0.0
trust_zone_blocked = False
if trust_zone_blocked:
effective_alpha = 0.0
else:
multiplier = min(max(multiplier, min_multiplier), max_multiplier)
effective_alpha = base_alpha * multiplier
return (
effective_alpha,
multiplier,
measured_kl,
threshold,
True,
direction,
min_multiplier,
max_multiplier,
trust_zone_blocked,
)
def _log_grpo_diversity(
self,
outputs: Dict[str, Any],
*,
mode: str,
) -> None:
completion_ids = outputs.get("completion_ids")
if not isinstance(completion_ids, torch.Tensor):
return
tokenizer = getattr(self, "processing_class", None)
decode = getattr(tokenizer, "batch_decode", None)
if not callable(decode):
return
try:
decode_fn = cast(Callable[..., List[str]], decode)
completions_text = decode_fn( # pylint: disable=not-callable
completion_ids, skip_special_tokens=True
)
except Exception:
return
group_size = max(int(getattr(self, "num_generations", 1) or 1), 1)
usable = len(completions_text) - (len(completions_text) % group_size)
if usable <= 0:
return
if usable != len(completions_text):
completions_text = completions_text[:usable]
grouped = [
completions_text[i : i + group_size]
for i in range(0, usable, group_size)
]
local_only_eval = (
_use_local_only_lightweight_eval_metrics(self, mode)
or _use_local_only_eval_diversity_metrics(self, mode)
)
if local_only_eval and not _is_main_process(self):
return
use_tokenizer = (
tokenizer
if callable(getattr(tokenizer, "encode", None)) or callable(tokenizer)
else None
)
metrics = _completion_diversity_metrics(
grouped,
tokenizer=use_tokenizer,
accelerator=None if local_only_eval else self.accelerator,
)
if metrics:
for key, val in metrics.items():
self._append_metric_value(
mode,
f"completions/diversity/{key}",
float(val),
)
def _maybe_log_rich_rollout_sidecar(
self,
inputs: List[Dict[str, Any]],
outputs: Dict[str, Any],
*,
mode: str,
) -> None:
"""Write prompt-major rollout rows for distribution figures."""
if mode != "train":
return
args = getattr(self, "args", None)
if not bool(getattr(args, "rich_log_completions", False)):
return
if not _is_main_process(self):
return
output_dir = getattr(args, "output_dir", None)
if not isinstance(output_dir, str) or not output_dir.strip():
return
completion_ids = outputs.get("completion_ids")
completion_mask = outputs.get("completion_mask")
advantages = outputs.get("advantages")
tokenizer = getattr(self, "processing_class", None)
decode = getattr(tokenizer, "decode", None)
if not isinstance(completion_ids, torch.Tensor) or not isinstance(
completion_mask, torch.Tensor
):
return
if not isinstance(advantages, torch.Tensor):
return
if not callable(decode):
return
rewards_local = self._recompute_local_rewards_for_outputs(inputs, outputs)
if not isinstance(rewards_local, torch.Tensor) or rewards_local.numel() <= 0:
return
total_rows = min(
int(completion_ids.size(0)),
int(completion_mask.size(0)),
int(advantages.numel()),
int(rewards_local.numel()),
len(inputs),
)
if total_rows <= 0:
return
prompt_texts = [
_build_prompt_text(example, tokenizer)
for example in inputs[:total_rows]
]
completion_texts: List[str] = []
for row_idx in range(total_rows):
mask_row = completion_mask[row_idx].to(torch.long)
active_ids = [
int(tok.item())
for tok, keep in zip(completion_ids[row_idx], mask_row)
if int(keep.item()) != 0
]
try:
completion_text = str(decode(active_ids, skip_special_tokens=True))
except Exception:
completion_text = ""
completion_texts.append(completion_text)
q_grouped = outputs.get("maxent_listwise_q")
q_values: Optional[List[float]] = None
if isinstance(q_grouped, torch.Tensor) and q_grouped.numel() > 0:
try:
q_values = [
float(val)
for val in q_grouped.detach().to(torch.float32).reshape(-1).tolist()
][:total_rows]
except Exception:
q_values = None
step_value = int(getattr(getattr(self, "state", None), "global_step", 0) or 0)
if mode == "train":
step_value += 1
columns, rows = _build_rich_rollout_rows(
step=step_value,
group_size=max(int(getattr(self, "num_generations", 1) or 1), 1),
prompt_texts=prompt_texts,
completion_texts=completion_texts,
rewards=[
float(val)
for val in rewards_local[:total_rows].detach().to(torch.float32).tolist()
],
advantages=[
float(val)
for val in advantages[:total_rows].detach().to(torch.float32).tolist()
],
q_values=q_values,
)
if not rows:
return
table_key = str(
getattr(args, "rich_log_completions_key", "rich_completions")
or "rich_completions"
).strip()
path = _write_rich_rollout_sidecar(
output_dir=output_dir.strip(),
table_key=table_key,
step=step_value,
columns=columns,
rows=rows,
)
if path:
LOG.info(
"Wrote rich rollout sidecar | step=%d rows=%d path=%s",
step_value,
len(rows),
path,
)
def _recompute_grouped_advantages(
self,
inputs: List[Dict[str, Any]],
outputs: Dict[str, Any],
*,
group_size: Optional[int] = None,
) -> Optional[torch.Tensor]:
"""Recompute local GRPO-style advantages after completion postprocessing."""
rewards_local = self._recompute_local_rewards_for_outputs(inputs, outputs)
if not isinstance(rewards_local, torch.Tensor) or rewards_local.numel() <= 0:
return None
rewards = gather(rewards_local)
if not isinstance(rewards, torch.Tensor) or rewards.numel() <= 0:
return None
effective_group = max(
int(group_size or getattr(self, "num_generations", 1) or 1),
1,
)
if int(rewards.numel()) % effective_group != 0:
return None
grouped_rewards = rewards.view(-1, effective_group)
mean_grouped_rewards = grouped_rewards.mean(dim=1)
std_grouped_rewards = grouped_rewards.std(dim=1)
repeated_means = mean_grouped_rewards.repeat_interleave(
effective_group, dim=0
)
repeated_stds = std_grouped_rewards.repeat_interleave(
effective_group, dim=0
)
advantages = rewards - repeated_means
if bool(getattr(self, "scale_rewards", False)):
advantages = advantages / (repeated_stds + 1e-4)
local_count = len(inputs)
process_index = int(getattr(self.accelerator, "process_index", 0) or 0)
process_slice = slice(
process_index * local_count,
(process_index + 1) * local_count,
)
outputs["advantages"] = advantages[process_slice].to(
device=rewards_local.device,
dtype=torch.float32,
)
return rewards_local
def _maybe_backfill_old_per_token_logps(
self,
outputs: Dict[str, Any],
*,
mode: str,
) -> None:
"""Populate rollout behavior log-probs when the parent TRL path omits them.
TRL intentionally skips ``old_per_token_logps`` when the current rollout can
reuse the policy log-probs in the loss. SEED-GRPO needs those rollout
log-probs earlier, during advantage scaling, so backfill them from the
current policy before any truncation/postprocessing changes the sequence.
"""
if mode != "train":
return
if isinstance(outputs.get("old_per_token_logps"), torch.Tensor):
return
args = getattr(self, "args", None)
if not bool(getattr(args, "seed_grpo_enabled", False)):
return
prompt_ids = outputs.get("prompt_ids")
prompt_mask = outputs.get("prompt_mask")
completion_ids = outputs.get("completion_ids")
completion_mask = outputs.get("completion_mask")
if not isinstance(prompt_ids, torch.Tensor) or not isinstance(
prompt_mask, torch.Tensor
):
return
if not isinstance(completion_ids, torch.Tensor) or not isinstance(
completion_mask, torch.Tensor
):
return
logits_to_keep = int(completion_ids.size(1))
if logits_to_keep <= 0:
return
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
configured_batch_size = int(
getattr(args, "per_device_train_batch_size", 1) or 1
)
chunk_size = int(
getattr(args, "maxent_logprob_chunk_size", 0)
or configured_batch_size
or 1
)
behavior_source = str(
getattr(args, "behavior_logprobs_source", "model") or "model"
).strip().lower()
if behavior_source not in {"", "model"} and not bool(
getattr(self, "_seed_grpo_behavior_source_warned", False)
):
LOG.warning(
"SEED-GRPO requested behavior_logprobs_source=%s, but the shared "
"trainer rollout did not return per-token behavior log-probs. "
"Recomputing them from the current policy for rollout parity.",
behavior_source,
)
setattr(self, "_seed_grpo_behavior_source_warned", True)
if not bool(getattr(self, "_seed_grpo_backfill_preflight_logged", False)):
tokenizer = getattr(self, "processing_class", None)
upper_bound = _resolve_token_id_upper_bound(
getattr(self, "model", None),
tokenizer,
)
def _token_range_stats(
tensor: torch.Tensor,
) -> Tuple[Optional[int], Optional[int], int]:
try:
min_token = int(tensor.min().item())
max_token = int(tensor.max().item())
except Exception:
min_token = None
max_token = None
invalid_count = 0
if isinstance(upper_bound, int) and upper_bound > 0:
try:
invalid_count = int(
((tensor < 0) | (tensor >= upper_bound))
.to(torch.long)
.sum()
.item()
)
except Exception:
invalid_count = 0
return min_token, max_token, invalid_count
prompt_min, prompt_max, prompt_invalid = _token_range_stats(prompt_ids)
completion_min, completion_max, completion_invalid = _token_range_stats(
completion_ids
)
LOG.info(
"SEED-GRPO backfill preflight | upper_bound=%s | "
"prompt_ids[min=%s max=%s invalid=%d] | "
"completion_ids[min=%s max=%s invalid=%d]",
upper_bound,
prompt_min,
prompt_max,
prompt_invalid,
completion_min,
completion_max,
completion_invalid,
)
setattr(self, "_seed_grpo_backfill_preflight_logged", True)
try:
with torch.no_grad():
old_per_token_logps = self._get_per_token_logps(
self.model,
input_ids,
attention_mask,
logits_to_keep,
batch_size=chunk_size,
)
except Exception as exc:
if not bool(getattr(self, "_seed_grpo_backfill_warned", False)):
LOG.warning(
"SEED-GRPO could not backfill rollout log-prob metadata; "
"falling back to unscaled GRPO advantages: %s",
exc,
)
setattr(self, "_seed_grpo_backfill_warned", True)
return
if not isinstance(old_per_token_logps, torch.Tensor):
return
outputs["old_per_token_logps"] = old_per_token_logps.detach().to(
device=completion_ids.device,
dtype=torch.float32,
)
self._append_metric_value(
"train",
"seed_grpo/behavior_logprobs_backfilled",
1.0,
include_legacy_aliases=False,
)
def _sanitize_rollout_token_ids(
self,
outputs: Dict[str, Any],
*,
mode: str,
) -> None:
"""Clamp rollout token ids into the train-time model vocab range.
This keeps the shared rollout pipeline robust when tokenizer/vLLM ids
exceed the policy model vocab, which otherwise crashes later log-prob
gathers with CUDA index assertions.
"""
del mode
setattr(self, "_last_rollout_invalid_token_id_count", 0.0)
unwrap_fn = getattr(getattr(self, "accelerator", None), "unwrap_model", None)
base_model = getattr(self, "model", None)
if callable(unwrap_fn):
try:
base_model = unwrap_fn(base_model)
except Exception:
base_model = getattr(self, "model", None)
tokenizer = getattr(self, "processing_class", None)
vocab_size = _resolve_token_id_upper_bound(base_model, tokenizer)
if not isinstance(vocab_size, int) or vocab_size <= 0:
return
replacement_id = _coerce_optional_int(getattr(tokenizer, "pad_token_id", None))
if replacement_id is None or replacement_id < 0 or replacement_id >= vocab_size:
replacement_id = _coerce_optional_int(getattr(tokenizer, "eos_token_id", None))
if replacement_id is None or replacement_id < 0 or replacement_id >= vocab_size:
replacement_id = max(vocab_size - 1, 0)
total_invalid = 0
details: List[str] = []
for key in ("prompt_ids", "completion_ids"):
tensor = outputs.get(key)
if not isinstance(tensor, torch.Tensor):
continue
if tensor.dtype.is_floating_point or tensor.dtype == torch.bool:
continue
invalid_mask = (tensor < 0) | (tensor >= vocab_size)
invalid_count = int(invalid_mask.to(torch.long).sum().item())
if invalid_count <= 0:
continue
total_invalid += invalid_count
try:
invalid_vals = tensor[invalid_mask]
min_invalid = int(invalid_vals.min().item())
max_invalid = int(invalid_vals.max().item())
except Exception:
min_invalid = 0
max_invalid = 0
sanitized = tensor.clone()
sanitized[invalid_mask] = int(replacement_id)
outputs[key] = sanitized
details.append(
f"{key}:count={invalid_count}:min={min_invalid}:max={max_invalid}"
)
if total_invalid <= 0:
return
setattr(self, "_last_rollout_invalid_token_id_count", float(total_invalid))
self._append_metric_value(
"train",
"rollout/invalid_token_id_count",
float(total_invalid),
include_legacy_aliases=False,
)
self._append_metric_value(
"train",
"rollout/invalid_token_id_replacement",
float(replacement_id),
include_legacy_aliases=False,
)
if not bool(getattr(self, "_invalid_rollout_token_ids_warned", False)):
LOG.warning(
"Sanitized %d rollout token ids outside model vocab_size=%d using replacement_id=%d (%s)",
total_invalid,
vocab_size,
replacement_id,
", ".join(details),
)
setattr(self, "_invalid_rollout_token_ids_warned", True)
fatal_flag = str(
os.getenv("MAXENT_FATAL_INVALID_ROLLOUT_TOKEN_IDS", "0")
).strip().lower()
if fatal_flag in {"1", "true", "yes", "on"}:
self._append_metric_value(
"train",
"rollout/invalid_token_id_guard_triggered",
1.0,
include_legacy_aliases=False,
)
raise RuntimeError(
"Detected rollout token ids outside the tokenizer-addressable "
f"range (count={total_invalid}, vocab_size={vocab_size}, "
f"replacement_id={replacement_id})."
)
def _maybe_apply_seed_grpo_advantages(
self,
inputs: List[Dict[str, Any]],
outputs: Dict[str, Any],
*,
mode: str,
group_size: Optional[int] = None,
) -> None:
"""Apply SEED-GRPO semantic-entropy scaling to prepared advantages."""
if mode != "train":
return
args = getattr(self, "args", None)
if not bool(getattr(args, "seed_grpo_enabled", False)):
return
advantages = outputs.get("advantages")
completion_ids = outputs.get("completion_ids")
completion_mask = outputs.get("completion_mask")
old_per_token_logps = outputs.get("old_per_token_logps")
if not isinstance(advantages, torch.Tensor):
return
if not isinstance(completion_ids, torch.Tensor):
return
if not isinstance(completion_mask, torch.Tensor):
return
if not isinstance(old_per_token_logps, torch.Tensor):
if not bool(getattr(self, "_seed_grpo_missing_logprobs_warned", False)):
LOG.warning(
"SEED-GRPO enabled but rollout log-prob metadata is missing; "
"falling back to unscaled GRPO advantages."
)
setattr(self, "_seed_grpo_missing_logprobs_warned", True)
return
if old_per_token_logps.ndim != 2:
if not bool(
getattr(self, "_seed_grpo_logprob_rank_mismatch_warned", False)
):
LOG.warning(
"SEED-GRPO requires a rank-2 per-token logprob tensor; got shape=%s. "
"Falling back to unscaled GRPO advantages.",
getattr(old_per_token_logps, "shape", None),
)
setattr(
self,
"_seed_grpo_logprob_rank_mismatch_warned",
True,
)
return
if int(old_per_token_logps.size(0)) != int(completion_ids.size(0)):
if not bool(
getattr(self, "_seed_grpo_logprob_row_mismatch_warned", False)
):
LOG.warning(
"SEED-GRPO requires rollout logprobs aligned with completions; "
"got logprob_rows=%d completion_rows=%d. Falling back to unscaled GRPO advantages.",
int(old_per_token_logps.size(0)),
int(completion_ids.size(0)),
)
setattr(
self,
"_seed_grpo_logprob_row_mismatch_warned",
True,
)
return
try:
required_width = int(
completion_mask.to(torch.long).sum(dim=1).max().item()
)
except Exception:
required_width = 0
if required_width > int(old_per_token_logps.size(1)):
if not bool(
getattr(self, "_seed_grpo_logprob_width_mismatch_warned", False)
):
LOG.warning(
"SEED-GRPO requires rollout logprobs wide enough for active completion tokens; "
"got logprob_width=%d required_width=%d. Falling back to unscaled GRPO advantages.",
int(old_per_token_logps.size(1)),
required_width,
)
setattr(
self,
"_seed_grpo_logprob_width_mismatch_warned",
True,
)
return
tokenizer = getattr(self, "processing_class", None)
decode = getattr(tokenizer, "decode", None)
if not callable(decode):
return
effective_group = max(
int(group_size or getattr(self, "num_generations", 1) or 1),
1,
)
total = int(completion_ids.size(0))
if total <= 0 or total % effective_group != 0:
if not bool(getattr(self, "_seed_grpo_group_shape_warned", False)):
LOG.warning(
"SEED-GRPO requires local rollout batches to contain whole "
"prompt groups; got batch=%d with num_generations=%d. "
"Falling back to unscaled GRPO advantages.",
total,
effective_group,
)
setattr(self, "_seed_grpo_group_shape_warned", True)
return
grouped_completions: List[List[str]] = []
grouped_ref_meta: List[List[Dict[str, Any]]] = []
for start in range(0, total, effective_group):
completion_group: List[str] = []
meta_group: List[Dict[str, Any]] = []
for row_idx in range(start, start + effective_group):
mask_row = completion_mask[row_idx].to(torch.long)
active_len = int(mask_row.sum().item())
active_ids = [
int(tok.item())
for tok, keep in zip(completion_ids[row_idx], mask_row)
if int(keep.item()) != 0
]
try:
completion_text = str(
decode(active_ids, skip_special_tokens=True)
)
except Exception:
completion_text = ""
completion_group.append(completion_text)
token_logps = old_per_token_logps[row_idx, :active_len]
logprob_sum = (
float(token_logps.sum().item()) if active_len > 0 else 0.0
)
meta_group.append(
{
"logprob_sum": logprob_sum,
"token_count": active_len,
}
)
grouped_completions.append(completion_group)
grouped_ref_meta.append(meta_group)
try:
(
semantic_entropies,
advantage_scales,
alpha_effective,
max_possible_entropy,
) = _compute_seed_grpo_statistics(
SimpleNamespace(
grouped_completions=grouped_completions,
grouped_ref_meta=grouped_ref_meta,
),
alpha=float(getattr(args, "seed_grpo_alpha", 0.0417) or 0.0417),
normalize_by_max_entropy=bool(
getattr(
args,
"seed_grpo_alpha_normalize_by_max_entropy",
True,
)
),
length_normalize_logprobs=bool(
getattr(args, "seed_grpo_length_normalize_logprobs", True)
),
num_generations=int(
getattr(args, "num_generations", effective_group)
or effective_group
),
)
except Exception as exc:
if not bool(getattr(self, "_seed_grpo_compute_warned", False)):
LOG.warning(
"SEED-GRPO scaling failed during rollout prep; falling back "
"to unscaled GRPO advantages. Error: %s",
exc,
)
setattr(self, "_seed_grpo_compute_warned", True)
return
if not advantage_scales:
return
repeated_scales = torch.tensor(
[
float(scale)
for scale in advantage_scales
for _ in range(effective_group)
],
device=advantages.device,
dtype=advantages.dtype,
)
if int(repeated_scales.numel()) != int(advantages.numel()):
return
outputs["advantages"] = advantages * repeated_scales
outputs["seed_grpo_semantic_entropies"] = torch.tensor(
semantic_entropies,
device=advantages.device,
dtype=torch.float32,
)
outputs["seed_grpo_advantage_scales"] = torch.tensor(
advantage_scales,
device=advantages.device,
dtype=torch.float32,
)
# Keep SEED diagnostics rank-local. These metrics are not used for
# correctness, and cross-rank gathers here can desync collectives if
# one rank bails out of SEED scaling earlier than another.
local_entropies = _local_metric_tensor(outputs["seed_grpo_semantic_entropies"])
if isinstance(local_entropies, torch.Tensor) and local_entropies.numel() > 0:
self._append_metric_value(
mode,
"seed_grpo/semantic_entropy_mean",
float(local_entropies.mean().item()),
)
self._append_metric_value(
mode,
"seed_grpo/semantic_entropy_min",
float(local_entropies.min().item()),
)
self._append_metric_value(
mode,
"seed_grpo/semantic_entropy_max",
float(local_entropies.max().item()),
)
local_scales = _local_metric_tensor(outputs["seed_grpo_advantage_scales"])
if isinstance(local_scales, torch.Tensor) and local_scales.numel() > 0:
self._append_metric_value(
mode,
"seed_grpo/advantage_scale_mean",
float(local_scales.mean().item()),
)
self._append_metric_value(
mode,
"seed_grpo/advantage_scale_min",
float(local_scales.min().item()),
)
self._append_metric_value(
mode,
"seed_grpo/advantage_scale_max",
float(local_scales.max().item()),
)
self._append_metric_value(
mode,
"seed_grpo/alpha_effective",
float(alpha_effective),
)
self._append_metric_value(
mode,
"seed_grpo/max_possible_entropy",
float(max_possible_entropy),
)
def _maybe_apply_seed_grpo_advantages_in_loss(
self,
inputs: Dict[str, Any],
*,
completion_ids: torch.Tensor,
completion_mask: torch.Tensor,
behavior_logps: torch.Tensor,
mode: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply SEED-GRPO scaling from the normal loss-path logprobs.
The shared rollout path can omit ``old_per_token_logps`` when the
current policy and rollout behavior are the same. For SEED-GRPO we
only need grouped logprob sums/token counts, so defer that scaling to
the loss path and reuse the already-computed policy logprobs instead
of running a second scorer pass immediately after generation.
"""
advantages = inputs.get("advantages")
if not isinstance(advantages, torch.Tensor):
return advantages, behavior_logps
args = getattr(self, "args", None)
if mode != "train" or not bool(getattr(args, "seed_grpo_enabled", False)):
return advantages, behavior_logps
if not isinstance(behavior_logps, torch.Tensor):
return advantages, behavior_logps
scaled_batch_ids = getattr(self, "_seed_grpo_scaled_batch_ids", None)
if not isinstance(scaled_batch_ids, set):
scaled_batch_ids = set()
setattr(self, "_seed_grpo_scaled_batch_ids", scaled_batch_ids)
if id(inputs) in scaled_batch_ids:
existing_advantages = inputs.get("advantages")
existing_logps = inputs.get("old_per_token_logps")
if isinstance(existing_advantages, torch.Tensor) and isinstance(
existing_logps, torch.Tensor
):
return existing_advantages, existing_logps
return advantages, behavior_logps
seed_outputs: Dict[str, Any] = {
"advantages": advantages,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"old_per_token_logps": behavior_logps,
}
self._maybe_apply_seed_grpo_advantages([], seed_outputs, mode=mode)
scaled_advantages = seed_outputs.get("advantages")
scaled_logps = seed_outputs.get("old_per_token_logps")
if not isinstance(scaled_advantages, torch.Tensor):
scaled_advantages = advantages
if not isinstance(scaled_logps, torch.Tensor):
scaled_logps = behavior_logps
inputs["advantages"] = scaled_advantages
inputs["old_per_token_logps"] = scaled_logps.detach()
scaled_batch_ids.add(id(inputs))
self._append_metric_value(
mode,
"seed_grpo/behavior_logprobs_deferred_to_loss",
1.0,
include_legacy_aliases=False,
)
return scaled_advantages, scaled_logps
def _maybe_truncate_completions_at_first_boxed_answer(
self,
inputs: List[Dict[str, Any]],
outputs: Dict[str, Any],
*,
mode: str,
group_size: Optional[int] = None,
update_logged_metrics: bool = True,
) -> None:
"""Trim generated completions at the first valid boxed answer when configured."""
args = getattr(self, "args", None)
if not bool(
getattr(args, "truncate_completions_at_first_boxed_answer", False)
):
return
completion_ids = outputs.get("completion_ids")
if not isinstance(completion_ids, torch.Tensor):
return
completion_mask = outputs.get("completion_mask")
if not isinstance(completion_mask, torch.Tensor):
completion_mask = _apply_eos_completion_mask(
completion_ids,
getattr(self.processing_class, "eos_token_id", None),
)
if not isinstance(completion_mask, torch.Tensor):
return
tokenizer = getattr(self, "processing_class", None)
decode = getattr(tokenizer, "decode", None)
if not callable(decode):
return
old_per_token_logps = outputs.get("old_per_token_logps")
old_log_rows: List[torch.Tensor] = []
truncated_rows: List[List[int]] = []
trimmed = 0
for row, mask_row in zip(completion_ids, completion_mask):
active_len = int(mask_row.to(torch.long).sum().item())
active_ids = [
int(tok.item()) for tok in row[:active_len]
]
if active_len <= 0:
truncated_rows.append([])
if isinstance(old_per_token_logps, torch.Tensor):
old_log_rows.append(
old_per_token_logps.new_zeros((0,))
)
continue
try:
text = str(decode(active_ids, skip_special_tokens=True))
except Exception:
text = ""
truncated_text = truncate_after_first_boxed_answer(text)
prefix_len = _find_token_prefix_len_for_text(
tokenizer,
active_ids,
truncated_text,
)
if (
truncated_text
and truncated_text != text
and prefix_len is not None
and 0 < prefix_len < active_len
):
active_ids = active_ids[:prefix_len]
trimmed += 1
truncated_rows.append(active_ids)
if isinstance(old_per_token_logps, torch.Tensor):
old_log_rows.append(
old_per_token_logps[
len(old_log_rows), : len(active_ids)
].detach()
)
if trimmed <= 0:
return
pad_token_id = getattr(tokenizer, "pad_token_id", None)
if pad_token_id is None:
pad_token_id = getattr(tokenizer, "eos_token_id", 0)
new_completion_ids, new_completion_mask = _pad_completion_rows(
truncated_rows,
pad_token_id=int(pad_token_id or 0),
device=completion_ids.device,
)
outputs["completion_ids"] = new_completion_ids
outputs["completion_mask"] = new_completion_mask
if isinstance(old_per_token_logps, torch.Tensor):
outputs["old_per_token_logps"] = _pad_logprob_rows(
old_log_rows,
device=old_per_token_logps.device,
dtype=old_per_token_logps.dtype,
)
skip_advantage_recompute = _use_lightweight_greedy_eval(
self,
mode,
)
rewards_local: Optional[torch.Tensor] = None
if not skip_advantage_recompute:
rewards_local = self._recompute_grouped_advantages(
inputs,
outputs,
group_size=group_size,
)
if update_logged_metrics:
if isinstance(rewards_local, torch.Tensor) and rewards_local.numel() > 0:
gathered_rewards = _metric_tensor_for_logging(
self,
rewards_local.to(torch.float32),
mode=mode,
)
if (
isinstance(gathered_rewards, torch.Tensor)
and gathered_rewards.numel() > 0
):
effective_group = max(
int(group_size or getattr(self, "num_generations", 1) or 1),
1,
)
if int(gathered_rewards.numel()) % effective_group == 0:
grouped_rewards = gathered_rewards.view(-1, effective_group)
reward_mean = grouped_rewards.mean(dim=1)
reward_std = grouped_rewards.std(dim=1)
self._set_latest_metric_value(
mode,
"reward",
float(reward_mean.mean().item()),
)
self._set_latest_metric_value(
mode,
"reward_std",
float(reward_std.mean().item()),
)
self._set_latest_metric_value(
mode,
"frac_reward_zero_std",
float(
torch.isclose(
reward_std,
torch.zeros_like(reward_std),
)
.to(torch.float32)
.mean()
.item()
),
)
completion_lengths = new_completion_mask.sum(dim=1).to(torch.float32)
gathered_lengths = _metric_tensor_for_logging(
self,
completion_lengths,
mode=mode,
)
if (
isinstance(gathered_lengths, torch.Tensor)
and gathered_lengths.numel() > 0
):
self._set_latest_metric_value(
mode,
"completions/mean_length",
float(gathered_lengths.mean().item()),
)
self._set_latest_metric_value(
mode,
"completions/min_length",
float(gathered_lengths.min().item()),
)
self._set_latest_metric_value(
mode,
"completions/max_length",
float(gathered_lengths.max().item()),
)
if not (
_use_local_only_lightweight_eval_metrics(self, mode)
and not _is_main_process(self)
):
trim_ratio = float(trimmed) / float(max(len(truncated_rows), 1))
self._append_metric_value(
mode,
"completions/boxed_stop_ratio",
trim_ratio,
)
def _prepare_greedy_eval_prompt_batch(
self,
inputs: List[Dict[str, Any]],
) -> Tuple[List[Dict[str, Any]], torch.Tensor, torch.Tensor]:
"""Deduplicate prompt-major eval groups and tokenize one prompt per group."""
if bool(
getattr(self, "_local_only_eval_prompt_major_loader_active", False)
):
prompt_inputs = list(inputs)
else:
group_size = max(int(getattr(self, "num_generations", 1) or 1), 1)
usable = len(inputs) - (len(inputs) % group_size)
if usable <= 0:
raise ValueError(
"Greedy eval requires at least one full prompt group."
)
prompt_inputs = list(inputs[:usable:group_size])
if not prompt_inputs:
raise ValueError("Greedy eval requires at least one prompt example.")
tokenizer = getattr(self, "processing_class", None)
prompts_text = [
_build_prompt_text(example, tokenizer) for example in prompt_inputs
]
prompt_tensors = tokenizer(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
)
prompt_tensors = self._move_prompt_tensors_to_device(prompt_tensors)
prompt_ids = prompt_tensors["input_ids"]
prompt_mask = prompt_tensors["attention_mask"]
max_prompt_length = getattr(self, "max_prompt_length", None)
if max_prompt_length is not None:
prompt_ids = prompt_ids[:, -max_prompt_length :]
prompt_mask = prompt_mask[:, -max_prompt_length :]
return prompt_inputs, prompt_ids, prompt_mask
def _move_prompt_tensors_to_device(self, value: Any) -> Any:
"""Move tokenizer outputs to device without re-entering TRL batch prep."""
device = getattr(getattr(self, "accelerator", None), "device", None)
if device is None:
device = getattr(getattr(self, "args", None), "device", None)
if isinstance(value, torch.Tensor):
return value.to(device=device) if device is not None else value
move_fn = getattr(value, "to", None)
if callable(move_fn) and not isinstance(value, (str, bytes)):
if device is None:
return value
try:
return move_fn(device=device)
except TypeError:
try:
return move_fn(device)
except TypeError:
pass
if isinstance(value, Mapping):
return {
key: self._move_prompt_tensors_to_device(item)
for key, item in value.items()
}
if isinstance(value, list):
return [self._move_prompt_tensors_to_device(item) for item in value]
if isinstance(value, tuple):
return tuple(
self._move_prompt_tensors_to_device(item) for item in value
)
return value
def _generate_greedy_eval_outputs(
self,
inputs: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Generate one greedy completion per prompt for lightweight eval."""
prompt_inputs, prompt_ids, prompt_mask = self._prepare_greedy_eval_prompt_batch(
inputs
)
lightweight_eval = _use_lightweight_greedy_eval(self, "eval")
tokenizer = getattr(self, "processing_class", None)
eos_token_id = getattr(tokenizer, "eos_token_id", None)
pad_token_id = getattr(tokenizer, "pad_token_id", None)
if pad_token_id is None:
pad_token_id = eos_token_id
max_new_tokens = int(
getattr(self, "max_completion_length", 0)
or getattr(getattr(self, "args", None), "max_completion_length", 0)
or 0
)
if max_new_tokens <= 0:
raise ValueError("Greedy eval requires a positive max_completion_length.")
unwrap_fn = getattr(self.accelerator, "unwrap_model", None)
gen_model = self.model
if callable(unwrap_fn):
try:
gen_model = unwrap_fn(gen_model)
except Exception:
gen_model = self.model
generate_fn = getattr(gen_model, "generate", None)
if not callable(generate_fn):
raise ValueError("Greedy eval requires model.generate to be available.")
was_training = bool(getattr(gen_model, "training", False))
if was_training:
gen_model.eval()
generate_kwargs: Dict[str, Any] = {
"input_ids": prompt_ids,
"attention_mask": prompt_mask,
"do_sample": False,
"max_new_tokens": max_new_tokens,
"num_return_sequences": 1,
}
if pad_token_id is not None:
generate_kwargs["pad_token_id"] = int(pad_token_id)
if eos_token_id is not None:
generate_kwargs["eos_token_id"] = int(eos_token_id)
if (
getattr(self.accelerator, "num_processes", 1) > 1
and not lightweight_eval
):
generate_kwargs["synced_gpus"] = True
try:
with torch.inference_mode():
try:
generated = generate_fn(**generate_kwargs)
except TypeError as exc:
if "synced_gpus" not in str(exc):
raise
generate_kwargs.pop("synced_gpus", None)
generated = generate_fn(**generate_kwargs)
finally:
if was_training:
gen_model.train()
sequences = getattr(generated, "sequences", generated)
if not isinstance(sequences, torch.Tensor):
raise ValueError("Greedy eval generation did not return tensor sequences.")
prompt_width = int(prompt_ids.size(1))
completion_ids = sequences[:, prompt_width:].contiguous()
completion_mask = _apply_eos_completion_mask(completion_ids, eos_token_id)
outputs: Dict[str, Any] = {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"advantages": torch.zeros(
(completion_ids.size(0),),
dtype=torch.float32,
device=completion_ids.device,
),
"old_per_token_logps": None,
"_greedy_eval_precomputed": True,
"_eval_prompt_inputs": prompt_inputs,
}
self._maybe_truncate_completions_at_first_boxed_answer(
prompt_inputs,
outputs,
mode="eval",
group_size=1,
update_logged_metrics=False,
)
rewards_local = self._recompute_local_rewards_for_outputs(
prompt_inputs,
outputs,
)
if isinstance(rewards_local, torch.Tensor) and rewards_local.numel() > 0:
gathered_rewards = _metric_tensor_for_logging(
self,
rewards_local.to(torch.float32),
mode="eval",
)
if isinstance(gathered_rewards, torch.Tensor) and gathered_rewards.numel() > 0:
self._append_metric_value(
"eval",
"reward",
float(gathered_rewards.mean().item()),
)
self._append_metric_value("eval", "reward_std", 0.0)
self._append_metric_value("eval", "frac_reward_zero_std", 1.0)
completion_lengths = outputs["completion_mask"].sum(dim=1).to(torch.float32)
gathered_lengths = _metric_tensor_for_logging(
self,
completion_lengths,
mode="eval",
)
if isinstance(gathered_lengths, torch.Tensor) and gathered_lengths.numel() > 0:
self._append_metric_value(
"eval",
"completions/mean_length",
float(gathered_lengths.mean().item()),
)
self._append_metric_value(
"eval",
"completions/min_length",
float(gathered_lengths.min().item()),
)
self._append_metric_value(
"eval",
"completions/max_length",
float(gathered_lengths.max().item()),
)
self._append_metric_value("eval", "completions/clipped_ratio", 0.0)
self._append_metric_value(
"eval",
"completions/mean_terminated_length",
float(gathered_lengths.mean().item()),
)
self._append_metric_value(
"eval",
"completions/min_terminated_length",
float(gathered_lengths.min().item()),
)
self._append_metric_value(
"eval",
"completions/max_terminated_length",
float(gathered_lengths.max().item()),
)
return outputs
def _log_eval_pass_at_k(
self,
inputs: List[Dict[str, Any]],
outputs: Dict[str, Any],
*,
mode: str,
) -> None:
"""Log global and per-benchmark pass@8/pass@1/mean@1 metrics."""
if mode != "eval":
return
group_size = max(int(getattr(self, "num_generations", 1) or 1), 1)
if group_size <= 0:
return
target_k = 8
if group_size < target_k:
if not bool(getattr(self, "_pass_at_8_warned", False)):
LOG.warning(
"Skipping eval pass_at_8 metrics because num_generations=%d < 8.",
group_size,
)
setattr(self, "_pass_at_8_warned", True)
return
def _reshape_eval_rollouts(
flat_values: torch.Tensor,
) -> Optional[torch.Tensor]:
"""Reshape prompt-major flat rollouts into prompt-major groups."""
usable = (int(flat_values.numel()) // group_size) * group_size
if usable <= 0:
return None
if usable != int(flat_values.numel()):
flat_values = flat_values[:usable]
num_prompts = usable // group_size
if num_prompts <= 0:
return None
# TRL rollout order is prompt-major: [p0g0, p0g1, ..., p1g0, p1g1, ...].
return flat_values.view(num_prompts, group_size)
def _gather_benchmark_ids(
expected_flat_count: int,
) -> Optional[torch.Tensor]:
"""Return prompt-major benchmark ids aligned with ``successes``."""
if expected_flat_count <= 0:
return None
keys = ("eval_benchmark_id", "benchmark_id")
raw_vals: Optional[List[Any]] = None
for key in keys:
candidate = [example.get(key) for example in inputs]
if candidate and any(val is not None for val in candidate):
raw_vals = candidate
break
if not raw_vals:
return None
usable = min(len(raw_vals), expected_flat_count)
usable = usable - (usable % group_size)
if usable <= 0:
return None
ids: List[int] = []
for val in raw_vals[:usable]:
try:
ids.append(int(val) if val is not None else -1)
except (TypeError, ValueError):
ids.append(-1)
ids_tensor = torch.tensor(
ids,
dtype=torch.long,
device=self.accelerator.device,
)
ids_global = gather(ids_tensor)
grouped_ids = _reshape_eval_rollouts(ids_global)
if grouped_ids is None:
return None
return grouped_ids[:, 0].to(torch.long)
def _append_per_benchmark_metrics(successes_tensor: torch.Tensor) -> None:
"""Append per-benchmark pass metrics when benchmark ids are present."""
if successes_tensor.numel() <= 0:
return
total_prompts = int(successes_tensor.size(0))
benchmark_ids = _gather_benchmark_ids(
total_prompts * int(successes_tensor.size(1))
)
if not isinstance(benchmark_ids, torch.Tensor):
return
if benchmark_ids.numel() != total_prompts:
return
id_to_name = getattr(self, "eval_benchmark_id_to_name", {}) or {}
unique_ids = torch.unique(benchmark_ids)
for bench_id_tensor in unique_ids:
bench_id = int(bench_id_tensor.item())
if bench_id < 0:
continue
mask = benchmark_ids == bench_id_tensor
bench_count = int(mask.to(torch.long).sum().item())
if bench_count <= 0:
continue
bench_successes = successes_tensor[mask]
bench_pass_at_8 = float(
bench_successes.any(dim=1).to(torch.float32).mean().item()
)
bench_pass_at_1 = float(
bench_successes[:, 0].to(torch.float32).mean().item()
)
bench_mean_at_1 = float(
bench_successes[:, :1].to(torch.float32).mean().item()
)
bench_label = id_to_name.get(bench_id, f"BENCH_{bench_id}")
suffix = _metric_suffix_from_benchmark(bench_label)
self._append_metric_value(
mode, f"pass_at_8_{suffix}", bench_pass_at_8
)
self._append_metric_value(
mode, f"pass_at_1_{suffix}", bench_pass_at_1
)
self._append_metric_value(
mode, f"mean_at_1_{suffix}", bench_mean_at_1
)
successes: Optional[torch.Tensor] = None
reward_funcs = list(getattr(self, "reward_funcs", []) or [])
if uses_pure_accuracy_math_reward(reward_funcs):
completion_ids = outputs.get("completion_ids")
tokenizer = getattr(self, "processing_class", None)
decode = getattr(tokenizer, "batch_decode", None)
if isinstance(completion_ids, torch.Tensor) and callable(decode):
try:
decode_fn = cast(Callable[..., List[str]], decode)
completions_text = decode_fn( # pylint: disable=not-callable
completion_ids, skip_special_tokens=True
)
except Exception:
completions_text = []
answers = [str(example.get("answer", "")) for example in inputs]
usable_local = min(len(completions_text), len(answers))
usable_local = usable_local - (usable_local % group_size)
if usable_local > 0:
# Paper-facing pass metrics: exact canonical answer match,
# allowing only a final-line exact fallback (no shaping).
correctness_local = pure_accuracy_math_correctness(
completions_text[:usable_local],
answers[:usable_local],
allow_last_line_fallback=True,
)
local_successes = torch.tensor(
correctness_local,
dtype=torch.bool,
device=completion_ids.device,
)
global_successes = gather(local_successes)
grouped_successes = _reshape_eval_rollouts(global_successes)
if grouped_successes is not None:
successes = grouped_successes[:, :target_k]
if successes is None:
try:
rewards = self._recompute_global_rewards_for_outputs(
inputs, outputs
)
except Exception as exc:
LOG.debug(
"Skipping eval pass@k logging due to reward error: %s", exc
)
return
if not isinstance(rewards, torch.Tensor) or rewards.numel() <= 0:
return
grouped_rewards = _reshape_eval_rollouts(rewards)
if grouped_rewards is None:
return
grouped_rewards = grouped_rewards[:, :target_k]
# Fallback for non-math/custom reward functions.
successes = grouped_rewards >= (
_PASS_METRIC_SUCCESS_REWARD - _PASS_METRIC_EPS
)
pass_at_8 = float(successes.any(dim=1).to(torch.float32).mean().item())
pass_at_1 = float(successes[:, 0].to(torch.float32).mean().item())
mean_at_1 = float(successes[:, :1].to(torch.float32).mean().item())
mean_at_8 = float(successes.to(torch.float32).mean().item())
self._append_metric_value(mode, "pass_at_8", pass_at_8)
self._append_metric_value(mode, "pass_at_1", pass_at_1)
self._append_metric_value(mode, "mean_at_1", mean_at_1)
self._append_metric_value(mode, "mean_at_8", mean_at_8)
_append_per_benchmark_metrics(successes)
def _log_eval_greedy_metrics(
self,
inputs: List[Dict[str, Any]],
outputs: Dict[str, Any],
*,
mode: str,
) -> None:
"""Log a deterministic greedy pass@1 eval beside sampled metrics."""
if mode != "eval":
return
args = getattr(self, "args", None)
if not bool(getattr(args, "greedy_eval_enabled", False)):
return
lightweight_eval = _use_lightweight_greedy_eval(self, mode)
local_only_eval = _use_local_only_lightweight_eval_metrics(self, mode)
if local_only_eval and not _is_main_process(self):
return
precomputed = bool(outputs.get("_greedy_eval_precomputed", False))
if precomputed:
prompt_inputs = outputs.get("_eval_prompt_inputs")
completion_ids = outputs.get("completion_ids")
completion_mask = outputs.get("completion_mask")
if not isinstance(prompt_inputs, list):
return
if not isinstance(completion_ids, torch.Tensor) or not isinstance(
completion_mask, torch.Tensor
):
return
greedy_outputs = {
"completion_ids": completion_ids,
"completion_mask": completion_mask,
}
else:
prompt_ids = outputs.get("prompt_ids")
prompt_mask = outputs.get("prompt_mask")
if not isinstance(prompt_ids, torch.Tensor) or not isinstance(
prompt_mask, torch.Tensor
):
return
group_size = max(int(getattr(self, "num_generations", 1) or 1), 1)
usable = min(
len(inputs),
int(prompt_ids.size(0)),
int(prompt_mask.size(0)),
)
usable = usable - (usable % group_size)
if usable <= 0:
return
prompt_inputs = list(inputs[:usable:group_size])
prompt_ids = prompt_ids[:usable:group_size].contiguous()
prompt_mask = prompt_mask[:usable:group_size].contiguous()
if prompt_ids.numel() <= 0 or not prompt_inputs:
return
unwrap_fn = getattr(self.accelerator, "unwrap_model", None)
gen_model = self.model
if callable(unwrap_fn):
try:
gen_model = unwrap_fn(gen_model)
except Exception:
gen_model = self.model
generate_fn = getattr(gen_model, "generate", None)
if not callable(generate_fn):
if not bool(
getattr(self, "_greedy_eval_generate_warned", False)
):
LOG.warning(
"Skipping greedy eval metrics because model.generate is unavailable."
)
setattr(self, "_greedy_eval_generate_warned", True)
return
tokenizer = getattr(self, "processing_class", None)
eos_token_id = getattr(tokenizer, "eos_token_id", None)
pad_token_id = getattr(tokenizer, "pad_token_id", None)
if pad_token_id is None:
pad_token_id = eos_token_id
max_new_tokens = int(
getattr(self, "max_completion_length", 0)
or getattr(args, "max_completion_length", 0)
or 0
)
if max_new_tokens <= 0:
if not bool(
getattr(self, "_greedy_eval_length_warned", False)
):
LOG.warning(
"Skipping greedy eval metrics because max_completion_length is invalid."
)
setattr(self, "_greedy_eval_length_warned", True)
return
was_training = bool(getattr(gen_model, "training", False))
if was_training:
gen_model.eval()
generate_kwargs: Dict[str, Any] = {
"input_ids": prompt_ids,
"attention_mask": prompt_mask,
"do_sample": False,
"max_new_tokens": max_new_tokens,
"num_return_sequences": 1,
}
if pad_token_id is not None:
generate_kwargs["pad_token_id"] = int(pad_token_id)
if eos_token_id is not None:
generate_kwargs["eos_token_id"] = int(eos_token_id)
if (
getattr(self.accelerator, "num_processes", 1) > 1
and not lightweight_eval
):
generate_kwargs["synced_gpus"] = True
try:
with torch.inference_mode():
try:
generated = generate_fn(**generate_kwargs)
except TypeError as exc:
if "synced_gpus" not in str(exc):
raise
generate_kwargs.pop("synced_gpus", None)
generated = generate_fn(**generate_kwargs)
except Exception as exc:
if not bool(
getattr(self, "_greedy_eval_failed_warned", False)
):
LOG.warning(
"Skipping greedy eval metrics because greedy generation failed: %s",
exc,
)
setattr(self, "_greedy_eval_failed_warned", True)
return
finally:
if was_training:
gen_model.train()
sequences = getattr(generated, "sequences", generated)
if not isinstance(sequences, torch.Tensor):
return
prompt_width = int(prompt_ids.size(1))
if int(sequences.size(1)) < prompt_width:
return
completion_ids = sequences[:, prompt_width:].contiguous()
completion_mask = _apply_eos_completion_mask(
completion_ids,
eos_token_id,
)
greedy_outputs = {
"completion_ids": completion_ids,
"completion_mask": completion_mask,
}
self._maybe_truncate_completions_at_first_boxed_answer(
prompt_inputs,
greedy_outputs,
mode=mode,
group_size=1,
update_logged_metrics=False,
)
completion_ids = greedy_outputs["completion_ids"]
completion_mask = greedy_outputs["completion_mask"]
rewards = self._recompute_local_rewards_for_outputs(
prompt_inputs,
greedy_outputs,
)
if not isinstance(rewards, torch.Tensor) or rewards.numel() <= 0:
return
global_rewards = (
rewards.to(torch.float32)
if local_only_eval
else gather(rewards.to(torch.float32))
)
if isinstance(global_rewards, torch.Tensor) and global_rewards.numel() > 0:
self._append_metric_value(
mode,
"greedy/reward",
float(global_rewards.mean().item()),
)
reward_funcs = list(getattr(self, "reward_funcs", []) or [])
successes: Optional[torch.Tensor] = None
if uses_pure_accuracy_math_reward(reward_funcs):
tokenizer = getattr(self, "processing_class", None)
decode = getattr(tokenizer, "batch_decode", None)
if callable(decode):
try:
decode_fn = cast(Callable[..., List[str]], decode)
completions_text = decode_fn(
completion_ids, skip_special_tokens=True
)
except Exception:
completions_text = []
answers = [str(example.get("answer", "")) for example in prompt_inputs]
usable_local = min(len(completions_text), len(answers))
if usable_local > 0:
correctness_local = pure_accuracy_math_correctness(
completions_text[:usable_local],
answers[:usable_local],
allow_last_line_fallback=True,
)
local_successes = torch.tensor(
correctness_local,
dtype=torch.bool,
device=completion_ids.device,
)
successes = (
local_successes
if local_only_eval
else gather(local_successes)
)
if successes is None:
local_successes = (
rewards >= (_PASS_METRIC_SUCCESS_REWARD - _PASS_METRIC_EPS)
).to(torch.bool)
successes = (
local_successes
if local_only_eval
else gather(local_successes)
)
if not isinstance(successes, torch.Tensor) or successes.numel() <= 0:
return
successes = successes.to(torch.bool)
pass_at_1 = float(successes.to(torch.float32).mean().item())
self._append_metric_value(mode, "greedy/pass_at_1", pass_at_1)
self._append_metric_value(mode, "greedy/mean_at_1", pass_at_1)
if bool(getattr(args, "eval_greedy_only_enabled", False)):
self._append_metric_value(mode, "pass_at_1", pass_at_1)
self._append_metric_value(mode, "mean_at_1", pass_at_1)
try:
completion_lengths = completion_mask.sum(dim=1).to(torch.float32)
gathered_lengths = (
completion_lengths if local_only_eval else gather(completion_lengths)
)
if (
isinstance(gathered_lengths, torch.Tensor)
and gathered_lengths.numel() > 0
):
self._append_metric_value(
mode,
"greedy/completions/mean_length",
float(gathered_lengths.mean().item()),
)
except Exception:
pass
benchmark_ids = _gather_eval_benchmark_ids_for_prompts(
self,
prompt_inputs,
device=completion_ids.device,
local_only=local_only_eval,
)
if not isinstance(benchmark_ids, torch.Tensor):
return
if benchmark_ids.numel() != successes.numel():
return
id_to_name = getattr(self, "eval_benchmark_id_to_name", {}) or {}
for bench_id_tensor in torch.unique(benchmark_ids):
bench_id = int(bench_id_tensor.item())
if bench_id < 0:
continue
mask = benchmark_ids == bench_id_tensor
bench_count = int(mask.to(torch.long).sum().item())
if bench_count <= 0:
continue
bench_pass_at_1 = float(
successes[mask].to(torch.float32).mean().item()
)
bench_label = id_to_name.get(bench_id, f"BENCH_{bench_id}")
suffix = _metric_suffix_from_benchmark(bench_label)
self._append_metric_value(
mode,
f"greedy/pass_at_1_{suffix}",
bench_pass_at_1,
)
if bool(getattr(args, "eval_greedy_only_enabled", False)):
self._append_metric_value(
mode,
f"pass_at_1_{suffix}",
bench_pass_at_1,
)
self._append_metric_value(
mode,
f"mean_at_1_{suffix}",
bench_pass_at_1,
)
def _get_per_token_logps_and_entropy(
self,
model: Any,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
logits_to_keep: int,
*,
entropy_mode: str,
batch_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute selected-token log-probs and entropy on the completion span."""
chunk_size = int(batch_size or input_ids.size(0) or 1)
all_logps: List[torch.Tensor] = []
all_entropy: List[torch.Tensor] = []
mode = "train" if bool(getattr(model, "training", False)) else "eval"
unwrap_fn = getattr(getattr(self, "accelerator", None), "unwrap_model", None)
base_model = model
if callable(unwrap_fn):
try:
base_model = unwrap_fn(model)
except Exception:
base_model = model
tokenizer = getattr(self, "processing_class", None)
vocab_size = _resolve_token_id_upper_bound(base_model, tokenizer)
for start in range(0, int(input_ids.size(0)), chunk_size):
stop = start + chunk_size
input_ids_batch = input_ids[start:stop]
attention_mask_batch = attention_mask[start:stop]
input_ids_batch = self._sanitize_scoring_token_ids(
input_ids_batch,
upper_bound=vocab_size,
mode=mode,
context="model_input",
)
logits = model(
input_ids=input_ids_batch,
attention_mask=attention_mask_batch,
logits_to_keep=logits_to_keep + 1,
).logits
logits = logits[:, :-1, :]
token_ids = input_ids_batch[:, -logits_to_keep:]
logits = logits[:, -logits_to_keep:]
logits = logits / self.temperature
if isinstance(vocab_size, int) and int(logits.size(-1)) > vocab_size:
if not bool(
getattr(self, "_invalid_logit_columns_warned_entropy", False)
):
LOG.warning(
"Masking %d tokenizer-inaccessible logit columns in exact-entropy scoring (valid_vocab_size=%d, logits_width=%d).",
int(logits.size(-1)) - vocab_size,
vocab_size,
int(logits.size(-1)),
)
setattr(self, "_invalid_logit_columns_warned_entropy", True)
logits = _mask_invalid_logit_columns(
logits,
valid_vocab_size=vocab_size,
)
token_ids = self._sanitize_scoring_token_ids(
token_ids,
upper_bound=int(logits.size(-1)),
mode=mode,
context="token_select",
)
logps, entropy = _selected_logps_and_entropy(
logits,
token_ids,
entropy_mode=entropy_mode,
)
all_logps.append(logps)
all_entropy.append(entropy)
return torch.cat(all_logps, dim=0), torch.cat(all_entropy, dim=0)
def _recompute_local_rewards_for_outputs(
self,
inputs: List[Dict[str, Any]],
outputs: Dict[str, Any],
) -> Optional[torch.Tensor]:
if not inputs:
return None
completion_ids = outputs.get("completion_ids")
if not isinstance(completion_ids, torch.Tensor):
return None
completion_mask = outputs.get("completion_mask")
if not isinstance(completion_mask, torch.Tensor):
completion_mask = _apply_eos_completion_mask(
completion_ids,
getattr(self.processing_class, "eos_token_id", None),
)
if not isinstance(completion_mask, torch.Tensor):
return None
completion_mask = completion_mask.to(
device=completion_ids.device, dtype=torch.long
)
prompts = [example["prompt"] for example in inputs]
completions_text = self.processing_class.batch_decode(
completion_ids, skip_special_tokens=True
)
if is_conversational(inputs[0]):
completions: List[Any] = []
for prompt, completion in zip(prompts, completions_text):
bootstrap = ""
if (
isinstance(prompt, list)
and prompt
and isinstance(prompt[-1], dict)
and prompt[-1].get("role") == "assistant"
):
bootstrap = str(prompt[-1].get("content", ""))
completions.append(
[{"role": "assistant", "content": f"{bootstrap}{completion}"}]
)
else:
completions = completions_text
completion_ids_list = [
[
int(tok.item())
for tok, keep in zip(row, mask_row)
if int(keep.item()) != 0
]
for row, mask_row in zip(completion_ids, completion_mask)
]
rewards_per_func_local = torch.zeros(
(len(prompts), len(self.reward_funcs)),
device=completion_ids.device,
dtype=torch.float32,
)
keys = [
key
for key in inputs[0]
if key not in {"prompt", "completion", "completion_ids"}
]
def _reward_value_for_key(example: Dict[str, Any], key: str) -> Any:
if key in example:
return example[key]
if key == "answer":
return example.get("solution")
if key == "solution":
return example.get("answer")
return None
reward_kwargs = {
key: [_reward_value_for_key(example, key) for example in inputs]
for key in keys
}
if "answer" not in reward_kwargs and "solution" in reward_kwargs:
reward_kwargs["answer"] = list(reward_kwargs["solution"])
if "solution" not in reward_kwargs and "answer" in reward_kwargs:
reward_kwargs["solution"] = list(reward_kwargs["answer"])
reward_processing_classes = list(
getattr(
self, "reward_processing_classes", [None] * len(self.reward_funcs)
)
)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, torch.nn.Module):
reward_processing_class = (
reward_processing_classes[i]
if i < len(reward_processing_classes)
else None
)
if reward_processing_class is None:
return None
if is_conversational(inputs[0]):
messages = [
{"messages": p + c} for p, c in zip(prompts, completions)
]
texts = [
apply_chat_template(x, reward_processing_class)["text"]
for x in messages
]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
text=texts,
return_tensors="pt",
padding=True,
padding_side="right",
add_special_tokens=False,
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func_local[:, i] = reward_func(
**reward_inputs
).logits[:, 0]
else:
if not callable(reward_func):
return None
output_reward_func = reward_func(
prompts=prompts,
completions=completions,
completion_ids=completion_ids_list,
**reward_kwargs,
)
output_reward_func = [
reward if reward is not None else torch.nan
for reward in output_reward_func
]
rewards_per_func_local[:, i] = torch.tensor(
output_reward_func,
dtype=torch.float32,
device=completion_ids.device,
)
reward_weights = getattr(self, "reward_weights", None)
if isinstance(reward_weights, torch.Tensor):
weights = reward_weights.to(
device=rewards_per_func_local.device, dtype=torch.float32
)
elif isinstance(reward_weights, (list, tuple)):
weights = torch.tensor(
list(reward_weights),
dtype=torch.float32,
device=rewards_per_func_local.device,
)
else:
weights = torch.ones(
(len(self.reward_funcs),),
dtype=torch.float32,
device=rewards_per_func_local.device,
)
if weights.numel() != rewards_per_func_local.size(1):
weights = torch.ones(
(rewards_per_func_local.size(1),),
dtype=torch.float32,
device=rewards_per_func_local.device,
)
rewards = (rewards_per_func_local * weights.unsqueeze(0)).nansum(dim=1)
return torch.nan_to_num(rewards, nan=0.0, posinf=0.0, neginf=0.0)
def _recompute_global_rewards_for_outputs(
self,
inputs: List[Dict[str, Any]],
outputs: Dict[str, Any],
) -> Optional[torch.Tensor]:
rewards_local = self._recompute_local_rewards_for_outputs(inputs, outputs)
if not isinstance(rewards_local, torch.Tensor):
return None
rewards = gather(rewards_local)
return torch.nan_to_num(rewards, nan=0.0, posinf=0.0, neginf=0.0)
def _prepare_listwise_rollout_targets(
self,
inputs: List[Dict[str, Any]],
outputs: Dict[str, Any],
) -> None:
"""Cache per-prompt listwise q targets on rollout outputs."""
rewards = self._recompute_local_rewards_for_outputs(inputs, outputs)
if not isinstance(rewards, torch.Tensor):
return
if rewards.numel() <= 0:
return
group_size = max(int(getattr(self, "num_generations", 1) or 1), 1)
gathered_rewards = gather(rewards)
if not isinstance(gathered_rewards, torch.Tensor):
gathered_rewards = torch.as_tensor(
gathered_rewards,
device=rewards.device,
dtype=rewards.dtype,
)
grouped_rewards = _reshape_prompt_major_tensor(gathered_rewards, group_size)
if grouped_rewards is None:
raise ValueError(
"Listwise MaxEnt rollout rewards must arrive as whole prompt "
f"groups with flat batch size divisible by num_generations={group_size}."
)
temperature = _coerce_non_negative_float(
getattr(getattr(self, "args", None), "maxent_q_temperature", 1.0),
default=1.0,
)
epsilon = _coerce_non_negative_float(
getattr(getattr(self, "args", None), "maxent_q_epsilon", 1e-6),
default=1e-6,
)
q_grouped = torch.softmax(grouped_rewards / max(temperature, 1e-8), dim=1)
if epsilon > 0.0:
max_eps = max((1.0 / float(max(q_grouped.size(1), 1))) - 1e-8, 0.0)
epsilon = min(epsilon, max_eps)
if epsilon > 0.0:
q_grouped = q_grouped * (1.0 - epsilon * q_grouped.size(1)) + epsilon
q_grouped = q_grouped / q_grouped.sum(dim=1, keepdim=True).clamp(
min=1e-12
)
local_count = int(rewards.size(0))
process_index = int(getattr(self.accelerator, "process_index", 0) or 0)
process_start = process_index * local_count
process_stop = process_start + local_count
total_count = int(gathered_rewards.size(0))
if process_stop > total_count:
raise ValueError(
"Listwise MaxEnt gathered reward totals are shorter than the "
"current rank slice."
)
if process_start % group_size != 0:
raise ValueError(
"Listwise MaxEnt requires each rank slice to begin on a whole "
"prompt-group boundary after reward gathering."
)
local_rewards = gathered_rewards[process_start:process_stop]
local_q = q_grouped.reshape(-1)[process_start:process_stop]
local_grouped_rewards = _reshape_prompt_major_tensor(local_rewards, group_size)
local_q_grouped = _reshape_prompt_major_tensor(local_q, group_size)
if local_grouped_rewards is None or local_q_grouped is None:
raise ValueError(
"Listwise MaxEnt local rollout slice must contain whole prompt "
"groups after reward gathering."
)
outputs["maxent_listwise_q"] = _normalize_listwise_q_targets(
local_q_grouped.detach(),
num_prompts=int(local_grouped_rewards.size(0)),
group_size=group_size,
context="Listwise MaxEnt rollout targets",
)
outputs["maxent_listwise_rewards"] = local_grouped_rewards.detach()
def _resolve_listwise_reference_mode(self) -> bool:
"""Return whether listwise MaxEnt should include a reference term."""
return self._should_use_model_reference_logprobs(
default_to_model_reference=False
)
def _compute_listwise_maxent_loss(self, model: Any, inputs: Any) -> torch.Tensor:
"""Match the sampled candidate distribution to the tau/q/beta posterior."""
if bool(getattr(self, "use_liger_loss", False)):
raise NotImplementedError(
"Listwise MaxEnt loss is not implemented for liger loss."
)
q_grouped = inputs.get("maxent_listwise_q")
if not isinstance(q_grouped, torch.Tensor) or q_grouped.numel() <= 0:
raise ValueError(
"Listwise MaxEnt requires rollout q targets from _generate_and_score_completions."
)
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = (
inputs["completion_ids"],
inputs["completion_mask"],
)
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1)
mode = "train" if self.model.training else "eval"
configured_batch_size = (
int(getattr(self.args, "per_device_train_batch_size", 1) or 1)
if self.model.training
else int(getattr(self.args, "per_device_eval_batch_size", 1) or 1)
)
chunk_size = int(
getattr(self.args, "maxent_logprob_chunk_size", 0)
or configured_batch_size
or 1
)
per_token_logps = self._get_per_token_logps(
model,
input_ids,
attention_mask,
logits_to_keep,
batch_size=chunk_size,
)
seq_logps = (per_token_logps * completion_mask).sum(dim=1)
group_size = max(int(getattr(self, "num_generations", 1) or 1), 1)
seq_logps_grouped = _reshape_prompt_major_tensor(seq_logps, group_size)
token_counts_grouped = _reshape_prompt_major_tensor(
completion_mask.sum(dim=1).to(torch.float32),
group_size,
)
if seq_logps_grouped is None or token_counts_grouped is None:
raise ValueError("Listwise MaxEnt could not reshape sequence log-probs.")
old_per_token_logps = (
per_token_logps.detach()
if inputs["old_per_token_logps"] is None
else inputs["old_per_token_logps"].to(
device=per_token_logps.device,
dtype=per_token_logps.dtype,
)
)
old_seq_logps_grouped = _reshape_prompt_major_tensor(
(old_per_token_logps * completion_mask).sum(dim=1),
group_size,
)
if old_seq_logps_grouped is None:
raise ValueError("Listwise MaxEnt could not reshape behavior log-probs.")
num_prompts = int(seq_logps_grouped.size(0))
if int(token_counts_grouped.size(0)) != num_prompts:
raise ValueError("Listwise MaxEnt token counts are misaligned with prompts.")
if int(old_seq_logps_grouped.size(0)) != num_prompts:
raise ValueError("Listwise MaxEnt behavior log-probs are misaligned with prompts.")
policy_seq_logps_grouped = seq_logps_grouped
behavior_seq_logps_grouped = old_seq_logps_grouped
if bool(getattr(self.args, "maxent_length_normalize_policy", False)):
token_denoms = token_counts_grouped.to(seq_logps_grouped.dtype).clamp(
min=1.0
)
policy_seq_logps_grouped = seq_logps_grouped / token_denoms
behavior_seq_logps_grouped = old_seq_logps_grouped / token_denoms
q_grouped = _normalize_listwise_q_targets(
q_grouped.to(
device=seq_logps_grouped.device,
dtype=seq_logps_grouped.dtype,
),
num_prompts=num_prompts,
group_size=group_size,
context="Listwise MaxEnt loss",
)
skip_zero_variance_groups = bool(
getattr(self.args, "maxent_listwise_skip_zero_variance_groups", False)
)
neutral_group_mask = (
(
q_grouped.to(torch.float32).amax(dim=1)
- q_grouped.to(torch.float32).amin(dim=1)
)
<= 1e-8
) if skip_zero_variance_groups else torch.zeros(
num_prompts,
device=q_grouped.device,
dtype=torch.bool,
)
active_group_mask = ~neutral_group_mask
active_group_count = int(active_group_mask.to(torch.int64).sum().item())
weighting = getattr(self, "_maxent_weighting", None)
if weighting is None:
raise ValueError("Listwise MaxEnt requires initialized weighting settings.")
include_reference_term = self._resolve_listwise_reference_mode()
ref_seq_logps_grouped = torch.zeros_like(seq_logps_grouped)
measured_kl: Optional[torch.Tensor] = None
if include_reference_term:
with torch.no_grad():
ref_per_token_logps = self._get_reference_per_token_logps(
input_ids,
attention_mask,
logits_to_keep,
batch_size=chunk_size,
)
if ref_per_token_logps is not None:
ref_per_token_logps = ref_per_token_logps.to(
device=per_token_logps.device,
dtype=per_token_logps.dtype,
)
# Guard the exponentials against rare runaway log-prob deltas.
kl_delta = _clamp_log_delta(ref_per_token_logps - per_token_logps)
per_token_kl = (
torch.exp(kl_delta)
- kl_delta
- 1
).to(per_token_logps.dtype)
measured_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(
min=1.0
)
ref_seq_logps = (ref_per_token_logps * completion_mask).sum(dim=1)
if bool(getattr(weighting, "len_norm_ref", True)):
ref_seq_logps = ref_seq_logps / completion_mask.sum(dim=1).clamp(
min=1.0
)
ref_seq_logps_grouped = _reshape_prompt_major_tensor(
ref_seq_logps,
int(q_grouped.size(1)),
)
if ref_seq_logps_grouped is None:
raise ValueError(
"Listwise MaxEnt could not reshape reference log-probs."
)
if int(ref_seq_logps_grouped.size(0)) != num_prompts:
raise ValueError(
"Listwise MaxEnt reference log-probs are misaligned with prompts."
)
else:
if not bool(getattr(self, "_maxent_listwise_ref_warned", False)):
LOG.warning(
"Listwise MaxEnt requested reference weighting but no model-based "
"reference path is available; using rollout behavior log-probs "
"as the reference term."
)
setattr(self, "_maxent_listwise_ref_warned", True)
# Reuse the rollout behavior log-probs as a fixed reference term when
# no separate frozen model is available. This preserves listwise
# weighting signal on full-model runs without allocating another copy.
ref_seq_logps_grouped = behavior_seq_logps_grouped.detach()
kl_delta = _clamp_log_delta(old_per_token_logps - per_token_logps)
per_token_kl = (
torch.exp(kl_delta)
- kl_delta
- 1
).to(per_token_logps.dtype)
measured_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(
min=1.0
)
weights_grouped = weight_matrix_from_q(
q_grouped,
ref_seq_logps_grouped,
token_counts_grouped,
weighting,
include_reference_term=include_reference_term,
normalize_by_tokens=False,
).to(
device=seq_logps_grouped.device,
dtype=seq_logps_grouped.dtype,
)
if bool(neutral_group_mask.any().item()):
uniform_weights = torch.full_like(
weights_grouped,
1.0 / float(max(weights_grouped.size(1), 1)),
)
weights_grouped = torch.where(
neutral_group_mask.unsqueeze(1),
uniform_weights,
weights_grouped,
)
weights_grouped_list = weights_grouped.detach().cpu().tolist()
log_probs_grouped = torch.log_softmax(policy_seq_logps_grouped, dim=1)
per_group_policy_loss = -(weights_grouped * log_probs_grouped).sum(dim=1)
if active_group_count > 0:
policy_loss = per_group_policy_loss[active_group_mask].mean()
else:
policy_loss = (per_group_policy_loss * 0.0).sum()
loss = policy_loss
clip_loss: Optional[torch.Tensor] = None
if bool(getattr(self.args, "maxent_use_clip_objective", False)):
clip_coef = _coerce_non_negative_float(
getattr(self.args, "maxent_clip_objective_coef", 1.0),
default=1.0,
)
if clip_coef > 0.0:
clip_range = getattr(self.args, "maxent_clip_range", None)
clip_low = (
_coerce_non_negative_float(clip_range, default=self.epsilon_low)
if clip_range is not None
else float(self.epsilon_low)
)
clip_high = (
_coerce_non_negative_float(clip_range, default=self.epsilon_high)
if clip_range is not None
else float(self.epsilon_high)
)
baseline = getattr(self.args, "maxent_clip_adv_baseline", None)
if baseline is None:
baseline_value = 1.0 / float(max(weights_grouped.size(1), 1))
else:
baseline_value = float(baseline)
clip_adv = weights_grouped - baseline_value
log_seq_ratio = _clamp_log_delta(
policy_seq_logps_grouped - behavior_seq_logps_grouped
)
seq_ratio = torch.exp(log_seq_ratio).to(seq_logps_grouped.dtype)
seq_ratio_clipped = torch.clamp(
seq_ratio,
1.0 - clip_low,
1.0 + clip_high,
)
clip_obj = torch.min(
seq_ratio * clip_adv,
seq_ratio_clipped * clip_adv,
)
per_group_clip_loss = -clip_obj.sum(dim=1)
if active_group_count > 0:
clip_loss = per_group_clip_loss[active_group_mask].mean()
else:
clip_loss = (per_group_clip_loss * 0.0).sum()
loss = loss + clip_coef * clip_loss
is_low_clipped = (seq_ratio < 1.0 - clip_low) & (clip_adv < 0.0)
is_high_clipped = (seq_ratio > 1.0 + clip_high) & (clip_adv > 0.0)
clip_region = is_low_clipped | is_high_clipped
self._append_metric_value(
mode,
"clip_ratio/low_mean",
is_low_clipped.to(torch.float32).mean().item(),
)
self._append_metric_value(
mode,
"clip_ratio/high_mean",
is_high_clipped.to(torch.float32).mean().item(),
)
self._append_metric_value(
mode,
"clip_ratio/region_mean",
clip_region.to(torch.float32).mean().item(),
)
weight_entropy, entropy_min, entropy_max, _ = collect_weight_entropy(
weights_grouped_list
)
loss_scale_raw = inputs.get("maxent_listwise_loss_scale")
loss_scale_value = 1.0
if isinstance(loss_scale_raw, torch.Tensor):
if loss_scale_raw.numel() != 1:
raise ValueError("Listwise MaxEnt loss scale must be scalar.")
loss_scale_value = float(loss_scale_raw.detach().cpu().item())
elif isinstance(loss_scale_raw, (int, float)):
loss_scale_value = float(loss_scale_raw)
if not math.isfinite(loss_scale_value) or loss_scale_value <= 0.0:
raise ValueError("Listwise MaxEnt loss scale must be finite and positive.")
if loss_scale_value != 1.0:
loss = loss * loss.new_tensor(loss_scale_value)
self._append_metric_value(mode, "loss/policy", float(policy_loss.item()))
self._append_metric_value(
mode, "weight_entropy", float(weight_entropy), include_legacy_aliases=False
)
self._append_metric_value(
mode, "weight_entropy_min", float(entropy_min), include_legacy_aliases=False
)
self._append_metric_value(
mode, "weight_entropy_max", float(entropy_max), include_legacy_aliases=False
)
self._append_metric_value(mode, "maxent/objective_variant_listwise", 1.0)
self._append_metric_value(mode, "maxent/objective_variant_entropy", 0.0)
self._append_metric_value(
mode,
"maxent/listwise_weight_mean",
float(weights_grouped.mean().item()),
)
self._append_metric_value(
mode,
"maxent/listwise_weight_std",
float(weights_grouped.to(torch.float32).std(unbiased=False).item()),
)
self._append_metric_value(
mode,
"maxent/listwise_neutral_group_frac",
float(neutral_group_mask.to(torch.float32).mean().item()),
)
self._append_metric_value(
mode,
"maxent/listwise_active_group_frac",
float(active_group_mask.to(torch.float32).mean().item()),
)
self._append_metric_value(
mode,
"maxent/listwise_loss_scale",
loss_scale_value,
)
if clip_loss is not None:
self._append_metric_value(mode, "loss/clip", float(clip_loss.item()))
if measured_kl is not None:
gathered_kl = _metric_tensor_for_logging(self, measured_kl, mode=mode)
if isinstance(gathered_kl, torch.Tensor) and gathered_kl.numel() > 0:
kl_value = float(gathered_kl.nanmean().item())
self._append_metric_value(mode, "kl", kl_value)
else:
kl_value = None
else:
kl_value = None
if mode == "train":
meta_objective = getattr(self, "_maxent_controller_objective", None)
beta_controller_enabled = bool(
getattr(self.args, "maxent_beta_controller_enabled", False)
)
if meta_objective is not None:
self._maybe_apply_controller_meta(
mode=mode,
kl_value=kl_value,
weight_entropy=weight_entropy,
total_loss=float(loss.item()),
)
else:
maybe_update_tau(
weighting,
SimpleNamespace(weight_entropy=weight_entropy),
global_step=int(getattr(self.state, "global_step", 0) or 0),
)
if (
beta_controller_enabled
and include_reference_term
and kl_value is not None
):
maybe_update_beta(weighting, measured_kl=kl_value)
self._sync_weighting_scalars()
self._append_metric_value(
mode,
"kl_controller/enabled",
1.0 if beta_controller_enabled else 0.0,
include_legacy_aliases=False,
)
self._append_metric_value(
mode,
"kl_controller_enabled",
1.0 if beta_controller_enabled else 0.0,
include_legacy_aliases=False,
)
self._append_metric_value(mode, "tau", float(self.tau))
self._append_metric_value(mode, "beta", float(self.beta))
self._append_metric_value(
mode,
"weight_norm_denom",
float(getattr(weighting, "denom", 1.0)),
include_legacy_aliases=False,
)
return loss
def _sanitize_scoring_token_ids(
self,
token_ids: torch.Tensor,
*,
upper_bound: Optional[int],
mode: str,
context: str,
) -> torch.Tensor:
"""Clamp scorer token ids into range before model/gather indexing."""
if not isinstance(token_ids, torch.Tensor):
return token_ids
if token_ids.dtype.is_floating_point or token_ids.dtype == torch.bool:
return token_ids
if not isinstance(upper_bound, int) or upper_bound <= 0:
return token_ids
tokenizer = getattr(self, "processing_class", None)
replacement_id = _coerce_optional_int(getattr(tokenizer, "pad_token_id", None))
if (
replacement_id is None
or replacement_id < 0
or replacement_id >= upper_bound
):
replacement_id = _coerce_optional_int(
getattr(tokenizer, "eos_token_id", None)
)
if (
replacement_id is None
or replacement_id < 0
or replacement_id >= upper_bound
):
replacement_id = max(upper_bound - 1, 0)
invalid_mask = (token_ids < 0) | (token_ids >= upper_bound)
invalid_count = int(invalid_mask.to(torch.long).sum().item())
if invalid_count <= 0:
return token_ids
try:
invalid_vals = token_ids[invalid_mask]
min_invalid = int(invalid_vals.min().item())
max_invalid = int(invalid_vals.max().item())
except Exception:
min_invalid = 0
max_invalid = 0
sanitized = token_ids.clone()
sanitized[invalid_mask] = int(replacement_id)
self._append_metric_value(
mode,
"scoring/invalid_token_id_count",
float(invalid_count),
include_legacy_aliases=False,
)
self._append_metric_value(
mode,
f"scoring/{context}_invalid_token_id_count",
float(invalid_count),
include_legacy_aliases=False,
)
self._append_metric_value(
mode,
"scoring/invalid_token_id_replacement",
float(replacement_id),
include_legacy_aliases=False,
)
warned_contexts = getattr(self, "_invalid_scoring_token_ids_warned_contexts", None)
if not isinstance(warned_contexts, set):
warned_contexts = set()
setattr(self, "_invalid_scoring_token_ids_warned_contexts", warned_contexts)
if context not in warned_contexts:
LOG.warning(
"Sanitized %d scoring token ids for %s outside upper_bound=%d using replacement_id=%d (min=%d max=%d)",
invalid_count,
context,
upper_bound,
replacement_id,
min_invalid,
max_invalid,
)
warned_contexts.add(context)
return sanitized
def _get_per_token_logps(self, *args: Any, **kwargs: Any) -> torch.Tensor: # type: ignore[override]
model = args[0] if len(args) >= 1 else kwargs.get("model")
input_ids = args[1] if len(args) >= 2 else kwargs.get("input_ids")
attention_mask = args[2] if len(args) >= 3 else kwargs.get("attention_mask")
logits_to_keep = args[3] if len(args) >= 4 else kwargs.get("logits_to_keep")
batch_size = args[4] if len(args) >= 5 else kwargs.get("batch_size")
if not isinstance(input_ids, torch.Tensor) or not isinstance(attention_mask, torch.Tensor):
logps = super()._get_per_token_logps(*args, **kwargs)
else:
chunk_size = int(batch_size or input_ids.size(0) or 1)
mode = "train" if bool(getattr(model, "training", False)) else "eval"
unwrap_fn = getattr(getattr(self, "accelerator", None), "unwrap_model", None)
base_model = model
if callable(unwrap_fn):
try:
base_model = unwrap_fn(model)
except Exception:
base_model = model
tokenizer = getattr(self, "processing_class", None)
vocab_size = _resolve_token_id_upper_bound(base_model, tokenizer)
all_logps: List[torch.Tensor] = []
for start in range(0, int(input_ids.size(0)), chunk_size):
stop = start + chunk_size
input_ids_batch = input_ids[start:stop]
attention_mask_batch = attention_mask[start:stop]
input_ids_batch = self._sanitize_scoring_token_ids(
input_ids_batch,
upper_bound=vocab_size,
mode=mode,
context="model_input",
)
try:
outputs = model(
input_ids=input_ids_batch,
attention_mask=attention_mask_batch,
logits_to_keep=int(logits_to_keep) + 1,
)
except TypeError:
outputs = model(
input_ids=input_ids_batch,
attention_mask=attention_mask_batch,
)
logits = getattr(outputs, "logits", outputs)
logits = logits[:, :-1, :]
token_ids = input_ids_batch[:, -int(logits_to_keep) :]
logits = logits[:, -int(logits_to_keep) :]
logits = logits / self.temperature
if isinstance(vocab_size, int) and int(logits.size(-1)) > vocab_size:
if not bool(
getattr(self, "_invalid_logit_columns_warned_logps", False)
):
LOG.warning(
"Masking %d tokenizer-inaccessible logit columns in shared logprob scoring (valid_vocab_size=%d, logits_width=%d).",
int(logits.size(-1)) - vocab_size,
vocab_size,
int(logits.size(-1)),
)
setattr(self, "_invalid_logit_columns_warned_logps", True)
logits = _mask_invalid_logit_columns(
logits,
valid_vocab_size=vocab_size,
)
token_ids = self._sanitize_scoring_token_ids(
token_ids,
upper_bound=int(logits.size(-1)),
mode=mode,
context="token_select",
)
log_probs = F.log_softmax(logits, dim=-1)
chunk_logps = torch.gather(
log_probs,
dim=-1,
index=token_ids.unsqueeze(-1),
).squeeze(-1)
all_logps.append(chunk_logps)
logps = torch.cat(all_logps, dim=0)
if self.maxent_enabled:
return logps
if self.accelerator.is_main_process:
step = int(getattr(self.state, "global_step", 0))
try:
requires_grad = bool(getattr(logps, "requires_grad", False))
except Exception:
requires_grad = False
LOG.info(
"GRPO debug | step=%d | token_logp_requires_grad=%s | grad_enabled=%s | logps_shape=%s",
step,
requires_grad,
torch.is_grad_enabled(),
getattr(logps, "shape", None),
)
return logps
def _compute_maxent_loss(self, model: Any, inputs: Any) -> torch.Tensor:
"""TRL-style GRPO loss with a true entropy regularizer in the loss."""
if bool(getattr(self, "use_liger_loss", False)):
raise NotImplementedError(
"MaxEnt loss regularization is not implemented for liger loss."
)
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = (
inputs["completion_ids"],
inputs["completion_mask"],
)
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1)
mode = "train" if self.model.training else "eval"
configured_batch_size = (
int(getattr(self.args, "per_device_train_batch_size", 1) or 1)
if self.model.training
else int(getattr(self.args, "per_device_eval_batch_size", 1) or 1)
)
chunk_size = int(
getattr(self.args, "maxent_logprob_chunk_size", 0)
or configured_batch_size
or 1
)
requested_entropy_mode = str(
getattr(self.args, "maxent_policy_entropy_mode", "exact") or "exact"
)
entropy_mode = requested_entropy_mode.strip().lower() or "exact"
args = getattr(self, "args", None)
if entropy_mode != "exact":
if not bool(
getattr(self, "_maxent_sample_entropy_loss_warned", False)
):
LOG.warning(
"Entropy-regularized MaxEnt requested maxent_policy_entropy_mode=%s, "
"but the training loss uses exact entropy. The sample estimator is "
"only valid for logging or GRPO reward bonuses, not direct "
"entropy-loss gradients.",
requested_entropy_mode,
)
setattr(self, "_maxent_sample_entropy_loss_warned", True)
entropy_mode = "exact"
unwrap_fn = getattr(getattr(self, "accelerator", None), "unwrap_model", None)
base_model = model
if callable(unwrap_fn):
try:
base_model = unwrap_fn(model)
except Exception:
base_model = model
tokenizer = getattr(self, "processing_class", None)
valid_vocab_size = _resolve_token_id_upper_bound(base_model, tokenizer)
entropy_normalization_scale = _entropy_normalization_scale(valid_vocab_size)
per_token_logps, per_token_entropy = self._get_per_token_logps_and_entropy(
model,
input_ids,
attention_mask,
logits_to_keep,
entropy_mode=entropy_mode,
batch_size=chunk_size,
)
per_token_kl: Optional[torch.Tensor] = None
kl_value: Optional[float] = None
alpha_kl_control_requested = self._entropy_alpha_kl_control_requested()
if self.beta != 0.0 or alpha_kl_control_requested:
use_model_reference = self._should_use_model_reference_logprobs(
default_to_model_reference=alpha_kl_control_requested
)
with torch.no_grad():
if use_model_reference:
ref_per_token_logps = self._get_reference_per_token_logps(
input_ids,
attention_mask,
logits_to_keep,
batch_size=chunk_size,
)
else:
ref_per_token_logps = None
if ref_per_token_logps is not None:
ref_per_token_logps = ref_per_token_logps.to(
device=per_token_logps.device,
dtype=per_token_logps.dtype,
)
else:
old_ref = inputs.get("old_per_token_logps")
if isinstance(old_ref, torch.Tensor):
ref_per_token_logps = old_ref.to(
device=per_token_logps.device,
dtype=per_token_logps.dtype,
)
else:
ref_per_token_logps = per_token_logps.detach()
# Guard the exponentials against rare runaway log-prob deltas.
kl_delta = _clamp_log_delta(ref_per_token_logps - per_token_logps)
per_token_kl = (
torch.exp(kl_delta)
- kl_delta
- 1
).to(per_token_logps.dtype)
current_batch_kl_measure: Optional[float] = None
if mode == "train" and per_token_kl is not None:
mean_kl_for_alpha = (
(per_token_kl.detach() * completion_mask).sum()
/ completion_mask.sum().clamp(min=1.0)
).to(torch.float32)
gathered_kl_for_alpha = self.accelerator.gather(mean_kl_for_alpha)
if torch.isfinite(gathered_kl_for_alpha).all():
current_batch_kl_measure = float(gathered_kl_for_alpha.mean().item())
else:
current_batch_kl_measure = float("inf")
old_per_token_logps = (
per_token_logps.detach()
if inputs["old_per_token_logps"] is None
else inputs["old_per_token_logps"]
)
advantages = inputs["advantages"]
advantages, old_per_token_logps = self._maybe_apply_seed_grpo_advantages_in_loss(
inputs,
completion_ids=completion_ids,
completion_mask=completion_mask,
behavior_logps=old_per_token_logps.detach(),
mode=mode,
)
log_ratio = _clamp_log_delta(per_token_logps - old_per_token_logps)
coef_1 = torch.exp(log_ratio).to(per_token_logps.dtype)
coef_2 = torch.clamp(
coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high
)
if self.args.delta is not None:
coef_1 = torch.clamp(coef_1, max=self.args.delta)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if per_token_kl is not None:
per_token_loss = per_token_loss + self.beta * per_token_kl
(
alpha,
alpha_multiplier,
alpha_kl_measure,
alpha_kl_threshold,
alpha_kl_control_enabled,
alpha_direction,
alpha_kl_min_multiplier,
alpha_kl_max_multiplier,
alpha_trust_zone_blocked,
) = self._resolve_effective_maxent_alpha(
mode,
measured_kl_override=current_batch_kl_measure,
)
completion_mask_f = completion_mask.to(
device=per_token_entropy.device,
dtype=per_token_entropy.dtype,
)
token_count_per_seq = completion_mask_f.sum(dim=1).clamp(min=1.0)
mean_entropy = (per_token_entropy * completion_mask_f).sum() / completion_mask_f.sum().clamp(
min=1.0
)
entropy_per_seq = (
(per_token_entropy * completion_mask_f).sum(dim=1) / token_count_per_seq
)
mean_entropy_per_seq = entropy_per_seq.mean()
if self.loss_type == "grpo":
loss = (
(per_token_loss * completion_mask).sum(-1)
/ completion_mask.sum(-1).clamp(min=1.0)
).mean()
entropy_bonus_basis = mean_entropy_per_seq / entropy_normalization_scale
elif self.loss_type == "bnpo":
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(
min=1.0
)
entropy_bonus_basis = mean_entropy / entropy_normalization_scale
elif self.loss_type == "dr_grpo":
loss = (per_token_loss * completion_mask).sum() / self._dr_grpo_loss_denominator(
completion_mask,
loss_tensor=per_token_loss,
mode=mode,
)
# Under fixed-denominator Dr.GRPO, a raw per-token entropy bonus
# creates a spurious incentive to emit longer completions. Apply
# the bonus on sequence-mean entropy instead so each sample gets
# equal entropy weight regardless of realized length.
entropy_bonus_basis = mean_entropy_per_seq / entropy_normalization_scale
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")
alpha_before_invalid_guard = float(alpha)
invalid_rollout_token_count = float(
getattr(self, "_last_rollout_invalid_token_id_count", 0.0) or 0.0
)
invalid_rollout_bonus_blocked = (
invalid_rollout_token_count > 0.0 and alpha_before_invalid_guard > 0.0
)
if invalid_rollout_bonus_blocked:
alpha = 0.0
entropy_bonus = alpha * entropy_bonus_basis
loss = loss - entropy_bonus
if per_token_kl is not None:
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
gathered_kl = _metric_tensor_for_logging(self, mean_kl, mode=mode)
if isinstance(gathered_kl, torch.Tensor) and gathered_kl.numel() > 0:
kl_value = float(gathered_kl.nanmean().item())
self._append_metric_value(
mode,
"kl",
kl_value,
)
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (
advantages.unsqueeze(1) < 0
)
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (
advantages.unsqueeze(1) > 0
)
is_region_clipped = is_low_clipped | is_high_clipped
low_clip = (is_low_clipped * completion_mask).sum() / completion_mask.sum()
high_clip = (is_high_clipped * completion_mask).sum() / completion_mask.sum()
clip_ratio = (is_region_clipped * completion_mask).sum() / completion_mask.sum()
gathered_low_clip = _metric_tensor_for_logging(self, low_clip, mode=mode)
if isinstance(gathered_low_clip, torch.Tensor) and gathered_low_clip.numel() > 0:
self._append_metric_value(
mode, "clip_ratio/low_mean", gathered_low_clip.nanmean().item()
)
self._append_metric_value(
mode, "clip_ratio/low_min", _nanmin_tensor(gathered_low_clip).item()
)
gathered_high_clip = _metric_tensor_for_logging(self, high_clip, mode=mode)
if isinstance(gathered_high_clip, torch.Tensor) and gathered_high_clip.numel() > 0:
self._append_metric_value(
mode, "clip_ratio/high_mean", gathered_high_clip.nanmean().item()
)
self._append_metric_value(
mode, "clip_ratio/high_max", _nanmax_tensor(gathered_high_clip).item()
)
gathered_clip_ratio = _metric_tensor_for_logging(self, clip_ratio, mode=mode)
if isinstance(gathered_clip_ratio, torch.Tensor) and gathered_clip_ratio.numel() > 0:
self._append_metric_value(
mode, "clip_ratio/region_mean", gathered_clip_ratio.nanmean().item()
)
gathered_entropy = _metric_tensor_for_logging(self, mean_entropy, mode=mode)
gathered_entropy_per_seq = _metric_tensor_for_logging(
self, entropy_per_seq, mode=mode
)
self._append_metric_value(mode, "maxent/alpha", alpha)
self._append_metric_value(mode, "maxent/alpha_base", float(self.maxent_alpha))
self._append_metric_value(
mode,
"maxent/alpha_before_invalid_token_guard",
alpha_before_invalid_guard,
include_legacy_aliases=False,
)
self._append_metric_value(mode, "maxent/alpha_multiplier", alpha_multiplier)
self._append_metric_value(mode, "maxent/objective_variant_entropy", 1.0)
self._append_metric_value(mode, "maxent/objective_variant_listwise", 0.0)
self._append_metric_value(
mode,
"maxent/invalid_rollout_bonus_blocked",
1.0 if invalid_rollout_bonus_blocked else 0.0,
include_legacy_aliases=False,
)
self._append_metric_value(
mode,
"maxent/rollout_invalid_token_id_count",
invalid_rollout_token_count,
include_legacy_aliases=False,
)
self._append_metric_value(
mode,
"maxent/alpha_kl_control_enabled",
1.0 if alpha_kl_control_enabled else 0.0,
)
self._append_metric_value(
mode,
"maxent/alpha_trust_zone_blocked",
1.0 if alpha_trust_zone_blocked else 0.0,
)
self._append_metric_value(mode, "maxent/alpha_kl_direction", alpha_direction)
self._append_metric_value(mode, "maxent/alpha_kl_threshold", alpha_kl_threshold)
self._append_metric_value(
mode, "maxent/alpha_kl_min_multiplier", alpha_kl_min_multiplier
)
self._append_metric_value(
mode, "maxent/alpha_kl_max_multiplier", alpha_kl_max_multiplier
)
if alpha_kl_measure is not None:
self._append_metric_value(
mode, "maxent/alpha_kl_measure", alpha_kl_measure
)
if isinstance(gathered_entropy, torch.Tensor) and gathered_entropy.numel() > 0:
raw_entropy_metric = (
gathered_entropy_per_seq.nanmean()
if self.loss_type in {"grpo", "dr_grpo"}
and isinstance(gathered_entropy_per_seq, torch.Tensor)
and gathered_entropy_per_seq.numel() > 0
else gathered_entropy.nanmean()
)
normalized_entropy_metric = raw_entropy_metric / entropy_normalization_scale
normalized_entropy_token = gathered_entropy.nanmean() / entropy_normalization_scale
self._append_metric_value(
mode, "maxent/policy_entropy_mean", raw_entropy_metric.item()
)
self._append_metric_value(
mode,
"maxent/policy_entropy_mean_token",
gathered_entropy.nanmean().item(),
include_legacy_aliases=False,
)
self._append_metric_value(
mode,
"maxent/policy_entropy_mean_normalized",
normalized_entropy_metric.item(),
include_legacy_aliases=False,
)
self._append_metric_value(
mode,
"maxent/policy_entropy_mean_token_normalized",
normalized_entropy_token.item(),
include_legacy_aliases=False,
)
if (
isinstance(gathered_entropy_per_seq, torch.Tensor)
and gathered_entropy_per_seq.numel() > 0
):
normalized_entropy_seq = (
gathered_entropy_per_seq.nanmean() / entropy_normalization_scale
)
self._append_metric_value(
mode,
"maxent/policy_entropy_mean_seq",
gathered_entropy_per_seq.nanmean().item(),
include_legacy_aliases=False,
)
self._append_metric_value(
mode,
"maxent/policy_entropy_mean_seq_normalized",
normalized_entropy_seq.item(),
include_legacy_aliases=False,
)
self._append_metric_value(
mode,
"maxent/entropy_bonus_length_normalized",
1.0 if self.loss_type in {"grpo", "dr_grpo"} else 0.0,
include_legacy_aliases=False,
)
self._append_metric_value(
mode,
"maxent/entropy_normalization_log_vocab",
entropy_normalization_scale,
include_legacy_aliases=False,
)
self._append_metric_value(
mode,
"maxent/valid_vocab_size",
float(valid_vocab_size) if valid_vocab_size is not None else 0.0,
include_legacy_aliases=False,
)
self._append_metric_value(
mode,
"maxent/loss_entropy_bonus",
(-alpha * normalized_entropy_metric).item(),
)
if (
isinstance(gathered_entropy_per_seq, torch.Tensor)
and gathered_entropy_per_seq.numel() > 0
):
self._append_metric_value(
mode,
"maxent/policy_entropy_std",
gathered_entropy_per_seq.to(torch.float32).std(unbiased=False).item(),
)
if mode == "train" and getattr(self, "_maxent_controller_objective", None) is not None:
self._maybe_apply_controller_meta(
mode=mode,
kl_value=kl_value,
total_loss=float(loss.item()),
)
self._sync_weighting_scalars()
self._append_metric_value(mode, "tau", float(self.tau))
self._append_metric_value(mode, "beta", float(self.beta))
self._append_metric_value(
mode,
"weight_norm_denom",
float(getattr(self, "weight_norm_denom", 1.0)),
include_legacy_aliases=False,
)
return loss
def _compute_grpo_native_loss(
self,
*,
model: Any,
inputs: Any,
return_outputs: bool,
num_items_in_batch: Any = None,
) -> Any:
"""Run GRPO through the parent TRL loss implementation only."""
try:
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
except TypeError as exc:
# Older TRL signatures may not accept num_items_in_batch.
if "num_items_in_batch" not in str(exc):
raise
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
)
def _compute_stable_grpo_loss(self, model: Any, inputs: Any) -> torch.Tensor:
"""GRPO loss using the same stabilized exponentials as the MaxEnt path."""
if bool(getattr(self, "use_liger_loss", False)):
raise NotImplementedError(
"Stable GRPO loss is not implemented for liger loss."
)
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = (
inputs["completion_ids"],
inputs["completion_mask"],
)
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1)
mode = "train" if self.model.training else "eval"
configured_batch_size = (
int(getattr(self.args, "per_device_train_batch_size", 1) or 1)
if self.model.training
else int(getattr(self.args, "per_device_eval_batch_size", 1) or 1)
)
chunk_size = int(
getattr(self.args, "maxent_logprob_chunk_size", 0)
or configured_batch_size
or 1
)
per_token_logps = self._get_per_token_logps(
model,
input_ids,
attention_mask,
logits_to_keep,
batch_size=chunk_size,
)
per_token_kl: Optional[torch.Tensor] = None
if self.beta != 0.0:
use_model_reference = self._should_use_model_reference_logprobs(
default_to_model_reference=False
)
with torch.no_grad():
if use_model_reference:
ref_per_token_logps = self._get_reference_per_token_logps(
input_ids,
attention_mask,
logits_to_keep,
batch_size=chunk_size,
)
else:
ref_per_token_logps = None
if ref_per_token_logps is not None:
ref_per_token_logps = ref_per_token_logps.to(
device=per_token_logps.device,
dtype=per_token_logps.dtype,
)
else:
old_ref = inputs.get("old_per_token_logps")
if isinstance(old_ref, torch.Tensor):
ref_per_token_logps = old_ref.to(
device=per_token_logps.device,
dtype=per_token_logps.dtype,
)
else:
ref_per_token_logps = per_token_logps.detach()
kl_delta = _clamp_log_delta(ref_per_token_logps - per_token_logps)
per_token_kl = (
torch.exp(kl_delta)
- kl_delta
- 1
).to(per_token_logps.dtype)
old_per_token_logps = (
per_token_logps.detach()
if inputs["old_per_token_logps"] is None
else inputs["old_per_token_logps"]
)
advantages = inputs["advantages"]
advantages, old_per_token_logps = self._maybe_apply_seed_grpo_advantages_in_loss(
inputs,
completion_ids=completion_ids,
completion_mask=completion_mask,
behavior_logps=old_per_token_logps.detach(),
mode=mode,
)
log_ratio = _clamp_log_delta(per_token_logps - old_per_token_logps)
coef_1 = torch.exp(log_ratio).to(per_token_logps.dtype)
coef_2 = torch.clamp(
coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high
)
if self.args.delta is not None:
coef_1 = torch.clamp(coef_1, max=self.args.delta)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if per_token_kl is not None:
per_token_loss = per_token_loss + self.beta * per_token_kl
if self.loss_type == "grpo":
loss = (
(per_token_loss * completion_mask).sum(-1)
/ completion_mask.sum(-1).clamp(min=1.0)
).mean()
elif self.loss_type == "bnpo":
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(
min=1.0
)
elif self.loss_type == "dr_grpo":
loss = (per_token_loss * completion_mask).sum() / self._dr_grpo_loss_denominator(
completion_mask,
loss_tensor=per_token_loss,
mode=mode,
)
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")
if per_token_kl is not None:
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
gathered_kl = _metric_tensor_for_logging(self, mean_kl, mode=mode)
if isinstance(gathered_kl, torch.Tensor) and gathered_kl.numel() > 0:
kl_value = float(gathered_kl.nanmean().item())
self._append_metric_value(mode, "kl", kl_value)
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (
advantages.unsqueeze(1) < 0
)
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (
advantages.unsqueeze(1) > 0
)
is_region_clipped = is_low_clipped | is_high_clipped
low_clip = (is_low_clipped * completion_mask).sum() / completion_mask.sum()
high_clip = (is_high_clipped * completion_mask).sum() / completion_mask.sum()
clip_ratio = (is_region_clipped * completion_mask).sum() / completion_mask.sum()
gathered_low_clip = _metric_tensor_for_logging(self, low_clip, mode=mode)
if isinstance(gathered_low_clip, torch.Tensor) and gathered_low_clip.numel() > 0:
self._append_metric_value(
mode, "clip_ratio/low_mean", gathered_low_clip.nanmean().item()
)
self._append_metric_value(
mode, "clip_ratio/low_min", _nanmin_tensor(gathered_low_clip).item()
)
gathered_high_clip = _metric_tensor_for_logging(self, high_clip, mode=mode)
if isinstance(gathered_high_clip, torch.Tensor) and gathered_high_clip.numel() > 0:
self._append_metric_value(
mode, "clip_ratio/high_mean", gathered_high_clip.nanmean().item()
)
self._append_metric_value(
mode, "clip_ratio/high_max", _nanmax_tensor(gathered_high_clip).item()
)
gathered_clip_ratio = _metric_tensor_for_logging(self, clip_ratio, mode=mode)
if isinstance(gathered_clip_ratio, torch.Tensor) and gathered_clip_ratio.numel() > 0:
self._append_metric_value(
mode, "clip_ratio/region_mean", gathered_clip_ratio.nanmean().item()
)
self._append_metric_value(mode, "maxent/objective_variant_entropy", 0.0)
self._append_metric_value(mode, "maxent/objective_variant_listwise", 0.0)
return loss
def compute_loss( # type: ignore[override]
self,
model: Any,
inputs: Any,
return_outputs: bool = False,
num_items_in_batch: Any = None,
) -> Any:
native_grpo_route = False
trl_prepared_inputs = isinstance(inputs, dict) and all(
key in inputs
for key in (
"prompt_ids",
"prompt_mask",
"completion_ids",
"completion_mask",
"advantages",
)
)
lightweight_eval = bool(
(not getattr(self.model, "training", False))
and trl_prepared_inputs
and getattr(
getattr(self, "args", None),
"eval_greedy_only_enabled",
False,
)
)
if lightweight_eval:
if return_outputs:
raise ValueError(
"The lightweight greedy eval path does not support returning outputs"
)
loss = self._compute_stable_grpo_loss(model=model, inputs=inputs)
elif self.objective_routing.uses_listwise_loss and trl_prepared_inputs:
if return_outputs:
raise ValueError(
"The custom listwise MaxEnt GRPOTrainer does not support returning outputs"
)
loss = self._compute_listwise_maxent_loss(model=model, inputs=inputs)
elif (
self.objective_routing.uses_entropy_regularized_loss
and trl_prepared_inputs
):
if return_outputs:
raise ValueError(
"The custom MaxEnt GRPOTrainer does not support returning outputs"
)
loss = self._compute_maxent_loss(model=model, inputs=inputs)
else:
if trl_prepared_inputs and not return_outputs:
loss = self._compute_stable_grpo_loss(model=model, inputs=inputs)
else:
native_grpo_route = True
loss = self._compute_grpo_native_loss(
model=model,
inputs=inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
# Cache the latest train KL immediately after native TRL loss
# computation so adaptive MaxEnt alpha can be evaluated every
# optimizer step (the next rollout consumes this cached value).
if bool(getattr(self.model, "training", False)):
train_metrics = self._metrics.get("train", {})
kl_history = (
train_metrics.get("kl") if isinstance(train_metrics, dict) else None
)
if isinstance(kl_history, list) and kl_history:
kl_value = _numeric_or_none(kl_history[-1])
if kl_value is not None:
setattr(self, "_last_train_kl_for_alpha", float(kl_value))
else:
kl_value = None
if (
native_grpo_route
and getattr(self, "_maxent_controller_objective", None) is not None
):
loss_value = loss[0] if isinstance(loss, tuple) else loss
self._maybe_apply_controller_meta(
mode="train",
kl_value=kl_value,
total_loss=_numeric_or_none(loss_value),
)
self._sync_weighting_scalars()
self._append_metric_value("train", "tau", float(self.tau))
self._append_metric_value("train", "beta", float(self.beta))
self._append_metric_value(
"train",
"weight_norm_denom",
float(getattr(self, "weight_norm_denom", 1.0)),
include_legacy_aliases=False,
)
if self.maxent_enabled:
self._sync_weighting_scalars()
self._maybe_update_reference_model_ema()
return loss
def _generate_and_score_completions( # type: ignore[override]
self, inputs: List[Dict[str, Any]]
) -> Dict[str, Any]:
mode = "train" if self.model.training else "eval"
if mode == "eval" and bool(
getattr(getattr(self, "args", None), "eval_greedy_only_enabled", False)
):
outputs = self._generate_greedy_eval_outputs(inputs)
self._log_eval_greedy_metrics(inputs, outputs, mode=mode)
return outputs
outputs = super()._generate_and_score_completions(inputs)
self._sanitize_rollout_token_ids(outputs, mode=mode)
self._maybe_truncate_completions_at_first_boxed_answer(
inputs,
outputs,
mode=mode,
)
defer_seed_scaling = mode == "train" and bool(
getattr(getattr(self, "args", None), "seed_grpo_enabled", False)
)
if not defer_seed_scaling:
self._maybe_backfill_old_per_token_logps(outputs, mode=mode)
self._maybe_apply_seed_grpo_advantages(
inputs,
outputs,
mode=mode,
)
if self.objective_routing.uses_listwise_loss:
self._prepare_listwise_rollout_targets(inputs, outputs)
self._recompute_completion_metrics(outputs, mode=mode)
self._maybe_log_rich_rollout_sidecar(inputs, outputs, mode=mode)
self._log_grpo_diversity(outputs, mode=mode)
self._log_eval_pass_at_k(inputs, outputs, mode=mode)
self._log_eval_greedy_metrics(inputs, outputs, mode=mode)
if not self.maxent_enabled:
self._log_grpo_debug(inputs, outputs, mode=mode)
return outputs
CustomGRPOTrainer.__name__ = "CustomGRPOTrainer"
return ensure_weighting_logging(CustomGRPOTrainer)
[docs]
def wrap_trl_trainer(trainer_cls: Type[Any]) -> Type[Any]:
"""Ensure a trainer class emits TRL-style logs and metrics."""
return ensure_weighting_logging(trainer_cls)
__all__ = ["build_custom_grpo_trainer", "wrap_trl_trainer"]