maxent_grpo.training.scoring_logprob

Model logprob computation and sequence-score assembly helpers.

Functions

_as_torch_tensor(torch_mod, value, *, ...)

Best-effort conversion of value into a torch tensor on device.

_chunked_sequence_logprobs(model, *, ...[, ...])

Compute summed log-probabilities per sequence with optional chunking/pooled states/entropy.

_match_tensor_length(torch_mod, tensor, ...)

Return tensor reshaped/padded to target_len elements.

_summon_fsdp_full_param_context(model)

Return a context manager that gathers FSDP parameters when available.

_trl_get_per_token_logps(model, input_ids, ...)

TRL-style per-token log-probabilities for completion tokens.

build_sequence_scores(cur_logp_sum, ref_stats)

Return SequenceScores built from current and reference log-probs.

score_model_outputs(model, score_batch, ...)

Compute current model log-probs for the batch and optional pooled states.

selective_log_softmax(logits, index)

Memory-efficient log_softmax + gather (TRL-style).

maxent_grpo.training.scoring_logprob.selective_log_softmax(logits, index)[source]

Memory-efficient log_softmax + gather (TRL-style).

Parameters:
  • logits (torch.Tensor)

  • index (torch.Tensor)

Return type:

torch.Tensor

maxent_grpo.training.scoring_logprob.score_model_outputs(model, score_batch, batching_cfg, runtime, *, return_hidden=False, pooling='mean', return_entropy=False, entropy_mode='exact', return_token_logp=False)[source]

Compute current model log-probs for the batch and optional pooled states.

Parameters:
  • model (PreTrainedModel) – Current policy model used for scoring.

  • score_batch (ScoreBatch) – Prepared scoring batch.

  • batching_cfg (BatchingSettings) – Batching config controlling logprob chunking.

  • runtime (RuntimeHandles) – Runtime handles providing device and accelerator state.

  • return_hidden (bool) – When True, also return pooled hidden states.

  • pooling (str) – Pooling strategy applied to hidden states.

  • return_entropy (bool)

  • entropy_mode (str)

  • return_token_logp (bool)

Returns:

Tuple of (cur_logp_sum, pooled_hidden[, policy_entropy_sum][, token_logp, token_mask]) or None if empty.

Return type:

tuple[Tensor, Tensor | None] | tuple[Tensor, Tensor | None, Tensor | None] | tuple[Tensor, Tensor | None, Tensor | None, Tensor | None, Tensor | None] | None

maxent_grpo.training.scoring_logprob.build_sequence_scores(cur_logp_sum, ref_stats, pooled_hidden=None, *, behavior_logp_sum=None, policy_entropy_sum=None, token_logp=None, token_mask=None, old_token_logp=None)[source]

Return SequenceScores built from current and reference log-probs.

Parameters:
  • cur_logp_sum (torch.Tensor) – Current policy log-prob sums per sequence.

  • ref_stats (ReferenceLogprobs) – Reference log-prob stats used for KL and weighting.

  • pooled_hidden (torch.Tensor | None) – Optional pooled hidden states for auxiliary losses.

  • behavior_logp_sum (torch.Tensor | None) – Optional behavior-policy log-probs for off-policy scoring.

  • policy_entropy_sum (torch.Tensor | None)

  • token_logp (torch.Tensor | None)

  • token_mask (torch.Tensor | None)

  • old_token_logp (torch.Tensor | None)

Returns:

SequenceScores dataclass with normalized log-probs and KL terms.

Return type:

SequenceScores