maxent_grpo.training.patches.vllm

Robust helpers for a vLLM /generate server.

This module provides a resilient safe_generate that handles transient errors with retries/backoff and decodes multiple response schemas across vLLM versions (OpenAI‑compatible choices, results, batched text, and legacy completion_ids when a tokenizer is provided). It also supports streaming responses by collecting chunked JSON lines.

Key functions

  • safe_request: Simple GET with retries/backoff.

  • safe_generate: POST to /generate with schema‑agnostic decoding and optional streaming support.

License Copyright 2025 Liv d’Aliberti

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the specific language governing permissions and limitations under the License.

Functions

_build_vllm_headers(prompts[, client_tag])

Construct optional headers used by TRL's vLLM RPC helpers.

_clean_logprob_seq(candidate)

Normalize various logprob containers into a float list.

_collect_stream_texts(response, num_prompts)

Collect and join streaming response chunks per prompt index.

_extract_logprob_info(entry)

Convert schema-specific logprob payloads into VLLMLogprobResult.

_filter_response_for_client_tag(data, client_tag)

Remove prompt groups that do not match client_tag when provided.

_filter_result_outputs_for_tag(entry, client_tag)

Filter per-output metadata for a single prompt entry.

_find_client_tag(candidate[, depth])

Traverse a limited portion of candidate to locate client_tag.

_has_token_logprobs(entry)

Return True when a logprob entry exposes per-token logprobs.

_infer_token_count(entry, seq)

Best-effort token-count heuristic for varied response schemas.

_log_vllm_info(template, *args)

_metadata_from_token_ids(token_ids)

Return metadata carrying token IDs without logprob fields.

_mirror_vllm_log(message)

_parse_nonstream_json(data[, tokenizer, ...])

Normalize vLLM JSON response into grouped texts (+ optional logprobs).

_regroup_flat_single_completion_groups(...)

Regroup flat one-completion entries into prompt-major groups.

_summarize_logprob_entry(entry)

Return a (logprob_sum, token_count) tuple when available.

safe_generate(*, prompts[, url, max_tokens, ...])

Robust POST to /generate with retry + schema-agnostic decoding.

safe_request(url[, max_retries, backoff, ...])

GET JSON with basic retry/backoff.

Classes

GenerationLogprobEntry

alias of VLLMLogprobResult

TokenizerLike(*args, **kwargs)

Protocol for objects that can decode token IDs to text.

VLLMLogprobResult(logprob_sum, token_count)

Aggregate (and optionally raw) log-probability info for one completion.

class maxent_grpo.training.patches.vllm.VLLMLogprobResult(logprob_sum, token_count, token_logprobs=None, raw_output=None)[source]

Bases: object

Aggregate (and optionally raw) log-probability info for one completion.

Parameters:
logprob_sum: float | None
token_count: int | None
token_logprobs: List[float] | None = None
raw_output: Dict[str, Any] | None = None
to_trl_payload()[source]

Return a dict compatible with TRL’s refinement metadata.

Returns:

Dictionary describing logprob sums/tokens/raw output.

Return type:

dict[str, Any]

maxent_grpo.training.patches.vllm.GenerationLogprobEntry

alias of VLLMLogprobResult

class maxent_grpo.training.patches.vllm.TokenizerLike(*args, **kwargs)[source]

Bases: Protocol

Protocol for objects that can decode token IDs to text.

decode(token_ids, **kwargs)[source]

Return the decoded string for token_ids.

Parameters:
Return type:

str

maxent_grpo.training.patches.vllm.safe_request(url, max_retries=3, backoff=1.0, timeout=10.0)[source]

GET JSON with basic retry/backoff.

Parameters:
  • url (str) – Endpoint to query.

  • max_retries (int) – Number of attempts before surfacing the error.

  • backoff (float) – Base backoff in seconds; exponential across attempts.

  • timeout (float) – Per‑request timeout in seconds.

Returns:

Parsed JSON payload.

Return type:

dict

Raises:

RuntimeError – If the request ultimately fails (HTTP error or repeated connection/timeouts).

maxent_grpo.training.patches.vllm.safe_generate(*, prompts, url='http://localhost:8000/generate', max_tokens=256, temperature=0.7, top_p=0.9, top_k=None, n=1, stream=False, tokenizer=None, best_of=None, frequency_penalty=None, presence_penalty=None, stop=None, include_stop_str_in_output=False, logit_bias=None, allowed_token_ids=None, blocked_token_ids=None, guided_json=None, guided_regex=None, seed=None, request_id=None, request_id_prefix=None, timeout=None, max_retries=None, backoff=None, backoff_multiplier=None, return_logprobs=False, service_model=None, metadata=None, client_tag=None)[source]

Robust POST to /generate with retry + schema-agnostic decoding.

Parameters:
  • prompts (list[str]) – Input prompts (batch) to generate from.

  • url (str) – Base URL to the /generate route.

  • max_tokens (int) – Maximum tokens to generate per completion.

  • temperature (float) – Sampling temperature.

  • top_p (float) – Nucleus sampling p.

  • top_k (int | None) – Optional top-k cutoff applied during sampling.

  • n (int) – Number of completions per prompt.

  • stream (bool) – Whether to use chunked streaming responses.

  • tokenizer (Any) – Optional tokenizer to decode token ID arrays.

  • best_of (int | None) – vLLM best_of parameter to sample more than n candidates.

  • frequency_penalty (float | None) – Frequency penalty forwarded to vLLM sampling.

  • presence_penalty (float | None) – Presence penalty forwarded to vLLM sampling.

  • stop (list[str] | None) – Stop sequences used to truncate completions.

  • include_stop_str_in_output (bool) – Whether matched stop strings should remain in the returned text.

  • logit_bias (dict[str, float] | None) – Token-level logit bias forwarded to vLLM.

  • allowed_token_ids (list[int] | None) – Optional hard allowlist of token IDs forwarded to vLLM.

  • blocked_token_ids (list[int] | None) – Optional hard denylist of token IDs forwarded to vLLM.

  • guided_json (str | None) – Optional JSON schema string for constrained decoding.

  • guided_regex (str | None) – Optional regex constraint for decoding.

  • seed (int | None) – Optional deterministic sampling seed forwarded to vLLM.

  • request_id (str | None) – Explicit request identifier to forward to vLLM.

  • request_id_prefix (str | None) – Prefix used when auto-generating request_id.

  • max_retries (int) – Number of attempts before surfacing the error.

  • backoff (float) – Base backoff in seconds; exponential across attempts.

  • timeout (float) – Per‑request timeout in seconds.

  • return_logprobs (bool) – Whether to request log-prob metadata from vLLM.

  • service_model (str | None) – Optional identifier for the served model (used in error payloads).

  • metadata (dict[str, Any] | None) – Optional structured context (dataset/model) copied into error payloads.

  • client_tag (str | None) – Optional client/rank identifier forwarded via headers/payload.

  • backoff_multiplier (float | None)

Returns:

Tuple of grouped texts, optional log-prob metadata, and latency in milliseconds.

Return type:

tuple[list[list[str]], Optional[list[list[VLLMLogprobResult]]], float]

Raises:

GenerationServiceError – When the server responds with repeated errors after exhausting retries.