maxent_grpo.training.scoring_logprob¶
Model logprob computation and sequence-score assembly helpers.
Functions
|
Best-effort conversion of |
|
Compute summed log-probabilities per sequence with optional chunking/pooled states/entropy. |
|
Return |
|
Return a context manager that gathers FSDP parameters when available. |
|
TRL-style per-token log-probabilities for completion tokens. |
|
Return |
|
Compute current model log-probs for the batch and optional pooled states. |
|
Memory-efficient log_softmax + gather (TRL-style). |
- maxent_grpo.training.scoring_logprob.selective_log_softmax(logits, index)[source]¶
Memory-efficient log_softmax + gather (TRL-style).
- Parameters:
logits (torch.Tensor)
index (torch.Tensor)
- Return type:
torch.Tensor
- maxent_grpo.training.scoring_logprob.score_model_outputs(model, score_batch, batching_cfg, runtime, *, return_hidden=False, pooling='mean', return_entropy=False, entropy_mode='exact', return_token_logp=False)[source]¶
Compute current model log-probs for the batch and optional pooled states.
- Parameters:
model (PreTrainedModel) – Current policy model used for scoring.
score_batch (ScoreBatch) – Prepared scoring batch.
batching_cfg (BatchingSettings) – Batching config controlling logprob chunking.
runtime (RuntimeHandles) – Runtime handles providing device and accelerator state.
return_hidden (bool) – When
True, also return pooled hidden states.pooling (str) – Pooling strategy applied to hidden states.
return_entropy (bool)
entropy_mode (str)
return_token_logp (bool)
- Returns:
Tuple of
(cur_logp_sum, pooled_hidden[, policy_entropy_sum][, token_logp, token_mask])orNoneif empty.- Return type:
tuple[Tensor, Tensor | None] | tuple[Tensor, Tensor | None, Tensor | None] | tuple[Tensor, Tensor | None, Tensor | None, Tensor | None, Tensor | None] | None
- maxent_grpo.training.scoring_logprob.build_sequence_scores(cur_logp_sum, ref_stats, pooled_hidden=None, *, behavior_logp_sum=None, policy_entropy_sum=None, token_logp=None, token_mask=None, old_token_logp=None)[source]¶
Return
SequenceScoresbuilt from current and reference log-probs.- Parameters:
cur_logp_sum (torch.Tensor) – Current policy log-prob sums per sequence.
ref_stats (ReferenceLogprobs) – Reference log-prob stats used for KL and weighting.
pooled_hidden (torch.Tensor | None) – Optional pooled hidden states for auxiliary losses.
behavior_logp_sum (torch.Tensor | None) – Optional behavior-policy log-probs for off-policy scoring.
policy_entropy_sum (torch.Tensor | None)
token_logp (torch.Tensor | None)
token_mask (torch.Tensor | None)
old_token_logp (torch.Tensor | None)
- Returns:
SequenceScoresdataclass with normalized log-probs and KL terms.- Return type: