maxent_grpo.training.generation.common¶
Shared generation utilities used by the training stack.
This module contains the small, dependency-light helpers for grouping, retrying, and trimming completions so higher layers can import a single source of truth instead of maintaining divergent copies.
Functions
|
Append completions (and metadata) for a specific prompt index. |
|
Return the number of retry rounds required for a batch. |
|
Remove prompts that never yielded completions. |
|
Drop prompt groups that do not match the expected completion count. |
|
Flatten metadata to align with the flattened completions list. |
|
Return prompt indices that still need completions. |
|
Retry prompts missing completions until limits are hit. |
|
Return initial completion/meta buffers aligned with prompts. |
Trim completions/meta to requested counts and track partial prompts. |
Classes
|
Mutable container for grouped completions and optional metadata. |
|
- class maxent_grpo.training.generation.common.AggregatedGenerationState(completions, metadata=None)[source]¶
Bases:
objectMutable container for grouped completions and optional metadata.
- Parameters:
- maxent_grpo.training.generation.common.append_completion_group(grouped_comps, grouped_meta, prompt_idx, completions, meta_group)[source]¶
Append completions (and metadata) for a specific prompt index.
Completions and metadata are extended in place, creating a fresh metadata structure when needed. Missing metadata is padded with
Noneso list lengths stay aligned with completions.- Parameters:
grouped_comps (list[list[str]]) – Existing grouped completions buffer.
grouped_meta (list[list[object | None]] | None) – Existing grouped metadata buffer; can be
Noneif metadata is not tracked.prompt_idx (int) – Index of the prompt whose completions are being appended.
completions (list[str] | None) – New completions to append for the prompt.
meta_group (list[object | None] | None) – Metadata aligned to
completions. Excess entries are trimmed and missing entries are padded withNone.
- Returns:
Updated grouped metadata (may be newly created), or
Nonewhen metadata tracking is disabled.- Return type:
- maxent_grpo.training.generation.common.determine_retry_limit(expected_generations, max_retry_rounds)[source]¶
Return the number of retry rounds required for a batch.
- Parameters:
- Returns:
Retry limit, defaulting to
expected_generationsor_DEFAULT_RETRY_LIMITwhen neither input is set.- Return type:
- maxent_grpo.training.generation.common.drop_empty_prompt_groups(prompts, answers, aggregated_comps, aggregated_meta, generation_stats)[source]¶
Remove prompts that never yielded completions.
Any prompt lacking completions is removed from all aligned structures and a
dropped_promptscounter ingeneration_statsis incremented.- Parameters:
prompts (list[str]) – Prompt texts aligned to
answersand grouped completions.aggregated_comps (list[list[str]]) – Grouped completions per prompt (mutable).
aggregated_meta (list[list[object | None]] | None) – Optional grouped metadata per prompt.
generation_stats (dict[str, int]) – Mutable statistics dictionary for counters.
- Returns:
Filtered prompts, answers, completions, and metadata aligned to the remaining prompts.
- Return type:
tuple[list[str], list[str], list[list[str]], list[list[object | None]] | None]
- maxent_grpo.training.generation.common.drop_incomplete_prompt_groups(prompts, answers, aggregated_comps, aggregated_meta, expected_generations, generation_stats)[source]¶
Drop prompt groups that do not match the expected completion count.
TRL assumes a fixed number of completions per prompt. Any prompt group whose completion count differs from
expected_generationsis removed from the aligned prompt/answer/completion lists.- Parameters:
prompts (list[str]) – Prompt texts aligned to
answersand grouped completions.aggregated_comps (list[list[str]]) – Grouped completions per prompt (mutable).
aggregated_meta (list[list[object | None]] | None) – Optional grouped metadata per prompt.
expected_generations (int) – Required completions per prompt.
generation_stats (dict[str, int]) – Mutable statistics dictionary for counters.
- Returns:
Filtered prompts, answers, completions, metadata, and dropped count.
- Return type:
tuple[list[str], list[str], list[list[str]], list[list[object | None]] | None, int]
- maxent_grpo.training.generation.common.flatten_ref_metadata(grouped_comps, grouped_meta)[source]¶
Flatten metadata to align with the flattened completions list.
Metadata entries exposing
to_trl_payloadare converted before being appended. Missing metadata is filled withNone.- Parameters:
- Returns:
Flattened metadata aligned to a flattened completions list, or
Nonewhen no metadata exists.- Return type:
- maxent_grpo.training.generation.common.pending_generation_indices(aggregated_comps, expected_generations)[source]¶
Return prompt indices that still need completions.
- maxent_grpo.training.generation.common.retry_incomplete_prompts(prompts, generator, expected_generations, aggregated, max_retry_rounds)[source]¶
Retry prompts missing completions until limits are hit.
The
generatorcallback is invoked with the list of prompts still missing completions, along with per-prompt deficits. Metadata returned by the generator is merged if available.- Parameters:
generator (Callable[[list[str], int, list[int] | None], tuple[list[list[str]], list[list[object | None]] | None]]) – Callable performing generation for a batch of prompts. It should accept
prompts,expected_generations, and optionally a list of per-prompt counts, returning grouped completions and metadata.expected_generations (int) – Number of completions requested per prompt.
aggregated (AggregatedGenerationState) – Aggregated state containing completions and metadata to be updated in place.
max_retry_rounds (int | None) – Explicit retry cap; defaults derive from
expected_generationswhen omitted.
- Returns:
Updated aggregated generation state after retries are exhausted or all prompts are complete.
- Return type:
- maxent_grpo.training.generation.common.seed_generation_groups(prompt_count, grouped_comps, grouped_meta)[source]¶
Return initial completion/meta buffers aligned with prompts.
The helper normalizes partially filled buffers into fresh lists sized to
prompt_countand ensures metadata stays aligned with completions.- Parameters:
- Returns:
Tuple of initialized completions buffer and optional metadata buffer, both sized to
prompt_count.- Return type:
- maxent_grpo.training.generation.common.truncate_to_expected_counts(aggregated_comps, aggregated_meta, expected_generations)[source]¶
Trim completions/meta to requested counts and track partial prompts.