maxent_grpo.training.pipeline¶
Helpers for preparing generation/scoring artifacts used by the training loop.
The training loop expects a consistent set of artifacts for every batch:
PreparedBatchBundles grouped completions, reward statistics, reference log-probability tensors, weighting diagnostics, and derived scores.
_collect_batch_statsBridges generation/reward outputs with the scoring stack by building
ScoreBatchobjects, gathering reference log-probs when necessary, and computing weighting and length summaries.prepare_training_batchHigh-level orchestration that runs the generation function, computes rewards, fetches reference log-probs, scores the policy, and returns a
PreparedBatchinstance to the optimizer.
The helpers raise the internal _SkipBatch exception when any step
fails; prepare_training_batch() catches it and returns None so the
caller can skip the problematic batch gracefully.
Functions
|
Return a tensor of behavior log-prob sums derived from metadata. |
|
|
|
Gather scoring, reference, and weighting artifacts for a batch. |
|
Return coarse diversity metrics for grouped completions. |
|
Return DeepSpeed ZeRO stage from Accelerate plugin state when present. |
|
Return True if flag is True on any rank (best-effort, object gather). |
|
|
|
Optionally add a policy-entropy bonus to rewards and refresh stats. |
|
Return the arithmetic mean for a non-empty list, else 0.0. |
|
|
|
Return best-effort rank string for logging. |
|
Return reference stats when metadata fully covers all sequences. |
|
Return |
|
Return a weighting attribute with graceful fallbacks. |
|
Return per-token log-prob tensor derived from vLLM metadata when available. |
|
Tokenize a completion for diversity metrics. |
|
Return the weighted mean or 0.0 when weights are empty. |
|
Return a |
Classes
|
Artifacts required to run optimization for a training batch. |
|
Aggregated batch statistics before building losses. |
|
Stateful helper to guard noisy tracebacks. |
Exceptions
|
Internal control-flow exception to skip invalid batches. |
- class maxent_grpo.training.pipeline.PreparedBatch(grouped_completions, reward_comp, batch_stats, total_input_tokens, scores, diversity_metrics=None)[source]¶
Bases:
objectArtifacts required to run optimization for a training batch.
- Parameters:
grouped_completions (list[list[str]]) – Nested list of completions per prompt.
reward_comp (RewardComputation) – Reward statistics computed by
training.rewards.compute_reward_statistics().batch_stats (_BatchStats) – Auxiliary scoring/weighting artifacts built by
_collect_batch_stats().total_input_tokens (float) – Prompt + completion token count used for throughput logging.
scores (SequenceScores) – Structure containing current-model log-probabilities aligned with the reference statistics.
- reward_comp: RewardComputation¶
- batch_stats: _BatchStats¶
- scores: SequenceScores¶
- property weight_stats: WeightStats¶
Shortcut to the batch weighting statistics.
- property ref_stats: ReferenceLogprobs¶
Return reference log-probability statistics for the batch.
- property length_stats: LengthStats¶
Return sequence length statistics computed for the batch.
- maxent_grpo.training.pipeline.prepare_training_batch(ctx, generator, batch)[source]¶
Return a
PreparedBatchorNonewhen any stage fails.- Parameters:
ctx (training.types.TrainingLoopContext) – Full training context containing generation/scoring configs.
generator (training.types.GenerationFn) – Callable that produces grouped completions (typically from
training.rollout.CompletionGenerator).batch (dict[str, list[str]]) – Mini-batch produced by the training dataloader.
- Returns:
Fully-populated batch artifacts or
Noneif generation, reward computation, reference scoring, or policy scoring fails.- Return type:
PreparedBatch | None