maxent_grpo.training.scoring_reference¶
Reference-logprob and vLLM metadata scoring helpers.
Functions
|
Return |
|
Best-effort conversion of a token logprob payload into a float. |
|
Return the first matching field from a metadata entry. |
|
Return the sum of per-token logprobs when the payload is parseable. |
|
Build a |
|
Compute log-probabilities by running the frozen reference model. |
|
Run the frozen reference model to compute log-probs. |
|
Run the frozen reference model using TRL-style log-prob computations. |
|
Convert flattened vLLM log-prob metadata into |
Build |
|
|
Return True when vLLM metadata includes per-completion logprob info. |
- maxent_grpo.training.scoring_reference.reference_from_model_trl(score_batch, runtime, batching_cfg, *, temperature=None)[source]¶
Run the frozen reference model using TRL-style log-prob computations.
- Parameters:
score_batch (ScoreBatch)
runtime (RuntimeHandles)
batching_cfg (BatchingSettings)
temperature (float | None)
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None] | None
- maxent_grpo.training.scoring_reference.reference_from_model(score_batch, runtime, batching_cfg)[source]¶
Run the frozen reference model to compute log-probs.
- Parameters:
score_batch (ScoreBatch) – Prepared scoring batch with prompts/completions.
runtime (RuntimeHandles) – Runtime handles exposing device and reference model.
batching_cfg (BatchingSettings) – Batching config controlling logprob chunking.
- Returns:
Tuple of
(ref_logp_sum, ref_token_counts)orNoneon failure.- Return type:
tuple[Tensor, Tensor] | None
- maxent_grpo.training.scoring_reference.gather_reference_logprobs(score_batch, runtime, batching_cfg, *, trl_reference_scoring=False, temperature=None)[source]¶
Compute log-probabilities by running the frozen reference model.
This function handles distributed preflight checks to avoid ZeRO hangs and aggregates reference statistics into a
ReferenceLogprobsobject.- Parameters:
score_batch (ScoreBatch) – Prepared scoring batch with prompts/completions.
runtime (RuntimeHandles) – Runtime handles exposing device, accelerator, and models.
batching_cfg (BatchingSettings) – Batching config controlling logprob chunking.
trl_reference_scoring (bool) – When True, use TRL/open-r1 reference scoring logic.
temperature (float | None) – Optional temperature for TRL-style logprob scaling.
- Returns:
ReferenceLogprobsorNonewhen reference scoring fails.- Return type:
ReferenceLogprobs | None
- maxent_grpo.training.scoring_reference.finalize_reference_stats(ref_logp_sum, ref_tok_counts, *, ref_token_logp=None, ref_token_mask=None)[source]¶
Build a
ReferenceLogprobsobject and derived scalars.- Parameters:
ref_logp_sum (torch.Tensor) – Per-sequence sum of reference log-probabilities.
ref_tok_counts (torch.Tensor) – Per-sequence token counts.
ref_token_logp (torch.Tensor | None) – Optional per-token reference log-probs (completion-only).
ref_token_mask (torch.Tensor | None) – Optional per-token completion mask aligned to
ref_token_logp.
- Returns:
Normalized reference stats and summary scalars.
- Return type:
- Raises:
ValueError – If log-prob tensors cannot be safely normalized.
- maxent_grpo.training.scoring_reference.reference_stats_from_policy_logprobs(cur_logp_sum, tok_counts)[source]¶
Build
ReferenceLogprobsassuming reference == current policy (KL ~= 0).- Parameters:
cur_logp_sum (torch.Tensor) – Current policy log-prob sums per sequence.
tok_counts (torch.Tensor) – Token counts per sequence.
- Returns:
Reference stats derived directly from the current policy.
- Return type:
- maxent_grpo.training.scoring_reference.reference_from_vllm_meta(flat_meta, total_sequences, device)[source]¶
Convert flattened vLLM log-prob metadata into
ReferenceLogprobs.- Parameters:
- Returns:
ReferenceLogprobsorNonewhen metadata is incomplete.- Return type:
ReferenceLogprobs | None