Source code for maxent_grpo.training.scoring_batching

# Copyright 2025 Liv d'Aliberti
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

"""Batch construction and slice materialization helpers for scoring."""

from __future__ import annotations
# pylint: disable=broad-exception-caught

from dataclasses import dataclass as _dataclass
from functools import lru_cache
import numbers
import sys
from typing import Callable, Iterator, List, Optional, Tuple, cast

import numpy as np

from .scoring_common import (
    LOG,
    TorchDType,
    TorchDevice,
    _LongDTypeProxy,
    _SCORING_EXCEPTIONS,
    _TorchModuleLike,
    _maybe_long_tensor,
    _prefetch_iterator,
    _refresh_torch,
    _resolve_dtype,
    _to_numpy_array,
)
from .types import (
    BatchingSettings,
    GenerationSettings,
    LengthStats,
    PreTrainedTokenizer,
    PromptCacheEntry,
    ReferenceLogprobs,
    RewardComputation,
    RuntimeHandles,
    ScoreBatch,
    Tensor,
)

dataclass = _dataclass
torch = _refresh_torch()


[docs] @dataclass class CompletionTensors: """Completion token IDs and masks.""" ids: Tensor mask: Tensor
@dataclass class _SliceState: """Cached tensors and metadata required for batch slicing.""" total_sequences: int slice_size: int completion_ids: Tensor completion_mask: Tensor prompt_entries: List[PromptCacheEntry] pad_token_id: int max_prompt_len: int score_tail_tokens: Optional[int] = None @classmethod def from_score_batch(cls, score_batch: ScoreBatch) -> "_SliceState": """Build a state snapshot derived from a score batch.""" total_sequences = score_batch.total_sequences if total_sequences == 0: slice_size = 0 else: slice_size = ( score_batch.slice_size if score_batch.slice_size > 0 else total_sequences ) return cls( total_sequences=total_sequences, slice_size=slice_size, completion_ids=score_batch.completion_ids, completion_mask=score_batch.completion_attention_mask, prompt_entries=score_batch.prompt_entries, pad_token_id=score_batch.pad_token_id, max_prompt_len=max(1, score_batch.max_prompt_len), score_tail_tokens=getattr(score_batch, "score_tail_tokens", None), ) @dataclass class _PromptCacheConfig: prompt_length_cache_get: Optional[Callable[[str], PromptCacheEntry]] prompt_cache_size: int = 0 def _collect_prompt_entries( prompt_batch: List[str], batching_cfg: _PromptCacheConfig, ) -> Optional[List[PromptCacheEntry]]: """Resolve cached prompt tokenization for a batch of strings. :param prompt_batch: Raw prompt strings to fetch from the cache. :type prompt_batch: list[str] :param batching_cfg: Prompt cache configuration containing the cache getter. :type batching_cfg: _PromptCacheConfig :returns: Cached prompt entries or ``None`` when the batch is empty. :rtype: list[PromptCacheEntry] | None """ cache_size = getattr(batching_cfg, "prompt_cache_size", 0) or 0 prompt_fn = getattr(batching_cfg, "prompt_length_cache_get", None) if cache_size > 0 and callable(prompt_fn): cached = getattr(batching_cfg, "_cached_prompt_lookup", None) underlying = getattr(batching_cfg, "_cached_prompt_source", None) if cached is None or underlying is not prompt_fn: cached = lru_cache(maxsize=cache_size)(prompt_fn) setattr(batching_cfg, "_cached_prompt_lookup", cached) setattr(batching_cfg, "_cached_prompt_source", prompt_fn) prompt_fn = cached if not callable(prompt_fn): return None prompt_entries = cast( List[PromptCacheEntry], [prompt_fn(prompt) for prompt in prompt_batch] ) if not prompt_entries: return None return prompt_entries def _tokenize_completions( completion_batch: List[str], tokenizer: PreTrainedTokenizer, generation_cfg: GenerationSettings, ) -> CompletionTensors: """Tokenize completions into padded tensors. :param completion_batch: Completion strings aligned with prompts. :type completion_batch: list[str] :param tokenizer: Tokenizer used to encode the completions. :type tokenizer: transformers.PreTrainedTokenizer :param generation_cfg: Generation settings (controls max length). :type generation_cfg: GenerationSettings :returns: Completion token IDs and attention masks. :rtype: CompletionTensors """ _refresh_torch() old_padding_side = getattr(tokenizer, "padding_side", None) try: try: if old_padding_side is not None: tokenizer.padding_side = "right" completion_enc = tokenizer( completion_batch, return_tensors="pt", padding=True, truncation=True, max_length=generation_cfg.max_completion_len, add_special_tokens=False, ) except TypeError: completion_enc = tokenizer(completion_batch) finally: if old_padding_side is not None: try: tokenizer.padding_side = old_padding_side except Exception: pass torch_mod = cast(_TorchModuleLike, sys.modules.get("torch", torch)) ids = _maybe_long_tensor(completion_enc["input_ids"], torch_mod) mask = _maybe_long_tensor(completion_enc["attention_mask"], torch_mod) return CompletionTensors( ids=ids, mask=mask, ) def _completion_tensors_from_token_ids( token_ids: List[List[int]], *, pad_token_id: int, max_length: int, ) -> CompletionTensors: """Build completion tensors from pre-tokenized token-id sequences.""" torch_mod = _refresh_torch() limit = int(max_length or 0) clipped: List[List[int]] = [] for seq in token_ids: seq_list = list(seq) if limit > 0: seq_list = seq_list[:limit] clipped.append(seq_list) max_len = max((len(seq) for seq in clipped), default=0) batch = len(clipped) ids_arr = np.full((batch, max_len), int(pad_token_id), dtype=np.int64) mask_arr = np.zeros((batch, max_len), dtype=np.int64) for row, seq in enumerate(clipped): if not seq: continue ids_arr[row, : len(seq)] = np.asarray(seq, dtype=np.int64) mask_arr[row, : len(seq)] = 1 ids = torch_mod.tensor(ids_arr, dtype=getattr(torch_mod, "long", None)) mask = torch_mod.tensor(mask_arr, dtype=getattr(torch_mod, "long", None)) return CompletionTensors(ids=ids, mask=mask) def _prepare_prompt_slice( prompt_slice: List[PromptCacheEntry], max_prompt_len: int, pad_token_id: int, ids_dtype: TorchDType, mask_dtype: TorchDType, ) -> Tuple[Tensor, Tensor, List[int]]: """Materialize prompt tensors for one scoring slice. :param prompt_slice: Cached prompt entries for the current slice. :type prompt_slice: list[PromptCacheEntry] :param max_prompt_len: Maximum prompt length to materialize. :type max_prompt_len: int :param pad_token_id: Token ID used to pad prompts. :type pad_token_id: int :param ids_dtype: Dtype for the generated ID tensor. :type ids_dtype: torch.dtype :param mask_dtype: Dtype for the generated attention mask tensor. :type mask_dtype: torch.dtype :returns: Tuple of (prompt_ids, prompt_mask, prompt_lengths). :rtype: tuple[Tensor, Tensor, list[int]] """ torch_mod = _refresh_torch() ids_dtype = getattr(ids_dtype, "np_dtype", ids_dtype) mask_dtype = getattr(mask_dtype, "np_dtype", mask_dtype) def _coerce_np_dtype(dtype: object) -> np.dtype | type[np.generic]: # Catch torch-style dtype strings/objects early. if isinstance(dtype, str) and dtype.startswith("torch"): return np.int64 dtype_str = str(dtype) if dtype_str.startswith("torch."): return np.int64 resolved = _resolve_dtype(dtype) if resolved is None: name_attr = getattr(dtype, "name", None) if isinstance(name_attr, str): try: resolved = np.dtype(name_attr) except (TypeError, ValueError): resolved = None if resolved is None: return np.int64 try: return np.dtype(resolved) except (TypeError, ValueError): return np.int64 ids_np_dtype = _coerce_np_dtype(ids_dtype) or np.int64 mask_np_dtype = _coerce_np_dtype(mask_dtype) or np.int64 prompt_lengths = [min(entry.length, max_prompt_len) for entry in prompt_slice] max_prompt_tokens = max(prompt_lengths) if prompt_lengths else 0 batch_size = len(prompt_slice) if max_prompt_tokens > 0: prompt_ids_arr = np.full( (batch_size, max_prompt_tokens), pad_token_id, dtype=ids_np_dtype, ) prompt_mask_arr = np.zeros( (batch_size, max_prompt_tokens), dtype=mask_np_dtype, ) for row, (entry, length) in enumerate(zip(prompt_slice, prompt_lengths)): if length == 0: continue start = max_prompt_tokens - length prompt_ids_arr[row, start : start + length] = entry.input_ids[:length] prompt_mask_arr[row, start : start + length] = entry.attention_mask[:length] def _safe_dtype(dtype: object) -> object | None: return ( None if ( isinstance(dtype, str) or str(dtype).startswith("torch.") or isinstance(dtype, np.dtype) ) else dtype ) tensor_ids_dtype = _safe_dtype(ids_dtype) tensor_mask_dtype = _safe_dtype(mask_dtype) prompt_ids = torch_mod.tensor(prompt_ids_arr, dtype=tensor_ids_dtype) prompt_mask = torch_mod.tensor(prompt_mask_arr, dtype=tensor_mask_dtype) else: # Avoid torch.empty here because minimal stubs may omit it. def _safe_dtype(dtype: object) -> object | None: return ( None if ( isinstance(dtype, str) or str(dtype).startswith("torch.") or isinstance(dtype, np.dtype) ) else dtype ) tensor_ids_dtype = _safe_dtype(ids_dtype) tensor_mask_dtype = _safe_dtype(mask_dtype) prompt_ids = torch_mod.zeros((batch_size, 0), dtype=tensor_ids_dtype) prompt_mask = torch_mod.zeros((batch_size, 0), dtype=tensor_mask_dtype) return prompt_ids, prompt_mask, prompt_lengths def _slice_tail_window( start_idx: int, input_ids: Tensor, attention_mask: Tensor, labels: Tensor, ) -> Tuple[Tensor, Tensor, Tensor]: """Slice tail tokens without closing over loop variables.""" if start_idx <= 0: return input_ids, attention_mask, labels return ( input_ids[:, start_idx:], attention_mask[:, start_idx:], labels[:, start_idx:], )
[docs] def iter_batch_slices( score_batch: ScoreBatch, device: TorchDevice, # kept for API symmetry with callers *, eos_token_id: Optional[int] = None, apply_eos_mask: bool = False, ) -> Iterator[Tuple[Tensor, Tensor, Tensor]]: """Yield scoring slices for a batch, assembling prompt tensors on demand. :param score_batch: Prepared prompt/completion tensors and metadata. :type score_batch: ScoreBatch :param device: Device where tensors should be materialized. :type device: torch.device :param eos_token_id: Optional EOS token id for TRL-style completion masking. :type eos_token_id: int | None :param apply_eos_mask: When ``True``, apply EOS-aware completion masks. :type apply_eos_mask: bool :yields: Tuples of ``(input_ids, attention_mask, labels)`` per slice. :rtype: Iterator[tuple[Tensor, Tensor, Tensor]] """ torch_mod = _refresh_torch() state = _SliceState.from_score_batch(score_batch) if state.total_sequences == 0 or state.slice_size <= 0: return as_tensor = getattr(torch_mod, "as_tensor", getattr(torch_mod, "tensor", None)) if as_tensor is None: raise AttributeError("torch.as_tensor (or tensor) is required for scoring.") as_tensor_fn = cast(Callable[..., Tensor], as_tensor) def _ensure_tensor(obj: object, *, target_device: object | None = None) -> Tensor: """Best-effort conversion that tolerates numpy arrays/stubs.""" is_tensor_fn = getattr(torch_mod, "is_tensor", None) try: if callable(is_tensor_fn) and is_tensor_fn(obj): return cast(Tensor, obj) except _SCORING_EXCEPTIONS as exc: # pragma: no cover - defensive LOG.debug("torch.is_tensor check failed; continuing: %s", exc) tensor_type = getattr(torch_mod, "Tensor", None) if tensor_type is not None and isinstance(obj, tensor_type): return cast(Tensor, obj) tensor_ctor = getattr(torch_mod, "tensor", None) if callable(tensor_ctor): data = getattr(obj, "arr", None) if data is None: data = obj try: return cast( Tensor, tensor_ctor( np.asarray(data), device=target_device, dtype=getattr(obj, "dtype", None), ), ) except TypeError: return cast(Tensor, tensor_ctor(np.asarray(data))) return cast(Tensor, obj) def as_tensor_typed(*args: object, **kwargs: object) -> Tensor: return cast(Tensor, as_tensor_fn(*args, **kwargs)) for start in range(0, state.total_sequences, state.slice_size): end = min(start + state.slice_size, state.total_sequences) prompt_slice = state.prompt_entries[start:end] comp_ids_slice = state.completion_ids[start:end] comp_mask_slice = state.completion_mask[start:end] if device is not None: try: comp_ids_slice = comp_ids_slice.to(device) comp_mask_slice = comp_mask_slice.to(device) except (AttributeError, TypeError, ValueError) as exc: # Some lightweight torch stubs treat the ``device`` argument as # a dtype. When that happens we leave tensors on their current # device and rely on ``as_tensor`` below to normalize types. LOG.debug("Failed to move completion tensors to device: %s", exc) batch_size = len(prompt_slice) if batch_size == 0: continue prompt_ids, prompt_mask, prompt_lengths = _prepare_prompt_slice( prompt_slice, state.max_prompt_len, state.pad_token_id, comp_ids_slice.dtype, comp_mask_slice.dtype, ) if device is not None: try: prompt_ids = prompt_ids.to(device) prompt_mask = prompt_mask.to(device) except (AttributeError, TypeError, ValueError) as exc: LOG.debug("Failed to move prompt tensors to device: %s", exc) # Ensure tensors before concatenation (protect against stubs/numpy). prompt_ids = as_tensor_typed(prompt_ids, device=device) prompt_mask = as_tensor_typed(prompt_mask, device=device) comp_ids_slice = as_tensor_typed(comp_ids_slice, device=device) comp_mask_slice = as_tensor_typed(comp_mask_slice, device=device) if apply_eos_mask: completion_mask = _apply_eos_completion_mask( comp_ids_slice, eos_token_id, completion_mask=comp_mask_slice ) try: completion_mask = completion_mask * comp_mask_slice except _SCORING_EXCEPTIONS: comp_arr = np.asarray(getattr(comp_mask_slice, "arr", comp_mask_slice)) eos_arr = np.asarray(getattr(completion_mask, "arr", completion_mask)) completion_mask = as_tensor_typed( comp_arr * eos_arr, device=getattr(comp_mask_slice, "device", None) ) comp_mask_slice = completion_mask # Drop completion columns that are padding for every sequence so tail-only # scoring keeps real tokens instead of global pad regions. comp_tokens_present = None active_comp_columns: List[int] = [] try: comp_tokens_present = (comp_mask_slice != 0).any(dim=0) except _SCORING_EXCEPTIONS as exc: LOG.debug("Failed to compute completion token presence mask: %s", exc) if comp_tokens_present is None: comp_tokens_arr = ( np.asarray(getattr(comp_mask_slice, "arr", comp_mask_slice)) != 0 ) col_activity = comp_tokens_arr.any(axis=0) active_idx_np = np.nonzero(col_activity)[0] active_comp_columns = [int(idx) for idx in active_idx_np.tolist()] if active_comp_columns: last_valid_idx = active_comp_columns[-1] + 1 else: last_valid_idx = 0 else: try: nonzero_cols = torch_mod.nonzero( comp_tokens_present, as_tuple=False ).view(-1) active_comp_columns = [int(idx.item()) for idx in nonzero_cols] last_valid_idx = ( active_comp_columns[-1] + 1 if active_comp_columns else 0 ) except _SCORING_EXCEPTIONS: active_comp_columns = [] last_valid_idx = 0 if last_valid_idx <= 0: comp_ids_slice = comp_ids_slice[:, :0] comp_mask_slice = comp_mask_slice[:, :0] elif last_valid_idx < getattr(comp_ids_slice, "shape", [0, 0])[1]: comp_ids_slice = comp_ids_slice[:, :last_valid_idx] comp_mask_slice = comp_mask_slice[:, :last_valid_idx] full_input_ids = torch_mod.cat([prompt_ids, comp_ids_slice], dim=1) full_attention_mask = torch_mod.cat([prompt_mask, comp_mask_slice], dim=1) labels_tensor = full_input_ids.clone() for idx, plen in enumerate(prompt_lengths): labels_tensor[idx, :plen] = -100 prompt_width = ( getattr(prompt_ids, "shape", [0, 0])[1] if prompt_ids is not None else 0 ) comp_width = ( getattr(comp_ids_slice, "shape", [0, 0])[1] if comp_ids_slice is not None else 0 ) if comp_width > 0: comp_slice = slice(prompt_width, prompt_width + comp_width) comp_labels = labels_tensor[:, comp_slice] updated_labels = comp_labels pad_mask = None pad_mask_arr = None has_padding = False try: pad_mask = comp_mask_slice == 0 has_padding = bool(pad_mask.any()) except _SCORING_EXCEPTIONS: pad_mask = None if pad_mask is None: try: pad_mask_arr = ( np.asarray(getattr(comp_mask_slice, "arr", comp_mask_slice)) == 0 ) has_padding = bool(pad_mask_arr.any()) except _SCORING_EXCEPTIONS: pad_mask_arr = None has_padding = False if has_padding: try: if pad_mask is None: raise AttributeError updated_labels = comp_labels.masked_fill(pad_mask, -100) except _SCORING_EXCEPTIONS: if pad_mask_arr is None: pad_mask_arr = ( np.asarray(getattr(comp_mask_slice, "arr", comp_mask_slice)) == 0 ) comp_arr = np.asarray(getattr(comp_labels, "arr", comp_labels)) comp_arr[pad_mask_arr] = -100 updated_labels = as_tensor_typed( comp_arr, device=getattr(labels_tensor, "device", None) ) labels_tensor[:, comp_slice] = updated_labels full_labels = labels_tensor input_ids = full_input_ids attention_mask = full_attention_mask tail_tokens = getattr(state, "score_tail_tokens", None) if tail_tokens is not None: try: tail_tokens = int(tail_tokens) except (TypeError, ValueError): tail_tokens = None if tail_tokens is not None and tail_tokens > 0: max_len = int(getattr(full_input_ids, "shape", [0, 0])[1] or 0) tail_tokens = min(tail_tokens, max_len) if max_len > 0 else 0 if tail_tokens > 0 and tail_tokens < max_len: slice_start = max(0, max_len - tail_tokens) first_comp_global = None last_comp_global = None if active_comp_columns: first_comp_global = prompt_width + active_comp_columns[0] last_comp_global = prompt_width + active_comp_columns[-1] + 1 input_ids, attention_mask, labels_tensor = _slice_tail_window( slice_start, full_input_ids, full_attention_mask, full_labels, ) if ( first_comp_global is not None and last_comp_global is not None and slice_start >= last_comp_global ): safe_start = max(first_comp_global, last_comp_global - tail_tokens) input_ids, attention_mask, labels_tensor = _slice_tail_window( safe_start, full_input_ids, full_attention_mask, full_labels, ) # Materialize tensors for the ref model; keep device parity with completions. target_device = device or getattr(comp_ids_slice, "device", None) target_dtype = getattr(torch_mod, "long", None) input_ids_out = as_tensor_typed( input_ids, device=target_device, dtype=target_dtype ) attention_mask_out = as_tensor_typed( attention_mask, device=target_device, dtype=target_dtype ) labels_out = as_tensor_typed( labels_tensor, device=target_device, dtype=target_dtype ) input_ids_out = _ensure_tensor(input_ids_out, target_device=target_device) attention_mask_out = _ensure_tensor( attention_mask_out, target_device=target_device ) labels_out = _ensure_tensor(labels_out, target_device=target_device) # Some lightweight stubs do not preserve the requested dtype object on # the Tensor wrapper (they store only the underlying numpy dtype). For # tests that compare against ``torch.long`` we make a best-effort pass # to align the exposed ``dtype`` attribute with the module constant. long_dtype = getattr(torch_mod, "long", None) if long_dtype is not None: proxy = _LongDTypeProxy(long_dtype) for _tensor in (input_ids_out, labels_out): try: # pragma: no cover - exercised via stubbed environments setattr(_tensor, "dtype", proxy) except _SCORING_EXCEPTIONS as exc: LOG.debug("Unable to patch tensor dtype proxy: %s", exc) yield (input_ids_out, attention_mask_out, labels_out)
[docs] def build_score_batch( reward_comp: RewardComputation, tokenizer: PreTrainedTokenizer, generation_cfg: GenerationSettings, batching_cfg: BatchingSettings, ) -> Optional[ScoreBatch]: """Tokenize prompt+completion pairs and prepare masks/labels. :param reward_comp: Reward computation payload containing prompts and completions. :param tokenizer: Tokenizer used to encode completions and determine padding. :param generation_cfg: Generation settings (max lengths, etc.). :param batching_cfg: Batching settings controlling scoring slice sizes. :returns: Prepared ``ScoreBatch`` or ``None`` when no sequences are available. :rtype: ScoreBatch | None """ prompt_batch = getattr(reward_comp.pairs, "prompts", reward_comp.pairs.completions) completion_batch = reward_comp.pairs.completions total_sequences = len(prompt_batch) if total_sequences == 0: return None def _coerce_int(value: object, default: int = 0) -> int: if isinstance(value, numbers.Integral): return int(value) if isinstance(value, numbers.Real): try: return int(float(value)) except (TypeError, ValueError): return default if isinstance(value, (str, bytes, bytearray)): try: return int(value) except (TypeError, ValueError): return default try: return int(str(value)) except (TypeError, ValueError): return default prompt_length_cache_fn: Callable[[str], PromptCacheEntry] prompt_length_cache = getattr(batching_cfg, "prompt_length_cache_get", None) if not callable(prompt_length_cache) and callable(batching_cfg): prompt_length_cache = batching_cfg if callable(prompt_length_cache): prompt_length_cache_fn = cast( Callable[[str], PromptCacheEntry], prompt_length_cache ) else: def _default_prompt_length_cache(_p: str) -> PromptCacheEntry: return PromptCacheEntry(input_ids=[], attention_mask=[]) prompt_length_cache_fn = _default_prompt_length_cache cache_cfg = _PromptCacheConfig( prompt_length_cache_get=prompt_length_cache_fn, prompt_cache_size=int(getattr(batching_cfg, "prompt_cache_size", 0) or 0), ) prompt_entries = _collect_prompt_entries(prompt_batch, cache_cfg) if prompt_entries is None: return None completion_tensors: Optional[CompletionTensors] = None completion_meta = getattr(reward_comp, "completion_metadata", None) if ( completion_meta and isinstance(completion_meta, list) and len(completion_meta) == len(completion_batch) ): token_ids: List[List[int]] = [] ok = True for entry in completion_meta: if not isinstance(entry, dict): ok = False break raw_ids = entry.get("token_ids") if raw_ids is None: ok = False break if hasattr(raw_ids, "tolist"): try: raw_ids = raw_ids.tolist() except _SCORING_EXCEPTIONS as exc: LOG.debug( "Failed to convert completion metadata token_ids: %s", exc ) if isinstance(raw_ids, list) and raw_ids and isinstance(raw_ids[0], list): raw_ids = raw_ids[0] if not isinstance(raw_ids, list): ok = False break try: token_ids.append([_coerce_int(val) for val in raw_ids]) except (TypeError, ValueError): ok = False break if ok: pad_token_raw = tokenizer.pad_token_id if pad_token_raw is None: pad_token_raw = tokenizer.eos_token_id or 0 pad_token_id = _coerce_int(pad_token_raw, 0) vocab_size = getattr(tokenizer, "vocab_size", None) if isinstance(vocab_size, numbers.Integral): vocab_size_int = int(vocab_size) if pad_token_id >= vocab_size_int: fallback = tokenizer.eos_token_id if fallback is None: fallback = vocab_size_int - 1 pad_token_id = _coerce_int(fallback, 0) completion_tensors = _completion_tensors_from_token_ids( token_ids, pad_token_id=pad_token_id, max_length=_coerce_int( getattr(generation_cfg, "max_completion_len", 0), 0 ), ) LOG.debug( "Using pre-tokenized completion token_ids from completion_metadata | sequences=%d", len(token_ids), ) if completion_tensors is None: completion_tensors = _tokenize_completions( completion_batch, tokenizer, generation_cfg, ) slice_size = ( batching_cfg.score_slice if batching_cfg.score_slice > 0 else total_sequences ) slice_size = max(1, slice_size) pad_token_raw = tokenizer.pad_token_id if pad_token_raw is None: pad_token_raw = tokenizer.eos_token_id or 0 pad_token_id = _coerce_int(pad_token_raw, 0) vocab_size = getattr(tokenizer, "vocab_size", None) if isinstance(vocab_size, numbers.Integral): vocab_size_int = int(vocab_size) if pad_token_id >= vocab_size_int: fallback = tokenizer.eos_token_id if fallback is None: fallback = vocab_size_int - 1 pad_token_id = _coerce_int(fallback, 0) return ScoreBatch( prompt_entries=prompt_entries, completion_ids=completion_tensors.ids, completion_attention_mask=completion_tensors.mask, pad_token_id=pad_token_id, max_prompt_len=generation_cfg.max_prompt_len, slice_size=slice_size, total_sequences=total_sequences, score_tail_tokens=getattr(batching_cfg, "score_tail_tokens", None), )
def _apply_eos_completion_mask( completion_ids: Tensor, eos_token_id: Optional[int], completion_mask: Optional[Tensor] = None, ) -> Tensor: """Mask completion tokens after the first EOS token (TRL-style).""" torch_mod = _refresh_torch() if eos_token_id is None: if completion_mask is not None: return completion_mask return cast( Tensor, torch_mod.ones_like(completion_ids, dtype=getattr(torch_mod, "long", None)), ) try: is_eos = completion_ids == eos_token_id batch = int(is_eos.size(0)) seq_len = int(is_eos.size(1)) eos_idx = torch_mod.full( (batch,), seq_len, dtype=getattr(torch_mod, "long", None), device=getattr(completion_ids, "device", None), ) any_eos = is_eos.any(dim=1) if bool(any_eos.any()): eos_pos = is_eos.int().argmax(dim=1) eos_idx = eos_idx.clone() eos_idx[any_eos] = eos_pos[any_eos] seq_idx = torch_mod.arange( seq_len, device=getattr(completion_ids, "device", None) ).unsqueeze(0) seq_idx = seq_idx.expand(batch, -1) mask = seq_idx <= eos_idx.unsqueeze(1) to_fn = getattr(mask, "to", None) if callable(to_fn): mask = to_fn(dtype=getattr(torch_mod, "long", None)) return cast(Tensor, mask) except _SCORING_EXCEPTIONS: comp_arr = _to_numpy_array(completion_ids) mask_arr = np.ones_like(comp_arr, dtype=np.int64) for row_idx, row in enumerate(comp_arr): eos_positions = np.where(row == eos_token_id)[0] if eos_positions.size: first = int(eos_positions[0]) if first + 1 < mask_arr.shape[1]: mask_arr[row_idx, first + 1 :] = 0 return cast( Tensor, torch_mod.tensor( mask_arr, dtype=getattr(torch_mod, "long", None), device=getattr(completion_ids, "device", None), ), )
[docs] def iter_batch_slices_trl( score_batch: ScoreBatch, runtime: RuntimeHandles, eos_token_id: Optional[int], ) -> Iterator[Tuple[Tensor, Tensor, Tensor, int]]: """Yield prompt+completion slices for TRL-style logprob computation.""" torch_mod = _refresh_torch() state = _SliceState.from_score_batch(score_batch) if state.total_sequences == 0 or state.slice_size <= 0: return device = getattr(runtime, "device", None) as_tensor = getattr(torch_mod, "as_tensor", getattr(torch_mod, "tensor", None)) if as_tensor is None: raise AttributeError("torch.as_tensor (or tensor) is required for scoring.") as_tensor_fn = cast(Callable[..., Tensor], as_tensor) def _ensure_tensor(obj: object, *, target_device: object | None = None) -> Tensor: is_tensor_fn = getattr(torch_mod, "is_tensor", None) try: if callable(is_tensor_fn) and is_tensor_fn(obj): return cast(Tensor, obj) except _SCORING_EXCEPTIONS: pass tensor_type = getattr(torch_mod, "Tensor", None) if tensor_type is not None and isinstance(obj, tensor_type): return cast(Tensor, obj) tensor_ctor = getattr(torch_mod, "tensor", None) if callable(tensor_ctor): data = getattr(obj, "arr", None) if data is None: data = obj try: return cast( Tensor, tensor_ctor( np.asarray(data), device=target_device, dtype=getattr(obj, "dtype", None), ), ) except TypeError: return cast(Tensor, tensor_ctor(np.asarray(data))) return cast(Tensor, obj) def as_tensor_typed(*args: object, **kwargs: object) -> Tensor: return cast(Tensor, as_tensor_fn(*args, **kwargs)) for start in range(0, state.total_sequences, state.slice_size): end = min(start + state.slice_size, state.total_sequences) prompt_slice = state.prompt_entries[start:end] comp_ids_slice = state.completion_ids[start:end] comp_mask_slice = state.completion_mask[start:end] if device is not None: try: comp_ids_slice = comp_ids_slice.to(device) comp_mask_slice = comp_mask_slice.to(device) except (AttributeError, TypeError, ValueError): pass batch_size = len(prompt_slice) if batch_size == 0: continue prompt_ids, prompt_mask, _prompt_lengths = _prepare_prompt_slice( prompt_slice, state.max_prompt_len, state.pad_token_id, comp_ids_slice.dtype, comp_mask_slice.dtype, ) if device is not None: try: prompt_ids = prompt_ids.to(device) prompt_mask = prompt_mask.to(device) except (AttributeError, TypeError, ValueError): pass prompt_ids = as_tensor_typed(prompt_ids, device=device) prompt_mask = as_tensor_typed(prompt_mask, device=device) comp_ids_slice = as_tensor_typed(comp_ids_slice, device=device) comp_mask_slice = as_tensor_typed(comp_mask_slice, device=device) full_input_ids = torch_mod.cat([prompt_ids, comp_ids_slice], dim=1) completion_mask = _apply_eos_completion_mask( comp_ids_slice, eos_token_id, completion_mask=None ) completion_mask = _ensure_tensor( completion_mask, target_device=getattr(comp_ids_slice, "device", None) ) try: completion_mask = completion_mask * comp_mask_slice except _SCORING_EXCEPTIONS: comp_arr = np.asarray(getattr(comp_mask_slice, "arr", comp_mask_slice)) eos_arr = np.asarray(getattr(completion_mask, "arr", completion_mask)) completion_mask = as_tensor_typed( comp_arr * eos_arr, device=getattr(comp_mask_slice, "device", None) ) full_attention_mask = torch_mod.cat([prompt_mask, completion_mask], dim=1) logits_to_keep = int(getattr(comp_ids_slice, "shape", [0, 0])[1] or 0) yield full_input_ids, full_attention_mask, completion_mask, logits_to_keep
[docs] def token_counts_from_score_batch( score_batch: ScoreBatch, runtime: RuntimeHandles, batching_cfg: BatchingSettings, ) -> Tensor: """Compute per-sequence token counts from the score batch labels mask. :param score_batch: Prepared scoring batch. :param runtime: Runtime handles exposing device/accelerator. :param batching_cfg: Batching config controlling slice sizes. :returns: 1D tensor of token counts per sequence. :rtype: Tensor """ torch_mod = _refresh_torch() tok_chunks: List[Tensor] = [] eos_token_id = getattr(getattr(runtime, "tokenizer", None), "eos_token_id", None) slice_iter = iter_batch_slices( score_batch, runtime.device, eos_token_id=eos_token_id, apply_eos_mask=True, ) slice_iter = _prefetch_iterator( slice_iter, getattr(batching_cfg, "slice_prefetch", 0) ) for _slice_inputs, _slice_mask, slice_labels in slice_iter: label_mask = slice_labels != -100 tok = label_mask.sum(dim=1).clamp(min=1) to_fn = getattr(tok, "to", None) if callable(to_fn): tok = to_fn(dtype=getattr(torch_mod, "float32", None)) tok = cast(Tensor, tok) tok_chunks.append(tok) if not tok_chunks: try: return cast( Tensor, torch_mod.zeros( (0,), dtype=getattr(torch_mod, "float32", None), device=runtime.device, ), ) except _SCORING_EXCEPTIONS: return cast( Tensor, torch_mod.tensor([], dtype=getattr(torch_mod, "float32", None)), ) try: out = torch_mod.cat(tok_chunks, dim=0) except _SCORING_EXCEPTIONS: out = tok_chunks[0] for chunk in tok_chunks[1:]: out = torch_mod.cat([out, chunk], dim=0) out_tensor: Tensor = cast(Tensor, out) to_fn = getattr(out_tensor, "to", None) if callable(to_fn): try: out_tensor = cast(Tensor, to_fn(device=runtime.device)) except _SCORING_EXCEPTIONS as exc: LOG.debug("Failed to move token counts to runtime device: %s", exc) out_tensor = cast(Tensor, out_tensor) return out_tensor
[docs] def summarize_completion_lengths( ref_stats: ReferenceLogprobs, max_completion_len: int, ) -> Tuple[Tensor, LengthStats, float]: """Summarize completion lengths for metrics. :param ref_stats: Reference log-prob stats containing token counts. :param max_completion_len: Maximum completion length used for clipping stats. :returns: Tuple of ``(completion_lengths, length_stats, total_tokens)``. :rtype: tuple[Tensor, LengthStats, float] """ torch_mod = _refresh_torch() lengths_arr = _to_numpy_array(ref_stats.ref_tok_counts).astype(float) num_completion_tokens = float(lengths_arr.sum()) if lengths_arr.size else 0.0 clipped_mask = lengths_arr >= float(max_completion_len) if lengths_arr.size > 0: min_length = float(lengths_arr.min()) mean_length = float(lengths_arr.mean()) max_length = float(lengths_arr.max()) clipped_ratio = float(clipped_mask.mean()) else: min_length = mean_length = max_length = clipped_ratio = 0.0 terminated = lengths_arr[~clipped_mask] if lengths_arr.size else np.asarray([]) if terminated.size > 0: min_terminated = float(terminated.min()) mean_terminated = float(terminated.mean()) max_terminated = float(terminated.max()) else: min_terminated = mean_terminated = max_terminated = 0.0 completion_lengths = cast( Tensor, torch_mod.tensor(lengths_arr, dtype=getattr(torch_mod, "float32", None)), ) return ( completion_lengths, LengthStats( min_length=min_length, mean_length=mean_length, max_length=max_length, clipped_ratio=clipped_ratio, min_terminated=min_terminated, mean_terminated=mean_terminated, max_terminated=max_terminated, ), num_completion_tokens, )