maxent_grpo.training.rollout.helpers

Completion generation helpers for the MaxEnt-GRPO runner.

Functions

_broadcast_object_list(accelerator, payload, *)

_broadcast_object_list_wrapper(accelerator, ...)

_ensure_dist(dist_obj)

_gather_object_list(accelerator, value)

_gather_object_list_wrapper(accelerator, value)

_import_vllm_client_cls([import_fn])

Import the TRL VLLMClient using the caller-provided optional import hook.

_refresh_vllm_globals()

Keep vLLM adapter globals in sync with test monkeypatches.

_scatter_object(accelerator, input_list, *)

_scatter_object_wrapper(accelerator, ...[, src])

Classes

CompletionGenerator(ctx)

Stateful helper that handles both local HF and vLLM completions.

_DistFallback()

Minimal torch.distributed stand-in for single-process tests.

class maxent_grpo.training.rollout.helpers.VLLMGenerationHelper(ctx, fallback_generate)[source]

Bases: VLLMWeightSyncMixin, VLLMRequestMixin, VLLMDistributedMixin

Encapsulate vLLM-specific logic so CompletionGenerator stays lean.

Parameters:
property vllm_client: Any
property vllm_sync_ready: bool
property last_vllm_synced_step: int | None
property fsdp_cls: Any
generate(prompts, num_samples, per_prompt_counts, ensure_client=None, sync_model=None)[source]

Generate completions for prompts via vLLM, optionally deduplicating.

The helper handles optional weight synchronization, deduplicates repeated prompts when enabled, and retries requests up to the configured round limit. Results are expanded back to the original prompt ordering.

Parameters:
  • prompts (list[str]) – Prompts to generate completions for.

  • num_samples (int) – Requested completions per prompt.

  • per_prompt_counts (list[int] | None) – Optional per-prompt completion counts; when provided overrides num_samples on a per-prompt basis.

  • ensure_client (Callable[[], bool] | None) – Optional callable to guarantee the vLLM client is ready before issuing requests.

  • sync_model (Callable[[Any], None] | None) – Optional callable to push model weights before generation.

Returns:

Grouped completions per prompt and optional grouped logprob metadata when enabled.

Return type:

tuple[list[list[str]], list[list[VLLMLogprobResult | None]] | None]

generate_collective(prompts, num_samples, per_prompt_counts, ensure_client=None, sync_model=None)[source]

Broadcast prompts across ranks and gather vLLM generations collectively.

Prompts from every rank are gathered on the main process, generated in a single vLLM call, and the outputs are scattered back to each rank with metadata preserved when available.

Parameters:
  • prompts (list[str]) – Local prompts for the current rank.

  • num_samples (int) – Requested completions per prompt.

  • per_prompt_counts (list[int] | None) – Optional per-prompt completion counts.

  • ensure_client (Callable[[], bool] | None) – Optional callable ensuring the vLLM client is ready on the main process.

  • sync_model (Callable[[Any], None] | None) – Optional callable to push model weights before generation.

Returns:

Grouped completions and optional metadata corresponding to the current rank’s prompts.

Return type:

tuple[list[list[str]], list[list[VLLMLogprobResult | None]] | None]

class maxent_grpo.training.rollout.helpers.CompletionGenerator(ctx)[source]

Bases: LocalGenerationMixin, VLLMGenerationMixin

Stateful helper that handles both local HF and vLLM completions.

Parameters:

ctx (GenerationContext)

class maxent_grpo.training.rollout.helpers.GenerationContext(max_prompt_len, max_completion_len, gen_temperature, gen_top_p, use_vllm, vllm, accelerator, model, tokenizer, generation_stats, device, penalty=<factory>, prompt_char_limit=None, *, vllm_mode='server')[source]

Bases: GenerationPenaltyPassthroughMixin, GenerationSamplingConfig

Configuration required to produce completions for each training batch.

Parameters:
accelerator: TypesAccelerator
model: TypesPreTrainedModel
tokenizer: TypesPreTrainedTokenizer
generation_stats: Dict[str, int]
device: Any
penalty: GenerationPenaltyConfig
prompt_char_limit: int | None = None
as_dict()[source]

Return a lightweight representation useful for logging/debugging.

Return type:

Dict[str, Any]

max_prompt_len: int
max_completion_len: int
gen_temperature: float
gen_top_p: float
use_vllm: bool
vllm: VLLMClientConfig
maxent_grpo.training.rollout.helpers.require_torch(context)[source]

Return the torch module or raise a helpful RuntimeError.

Parameters:

context (str)

Return type:

Any

maxent_grpo.training.rollout.helpers.safe_generate(*, prompts, url='http://localhost:8000/generate', max_tokens=256, temperature=0.7, top_p=0.9, top_k=None, n=1, stream=False, tokenizer=None, best_of=None, frequency_penalty=None, presence_penalty=None, stop=None, include_stop_str_in_output=False, logit_bias=None, allowed_token_ids=None, blocked_token_ids=None, guided_json=None, guided_regex=None, seed=None, request_id=None, request_id_prefix=None, timeout=None, max_retries=None, backoff=None, backoff_multiplier=None, return_logprobs=False, service_model=None, metadata=None, client_tag=None)[source]

Robust POST to /generate with retry + schema-agnostic decoding.

Parameters:
  • prompts (list[str]) – Input prompts (batch) to generate from.

  • url (str) – Base URL to the /generate route.

  • max_tokens (int) – Maximum tokens to generate per completion.

  • temperature (float) – Sampling temperature.

  • top_p (float) – Nucleus sampling p.

  • top_k (int | None) – Optional top-k cutoff applied during sampling.

  • n (int) – Number of completions per prompt.

  • stream (bool) – Whether to use chunked streaming responses.

  • tokenizer (Any) – Optional tokenizer to decode token ID arrays.

  • best_of (int | None) – vLLM best_of parameter to sample more than n candidates.

  • frequency_penalty (float | None) – Frequency penalty forwarded to vLLM sampling.

  • presence_penalty (float | None) – Presence penalty forwarded to vLLM sampling.

  • stop (list[str] | None) – Stop sequences used to truncate completions.

  • include_stop_str_in_output (bool) – Whether matched stop strings should remain in the returned text.

  • logit_bias (dict[str, float] | None) – Token-level logit bias forwarded to vLLM.

  • allowed_token_ids (list[int] | None) – Optional hard allowlist of token IDs forwarded to vLLM.

  • blocked_token_ids (list[int] | None) – Optional hard denylist of token IDs forwarded to vLLM.

  • guided_json (str | None) – Optional JSON schema string for constrained decoding.

  • guided_regex (str | None) – Optional regex constraint for decoding.

  • seed (int | None) – Optional deterministic sampling seed forwarded to vLLM.

  • request_id (str | None) – Explicit request identifier to forward to vLLM.

  • request_id_prefix (str | None) – Prefix used when auto-generating request_id.

  • max_retries (int) – Number of attempts before surfacing the error.

  • backoff (float) – Base backoff in seconds; exponential across attempts.

  • timeout (float) – Per‑request timeout in seconds.

  • return_logprobs (bool) – Whether to request log-prob metadata from vLLM.

  • service_model (str | None) – Optional identifier for the served model (used in error payloads).

  • metadata (dict[str, Any] | None) – Optional structured context (dataset/model) copied into error payloads.

  • client_tag (str | None) – Optional client/rank identifier forwarded via headers/payload.

  • backoff_multiplier (float | None)

Returns:

Tuple of grouped texts, optional log-prob metadata, and latency in milliseconds.

Return type:

tuple[list[list[str]], Optional[list[list[VLLMLogprobResult]]], float]

Raises:

GenerationServiceError – When the server responds with repeated errors after exhausting retries.