maxent_grpo.training.generation.common

Shared generation utilities used by the training stack.

This module contains the small, dependency-light helpers for grouping, retrying, and trimming completions so higher layers can import a single source of truth instead of maintaining divergent copies.

Functions

append_completion_group(grouped_comps, ...)

Append completions (and metadata) for a specific prompt index.

determine_retry_limit(expected_generations, ...)

Return the number of retry rounds required for a batch.

drop_empty_prompt_groups(prompts, answers, ...)

Remove prompts that never yielded completions.

drop_incomplete_prompt_groups(prompts, ...)

Drop prompt groups that do not match the expected completion count.

flatten_ref_metadata(grouped_comps, grouped_meta)

Flatten metadata to align with the flattened completions list.

pending_generation_indices(aggregated_comps, ...)

Return prompt indices that still need completions.

retry_incomplete_prompts(prompts, generator, ...)

Retry prompts missing completions until limits are hit.

seed_generation_groups(prompt_count, ...)

Return initial completion/meta buffers aligned with prompts.

truncate_to_expected_counts(...)

Trim completions/meta to requested counts and track partial prompts.

Classes

AggregatedGenerationState(completions[, ...])

Mutable container for grouped completions and optional metadata.

_TrlPayloadConvertible(*args, **kwargs)

class maxent_grpo.training.generation.common.AggregatedGenerationState(completions, metadata=None)[source]

Bases: object

Mutable container for grouped completions and optional metadata.

Parameters:
  • completions (list[list[str]]) – Nested list of completions grouped per prompt index.

  • metadata (list[list[object | None]] | None) – Optional nested metadata aligned to completions.

completions: List[List[str]]
metadata: List[List[object | None]] | None = None
maxent_grpo.training.generation.common.append_completion_group(grouped_comps, grouped_meta, prompt_idx, completions, meta_group)[source]

Append completions (and metadata) for a specific prompt index.

Completions and metadata are extended in place, creating a fresh metadata structure when needed. Missing metadata is padded with None so list lengths stay aligned with completions.

Parameters:
  • grouped_comps (list[list[str]]) – Existing grouped completions buffer.

  • grouped_meta (list[list[object | None]] | None) – Existing grouped metadata buffer; can be None if metadata is not tracked.

  • prompt_idx (int) – Index of the prompt whose completions are being appended.

  • completions (list[str] | None) – New completions to append for the prompt.

  • meta_group (list[object | None] | None) – Metadata aligned to completions. Excess entries are trimmed and missing entries are padded with None.

Returns:

Updated grouped metadata (may be newly created), or None when metadata tracking is disabled.

Return type:

list[list[object | None]] | None

maxent_grpo.training.generation.common.determine_retry_limit(expected_generations, max_retry_rounds)[source]

Return the number of retry rounds required for a batch.

Parameters:
  • expected_generations (int) – Desired completions per prompt. Used as a fallback retry budget when explicit retries are not provided.

  • max_retry_rounds (int | None) – Explicit retry cap; overrides defaults when > 0.

Returns:

Retry limit, defaulting to expected_generations or _DEFAULT_RETRY_LIMIT when neither input is set.

Return type:

int

maxent_grpo.training.generation.common.drop_empty_prompt_groups(prompts, answers, aggregated_comps, aggregated_meta, generation_stats)[source]

Remove prompts that never yielded completions.

Any prompt lacking completions is removed from all aligned structures and a dropped_prompts counter in generation_stats is incremented.

Parameters:
  • prompts (list[str]) – Prompt texts aligned to answers and grouped completions.

  • answers (list[str]) – Reference answers aligned to prompts.

  • aggregated_comps (list[list[str]]) – Grouped completions per prompt (mutable).

  • aggregated_meta (list[list[object | None]] | None) – Optional grouped metadata per prompt.

  • generation_stats (dict[str, int]) – Mutable statistics dictionary for counters.

Returns:

Filtered prompts, answers, completions, and metadata aligned to the remaining prompts.

Return type:

tuple[list[str], list[str], list[list[str]], list[list[object | None]] | None]

maxent_grpo.training.generation.common.drop_incomplete_prompt_groups(prompts, answers, aggregated_comps, aggregated_meta, expected_generations, generation_stats)[source]

Drop prompt groups that do not match the expected completion count.

TRL assumes a fixed number of completions per prompt. Any prompt group whose completion count differs from expected_generations is removed from the aligned prompt/answer/completion lists.

Parameters:
  • prompts (list[str]) – Prompt texts aligned to answers and grouped completions.

  • answers (list[str]) – Reference answers aligned to prompts.

  • aggregated_comps (list[list[str]]) – Grouped completions per prompt (mutable).

  • aggregated_meta (list[list[object | None]] | None) – Optional grouped metadata per prompt.

  • expected_generations (int) – Required completions per prompt.

  • generation_stats (dict[str, int]) – Mutable statistics dictionary for counters.

Returns:

Filtered prompts, answers, completions, metadata, and dropped count.

Return type:

tuple[list[str], list[str], list[list[str]], list[list[object | None]] | None, int]

maxent_grpo.training.generation.common.flatten_ref_metadata(grouped_comps, grouped_meta)[source]

Flatten metadata to align with the flattened completions list.

Metadata entries exposing to_trl_payload are converted before being appended. Missing metadata is filled with None.

Parameters:
  • grouped_comps (list[list[str]]) – Grouped completions per prompt.

  • grouped_meta (list[list[object | None]] | None) – Grouped metadata aligned to grouped_comps.

Returns:

Flattened metadata aligned to a flattened completions list, or None when no metadata exists.

Return type:

list[object | None] | None

maxent_grpo.training.generation.common.pending_generation_indices(aggregated_comps, expected_generations)[source]

Return prompt indices that still need completions.

Parameters:
  • aggregated_comps (list[list[str]]) – Completions grouped per prompt.

  • expected_generations (int) – Desired number of completions per prompt.

Returns:

Indices whose completion count is below expected_generations.

Return type:

list[int]

maxent_grpo.training.generation.common.retry_incomplete_prompts(prompts, generator, expected_generations, aggregated, max_retry_rounds)[source]

Retry prompts missing completions until limits are hit.

The generator callback is invoked with the list of prompts still missing completions, along with per-prompt deficits. Metadata returned by the generator is merged if available.

Parameters:
  • prompts (list[str]) – Original prompt strings.

  • generator (Callable[[list[str], int, list[int] | None], tuple[list[list[str]], list[list[object | None]] | None]]) – Callable performing generation for a batch of prompts. It should accept prompts, expected_generations, and optionally a list of per-prompt counts, returning grouped completions and metadata.

  • expected_generations (int) – Number of completions requested per prompt.

  • aggregated (AggregatedGenerationState) – Aggregated state containing completions and metadata to be updated in place.

  • max_retry_rounds (int | None) – Explicit retry cap; defaults derive from expected_generations when omitted.

Returns:

Updated aggregated generation state after retries are exhausted or all prompts are complete.

Return type:

AggregatedGenerationState

maxent_grpo.training.generation.common.seed_generation_groups(prompt_count, grouped_comps, grouped_meta)[source]

Return initial completion/meta buffers aligned with prompts.

The helper normalizes partially filled buffers into fresh lists sized to prompt_count and ensures metadata stays aligned with completions.

Parameters:
  • prompt_count (int) – Number of prompts that will be processed.

  • grouped_comps (list[list[str]] | None) – Optional preexisting completions grouped per prompt.

  • grouped_meta (list[list[object | None]] | None) – Optional preexisting metadata grouped per prompt.

Returns:

Tuple of initialized completions buffer and optional metadata buffer, both sized to prompt_count.

Return type:

tuple[list[list[str]], list[list[object | None]] | None]

maxent_grpo.training.generation.common.truncate_to_expected_counts(aggregated_comps, aggregated_meta, expected_generations)[source]

Trim completions/meta to requested counts and track partial prompts.

Parameters:
  • aggregated_comps (list[list[str]]) – Grouped completions per prompt.

  • aggregated_meta (list[list[object | None]] | None) – Optional grouped metadata per prompt.

  • expected_generations (int) – Desired completions per prompt; values <= 0 skip trimming.

Returns:

Tuple of grouped completions, grouped metadata, and the number of non-empty prompts whose completion count differs from the requested value.

Return type:

tuple[list[list[str]], list[list[object | None]] | None, int]