maxent_grpo.training.patches.vllm¶
Robust helpers for a vLLM /generate server.
This module provides a resilient safe_generate that handles transient
errors with retries/backoff and decodes multiple response schemas across vLLM
versions (OpenAI‑compatible choices, results, batched text, and
legacy completion_ids when a tokenizer is provided). It also supports
streaming responses by collecting chunked JSON lines.
Key functions
safe_request: Simple GET with retries/backoff.safe_generate: POST to/generatewith schema‑agnostic decoding and optional streaming support.
License Copyright 2025 Liv d’Aliberti
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the specific language governing permissions and limitations under the License.
Functions
|
Construct optional headers used by TRL's vLLM RPC helpers. |
|
Normalize various logprob containers into a float list. |
|
Collect and join streaming response chunks per prompt index. |
|
Convert schema-specific logprob payloads into |
|
Remove prompt groups that do not match |
|
Filter per-output metadata for a single prompt entry. |
|
Traverse a limited portion of |
|
Return True when a logprob entry exposes per-token logprobs. |
|
Best-effort token-count heuristic for varied response schemas. |
|
|
|
Return metadata carrying token IDs without logprob fields. |
|
|
|
Normalize vLLM JSON response into grouped texts (+ optional logprobs). |
|
Regroup flat one-completion entries into prompt-major groups. |
|
Return a (logprob_sum, token_count) tuple when available. |
|
Robust POST to |
|
GET JSON with basic retry/backoff. |
Classes
alias of |
|
|
Protocol for objects that can decode token IDs to text. |
|
Aggregate (and optionally raw) log-probability info for one completion. |
- class maxent_grpo.training.patches.vllm.VLLMLogprobResult(logprob_sum, token_count, token_logprobs=None, raw_output=None)[source]¶
Bases:
objectAggregate (and optionally raw) log-probability info for one completion.
- Parameters:
- maxent_grpo.training.patches.vllm.GenerationLogprobEntry¶
alias of
VLLMLogprobResult
- class maxent_grpo.training.patches.vllm.TokenizerLike(*args, **kwargs)[source]¶
Bases:
ProtocolProtocol for objects that can decode token IDs to text.
- maxent_grpo.training.patches.vllm.safe_request(url, max_retries=3, backoff=1.0, timeout=10.0)[source]¶
GET JSON with basic retry/backoff.
- Parameters:
- Returns:
Parsed JSON payload.
- Return type:
- Raises:
RuntimeError – If the request ultimately fails (HTTP error or repeated connection/timeouts).
- maxent_grpo.training.patches.vllm.safe_generate(*, prompts, url='http://localhost:8000/generate', max_tokens=256, temperature=0.7, top_p=0.9, top_k=None, n=1, stream=False, tokenizer=None, best_of=None, frequency_penalty=None, presence_penalty=None, stop=None, include_stop_str_in_output=False, logit_bias=None, allowed_token_ids=None, blocked_token_ids=None, guided_json=None, guided_regex=None, seed=None, request_id=None, request_id_prefix=None, timeout=None, max_retries=None, backoff=None, backoff_multiplier=None, return_logprobs=False, service_model=None, metadata=None, client_tag=None)[source]¶
Robust POST to
/generatewith retry + schema-agnostic decoding.- Parameters:
prompts (list[str]) – Input prompts (batch) to generate from.
url (str) – Base URL to the
/generateroute.max_tokens (int) – Maximum tokens to generate per completion.
temperature (float) – Sampling temperature.
top_p (float) – Nucleus sampling p.
top_k (int | None) – Optional top-k cutoff applied during sampling.
n (int) – Number of completions per prompt.
stream (bool) – Whether to use chunked streaming responses.
tokenizer (Any) – Optional tokenizer to decode token ID arrays.
best_of (int | None) – vLLM
best_ofparameter to sample more thanncandidates.frequency_penalty (float | None) – Frequency penalty forwarded to vLLM sampling.
presence_penalty (float | None) – Presence penalty forwarded to vLLM sampling.
stop (list[str] | None) – Stop sequences used to truncate completions.
include_stop_str_in_output (bool) – Whether matched stop strings should remain in the returned text.
logit_bias (dict[str, float] | None) – Token-level logit bias forwarded to vLLM.
allowed_token_ids (list[int] | None) – Optional hard allowlist of token IDs forwarded to vLLM.
blocked_token_ids (list[int] | None) – Optional hard denylist of token IDs forwarded to vLLM.
guided_json (str | None) – Optional JSON schema string for constrained decoding.
guided_regex (str | None) – Optional regex constraint for decoding.
seed (int | None) – Optional deterministic sampling seed forwarded to vLLM.
request_id (str | None) – Explicit request identifier to forward to vLLM.
request_id_prefix (str | None) – Prefix used when auto-generating
request_id.max_retries (int) – Number of attempts before surfacing the error.
backoff (float) – Base backoff in seconds; exponential across attempts.
timeout (float) – Per‑request timeout in seconds.
return_logprobs (bool) – Whether to request log-prob metadata from vLLM.
service_model (str | None) – Optional identifier for the served model (used in error payloads).
metadata (dict[str, Any] | None) – Optional structured context (dataset/model) copied into error payloads.
client_tag (str | None) – Optional client/rank identifier forwarded via headers/payload.
backoff_multiplier (float | None)
- Returns:
Tuple of grouped texts, optional log-prob metadata, and latency in milliseconds.
- Return type:
tuple[list[list[str]], Optional[list[list[VLLMLogprobResult]]], float]
- Raises:
GenerationServiceError – When the server responds with repeated errors after exhausting retries.