maxent_grpo.training.scoring_batching

Batch construction and slice materialization helpers for scoring.

Functions

_apply_eos_completion_mask(completion_ids, ...)

Mask completion tokens after the first EOS token (TRL-style).

_collect_prompt_entries(prompt_batch, ...)

Resolve cached prompt tokenization for a batch of strings.

_completion_tensors_from_token_ids(...)

Build completion tensors from pre-tokenized token-id sequences.

_prepare_prompt_slice(prompt_slice, ...)

Materialize prompt tensors for one scoring slice.

_slice_tail_window(start_idx, input_ids, ...)

Slice tail tokens without closing over loop variables.

_tokenize_completions(completion_batch, ...)

Tokenize completions into padded tensors.

build_score_batch(reward_comp, tokenizer, ...)

Tokenize prompt+completion pairs and prepare masks/labels.

iter_batch_slices(score_batch, device, *[, ...])

Yield scoring slices for a batch, assembling prompt tensors on demand.

iter_batch_slices_trl(score_batch, runtime, ...)

Yield prompt+completion slices for TRL-style logprob computation.

summarize_completion_lengths(ref_stats, ...)

Summarize completion lengths for metrics.

token_counts_from_score_batch(score_batch, ...)

Compute per-sequence token counts from the score batch labels mask.

Classes

CompletionTensors(ids, mask)

Completion token IDs and masks.

_PromptCacheConfig(prompt_length_cache_get)

_SliceState(total_sequences, slice_size, ...)

Cached tensors and metadata required for batch slicing.

class maxent_grpo.training.scoring_batching.CompletionTensors(ids, mask)[source]

Bases: object

Completion token IDs and masks.

Parameters:
  • ids (torch.Tensor)

  • mask (torch.Tensor)

ids: torch.Tensor
mask: torch.Tensor
maxent_grpo.training.scoring_batching.iter_batch_slices(score_batch, device, *, eos_token_id=None, apply_eos_mask=False)[source]

Yield scoring slices for a batch, assembling prompt tensors on demand.

Parameters:
  • score_batch (ScoreBatch) – Prepared prompt/completion tensors and metadata.

  • device (torch.device) – Device where tensors should be materialized.

  • eos_token_id (int | None) – Optional EOS token id for TRL-style completion masking.

  • apply_eos_mask (bool) – When True, apply EOS-aware completion masks.

Yields:

Tuples of (input_ids, attention_mask, labels) per slice.

Return type:

Iterator[tuple[Tensor, Tensor, Tensor]]

maxent_grpo.training.scoring_batching.build_score_batch(reward_comp, tokenizer, generation_cfg, batching_cfg)[source]

Tokenize prompt+completion pairs and prepare masks/labels.

Parameters:
  • reward_comp (RewardComputation) – Reward computation payload containing prompts and completions.

  • tokenizer (PreTrainedTokenizer) – Tokenizer used to encode completions and determine padding.

  • generation_cfg (GenerationSettings) – Generation settings (max lengths, etc.).

  • batching_cfg (BatchingSettings) – Batching settings controlling scoring slice sizes.

Returns:

Prepared ScoreBatch or None when no sequences are available.

Return type:

ScoreBatch | None

maxent_grpo.training.scoring_batching.iter_batch_slices_trl(score_batch, runtime, eos_token_id)[source]

Yield prompt+completion slices for TRL-style logprob computation.

Parameters:
Return type:

Iterator[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]]

maxent_grpo.training.scoring_batching.token_counts_from_score_batch(score_batch, runtime, batching_cfg)[source]

Compute per-sequence token counts from the score batch labels mask.

Parameters:
  • score_batch (ScoreBatch) – Prepared scoring batch.

  • runtime (RuntimeHandles) – Runtime handles exposing device/accelerator.

  • batching_cfg (BatchingSettings) – Batching config controlling slice sizes.

Returns:

1D tensor of token counts per sequence.

Return type:

Tensor

maxent_grpo.training.scoring_batching.summarize_completion_lengths(ref_stats, max_completion_len)[source]

Summarize completion lengths for metrics.

Parameters:
  • ref_stats (ReferenceLogprobs) – Reference log-prob stats containing token counts.

  • max_completion_len (int) – Maximum completion length used for clipping stats.

Returns:

Tuple of (completion_lengths, length_stats, total_tokens).

Return type:

tuple[Tensor, LengthStats, float]