maxent_grpo.training.rollout.helpers¶
Completion generation helpers for the MaxEnt-GRPO runner.
Functions
|
|
|
|
|
|
|
|
|
|
|
Import the TRL VLLMClient using the caller-provided optional import hook. |
|
Keep vLLM adapter globals in sync with test monkeypatches. |
|
|
|
Classes
|
Stateful helper that handles both local HF and vLLM completions. |
|
Minimal torch.distributed stand-in for single-process tests. |
- class maxent_grpo.training.rollout.helpers.VLLMGenerationHelper(ctx, fallback_generate)[source]¶
Bases:
VLLMWeightSyncMixin,VLLMRequestMixin,VLLMDistributedMixinEncapsulate vLLM-specific logic so CompletionGenerator stays lean.
- Parameters:
- generate(prompts, num_samples, per_prompt_counts, ensure_client=None, sync_model=None)[source]¶
Generate completions for prompts via vLLM, optionally deduplicating.
The helper handles optional weight synchronization, deduplicates repeated prompts when enabled, and retries requests up to the configured round limit. Results are expanded back to the original prompt ordering.
- Parameters:
num_samples (int) – Requested completions per prompt.
per_prompt_counts (list[int] | None) – Optional per-prompt completion counts; when provided overrides
num_sampleson a per-prompt basis.ensure_client (Callable[[], bool] | None) – Optional callable to guarantee the vLLM client is ready before issuing requests.
sync_model (Callable[[Any], None] | None) – Optional callable to push model weights before generation.
- Returns:
Grouped completions per prompt and optional grouped logprob metadata when enabled.
- Return type:
tuple[list[list[str]], list[list[VLLMLogprobResult | None]] | None]
- generate_collective(prompts, num_samples, per_prompt_counts, ensure_client=None, sync_model=None)[source]¶
Broadcast prompts across ranks and gather vLLM generations collectively.
Prompts from every rank are gathered on the main process, generated in a single vLLM call, and the outputs are scattered back to each rank with metadata preserved when available.
- Parameters:
num_samples (int) – Requested completions per prompt.
per_prompt_counts (list[int] | None) – Optional per-prompt completion counts.
ensure_client (Callable[[], bool] | None) – Optional callable ensuring the vLLM client is ready on the main process.
sync_model (Callable[[Any], None] | None) – Optional callable to push model weights before generation.
- Returns:
Grouped completions and optional metadata corresponding to the current rank’s prompts.
- Return type:
tuple[list[list[str]], list[list[VLLMLogprobResult | None]] | None]
- class maxent_grpo.training.rollout.helpers.CompletionGenerator(ctx)[source]¶
Bases:
LocalGenerationMixin,VLLMGenerationMixinStateful helper that handles both local HF and vLLM completions.
- Parameters:
ctx (GenerationContext)
- class maxent_grpo.training.rollout.helpers.GenerationContext(max_prompt_len, max_completion_len, gen_temperature, gen_top_p, use_vllm, vllm, accelerator, model, tokenizer, generation_stats, device, penalty=<factory>, prompt_char_limit=None, *, vllm_mode='server')[source]¶
Bases:
GenerationPenaltyPassthroughMixin,GenerationSamplingConfigConfiguration required to produce completions for each training batch.
- Parameters:
max_prompt_len (int)
max_completion_len (int)
gen_temperature (float)
gen_top_p (float)
use_vllm (bool)
vllm (VLLMClientConfig)
accelerator (Accelerator)
model (PreTrainedModel)
tokenizer (PreTrainedTokenizer)
device (Any)
penalty (GenerationPenaltyConfig)
prompt_char_limit (int | None)
vllm_mode (str)
- accelerator: TypesAccelerator¶
- model: TypesPreTrainedModel¶
- tokenizer: TypesPreTrainedTokenizer¶
- device: Any¶
- penalty: GenerationPenaltyConfig¶
- vllm: VLLMClientConfig¶
- maxent_grpo.training.rollout.helpers.require_torch(context)[source]¶
Return the torch module or raise a helpful RuntimeError.
- maxent_grpo.training.rollout.helpers.safe_generate(*, prompts, url='http://localhost:8000/generate', max_tokens=256, temperature=0.7, top_p=0.9, top_k=None, n=1, stream=False, tokenizer=None, best_of=None, frequency_penalty=None, presence_penalty=None, stop=None, include_stop_str_in_output=False, logit_bias=None, allowed_token_ids=None, blocked_token_ids=None, guided_json=None, guided_regex=None, seed=None, request_id=None, request_id_prefix=None, timeout=None, max_retries=None, backoff=None, backoff_multiplier=None, return_logprobs=False, service_model=None, metadata=None, client_tag=None)[source]¶
Robust POST to
/generatewith retry + schema-agnostic decoding.- Parameters:
prompts (list[str]) – Input prompts (batch) to generate from.
url (str) – Base URL to the
/generateroute.max_tokens (int) – Maximum tokens to generate per completion.
temperature (float) – Sampling temperature.
top_p (float) – Nucleus sampling p.
top_k (int | None) – Optional top-k cutoff applied during sampling.
n (int) – Number of completions per prompt.
stream (bool) – Whether to use chunked streaming responses.
tokenizer (Any) – Optional tokenizer to decode token ID arrays.
best_of (int | None) – vLLM
best_ofparameter to sample more thanncandidates.frequency_penalty (float | None) – Frequency penalty forwarded to vLLM sampling.
presence_penalty (float | None) – Presence penalty forwarded to vLLM sampling.
stop (list[str] | None) – Stop sequences used to truncate completions.
include_stop_str_in_output (bool) – Whether matched stop strings should remain in the returned text.
logit_bias (dict[str, float] | None) – Token-level logit bias forwarded to vLLM.
allowed_token_ids (list[int] | None) – Optional hard allowlist of token IDs forwarded to vLLM.
blocked_token_ids (list[int] | None) – Optional hard denylist of token IDs forwarded to vLLM.
guided_json (str | None) – Optional JSON schema string for constrained decoding.
guided_regex (str | None) – Optional regex constraint for decoding.
seed (int | None) – Optional deterministic sampling seed forwarded to vLLM.
request_id (str | None) – Explicit request identifier to forward to vLLM.
request_id_prefix (str | None) – Prefix used when auto-generating
request_id.max_retries (int) – Number of attempts before surfacing the error.
backoff (float) – Base backoff in seconds; exponential across attempts.
timeout (float) – Per‑request timeout in seconds.
return_logprobs (bool) – Whether to request log-prob metadata from vLLM.
service_model (str | None) – Optional identifier for the served model (used in error payloads).
metadata (dict[str, Any] | None) – Optional structured context (dataset/model) copied into error payloads.
client_tag (str | None) – Optional client/rank identifier forwarded via headers/payload.
backoff_multiplier (float | None)
- Returns:
Tuple of grouped texts, optional log-prob metadata, and latency in milliseconds.
- Return type:
tuple[list[list[str]], Optional[list[list[VLLMLogprobResult]]], float]
- Raises:
GenerationServiceError – When the server responds with repeated errors after exhausting retries.