maxent_grpo.training.generation.vllm_helper

Assemble the vLLMGenerationHelper from dedicated mixins.

Functions

_seed_stats_metadata(stats, ctx)

Ensure dataset/model identifiers are stored on generation stats.

Classes

VLLMGenerationHelper(ctx, fallback_generate)

Encapsulate vLLM-specific logic so CompletionGenerator stays lean.

class maxent_grpo.training.generation.vllm_helper.VLLMGenerationHelper(ctx, fallback_generate)[source]

Bases: VLLMWeightSyncMixin, VLLMRequestMixin, VLLMDistributedMixin

Encapsulate vLLM-specific logic so CompletionGenerator stays lean.

Parameters:
property vllm_client: Any
property vllm_sync_ready: bool
property last_vllm_synced_step: int | None
property fsdp_cls: Any
generate(prompts, num_samples, per_prompt_counts, ensure_client=None, sync_model=None)[source]

Generate completions for prompts via vLLM, optionally deduplicating.

The helper handles optional weight synchronization, deduplicates repeated prompts when enabled, and retries requests up to the configured round limit. Results are expanded back to the original prompt ordering.

Parameters:
  • prompts (list[str]) – Prompts to generate completions for.

  • num_samples (int) – Requested completions per prompt.

  • per_prompt_counts (list[int] | None) – Optional per-prompt completion counts; when provided overrides num_samples on a per-prompt basis.

  • ensure_client (Callable[[], bool] | None) – Optional callable to guarantee the vLLM client is ready before issuing requests.

  • sync_model (Callable[[Any], None] | None) – Optional callable to push model weights before generation.

Returns:

Grouped completions per prompt and optional grouped logprob metadata when enabled.

Return type:

tuple[list[list[str]], list[list[VLLMLogprobResult | None]] | None]

generate_collective(prompts, num_samples, per_prompt_counts, ensure_client=None, sync_model=None)[source]

Broadcast prompts across ranks and gather vLLM generations collectively.

Prompts from every rank are gathered on the main process, generated in a single vLLM call, and the outputs are scattered back to each rank with metadata preserved when available.

Parameters:
  • prompts (list[str]) – Local prompts for the current rank.

  • num_samples (int) – Requested completions per prompt.

  • per_prompt_counts (list[int] | None) – Optional per-prompt completion counts.

  • ensure_client (Callable[[], bool] | None) – Optional callable ensuring the vLLM client is ready on the main process.

  • sync_model (Callable[[Any], None] | None) – Optional callable to push model weights before generation.

Returns:

Grouped completions and optional metadata corresponding to the current rank’s prompts.

Return type:

tuple[list[list[str]], list[list[VLLMLogprobResult | None]] | None]