maxent_grpo.training.generation.helpers

Copyright 2025 Liv d’Aliberti

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Functions

flatten_prompt_completions(gen_batch)

Return flattened prompt/completion pairs and aligned answers.

class maxent_grpo.training.generation.helpers.AggregatedGenerationState(completions, metadata=None)[source]

Bases: object

Mutable container for grouped completions and optional metadata.

Parameters:
  • completions (list[list[str]]) – Nested list of completions grouped per prompt index.

  • metadata (list[list[object | None]] | None) – Optional nested metadata aligned to completions.

completions: List[List[str]]
metadata: List[List[object | None]] | None = None
maxent_grpo.training.generation.helpers.append_completion_group(grouped_comps, grouped_meta, prompt_idx, completions, meta_group)[source]

Append completions (and metadata) for a specific prompt index.

Completions and metadata are extended in place, creating a fresh metadata structure when needed. Missing metadata is padded with None so list lengths stay aligned with completions.

Parameters:
  • grouped_comps (list[list[str]]) – Existing grouped completions buffer.

  • grouped_meta (list[list[object | None]] | None) – Existing grouped metadata buffer; can be None if metadata is not tracked.

  • prompt_idx (int) – Index of the prompt whose completions are being appended.

  • completions (list[str] | None) – New completions to append for the prompt.

  • meta_group (list[object | None] | None) – Metadata aligned to completions. Excess entries are trimmed and missing entries are padded with None.

Returns:

Updated grouped metadata (may be newly created), or None when metadata tracking is disabled.

Return type:

list[list[object | None]] | None

maxent_grpo.training.generation.helpers.determine_retry_limit(expected_generations, max_retry_rounds)[source]

Return the number of retry rounds required for a batch.

Parameters:
  • expected_generations (int) – Desired completions per prompt. Used as a fallback retry budget when explicit retries are not provided.

  • max_retry_rounds (int | None) – Explicit retry cap; overrides defaults when > 0.

Returns:

Retry limit, defaulting to expected_generations or _DEFAULT_RETRY_LIMIT when neither input is set.

Return type:

int

maxent_grpo.training.generation.helpers.drop_empty_prompt_groups(prompts, answers, aggregated_comps, aggregated_meta, generation_stats)[source]

Remove prompts that never yielded completions.

Any prompt lacking completions is removed from all aligned structures and a dropped_prompts counter in generation_stats is incremented.

Parameters:
  • prompts (list[str]) – Prompt texts aligned to answers and grouped completions.

  • answers (list[str]) – Reference answers aligned to prompts.

  • aggregated_comps (list[list[str]]) – Grouped completions per prompt (mutable).

  • aggregated_meta (list[list[object | None]] | None) – Optional grouped metadata per prompt.

  • generation_stats (dict[str, int]) – Mutable statistics dictionary for counters.

Returns:

Filtered prompts, answers, completions, and metadata aligned to the remaining prompts.

Return type:

tuple[list[str], list[str], list[list[str]], list[list[object | None]] | None]

maxent_grpo.training.generation.helpers.drop_incomplete_prompt_groups(prompts, answers, aggregated_comps, aggregated_meta, expected_generations, generation_stats)[source]

Drop prompt groups that do not match the expected completion count.

TRL assumes a fixed number of completions per prompt. Any prompt group whose completion count differs from expected_generations is removed from the aligned prompt/answer/completion lists.

Parameters:
  • prompts (list[str]) – Prompt texts aligned to answers and grouped completions.

  • answers (list[str]) – Reference answers aligned to prompts.

  • aggregated_comps (list[list[str]]) – Grouped completions per prompt (mutable).

  • aggregated_meta (list[list[object | None]] | None) – Optional grouped metadata per prompt.

  • expected_generations (int) – Required completions per prompt.

  • generation_stats (dict[str, int]) – Mutable statistics dictionary for counters.

Returns:

Filtered prompts, answers, completions, metadata, and dropped count.

Return type:

tuple[list[str], list[str], list[list[str]], list[list[object | None]] | None, int]

maxent_grpo.training.generation.helpers.flatten_prompt_completions(gen_batch)[source]

Return flattened prompt/completion pairs and aligned answers.

Parameters:

gen_batch (GenerationBatch) – Aggregated generation results.

Returns:

Tuple of flattened prompt/completion batch and answer list.

Return type:

tuple[PromptCompletionBatch, list[str]]

maxent_grpo.training.generation.helpers.flatten_ref_metadata(grouped_comps, grouped_meta)[source]

Flatten metadata to align with the flattened completions list.

Metadata entries exposing to_trl_payload are converted before being appended. Missing metadata is filled with None.

Parameters:
  • grouped_comps (list[list[str]]) – Grouped completions per prompt.

  • grouped_meta (list[list[object | None]] | None) – Grouped metadata aligned to grouped_comps.

Returns:

Flattened metadata aligned to a flattened completions list, or None when no metadata exists.

Return type:

list[object | None] | None

maxent_grpo.training.generation.helpers.pending_generation_indices(aggregated_comps, expected_generations)[source]

Return prompt indices that still need completions.

Parameters:
  • aggregated_comps (list[list[str]]) – Completions grouped per prompt.

  • expected_generations (int) – Desired number of completions per prompt.

Returns:

Indices whose completion count is below expected_generations.

Return type:

list[int]

maxent_grpo.training.generation.helpers.retry_incomplete_prompts(prompts, generator, expected_generations, aggregated, max_retry_rounds)[source]

Retry prompts missing completions until limits are hit.

The generator callback is invoked with the list of prompts still missing completions, along with per-prompt deficits. Metadata returned by the generator is merged if available.

Parameters:
  • prompts (list[str]) – Original prompt strings.

  • generator (Callable[[list[str], int, list[int] | None], tuple[list[list[str]], list[list[object | None]] | None]]) – Callable performing generation for a batch of prompts. It should accept prompts, expected_generations, and optionally a list of per-prompt counts, returning grouped completions and metadata.

  • expected_generations (int) – Number of completions requested per prompt.

  • aggregated (AggregatedGenerationState) – Aggregated state containing completions and metadata to be updated in place.

  • max_retry_rounds (int | None) – Explicit retry cap; defaults derive from expected_generations when omitted.

Returns:

Updated aggregated generation state after retries are exhausted or all prompts are complete.

Return type:

AggregatedGenerationState

maxent_grpo.training.generation.helpers.seed_generation_groups(prompt_count, grouped_comps, grouped_meta)[source]

Return initial completion/meta buffers aligned with prompts.

The helper normalizes partially filled buffers into fresh lists sized to prompt_count and ensures metadata stays aligned with completions.

Parameters:
  • prompt_count (int) – Number of prompts that will be processed.

  • grouped_comps (list[list[str]] | None) – Optional preexisting completions grouped per prompt.

  • grouped_meta (list[list[object | None]] | None) – Optional preexisting metadata grouped per prompt.

Returns:

Tuple of initialized completions buffer and optional metadata buffer, both sized to prompt_count.

Return type:

tuple[list[list[str]], list[list[object | None]] | None]

maxent_grpo.training.generation.helpers.truncate_to_expected_counts(aggregated_comps, aggregated_meta, expected_generations)[source]

Trim completions/meta to requested counts and track partial prompts.

Parameters:
  • aggregated_comps (list[list[str]]) – Grouped completions per prompt.

  • aggregated_meta (list[list[object | None]] | None) – Optional grouped metadata per prompt.

  • expected_generations (int) – Desired completions per prompt; values <= 0 skip trimming.

Returns:

Tuple of grouped completions, grouped metadata, and the number of non-empty prompts whose completion count differs from the requested value.

Return type:

tuple[list[list[str]], list[list[object | None]] | None, int]