maxent_grpo.training.rollout.generator

Public CompletionGenerator that wires local and vLLM helpers together.

Classes

CompletionGenerator(ctx)

Stateful helper that handles both local HF and vLLM completions.

class maxent_grpo.training.rollout.generator.CompletionGenerator(ctx)[source]

Bases: LocalGenerationMixin, VLLMGenerationMixin

Stateful helper that handles both local HF and vLLM completions.

Parameters:

ctx (GenerationContext)

describe()[source]

Expose the underlying generation configuration for logging.

Return type:

Dict[str, Any]

generate(prompts, num_samples, per_prompt_counts=None)[source]

Produce completions, preferring vLLM when configured.

Parameters:
Return type:

Tuple[List[List[str]], List[List[VLLMLogprobResult | None]] | None]

class maxent_grpo.training.rollout.generator.GenerationContext(max_prompt_len, max_completion_len, gen_temperature, gen_top_p, use_vllm, vllm, accelerator, model, tokenizer, generation_stats, device, penalty=<factory>, prompt_char_limit=None, *, vllm_mode='server')[source]

Bases: GenerationPenaltyPassthroughMixin, GenerationSamplingConfig

Configuration required to produce completions for each training batch.

Parameters:
accelerator: TypesAccelerator
model: TypesPreTrainedModel
tokenizer: TypesPreTrainedTokenizer
generation_stats: Dict[str, int]
device: Any
penalty: GenerationPenaltyConfig
prompt_char_limit: int | None = None
as_dict()[source]

Return a lightweight representation useful for logging/debugging.

Return type:

Dict[str, Any]

max_prompt_len: int
max_completion_len: int
gen_temperature: float
gen_top_p: float
use_vllm: bool
vllm: VLLMClientConfig
maxent_grpo.training.rollout.generator.safe_generate(*, prompts, url='http://localhost:8000/generate', max_tokens=256, temperature=0.7, top_p=0.9, top_k=None, n=1, stream=False, tokenizer=None, best_of=None, frequency_penalty=None, presence_penalty=None, stop=None, include_stop_str_in_output=False, logit_bias=None, allowed_token_ids=None, blocked_token_ids=None, guided_json=None, guided_regex=None, seed=None, request_id=None, request_id_prefix=None, timeout=None, max_retries=None, backoff=None, backoff_multiplier=None, return_logprobs=False, service_model=None, metadata=None, client_tag=None)[source]

Robust POST to /generate with retry + schema-agnostic decoding.

Parameters:
  • prompts (list[str]) – Input prompts (batch) to generate from.

  • url (str) – Base URL to the /generate route.

  • max_tokens (int) – Maximum tokens to generate per completion.

  • temperature (float) – Sampling temperature.

  • top_p (float) – Nucleus sampling p.

  • top_k (int | None) – Optional top-k cutoff applied during sampling.

  • n (int) – Number of completions per prompt.

  • stream (bool) – Whether to use chunked streaming responses.

  • tokenizer (Any) – Optional tokenizer to decode token ID arrays.

  • best_of (int | None) – vLLM best_of parameter to sample more than n candidates.

  • frequency_penalty (float | None) – Frequency penalty forwarded to vLLM sampling.

  • presence_penalty (float | None) – Presence penalty forwarded to vLLM sampling.

  • stop (list[str] | None) – Stop sequences used to truncate completions.

  • include_stop_str_in_output (bool) – Whether matched stop strings should remain in the returned text.

  • logit_bias (dict[str, float] | None) – Token-level logit bias forwarded to vLLM.

  • allowed_token_ids (list[int] | None) – Optional hard allowlist of token IDs forwarded to vLLM.

  • blocked_token_ids (list[int] | None) – Optional hard denylist of token IDs forwarded to vLLM.

  • guided_json (str | None) – Optional JSON schema string for constrained decoding.

  • guided_regex (str | None) – Optional regex constraint for decoding.

  • seed (int | None) – Optional deterministic sampling seed forwarded to vLLM.

  • request_id (str | None) – Explicit request identifier to forward to vLLM.

  • request_id_prefix (str | None) – Prefix used when auto-generating request_id.

  • max_retries (int) – Number of attempts before surfacing the error.

  • backoff (float) – Base backoff in seconds; exponential across attempts.

  • timeout (float) – Per‑request timeout in seconds.

  • return_logprobs (bool) – Whether to request log-prob metadata from vLLM.

  • service_model (str | None) – Optional identifier for the served model (used in error payloads).

  • metadata (dict[str, Any] | None) – Optional structured context (dataset/model) copied into error payloads.

  • client_tag (str | None) – Optional client/rank identifier forwarded via headers/payload.

  • backoff_multiplier (float | None)

Returns:

Tuple of grouped texts, optional log-prob metadata, and latency in milliseconds.

Return type:

tuple[list[list[str]], Optional[list[list[VLLMLogprobResult]]], float]

Raises:

GenerationServiceError – When the server responds with repeated errors after exhausting retries.