maxent_grpo.training.pipeline

Helpers for preparing generation/scoring artifacts used by the training loop.

The training loop expects a consistent set of artifacts for every batch:

PreparedBatch

Bundles grouped completions, reward statistics, reference log-probability tensors, weighting diagnostics, and derived scores.

_collect_batch_stats

Bridges generation/reward outputs with the scoring stack by building ScoreBatch objects, gathering reference log-probs when necessary, and computing weighting and length summaries.

prepare_training_batch

High-level orchestration that runs the generation function, computes rewards, fetches reference log-probs, scores the policy, and returns a PreparedBatch instance to the optimizer.

The helpers raise the internal _SkipBatch exception when any step fails; prepare_training_batch() catches it and returns None so the caller can skip the problematic batch gracefully.

Functions

_behavior_logp_tensor_from_meta(flat_meta, ...)

Return a tensor of behavior log-prob sums derived from metadata.

_coerce_token_logprob_value(value)

_collect_batch_stats(ctx, gen_batch, ...[, ...])

Gather scoring, reference, and weighting artifacts for a batch.

_completion_diversity_metrics(...[, ...])

Return coarse diversity metrics for grouped completions.

_deepspeed_zero_stage(accelerator)

Return DeepSpeed ZeRO stage from Accelerate plugin state when present.

_dist_any_flag(accelerator, flag)

Return True if flag is True on any rank (best-effort, object gather).

_extract_token_logprob_seq(entry)

_maybe_apply_entropy_bonus(ctx, gen_batch, ...)

Optionally add a policy-entropy bonus to rewards and refresh stats.

_mean(values)

Return the arithmetic mean for a non-empty list, else 0.0.

_progress_log_enabled()

_rank_tag([accelerator])

Return best-effort rank string for logging.

_reference_stats_from_meta(flat_meta, ...)

Return reference stats when metadata fully covers all sequences.

_require_artifact(value, stage)

Return value or raise the internal _SkipBatch sentinel.

_resolve_weighting_value(ctx, attribute[, ...])

Return a weighting attribute with graceful fallbacks.

_token_logp_tensor_from_meta(flat_meta, ...)

Return per-token log-prob tensor derived from vLLM metadata when available.

_tokenize_for_diversity(text[, tokenizer])

Tokenize a completion for diversity metrics.

_weighted_mean(values, weights)

Return the weighted mean or 0.0 when weights are empty.

prepare_training_batch(ctx, generator, batch)

Return a PreparedBatch or None when any stage fails.

Classes

PreparedBatch(grouped_completions, ...[, ...])

Artifacts required to run optimization for a training batch.

_BatchStats(score_batch, ref_stats, ...)

Aggregated batch statistics before building losses.

_TraceCounter(limit)

Stateful helper to guard noisy tracebacks.

Exceptions

_SkipBatch(stage)

Internal control-flow exception to skip invalid batches.

class maxent_grpo.training.pipeline.PreparedBatch(grouped_completions, reward_comp, batch_stats, total_input_tokens, scores, diversity_metrics=None)[source]

Bases: object

Artifacts required to run optimization for a training batch.

Parameters:
  • grouped_completions (list[list[str]]) – Nested list of completions per prompt.

  • reward_comp (RewardComputation) – Reward statistics computed by training.rewards.compute_reward_statistics().

  • batch_stats (_BatchStats) – Auxiliary scoring/weighting artifacts built by _collect_batch_stats().

  • total_input_tokens (float) – Prompt + completion token count used for throughput logging.

  • scores (SequenceScores) – Structure containing current-model log-probabilities aligned with the reference statistics.

  • diversity_metrics (Dict[str, float] | None)

grouped_completions: List[List[str]]
reward_comp: RewardComputation
batch_stats: _BatchStats
total_input_tokens: float
scores: SequenceScores
diversity_metrics: Dict[str, float] | None = None
property weight_stats: WeightStats

Shortcut to the batch weighting statistics.

property ref_stats: ReferenceLogprobs

Return reference log-probability statistics for the batch.

property length_stats: LengthStats

Return sequence length statistics computed for the batch.

property num_completion_tokens: float

Return total completion token count used to build the batch.

maxent_grpo.training.pipeline.prepare_training_batch(ctx, generator, batch)[source]

Return a PreparedBatch or None when any stage fails.

Parameters:
  • ctx (training.types.TrainingLoopContext) – Full training context containing generation/scoring configs.

  • generator (training.types.GenerationFn) – Callable that produces grouped completions (typically from training.rollout.CompletionGenerator).

  • batch (dict[str, list[str]]) – Mini-batch produced by the training dataloader.

Returns:

Fully-populated batch artifacts or None if generation, reward computation, reference scoring, or policy scoring fails.

Return type:

PreparedBatch | None