Source code for maxent_grpo.training.generation.vllm_requests

"""Request/retry helpers separated from vLLM weight sync and scatter logic."""

from __future__ import annotations
# pylint: disable=broad-exception-caught

import hashlib
import logging
import os
import sys
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
from urllib.parse import urlparse

from maxent_grpo.training.generation.errors import (
    GenerationServiceError,
    ServiceErrorPayload,
)
from maxent_grpo.training.generation.vocab_guard import (
    merge_invalid_token_block_logit_bias,
    resolve_blocked_token_ids,
    resolve_allowed_token_ids,
)
from maxent_grpo.training.patches.vllm import VLLMLogprobResult, safe_generate
from maxent_grpo.training.runtime.prompts import _truncate_prompt
from maxent_grpo.training.runtime.logging import _wandb_error_types
from .vllm_state import _VLLMGenerationState

_WANDB_LOG_EXCEPTIONS = _wandb_error_types() + (OSError,)

_DEFAULT_PROMPT_CHAR_LIMIT = 2048

LOG = logging.getLogger(__name__)
_CLIENT_TAG_FAIL_FAST_ENV = "MAXENT_VLLM_CLIENT_TAG_FAIL_FAST"


def _coerce_bool(value: Any) -> Optional[bool]:
    if value is None:
        return None
    if isinstance(value, bool):
        return value
    if isinstance(value, (int, float)):
        return bool(value)
    if isinstance(value, str):
        normalized = value.strip().lower()
        if normalized in {"1", "true", "yes", "on"}:
            return True
        if normalized in {"0", "false", "no", "off"}:
            return False
    return None


def _client_tag_fail_fast_enabled(ctx: Any | None = None) -> bool:
    cfg_val = None
    if ctx is not None:
        cfg_val = getattr(ctx, "vllm_client_tag_fail_fast", None)
        if cfg_val is None:
            training_args = getattr(ctx, "training_args", None)
            cfg_val = getattr(training_args, "vllm_client_tag_fail_fast", None)
    parsed = _coerce_bool(cfg_val)
    if parsed is not None:
        return parsed
    raw = os.environ.get(_CLIENT_TAG_FAIL_FAST_ENV, "1")
    return str(raw).strip().lower() not in {"0", "false", "no", "off"}


def _is_client_tag_error(err: BaseException) -> bool:
    message = ""
    if isinstance(err, GenerationServiceError):
        message = getattr(err.payload, "exception_message", "") or ""
    if not message:
        message = str(err)
    lowered = message.lower()
    return "client_tag" in lowered or "client tag" in lowered


def _record_logprob_status(ctx: Any, has_payload: bool) -> None:
    stats = getattr(ctx, "generation_stats", None)
    if not isinstance(stats, dict):
        return
    stats.setdefault("vllm_logprobs_missing_rounds", 0)
    stats.setdefault("vllm_logprobs_missing_consecutive", 0)
    stats.setdefault("vllm_logprobs_present_rounds", 0)
    if has_payload:
        stats["vllm_logprobs_present_rounds"] += 1
        stats["vllm_logprobs_missing_consecutive"] = 0
    else:
        stats["vllm_logprobs_missing_rounds"] += 1
        stats["vllm_logprobs_missing_consecutive"] += 1


def _hash_prompts(prompts: List[str]) -> str:
    """Return a stable identifier for the pending prompt batch."""

    if not prompts:
        return "0" * 12
    try:
        joined = "\u241e".join(prompts)
    except TypeError:
        joined = "\u241e".join(str(prompt) for prompt in prompts)
    digest = hashlib.sha256(joined.encode("utf-8", errors="ignore")).hexdigest()
    return digest[:12]


def _resolve_served_model_id(ctx: Any) -> Optional[str]:
    """Best-effort resolution of the external model identifier."""

    direct_keys = (
        "vllm_model_id",
        "served_model_id",
        "model_name",
        "model_id",
        "hub_model_id",
    )
    for key in direct_keys:
        value = getattr(ctx, key, None)
        if isinstance(value, str) and value:
            return value
    model_obj = getattr(ctx, "model", None)
    if model_obj is not None:
        name = getattr(model_obj, "name_or_path", None)
        if isinstance(name, str) and name:
            return name
        cfg = getattr(model_obj, "config", None)
        cfg_name = getattr(cfg, "name_or_path", None) or getattr(
            cfg, "_name_or_path", None
        )
        if isinstance(cfg_name, str) and cfg_name:
            return cfg_name
    training_args = getattr(ctx, "training_args", None)
    hub_id = getattr(training_args, "hub_model_id", None)
    if isinstance(hub_id, str) and hub_id:
        return hub_id
    return None


def _resolve_dataset_label(ctx: Any) -> Optional[str]:
    """Return the dataset label stored on the context or stats."""

    stats = getattr(ctx, "generation_stats", None)
    if isinstance(stats, dict):
        label = stats.get("dataset_name")
        if isinstance(label, str) and label:
            return label
    label = getattr(ctx, "dataset_name", None)
    if isinstance(label, str) and label:
        return label
    training_args = getattr(ctx, "training_args", None)
    label = getattr(training_args, "dataset_name", None)
    if isinstance(label, str) and label:
        return label
    return None


def _resolve_client_tag(ctx: Any) -> Optional[str]:
    """Return a stable client tag for this trainer rank if available."""

    explicit = getattr(ctx, "vllm_client_tag", None)
    if isinstance(explicit, str) and explicit.strip():
        return explicit.strip()

    env_tag = os.environ.get("VLLM_CLIENT_TAG")
    if isinstance(env_tag, str) and env_tag.strip():
        return env_tag.strip()

    accelerator = getattr(ctx, "accelerator", None)
    rank: Optional[int] = None
    world: Optional[int] = None
    if accelerator is not None:
        rank = getattr(accelerator, "process_index", None)
        world = getattr(accelerator, "num_processes", None)
    if rank is None:
        for key in ("RANK", "LOCAL_RANK", "SLURM_PROCID"):
            raw = os.environ.get(key)
            if raw is not None:
                try:
                    rank = int(raw)
                    break
                except (TypeError, ValueError):
                    continue
    if rank is None:
        return None
    if world is None:
        for key in ("WORLD_SIZE", "SLURM_NTASKS"):
            raw = os.environ.get(key)
            if raw is not None:
                try:
                    world = int(raw)
                    break
                except (TypeError, ValueError):
                    continue
    if world is not None and world > 0:
        return f"rank-{rank}-of-{world}"
    return f"rank-{rank}"


def _resolve_default_limit() -> int:
    """Return the current default prompt token cap from the environment."""
    env_val = os.environ.get("MAX_PROMPT_TOKENS")
    if env_val is None:
        env_val = os.environ.get("MAX_PROMPT_CHARS")
    if env_val is not None:
        try:
            return int(env_val)
        except (TypeError, ValueError):
            LOG.debug(
                "Invalid MAX_PROMPT_TOKENS/MAX_PROMPT_CHARS=%s; using defaults.",
                env_val,
            )
    # A non-positive default explicitly disables the static fallback limit.
    if int(_DEFAULT_PROMPT_CHAR_LIMIT) <= 0:
        return int(_DEFAULT_PROMPT_CHAR_LIMIT)
    try:
        from maxent_grpo.training.runtime import prompts as prompts_mod

        baseline = getattr(prompts_mod, "PROMPT_CHAR_LIMIT", _DEFAULT_PROMPT_CHAR_LIMIT)
    except (ImportError, AttributeError):
        baseline = _DEFAULT_PROMPT_CHAR_LIMIT
    try:
        return max(int(baseline), int(_DEFAULT_PROMPT_CHAR_LIMIT))
    except (TypeError, ValueError):
        return _DEFAULT_PROMPT_CHAR_LIMIT


def _normalize_vllm_url(raw_url: Optional[str]) -> str:
    """Return a normalized vLLM /generate endpoint URL or raise on invalid input."""
    if raw_url is None:
        return ""
    raw = str(raw_url).strip()
    if not raw:
        return ""
    trimmed = raw.rstrip("/")
    if trimmed.endswith("/generate"):
        return f"{trimmed}/"
    parsed = None
    try:
        parsed = urlparse(trimmed)
    except ValueError:
        parsed = None
    if parsed and parsed.scheme and parsed.netloc:
        base = f"{parsed.scheme}://{parsed.netloc}"
        if parsed.path in ("", "/"):
            normalized = f"{base}/generate/"
            LOG.warning("vllm_url missing /generate; normalizing to %s", normalized)
            return normalized
        if parsed.path in ("/v1", "/v1/"):
            normalized = f"{base}/generate/"
            LOG.warning(
                "vllm_url points to OpenAI-style /v1; expected /generate. "
                "Normalizing to %s",
                normalized,
            )
            return normalized
    raise ValueError(f"vllm_url must point to the /generate endpoint (got {raw!r}).")


[docs] class VLLMRequestMixin: """Mix-in that isolates request building, retries, and aggregation.""" ctx: Any _safe_generate: Any _time: Any _fallback_generate: Any
[docs] def set_safe_generate(self, safe_fn: Callable[..., Any]) -> None: """Allow callers to override the vLLM ``safe_generate`` hook. :param safe_fn: Callable matching the ``safe_generate`` signature. :type safe_fn: Callable[..., Any] """ self._safe_generate = safe_fn
[docs] def set_time_provider(self, time_mod: Any) -> None: """Allow callers to override the time module for sleep/now calls. :param time_mod: Replacement module or object exposing ``sleep`` and ``time`` as needed. :type time_mod: Any """ self._time = time_mod
[docs] def set_fallback_generate(self, fallback_fn: Callable[..., Any]) -> None: """Allow callers to override the local fallback generation hook. :param fallback_fn: Callable invoked when vLLM cannot provide outputs. :type fallback_fn: Callable[..., Any] """ self._fallback_generate = fallback_fn
[docs] def set_request_executor( self, executor_fn: Callable[["_VLLMGenerationState", List[int]], bool] ) -> None: """Allow callers to override the vLLM request executor. :param executor_fn: Function that performs one vLLM request round for pending indices and returns ``True`` on success. :type executor_fn: Callable[[maxent_grpo.training.generation.vllm_state._VLLMGenerationState, list[int]], bool] """ setattr(self, "_execute_vllm_request", executor_fn)
[docs] def set_request_batcher( self, batcher_fn: Callable[ [list[str], int], Tuple[ Optional[List[List[str]]], Optional[List[List[Optional[VLLMLogprobResult]]]], ], ], ) -> None: """Allow callers to override the vLLM batch request helper. :param batcher_fn: Callable used to build and dispatch a single vLLM request for a list of prompts and a target count. :type batcher_fn: Callable[[list[str], int], tuple[list[list[str]] | None, list[list[VLLMLogprobResult | None]] | None]] """ setattr(self, "_request_vllm_batch", batcher_fn)
[docs] def run_vllm_rounds(self, state: _VLLMGenerationState) -> None: """Public entry point for executing vLLM retry rounds. :param state: Mutable vLLM generation state tracked across retries. :type state: _VLLMGenerationState """ self._run_vllm_rounds(state)
[docs] @staticmethod def expand_dedup_results( grouped: List[List[str]], meta: Optional[List[List[Optional[VLLMLogprobResult]]]], mapping: Optional[List[int]], ) -> Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]]]: """Public wrapper for expanding de-duplicated results. :param grouped: Grouped completions for unique prompts. :type grouped: list[list[str]] :param meta: Optional grouped metadata for unique prompts. :type meta: list[list[VLLMLogprobResult | None]] | None :param mapping: Mapping from original prompt indices to unique indices. :type mapping: list[int] | None :returns: Grouped completions and metadata expanded to the original prompt ordering. :rtype: tuple[list[list[str]], list[list[VLLMLogprobResult | None]] | None] """ return VLLMRequestMixin._expand_dedup_results(grouped, meta, mapping)
def _resolve_vllm_round_limit(self, requested_n: int) -> int: ctx = self.ctx rounds_cfg = getattr(ctx, "vllm_rounds_cfg", 0) or 0 if rounds_cfg > 0: return max(1, rounds_cfg) return max(1, requested_n) def _prepare_vllm_targets( self, prompts: List[str], num_samples: int, per_prompt_counts: Optional[List[int]], ) -> Tuple[List[str], List[int], Optional[List[int]]]: """Resolve per-prompt targets and deduplication mapping. :param prompts: Original prompt list. :type prompts: list[str] :param num_samples: Global completion target per prompt. :type num_samples: int :param per_prompt_counts: Optional per-prompt completion overrides. :type per_prompt_counts: list[int] | None :returns: Tuple containing prompts to request, target counts per prompt, and an optional mapping back to the original indices. :rtype: tuple[list[str], list[int], list[int] | None] :raises ValueError: If ``per_prompt_counts`` length does not match ``prompts`` length. """ if per_prompt_counts is not None and len(per_prompt_counts) != len(prompts): raise ValueError("per_prompt_counts length must match prompts length") target_counts = ( [max(0, int(count)) for count in per_prompt_counts] if per_prompt_counts is not None else [max(0, int(num_samples))] * len(prompts) ) dedupe_enabled = os.environ.get("MAXENT_VLLM_DEDUP", "0").lower() in { "1", "true", "yes", } if not dedupe_enabled: return list(prompts), target_counts, None seen: Dict[str, int] = {} unique_prompts: List[str] = [] unique_counts: List[int] = [] mapping: List[int] = [] for prompt, count in zip(prompts, target_counts): if prompt in seen: existing_idx = seen[prompt] mapping.append(existing_idx) continue seen[prompt] = len(unique_prompts) mapping.append(seen[prompt]) unique_prompts.append(prompt) unique_counts.append(count) if len(unique_prompts) == len(prompts): return unique_prompts, unique_counts, None return unique_prompts, unique_counts, mapping
[docs] def prepare_vllm_targets( self, prompts: List[str], num_samples: int, per_prompt_counts: Optional[List[int]], ) -> Tuple[List[str], List[int], Optional[List[int]]]: """Public wrapper for resolving vLLM targets/dedup mapping. :param prompts: Original prompt list. :type prompts: list[str] :param num_samples: Global completion target per prompt. :type num_samples: int :param per_prompt_counts: Optional per-prompt completion overrides. :type per_prompt_counts: list[int] | None :returns: Tuple of deduplicated prompts, target counts, and mapping back to the original order when deduplication occurs. :rtype: tuple[list[str], list[int], list[int] | None] """ return self._prepare_vllm_targets(prompts, num_samples, per_prompt_counts)
@staticmethod def _expand_dedup_results( grouped: List[List[str]], meta: Optional[List[List[Optional[VLLMLogprobResult]]]], mapping: Optional[List[int]], ) -> Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]]]: """Expand grouped completions back to match the original prompt ordering. :param grouped: Grouped completions for unique prompts. :type grouped: list[list[str]] :param meta: Optional grouped metadata for unique prompts. :type meta: list[list[VLLMLogprobResult | None]] | None :param mapping: Mapping from original prompt indices to unique indices. :type mapping: list[int] | None :returns: Grouped completions and metadata aligned to the original prompt list. :rtype: tuple[list[list[str]], list[list[VLLMLogprobResult | None]] | None] """ if mapping is None: return grouped, meta expanded: List[List[str]] = [] expanded_meta: Optional[List[List[Optional[VLLMLogprobResult]]]] = ( [] if meta is not None else None ) for idx in mapping: expanded.append(list(grouped[idx]) if idx < len(grouped) else []) if expanded_meta is None: continue expanded_meta.append( list(meta[idx]) if meta is not None and idx < len(meta) else [] ) return expanded, expanded_meta def _run_vllm_rounds(self, state: _VLLMGenerationState) -> None: """Execute vLLM request rounds until targets are met or retries are exhausted. :param state: Mutable generation state containing prompts, targets, and aggregation buffers. :type state: _VLLMGenerationState """ ctx = self.ctx attempt = 0 inner_retries_override = None disable_inner = os.environ.get( "MAXENT_VLLM_DISABLE_INNER_RETRIES", "1" ).strip().lower() not in { "0", "false", "no", "off", } if disable_inner: try: current_retries = int(getattr(ctx, "vllm_max_retries", 0) or 0) except (TypeError, ValueError): current_retries = 0 if current_retries > 1 and state.round_limit > 1: inner_retries_override = current_retries try: setattr(ctx, "vllm_max_retries", 1) except Exception: inner_retries_override = None else: LOG.debug( "Disabling inner vLLM retries (max_retries=%d) because outer rounds=%d.", current_retries, state.round_limit, ) LOG.info( "vLLM rounds loop start | prompts=%d | round_limit=%d", len(state.prompts), state.round_limit, ) try: while attempt < state.round_limit: pending_indices = state.pending_indices() if not pending_indices: break attempt += 1 remaining_counts = state.remaining_counts(pending_indices) LOG.info( "vLLM round attempt | attempt=%d/%d | pending_prompts=%d", attempt, state.round_limit, len(pending_indices), ) LOG.debug( ( "vLLM round dispatch | attempt=%d/%d | pending_prompts=%d " "| remaining_counts_sample=%s | prompt_hash_sample=%s" ), attempt, state.round_limit, len(pending_indices), remaining_counts[:8], _hash_prompts([state.prompts[idx] for idx in pending_indices[:4]]), ) if attempt > 1: ctx.generation_stats["vllm_retry_rounds"] += 1 try: success = self._execute_vllm_request(state, pending_indices) except RuntimeError as err: if _is_client_tag_error(err) and _client_tag_fail_fast_enabled(ctx): stats = getattr(self.ctx, "generation_stats", None) if isinstance(stats, dict): stats["vllm_client_tag_errors"] = ( int(stats.get("vllm_client_tag_errors", 0)) + 1 ) LOG.error( "vLLM client_tag mismatch detected; aborting retries: %s", err, ) raise pending_count = len(pending_indices) LOG.warning( "vLLM attempt %d/%d for %d prompts failed (policy=%s): %s", attempt, state.round_limit, pending_count, self._format_retry_policy(), err, ) status_code = self._status_code_from_error(err) self._record_retry_attempt_metric( status_code, attempt, pending_count, ) if attempt >= state.round_limit: payload = self._build_vllm_failure_payload( state, pending_indices, attempt, err, ) self._log_retry_exhausted_metric(payload) self._log_structured_vllm_failure(payload) raise GenerationServiceError( f"vLLM retries exhausted for batch: {err}", payload, ) from err self._sleep_before_retry() continue if not success: self._record_retry_attempt_metric( None, attempt, len(pending_indices), reason="no_response", ) self._sleep_before_retry() continue missing_indices = state.pending_indices() if missing_indices: self._backfill_missing(state, missing_indices) remaining = state.pending_indices() if remaining: self._record_vllm_failure(state, remaining) LOG.info( "vLLM rounds loop done | remaining_prompts=%d", len(state.pending_indices()), ) finally: if inner_retries_override is not None: try: setattr(ctx, "vllm_max_retries", inner_retries_override) except Exception: pass def _sleep_before_retry(self) -> None: """Sleep between retries when ``vllm_retry_sleep`` is positive.""" retry_sleep = float(getattr(self.ctx, "vllm_retry_sleep", 0.0) or 0.0) if retry_sleep <= 0: return sleep_mod = getattr(self, "_time", time) if hasattr(sleep_mod, "sleep"): sleep_mod.sleep(retry_sleep) def _wandb_run(self) -> Optional[Any]: wandb_mod = sys.modules.get("wandb") if wandb_mod is None: return None return getattr(wandb_mod, "run", None) def _log_wandb_metrics(self, metrics: Dict[str, Any]) -> None: run = self._wandb_run() if run is None or not metrics: return stats = getattr(self.ctx, "generation_stats", {}) or {} step = int(stats.get("current_step") or 0) try: run.log(metrics, step=step) except _WANDB_LOG_EXCEPTIONS as exc: # pragma: no cover - defensive LOG.warning("Failed to log retry metrics to W&B: %s", exc) def _retry_policy_details(self) -> Dict[str, Any]: ctx = self.ctx return { "backoff_initial": float(getattr(ctx, "vllm_backoff", 1.0) or 0.0), "backoff_multiplier": float( getattr(ctx, "vllm_backoff_multiplier", 2.0) or 1.0 ), "retry_sleep": float(getattr(ctx, "vllm_retry_sleep", 0.0) or 0.0), "max_retries": int(getattr(ctx, "vllm_max_retries", 3) or 0), } def _format_retry_policy(self) -> str: policy = self._retry_policy_details() return ( "initial={backoff_initial:.3f},multiplier={backoff_multiplier:.3f}," "sleep={retry_sleep:.3f},max_retries={max_retries}" ).format(**policy) def _record_retry_attempt_metric( self, status_code: Optional[int], attempt: int, pending_count: int, reason: Optional[str] = None, ) -> None: metrics: Dict[str, Any] = { "generation/retry_attempts": 1, "generation/retry_status_code": status_code if status_code is not None else -1, "generation/retry_attempt_index": attempt, "generation/retry_pending": pending_count, } dataset_label = _resolve_dataset_label(self.ctx) if dataset_label: metrics["generation/retry_dataset"] = dataset_label if reason: metrics["generation/retry_reason"] = reason self._log_wandb_metrics(metrics) def _log_retry_exhausted_metric(self, payload: ServiceErrorPayload) -> None: metrics: Dict[str, Any] = { "generation/retry_exhausted": 1, "generation/retry_exhausted_status": payload.status_code if payload.status_code is not None else -1, "generation/retry_prompt_count": payload.prompt_count, "generation/retry_attempt_index": payload.attempt, } prompt_hash = payload.extra.get("prompt_hash") if payload.extra else None if prompt_hash: metrics["generation/retry_prompt_hash"] = prompt_hash dataset_label = payload.extra.get("dataset") if payload.extra else None if not dataset_label: dataset_label = _resolve_dataset_label(self.ctx) if dataset_label: metrics["generation/retry_dataset"] = dataset_label model_label = payload.extra.get("model_id") if payload.extra else None if not model_label: stats = getattr(self.ctx, "generation_stats", None) if isinstance(stats, dict): model_label = stats.get("model_id") if not model_label: model_label = _resolve_served_model_id(self.ctx) if model_label: metrics["generation/retry_model_id"] = model_label self._log_wandb_metrics(metrics) def _status_code_from_error(self, err: BaseException) -> Optional[int]: if isinstance(err, GenerationServiceError): return err.payload.status_code message = str(err) if message.startswith("HTTP "): parts = message.split(" ", 2) if len(parts) >= 2: try: return int(parts[1].rstrip(":")) except ValueError: return None return None def _build_vllm_failure_payload( self, state: _VLLMGenerationState, pending_indices: List[int], attempt: int, err: RuntimeError, ) -> ServiceErrorPayload: """Construct structured metadata for an exhausted retry batch.""" prompts = [state.prompts[idx] for idx in pending_indices] remaining = state.remaining_counts(pending_indices) stats = getattr(self.ctx, "generation_stats", {}) or {} accelerator = getattr(self.ctx, "accelerator", None) rank = int(getattr(accelerator, "process_index", 0)) if accelerator else 0 world = int(getattr(accelerator, "num_processes", 1)) if accelerator else 1 step = stats.get("current_step") if isinstance(stats, dict) else None extra = { "prompt_hash": _hash_prompts(prompts), "pending_indices": list(pending_indices), "remaining_need": remaining, "rank": rank, "world_size": world, "step": int(step) if isinstance(step, (int, float)) else None, "request_id_prefix": getattr(self.ctx, "vllm_request_id_prefix", None), "round_limit": state.round_limit, "backfill_enabled": bool(getattr(self.ctx, "vllm_backfill_local", False)), } extra.update(self._retry_policy_details()) dataset_label = _resolve_dataset_label(self.ctx) if dataset_label: extra["dataset"] = dataset_label model_label = None if isinstance(stats, dict): model_label = stats.get("model_id") if not model_label: model_label = _resolve_served_model_id(self.ctx) if model_label: extra["model_id"] = model_label base_payload = None if isinstance(err, GenerationServiceError): base_payload = err.payload if base_payload is not None: return base_payload.copy_with( attempt=attempt, max_attempts=state.round_limit, exception_type=type(err).__name__, exception_message=str(err), extra=extra, ) payload_chars = sum(len(prompt) for prompt in prompts) payload_size_bytes = sum( len(prompt.encode("utf-8", errors="ignore")) for prompt in prompts ) return ServiceErrorPayload( service="vllm", endpoint=getattr(self.ctx, "vllm_url", ""), model=_resolve_served_model_id(self.ctx), prompt_count=len(prompts), payload_chars=payload_chars, payload_size_bytes=payload_size_bytes, status_code=None, attempt=attempt, max_attempts=state.round_limit, exception_type=type(err).__name__, exception_message=str(err), request_id=None, extra=extra, ) def _log_structured_vllm_failure(self, payload: ServiceErrorPayload) -> None: """Emit structured metrics/logging for retry exhaustion.""" stats = getattr(self.ctx, "generation_stats", None) if isinstance(stats, dict): stats["vllm_retry_failures"] = int(stats.get("vllm_retry_failures", 0)) + 1 stats["vllm_last_error"] = payload.to_dict() LOG.error("vLLM retry exhaustion | payload=%s", payload.to_json()) def _execute_vllm_request( self, state: _VLLMGenerationState, pending_indices: List[int], ) -> bool: """Issue batched vLLM requests for the pending prompts. :param state: Mutable generation state to update with results. :type state: _VLLMGenerationState :param pending_indices: Prompt indices that still need completions. :type pending_indices: list[int] :returns: ``True`` if all batches were accepted by vLLM, ``False`` on a recoverable failure. :rtype: bool """ remaining_counts = state.remaining_counts(pending_indices) grouped_indices: Dict[int, List[int]] = {} for prompt_idx, need in zip(pending_indices, remaining_counts): if need <= 0: continue grouped_indices.setdefault(need, []).append(prompt_idx) for need, indices in grouped_indices.items(): pending_prompts = [state.prompts[idx] for idx in indices] grouped, grouped_meta = self._request_vllm_batch( pending_prompts, need, ) if grouped is None: return False if state.aggregated_meta is not None: has_logprob_payload = False if grouped_meta is not None: for group in grouped_meta: if not group: continue for entry in group: if entry is None: continue if isinstance(entry, dict): if ( "logprob_sum" in entry or "cumulative_logprob" in entry or "token_logprobs" in entry or "logprobs" in entry ): has_logprob_payload = True break else: logprob_sum = getattr(entry, "logprob_sum", None) token_logprobs = getattr(entry, "token_logprobs", None) if logprob_sum is not None or token_logprobs: has_logprob_payload = True break if has_logprob_payload: break if bool(getattr(self.ctx, "vllm_request_logprobs", False)): _record_logprob_status(self.ctx, has_logprob_payload) if not has_logprob_payload: warned = getattr(self.ctx, "_vllm_logprobs_missing_warned", False) if not warned and bool( getattr(self.ctx, "vllm_request_logprobs", False) ): LOG.warning( "vLLM logprob metadata requested but missing from responses; " "continuing without vLLM logprobs. " "If you are using `trl.scripts.vllm_serve`, it does not return logprobs." ) setattr(self.ctx, "_vllm_logprobs_missing_warned", True) self._merge_vllm_results( state, grouped, grouped_meta, indices, ) return True def _prompt_char_limit(self) -> int: """Return the maximum prompt length (tokens) enforced before calling vLLM. The limit prefers ``ctx.prompt_char_limit`` when set, otherwise derives from ``max_prompt_len`` or the static default constant. :returns: Maximum number of tokens to send per prompt. :rtype: int """ base_limit = _resolve_default_limit() limit_override = getattr(self.ctx, "prompt_char_limit", None) if isinstance(limit_override, int) and limit_override > 0: return limit_override max_len = getattr(self.ctx, "max_prompt_len", None) approx_limit = int(max_len) * 4 if isinstance(max_len, int) and max_len > 0 else 0 try: from maxent_grpo.training.generation import vllm as _vllm_mod limit_const = getattr(_vllm_mod, "PROMPT_CHAR_LIMIT", base_limit) except (ImportError, AttributeError, RuntimeError): limit_const = base_limit env_limit = int(limit_const) if isinstance(limit_const, int) else int(base_limit) if env_limit <= 0 and approx_limit <= 0: return 0 if env_limit <= 0: return approx_limit if approx_limit <= 0: return env_limit return max(env_limit, approx_limit) def _request_vllm_batch( self, pending_prompts: List[str], request_count: int, invoke_fn: Optional[ Callable[ [List[str], int], Optional[ Tuple[ List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]], float, ] ], ] ] = None, ) -> Tuple[ Optional[List[List[str]]], Optional[List[List[Optional[VLLMLogprobResult]]]], ]: """Build and dispatch a vLLM batch request for the given prompts. :param pending_prompts: Prompts to send to vLLM. :type pending_prompts: list[str] :param request_count: Number of completions to request per prompt. :type request_count: int :param invoke_fn: Optional callable to override the actual request invocation (useful for testing). :type invoke_fn: Callable[[list[str], int], tuple[list[list[str]], list[list[VLLMLogprobResult | None]], float] | None] | None :returns: Tuple of grouped completions and optional metadata; ``None`` values indicate a hard failure that should trigger retries. :rtype: tuple[list[list[str]] | None, list[list[VLLMLogprobResult | None]] | None] """ char_limit = self._prompt_char_limit() tokenizer = getattr(self.ctx, "tokenizer", None) max_tokens = getattr(self.ctx, "max_prompt_len", None) truncated = [ _truncate_prompt( prompt, char_limit, tokenizer=tokenizer, max_tokens=max_tokens, ) for prompt in pending_prompts ] request_impl = invoke_fn or self._invoke_vllm_requests response = request_impl(truncated, request_count) if response is None: return None, None grouped, grouped_meta, latency_ms = response self._record_vllm_latency(latency_ms) pending_count = len(pending_prompts) raw_group_count = len(grouped) if raw_group_count != pending_count: LOG.warning( "vLLM raw groups=%d for %d prompts (req_n=%d) | per-group preview: %s", raw_group_count, pending_count, request_count, self._summarize_grouped(grouped), ) grouped, grouped_meta = self._coalesce_grouped_outputs( grouped, pending_count, request_count, grouped_meta, ) if len(grouped) == pending_count: LOG.warning( ( "vLLM grouped outputs normalized to %d prompts " "(req_n=%d) | per-prompt lengths=%s" ), len(grouped), request_count, [len(entry) for entry in grouped], ) return grouped, grouped_meta LOG.warning( "vLLM grouped outputs len=%d vs pending=%d | per-prompt lengths=%s", len(grouped), pending_count, [len(entry) for entry in grouped], ) return None, None def _record_vllm_latency(self, latency_ms: float) -> None: """Record latency stats for a vLLM request round. :param latency_ms: Observed latency in milliseconds. :type latency_ms: float """ stats = self.ctx.generation_stats stats["vllm_last_latency_ms"] = float(latency_ms) stats["vllm_latency_total_ms"] = float( stats.get("vllm_latency_total_ms", 0.0) ) + float(latency_ms) stats["vllm_latency_calls"] = int(stats.get("vllm_latency_calls", 0)) + 1 def _build_vllm_request_kwargs( self, prompts: List[str], request_count: int, ) -> Dict[str, Any]: """Assemble keyword arguments passed to ``safe_generate``. :param prompts: Prompt texts to send. :type prompts: list[str] :param request_count: Number of completions requested per prompt. :type request_count: int :returns: Keyword arguments for the vLLM request. :rtype: dict[str, Any] """ ctx = self.ctx stop_sequences = ( ctx.gen_stop_sequences if ctx.gen_stop_sequences is not None else ctx.vllm_stop_sequences ) top_k = ctx.gen_top_k if ctx.gen_top_k is not None else ctx.vllm_top_k if top_k == 0: top_k = -1 best_of = ctx.gen_best_of if ctx.gen_best_of is not None else ctx.vllm_best_of backoff_multiplier = getattr(ctx, "vllm_backoff_multiplier", 2.0) stats = getattr(self.ctx, "generation_stats", {}) or {} stream = _coerce_bool(getattr(ctx, "vllm_stream", None)) if stream is None: stream = _coerce_bool(os.getenv("MAXENT_VLLM_STREAM")) if stream is None: stream = False metadata: Dict[str, Any] = {} dataset_label = _resolve_dataset_label(self.ctx) if dataset_label: metadata["dataset"] = dataset_label model_label = stats.get("model_id") if not model_label: model_label = _resolve_served_model_id(self.ctx) if model_label: metadata["model_id"] = model_label client_tag = _resolve_client_tag(ctx) url = str(getattr(ctx, "vllm_url", "") or "") logit_bias = merge_invalid_token_block_logit_bias( ctx, getattr(ctx, "vllm_logit_bias", None), ) allowed_token_ids = resolve_allowed_token_ids(ctx) blocked_token_ids = resolve_blocked_token_ids(ctx) request_kwargs: Dict[str, Any] = { "prompts": prompts, "url": url, "max_tokens": ctx.max_completion_len, "temperature": ctx.gen_temperature, "top_p": ctx.gen_top_p, "top_k": top_k, "n": request_count, "best_of": best_of, "frequency_penalty": ctx.gen_frequency_penalty, "presence_penalty": ctx.gen_presence_penalty, "stop": stop_sequences, "include_stop_str_in_output": bool( getattr(ctx, "vllm_include_stop_str_in_output", False) ), "logit_bias": logit_bias, "allowed_token_ids": allowed_token_ids, "blocked_token_ids": blocked_token_ids, "guided_json": ctx.vllm_guided_json, "guided_regex": ctx.vllm_guided_regex, "request_id_prefix": ctx.vllm_request_id_prefix, "stream": bool(stream), "tokenizer": ctx.tokenizer, "timeout": ctx.vllm_timeout, "max_retries": ctx.vllm_max_retries, "backoff": ctx.vllm_backoff, "backoff_multiplier": backoff_multiplier, "return_logprobs": ctx.vllm_request_logprobs, "service_model": _resolve_served_model_id(self.ctx), } if metadata: request_kwargs["metadata"] = metadata if client_tag: request_kwargs["client_tag"] = client_tag return request_kwargs def _invoke_vllm_requests( self, prompts: List[str], request_count: int, ) -> Optional[ Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]], float] ]: """Invoke vLLM requests with recursive fallback on failures. :param prompts: Prompt texts to send. :type prompts: list[str] :param request_count: Number of completions requested per prompt. :type request_count: int :returns: Tuple of grouped completions, grouped metadata when enabled, and latency in milliseconds, or ``None`` if the request fails. :rtype: tuple[list[list[str]], list[list[VLLMLogprobResult | None]] | None, float] | None """ try: request_kwargs = self._build_vllm_request_kwargs(prompts, request_count) safe_gen = cast( Callable[ ..., Tuple[ List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]], float, ], ], getattr(self, "_safe_generate", safe_generate), ) LOG.debug( ( "Dispatching vLLM request | prompts=%d | req_n=%d | prompt_hash=%s " "| prompt_lens_sample=%s | total_chars=%d | timeout=%s | " "max_retries=%s | backoff=%s | url=%s" ), len(prompts), request_count, _hash_prompts(prompts), [len(prompt) for prompt in prompts[:8]], sum(len(prompt) for prompt in prompts), request_kwargs.get("timeout"), request_kwargs.get("max_retries"), request_kwargs.get("backoff"), request_kwargs.get("url"), ) LOG.info( "vLLM request start | prompts=%d | req_n=%d | timeout=%s | url=%s", len(prompts), request_count, request_kwargs.get("timeout"), request_kwargs.get("url"), ) grouped, grouped_meta, latency_ms = safe_gen(**request_kwargs) LOG.info( "vLLM request done | prompts=%d | req_n=%d | latency_ms=%.2f", len(prompts), request_count, float(latency_ms), ) return grouped, grouped_meta, latency_ms except RuntimeError as err: if len(prompts) <= 1: LOG.warning("vLLM request failed for single prompt: %s", err) return None mid = max(1, len(prompts) // 2) left_result = self._invoke_vllm_requests(prompts[:mid], request_count) right_result = self._invoke_vllm_requests(prompts[mid:], request_count) if left_result is None or right_result is None: return None combined_meta = None if left_result[1] is not None and right_result[1] is not None: combined_meta = left_result[1] + right_result[1] return ( left_result[0] + right_result[0], combined_meta, left_result[2] + right_result[2], ) @staticmethod def _summarize_grouped(groups: List[List[str]], limit: int = 8) -> str: """Return a concise string summary of grouped completions. :param groups: Grouped completions to summarize. :type groups: list[list[str]] :param limit: Maximum number of groups to include in the summary. :type limit: int :returns: Human-readable summary suitable for logging. :rtype: str """ summary_parts: List[str] = [] for idx, entry in enumerate(groups[:limit]): if isinstance(entry, list): preview = entry[0][:32] if entry else "" summary_parts.append(f"{idx}:len={len(entry)} sample={preview!r}") else: summary_parts.append(f"{idx}:type={type(entry).__name__}") if len(groups) > limit: summary_parts.append(f"...(+{len(groups) - limit})") return "; ".join(summary_parts) def _merge_vllm_results( self, state: _VLLMGenerationState, grouped: List[List[str]], grouped_meta: Optional[List[List[Optional[VLLMLogprobResult]]]], pending_indices: List[int], ) -> None: """Merge vLLM responses into the aggregated state with overflow trimming. :param state: Mutable generation state to update. :type state: _VLLMGenerationState :param grouped: Generated completions aligned to ``pending_indices``. :type grouped: list[list[str]] :param grouped_meta: Optional metadata aligned to ``pending_indices``. :type grouped_meta: list[list[VLLMLogprobResult | None]] | None :param pending_indices: Prompt indices associated with ``grouped``. :type pending_indices: list[int] """ aggregated = state.aggregated aggregated_meta = state.aggregated_meta stats = self.ctx.generation_stats for idx, prompt_idx in enumerate(pending_indices): aggregated[prompt_idx].extend(grouped[idx]) if aggregated_meta is not None and grouped_meta is not None: aggregated_meta[prompt_idx].extend(grouped_meta[idx]) target = state.target_counts[prompt_idx] overflow = 0 if 0 < target < len(aggregated[prompt_idx]): overflow = len(aggregated[prompt_idx]) - target aggregated[prompt_idx] = aggregated[prompt_idx][:target] if aggregated_meta is not None: aggregated_meta[prompt_idx] = aggregated_meta[prompt_idx][:target] if overflow > 0: stats["vllm_excess_prompts"] = stats.get("vllm_excess_prompts", 0) + 1 stats["vllm_excess_completions"] = ( stats.get( "vllm_excess_completions", 0, ) + overflow )
[docs] def merge_vllm_results( self, state: _VLLMGenerationState, grouped: List[List[str]], grouped_meta: Optional[List[List[Optional[VLLMLogprobResult]]]], pending_indices: List[int], ) -> None: """Public wrapper for merging generated outputs. :param state: Generation state to update. :type state: _VLLMGenerationState :param grouped: Generated completions aligned to ``pending_indices``. :type grouped: list[list[str]] :param grouped_meta: Optional metadata aligned to ``pending_indices``. :type grouped_meta: list[list[VLLMLogprobResult | None]] | None :param pending_indices: Prompt indices associated with the provided completions. :type pending_indices: list[int] """ self._merge_vllm_results(state, grouped, grouped_meta, pending_indices)
def _backfill_missing( self, state: _VLLMGenerationState, missing_indices: List[int], ) -> None: """Generate missing completions locally when vLLM fails. :param state: Mutable generation state. :type state: _VLLMGenerationState :param missing_indices: Prompt indices still missing completions. :type missing_indices: list[int] """ ctx = self.ctx if bool(getattr(ctx, "vllm_disable_local_fallback", False)): return backfill_enabled = bool(getattr(ctx, "vllm_backfill_local", False)) if not backfill_enabled: return missing_prompts = [state.prompts[idx] for idx in missing_indices] needed_per_prompt = state.remaining_counts(missing_indices) ctx.generation_stats["vllm_backfilled_prompts"] += len(missing_indices) max_need = max(needed_per_prompt) if needed_per_prompt else 0 LOG.warning( ( "Backfilling %d/%d prompts locally because vLLM failed to " "return the remaining completions (max_need=%d) after %d attempts." ), len(missing_indices), len(state.prompts), max_need, state.round_limit, ) local_groups, _ = self._fallback_generate( missing_prompts, state.requested_n, needed_per_prompt, ) aggregated = state.aggregated for local_idx, prompt_idx in enumerate(missing_indices): target = state.target_counts[prompt_idx] needed = max(0, target - len(aggregated[prompt_idx])) if needed <= 0: continue aggregated[prompt_idx].extend(local_groups[local_idx][:needed]) state.drop_meta()
[docs] def backfill_missing( self, state: _VLLMGenerationState, missing_indices: List[int], ) -> None: """Public wrapper for local fallback generation. :param state: Generation state to update. :type state: _VLLMGenerationState :param missing_indices: Prompt indices still missing completions. :type missing_indices: list[int] """ self._backfill_missing(state, missing_indices)
def _record_vllm_failure( self, state: _VLLMGenerationState, missing_indices: List[int], ) -> None: """Record metrics and warnings when vLLM could not satisfy requests. :param state: Generation state containing prompt counts. :type state: _VLLMGenerationState :param missing_indices: Indices that remain incomplete. :type missing_indices: list[int] """ ctx = self.ctx missing_count = len(missing_indices) ctx.generation_stats["vllm_failed_prompts"] += missing_count backfill_enabled = bool(getattr(ctx, "vllm_backfill_local", False)) suffix = " + local fallback" if backfill_enabled else "" remaining = state.remaining_counts(missing_indices) max_need = max(remaining) if remaining else 0 LOG.warning( ( "Unable to obtain the remaining completions (max_need=%d) for %d/%d prompts even " "after %d attempts%s." ), max_need, missing_count, len(state.prompts), state.round_limit, suffix, )
[docs] def record_vllm_failure( self, state: _VLLMGenerationState, missing_indices: List[int], ) -> None: """Public wrapper for reporting vLLM failures. :param state: Generation state containing prompt counts. :type state: _VLLMGenerationState :param missing_indices: Indices that remain incomplete. :type missing_indices: list[int] """ self._record_vllm_failure(state, missing_indices)
@staticmethod def _coalesce_grouped_outputs( groups: List[List[str]], prompt_count: int, requested_n: int, meta: Optional[List[List[Optional[VLLMLogprobResult]]]] = None, ) -> Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]]]: """Normalize grouped outputs when vLLM returns multiple slices per prompt. :param groups: Raw grouped completions returned by vLLM. :type groups: list[list[str]] :param prompt_count: Number of prompts originally requested. :type prompt_count: int :param requested_n: Target completions per prompt. :type requested_n: int :param meta: Optional grouped metadata aligned with ``groups``. :type meta: list[list[VLLMLogprobResult | None]] | None :returns: Regrouped completions and metadata aligned to prompts. If regrouping is not possible, metadata may be dropped. :rtype: tuple[list[list[str]], list[list[VLLMLogprobResult | None]] | None] """ if not groups or prompt_count <= 0: return groups, meta total_groups = len(groups) if total_groups == prompt_count: return groups, meta if total_groups % prompt_count != 0: return groups, None per_prompt = total_groups // prompt_count if ( per_prompt <= 1 or (requested_n > 0 and per_prompt != requested_n) or not all(len(entry) <= 1 for entry in groups) ): return groups, meta regrouped: List[List[str]] = [] regrouped_meta: Optional[List[List[Optional[VLLMLogprobResult]]]] = ( [] if meta is not None else None ) for chunk_start in range(0, total_groups, per_prompt): chunk = groups[chunk_start : chunk_start + per_prompt] meta_slice = ( meta[chunk_start : chunk_start + per_prompt] if meta is not None else None ) merged, merged_meta = VLLMRequestMixin._merge_group_chunk( chunk, meta_slice, requested_n, ) regrouped.append(merged) if regrouped_meta is None: continue regrouped_meta.append(merged_meta if merged_meta is not None else []) return regrouped, regrouped_meta
[docs] @staticmethod def coalesce_grouped_outputs( groups: List[List[str]], prompt_count: int, requested_n: int, meta: Optional[List[List[Optional[VLLMLogprobResult]]]] = None, ) -> Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]]]: """Public wrapper for regrouping vLLM outputs. :param groups: Raw grouped completions returned by vLLM. :type groups: list[list[str]] :param prompt_count: Number of prompts originally requested. :type prompt_count: int :param requested_n: Target completions per prompt. :type requested_n: int :param meta: Optional grouped metadata aligned with ``groups``. :type meta: list[list[VLLMLogprobResult | None]] | None :returns: Regrouped completions and metadata aligned to prompts. If regrouping is not possible, metadata may be dropped. :rtype: tuple[list[list[str]], list[list[VLLMLogprobResult | None]] | None] """ return VLLMRequestMixin._coalesce_grouped_outputs( groups, prompt_count, requested_n, meta )
@staticmethod def _merge_group_chunk( chunk: List[List[str]], meta_chunk: Optional[List[List[Optional[VLLMLogprobResult]]]], requested_n: int, ) -> Tuple[List[str], Optional[List[Optional[VLLMLogprobResult]]]]: """Merge a contiguous chunk of grouped outputs for a single prompt. :param chunk: Subset of grouped outputs belonging to one prompt. :type chunk: list[list[str]] :param meta_chunk: Optional metadata aligned to ``chunk``. :type meta_chunk: list[list[VLLMLogprobResult | None]] | None :param requested_n: Target number of completions for the prompt. :type requested_n: int :returns: Flattened completions and optional flattened metadata trimmed to ``requested_n``. :rtype: tuple[list[str], list[VLLMLogprobResult | None] | None] """ merged: List[str] = [] merged_meta: Optional[List[Optional[VLLMLogprobResult]]] = ( [] if meta_chunk is not None else None ) for idx, entry in enumerate(chunk): merged.extend(entry) if merged_meta is None: continue if meta_chunk is None or idx >= len(meta_chunk): merged_meta = None continue merged_meta.extend(meta_chunk[idx]) if requested_n > 0: merged = merged[:requested_n] if merged_meta is not None: merged_meta = merged_meta[:requested_n] return merged, merged_meta
[docs] @staticmethod def merge_group_chunk( chunk: List[List[str]], meta_chunk: Optional[List[List[Optional[VLLMLogprobResult]]]], requested_n: int, ) -> Tuple[List[str], Optional[List[Optional[VLLMLogprobResult]]]]: """Public wrapper for merging grouped chunks. :param chunk: Subset of grouped outputs belonging to one prompt. :type chunk: list[list[str]] :param meta_chunk: Optional metadata aligned to ``chunk``. :type meta_chunk: list[list[VLLMLogprobResult | None]] | None :param requested_n: Target number of completions for the prompt. :type requested_n: int :returns: Flattened completions and optional flattened metadata trimmed to ``requested_n``. :rtype: tuple[list[str], list[VLLMLogprobResult | None] | None] """ return VLLMRequestMixin._merge_group_chunk(chunk, meta_chunk, requested_n)
__all__ = ["VLLMRequestMixin"]