# Copyright 2025 Liv d'Aliberti
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Reward and generation helpers extracted from the training loop."""
from __future__ import annotations
# pylint: disable=broad-exception-caught
import logging
import math
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, cast
from types import SimpleNamespace
from maxent_grpo.training.generation import (
AggregatedGenerationState,
drop_empty_prompt_groups,
drop_incomplete_prompt_groups,
flatten_prompt_completions as _flatten_prompt_completions,
flatten_ref_metadata as _flatten_ref_metadata,
retry_incomplete_prompts,
seed_generation_groups,
truncate_to_expected_counts,
)
from maxent_grpo.training.generation.errors import (
GenerationServiceError,
log_generation_service_error,
)
from maxent_grpo.training.runtime import require_torch
from .run_helpers import _group_softmax
from maxent_grpo.rewards.basic import (
RewardConfig,
_answer_pat,
_extract_boxed_answer,
get_reward_funcs,
)
from .types import (
AdvantageStats,
GenerationBatch,
GenerationFn,
QDistribution,
RewardComputation,
RewardMoments,
RewardSpec,
)
torch = require_torch("training")
LOG = logging.getLogger(__name__)
if TYPE_CHECKING:
import torch as torch_types
TorchDevice = torch_types.device
else: # pragma: no cover - runtime uses optional torch stub
TorchDevice = Any
def _rank_tag() -> str:
"""Return best-effort rank string for logging."""
try:
dist = getattr(torch, "distributed", None)
if dist is not None and dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world = dist.get_world_size()
return f"rank={rank}/{world}"
except Exception:
pass
return "rank=na"
def _extract_ref_logprob_fields(meta_entry: Any) -> Tuple[Optional[Any], Optional[Any]]:
"""Return ``(logprob_sum, token_count)`` when present in metadata entries."""
if meta_entry is None:
return None, None
logprob_sum = getattr(meta_entry, "logprob_sum", None)
token_count = getattr(meta_entry, "token_count", None)
token_logprobs = getattr(meta_entry, "token_logprobs", None)
if isinstance(meta_entry, dict):
if logprob_sum is None:
logprob_sum = meta_entry.get("logprob_sum")
if logprob_sum is None:
logprob_sum = meta_entry.get("cumulative_logprob")
if token_logprobs is None:
token_logprobs = meta_entry.get("token_logprobs")
if token_logprobs is None:
token_logprobs = meta_entry.get("logprobs")
if token_count is None:
token_count = meta_entry.get("token_count")
if token_count is None:
token_count = meta_entry.get("num_tokens")
if token_count is None:
if token_logprobs is not None:
try:
token_count = len(token_logprobs)
except (TypeError, ValueError):
token_count = None
if logprob_sum is None and token_logprobs is not None:
try:
logprob_sum = float(sum(float(val) for val in token_logprobs))
except (TypeError, ValueError):
logprob_sum = None
return logprob_sum, token_count
def _sanitize_ref_logprob_meta(
flat_meta: Optional[List[Optional[Any]]], total_sequences: int
) -> Optional[List[Optional[Any]]]:
"""
Drop reference metadata when any entry is missing logprob information.
:param flat_meta: Flattened metadata aligned to completions.
:type flat_meta: list | None
:param total_sequences: Expected number of completions for the batch.
:type total_sequences: int
:returns: Metadata when complete, otherwise ``None``.
:rtype: list | None
"""
if not flat_meta or total_sequences <= 0:
return None
if len(flat_meta) != total_sequences:
return None
missing_idx: List[int] = []
saw_any_logprob_fields = False
for idx, entry in enumerate(flat_meta):
logprob_sum, token_count = _extract_ref_logprob_fields(entry)
if logprob_sum is not None or token_count is not None:
saw_any_logprob_fields = True
if logprob_sum is None or token_count is None:
missing_idx.append(idx)
continue
try:
float(logprob_sum)
int(token_count)
except (TypeError, ValueError):
missing_idx.append(idx)
# If none of the entries advertise logprob information, treat the metadata
# as opaque and keep it. When some entries include logprob fields but others
# don't, drop the entire batch to avoid mixing stale/partial ref stats.
if not saw_any_logprob_fields:
return flat_meta
if missing_idx:
if not getattr(_sanitize_ref_logprob_meta, "_warned", False):
LOG.warning(
"Incomplete reference logprob metadata detected | missing_entries=%d/%d | first_missing_idx=%d | "
"keeping metadata for behavior-logprob fallbacks.",
len(missing_idx),
total_sequences,
missing_idx[0],
)
setattr(_sanitize_ref_logprob_meta, "_warned", True)
return flat_meta
return flat_meta
def _call_reward_fn(
reward_fn: Any,
completions: List[str],
answers: List[str],
*,
is_eval: bool,
split: str,
) -> List[float]:
"""Call a reward fn with backward-compatible kwargs handling."""
try:
return reward_fn(completions, answers, is_eval=is_eval, split=split)
except TypeError:
try:
return reward_fn(completions, answers)
except TypeError:
return reward_fn(completions, answers, is_eval=is_eval)
def _extract_completion_runtime_info(
entry_dict: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
"""Return compact completion metadata needed by scoring and truncation logic."""
if not entry_dict:
return {}
info: Dict[str, Any] = {}
raw_output = entry_dict.get("raw_output")
raw_dict = raw_output if isinstance(raw_output, dict) else {}
token_ids = entry_dict.get("token_ids")
if token_ids is None:
token_ids = raw_dict.get("token_ids") or raw_dict.get("output_token_ids")
if hasattr(token_ids, "tolist"):
try:
token_ids = token_ids.tolist()
except (AttributeError, TypeError, ValueError):
token_ids = None
if (
isinstance(token_ids, list)
and token_ids
and isinstance(token_ids[0], list)
):
token_ids = token_ids[0]
if isinstance(token_ids, list):
coerced_ids: List[int] = []
for val in token_ids:
try:
coerced_ids.append(int(val))
except (TypeError, ValueError):
coerced_ids = []
break
if coerced_ids:
info["token_ids"] = coerced_ids
token_count = entry_dict.get("token_count")
if token_count is None:
token_count = raw_dict.get("token_count") or raw_dict.get("num_tokens")
if token_count is None and "token_ids" in info:
token_count = len(info["token_ids"])
if token_count is not None:
try:
info["token_count"] = int(token_count)
except (TypeError, ValueError):
pass
finish_reason = entry_dict.get("finish_reason")
if finish_reason is None:
finish_reason = raw_dict.get("finish_reason") or raw_dict.get("finishReason")
if finish_reason is not None:
info["finish_reason"] = str(finish_reason)
stop_reason = entry_dict.get("stop_reason")
if stop_reason is None:
stop_reason = raw_dict.get("stop_reason") or raw_dict.get("stopReason")
if stop_reason is not None:
info["stop_reason"] = str(stop_reason)
if str(info.get("finish_reason", "")).strip().lower() == "length":
info["truncated"] = True
return info
def _completion_was_truncated(
metadata: Optional[Dict[str, Any]],
*,
max_completion_len: Optional[int],
) -> bool:
"""Return ``True`` when completion metadata indicates a length stop."""
if not isinstance(metadata, dict):
return False
truncated = metadata.get("truncated")
if truncated is not None:
return bool(truncated)
finish_reason = metadata.get("finish_reason")
if isinstance(finish_reason, str) and finish_reason.strip().lower() == "length":
return True
if max_completion_len is None or int(max_completion_len) <= 0:
return False
token_count = metadata.get("token_count")
if token_count is None:
token_ids = metadata.get("token_ids")
if isinstance(token_ids, list):
token_count = len(token_ids)
try:
return int(token_count) >= int(max_completion_len)
except (TypeError, ValueError):
return False
def _zero_truncated_completion_rewards(
total_utils: List[float],
completion_metadata: Optional[List[Dict[str, Any]]],
*,
max_completion_len: Optional[int],
) -> List[float]:
"""Zero sequence rewards for samples that appear truncated."""
if not total_utils or not completion_metadata:
return total_utils
adjusted = list(total_utils)
for idx, metadata in enumerate(completion_metadata[: len(adjusted)]):
if _completion_was_truncated(
metadata,
max_completion_len=max_completion_len,
):
adjusted[idx] = 0.0
return adjusted
[docs]
def compute_reward_totals(
reward_spec: RewardSpec,
completion_batch: List[str],
flat_answers: List[str],
) -> Tuple[List[float], Dict[str, List[float]]]:
"""Evaluate reward functions and aggregate per-sequence utilities.
:param reward_spec: Reward configuration specifying callables/weights.
:type reward_spec: RewardSpec
:param completion_batch: Flattened completion texts.
:type completion_batch: list[str]
:param flat_answers: Flattened answer strings aligned with completions.
:type flat_answers: list[str]
:returns: Tuple of total utilities and per-reward raw values.
:rtype: tuple[list[float], dict[str, list[float]]]
"""
total_utils = [0.0] * len(completion_batch)
per_reward_values: Dict[str, List[float]] = {}
reward_tensors: List[Any] = []
for idx_reward, (reward_fn, reward_weight) in enumerate(
zip(reward_spec.reward_funcs, reward_spec.reward_weights)
):
reward_key = f"reward_{idx_reward}"
reward_values = [
float(val)
for val in _call_reward_fn(
reward_fn, completion_batch, flat_answers, is_eval=False, split="train"
)
]
per_reward_values[reward_key] = reward_values
reward_tensor = torch.tensor(
reward_values, dtype=getattr(torch, "float32", None)
)
if reward_weight != 1.0:
reward_tensor = reward_tensor * float(reward_weight)
reward_tensors.append(reward_tensor)
if reward_tensors:
stack_fn = getattr(torch, "stack", None)
if callable(stack_fn):
stacked = stack_fn(reward_tensors, dim=0)
try:
total_tensor = torch.nansum(stacked, dim=0)
except AttributeError:
total_tensor = torch.sum(torch.nan_to_num(stacked, nan=0.0), dim=0)
total_utils = [float(val) for val in total_tensor.tolist()]
else:
# Fallback path used by lightweight torch doubles in unit tests.
total_utils = [0.0] * len(completion_batch)
for tensor in reward_tensors:
tolist_fn = getattr(tensor, "tolist", None)
values = tolist_fn() if callable(tolist_fn) else tensor
for idx, raw in enumerate(values):
try:
val = float(raw)
except (TypeError, ValueError):
val = 0.0
if math.isnan(val) or math.isinf(val):
val = 0.0
total_utils[idx] += val
return total_utils, per_reward_values
[docs]
def reward_moments(
total_utils: List[float], device: TorchDevice
) -> Tuple[float, float]:
"""Compute reward mean/std on CPU or current accelerator device.
:param total_utils: Flattened reward totals per completion.
:type total_utils: list[float]
:param device: Device used for tensor computations.
:type device: ``torch.device``
:returns: Tuple containing ``(mean, std)`` rewards.
:rtype: tuple[float, float]
"""
if not total_utils:
return 0.0, 0.0
try:
mean_val = sum(total_utils) / len(total_utils)
train_reward_mean = float(mean_val)
if len(total_utils) > 1:
var = sum((u - mean_val) ** 2 for u in total_utils) / len(total_utils)
train_reward_std = float(math.sqrt(var))
else:
train_reward_std = 0.0
return train_reward_mean, train_reward_std
except (TypeError, ZeroDivisionError, ValueError, OverflowError, RuntimeError):
torch_mod = require_torch("training_rewards")
utils_tensor = torch_mod.tensor(
total_utils,
dtype=getattr(torch_mod, "float32", None),
device=(
device
if getattr(device, "type", "cpu") != "cpu"
else torch_mod.device("cpu")
),
)
mean_val = getattr(utils_tensor, "mean", lambda *a, **k: utils_tensor)()
train_reward_mean = float(getattr(mean_val, "item", lambda: mean_val)())
if utils_tensor.numel() > 1:
std_val = getattr(utils_tensor, "std", lambda *a, **k: utils_tensor)(
unbiased=False
)
train_reward_std = float(getattr(std_val, "item", lambda: std_val)())
else:
train_reward_std = 0.0
return train_reward_mean, train_reward_std
def _seed_extract_answer(text: str) -> Optional[str]:
"""Return the raw final answer string used for SEED-GRPO clustering."""
boxed = _extract_boxed_answer(text)
if boxed is not None:
answer = str(boxed).strip()
return answer or None
match = _answer_pat.search(text)
if match is None:
return None
answer = str(match.group(1)).strip()
return answer or None
def _seed_semantic_ids_by_answers(answers_list: List[str]) -> List[int]:
"""Match the official SEED-GRPO exact-answer clustering rule."""
answer_to_id: Dict[str, int] = {}
semantic_ids: List[int] = []
for answer in answers_list:
if answer not in answer_to_id:
answer_to_id[answer] = len(answer_to_id)
semantic_ids.append(answer_to_id[answer])
return semantic_ids
def _seed_logsumexp(values: List[float]) -> float:
"""Return a numerically stable log-sum-exp over ``values``."""
if not values:
return float("-inf")
max_val = max(values)
if not math.isfinite(max_val):
return max_val
total = sum(math.exp(val - max_val) for val in values)
if total <= 0.0:
return float("-inf")
return max_val + math.log(total)
def _seed_logsumexp_by_id(
semantic_ids: List[int], log_likelihoods: List[float]
) -> List[float]:
"""Aggregate normalized cluster log-mass by semantic id."""
if not semantic_ids or not log_likelihoods or len(semantic_ids) != len(log_likelihoods):
return []
norm = _seed_logsumexp(log_likelihoods)
if not math.isfinite(norm):
raise ValueError("SEED-GRPO requires finite completion log-likelihoods.")
unique_ids = sorted(set(int(uid) for uid in semantic_ids))
cluster_log_probs: List[float] = []
for uid in unique_ids:
cluster_vals = [
float(log_likelihoods[idx])
for idx, semantic_id in enumerate(semantic_ids)
if int(semantic_id) == uid
]
cluster_log_probs.append(_seed_logsumexp(cluster_vals) - norm)
return cluster_log_probs
def _seed_predictive_entropy_rao(cluster_log_probs: List[float]) -> float:
"""Return Rao-style predictive entropy over normalized cluster mass."""
entropy = 0.0
for log_prob in cluster_log_probs:
if not math.isfinite(log_prob):
continue
prob = math.exp(log_prob)
entropy -= prob * log_prob
return float(entropy)
def _compute_seed_grpo_statistics(
gen_batch: GenerationBatch,
*,
alpha: float,
normalize_by_max_entropy: bool,
length_normalize_logprobs: bool,
num_generations: Optional[int],
) -> Tuple[List[float], List[float], float, float]:
"""Return per-prompt semantic entropies and advantage scales for SEED-GRPO."""
grouped_comps = list(getattr(gen_batch, "grouped_completions", []) or [])
grouped_meta = list(getattr(gen_batch, "grouped_ref_meta", []) or [])
if not grouped_comps:
return [], [], 0.0, 0.0
if len(grouped_meta) < len(grouped_comps):
raise ValueError(
"SEED-GRPO requires generation logprob metadata for every prompt group."
)
generation_count = int(num_generations or 0)
if generation_count <= 0:
generation_count = max((len(group) for group in grouped_comps), default=0)
max_possible_entropy = math.log(generation_count) if generation_count > 1 else 0.0
effective_alpha = float(alpha)
if normalize_by_max_entropy and max_possible_entropy > 0.0:
effective_alpha = effective_alpha / max_possible_entropy
entropies: List[float] = []
scales: List[float] = []
for prompt_idx, comp_group in enumerate(grouped_comps):
meta_group = grouped_meta[prompt_idx]
if not isinstance(meta_group, list) or len(meta_group) < len(comp_group):
raise ValueError(
"SEED-GRPO requires generation logprob metadata aligned with completions."
)
question_answers: List[str] = []
log_liks: List[float] = []
for comp_text, meta_entry in zip(comp_group, meta_group):
answer = _seed_extract_answer(str(comp_text))
question_answers.append(answer if answer is not None else "NO_ANSWER_FOUND")
logprob_sum, token_count = _extract_ref_logprob_fields(meta_entry)
if logprob_sum is None:
raise ValueError(
"SEED-GRPO requires per-completion generation logprob metadata."
)
try:
log_lik = float(logprob_sum)
if length_normalize_logprobs:
denom = float(token_count) if token_count is not None else 0.0
if math.isfinite(denom) and denom > 0.0:
log_lik = log_lik / denom
log_liks.append(log_lik)
except (TypeError, ValueError) as exc:
raise ValueError(
"SEED-GRPO requires finite generation log-probabilities."
) from exc
if all(answer == "NO_ANSWER_FOUND" for answer in question_answers):
semantic_ids = list(range(len(question_answers)))
else:
no_answer_indices = [
idx
for idx, answer in enumerate(question_answers)
if answer == "NO_ANSWER_FOUND"
]
valid_answers = [
answer for answer in question_answers if answer != "NO_ANSWER_FOUND"
]
valid_indices = [
idx
for idx, answer in enumerate(question_answers)
if answer != "NO_ANSWER_FOUND"
]
if no_answer_indices:
valid_semantic_ids = _seed_semantic_ids_by_answers(valid_answers)
semantic_ids = [-1] * len(question_answers)
for valid_pos, question_idx in enumerate(valid_indices):
semantic_ids[question_idx] = valid_semantic_ids[valid_pos]
max_id = max(valid_semantic_ids, default=-1)
for extra_offset, question_idx in enumerate(no_answer_indices):
semantic_ids[question_idx] = max_id + 1 + extra_offset
else:
semantic_ids = _seed_semantic_ids_by_answers(question_answers)
cluster_log_probs = _seed_logsumexp_by_id(semantic_ids, log_liks)
entropy = _seed_predictive_entropy_rao(cluster_log_probs)
scale = 1.0 / (1.0 + effective_alpha * entropy)
entropies.append(float(entropy))
scales.append(float(scale))
return entropies, scales, float(effective_alpha), float(max_possible_entropy)
def _apply_group_scales(
advantage_grouped: List[List[float]],
group_scales: Optional[List[float]],
) -> Tuple[List[List[float]], List[float]]:
"""Scale grouped advantages by per-prompt multipliers."""
if not group_scales:
flat: List[float] = []
for group in advantage_grouped:
flat.extend(group)
return advantage_grouped, flat
scaled_grouped: List[List[float]] = []
scaled_flat: List[float] = []
for idx, group in enumerate(advantage_grouped):
scale = (
float(group_scales[idx])
if idx < len(group_scales) and math.isfinite(float(group_scales[idx]))
else 1.0
)
scaled_group = [float(scale) * float(value) for value in group]
scaled_grouped.append(scaled_group)
scaled_flat.extend(scaled_group)
return scaled_grouped, scaled_flat
[docs]
def group_advantages(
grouped_comps: List[List[str]],
total_utils: List[float],
*,
scale_rewards: bool = True,
) -> Tuple[List[List[float]], List[float]]:
"""Return normalized advantages per prompt group and flattened samples.
:param grouped_comps: Completions grouped by prompt.
:type grouped_comps: list[list[str]]
:param total_utils: Flattened utilities aligned with completions.
:type total_utils: list[float]
:param scale_rewards: Whether to divide by group std (TRL default).
:type scale_rewards: bool
:returns: Tuple of grouped advantages and flattened advantage samples.
:rtype: tuple[list[list[float]], list[float]]
"""
advantage_grouped: List[List[float]] = []
eps = 1e-4
idx_utils = 0
for comp_group in grouped_comps:
size = len(comp_group)
group_vals = total_utils[idx_utils : idx_utils + size]
if size > 0:
baseline = float(sum(group_vals) / size)
if scale_rewards and size > 1:
var = sum((val - baseline) ** 2 for val in group_vals) / float(size - 1)
std = math.sqrt(var)
else:
std = 0.0
if scale_rewards:
denom = std + eps
adv_vals = [(val - baseline) / denom for val in group_vals]
else:
adv_vals = [val - baseline for val in group_vals]
advantage_grouped.append(adv_vals)
else:
adv_vals = []
advantage_grouped.append(adv_vals)
idx_utils += size
advantage_samples: List[float] = []
for adv_vals in advantage_grouped:
advantage_samples.extend(adv_vals)
return advantage_grouped, advantage_samples
[docs]
def prepare_generation_batch(
batch: Dict[str, List[str]],
generator: GenerationFn[Any],
generation_stats: Dict[str, int],
expected_generations: int,
max_retry_rounds: Optional[int] = None,
) -> Optional[GenerationBatch]:
"""Generate completions and retry prompts that initially returned nothing.
:param batch: Mini-batch containing ``prompt``/``answer`` lists.
:type batch: dict[str, list[str]]
:param generator: Callable that produces grouped completions and metadata.
:type generator: :class:`~training.types.GenerationFn`
:param generation_stats: Mutable statistics dictionary updated in-place.
:type generation_stats: dict[str, int]
:param expected_generations: Desired completions per prompt.
:type expected_generations: int
:param max_retry_rounds: Optional cap overriding the default retry limit.
:type max_retry_rounds: int | None
:returns: Populated :class:`~training.types.GenerationBatch` or ``None`` if
generation fails after retries.
:rtype: :class:`~training.types.GenerationBatch` | None
"""
prompts: List[str] = batch["prompt"]
answers: List[str] = batch["answer"]
if not prompts:
LOG.debug("Generation skipped | %s | reason=empty_prompts", _rank_tag())
return None
LOG.debug(
"Starting completion generation | %s | prompts=%d | expected_generations=%d",
_rank_tag(),
len(prompts),
expected_generations,
)
def _call_generator(
prompt_batch: List[str],
expected: int,
per_prompt_counts: Optional[List[int]] = None,
) -> Any:
import inspect
per_prompt_repr = "none"
if per_prompt_counts is not None:
try:
per_prompt_repr = (
f"len={len(per_prompt_counts)} first3={list(per_prompt_counts)[:3]}"
)
except (TypeError, ValueError):
per_prompt_repr = str(per_prompt_counts)
LOG.debug(
"Invoking generator | prompts=%d | expected=%d | per_prompt_counts=%s",
len(prompt_batch),
expected,
per_prompt_repr,
)
try:
signature = inspect.signature(generator)
except (TypeError, ValueError):
signature = None
def _supports_positional_counts() -> bool:
if signature is None:
return True
params = list(signature.parameters.values())
if any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in params):
return True
positional = [
p
for p in params
if p.kind
in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
]
return len(positional) >= 3
def _supports_keyword_counts() -> bool:
if signature is None:
return False
if "per_prompt_counts" in signature.parameters:
return True
return any(
p.kind == inspect.Parameter.VAR_KEYWORD
for p in signature.parameters.values()
)
try:
if per_prompt_counts is None:
if _supports_positional_counts():
result = generator(prompt_batch, expected, None)
else:
result = generator(prompt_batch, expected)
elif _supports_positional_counts():
result = generator(prompt_batch, expected, per_prompt_counts)
elif _supports_keyword_counts():
result = generator(
prompt_batch,
expected,
per_prompt_counts=per_prompt_counts,
)
else:
LOG.debug(
"Generator does not accept per_prompt_counts; invoking without it."
)
result = generator(prompt_batch, expected)
LOG.debug(
"Generator returned | result_type=%s",
type(result).__name__,
)
return result
except GenerationServiceError as exc:
log_generation_service_error(LOG, "training", exc)
raise
gen_result = _call_generator(prompts, expected_generations)
if gen_result is None:
LOG.warning(
"Generation skipped | %s | reason=generator_returned_none",
_rank_tag(),
)
return None
if isinstance(gen_result, GenerationBatch):
grouped_comps = gen_result.grouped_completions
grouped_meta = getattr(gen_result, "grouped_ref_meta", None)
elif hasattr(gen_result, "grouped_completions"):
grouped_comps = getattr(gen_result, "grouped_completions")
grouped_meta = getattr(gen_result, "grouped_ref_meta", None)
else:
grouped_comps, grouped_meta = gen_result
LOG.debug(
"Generator output unpacked | %s | grouped_type=%s | meta_present=%s",
_rank_tag(),
type(grouped_comps).__name__,
grouped_meta is not None,
)
LOG.debug(
"Generation finished | %s | prompts=%d | groups_returned=%d",
_rank_tag(),
len(prompts),
len(grouped_comps) if grouped_comps is not None else 0,
)
prompt_count = len(prompts)
aggregated_comps, aggregated_meta = seed_generation_groups(
prompt_count,
grouped_comps,
grouped_meta,
)
aggregated_state = AggregatedGenerationState(aggregated_comps, aggregated_meta)
LOG.debug(
"Retrying incomplete prompts | %s | initial_groups=%d | expected=%d",
_rank_tag(),
len(aggregated_comps) if aggregated_comps is not None else 0,
expected_generations,
)
aggregated_state = retry_incomplete_prompts(
prompts,
_call_generator,
expected_generations,
aggregated_state,
max_retry_rounds,
)
aggregated_comps, aggregated_meta = (
aggregated_state.completions,
aggregated_state.metadata,
)
LOG.debug(
"Retries done | %s | prompts=%d | groups=%d",
_rank_tag(),
len(prompts),
len(aggregated_comps) if aggregated_comps is not None else 0,
)
pre_prompt_count = len(prompts)
pre_group_count = len(aggregated_comps) if aggregated_comps is not None else 0
pre_total_comps = (
sum(len(group) for group in aggregated_comps) if aggregated_comps else 0
)
pre_empty_groups = (
sum(1 for group in aggregated_comps if not group) if aggregated_comps else 0
)
LOG.debug(
"Generation pre-filter | %s | prompts=%d | groups=%d | total_completions=%d | empty_groups=%d",
_rank_tag(),
pre_prompt_count,
pre_group_count,
pre_total_comps,
pre_empty_groups,
)
prompts, answers, aggregated_comps, aggregated_meta = drop_empty_prompt_groups(
prompts,
answers,
aggregated_comps,
aggregated_meta,
generation_stats,
)
aggregated_comps, aggregated_meta, partial_count = truncate_to_expected_counts(
aggregated_comps,
aggregated_meta,
expected_generations,
)
if partial_count > 0:
generation_stats.setdefault("partial_prompts", 0)
generation_stats["partial_prompts"] += int(partial_count)
prompts, answers, aggregated_comps, aggregated_meta, mismatch_count = (
drop_incomplete_prompt_groups(
prompts,
answers,
aggregated_comps,
aggregated_meta,
expected_generations,
generation_stats,
)
)
post_prompt_count = len(prompts)
post_group_count = len(aggregated_comps) if aggregated_comps is not None else 0
post_total_comps = (
sum(len(group) for group in aggregated_comps) if aggregated_comps else 0
)
post_empty_groups = (
sum(1 for group in aggregated_comps if not group) if aggregated_comps else 0
)
LOG.debug(
"Generation post-filter | %s | prompts=%d | groups=%d | total_completions=%d | empty_groups=%d | dropped_prompts=%d",
_rank_tag(),
post_prompt_count,
post_group_count,
post_total_comps,
post_empty_groups,
max(pre_prompt_count - post_prompt_count, 0),
)
if not aggregated_comps:
LOG.warning(
"Generation skipped | %s | reason=no_completions_after_filter | prompts=%d",
_rank_tag(),
post_prompt_count,
)
return None
if mismatch_count > 0:
LOG.debug(
"Dropped incomplete groups | %s | prompts=%d | expected=%d | dropped=%d",
_rank_tag(),
len(prompts),
expected_generations,
mismatch_count,
)
completion_info: List[List[dict]] = [
[{} for _ in group] for group in aggregated_comps
]
if aggregated_meta is not None:
# Propagate token-id metadata (when available) into completion_info.
# Some generation backends include token ids or other structured info
# alongside reference-logprob summaries; keeping token ids here lets
# downstream scoring avoid re-tokenizing long completions.
def _meta_to_dict(entry: Any) -> Optional[Dict[str, Any]]:
if entry is None:
return None
if hasattr(entry, "to_trl_payload"):
try:
value = entry.to_trl_payload()
return value if isinstance(value, dict) else None
except (AttributeError, TypeError, ValueError):
return None
return entry if isinstance(entry, dict) else None
def _extract_token_ids(
entry_dict: Optional[Dict[str, Any]],
) -> Optional[List[int]]:
if not entry_dict:
return None
token_ids = entry_dict.get("token_ids")
if token_ids is None and isinstance(entry_dict.get("raw_output"), dict):
raw = entry_dict["raw_output"]
token_ids = raw.get("token_ids") or raw.get("output_token_ids")
if token_ids is None:
return None
if hasattr(token_ids, "tolist"):
try:
token_ids = token_ids.tolist()
except (AttributeError, TypeError, ValueError) as exc:
LOG.debug("Failed to coerce token_ids to list: %s", exc)
if (
isinstance(token_ids, list)
and token_ids
and isinstance(token_ids[0], list)
):
token_ids = token_ids[0]
if not isinstance(token_ids, list):
return None
coerced: List[int] = []
for val in token_ids:
try:
coerced.append(int(val))
except (TypeError, ValueError):
return None
return coerced
for prompt_idx, comp_group in enumerate(aggregated_comps):
if prompt_idx >= len(aggregated_meta):
continue
meta_group = aggregated_meta[prompt_idx]
if not isinstance(meta_group, list) or not meta_group:
continue
for comp_idx in range(len(comp_group)):
meta_entry = (
meta_group[comp_idx] if comp_idx < len(meta_group) else None
)
meta_dict = _meta_to_dict(meta_entry)
token_ids = _extract_token_ids(meta_dict)
if token_ids is not None:
completion_info[prompt_idx][comp_idx]["token_ids"] = token_ids
runtime_info = _extract_completion_runtime_info(meta_dict)
if runtime_info:
completion_info[prompt_idx][comp_idx].update(runtime_info)
return GenerationBatch(
prompts=prompts,
answers=answers,
grouped_completions=aggregated_comps,
grouped_ref_meta=aggregated_meta,
grouped_completion_info=completion_info,
)
def _group_q_distribution(
grouped_comps: List[List[str]],
total_utils: List[float],
temperature: float,
epsilon: float,
) -> Tuple[List[List[float]], List[float]]:
"""Return per-group q distributions derived from listwise utilities.
:param grouped_comps: Completion groups per prompt.
:type grouped_comps: list[list[str]]
:param total_utils: Flattened utility values aligned with completions.
:type total_utils: list[float]
:param temperature: Softmax temperature for listwise distribution.
:type temperature: float
:param epsilon: Minimum support value to ensure non-zero probabilities.
:type epsilon: float
:returns: Tuple of grouped q-values and flattened q-samples.
:rtype: tuple[list[list[float]], list[float]]
"""
LOG.debug(
"Computing q distribution | groups=%d | total_utils=%d | temperature=%.4f | epsilon=%.2e",
len(grouped_comps),
len(total_utils),
temperature,
epsilon,
)
q_grouped: List[List[float]] = []
q_samples: List[float] = []
idx_utils = 0
for group_idx, comp_group in enumerate(grouped_comps):
size = len(comp_group)
group_vals = total_utils[idx_utils : idx_utils + size]
if LOG.isEnabledFor(logging.DEBUG) and group_idx < 5:
LOG.debug(
"Softmax sampler | group_idx=%d | size=%d | temp=%.4f | eps=%.2e | util_sample=%s",
group_idx,
size,
temperature,
epsilon,
group_vals[: min(3, len(group_vals))],
)
if size > 0 and group_vals:
q_vals = _group_softmax(
group_vals,
temperature=max(temperature, 1e-8),
eps=epsilon,
)
else:
q_vals = []
q_grouped.append(q_vals)
q_samples.extend(q_vals)
idx_utils += size
return q_grouped, q_samples
[docs]
def compute_reward_statistics(
gen_batch: GenerationBatch,
reward_spec: RewardSpec,
device: TorchDevice,
q_temperature: float,
q_epsilon: float,
controller_beta: Optional[float] = None,
controller_tau: Optional[float] = None,
scale_rewards: bool = True,
zero_truncated_completion_rewards: bool = False,
max_completion_len: Optional[int] = None,
seed_grpo_enabled: bool = False,
seed_grpo_alpha: float = 0.0417,
seed_grpo_alpha_normalize_by_max_entropy: bool = True,
seed_grpo_length_normalize_logprobs: bool = True,
seed_grpo_num_generations: Optional[int] = None,
) -> Optional[RewardComputation]:
"""Compute utilities, q-distributions, and flattened prompt/completion pairs.
:param gen_batch: Generation batch containing grouped completions/meta.
:type gen_batch: :class:`~training.types.GenerationBatch`
:param reward_spec: Reward configuration (functions + weights).
:type reward_spec: RewardSpec
:param device: Torch device used for reward moment computations.
:type device: ``torch.device``
:param q_temperature: Temperature used when forming q-distributions.
:type q_temperature: float
:param q_epsilon: Epsilon floor ensuring full support in q-distribution.
:type q_epsilon: float
:param controller_beta: Optional KL controller beta logged with stats.
:type controller_beta: float | None
:param controller_tau: Optional controller tau logged alongside q temp.
:type controller_tau: float | None
:returns: Populated :class:`~maxent_grpo.training.types.rewards.RewardComputation` or ``None``
when inputs are empty.
:rtype: :class:`~maxent_grpo.training.types.rewards.RewardComputation` | None
"""
grouped_comps = gen_batch.grouped_completions
if not grouped_comps:
LOG.warning(
"Reward stats skipped | %s | reason=empty_grouped_completions",
_rank_tag(),
)
return None
pair_batch, flat_answers = _flatten_prompt_completions(gen_batch)
if not pair_batch.completions:
total_groups = len(grouped_comps)
total_comps = sum(len(group) for group in grouped_comps)
LOG.warning(
"Reward stats skipped | %s | reason=empty_flat_completions | groups=%d | total_completions=%d",
_rank_tag(),
total_groups,
total_comps,
)
return None
LOG.debug(
"Reward stats inputs | %s | prompts=%d | completions=%d",
_rank_tag(),
len(grouped_comps),
len(pair_batch.completions),
)
completion_metadata = getattr(pair_batch, "metadata", None)
total_utils, per_reward_values = compute_reward_totals(
reward_spec,
pair_batch.completions,
flat_answers,
)
if zero_truncated_completion_rewards:
total_utils = _zero_truncated_completion_rewards(
total_utils,
completion_metadata,
max_completion_len=max_completion_len,
)
moments = RewardMoments(*reward_moments(total_utils, device))
advantage_grouped, advantage_samples = group_advantages(
grouped_comps, total_utils, scale_rewards=scale_rewards
)
seed_semantic_entropies: Optional[List[float]] = None
seed_advantage_scales: Optional[List[float]] = None
seed_alpha_effective: Optional[float] = None
seed_max_possible_entropy: Optional[float] = None
if seed_grpo_enabled:
(
seed_semantic_entropies,
seed_advantage_scales,
seed_alpha_effective,
seed_max_possible_entropy,
) = _compute_seed_grpo_statistics(
gen_batch,
alpha=seed_grpo_alpha,
normalize_by_max_entropy=seed_grpo_alpha_normalize_by_max_entropy,
length_normalize_logprobs=seed_grpo_length_normalize_logprobs,
num_generations=seed_grpo_num_generations,
)
advantage_grouped, advantage_samples = _apply_group_scales(
advantage_grouped,
seed_advantage_scales,
)
advantage_stats = AdvantageStats(advantage_grouped, advantage_samples)
q_distribution = QDistribution(
*_group_q_distribution(
grouped_comps,
total_utils,
q_temperature,
q_epsilon,
)
)
flat_ref_meta = _sanitize_ref_logprob_meta(
_flatten_ref_metadata(grouped_comps, gen_batch.grouped_ref_meta),
len(pair_batch.completions),
)
if LOG.isEnabledFor(logging.DEBUG):
if isinstance(flat_ref_meta, list):
flat_meta_list: List[Optional[Any]] = list(flat_ref_meta)
else:
flat_meta_list = []
meta_len = len(flat_meta_list)
sample = flat_meta_list[: min(2, meta_len)] if meta_len else None
LOG.debug(
"Ref metadata flatten | grouped_meta=%s | entries=%d | sample=%s",
"none" if gen_batch.grouped_ref_meta is None else "present",
meta_len,
sample,
)
q_samples = q_distribution.samples or []
if q_samples:
q_min = min(q_samples)
q_max = max(q_samples)
else:
q_min = q_max = 0.0
beta_repr = "nan"
try:
if controller_beta is not None:
beta_repr = f"{float(controller_beta):.4f}"
except (TypeError, ValueError) as exc:
LOG.debug("Failed to format controller_beta for logging: %s", exc)
tau_repr = "nan"
try:
if controller_tau is not None:
tau_repr = f"{float(controller_tau):.4f}"
except (TypeError, ValueError) as exc:
LOG.debug("Failed to format controller_tau for logging: %s", exc)
LOG.debug(
"Reward computation | %s | prompts=%d | completions=%d | reward_mean=%.4f | reward_std=%.4f | q_range=[%.4f, %.4f] | q_temperature=%.3f | controller_tau=%s | beta=%s | eps=%.2e",
_rank_tag(),
len(grouped_comps),
len(pair_batch.completions),
moments.mean,
moments.std,
q_min,
q_max,
q_temperature,
tau_repr,
beta_repr,
q_epsilon,
)
return RewardComputation(
total_utils=total_utils,
per_reward_values=per_reward_values,
advantage=advantage_stats,
pairs=pair_batch,
q_distribution=q_distribution,
moments=moments,
ref_logprob_meta=flat_ref_meta,
completion_metadata=completion_metadata,
seed_semantic_entropies=seed_semantic_entropies,
seed_advantage_scales=seed_advantage_scales,
seed_alpha_effective=seed_alpha_effective,
seed_max_possible_entropy=seed_max_possible_entropy,
)
def _coerce_reward_names(raw_names: Any) -> List[str]:
"""Return a list of reward identifiers from arbitrary inputs."""
if not raw_names:
return []
if isinstance(raw_names, str):
return [raw_names]
try:
sequence = list(raw_names)
except TypeError:
return [str(raw_names)]
names: List[str] = []
for name in sequence:
if name is None:
continue
names.append(str(name))
return names
def _has_recipe_path(obj: Any) -> bool:
"""Return ``True`` when the object carries a recipe path marker."""
return bool(getattr(obj, "recipe_path", None))
def _build_reward_proxy(source: Any, reward_names: List[str]) -> RewardConfig:
"""Preserve source config attributes when instantiating reward helpers."""
proxy_data: Dict[str, Any] = {"reward_funcs": list(reward_names)}
if source is not None:
try:
source_data = vars(source)
except TypeError:
source_data = None
if isinstance(source_data, dict):
proxy_data.update(source_data)
proxy_data["reward_funcs"] = list(reward_names)
return cast(RewardConfig, SimpleNamespace(**proxy_data))
[docs]
def load_reward_functions(
script_args: Any, tokenizer: Any, training_args: Any = None
) -> Tuple[list, list]:
"""Resolve reward functions/weights from script or training args.
:param script_args: Script arguments carrying reward names/weights.
:param tokenizer: Tokenizer passed to reward function factory helpers.
:param training_args: Optional training config that can override script rewards.
:returns: Tuple of ``(reward_funcs, reward_weights)``.
:rtype: tuple[list, list]
"""
def _resolve_rewards(source: Any) -> Tuple[List[str], Optional[List[float]]]:
if source is None:
return [], None
names = _coerce_reward_names(getattr(source, "reward_funcs", None))
weights = getattr(source, "reward_weights", None)
return names, weights
script_names, script_weights = _resolve_rewards(script_args)
training_names, training_weights = _resolve_rewards(training_args)
use_training = False
if training_names:
if not script_names or training_names != script_names:
use_training = True
if use_training:
reward_names = training_names
weight_source = training_weights
proxy_source = training_args
elif script_names:
reward_names = script_names
weight_source = script_weights
proxy_source = script_args
else:
reward_names = ["pure_accuracy_math"]
weight_source = None
proxy_source = script_args if script_args is not None else training_args
proxy = _build_reward_proxy(proxy_source, reward_names)
reward_funcs = get_reward_funcs(proxy, None, tokenizer)
reward_weights = weight_source
if reward_weights is None or len(reward_weights) != len(reward_funcs):
reward_weights = [1.0] * len(reward_funcs)
return reward_funcs, reward_weights
[docs]
def load_eval_reward_functions(
script_args: Any, tokenizer: Any, training_args: Any = None
) -> Tuple[list, list]:
"""Resolve eval reward functions/weights, defaulting to training rewards.
:param script_args: Script arguments containing eval-specific reward settings.
:param tokenizer: Tokenizer passed to reward function factory helpers.
:param training_args: Optional training config with reward overrides.
:returns: Tuple of ``(reward_funcs, reward_weights)`` for evaluation.
:rtype: tuple[list, list]
"""
script_eval_names = _coerce_reward_names(
getattr(script_args, "eval_reward_funcs", None)
)
script_eval_weights = getattr(script_args, "eval_reward_weights", None)
script_train_names = _coerce_reward_names(
getattr(script_args, "reward_funcs", None)
)
script_train_weights = getattr(script_args, "reward_weights", None)
training_names = (
_coerce_reward_names(getattr(training_args, "reward_funcs", None))
if training_args is not None
else []
)
training_weights = (
getattr(training_args, "reward_weights", None)
if training_args is not None
else None
)
if script_eval_names:
reward_names = script_eval_names
weight_source = script_eval_weights
proxy_source = script_args
else:
use_training = False
if training_names:
if not script_train_names or training_names != script_train_names:
use_training = True
if script_train_names and not use_training:
reward_names = script_train_names
weight_source = script_train_weights
proxy_source = script_args
elif training_names:
reward_names = training_names
weight_source = training_weights
proxy_source = training_args
else:
reward_names = ["pure_accuracy_math"]
weight_source = None
proxy_source = script_args if script_args is not None else training_args
proxy = _build_reward_proxy(proxy_source, reward_names)
reward_funcs = get_reward_funcs(proxy, None, tokenizer)
reward_weights = weight_source
if reward_weights is None or len(reward_weights) != len(reward_funcs):
reward_weights = [1.0] * len(reward_funcs)
return reward_funcs, reward_weights
__all__ = [
"compute_reward_statistics",
"prepare_generation_batch",
"load_reward_functions",
"load_eval_reward_functions",
]