"""
Shared generation utilities used by the training stack.
This module contains the small, dependency-light helpers for grouping,
retrying, and trimming completions so higher layers can import a single source
of truth instead of maintaining divergent copies.
"""
from __future__ import annotations
# pylint: disable=broad-exception-caught
from dataclasses import dataclass
import logging
from typing import Callable, Dict, List, Optional, Protocol, Tuple, runtime_checkable
try: # Optional: keep module importable in minimal environments.
import torch
except Exception: # pragma: no cover - torch optional in some utilities
torch = None
_DEFAULT_RETRY_LIMIT = 3
LOG = logging.getLogger(__name__)
@runtime_checkable
class _TrlPayloadConvertible(Protocol):
def to_trl_payload(self) -> object: ...
MetadataEntry = object
MetadataGroup = List[Optional[MetadataEntry]]
GroupedMetadata = List[MetadataGroup]
FlatMetadata = List[Optional[MetadataEntry]]
[docs]
@dataclass
class AggregatedGenerationState:
"""Mutable container for grouped completions and optional metadata.
:param completions: Nested list of completions grouped per prompt index.
:type completions: list[list[str]]
:param metadata: Optional nested metadata aligned to ``completions``.
:type metadata: list[list[object | None]] | None
"""
completions: List[List[str]]
metadata: Optional[GroupedMetadata] = None
[docs]
def append_completion_group(
grouped_comps: List[List[str]],
grouped_meta: Optional[GroupedMetadata],
prompt_idx: int,
completions: Optional[List[str]],
meta_group: Optional[MetadataGroup],
) -> Optional[GroupedMetadata]:
"""Append completions (and metadata) for a specific prompt index.
Completions and metadata are extended in place, creating a fresh metadata
structure when needed. Missing metadata is padded with ``None`` so list
lengths stay aligned with completions.
:param grouped_comps: Existing grouped completions buffer.
:type grouped_comps: list[list[str]]
:param grouped_meta: Existing grouped metadata buffer; can be ``None`` if
metadata is not tracked.
:type grouped_meta: list[list[object | None]] | None
:param prompt_idx: Index of the prompt whose completions are being appended.
:type prompt_idx: int
:param completions: New completions to append for the prompt.
:type completions: list[str] | None
:param meta_group: Metadata aligned to ``completions``. Excess entries are
trimmed and missing entries are padded with ``None``.
:type meta_group: list[object | None] | None
:returns: Updated grouped metadata (may be newly created), or ``None`` when
metadata tracking is disabled.
:rtype: list[list[object | None]] | None
"""
if not completions:
return grouped_meta
entries = list(completions)
start = len(grouped_comps[prompt_idx])
grouped_comps[prompt_idx].extend(entries)
if meta_group is None:
if grouped_meta is not None:
grouped_meta[prompt_idx].extend([None] * len(entries))
return grouped_meta
if grouped_meta is None:
grouped_meta = [[None] * len(group) for group in grouped_comps]
meta_entries = list(meta_group)
if len(meta_entries) < len(entries):
meta_entries.extend([None] * (len(entries) - len(meta_entries)))
else:
meta_entries = meta_entries[: len(entries)]
end = start + len(entries)
current_meta = grouped_meta[prompt_idx]
if len(current_meta) < end:
current_meta.extend([None] * (end - len(current_meta)))
current_meta[start:end] = meta_entries
return grouped_meta
[docs]
def seed_generation_groups(
prompt_count: int,
grouped_comps: Optional[List[List[str]]],
grouped_meta: Optional[GroupedMetadata],
) -> Tuple[List[List[str]], Optional[GroupedMetadata]]:
"""Return initial completion/meta buffers aligned with prompts.
The helper normalizes partially filled buffers into fresh lists sized to
``prompt_count`` and ensures metadata stays aligned with completions.
:param prompt_count: Number of prompts that will be processed.
:type prompt_count: int
:param grouped_comps: Optional preexisting completions grouped per prompt.
:type grouped_comps: list[list[str]] | None
:param grouped_meta: Optional preexisting metadata grouped per prompt.
:type grouped_meta: list[list[object | None]] | None
:returns: Tuple of initialized completions buffer and optional metadata
buffer, both sized to ``prompt_count``.
:rtype: tuple[list[list[str]], list[list[object | None]] | None]
"""
aggregated_comps: List[List[str]] = [[] for _ in range(prompt_count)]
aggregated_meta: Optional[GroupedMetadata] = None
base_groups = grouped_comps or []
for idx in range(prompt_count):
comp_group: List[str] = []
if idx < len(base_groups) and base_groups[idx]:
comp_group = list(base_groups[idx])
meta_group: Optional[MetadataGroup] = None
if grouped_meta is not None and idx < len(grouped_meta):
meta_group = grouped_meta[idx]
aggregated_meta = append_completion_group(
aggregated_comps,
aggregated_meta,
idx,
comp_group,
meta_group,
)
return aggregated_comps, aggregated_meta
[docs]
def pending_generation_indices(
aggregated_comps: List[List[str]],
expected_generations: int,
) -> List[int]:
"""Return prompt indices that still need completions.
:param aggregated_comps: Completions grouped per prompt.
:type aggregated_comps: list[list[str]]
:param expected_generations: Desired number of completions per prompt.
:type expected_generations: int
:returns: Indices whose completion count is below ``expected_generations``.
:rtype: list[int]
"""
if expected_generations <= 0:
return []
return [
idx
for idx, comps in enumerate(aggregated_comps)
if len(comps) < expected_generations
]
[docs]
def determine_retry_limit(
expected_generations: int,
max_retry_rounds: Optional[int],
) -> int:
"""Return the number of retry rounds required for a batch.
:param expected_generations: Desired completions per prompt. Used as a
fallback retry budget when explicit retries are not provided.
:type expected_generations: int
:param max_retry_rounds: Explicit retry cap; overrides defaults when > 0.
:type max_retry_rounds: int | None
:returns: Retry limit, defaulting to ``expected_generations`` or
``_DEFAULT_RETRY_LIMIT`` when neither input is set.
:rtype: int
"""
if max_retry_rounds and max_retry_rounds > 0:
return max_retry_rounds
if expected_generations > 0:
return expected_generations
return _DEFAULT_RETRY_LIMIT
[docs]
def retry_incomplete_prompts(
prompts: List[str],
generator: Callable[
[List[str], int, Optional[List[int]]],
Tuple[List[List[str]], Optional[GroupedMetadata]],
],
expected_generations: int,
aggregated: AggregatedGenerationState,
max_retry_rounds: Optional[int],
) -> AggregatedGenerationState:
"""Retry prompts missing completions until limits are hit.
The ``generator`` callback is invoked with the list of prompts still
missing completions, along with per-prompt deficits. Metadata returned by
the generator is merged if available.
:param prompts: Original prompt strings.
:type prompts: list[str]
:param generator: Callable performing generation for a batch of prompts.
It should accept ``prompts``, ``expected_generations``, and optionally a
list of per-prompt counts, returning grouped completions and metadata.
:type generator: Callable[[list[str], int, list[int] | None], tuple[list[list[str]], list[list[object | None]] | None]]
:param expected_generations: Number of completions requested per prompt.
:type expected_generations: int
:param aggregated: Aggregated state containing completions and metadata to
be updated in place.
:type aggregated: AggregatedGenerationState
:param max_retry_rounds: Explicit retry cap; defaults derive from
``expected_generations`` when omitted.
:type max_retry_rounds: int | None
:returns: Updated aggregated generation state after retries are exhausted or
all prompts are complete.
:rtype: AggregatedGenerationState
"""
incomplete_indices = pending_generation_indices(
aggregated.completions,
expected_generations,
)
retry_limit = determine_retry_limit(expected_generations, max_retry_rounds)
retry_round = 0
def _any_rank_incomplete(local_has: bool) -> bool:
if torch is None:
return local_has
dist = getattr(torch, "distributed", None)
if (
dist is None
or not callable(getattr(dist, "is_available", None))
or not callable(getattr(dist, "is_initialized", None))
or not dist.is_available()
or not dist.is_initialized()
):
return local_has
try:
backend = dist.get_backend()
except Exception:
backend = ""
use_cuda = bool(getattr(torch, "cuda", None)) and torch.cuda.is_available()
device = (
torch.device("cuda")
if use_cuda and backend == "nccl"
else torch.device("cpu")
)
flag = torch.tensor([1 if local_has else 0], device=device, dtype=torch.int32)
dist.all_reduce(flag, op=dist.ReduceOp.MAX)
return bool(flag.item())
while retry_round < retry_limit:
global_incomplete = _any_rank_incomplete(bool(incomplete_indices))
if not global_incomplete:
break
retry_round += 1
retry_prompts = [prompts[idx] for idx in incomplete_indices]
retry_counts = [
max(expected_generations - len(aggregated.completions[idx]), 0)
for idx in incomplete_indices
]
retry_groups, retry_meta = generator(
retry_prompts,
expected_generations,
retry_counts,
)
retry_groups = retry_groups or [[] for _ in incomplete_indices]
meta_payload: Optional[GroupedMetadata] = None
if isinstance(retry_meta, list):
meta_payload = retry_meta
for local_idx, prompt_idx in enumerate(incomplete_indices):
meta_group = None
if meta_payload is not None and local_idx < len(meta_payload):
meta_group = meta_payload[local_idx]
group = retry_groups[local_idx] if local_idx < len(retry_groups) else []
aggregated.metadata = append_completion_group(
aggregated.completions,
aggregated.metadata,
prompt_idx,
group,
meta_group,
)
incomplete_indices = pending_generation_indices(
aggregated.completions,
expected_generations,
)
return aggregated
[docs]
def drop_empty_prompt_groups(
prompts: List[str],
answers: List[str],
aggregated_comps: List[List[str]],
aggregated_meta: Optional[GroupedMetadata],
generation_stats: Dict[str, int],
) -> Tuple[
List[str],
List[str],
List[List[str]],
Optional[GroupedMetadata],
]:
"""Remove prompts that never yielded completions.
Any prompt lacking completions is removed from all aligned structures and
a ``dropped_prompts`` counter in ``generation_stats`` is incremented.
:param prompts: Prompt texts aligned to ``answers`` and grouped completions.
:type prompts: list[str]
:param answers: Reference answers aligned to prompts.
:type answers: list[str]
:param aggregated_comps: Grouped completions per prompt (mutable).
:type aggregated_comps: list[list[str]]
:param aggregated_meta: Optional grouped metadata per prompt.
:type aggregated_meta: list[list[object | None]] | None
:param generation_stats: Mutable statistics dictionary for counters.
:type generation_stats: dict[str, int]
:returns: Filtered prompts, answers, completions, and metadata aligned to
the remaining prompts.
:rtype: tuple[list[str], list[str], list[list[str]], list[list[object | None]] | None]
"""
drop_indices = [idx for idx, comps in enumerate(aggregated_comps) if not comps]
if not drop_indices:
return prompts, answers, aggregated_comps, aggregated_meta
generation_stats["dropped_prompts"] += len(drop_indices)
missing_set = set(drop_indices)
keep_indices = [idx for idx in range(len(prompts)) if idx not in missing_set]
prompts = [prompts[idx] for idx in keep_indices]
answers = [answers[idx] for idx in keep_indices]
aggregated_comps = [aggregated_comps[idx] for idx in keep_indices]
if aggregated_meta is not None:
aggregated_meta = [aggregated_meta[idx] for idx in keep_indices]
return prompts, answers, aggregated_comps, aggregated_meta
[docs]
def drop_incomplete_prompt_groups(
prompts: List[str],
answers: List[str],
aggregated_comps: List[List[str]],
aggregated_meta: Optional[GroupedMetadata],
expected_generations: int,
generation_stats: Dict[str, int],
) -> Tuple[
List[str],
List[str],
List[List[str]],
Optional[GroupedMetadata],
int,
]:
"""Drop prompt groups that do not match the expected completion count.
TRL assumes a fixed number of completions per prompt. Any prompt group whose
completion count differs from ``expected_generations`` is removed from the
aligned prompt/answer/completion lists.
:param prompts: Prompt texts aligned to ``answers`` and grouped completions.
:type prompts: list[str]
:param answers: Reference answers aligned to prompts.
:type answers: list[str]
:param aggregated_comps: Grouped completions per prompt (mutable).
:type aggregated_comps: list[list[str]]
:param aggregated_meta: Optional grouped metadata per prompt.
:type aggregated_meta: list[list[object | None]] | None
:param expected_generations: Required completions per prompt.
:type expected_generations: int
:param generation_stats: Mutable statistics dictionary for counters.
:type generation_stats: dict[str, int]
:returns: Filtered prompts, answers, completions, metadata, and dropped count.
:rtype: tuple[list[str], list[str], list[list[str]], list[list[object | None]] | None, int]
"""
if expected_generations <= 0:
return prompts, answers, aggregated_comps, aggregated_meta, 0
drop_indices = [
idx
for idx, comps in enumerate(aggregated_comps)
if len(comps) != expected_generations
]
if not drop_indices:
# Trim metadata defensively to match completion counts.
if aggregated_meta is not None:
for idx, comps in enumerate(aggregated_comps):
if idx >= len(aggregated_meta):
break
meta_group = aggregated_meta[idx]
if isinstance(meta_group, list) and len(meta_group) > len(comps):
aggregated_meta[idx] = meta_group[: len(comps)]
return prompts, answers, aggregated_comps, aggregated_meta, 0
generation_stats.setdefault("dropped_prompts", 0)
generation_stats["dropped_prompts"] += len(drop_indices)
drop_set = set(drop_indices)
keep_indices = [idx for idx in range(len(prompts)) if idx not in drop_set]
prompts = [prompts[idx] for idx in keep_indices]
answers = [answers[idx] for idx in keep_indices]
aggregated_comps = [aggregated_comps[idx] for idx in keep_indices]
if aggregated_meta is not None:
aggregated_meta = [aggregated_meta[idx] for idx in keep_indices]
for idx, comps in enumerate(aggregated_comps):
if idx >= len(aggregated_meta):
break
meta_group = aggregated_meta[idx]
if isinstance(meta_group, list) and len(meta_group) > len(comps):
aggregated_meta[idx] = meta_group[: len(comps)]
return prompts, answers, aggregated_comps, aggregated_meta, len(drop_indices)
[docs]
def truncate_to_expected_counts(
aggregated_comps: List[List[str]],
aggregated_meta: Optional[GroupedMetadata],
expected_generations: int,
) -> Tuple[
List[List[str]],
Optional[GroupedMetadata],
int,
]:
"""Trim completions/meta to requested counts and track partial prompts.
:param aggregated_comps: Grouped completions per prompt.
:type aggregated_comps: list[list[str]]
:param aggregated_meta: Optional grouped metadata per prompt.
:type aggregated_meta: list[list[object | None]] | None
:param expected_generations: Desired completions per prompt; values <= 0
skip trimming.
:type expected_generations: int
:returns: Tuple of grouped completions, grouped metadata, and the number of
non-empty prompts whose completion count differs from the requested
value.
:rtype: tuple[list[list[str]], list[list[object | None]] | None, int]
"""
if expected_generations <= 0:
return aggregated_comps, aggregated_meta, 0
partial_count = 0
for idx, comps in enumerate(aggregated_comps):
comp_count = len(comps)
if 0 < comp_count != expected_generations:
partial_count += 1
if aggregated_meta is None or idx >= len(aggregated_meta):
continue
meta_group = aggregated_meta[idx]
if isinstance(meta_group, list) and len(meta_group) > comp_count:
aggregated_meta[idx] = meta_group[:comp_count]
return aggregated_comps, aggregated_meta, partial_count
__all__ = [
"AggregatedGenerationState",
"append_completion_group",
"determine_retry_limit",
"drop_empty_prompt_groups",
"drop_incomplete_prompt_groups",
"flatten_ref_metadata",
"pending_generation_indices",
"retry_incomplete_prompts",
"seed_generation_groups",
"truncate_to_expected_counts",
]