maxent_grpo.training.generation.helpers¶
Copyright 2025 Liv d’Aliberti
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
Functions
|
Return flattened prompt/completion pairs and aligned answers. |
- class maxent_grpo.training.generation.helpers.AggregatedGenerationState(completions, metadata=None)[source]¶
Bases:
objectMutable container for grouped completions and optional metadata.
- Parameters:
- maxent_grpo.training.generation.helpers.append_completion_group(grouped_comps, grouped_meta, prompt_idx, completions, meta_group)[source]¶
Append completions (and metadata) for a specific prompt index.
Completions and metadata are extended in place, creating a fresh metadata structure when needed. Missing metadata is padded with
Noneso list lengths stay aligned with completions.- Parameters:
grouped_comps (list[list[str]]) – Existing grouped completions buffer.
grouped_meta (list[list[object | None]] | None) – Existing grouped metadata buffer; can be
Noneif metadata is not tracked.prompt_idx (int) – Index of the prompt whose completions are being appended.
completions (list[str] | None) – New completions to append for the prompt.
meta_group (list[object | None] | None) – Metadata aligned to
completions. Excess entries are trimmed and missing entries are padded withNone.
- Returns:
Updated grouped metadata (may be newly created), or
Nonewhen metadata tracking is disabled.- Return type:
- maxent_grpo.training.generation.helpers.determine_retry_limit(expected_generations, max_retry_rounds)[source]¶
Return the number of retry rounds required for a batch.
- Parameters:
- Returns:
Retry limit, defaulting to
expected_generationsor_DEFAULT_RETRY_LIMITwhen neither input is set.- Return type:
- maxent_grpo.training.generation.helpers.drop_empty_prompt_groups(prompts, answers, aggregated_comps, aggregated_meta, generation_stats)[source]¶
Remove prompts that never yielded completions.
Any prompt lacking completions is removed from all aligned structures and a
dropped_promptscounter ingeneration_statsis incremented.- Parameters:
prompts (list[str]) – Prompt texts aligned to
answersand grouped completions.aggregated_comps (list[list[str]]) – Grouped completions per prompt (mutable).
aggregated_meta (list[list[object | None]] | None) – Optional grouped metadata per prompt.
generation_stats (dict[str, int]) – Mutable statistics dictionary for counters.
- Returns:
Filtered prompts, answers, completions, and metadata aligned to the remaining prompts.
- Return type:
tuple[list[str], list[str], list[list[str]], list[list[object | None]] | None]
- maxent_grpo.training.generation.helpers.drop_incomplete_prompt_groups(prompts, answers, aggregated_comps, aggregated_meta, expected_generations, generation_stats)[source]¶
Drop prompt groups that do not match the expected completion count.
TRL assumes a fixed number of completions per prompt. Any prompt group whose completion count differs from
expected_generationsis removed from the aligned prompt/answer/completion lists.- Parameters:
prompts (list[str]) – Prompt texts aligned to
answersand grouped completions.aggregated_comps (list[list[str]]) – Grouped completions per prompt (mutable).
aggregated_meta (list[list[object | None]] | None) – Optional grouped metadata per prompt.
expected_generations (int) – Required completions per prompt.
generation_stats (dict[str, int]) – Mutable statistics dictionary for counters.
- Returns:
Filtered prompts, answers, completions, metadata, and dropped count.
- Return type:
tuple[list[str], list[str], list[list[str]], list[list[object | None]] | None, int]
- maxent_grpo.training.generation.helpers.flatten_prompt_completions(gen_batch)[source]¶
Return flattened prompt/completion pairs and aligned answers.
- Parameters:
gen_batch (GenerationBatch) – Aggregated generation results.
- Returns:
Tuple of flattened prompt/completion batch and answer list.
- Return type:
- maxent_grpo.training.generation.helpers.flatten_ref_metadata(grouped_comps, grouped_meta)[source]¶
Flatten metadata to align with the flattened completions list.
Metadata entries exposing
to_trl_payloadare converted before being appended. Missing metadata is filled withNone.- Parameters:
- Returns:
Flattened metadata aligned to a flattened completions list, or
Nonewhen no metadata exists.- Return type:
- maxent_grpo.training.generation.helpers.pending_generation_indices(aggregated_comps, expected_generations)[source]¶
Return prompt indices that still need completions.
- maxent_grpo.training.generation.helpers.retry_incomplete_prompts(prompts, generator, expected_generations, aggregated, max_retry_rounds)[source]¶
Retry prompts missing completions until limits are hit.
The
generatorcallback is invoked with the list of prompts still missing completions, along with per-prompt deficits. Metadata returned by the generator is merged if available.- Parameters:
generator (Callable[[list[str], int, list[int] | None], tuple[list[list[str]], list[list[object | None]] | None]]) – Callable performing generation for a batch of prompts. It should accept
prompts,expected_generations, and optionally a list of per-prompt counts, returning grouped completions and metadata.expected_generations (int) – Number of completions requested per prompt.
aggregated (AggregatedGenerationState) – Aggregated state containing completions and metadata to be updated in place.
max_retry_rounds (int | None) – Explicit retry cap; defaults derive from
expected_generationswhen omitted.
- Returns:
Updated aggregated generation state after retries are exhausted or all prompts are complete.
- Return type:
- maxent_grpo.training.generation.helpers.seed_generation_groups(prompt_count, grouped_comps, grouped_meta)[source]¶
Return initial completion/meta buffers aligned with prompts.
The helper normalizes partially filled buffers into fresh lists sized to
prompt_countand ensures metadata stays aligned with completions.- Parameters:
- Returns:
Tuple of initialized completions buffer and optional metadata buffer, both sized to
prompt_count.- Return type:
- maxent_grpo.training.generation.helpers.truncate_to_expected_counts(aggregated_comps, aggregated_meta, expected_generations)[source]¶
Trim completions/meta to requested counts and track partial prompts.