maxent_grpo.training.scoring¶
Compatibility facade for scoring helpers.
Implementation is split across focused modules:
- scoring_batching (prompt/completion batching + slices)
- scoring_logprob (policy/reference logprob kernels + sequence scores)
- scoring_reference (reference-model/vLLM metadata paths)
- class maxent_grpo.training.scoring.CompletionTensors(ids, mask)[source]¶
Bases:
objectCompletion token IDs and masks.
- Parameters:
ids (torch.Tensor)
mask (torch.Tensor)
- ids: torch.Tensor¶
- mask: torch.Tensor¶
- maxent_grpo.training.scoring.build_score_batch(reward_comp, tokenizer, generation_cfg, batching_cfg)[source]¶
Tokenize prompt+completion pairs and prepare masks/labels.
- Parameters:
reward_comp (RewardComputation) – Reward computation payload containing prompts and completions.
tokenizer (PreTrainedTokenizer) – Tokenizer used to encode completions and determine padding.
generation_cfg (GenerationSettings) – Generation settings (max lengths, etc.).
batching_cfg (BatchingSettings) – Batching settings controlling scoring slice sizes.
- Returns:
Prepared
ScoreBatchorNonewhen no sequences are available.- Return type:
ScoreBatch | None
- maxent_grpo.training.scoring.build_sequence_scores(cur_logp_sum, ref_stats, pooled_hidden=None, *, behavior_logp_sum=None, policy_entropy_sum=None, token_logp=None, token_mask=None, old_token_logp=None)[source]¶
Return
SequenceScoresbuilt from current and reference log-probs.- Parameters:
cur_logp_sum (torch.Tensor) – Current policy log-prob sums per sequence.
ref_stats (ReferenceLogprobs) – Reference log-prob stats used for KL and weighting.
pooled_hidden (torch.Tensor | None) – Optional pooled hidden states for auxiliary losses.
behavior_logp_sum (torch.Tensor | None) – Optional behavior-policy log-probs for off-policy scoring.
policy_entropy_sum (torch.Tensor | None)
token_logp (torch.Tensor | None)
token_mask (torch.Tensor | None)
old_token_logp (torch.Tensor | None)
- Returns:
SequenceScoresdataclass with normalized log-probs and KL terms.- Return type:
- maxent_grpo.training.scoring.finalize_reference_stats(ref_logp_sum, ref_tok_counts, *, ref_token_logp=None, ref_token_mask=None)[source]¶
Build a
ReferenceLogprobsobject and derived scalars.- Parameters:
ref_logp_sum (torch.Tensor) – Per-sequence sum of reference log-probabilities.
ref_tok_counts (torch.Tensor) – Per-sequence token counts.
ref_token_logp (torch.Tensor | None) – Optional per-token reference log-probs (completion-only).
ref_token_mask (torch.Tensor | None) – Optional per-token completion mask aligned to
ref_token_logp.
- Returns:
Normalized reference stats and summary scalars.
- Return type:
- Raises:
ValueError – If log-prob tensors cannot be safely normalized.
- maxent_grpo.training.scoring.gather_reference_logprobs(score_batch, runtime, batching_cfg, *, trl_reference_scoring=False, temperature=None)[source]¶
Compute log-probabilities by running the frozen reference model.
This function handles distributed preflight checks to avoid ZeRO hangs and aggregates reference statistics into a
ReferenceLogprobsobject.- Parameters:
score_batch (ScoreBatch) – Prepared scoring batch with prompts/completions.
runtime (RuntimeHandles) – Runtime handles exposing device, accelerator, and models.
batching_cfg (BatchingSettings) – Batching config controlling logprob chunking.
trl_reference_scoring (bool) – When True, use TRL/open-r1 reference scoring logic.
temperature (float | None) – Optional temperature for TRL-style logprob scaling.
- Returns:
ReferenceLogprobsorNonewhen reference scoring fails.- Return type:
ReferenceLogprobs | None
- maxent_grpo.training.scoring.iter_batch_slices(score_batch, device, *, eos_token_id=None, apply_eos_mask=False)[source]¶
Yield scoring slices for a batch, assembling prompt tensors on demand.
- Parameters:
score_batch (ScoreBatch) – Prepared prompt/completion tensors and metadata.
device (torch.device) – Device where tensors should be materialized.
eos_token_id (int | None) – Optional EOS token id for TRL-style completion masking.
apply_eos_mask (bool) – When
True, apply EOS-aware completion masks.
- Yields:
Tuples of
(input_ids, attention_mask, labels)per slice.- Return type:
Iterator[tuple[Tensor, Tensor, Tensor]]
- maxent_grpo.training.scoring.iter_batch_slices_trl(score_batch, runtime, eos_token_id)[source]¶
Yield prompt+completion slices for TRL-style logprob computation.
- Parameters:
score_batch (ScoreBatch)
runtime (RuntimeHandles)
eos_token_id (int | None)
- Return type:
Iterator[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]]
- maxent_grpo.training.scoring.reference_from_model(score_batch, runtime, batching_cfg)[source]¶
Run the frozen reference model to compute log-probs.
- Parameters:
score_batch (ScoreBatch) – Prepared scoring batch with prompts/completions.
runtime (RuntimeHandles) – Runtime handles exposing device and reference model.
batching_cfg (BatchingSettings) – Batching config controlling logprob chunking.
- Returns:
Tuple of
(ref_logp_sum, ref_token_counts)orNoneon failure.- Return type:
tuple[Tensor, Tensor] | None
- maxent_grpo.training.scoring.reference_from_model_trl(score_batch, runtime, batching_cfg, *, temperature=None)[source]¶
Run the frozen reference model using TRL-style log-prob computations.
- Parameters:
score_batch (ScoreBatch)
runtime (RuntimeHandles)
batching_cfg (BatchingSettings)
temperature (float | None)
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None] | None
- maxent_grpo.training.scoring.reference_from_vllm_meta(flat_meta, total_sequences, device)[source]¶
Convert flattened vLLM log-prob metadata into
ReferenceLogprobs.- Parameters:
- Returns:
ReferenceLogprobsorNonewhen metadata is incomplete.- Return type:
ReferenceLogprobs | None
- maxent_grpo.training.scoring.reference_stats_from_policy_logprobs(cur_logp_sum, tok_counts)[source]¶
Build
ReferenceLogprobsassuming reference == current policy (KL ~= 0).- Parameters:
cur_logp_sum (torch.Tensor) – Current policy log-prob sums per sequence.
tok_counts (torch.Tensor) – Token counts per sequence.
- Returns:
Reference stats derived directly from the current policy.
- Return type:
- maxent_grpo.training.scoring.score_model_outputs(model, score_batch, batching_cfg, runtime, *, return_hidden=False, pooling='mean', return_entropy=False, entropy_mode='exact', return_token_logp=False)[source]¶
Compute current model log-probs for the batch and optional pooled states.
- Parameters:
model (PreTrainedModel) – Current policy model used for scoring.
score_batch (ScoreBatch) – Prepared scoring batch.
batching_cfg (BatchingSettings) – Batching config controlling logprob chunking.
runtime (RuntimeHandles) – Runtime handles providing device and accelerator state.
return_hidden (bool) – When
True, also return pooled hidden states.pooling (str) – Pooling strategy applied to hidden states.
return_entropy (bool)
entropy_mode (str)
return_token_logp (bool)
- Returns:
Tuple of
(cur_logp_sum, pooled_hidden[, policy_entropy_sum][, token_logp, token_mask])orNoneif empty.- Return type:
tuple[Tensor, Tensor | None] | tuple[Tensor, Tensor | None, Tensor | None] | tuple[Tensor, Tensor | None, Tensor | None, Tensor | None, Tensor | None] | None
- maxent_grpo.training.scoring.selective_log_softmax(logits, index)[source]¶
Memory-efficient log_softmax + gather (TRL-style).
- Parameters:
logits (torch.Tensor)
index (torch.Tensor)
- Return type:
torch.Tensor
- maxent_grpo.training.scoring.summarize_completion_lengths(ref_stats, max_completion_len)[source]¶
Summarize completion lengths for metrics.
- Parameters:
ref_stats (ReferenceLogprobs) – Reference log-prob stats containing token counts.
max_completion_len (int) – Maximum completion length used for clipping stats.
- Returns:
Tuple of
(completion_lengths, length_stats, total_tokens).- Return type:
tuple[Tensor, LengthStats, float]
- maxent_grpo.training.scoring.token_counts_from_score_batch(score_batch, runtime, batching_cfg)[source]¶
Compute per-sequence token counts from the score batch labels mask.
- Parameters:
score_batch (ScoreBatch) – Prepared scoring batch.
runtime (RuntimeHandles) – Runtime handles exposing device/accelerator.
batching_cfg (BatchingSettings) – Batching config controlling slice sizes.
- Returns:
1D tensor of token counts per sequence.
- Return type:
Tensor