# Copyright 2025 Liv d'Aliberti
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for preparing generation/scoring artifacts used by the training loop.
The training loop expects a consistent set of artifacts for every batch:
``PreparedBatch``
Bundles grouped completions, reward statistics, reference log-probability
tensors, weighting diagnostics, and derived scores.
``_collect_batch_stats``
Bridges generation/reward outputs with the scoring stack by building
:class:`~training.scoring.ScoreBatch` objects,
gathering reference log-probs when necessary, and computing weighting and
length summaries.
``prepare_training_batch``
High-level orchestration that runs the generation function, computes
rewards, fetches reference log-probs, scores the policy, and returns a
:class:`PreparedBatch` instance to the optimizer.
The helpers raise the internal :class:`_SkipBatch` exception when any step
fails; :func:`prepare_training_batch` catches it and returns ``None`` so the
caller can skip the problematic batch gracefully.
"""
from __future__ import annotations
# pylint: disable=broad-exception-caught
import logging
import math
import os
import time
import sys
import traceback
from collections.abc import Iterable, Mapping, Sized
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, TypeVar, TYPE_CHECKING, cast
from types import SimpleNamespace
from .rewards import (
_apply_group_scales,
_group_q_distribution,
compute_reward_statistics,
group_advantages,
prepare_generation_batch,
reward_moments,
)
from .scoring import (
build_score_batch,
build_sequence_scores,
gather_reference_logprobs,
reference_stats_from_policy_logprobs,
reference_from_vllm_meta,
score_model_outputs,
summarize_completion_lengths,
token_counts_from_score_batch,
)
from .runtime import require_torch
from .types import (
BatchingSettings,
GenerationBatch,
GenerationFn,
GenerationSettings,
LengthStats,
PreTrainedTokenizer,
PromptCacheEntry,
ReferenceLogprobs,
RewardComputation,
RewardMoments,
ScoreBatch,
SequenceScores,
Tensor,
TrainingLoopContext,
AdvantageStats,
QDistribution,
)
from .weighting import WeightStats, WeightingSettings
from .weighting.logic import compute_weight_stats, build_uniform_weight_stats
if TYPE_CHECKING:
import torch
LOG = logging.getLogger(__name__)
def _progress_log_enabled() -> bool:
raw = os.getenv("MAXENT_PROGRESS_LOG")
if raw is None or not str(raw).strip():
return False
return str(raw).strip().lower() not in {"0", "false", "no", "off"}
_REF_LOGPROB_TRACE_LIMIT = 3
torch = require_torch("training_pipeline")
class _TraceCounter:
"""Stateful helper to guard noisy tracebacks."""
def __init__(self, limit: int) -> None:
self._limit = limit
self._count = 0
def next_occurrence(self) -> Optional[int]:
"""Return the next occurrence number or None when exhausted."""
if self._count >= self._limit:
return None
self._count += 1
return self._count
def reset(self) -> None:
"""Reset the counter so new traces can be emitted."""
self._count = 0
_REF_LOGPROB_TRACE_LIMITER = _TraceCounter(_REF_LOGPROB_TRACE_LIMIT)
def _deepspeed_zero_stage(accelerator: Any) -> int:
"""Return DeepSpeed ZeRO stage from Accelerate plugin state when present."""
state = getattr(accelerator, "state", None)
ds_plugin = getattr(state, "deepspeed_plugin", None)
try:
return int(getattr(ds_plugin, "zero_stage", 0) or 0)
except (TypeError, ValueError):
return 0
def _rank_tag(accelerator: Any = None) -> str:
"""Return best-effort rank string for logging."""
rank = getattr(accelerator, "process_index", None)
world = getattr(accelerator, "num_processes", None)
if rank is None:
try:
dist = getattr(torch, "distributed", None)
if dist is not None and dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world = dist.get_world_size()
except Exception:
rank = None
world = None
if rank is None:
return "rank=na"
if world is None:
return f"rank={rank}"
return f"rank={rank}/{world}"
def _mean(values: List[float]) -> float:
"""Return the arithmetic mean for a non-empty list, else 0.0."""
return float(sum(values)) / float(len(values)) if values else 0.0
def _weighted_mean(values: List[float], weights: List[float]) -> float:
"""Return the weighted mean or 0.0 when weights are empty."""
if not values or not weights:
return 0.0
total_weight = float(sum(weights))
if total_weight <= 0.0:
return 0.0
return float(sum(v * w for v, w in zip(values, weights))) / total_weight
def _tokenize_for_diversity(text: str, tokenizer: Any = None) -> List[Any]:
"""Tokenize a completion for diversity metrics.
Prefers the configured tokenizer when available; falls back to whitespace.
"""
if not text:
return []
if tokenizer is not None:
try:
encode = getattr(tokenizer, "encode", None)
if callable(encode):
return list(encode(text, add_special_tokens=False))
if callable(tokenizer):
tokenized = tokenizer(text, add_special_tokens=False)
if isinstance(tokenized, dict) and "input_ids" in tokenized:
return list(tokenized["input_ids"])
if isinstance(tokenized, (list, tuple)):
return list(tokenized)
except Exception:
pass
return [tok for tok in text.strip().split() if tok]
def _completion_diversity_metrics(
grouped_completions: List[List[str]],
*,
tokenizer: Any = None,
accelerator: Any = None,
) -> Dict[str, float]:
"""Return coarse diversity metrics for grouped completions.
Metrics are averaged across prompt groups so each prompt contributes equally.
When running distributed, gathers group metrics across ranks.
"""
if not grouped_completions:
return {}
def _distinct_n(tokens: List[Any], n: int) -> float:
if n <= 0 or len(tokens) < n:
return 0.0
total = len(tokens) - n + 1
if total <= 0:
return 0.0
ngrams = {tuple(tokens[i : i + n]) for i in range(total)}
return float(len(ngrams)) / float(total)
def _jaccard_distance(sets: List[set[Any]]) -> float:
if len(sets) < 2:
return 0.0
total_dist = 0.0
pairs = 0
for i in range(len(sets)):
for j in range(i + 1, len(sets)):
a = sets[i]
b = sets[j]
union = a | b
if not union:
dist = 0.0
else:
dist = 1.0 - (len(a & b) / float(len(union)))
total_dist += dist
pairs += 1
return total_dist / float(pairs) if pairs > 0 else 0.0
group_metrics: List[Dict[str, float]] = []
for group in grouped_completions:
if not group:
continue
normalized = [comp.strip() for comp in group if comp is not None]
group_size = len(normalized)
if group_size <= 0:
continue
all_tokens: List[Any] = []
token_sets: List[set[Any]] = []
for comp in normalized:
tokens = _tokenize_for_diversity(comp, tokenizer)
if tokens:
all_tokens.extend(tokens)
token_sets.append(set(tokens))
else:
token_sets.append(set())
group_metrics.append(
{
"group_size": float(group_size),
"distinct_1": _distinct_n(all_tokens, 1),
"distinct_2": _distinct_n(all_tokens, 2),
"jaccard": _jaccard_distance(token_sets),
}
)
if not group_metrics:
return {}
if accelerator is not None and getattr(accelerator, "num_processes", 1) > 1:
gather_fn = getattr(accelerator, "gather_object", None)
if callable(gather_fn):
try:
gather_fn_typed = cast(Callable[[Any], Any], gather_fn)
gathered = gather_fn_typed(group_metrics) # pylint: disable=not-callable
if isinstance(gathered, list):
merged: List[Dict[str, float]] = []
for item in gathered:
if isinstance(item, list):
merged.extend([m for m in item if isinstance(m, dict)])
elif isinstance(item, dict):
merged.append(item)
if merged:
group_metrics = merged
except Exception:
pass
else:
dist = getattr(torch, "distributed", None)
if (
dist is not None
and callable(getattr(dist, "is_available", None))
and callable(getattr(dist, "is_initialized", None))
and dist.is_available()
and dist.is_initialized()
):
try:
world = int(getattr(dist, "get_world_size")())
except (TypeError, ValueError, RuntimeError):
world = 0
if world > 1:
try:
gathered = [None for _ in range(world)]
gather_obj = getattr(dist, "all_gather_object", None)
if callable(gather_obj):
gather_obj(gathered, group_metrics)
merged: List[Dict[str, float]] = []
for item in gathered:
if isinstance(item, list):
merged.extend(
[m for m in item if isinstance(m, dict)]
)
elif isinstance(item, dict):
merged.append(item)
if merged:
group_metrics = merged
except (RuntimeError, ValueError, TypeError):
pass
distinct1_vals = [m["distinct_1"] for m in group_metrics if "distinct_1" in m]
distinct2_vals = [m["distinct_2"] for m in group_metrics if "distinct_2" in m]
jaccard_vals = [m["jaccard"] for m in group_metrics if "jaccard" in m]
weights = [m.get("group_size", 0.0) for m in group_metrics]
return {
"distinct_1": _mean(distinct1_vals),
"distinct_2": _mean(distinct2_vals),
"jaccard": _mean(jaccard_vals),
"distinct_1_micro": _weighted_mean(distinct1_vals, weights),
"distinct_2_micro": _weighted_mean(distinct2_vals, weights),
"jaccard_micro": _weighted_mean(jaccard_vals, weights),
}
def _dist_any_flag(accelerator: Any, flag: bool) -> bool:
"""Return True if flag is True on any rank (best-effort, object gather)."""
if getattr(accelerator, "num_processes", 1) <= 1:
return bool(flag)
torch_mod = sys.modules.get("torch")
dist = getattr(torch_mod, "distributed", None) if torch_mod is not None else None
if (
dist is None
or not callable(getattr(dist, "is_available", None))
or not callable(getattr(dist, "is_initialized", None))
or not dist.is_available()
or not dist.is_initialized()
):
return bool(flag)
get_world_size = getattr(dist, "get_world_size", None)
if not callable(get_world_size):
return bool(flag)
try:
world_size = int(cast(Any, get_world_size()))
except (TypeError, ValueError, RuntimeError):
return bool(flag)
gathered = [None for _ in range(max(world_size, 1))]
gather_fn = getattr(dist, "all_gather_object", None)
if not callable(gather_fn):
return bool(flag)
try:
gather_fn(gathered, bool(flag))
return any(bool(x) for x in gathered)
except (RuntimeError, ValueError, TypeError):
return bool(flag)
def _resolve_weighting_value(
ctx: TrainingLoopContext,
attribute: str,
default: Optional[float] = None,
) -> Optional[float]:
"""Return a weighting attribute with graceful fallbacks.
Some lightweight test contexts construct ``ctx.scoring.weighting`` as a
simple namespace that omits optional controller fields (e.g., ``beta`` and
``tau``). When those values are absent we try ``ctx.settings.scoring`` and
finally fall back to a default so :func:`compute_reward_statistics` always
receives valid arguments.
:param ctx: Training loop context supplying scoring configs.
:param attribute: Weighting attribute name to resolve.
:param default: Value returned when the attribute is missing everywhere.
:returns: Attribute value or ``default`` when undefined.
"""
scoring_cfg = getattr(ctx, "scoring", None)
weighting = getattr(scoring_cfg, "weighting", None)
if weighting is not None:
value = getattr(weighting, attribute, None)
if value is not None:
return value
settings = getattr(ctx, "settings", None)
if settings is not None:
settings_scoring = getattr(settings, "scoring", None)
settings_weighting = getattr(settings_scoring, "weighting", None)
if settings_weighting is not None:
value = getattr(settings_weighting, attribute, None)
if value is not None:
return value
return default
def _maybe_apply_entropy_bonus(
ctx: TrainingLoopContext,
gen_batch: GenerationBatch,
reward_comp: RewardComputation,
ref_stats: ReferenceLogprobs,
policy_entropy_sum: Optional[Any],
) -> RewardComputation:
"""Optionally add a policy-entropy bonus to rewards and refresh stats."""
scoring_cfg = getattr(ctx, "scoring", None)
bonus_coef = getattr(scoring_cfg, "policy_entropy_bonus_coef", 0.0)
try:
bonus_coef = float(bonus_coef)
except (TypeError, ValueError):
LOG.warning(
"Invalid policy_entropy_bonus_coef=%s; skipping entropy bonus.", bonus_coef
)
return reward_comp
if bonus_coef == 0.0 or not math.isfinite(bonus_coef):
return reward_comp
if policy_entropy_sum is None:
return reward_comp
total_utils = list(getattr(reward_comp, "total_utils", []) or [])
if not total_utils:
return reward_comp
device = getattr(policy_entropy_sum, "device", None)
dtype = getattr(policy_entropy_sum, "dtype", None)
try:
if not isinstance(policy_entropy_sum, torch.Tensor):
entropy_tensor = torch.tensor(
getattr(policy_entropy_sum, "arr", policy_entropy_sum),
device=device,
dtype=dtype or getattr(torch, "float32", None),
)
else:
entropy_tensor = policy_entropy_sum
except (TypeError, ValueError, RuntimeError):
return reward_comp
tok_counts = getattr(ref_stats, "ref_tok_counts", None)
try:
if not isinstance(tok_counts, torch.Tensor):
tok_tensor = torch.tensor(
getattr(tok_counts, "arr", tok_counts),
device=getattr(entropy_tensor, "device", None),
dtype=getattr(entropy_tensor, "dtype", None)
or getattr(torch, "float32", None),
)
else:
tok_tensor = tok_counts
except (TypeError, ValueError, RuntimeError):
return reward_comp
if getattr(entropy_tensor, "numel", lambda: 0)() == 0:
return reward_comp
if getattr(tok_tensor, "numel", lambda: 0)() == 0:
return reward_comp
entropy_tensor = entropy_tensor.view(-1).float()
tok_tensor = tok_tensor.view(-1).float()
target_len = len(total_utils)
try:
ent_len = int(entropy_tensor.numel())
except (TypeError, ValueError, RuntimeError, AttributeError):
ent_len = len(getattr(entropy_tensor, "data", []))
try:
tok_len = int(tok_tensor.numel())
except (TypeError, ValueError, RuntimeError, AttributeError):
tok_len = len(getattr(tok_tensor, "data", []))
n = min(target_len, ent_len, tok_len)
if n <= 0:
return reward_comp
if ent_len != target_len or tok_len != target_len:
LOG.debug(
"Entropy bonus length mismatch | rewards=%d entropy=%d tok_counts=%d; aligning to %d",
target_len,
ent_len,
tok_len,
n,
)
device = getattr(entropy_tensor, "device", None)
if device is not None and hasattr(tok_tensor, "to"):
try:
tok_tensor = tok_tensor.to(device)
except (TypeError, ValueError, RuntimeError):
pass
entropy_slice = entropy_tensor[:n]
tok_slice = tok_tensor[:n].clamp(min=1.0)
entropy_per_tok_raw = entropy_slice / tok_slice
entropy_per_tok = entropy_per_tok_raw
group_sizes = [
len(group) for group in getattr(gen_batch, "grouped_completions", []) or []
]
if group_sizes:
zscored = entropy_per_tok.clone()
offset = 0
for group_size in group_sizes:
if offset >= n:
break
take = min(group_size, n - offset)
if take <= 0:
offset += max(group_size, 0)
continue
slice_vals = entropy_per_tok[offset : offset + take]
mean_val = slice_vals.mean()
try:
std_val = slice_vals.std(unbiased=False)
except (TypeError, RuntimeError, ValueError, AttributeError):
std_val = slice_vals.std()
std_val = std_val.clamp(min=1e-6)
zscored[offset : offset + take] = (slice_vals - mean_val) / std_val
offset += group_size
if offset < n:
LOG.debug(
"Entropy bonus group sizes shorter than rewards | groups_total=%d rewards=%d",
sum(group_sizes),
n,
)
zscored[offset:n] = entropy_per_tok[offset:n]
entropy_per_tok = zscored
try:
entropy_per_tok = torch.nan_to_num(
entropy_per_tok, nan=0.0, posinf=0.0, neginf=0.0
)
except (AttributeError, TypeError, RuntimeError):
isfinite = getattr(torch, "isfinite", None)
if callable(isfinite):
entropy_per_tok = torch.where(
isfinite(entropy_per_tok),
entropy_per_tok,
torch.zeros_like(entropy_per_tok),
)
reward_std = None
moments = getattr(reward_comp, "moments", None)
if moments is not None:
reward_std = getattr(moments, "std", None)
try:
reward_std = float(reward_std)
except (TypeError, ValueError):
reward_std = None
if reward_std is None or not math.isfinite(reward_std) or reward_std <= 0.0:
reward_std = 1.0
zscale = bonus_coef * reward_std
bonus_tensor = entropy_per_tok * zscale
try:
bonus_vals = bonus_tensor.detach().float().cpu().tolist()
except (AttributeError, RuntimeError, TypeError, ValueError):
bonus_vals = [float(x) for x in getattr(bonus_tensor, "arr", bonus_tensor)]
if len(bonus_vals) < target_len:
bonus_vals.extend([0.0] * (target_len - len(bonus_vals)))
entropy_vals = [b / zscale if zscale != 0.0 else 0.0 for b in bonus_vals]
try:
entropy_raw_vals = entropy_per_tok_raw.detach().float().cpu().tolist()
except (AttributeError, RuntimeError, TypeError, ValueError):
entropy_raw_vals = [
float(x) for x in getattr(entropy_per_tok_raw, "arr", entropy_per_tok_raw)
]
if len(entropy_raw_vals) < target_len:
entropy_raw_vals.extend([0.0] * (target_len - len(entropy_raw_vals)))
new_total_utils = [float(u) + b for u, b in zip(total_utils, bonus_vals)]
per_reward_values = dict(getattr(reward_comp, "per_reward_values", {}) or {})
per_reward_values["policy_entropy_group_zscore"] = entropy_vals
per_reward_values["policy_entropy_per_token"] = entropy_raw_vals
per_reward_values["entropy_bonus"] = bonus_vals
moments = RewardMoments(*reward_moments(new_total_utils, ctx.runtime.device))
training_args = getattr(ctx, "training_args", None)
scale_rewards = True
if training_args is not None:
scale_rewards = bool(getattr(training_args, "scale_rewards", True))
advantage_grouped, advantage_samples = group_advantages(
gen_batch.grouped_completions,
new_total_utils,
scale_rewards=scale_rewards,
)
advantage_grouped, advantage_samples = _apply_group_scales(
advantage_grouped,
getattr(reward_comp, "seed_advantage_scales", None),
)
advantage_stats = AdvantageStats(advantage_grouped, advantage_samples)
q_temperature = _resolve_weighting_value(ctx, "q_temperature", 1.0) or 1.0
q_epsilon = _resolve_weighting_value(ctx, "q_epsilon", 1e-6) or 1e-6
q_distribution = QDistribution(
*_group_q_distribution(
gen_batch.grouped_completions,
new_total_utils,
q_temperature,
q_epsilon,
)
)
try:
reward_comp.total_utils = new_total_utils
reward_comp.per_reward_values = per_reward_values
reward_comp.advantage = advantage_stats
reward_comp.moments = moments
reward_comp.q_distribution = q_distribution
reward_comp.entropy_bonus_scale = reward_std
except (AttributeError, TypeError, ValueError):
reward_comp = RewardComputation(
total_utils=new_total_utils,
per_reward_values=per_reward_values,
advantage=advantage_stats,
pairs=reward_comp.pairs,
q_distribution=q_distribution,
moments=moments,
ref_logprob_meta=getattr(reward_comp, "ref_logprob_meta", None),
completion_metadata=getattr(reward_comp, "completion_metadata", None),
entropy_bonus_scale=reward_std,
seed_semantic_entropies=getattr(
reward_comp, "seed_semantic_entropies", None
),
seed_advantage_scales=getattr(
reward_comp, "seed_advantage_scales", None
),
seed_alpha_effective=getattr(
reward_comp, "seed_alpha_effective", None
),
seed_max_possible_entropy=getattr(
reward_comp, "seed_max_possible_entropy", None
),
)
return reward_comp
@dataclass
class _BatchStats:
"""Aggregated batch statistics before building losses."""
score_batch: ScoreBatch
ref_stats: ReferenceLogprobs
weight_stats: WeightStats
length_stats: LengthStats
num_completion_tokens: float
prompt_token_count: float
[docs]
@dataclass
class PreparedBatch:
"""Artifacts required to run optimization for a training batch.
:param grouped_completions: Nested list of completions per prompt.
:type grouped_completions: list[list[str]]
:param reward_comp: Reward statistics computed by
:func:`training.rewards.compute_reward_statistics`.
:type reward_comp: ~maxent_grpo.training.types.rewards.RewardComputation
:param batch_stats: Auxiliary scoring/weighting artifacts built by
:func:`_collect_batch_stats`.
:type batch_stats: _BatchStats
:param total_input_tokens: Prompt + completion token count used for
throughput logging.
:type total_input_tokens: float
:param scores: Structure containing current-model log-probabilities aligned
with the reference statistics.
:type scores: ~maxent_grpo.training.types.rewards.SequenceScores
"""
grouped_completions: List[List[str]]
reward_comp: RewardComputation
batch_stats: _BatchStats
total_input_tokens: float
scores: SequenceScores
diversity_metrics: Optional[Dict[str, float]] = None
@property
def weight_stats(self) -> WeightStats:
"""Shortcut to the batch weighting statistics."""
return self.batch_stats.weight_stats
@property
def ref_stats(self) -> ReferenceLogprobs:
"""Return reference log-probability statistics for the batch."""
return self.batch_stats.ref_stats
@property
def length_stats(self) -> LengthStats:
"""Return sequence length statistics computed for the batch."""
return self.batch_stats.length_stats
@property
def num_completion_tokens(self) -> float:
"""Return total completion token count used to build the batch."""
return self.batch_stats.num_completion_tokens
class _SkipBatch(RuntimeError):
"""Internal control-flow exception to skip invalid batches."""
def __init__(self, stage: str) -> None:
super().__init__(stage)
self.stage = stage or "unknown"
_T = TypeVar("_T")
def _require_artifact(value: Optional[_T], stage: str) -> _T:
"""Return ``value`` or raise the internal ``_SkipBatch`` sentinel.
:param value: Artifact produced by a preparation step.
:type value: Any | None
:raises _SkipBatch: When ``value`` is ``None`` indicating the step failed.
:returns: The validated artifact.
:rtype: Any
"""
if value is None:
raise _SkipBatch(stage)
return value
def _reference_stats_from_meta(
flat_meta: Optional[List[Optional[Any]]],
total_sequences: int,
device: "torch.device",
) -> Optional[ReferenceLogprobs]:
"""Return reference stats when metadata fully covers all sequences.
:param flat_meta: Flattened list of reference metadata per sequence.
:type flat_meta: list | None
:param total_sequences: Number of sequences expected in the batch.
:type total_sequences: int
:param device: Target device used for the resulting tensors.
:type device: torch.device
:returns: Reference log-probability statistics or ``None`` if metadata is
missing/partial.
:rtype: ~maxent_grpo.training.types.rewards.ReferenceLogprobs | None
"""
if not flat_meta:
return None
if total_sequences <= 0:
total_sequences = len(flat_meta)
if total_sequences <= 0:
return None
ref_fn = reference_from_vllm_meta
try:
return ref_fn(flat_meta, total_sequences, device)
except (RuntimeError, TypeError, ValueError):
return None
def _behavior_logp_tensor_from_meta(
flat_meta: Optional[List[Optional[Any]]],
total_sequences: int,
template_tensor: Any,
) -> Optional[Tensor]:
"""Return a tensor of behavior log-prob sums derived from metadata.
The metadata is expected to contain ``logprob_sum`` entries aligned with
the flattened completions list. When metadata is missing or incomplete
``None`` is returned so downstream callers can fall back to current-policy
log-probs.
:param flat_meta: Flattened metadata per sequence emitted by generation.
:type flat_meta: list | None
:param total_sequences: Expected number of completions in the batch.
:type total_sequences: int
:param template_tensor: Tensor used to infer device/dtype for the result.
:type template_tensor: torch.Tensor
:returns: Tensor of log-prob sums or ``None`` if unavailable.
:rtype: torch.Tensor | None
"""
if total_sequences <= 0:
return None
fallback_vals: Optional[List[float]] = None
if template_tensor is not None:
try:
if isinstance(template_tensor, torch.Tensor):
fallback_vals = template_tensor.detach().float().cpu().view(-1).tolist()
elif hasattr(template_tensor, "tolist"):
fallback_raw = template_tensor.tolist()
if (
isinstance(fallback_raw, list)
and fallback_raw
and isinstance(fallback_raw[0], list)
):
fallback_raw = fallback_raw[0]
fallback_vals = [float(val) for val in fallback_raw]
else:
fallback_vals = [float(val) for val in list(template_tensor)]
except (TypeError, ValueError, RuntimeError):
fallback_vals = None
meta_len = len(flat_meta) if flat_meta else 0
if (not flat_meta or meta_len < total_sequences) and not fallback_vals:
if meta_len > 0:
LOG.debug(
"Behavior log-prob metadata too short | meta_len=%d | sequences=%d",
meta_len,
total_sequences,
)
return None
logprob_vals: List[float] = []
missing = 0
for idx in range(total_sequences):
entry = flat_meta[idx] if flat_meta and idx < meta_len else None
if entry is None:
if fallback_vals is not None and idx < len(fallback_vals):
logprob_vals.append(float(fallback_vals[idx]))
missing += 1
continue
LOG.debug("Behavior log-prob metadata missing entry at idx=%d", idx)
return None
logprob_sum = getattr(entry, "logprob_sum", None)
if logprob_sum is None and isinstance(entry, dict):
logprob_sum = entry.get("logprob_sum")
if logprob_sum is None:
if fallback_vals is not None and idx < len(fallback_vals):
logprob_vals.append(float(fallback_vals[idx]))
missing += 1
continue
LOG.debug("Behavior log-prob metadata missing logprob_sum at idx=%d", idx)
return None
try:
logprob_vals.append(float(logprob_sum))
except (TypeError, ValueError):
if fallback_vals is not None and idx < len(fallback_vals):
logprob_vals.append(float(fallback_vals[idx]))
missing += 1
continue
LOG.debug(
"Behavior log-prob metadata has non-castable value at idx=%d: %s",
idx,
logprob_sum,
)
return None
if missing:
if not getattr(_behavior_logp_tensor_from_meta, "_warned_partial", False):
LOG.warning(
"Behavior log-prob metadata missing entries | missing_entries=%d/%d | "
"falling back to policy logprobs for missing entries.",
missing,
total_sequences,
)
setattr(_behavior_logp_tensor_from_meta, "_warned_partial", True)
else:
LOG.debug(
"Behavior log-prob metadata missing entries | missing_entries=%d/%d",
missing,
total_sequences,
)
new_tensor = getattr(template_tensor, "new_tensor", None)
tensor_obj: Optional[Tensor] = None
if callable(new_tensor):
try:
tensor_obj = cast(
Tensor,
new_tensor(
logprob_vals,
dtype=getattr(template_tensor, "dtype", None),
device=getattr(template_tensor, "device", None),
),
)
except (TypeError, ValueError, RuntimeError):
tensor_obj = None
if tensor_obj is None:
torch_mod = sys.modules.get("torch")
tensor_fn = (
getattr(torch_mod, "tensor", None) if torch_mod is not None else None
)
if callable(tensor_fn):
try:
tensor_obj = cast(
Tensor,
tensor_fn(
logprob_vals,
dtype=getattr(template_tensor, "dtype", None),
device=getattr(template_tensor, "device", None),
),
)
except (TypeError, ValueError, RuntimeError):
tensor_obj = None
if tensor_obj is None:
LOG.debug("Unable to convert behavior log-prob metadata into a tensor.")
return None
view_attr = getattr(tensor_obj, "view", None)
if callable(view_attr):
try:
tensor_obj = tensor_obj.view(-1)
except (TypeError, ValueError, RuntimeError):
LOG.debug("Failed to reshape logprob tensor to 1D.")
return tensor_obj
def _coerce_token_logprob_value(value: object) -> Optional[float]:
if value is None:
return None
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, Mapping):
if "logprob" in value:
return _coerce_token_logprob_value(value.get("logprob"))
if "log_prob" in value:
return _coerce_token_logprob_value(value.get("log_prob"))
if len(value) == 1:
return _coerce_token_logprob_value(next(iter(value.values())))
return None
attr_val = getattr(value, "logprob", None)
if attr_val is not None:
return _coerce_token_logprob_value(attr_val)
return None
def _extract_token_logprob_seq(entry: Optional[Any]) -> Optional[List[float]]:
if entry is None:
return None
token_logprobs = None
if isinstance(entry, Mapping):
token_logprobs = entry.get("token_logprobs") or entry.get("logprobs")
if isinstance(token_logprobs, Mapping):
token_logprobs = token_logprobs.get("token_logprobs") or token_logprobs.get(
"logprobs"
)
else:
token_logprobs = getattr(entry, "token_logprobs", None) or getattr(
entry, "logprobs", None
)
if token_logprobs is None:
return None
if isinstance(token_logprobs, (str, bytes, bytearray)):
return None
if not isinstance(token_logprobs, Iterable):
return None
cleaned: List[float] = []
for item in token_logprobs:
val = _coerce_token_logprob_value(item)
if val is None:
continue
cleaned.append(val)
return cleaned if cleaned else None
def _token_logp_tensor_from_meta(
flat_meta: Optional[List[Optional[Any]]],
total_sequences: int,
token_mask: Optional[Tensor],
fallback_token_logp: Optional[Tensor],
) -> Optional[Tensor]:
"""Return per-token log-prob tensor derived from vLLM metadata when available."""
if total_sequences <= 0 or not flat_meta:
return None
meta_len = len(flat_meta)
if meta_len <= 0:
return None
target_len: Optional[int] = None
if token_mask is not None:
try:
target_len = int(getattr(token_mask, "shape", [0, 0])[1])
except (TypeError, ValueError, IndexError):
target_len = None
if target_len is None and fallback_token_logp is not None:
try:
target_len = int(getattr(fallback_token_logp, "shape", [0, 0])[1])
except (TypeError, ValueError, IndexError):
target_len = None
sequences: List[List[float]] = []
missing = 0
for idx in range(total_sequences):
entry = flat_meta[idx] if idx < meta_len else None
seq_vals = _extract_token_logprob_seq(entry)
if seq_vals is None:
if fallback_token_logp is not None:
try:
row = fallback_token_logp[idx].detach().float().cpu().tolist()
if isinstance(row, list):
seq_vals = [float(val) for val in row]
else:
seq_vals = None
except (
TypeError,
ValueError,
RuntimeError,
AttributeError,
IndexError,
):
seq_vals = None
if seq_vals is None:
return None
missing += 1
if target_len is None:
target_len = max(len(seq_vals), 0)
sequences.append(seq_vals)
if target_len is None or target_len <= 0:
return None
aligned: List[List[float]] = []
for seq_vals in sequences:
seq_len = len(seq_vals)
if seq_len >= target_len:
aligned.append(seq_vals[-target_len:])
else:
pad = [0.0] * (target_len - seq_len)
aligned.append(pad + seq_vals)
if missing:
if not getattr(_token_logp_tensor_from_meta, "_warned_partial", False):
LOG.warning(
"Token logprob metadata missing entries | missing_entries=%d/%d | "
"falling back to policy token logprobs for missing rows.",
missing,
total_sequences,
)
setattr(_token_logp_tensor_from_meta, "_warned_partial", True)
else:
LOG.debug(
"Token logprob metadata missing entries | missing_entries=%d/%d",
missing,
total_sequences,
)
dtype = None
device = None
if isinstance(fallback_token_logp, torch.Tensor):
dtype = fallback_token_logp.dtype
device = fallback_token_logp.device
elif isinstance(token_mask, torch.Tensor):
device = token_mask.device
if dtype is None:
dtype = getattr(torch, "float32", None)
try:
return torch.tensor(aligned, dtype=dtype, device=device)
except (TypeError, ValueError, RuntimeError):
return None
def _collect_batch_stats(
ctx: TrainingLoopContext,
gen_batch: GenerationBatch,
reward_comp: RewardComputation,
*,
score_batch: Optional[ScoreBatch] = None,
cur_logp_sum: Optional[Any] = None,
policy_entropy_sum: Optional[Any] = None,
) -> Optional[_BatchStats]:
"""Gather scoring, reference, and weighting artifacts for a batch.
:param ctx: Training loop context supplying runtime/scoring handles.
:type ctx: training.types.TrainingLoopContext
:param gen_batch: Outputs from :func:`prepare_generation_batch`.
:type gen_batch: training.types.GenerationBatch
:param reward_comp: Reward statistics used to build weighting/reward logs.
:type reward_comp: ~maxent_grpo.training.types.rewards.RewardComputation
:returns: Aggregated structures required downstream, or ``None`` when any
stage fails (e.g., reference log-prob gathering).
:rtype: _BatchStats | None
"""
ref_stats = None
last_ref_stats = getattr(ctx, "_last_ref_stats", None)
def _ref_stats_empty(candidate: Optional[ReferenceLogprobs]) -> bool:
if candidate is not None and not getattr(
_collect_batch_stats, "_ref_candidate_seen", False
):
LOG.debug(
"Reference stats candidate present | type=%s | ref_logp_sum_raw_shape=%s | ref_logp_sum_shape=%s | ref_tok_counts_shape=%s",
type(candidate).__name__,
getattr(getattr(candidate, "ref_logp_sum_raw", None), "shape", None),
getattr(getattr(candidate, "ref_logp_sum", None), "shape", None),
getattr(getattr(candidate, "ref_tok_counts", None), "shape", None),
)
setattr(_collect_batch_stats, "_ref_candidate_seen", True)
if candidate is None:
_log_ref_diag = not getattr(_collect_batch_stats, "_ref_diag_logged", False)
if _log_ref_diag:
LOG.warning(
"Reference stats deemed empty: candidate=None | ref_logp_sum_raw=None | ref_logp_sum=None | ref_tok_counts=None"
)
setattr(_collect_batch_stats, "_ref_diag_logged", True)
return True
tensor = getattr(candidate, "ref_logp_sum_raw", None)
if tensor is None:
tensor = getattr(candidate, "ref_logp_sum", None)
if tensor is None:
try:
length_fn = getattr(candidate, "__len__", None)
if callable(length_fn):
return length_fn() == 0
except TypeError:
return False
numel = getattr(tensor, "numel", None)
if callable(numel):
try:
return numel() == 0
except (RuntimeError, TypeError, ValueError):
LOG.debug("Unable to compute numel for reference stats tensor.")
to_list = getattr(tensor, "tolist", None)
data = tensor
if callable(to_list):
try:
data = to_list()
except (RuntimeError, TypeError, ValueError):
data = tensor
if not isinstance(data, Sized):
return False
length = len(data)
is_empty = length == 0
if is_empty and not getattr(_collect_batch_stats, "_ref_diag_logged", False):
def _describe(obj: Any) -> str:
if obj is None:
return "None"
shape = getattr(obj, "shape", None)
numel_fn = getattr(obj, "numel", None)
numel_val = None
if callable(numel_fn):
try:
numel_val = numel_fn()
except (RuntimeError, TypeError, ValueError):
numel_val = "error"
return f"{type(obj).__name__}(shape={shape}, numel={numel_val})"
LOG.warning(
"Reference stats deemed empty: candidate=%s | ref_logp_sum_raw=%s | ref_logp_sum=%s | ref_tok_counts=%s",
type(candidate).__name__,
_describe(getattr(candidate, "ref_logp_sum_raw", None)),
_describe(getattr(candidate, "ref_logp_sum", None)),
_describe(getattr(candidate, "ref_tok_counts", None)),
)
setattr(_collect_batch_stats, "_ref_diag_logged", True)
elif not is_empty and not getattr(
_collect_batch_stats, "_ref_diag_logged_success", False
):
LOG.debug(
"Reference stats non-empty | ref_logp_sum_raw_shape=%s | ref_tok_counts_shape=%s",
getattr(getattr(candidate, "ref_logp_sum_raw", None), "shape", None),
getattr(getattr(candidate, "ref_tok_counts", None), "shape", None),
)
setattr(_collect_batch_stats, "_ref_diag_logged_success", True)
# If logp sum tensors are length zero but token counts exist, treat as non-empty
tok_counts = getattr(candidate, "ref_tok_counts", None)
if tok_counts is not None:
tok_numel = _safe_numel(tok_counts)
logp_sum_raw_numel = _safe_numel(
getattr(candidate, "ref_logp_sum_raw", None)
)
logp_sum_numel = _safe_numel(getattr(candidate, "ref_logp_sum", None))
if (
tok_numel
and tok_numel > 0
and (logp_sum_raw_numel == 0 or logp_sum_numel == 0)
):
LOG.debug(
"Reference stats: allowing zero-length logp_sum because tok_counts exist | tok_numel=%s | logp_sum_raw_numel=%s | logp_sum_numel=%s",
tok_numel,
logp_sum_raw_numel,
logp_sum_numel,
)
return False
return is_empty
def _warn_fallback(reason: str) -> None:
flag = getattr(_collect_batch_stats, "_fallback_warned", False)
if flag:
return
LOG.error(
"Reference scoring degraded (%s); configured to reuse last cached ReferenceLogprobs. "
"Set maxent_allow_stale_reference_logprobs=false to skip batches instead.",
reason,
)
setattr(_collect_batch_stats, "_fallback_warned", True)
def _retry_reference_gather(
score_batch_retry: ScoreBatch, batching_cfg_retry: BatchingSettings
) -> Optional[ReferenceLogprobs]:
original_slice = getattr(score_batch_retry, "slice_size", None)
original_chunk = getattr(batching_cfg_retry, "logprob_chunk_size", None)
slice_val = int(original_slice or score_batch_retry.total_sequences or 1)
reduced_slice = max(1, slice_val // 2)
if reduced_slice == original_slice:
if reduced_slice > 1:
reduced_slice -= 1
else:
reduced_slice = 1
if original_slice is not None and reduced_slice == original_slice:
return None
logprob_chunk = int(original_chunk or reduced_slice)
reduced_chunk = max(1, logprob_chunk // 2) if logprob_chunk > 1 else 1
try:
score_batch_retry.slice_size = reduced_slice
batching_cfg_retry.logprob_chunk_size = reduced_chunk
try:
result = gather_reference_logprobs(
score_batch_retry,
ctx.runtime,
batching_cfg_retry,
trl_reference_scoring=trl_reference_scoring,
temperature=ref_temperature,
)
LOG.debug(
"Retry reference gather result | slice_size=%s | chunk_size=%s | result=%s",
reduced_slice,
reduced_chunk,
_describe_ref(result),
)
return result
except (
RuntimeError,
ValueError,
TypeError,
AssertionError,
) as exc: # pragma: no cover - best-effort logging
LOG.warning("Retry reference gather failed: %s", exc)
return None
finally:
if original_slice is not None:
score_batch_retry.slice_size = original_slice
if original_chunk is not None:
batching_cfg_retry.logprob_chunk_size = original_chunk
scoring_cfg = getattr(ctx, "scoring", None)
if scoring_cfg is None:
scoring_cfg = getattr(
getattr(ctx, "settings", SimpleNamespace()), "scoring", None
)
if scoring_cfg is None:
scoring_cfg = SimpleNamespace()
batching_cfg = getattr(scoring_cfg, "batching", SimpleNamespace())
if not getattr(batching_cfg, "prompt_length_cache_get", None):
runtime_cache = getattr(getattr(ctx, "runtime", None), "prompt_cache_get", None)
if callable(runtime_cache):
batching_cfg.prompt_length_cache_get = runtime_cache
else:
batching_cfg.prompt_length_cache_get = (
lambda _p, _cls=PromptCacheEntry: _cls(input_ids=[], attention_mask=[])
)
batching_cfg = cast(BatchingSettings, batching_cfg)
gen_cfg = getattr(ctx, "generation", None)
if gen_cfg is None:
gen_cfg = getattr(
getattr(ctx, "settings", SimpleNamespace()), "generation", None
)
gen_cfg = cast(GenerationSettings, gen_cfg)
trl_reference_scoring = bool(getattr(scoring_cfg, "trl_reference_scoring", False))
ref_temperature = getattr(gen_cfg, "gen_temperature", None)
if score_batch is None:
score_batch = build_score_batch(
reward_comp,
ctx.runtime.tokenizer,
gen_cfg,
batching_cfg,
)
accelerator = getattr(ctx.runtime, "accelerator", None)
# Under DeepSpeed ZeRO, reference scoring may invoke collective param gathers even
# in no_grad forward passes. If ranks diverge (some build an empty score batch),
# later collectives can hang. Make the skip decision consistent across ranks.
if _deepspeed_zero_stage(accelerator) >= 2 and _dist_any_flag(
accelerator, score_batch is None
):
if score_batch is None:
LOG.warning(
"Score batch build failed; completions=%d | prompts=%d",
len(getattr(reward_comp.pairs, "completions", []) or []),
len(getattr(reward_comp.pairs, "prompts", []) or []),
)
LOG.warning(
"Skipping batch because at least one rank could not build a ScoreBatch "
"(DeepSpeed ZeRO safety guard)."
)
return None
if score_batch is None:
LOG.warning(
"Score batch build failed; completions=%d | prompts=%d",
len(getattr(reward_comp.pairs, "completions", []) or []),
len(getattr(reward_comp.pairs, "prompts", []) or []),
)
return None
completion_ids = getattr(score_batch, "completion_ids", None)
completion_attention_mask = getattr(score_batch, "completion_attention_mask", None)
LOG.debug(
"Score batch built | total_sequences=%d | max_prompt_len=%s | slice_size=%s | comp_ids_shape=%s | comp_mask_shape=%s | pad_id=%s",
getattr(score_batch, "total_sequences", 0),
getattr(score_batch, "max_prompt_len", None),
getattr(score_batch, "slice_size", None),
completion_ids.shape if completion_ids is not None else None,
(
completion_attention_mask.shape
if completion_attention_mask is not None
else None
),
getattr(score_batch, "pad_token_id", None),
)
weighting_cfg = getattr(scoring_cfg, "weighting", None)
grpo_mode = bool(getattr(weighting_cfg, "train_grpo_objective", False))
ref_meta = getattr(reward_comp, "ref_logprob_meta", None)
ref_source = (
str(getattr(scoring_cfg, "reference_logprobs_source", "auto") or "auto")
.strip()
.lower()
)
if grpo_mode:
ref_source = "model"
ref_meta = None
force_reference_model = ref_source in {
"model",
"reference_model",
"ref_model",
"reference",
}
ref_stats_source = "unknown"
ref_meta_len = len(ref_meta) if ref_meta else 0
if not grpo_mode and ref_source in {"policy", "none"}:
ref_meta_len = 0
if cur_logp_sum is not None:
try:
tok_counts = token_counts_from_score_batch(
score_batch, ctx.runtime, batching_cfg
)
ref_stats = reference_stats_from_policy_logprobs(
cur_logp_sum, tok_counts
)
ref_stats_source = "policy_logprobs"
if not getattr(
_collect_batch_stats, "_policy_ref_forced_warned", False
):
LOG.info(
"Reference logprobs source=%s; using policy logprobs as reference (no frozen reference model).",
ref_source,
)
setattr(_collect_batch_stats, "_policy_ref_forced_warned", True)
except (RuntimeError, ValueError, TypeError, AttributeError) as exc:
LOG.warning("Policy-logprob reference fallback failed: %s", exc)
else:
LOG.warning(
"Reference logprobs source=%s but policy logprobs are unavailable; "
"reference scoring may fall back to reference model.",
ref_source,
)
elif not grpo_mode and ref_meta_len and not force_reference_model:
# Prefer reconstructing from metadata; always make an initial attempt.
ref_stats = _reference_stats_from_meta(
ref_meta,
score_batch.total_sequences,
ctx.runtime.device,
)
if ref_meta_len != score_batch.total_sequences:
# Mismatch path: retry metadata reconstruction using score batch length.
ref_stats = _reference_stats_from_meta(
ref_meta,
score_batch.total_sequences,
ctx.runtime.device,
)
if ref_stats is not None:
ref_stats_source = "vllm_meta"
if (
not grpo_mode
and ref_stats is None
and not force_reference_model
and cur_logp_sum is not None
):
# Prefer policy logprobs over a reference-model pass when metadata is missing.
try:
tok_counts = token_counts_from_score_batch(
score_batch, ctx.runtime, batching_cfg
)
ref_stats = reference_stats_from_policy_logprobs(cur_logp_sum, tok_counts)
ref_stats_source = "policy_logprobs"
if not getattr(_collect_batch_stats, "_policy_ref_warned", False):
LOG.warning(
"vLLM did not provide reference logprob metadata; using policy logprobs as reference "
"(KL ~= 0 fallback; set vllm_return_logprobs=true or "
"maxent_reference_logprobs_source=model to force a reference-model pass)."
)
setattr(_collect_batch_stats, "_policy_ref_warned", True)
except (
RuntimeError,
ValueError,
TypeError,
AttributeError,
) as exc: # pragma: no cover - defensive diagnostics
LOG.warning("Policy-logprob reference fallback failed: %s", exc)
needs_ref_model_local = bool(force_reference_model or ref_stats is None)
needs_ref_model_any = needs_ref_model_local
# Keep reference scoring branches aligned under ZeRO to avoid mismatched collectives.
if _deepspeed_zero_stage(accelerator) >= 2:
needs_ref_model_any = _dist_any_flag(accelerator, needs_ref_model_local)
if needs_ref_model_any:
use_ref_try = bool(force_reference_model or ref_stats is None)
ref_try = None
try:
ref_try = gather_reference_logprobs(
score_batch,
ctx.runtime,
batching_cfg,
trl_reference_scoring=trl_reference_scoring,
temperature=ref_temperature,
)
except (RuntimeError, AssertionError) as exc:
LOG.warning("Failed to gather reference logprobs: %s", exc)
occurrence = _REF_LOGPROB_TRACE_LIMITER.next_occurrence()
if occurrence is not None:
LOG.error(
"Reference logprob traceback (occurrence %d/%d):\n%s",
occurrence,
_REF_LOGPROB_TRACE_LIMIT,
traceback.format_exc(),
)
except (
ValueError,
TypeError,
AttributeError,
): # pragma: no cover - defensive diag
LOG.error(
"Unexpected exception during gather_reference_logprobs: %s",
traceback.format_exc(),
)
# Always run the gather on every rank when any rank needs it, but only
# overwrite precomputed metadata-derived stats when needed/forced.
if use_ref_try:
ref_stats = ref_try
if ref_stats is None:
LOG.warning(
"gather_reference_logprobs returned None | slice_size=%s chunk_size=%s device=%s | ref_meta_len=%d | total_sequences=%d",
getattr(score_batch, "slice_size", None),
getattr(batching_cfg, "logprob_chunk_size", None),
getattr(ctx.runtime, "device", None),
ref_meta_len,
getattr(score_batch, "total_sequences", 0),
)
else:
ref_stats_source = "reference_model"
LOG.debug(
"Reference stats gathered | type=%s | ref_logp_sum_shape=%s | ref_tok_counts_shape=%s",
type(ref_stats).__name__,
getattr(getattr(ref_stats, "ref_logp_sum", None), "shape", None),
getattr(getattr(ref_stats, "ref_tok_counts", None), "shape", None),
)
elif ref_try is None:
LOG.debug(
"Reference gather ran for ZeRO alignment but returned None; keeping metadata-derived stats."
)
def _safe_numel(tensor: Any) -> Any:
numel_fn = getattr(tensor, "numel", None)
if callable(numel_fn):
try:
return numel_fn()
except (RuntimeError, ValueError, TypeError):
return "error"
return None
def _describe_ref(obj: Any) -> str:
if obj is None:
return "None"
return (
f"{type(obj).__name__}(shape_logp_sum_raw="
f"{getattr(getattr(obj, 'ref_logp_sum_raw', None), 'shape', None)}, "
f"shape_logp_sum={getattr(getattr(obj, 'ref_logp_sum', None), 'shape', None)}, "
f"shape_tok_counts={getattr(getattr(obj, 'ref_tok_counts', None), 'shape', None)}, "
f"numel_logp_sum_raw={_safe_numel(getattr(obj, 'ref_logp_sum_raw', None))}, "
f"numel_logp_sum={_safe_numel(getattr(obj, 'ref_logp_sum', None))}, "
f"numel_tok_counts={_safe_numel(getattr(obj, 'ref_tok_counts', None))})"
)
LOG.debug(
"Reference stats post gather | is_none=%s | type=%s | ref_logp_sum_raw_shape=%s | ref_logp_sum_shape=%s | ref_tok_counts_shape=%s | ref_logp_sum_raw_numel=%s | ref_logp_sum_numel=%s | ref_tok_counts_numel=%s",
ref_stats is None,
type(ref_stats).__name__ if ref_stats is not None else None,
getattr(getattr(ref_stats, "ref_logp_sum_raw", None), "shape", None),
getattr(getattr(ref_stats, "ref_logp_sum", None), "shape", None),
getattr(getattr(ref_stats, "ref_tok_counts", None), "shape", None),
(
_safe_numel(getattr(ref_stats, "ref_logp_sum_raw", None))
if ref_stats is not None
else None
),
(
_safe_numel(getattr(ref_stats, "ref_logp_sum", None))
if ref_stats is not None
else None
),
(
_safe_numel(getattr(ref_stats, "ref_tok_counts", None))
if ref_stats is not None
else None
),
)
fallback_guard = getattr(_collect_batch_stats, "_fallback_warned", False)
LOG.debug(
"Reference stats emptiness check | ref_stats=%s | last_ref_stats=%s | ref_meta_len=%d | total_sequences=%d",
_describe_ref(ref_stats),
_describe_ref(last_ref_stats),
ref_meta_len,
getattr(score_batch, "total_sequences", 0),
)
if _ref_stats_empty(ref_stats):
retry_stats = _retry_reference_gather(score_batch, batching_cfg)
if not _ref_stats_empty(retry_stats):
ref_stats = retry_stats
elif last_ref_stats is None and not fallback_guard:
# No cache yet and retry failed: force a minimal slice/chunk attempt.
fallback_guard = True
LOG.debug(
"Attempting forced minimal reference gather | orig_slice=%s orig_chunk=%s",
getattr(score_batch, "slice_size", None),
getattr(batching_cfg, "logprob_chunk_size", None),
)
single_score = ScoreBatch(
prompt_entries=score_batch.prompt_entries,
completion_ids=score_batch.completion_ids,
completion_attention_mask=score_batch.completion_attention_mask,
pad_token_id=score_batch.pad_token_id,
max_prompt_len=score_batch.max_prompt_len,
slice_size=1,
total_sequences=score_batch.total_sequences,
score_tail_tokens=score_batch.score_tail_tokens,
)
single_batching = BatchingSettings(
logprob_chunk_size=1,
score_slice=1,
prompt_length_cache_get=getattr(
batching_cfg,
"prompt_length_cache_get",
lambda _p, _cls=PromptCacheEntry: _cls(
input_ids=[], attention_mask=[]
),
),
score_tail_tokens=getattr(batching_cfg, "score_tail_tokens", None),
slice_prefetch=0,
prompt_cache_size=getattr(batching_cfg, "prompt_cache_size", 0),
)
try:
forced = gather_reference_logprobs(
single_score,
ctx.runtime,
single_batching,
trl_reference_scoring=trl_reference_scoring,
temperature=ref_temperature,
)
except (RuntimeError, ValueError, TypeError, AssertionError):
LOG.warning(
"Forced minimal reference gather raised an exception; skipping."
)
forced = None
else:
LOG.debug(
"Forced minimal reference gather result | ref_stats=%s",
_describe_ref(forced),
)
if not _ref_stats_empty(forced):
ref_stats = forced
if _ref_stats_empty(ref_stats) and last_ref_stats is not None:
allow_stale = (
False
if grpo_mode
else bool(getattr(scoring_cfg, "allow_stale_reference_logprobs", False))
)
if not allow_stale:
LOG.warning(
"Reference gather empty; skipping batch instead of reusing stale ref stats "
"(enable maxent_allow_stale_reference_logprobs to override)."
)
try:
setattr(ctx.runtime, "_last_skip_stage", "reference_logprobs")
except (AttributeError, TypeError):
LOG.debug("Failed to record reference_logprobs skip stage on runtime.")
return None
LOG.warning(
"Reference gather empty; reusing last ref stats | last_ref_shapes=%s/%s",
getattr(getattr(last_ref_stats, "ref_logp_sum", None), "shape", None),
getattr(getattr(last_ref_stats, "ref_tok_counts", None), "shape", None),
)
_warn_fallback("reference gather returned empty tensors")
ref_stats = last_ref_stats
ref_stats_source = "stale_cached"
if _ref_stats_empty(ref_stats):
LOG.error(
"Reference scoring returned empty tensors even after retries; meta_len=%d | sequences=%d",
ref_meta_len,
getattr(score_batch, "total_sequences", 0),
)
try:
setattr(ctx.runtime, "_last_skip_stage", "reference_logprobs")
except (AttributeError, TypeError):
LOG.debug("Failed to record reference_logprobs skip stage on runtime.")
return None
prev_source = getattr(ctx, "_ref_logprobs_source", None)
if prev_source != ref_stats_source and ref_stats_source != "unknown":
LOG.info("Reference logprobs source=%s", ref_stats_source)
try:
setattr(ctx, "_ref_logprobs_source", ref_stats_source)
except (AttributeError, TypeError):
LOG.debug("Failed to update reference logprobs source on context.")
setattr(ctx, "_last_ref_stats", ref_stats)
LOG.debug(
"Reference stats gathered | avg_completion_tokens=%.2f",
getattr(ref_stats, "avg_completion_tokens", 0.0),
)
if ref_stats is None:
return None
ref_stats = cast(ReferenceLogprobs, ref_stats)
prompt_token_count = 0.0
prompt_entries = score_batch.prompt_entries
if prompt_entries:
max_prompt_len = score_batch.max_prompt_len
prompt_token_count = float(
sum(min(entry.length, max_prompt_len) for entry in prompt_entries)
)
reward_comp = _maybe_apply_entropy_bonus(
ctx,
gen_batch,
reward_comp,
ref_stats,
policy_entropy_sum,
)
weighting_cfg = cast(
WeightingSettings, getattr(scoring_cfg, "weighting", SimpleNamespace())
)
weight_stats = compute_weight_stats(
gen_batch.grouped_completions,
reward_comp,
ref_stats,
weighting_cfg,
)
grpo_mode = bool(getattr(weighting_cfg, "train_grpo_objective", False))
fallback_enabled = (
bool(getattr(weighting_cfg, "allow_empty_weight_fallback", False))
and not grpo_mode
)
if weight_stats is None or not getattr(weight_stats, "flat_weights", None):
fallback_weights = None
if fallback_enabled:
fallback_weights = build_uniform_weight_stats(gen_batch.grouped_completions)
if fallback_weights is not None:
LOG.warning(
"MaxEnt weighting returned no samples; falling back to uniform GRPO weights for this batch."
)
if fallback_weights is not None:
weight_stats = fallback_weights
else:
if grpo_mode:
LOG.error(
"GRPO weighting returned no samples; check reward outputs or scale_rewards."
)
else:
LOG.error(
"MaxEnt weighting returned no samples; check reward outputs, `maxent_tau`, or `maxent_q_temperature`."
)
return None
LOG.debug(
"Weight stats ready | entropy=%.4f",
getattr(weight_stats, "weight_entropy", 0.0),
)
_, length_stats, num_completion_tokens = summarize_completion_lengths(
ref_stats,
ctx.generation.max_completion_len,
)
return _BatchStats(
score_batch=score_batch,
ref_stats=ref_stats,
weight_stats=weight_stats,
length_stats=length_stats,
num_completion_tokens=num_completion_tokens,
prompt_token_count=prompt_token_count,
)
[docs]
def prepare_training_batch(
ctx: TrainingLoopContext,
generator: GenerationFn[Any],
batch: Dict[str, List[str]],
) -> Optional[PreparedBatch]:
"""Return a :class:`PreparedBatch` or ``None`` when any stage fails.
:param ctx: Full training context containing generation/scoring configs.
:type ctx: training.types.TrainingLoopContext
:param generator: Callable that produces grouped completions (typically
from :class:`training.rollout.CompletionGenerator`).
:type generator: training.types.GenerationFn
:param batch: Mini-batch produced by the training dataloader.
:type batch: dict[str, list[str]]
:returns: Fully-populated batch artifacts or ``None`` if generation,
reward computation, reference scoring, or policy scoring fails.
:rtype: PreparedBatch | None
"""
try:
prompt_value = cast(Any, batch.get("prompt"))
if isinstance(prompt_value, str):
batch = dict(batch)
batch["prompt"] = [prompt_value]
answer_value = cast(Any, batch.get("answer"))
if isinstance(answer_value, str):
batch = dict(batch)
batch["answer"] = [answer_value]
retry_limit = ctx.generation.vllm_rounds_cfg
if retry_limit <= 0:
retry_limit = ctx.optimization.schedule.num_generations
rank_tag = _rank_tag(getattr(ctx.runtime, "accelerator", None))
accelerator = getattr(ctx.runtime, "accelerator", None)
is_main = bool(getattr(accelerator, "is_main_process", True))
progress_log = _progress_log_enabled()
LOG.debug(
"Preparing training batch | %s | prompts=%d | retry_limit=%d",
rank_tag,
len(batch.get("prompt", [])),
retry_limit,
)
gen_start = time.monotonic()
if progress_log and is_main:
LOG.info(
"Stage generation start | %s | prompts=%d | num_generations=%d | retry_limit=%d",
rank_tag,
len(batch.get("prompt", [])),
ctx.optimization.schedule.num_generations,
retry_limit,
)
try:
gen_batch = _require_artifact(
prepare_generation_batch(
batch,
generator,
ctx.generation.generation_stats,
ctx.optimization.schedule.num_generations,
max_retry_rounds=retry_limit,
),
stage="generation",
)
except TypeError:
gen_batch = _require_artifact(
prepare_generation_batch(
batch,
generator,
ctx.generation.generation_stats,
ctx.optimization.schedule.num_generations,
max_retry_rounds=retry_limit,
),
stage="generation",
)
grouped = getattr(gen_batch, "grouped_completions", []) or []
group_count = len(grouped)
total_comps = sum(len(group) for group in grouped) if grouped else 0
empty_groups = sum(1 for group in grouped if not group) if grouped else 0
min_group = min((len(group) for group in grouped), default=0)
max_group = max((len(group) for group in grouped), default=0)
avg_group = total_comps / max(group_count, 1)
runtime_tokenizer = getattr(ctx.runtime, "tokenizer", None)
diversity_metrics = _completion_diversity_metrics(
grouped,
tokenizer=runtime_tokenizer if callable(runtime_tokenizer) else None,
accelerator=accelerator,
)
LOG.debug(
"Generation complete | %s | grouped_prompts=%d | total_completions=%d | empty_groups=%d | min_group=%d | max_group=%d | avg_group_size=%.2f",
rank_tag,
group_count,
total_comps,
empty_groups,
min_group,
max_group,
avg_group,
)
if progress_log and is_main:
LOG.info(
"Stage generation done | %s | grouped_prompts=%d | total_completions=%d | seconds=%.2f",
rank_tag,
group_count,
total_comps,
time.monotonic() - gen_start,
)
q_temperature = _resolve_weighting_value(ctx, "q_temperature", 1.0)
if q_temperature is None:
q_temperature = 1.0
q_epsilon = _resolve_weighting_value(ctx, "q_epsilon", 1e-6)
if q_epsilon is None:
q_epsilon = 1e-6
reward_start = time.monotonic()
if progress_log and is_main:
LOG.info(
"Stage reward stats start | %s | completions=%d",
rank_tag,
total_comps,
)
training_args = getattr(ctx, "training_args", None)
scale_rewards = True
if training_args is not None:
scale_rewards = bool(getattr(training_args, "scale_rewards", True))
reward_comp = _require_artifact(
compute_reward_statistics(
gen_batch,
ctx.reward,
ctx.runtime.device,
q_temperature,
q_epsilon,
_resolve_weighting_value(ctx, "beta"),
_resolve_weighting_value(ctx, "tau"),
scale_rewards=scale_rewards,
zero_truncated_completion_rewards=bool(
getattr(training_args, "zero_truncated_completion_rewards", False)
)
if training_args is not None
else False,
max_completion_len=int(
getattr(training_args, "max_completion_length", 0) or 0
)
if training_args is not None
else 0,
seed_grpo_enabled=bool(
getattr(training_args, "seed_grpo_enabled", False)
)
if training_args is not None
else False,
seed_grpo_alpha=float(
getattr(training_args, "seed_grpo_alpha", 0.0417)
)
if training_args is not None
else 0.0417,
seed_grpo_alpha_normalize_by_max_entropy=bool(
getattr(
training_args,
"seed_grpo_alpha_normalize_by_max_entropy",
True,
)
)
if training_args is not None
else True,
seed_grpo_length_normalize_logprobs=bool(
getattr(training_args, "seed_grpo_length_normalize_logprobs", True)
)
if training_args is not None
else True,
seed_grpo_num_generations=int(
getattr(training_args, "num_generations", 0) or 0
)
if training_args is not None
else 0,
),
stage="reward_stats",
)
reward_mean = float(getattr(getattr(reward_comp, "moments", None), "mean", 0.0))
reward_std = float(getattr(getattr(reward_comp, "moments", None), "std", 0.0))
LOG.debug(
"Reward statistics ready | %s | completions=%d | reward_mean=%.4f | reward_std=%.4f",
rank_tag,
len(getattr(reward_comp.pairs, "completions", []) or []),
reward_mean,
reward_std,
)
if progress_log and is_main:
LOG.info(
"Stage reward stats done | %s | reward_mean=%.4f | reward_std=%.4f | seconds=%.2f",
rank_tag,
reward_mean,
reward_std,
time.monotonic() - reward_start,
)
if not callable(runtime_tokenizer):
# Unit tests often stub `ctx.runtime.tokenizer`; preserve the previous
# control-flow by letting `_collect_batch_stats` supply the ScoreBatch.
stats = _require_artifact(
_collect_batch_stats(ctx, gen_batch, reward_comp),
stage="batch_stats",
)
score_batch = stats.score_batch
LOG.debug(
"Batch stats ready | %s | sequences=%d | prompt_tokens=%.0f | completion_tokens=%.0f",
rank_tag,
getattr(stats.score_batch, "total_sequences", 0),
stats.prompt_token_count,
stats.num_completion_tokens,
)
else:
score_batch = build_score_batch(
reward_comp,
cast(PreTrainedTokenizer, runtime_tokenizer),
ctx.generation,
ctx.scoring.batching,
)
accelerator = getattr(ctx.runtime, "accelerator", None)
if _deepspeed_zero_stage(accelerator) >= 2 and _dist_any_flag(
accelerator, score_batch is None
):
if score_batch is None:
LOG.warning(
"Score batch build failed | %s | completions=%d | prompts=%d",
rank_tag,
len(getattr(reward_comp.pairs, "completions", []) or []),
len(getattr(reward_comp.pairs, "prompts", []) or []),
)
LOG.warning(
"Skipping batch because at least one rank could not build a ScoreBatch "
"(DeepSpeed ZeRO safety guard)."
)
return None
score_batch = _require_artifact(score_batch, stage="score_batch")
completion_ids = getattr(score_batch, "completion_ids", None)
completion_attention_mask = getattr(
score_batch, "completion_attention_mask", None
)
LOG.debug(
"Score batch built | %s | total_sequences=%d | max_prompt_len=%s | slice_size=%s | comp_ids_shape=%s | comp_mask_shape=%s | pad_id=%s",
rank_tag,
getattr(score_batch, "total_sequences", 0),
getattr(score_batch, "max_prompt_len", None),
getattr(score_batch, "slice_size", None),
completion_ids.shape if completion_ids is not None else None,
(
completion_attention_mask.shape
if completion_attention_mask is not None
else None
),
getattr(score_batch, "pad_token_id", None),
)
return_entropy = bool(getattr(ctx.scoring, "policy_entropy", False))
entropy_mode = getattr(ctx.scoring, "policy_entropy_mode", "exact")
return_token_logp = bool(
getattr(
getattr(ctx.scoring, "weighting", None), "train_grpo_objective", False
)
)
score_start = time.monotonic()
if progress_log and is_main:
LOG.info(
"Stage policy scoring start | %s | total_sequences=%d",
rank_tag,
getattr(score_batch, "total_sequences", 0),
)
try:
cur_logp_result = _require_artifact(
score_model_outputs(
ctx.runtime.model,
score_batch,
ctx.scoring.batching,
ctx.runtime,
return_hidden=False,
return_entropy=return_entropy,
entropy_mode=entropy_mode,
return_token_logp=return_token_logp,
),
stage="policy_scoring",
)
except TypeError:
cur_logp_result = _require_artifact(
score_model_outputs(
ctx.runtime.model,
score_batch,
ctx.scoring.batching,
ctx.runtime,
),
stage="policy_scoring",
)
logprob_tensor = (
cur_logp_result[0]
if isinstance(cur_logp_result, tuple)
else cur_logp_result
)
LOG.debug(
"Policy scoring complete | %s | logprob_shape=%s",
rank_tag,
getattr(logprob_tensor, "shape", None),
)
if progress_log and is_main:
LOG.info(
"Stage policy scoring done | %s | logprob_shape=%s | seconds=%.2f",
rank_tag,
getattr(logprob_tensor, "shape", None),
time.monotonic() - score_start,
)
policy_entropy_sum = None
token_logp = None
token_mask = None
if isinstance(cur_logp_result, tuple):
if return_entropy:
if return_token_logp and len(cur_logp_result) >= 5:
(
cur_logp_sum,
pooled_hidden,
policy_entropy_sum,
token_logp,
token_mask,
) = cur_logp_result
elif len(cur_logp_result) >= 3:
cur_logp_sum, pooled_hidden, policy_entropy_sum = cur_logp_result[
:3
]
else:
cur_logp_sum, pooled_hidden = cur_logp_result[:2]
else:
if return_token_logp and len(cur_logp_result) >= 4:
(
cur_logp_sum,
pooled_hidden,
token_logp,
token_mask,
) = cur_logp_result
else:
cur_logp_sum, pooled_hidden = cur_logp_result[:2]
else:
cur_logp_sum, pooled_hidden = cur_logp_result, None
if callable(runtime_tokenizer):
stats = _require_artifact(
_collect_batch_stats(
ctx,
gen_batch,
reward_comp,
score_batch=score_batch,
cur_logp_sum=cur_logp_sum,
policy_entropy_sum=policy_entropy_sum,
),
stage="batch_stats",
)
LOG.debug(
"Batch stats ready | %s | sequences=%d | prompt_tokens=%.0f | completion_tokens=%.0f",
rank_tag,
getattr(getattr(stats, "score_batch", None), "total_sequences", 0),
stats.prompt_token_count,
stats.num_completion_tokens,
)
behavior_source = (
str(getattr(ctx.scoring, "behavior_logprobs_source", "model") or "model")
.strip()
.lower()
)
use_vllm_behavior = behavior_source in {"vllm", "metadata", "meta"}
behavior_tensor = None
if use_vllm_behavior:
behavior_tensor = _behavior_logp_tensor_from_meta(
getattr(reward_comp, "ref_logprob_meta", None),
stats.score_batch.total_sequences,
cur_logp_sum,
)
old_token_logp = None
if return_token_logp and use_vllm_behavior:
old_token_logp = _token_logp_tensor_from_meta(
getattr(reward_comp, "ref_logprob_meta", None),
stats.score_batch.total_sequences,
token_mask,
token_logp,
)
try:
scores = build_sequence_scores(
cur_logp_sum,
stats.ref_stats,
pooled_hidden,
behavior_logp_sum=behavior_tensor,
policy_entropy_sum=policy_entropy_sum,
token_logp=token_logp,
token_mask=token_mask,
old_token_logp=old_token_logp,
)
except TypeError:
if behavior_tensor is not None:
try:
scores = build_sequence_scores(
cur_logp_sum,
stats.ref_stats,
behavior_logp_sum=behavior_tensor,
policy_entropy_sum=policy_entropy_sum,
)
except TypeError:
scores = build_sequence_scores(cur_logp_sum, stats.ref_stats)
else:
scores = build_sequence_scores(
cur_logp_sum, stats.ref_stats, policy_entropy_sum=policy_entropy_sum
)
return PreparedBatch(
grouped_completions=gen_batch.grouped_completions,
reward_comp=reward_comp,
batch_stats=stats,
total_input_tokens=stats.prompt_token_count + stats.num_completion_tokens,
scores=scores,
diversity_metrics=diversity_metrics or None,
)
except _SkipBatch as exc:
skip_stage = getattr(exc, "stage", "unknown")
try:
setattr(ctx.runtime, "_last_skip_stage", skip_stage)
except (AttributeError, TypeError):
LOG.debug("Failed to record skip stage on runtime.")
LOG.debug(
"Skipping training batch: stage=%s returned None | %s",
skip_stage,
_rank_tag(getattr(ctx.runtime, "accelerator", None)),
)
return None
__all__ = ["PreparedBatch", "prepare_training_batch"]