maxent_grpo.training.rollout.generator¶
Public CompletionGenerator that wires local and vLLM helpers together.
Classes
|
Stateful helper that handles both local HF and vLLM completions. |
- class maxent_grpo.training.rollout.generator.CompletionGenerator(ctx)[source]¶
Bases:
LocalGenerationMixin,VLLMGenerationMixinStateful helper that handles both local HF and vLLM completions.
- Parameters:
ctx (GenerationContext)
- class maxent_grpo.training.rollout.generator.GenerationContext(max_prompt_len, max_completion_len, gen_temperature, gen_top_p, use_vllm, vllm, accelerator, model, tokenizer, generation_stats, device, penalty=<factory>, prompt_char_limit=None, *, vllm_mode='server')[source]¶
Bases:
GenerationPenaltyPassthroughMixin,GenerationSamplingConfigConfiguration required to produce completions for each training batch.
- Parameters:
max_prompt_len (int)
max_completion_len (int)
gen_temperature (float)
gen_top_p (float)
use_vllm (bool)
vllm (VLLMClientConfig)
accelerator (Accelerator)
model (PreTrainedModel)
tokenizer (PreTrainedTokenizer)
device (Any)
penalty (GenerationPenaltyConfig)
prompt_char_limit (int | None)
vllm_mode (str)
- accelerator: TypesAccelerator¶
- model: TypesPreTrainedModel¶
- tokenizer: TypesPreTrainedTokenizer¶
- device: Any¶
- penalty: GenerationPenaltyConfig¶
- vllm: VLLMClientConfig¶
- maxent_grpo.training.rollout.generator.safe_generate(*, prompts, url='http://localhost:8000/generate', max_tokens=256, temperature=0.7, top_p=0.9, top_k=None, n=1, stream=False, tokenizer=None, best_of=None, frequency_penalty=None, presence_penalty=None, stop=None, include_stop_str_in_output=False, logit_bias=None, allowed_token_ids=None, blocked_token_ids=None, guided_json=None, guided_regex=None, seed=None, request_id=None, request_id_prefix=None, timeout=None, max_retries=None, backoff=None, backoff_multiplier=None, return_logprobs=False, service_model=None, metadata=None, client_tag=None)[source]¶
Robust POST to
/generatewith retry + schema-agnostic decoding.- Parameters:
prompts (list[str]) – Input prompts (batch) to generate from.
url (str) – Base URL to the
/generateroute.max_tokens (int) – Maximum tokens to generate per completion.
temperature (float) – Sampling temperature.
top_p (float) – Nucleus sampling p.
top_k (int | None) – Optional top-k cutoff applied during sampling.
n (int) – Number of completions per prompt.
stream (bool) – Whether to use chunked streaming responses.
tokenizer (Any) – Optional tokenizer to decode token ID arrays.
best_of (int | None) – vLLM
best_ofparameter to sample more thanncandidates.frequency_penalty (float | None) – Frequency penalty forwarded to vLLM sampling.
presence_penalty (float | None) – Presence penalty forwarded to vLLM sampling.
stop (list[str] | None) – Stop sequences used to truncate completions.
include_stop_str_in_output (bool) – Whether matched stop strings should remain in the returned text.
logit_bias (dict[str, float] | None) – Token-level logit bias forwarded to vLLM.
allowed_token_ids (list[int] | None) – Optional hard allowlist of token IDs forwarded to vLLM.
blocked_token_ids (list[int] | None) – Optional hard denylist of token IDs forwarded to vLLM.
guided_json (str | None) – Optional JSON schema string for constrained decoding.
guided_regex (str | None) – Optional regex constraint for decoding.
seed (int | None) – Optional deterministic sampling seed forwarded to vLLM.
request_id (str | None) – Explicit request identifier to forward to vLLM.
request_id_prefix (str | None) – Prefix used when auto-generating
request_id.max_retries (int) – Number of attempts before surfacing the error.
backoff (float) – Base backoff in seconds; exponential across attempts.
timeout (float) – Per‑request timeout in seconds.
return_logprobs (bool) – Whether to request log-prob metadata from vLLM.
service_model (str | None) – Optional identifier for the served model (used in error payloads).
metadata (dict[str, Any] | None) – Optional structured context (dataset/model) copied into error payloads.
client_tag (str | None) – Optional client/rank identifier forwarded via headers/payload.
backoff_multiplier (float | None)
- Returns:
Tuple of grouped texts, optional log-prob metadata, and latency in milliseconds.
- Return type:
tuple[list[list[str]], Optional[list[list[VLLMLogprobResult]]], float]
- Raises:
GenerationServiceError – When the server responds with repeated errors after exhausting retries.