maxent_grpo.training.scoring_reference

Reference-logprob and vLLM metadata scoring helpers.

Functions

_coerce_int_optional(value)

Return int(value) when possible, otherwise None.

_coerce_logprob_value(value)

Best-effort conversion of a token logprob payload into a float.

_meta_field(entry, *names)

Return the first matching field from a metadata entry.

_sum_token_logprobs(token_logprobs)

Return the sum of per-token logprobs when the payload is parseable.

finalize_reference_stats(ref_logp_sum, ...)

Build a ReferenceLogprobs object and derived scalars.

gather_reference_logprobs(score_batch, ...)

Compute log-probabilities by running the frozen reference model.

reference_from_model(score_batch, runtime, ...)

Run the frozen reference model to compute log-probs.

reference_from_model_trl(score_batch, ...[, ...])

Run the frozen reference model using TRL-style log-prob computations.

reference_from_vllm_meta(flat_meta, ...)

Convert flattened vLLM log-prob metadata into ReferenceLogprobs.

reference_stats_from_policy_logprobs(...)

Build ReferenceLogprobs assuming reference == current policy (KL ~= 0).

vllm_meta_has_logprobs(flat_meta[, ...])

Return True when vLLM metadata includes per-completion logprob info.

maxent_grpo.training.scoring_reference.reference_from_model_trl(score_batch, runtime, batching_cfg, *, temperature=None)[source]

Run the frozen reference model using TRL-style log-prob computations.

Parameters:
Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None] | None

maxent_grpo.training.scoring_reference.reference_from_model(score_batch, runtime, batching_cfg)[source]

Run the frozen reference model to compute log-probs.

Parameters:
  • score_batch (ScoreBatch) – Prepared scoring batch with prompts/completions.

  • runtime (RuntimeHandles) – Runtime handles exposing device and reference model.

  • batching_cfg (BatchingSettings) – Batching config controlling logprob chunking.

Returns:

Tuple of (ref_logp_sum, ref_token_counts) or None on failure.

Return type:

tuple[Tensor, Tensor] | None

maxent_grpo.training.scoring_reference.gather_reference_logprobs(score_batch, runtime, batching_cfg, *, trl_reference_scoring=False, temperature=None)[source]

Compute log-probabilities by running the frozen reference model.

This function handles distributed preflight checks to avoid ZeRO hangs and aggregates reference statistics into a ReferenceLogprobs object.

Parameters:
  • score_batch (ScoreBatch) – Prepared scoring batch with prompts/completions.

  • runtime (RuntimeHandles) – Runtime handles exposing device, accelerator, and models.

  • batching_cfg (BatchingSettings) – Batching config controlling logprob chunking.

  • trl_reference_scoring (bool) – When True, use TRL/open-r1 reference scoring logic.

  • temperature (float | None) – Optional temperature for TRL-style logprob scaling.

Returns:

ReferenceLogprobs or None when reference scoring fails.

Return type:

ReferenceLogprobs | None

maxent_grpo.training.scoring_reference.finalize_reference_stats(ref_logp_sum, ref_tok_counts, *, ref_token_logp=None, ref_token_mask=None)[source]

Build a ReferenceLogprobs object and derived scalars.

Parameters:
  • ref_logp_sum (torch.Tensor) – Per-sequence sum of reference log-probabilities.

  • ref_tok_counts (torch.Tensor) – Per-sequence token counts.

  • ref_token_logp (torch.Tensor | None) – Optional per-token reference log-probs (completion-only).

  • ref_token_mask (torch.Tensor | None) – Optional per-token completion mask aligned to ref_token_logp.

Returns:

Normalized reference stats and summary scalars.

Return type:

ReferenceLogprobs

Raises:

ValueError – If log-prob tensors cannot be safely normalized.

maxent_grpo.training.scoring_reference.reference_stats_from_policy_logprobs(cur_logp_sum, tok_counts)[source]

Build ReferenceLogprobs assuming reference == current policy (KL ~= 0).

Parameters:
  • cur_logp_sum (torch.Tensor) – Current policy log-prob sums per sequence.

  • tok_counts (torch.Tensor) – Token counts per sequence.

Returns:

Reference stats derived directly from the current policy.

Return type:

ReferenceLogprobs

maxent_grpo.training.scoring_reference.reference_from_vllm_meta(flat_meta, total_sequences, device)[source]

Convert flattened vLLM log-prob metadata into ReferenceLogprobs.

Parameters:
  • flat_meta (Sequence[object | None]) – Flat list of vLLM metadata entries (one per completion).

  • total_sequences (int) – Expected number of sequences in the batch.

  • device (Any) – Device for the resulting tensors.

Returns:

ReferenceLogprobs or None when metadata is incomplete.

Return type:

ReferenceLogprobs | None

maxent_grpo.training.scoring_reference.vllm_meta_has_logprobs(flat_meta, total_sequences=None)[source]

Return True when vLLM metadata includes per-completion logprob info.

Parameters:
  • flat_meta (Sequence[object | None] | None) – Flat list of vLLM metadata entries.

  • total_sequences (int | None) – Optional expected length used for sanity checks.

Returns:

True when logprob metadata appears complete.

Return type:

bool