maxent_grpo.training.scoring

Compatibility facade for scoring helpers.

Implementation is split across focused modules: - scoring_batching (prompt/completion batching + slices) - scoring_logprob (policy/reference logprob kernels + sequence scores) - scoring_reference (reference-model/vLLM metadata paths)

class maxent_grpo.training.scoring.CompletionTensors(ids, mask)[source]

Bases: object

Completion token IDs and masks.

Parameters:
  • ids (torch.Tensor)

  • mask (torch.Tensor)

ids: torch.Tensor
mask: torch.Tensor
maxent_grpo.training.scoring.build_score_batch(reward_comp, tokenizer, generation_cfg, batching_cfg)[source]

Tokenize prompt+completion pairs and prepare masks/labels.

Parameters:
  • reward_comp (RewardComputation) – Reward computation payload containing prompts and completions.

  • tokenizer (PreTrainedTokenizer) – Tokenizer used to encode completions and determine padding.

  • generation_cfg (GenerationSettings) – Generation settings (max lengths, etc.).

  • batching_cfg (BatchingSettings) – Batching settings controlling scoring slice sizes.

Returns:

Prepared ScoreBatch or None when no sequences are available.

Return type:

ScoreBatch | None

maxent_grpo.training.scoring.build_sequence_scores(cur_logp_sum, ref_stats, pooled_hidden=None, *, behavior_logp_sum=None, policy_entropy_sum=None, token_logp=None, token_mask=None, old_token_logp=None)[source]

Return SequenceScores built from current and reference log-probs.

Parameters:
  • cur_logp_sum (torch.Tensor) – Current policy log-prob sums per sequence.

  • ref_stats (ReferenceLogprobs) – Reference log-prob stats used for KL and weighting.

  • pooled_hidden (torch.Tensor | None) – Optional pooled hidden states for auxiliary losses.

  • behavior_logp_sum (torch.Tensor | None) – Optional behavior-policy log-probs for off-policy scoring.

  • policy_entropy_sum (torch.Tensor | None)

  • token_logp (torch.Tensor | None)

  • token_mask (torch.Tensor | None)

  • old_token_logp (torch.Tensor | None)

Returns:

SequenceScores dataclass with normalized log-probs and KL terms.

Return type:

SequenceScores

maxent_grpo.training.scoring.finalize_reference_stats(ref_logp_sum, ref_tok_counts, *, ref_token_logp=None, ref_token_mask=None)[source]

Build a ReferenceLogprobs object and derived scalars.

Parameters:
  • ref_logp_sum (torch.Tensor) – Per-sequence sum of reference log-probabilities.

  • ref_tok_counts (torch.Tensor) – Per-sequence token counts.

  • ref_token_logp (torch.Tensor | None) – Optional per-token reference log-probs (completion-only).

  • ref_token_mask (torch.Tensor | None) – Optional per-token completion mask aligned to ref_token_logp.

Returns:

Normalized reference stats and summary scalars.

Return type:

ReferenceLogprobs

Raises:

ValueError – If log-prob tensors cannot be safely normalized.

maxent_grpo.training.scoring.gather_reference_logprobs(score_batch, runtime, batching_cfg, *, trl_reference_scoring=False, temperature=None)[source]

Compute log-probabilities by running the frozen reference model.

This function handles distributed preflight checks to avoid ZeRO hangs and aggregates reference statistics into a ReferenceLogprobs object.

Parameters:
  • score_batch (ScoreBatch) – Prepared scoring batch with prompts/completions.

  • runtime (RuntimeHandles) – Runtime handles exposing device, accelerator, and models.

  • batching_cfg (BatchingSettings) – Batching config controlling logprob chunking.

  • trl_reference_scoring (bool) – When True, use TRL/open-r1 reference scoring logic.

  • temperature (float | None) – Optional temperature for TRL-style logprob scaling.

Returns:

ReferenceLogprobs or None when reference scoring fails.

Return type:

ReferenceLogprobs | None

maxent_grpo.training.scoring.iter_batch_slices(score_batch, device, *, eos_token_id=None, apply_eos_mask=False)[source]

Yield scoring slices for a batch, assembling prompt tensors on demand.

Parameters:
  • score_batch (ScoreBatch) – Prepared prompt/completion tensors and metadata.

  • device (torch.device) – Device where tensors should be materialized.

  • eos_token_id (int | None) – Optional EOS token id for TRL-style completion masking.

  • apply_eos_mask (bool) – When True, apply EOS-aware completion masks.

Yields:

Tuples of (input_ids, attention_mask, labels) per slice.

Return type:

Iterator[tuple[Tensor, Tensor, Tensor]]

maxent_grpo.training.scoring.iter_batch_slices_trl(score_batch, runtime, eos_token_id)[source]

Yield prompt+completion slices for TRL-style logprob computation.

Parameters:
Return type:

Iterator[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]]

maxent_grpo.training.scoring.reference_from_model(score_batch, runtime, batching_cfg)[source]

Run the frozen reference model to compute log-probs.

Parameters:
  • score_batch (ScoreBatch) – Prepared scoring batch with prompts/completions.

  • runtime (RuntimeHandles) – Runtime handles exposing device and reference model.

  • batching_cfg (BatchingSettings) – Batching config controlling logprob chunking.

Returns:

Tuple of (ref_logp_sum, ref_token_counts) or None on failure.

Return type:

tuple[Tensor, Tensor] | None

maxent_grpo.training.scoring.reference_from_model_trl(score_batch, runtime, batching_cfg, *, temperature=None)[source]

Run the frozen reference model using TRL-style log-prob computations.

Parameters:
Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None] | None

maxent_grpo.training.scoring.reference_from_vllm_meta(flat_meta, total_sequences, device)[source]

Convert flattened vLLM log-prob metadata into ReferenceLogprobs.

Parameters:
  • flat_meta (Sequence[object | None]) – Flat list of vLLM metadata entries (one per completion).

  • total_sequences (int) – Expected number of sequences in the batch.

  • device (Any) – Device for the resulting tensors.

Returns:

ReferenceLogprobs or None when metadata is incomplete.

Return type:

ReferenceLogprobs | None

maxent_grpo.training.scoring.reference_stats_from_policy_logprobs(cur_logp_sum, tok_counts)[source]

Build ReferenceLogprobs assuming reference == current policy (KL ~= 0).

Parameters:
  • cur_logp_sum (torch.Tensor) – Current policy log-prob sums per sequence.

  • tok_counts (torch.Tensor) – Token counts per sequence.

Returns:

Reference stats derived directly from the current policy.

Return type:

ReferenceLogprobs

maxent_grpo.training.scoring.score_model_outputs(model, score_batch, batching_cfg, runtime, *, return_hidden=False, pooling='mean', return_entropy=False, entropy_mode='exact', return_token_logp=False)[source]

Compute current model log-probs for the batch and optional pooled states.

Parameters:
  • model (PreTrainedModel) – Current policy model used for scoring.

  • score_batch (ScoreBatch) – Prepared scoring batch.

  • batching_cfg (BatchingSettings) – Batching config controlling logprob chunking.

  • runtime (RuntimeHandles) – Runtime handles providing device and accelerator state.

  • return_hidden (bool) – When True, also return pooled hidden states.

  • pooling (str) – Pooling strategy applied to hidden states.

  • return_entropy (bool)

  • entropy_mode (str)

  • return_token_logp (bool)

Returns:

Tuple of (cur_logp_sum, pooled_hidden[, policy_entropy_sum][, token_logp, token_mask]) or None if empty.

Return type:

tuple[Tensor, Tensor | None] | tuple[Tensor, Tensor | None, Tensor | None] | tuple[Tensor, Tensor | None, Tensor | None, Tensor | None, Tensor | None] | None

maxent_grpo.training.scoring.selective_log_softmax(logits, index)[source]

Memory-efficient log_softmax + gather (TRL-style).

Parameters:
  • logits (torch.Tensor)

  • index (torch.Tensor)

Return type:

torch.Tensor

maxent_grpo.training.scoring.summarize_completion_lengths(ref_stats, max_completion_len)[source]

Summarize completion lengths for metrics.

Parameters:
  • ref_stats (ReferenceLogprobs) – Reference log-prob stats containing token counts.

  • max_completion_len (int) – Maximum completion length used for clipping stats.

Returns:

Tuple of (completion_lengths, length_stats, total_tokens).

Return type:

tuple[Tensor, LengthStats, float]

maxent_grpo.training.scoring.token_counts_from_score_batch(score_batch, runtime, batching_cfg)[source]

Compute per-sequence token counts from the score batch labels mask.

Parameters:
  • score_batch (ScoreBatch) – Prepared scoring batch.

  • runtime (RuntimeHandles) – Runtime handles exposing device/accelerator.

  • batching_cfg (BatchingSettings) – Batching config controlling slice sizes.

Returns:

1D tensor of token counts per sequence.

Return type:

Tensor

maxent_grpo.training.scoring.vllm_meta_has_logprobs(flat_meta, total_sequences=None)[source]

Return True when vLLM metadata includes per-completion logprob info.

Parameters:
  • flat_meta (Sequence[object | None] | None) – Flat list of vLLM metadata entries.

  • total_sequences (int | None) – Optional expected length used for sanity checks.

Returns:

True when logprob metadata appears complete.

Return type:

bool