maxent_grpo.training.scoring_batching¶
Batch construction and slice materialization helpers for scoring.
Functions
|
Mask completion tokens after the first EOS token (TRL-style). |
|
Resolve cached prompt tokenization for a batch of strings. |
|
Build completion tensors from pre-tokenized token-id sequences. |
|
Materialize prompt tensors for one scoring slice. |
|
Slice tail tokens without closing over loop variables. |
|
Tokenize completions into padded tensors. |
|
Tokenize prompt+completion pairs and prepare masks/labels. |
|
Yield scoring slices for a batch, assembling prompt tensors on demand. |
|
Yield prompt+completion slices for TRL-style logprob computation. |
|
Summarize completion lengths for metrics. |
|
Compute per-sequence token counts from the score batch labels mask. |
Classes
|
Completion token IDs and masks. |
|
|
|
Cached tensors and metadata required for batch slicing. |
- class maxent_grpo.training.scoring_batching.CompletionTensors(ids, mask)[source]¶
Bases:
objectCompletion token IDs and masks.
- Parameters:
ids (torch.Tensor)
mask (torch.Tensor)
- ids: torch.Tensor¶
- mask: torch.Tensor¶
- maxent_grpo.training.scoring_batching.iter_batch_slices(score_batch, device, *, eos_token_id=None, apply_eos_mask=False)[source]¶
Yield scoring slices for a batch, assembling prompt tensors on demand.
- Parameters:
score_batch (ScoreBatch) – Prepared prompt/completion tensors and metadata.
device (torch.device) – Device where tensors should be materialized.
eos_token_id (int | None) – Optional EOS token id for TRL-style completion masking.
apply_eos_mask (bool) – When
True, apply EOS-aware completion masks.
- Yields:
Tuples of
(input_ids, attention_mask, labels)per slice.- Return type:
Iterator[tuple[Tensor, Tensor, Tensor]]
- maxent_grpo.training.scoring_batching.build_score_batch(reward_comp, tokenizer, generation_cfg, batching_cfg)[source]¶
Tokenize prompt+completion pairs and prepare masks/labels.
- Parameters:
reward_comp (RewardComputation) – Reward computation payload containing prompts and completions.
tokenizer (PreTrainedTokenizer) – Tokenizer used to encode completions and determine padding.
generation_cfg (GenerationSettings) – Generation settings (max lengths, etc.).
batching_cfg (BatchingSettings) – Batching settings controlling scoring slice sizes.
- Returns:
Prepared
ScoreBatchorNonewhen no sequences are available.- Return type:
ScoreBatch | None
- maxent_grpo.training.scoring_batching.iter_batch_slices_trl(score_batch, runtime, eos_token_id)[source]¶
Yield prompt+completion slices for TRL-style logprob computation.
- Parameters:
score_batch (ScoreBatch)
runtime (RuntimeHandles)
eos_token_id (int | None)
- Return type:
Iterator[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]]
- maxent_grpo.training.scoring_batching.token_counts_from_score_batch(score_batch, runtime, batching_cfg)[source]¶
Compute per-sequence token counts from the score batch labels mask.
- Parameters:
score_batch (ScoreBatch) – Prepared scoring batch.
runtime (RuntimeHandles) – Runtime handles exposing device/accelerator.
batching_cfg (BatchingSettings) – Batching config controlling slice sizes.
- Returns:
1D tensor of token counts per sequence.
- Return type:
Tensor
- maxent_grpo.training.scoring_batching.summarize_completion_lengths(ref_stats, max_completion_len)[source]¶
Summarize completion lengths for metrics.
- Parameters:
ref_stats (ReferenceLogprobs) – Reference log-prob stats containing token counts.
max_completion_len (int) – Maximum completion length used for clipping stats.
- Returns:
Tuple of
(completion_lengths, length_stats, total_tokens).- Return type:
tuple[Tensor, LengthStats, float]