maxent_grpo.training.generation.vllm¶
vLLM-specific helpers extracted from the MaxEnt-GRPO generation module.
Functions
|
Return a no-op context manager. |
- class maxent_grpo.training.generation.vllm.VLLMGenerationHelper(ctx, fallback_generate)[source]¶
Bases:
VLLMWeightSyncMixin,VLLMRequestMixin,VLLMDistributedMixinEncapsulate vLLM-specific logic so CompletionGenerator stays lean.
- Parameters:
- generate(prompts, num_samples, per_prompt_counts, ensure_client=None, sync_model=None)[source]¶
Generate completions for prompts via vLLM, optionally deduplicating.
The helper handles optional weight synchronization, deduplicates repeated prompts when enabled, and retries requests up to the configured round limit. Results are expanded back to the original prompt ordering.
- Parameters:
num_samples (int) – Requested completions per prompt.
per_prompt_counts (list[int] | None) – Optional per-prompt completion counts; when provided overrides
num_sampleson a per-prompt basis.ensure_client (Callable[[], bool] | None) – Optional callable to guarantee the vLLM client is ready before issuing requests.
sync_model (Callable[[Any], None] | None) – Optional callable to push model weights before generation.
- Returns:
Grouped completions per prompt and optional grouped logprob metadata when enabled.
- Return type:
tuple[list[list[str]], list[list[VLLMLogprobResult | None]] | None]
- generate_collective(prompts, num_samples, per_prompt_counts, ensure_client=None, sync_model=None)[source]¶
Broadcast prompts across ranks and gather vLLM generations collectively.
Prompts from every rank are gathered on the main process, generated in a single vLLM call, and the outputs are scattered back to each rank with metadata preserved when available.
- Parameters:
num_samples (int) – Requested completions per prompt.
per_prompt_counts (list[int] | None) – Optional per-prompt completion counts.
ensure_client (Callable[[], bool] | None) – Optional callable ensuring the vLLM client is ready on the main process.
sync_model (Callable[[Any], None] | None) – Optional callable to push model weights before generation.
- Returns:
Grouped completions and optional metadata corresponding to the current rank’s prompts.
- Return type:
tuple[list[list[str]], list[list[VLLMLogprobResult | None]] | None]