Source code for maxent_grpo.training.rollout.vllm_adapter

"""vLLM-focused helpers split away from the local generation path."""

from __future__ import annotations
# pylint: disable=broad-exception-caught

import importlib
import sys
import logging
import os
import time
import numbers
from contextlib import AbstractContextManager, contextmanager
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Optional,
    Sequence,
    Tuple,
    TYPE_CHECKING,
    cast,
)

from maxent_grpo.training.generation.common import (
    AggregatedGenerationState as _AggregatedGenerationState,
    retry_incomplete_prompts as _retry_incomplete_prompts_impl,
    seed_generation_groups as _seed_generation_groups_impl,
)
from maxent_grpo.training.generation.vllm import (
    VLLMGenerationHelper,
    _VLLMGenerationState as _BaseVLLMGenerationState,
)
from maxent_grpo.training.generation.vllm_utils import (
    import_vllm_client_cls as _shared_import_vllm_client_cls,
    init_vllm_client_communicator as _shared_init_vllm_client_communicator,
    zero3_gather_factory as _shared_zero3_gather_factory,
)
from maxent_grpo.training.patches.vllm import VLLMLogprobResult, safe_generate
from maxent_grpo.training.runtime import require_accelerator, require_torch
from maxent_grpo.training.runtime.prompts import _truncate_prompt

from .context import GenerationContext

torch = require_torch("generation")
Accelerator = require_accelerator("generation")
dist = getattr(torch, "distributed", None)
LOG = logging.getLogger(__name__)


def _progress_log_enabled() -> bool:
    raw = os.getenv("MAXENT_PROGRESS_LOG")
    if raw is None or not str(raw).strip():
        return False
    return str(raw).strip().lower() not in {"0", "false", "no", "off"}


if TYPE_CHECKING:
    from accelerate import Accelerator as AcceleratorLike
else:
    AcceleratorLike = Any


def _optional_import(module_name: str) -> Any:
    """Import a module if available without triggering import errors."""
    try:
        return importlib.import_module(module_name)
    except ImportError:
        return None


def _env_flag(name: str, default: bool = False) -> bool:
    raw = os.environ.get(name)
    if raw is None:
        return default
    return raw.strip().lower() not in {"0", "false", "no", "off"}


def _env_int(name: str) -> Optional[int]:
    raw = os.environ.get(name)
    if raw is None:
        return None
    try:
        return int(raw)
    except (TypeError, ValueError):
        return None


def _use_vllm_collective() -> bool:
    raw = os.getenv("MAXENT_VLLM_COLLECTIVE")
    if raw is None:
        return True
    return raw.strip().lower() not in {"0", "false", "no", "off"}


def _zero3_gather_factory(
    accelerator: AcceleratorLike,
) -> Callable[[Sequence[Any]], AbstractContextManager[Any]]:
    return _shared_zero3_gather_factory(accelerator, import_fn=_optional_import)


def _import_vllm_client_cls(
    import_fn: Optional[Callable[[str], Any]] = None,
) -> Optional[type]:
    """Return TRL's VLLMClient using the provided import fn (defaults to optional_import)."""

    return _shared_import_vllm_client_cls(import_fn or _optional_import)


def _resolve_vllm_group_port() -> Optional[int]:
    for key in ("VLLM_GROUP_PORT", "PORT_FOR_COMMUNICATION"):
        value = os.environ.get(key)
        if not value:
            continue
        try:
            return int(value)
        except ValueError:
            LOG.warning("Invalid %s=%r; expected an integer.", key, value)
    return None


@contextmanager
def _temporary_env(overrides: Dict[str, str]) -> Iterable[None]:
    if not overrides:
        yield
        return
    previous: Dict[str, Optional[str]] = {}
    for key, value in overrides.items():
        previous[key] = os.environ.get(key)
        os.environ[key] = value
    try:
        yield
    finally:
        for key, prior in previous.items():
            if prior is None:
                os.environ.pop(key, None)
            else:
                os.environ[key] = prior


def _loopback_host(base_url: str) -> bool:
    try:
        from urllib.parse import urlparse

        parsed = urlparse(base_url)
        host = parsed.hostname or ""
    except Exception:
        host = ""
    if not host:
        host = base_url
    host = host.strip().lower()
    return host in {"localhost", "127.0.0.1", "::1"}


def _vllm_client_nccl_overrides(base_url: str) -> Dict[str, str]:
    overrides: Dict[str, str] = {}
    enable_overrides = str(
        os.getenv("MAXENT_VLLM_CLIENT_NCCL_OVERRIDES", "0")
    ).strip().lower() in {"1", "true", "yes", "on"}
    if not enable_overrides:
        return overrides

    if not _loopback_host(base_url):
        explicit = os.getenv("MAXENT_VLLM_CLIENT_NCCL_SOCKET_IFNAME")
        if explicit and "NCCL_SOCKET_IFNAME" not in os.environ:
            overrides["NCCL_SOCKET_IFNAME"] = explicit
        explicit = os.getenv("MAXENT_VLLM_CLIENT_NCCL_P2P_DISABLE")
        if explicit and "NCCL_P2P_DISABLE" not in os.environ:
            overrides["NCCL_P2P_DISABLE"] = explicit
        explicit = os.getenv("MAXENT_VLLM_CLIENT_NCCL_IB_DISABLE")
        if explicit and "NCCL_IB_DISABLE" not in os.environ:
            overrides["NCCL_IB_DISABLE"] = explicit
        return overrides
    if "NCCL_SOCKET_IFNAME" not in os.environ:
        overrides["NCCL_SOCKET_IFNAME"] = os.getenv(
            "MAXENT_VLLM_CLIENT_NCCL_SOCKET_IFNAME", "lo"
        )
    if "NCCL_P2P_DISABLE" not in os.environ:
        overrides["NCCL_P2P_DISABLE"] = os.getenv(
            "MAXENT_VLLM_CLIENT_NCCL_P2P_DISABLE", "1"
        )
    if "NCCL_IB_DISABLE" not in os.environ:
        overrides["NCCL_IB_DISABLE"] = os.getenv(
            "MAXENT_VLLM_CLIENT_NCCL_IB_DISABLE", "1"
        )
    return overrides


def _is_peft_model_safe(target: Any) -> bool:
    """Return True if accelerate.utils reports that the model uses PEFT adapters."""
    accel_utils = _optional_import("accelerate.utils")
    if accel_utils is None:
        return False
    is_peft_model = getattr(accel_utils, "is_peft_model", None)
    if not callable(is_peft_model):
        return False
    try:
        return bool(is_peft_model(target))
    except (TypeError, AttributeError, ValueError):
        return False


_VLLMGenerationState = _BaseVLLMGenerationState


[docs] class VLLMGenerationMixin: """All vLLM-specific plumbing extracted from the main generator.""" # Access to helper internals is intentional for tests/patching. ctx: GenerationContext def __init__(self, ctx: GenerationContext) -> None: self.ctx = ctx self._vllm_helper = VLLMGenerationHelper(ctx, self._generate_local) self._vllm_colocate_engine = None # Surface patchable hooks for tests so monkeypatched helpers.* propagate. if hasattr(self._vllm_helper, "set_safe_generate"): self._vllm_helper.set_safe_generate(safe_generate) else: setattr(self._vllm_helper, "_safe_generate", safe_generate) setattr(self._vllm_helper, "_scatter_object", _scatter_object) if hasattr(self._vllm_helper, "set_time_provider"): self._vllm_helper.set_time_provider(time) else: setattr(self._vllm_helper, "_time", time) # pragma: no cover - legacy stubs setattr(self._vllm_helper, "_is_peft_model_safe", _is_peft_model_safe) if hasattr(self._vllm_helper, "set_fallback_generate"): self._vllm_helper.set_fallback_generate(self._generate_local) else: setattr(self._vllm_helper, "_fallback_generate", self._generate_local) self._configure_vllm_mode() def _configure_vllm_mode(self) -> None: if not getattr(self.ctx, "use_vllm", False): return mode = str(getattr(self.ctx, "vllm_mode", "server") or "server").strip().lower() if mode in {"inline", "local", "inprocess", "in-process"}: mode = "colocate" if mode != "colocate": return try: from .vllm_colocate import ColocateVLLMEngine except ImportError as exc: LOG.warning("vLLM colocate requested but unavailable: %s", exc) return is_main = bool( getattr(getattr(self.ctx, "accelerator", None), "is_main_process", True) ) if is_main: device_hint = os.getenv("MAXENT_VLLM_COLOCATE_DEVICE", "") if not device_hint: local_rank = os.getenv("LOCAL_RANK") or os.getenv("SLURM_LOCALID") if local_rank is not None and str(local_rank).isdigit(): device_hint = f"cuda:{int(local_rank)}" else: try: if torch.cuda.is_available(): device_hint = f"cuda:{torch.cuda.current_device()}" except Exception: device_hint = "" LOG.info( "vLLM pre-init | CUDA_VISIBLE_DEVICES=%s MAXENT_VLLM_COLOCATE_DEVICE=%s " "LOCAL_RANK=%s SLURM_LOCALID=%s", os.getenv("CUDA_VISIBLE_DEVICES"), device_hint or "auto", os.getenv("LOCAL_RANK"), os.getenv("SLURM_LOCALID"), ) configured_sync = bool(getattr(self.ctx, "vllm_sync_weights", False)) raw_sync_override = os.getenv("MAXENT_VLLM_COLOCATE_SYNC") if raw_sync_override is None: sync_enabled = configured_sync else: sync_enabled = _env_flag("MAXENT_VLLM_COLOCATE_SYNC", False) if sync_enabled != configured_sync: LOG.warning( "vLLM colocate sync override changed vllm_sync_weights from %s to %s " "via MAXENT_VLLM_COLOCATE_SYNC=%r.", configured_sync, sync_enabled, raw_sync_override, ) try: setattr(self.ctx, "vllm_sync_weights", sync_enabled) except Exception: pass self._vllm_colocate_engine = ColocateVLLMEngine(self.ctx, self._generate_local) batcher = getattr(self._vllm_helper, "set_request_batcher", None) if callable(batcher): batcher(self._vllm_colocate_engine.request_batch) else: setattr( self._vllm_helper, "_request_vllm_batch", self._vllm_colocate_engine.request_batch, ) if sync_enabled: if not getattr(self.ctx, "vllm_sync_weights", False): setattr(self.ctx, "vllm_sync_weights", True) LOG.info("vLLM colocate sync enabled.") sync_interval = _env_int("MAXENT_VLLM_COLOCATE_SYNC_INTERVAL") if sync_interval is not None: if sync_interval < 0: LOG.warning( "Invalid MAXENT_VLLM_COLOCATE_SYNC_INTERVAL=%s; ignoring.", sync_interval, ) else: setattr(self.ctx, "vllm_sync_interval_steps", sync_interval) else: current_interval = getattr(self.ctx, "vllm_sync_interval_steps", None) if current_interval is None: setattr(self.ctx, "vllm_sync_interval_steps", 10) LOG.info( "vLLM colocate sync interval defaulted to 10 steps. " "Override via MAXENT_VLLM_COLOCATE_SYNC_INTERVAL." ) is_main = bool( getattr(getattr(self.ctx, "accelerator", None), "is_main_process", True) ) if is_main: try: client = self._vllm_colocate_engine.sync_client() self._vllm_client = client self._vllm_sync_ready = True except Exception as exc: LOG.warning("vLLM colocate sync client unavailable: %s", exc) self._vllm_client = None self._vllm_sync_ready = False else: # Avoid initializing colocate sync clients on non-main ranks. self._vllm_client = None self._vllm_sync_ready = False LOG.info("Skipping vLLM colocate sync client init on non-main rank.") elif is_main: LOG.warning( "vLLM colocate weight sync is disabled; rollouts will use stale vLLM " "weights and can drift off-policy." ) LOG.info("vLLM colocate mode enabled; using in-process vLLM engine.") def _generate_local( self, prompts: List[str], num_samples: int, per_prompt_counts: Optional[List[int]] = None, ) -> Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]]]: raise NotImplementedError("Subclasses must implement _generate_local().") def _prompt_char_limit(self) -> int: raise NotImplementedError("Subclasses must implement _prompt_char_limit().") @property def _vllm_client(self) -> Any: client = getattr(self._vllm_helper, "vllm_client", None) if client is None: client = getattr(self._vllm_helper, "_vllm_client", None) return client @_vllm_client.setter def _vllm_client(self, value: Any) -> None: setattr(self._vllm_helper, "vllm_client", value) setattr(self._vllm_helper, "_vllm_client", value) @property def _vllm_sync_ready(self) -> bool: if hasattr(self._vllm_helper, "vllm_sync_ready"): return bool(getattr(self._vllm_helper, "vllm_sync_ready")) return bool(getattr(self._vllm_helper, "_vllm_sync_ready", False)) @_vllm_sync_ready.setter def _vllm_sync_ready(self, value: bool) -> None: setattr(self._vllm_helper, "vllm_sync_ready", value) setattr(self._vllm_helper, "_vllm_sync_ready", value) @property def _last_vllm_synced_step(self) -> Optional[int]: step = getattr(self._vllm_helper, "last_vllm_synced_step", None) if step is None: step = getattr(self._vllm_helper, "_last_vllm_synced_step", None) return step @_last_vllm_synced_step.setter def _last_vllm_synced_step(self, value: Optional[int]) -> None: setattr(self._vllm_helper, "last_vllm_synced_step", value) setattr(self._vllm_helper, "_last_vllm_synced_step", value) @property def _fsdp_cls(self) -> Any: fsdp = getattr(self._vllm_helper, "fsdp_cls", None) if fsdp is None: fsdp = getattr(self._vllm_helper, "_fsdp_cls", None) return fsdp @_fsdp_cls.setter def _fsdp_cls(self, value: Any) -> None: setattr(self._vllm_helper, "fsdp_cls", value) setattr(self._vllm_helper, "_fsdp_cls", value) def _vllm_base_url(self, url: str) -> str: """Delegate to the shared vLLM helper to normalize the base URL.""" base_url_fn_obj = getattr(self._vllm_helper, "vllm_base_url", None) def _fallback_normalized(value: str) -> str: resolved = self._invoke_helper("_vllm_base_url", value) return str(resolved) if resolved is not None else value normalized_fn: Callable[[str], str] = ( cast(Callable[[str], str], base_url_fn_obj) if callable(base_url_fn_obj) else _fallback_normalized ) return normalized_fn(url) def _ensure_vllm_client(self) -> bool: """Instantiate the TRL VLLMClient when weight sync is enabled.""" try: helpers_mod = sys.modules.get( type(self).__module__ ) or importlib.import_module("maxent_grpo.training.rollout.helpers") except ImportError: helpers_mod = None import_fn = getattr( self, "_import_vllm_client_cls", getattr(helpers_mod, "_import_vllm_client_cls", _import_vllm_client_cls), ) ctx = self.ctx if not getattr(ctx, "vllm_sync_weights", False) or not getattr( ctx.accelerator, "is_main_process", False ): return False if self._vllm_client is not None and self._vllm_sync_ready: return True client_cls = import_fn() if client_cls is None: try: client_cls = import_fn(_optional_import) except TypeError: client_cls = None if client_cls is None or not callable(client_cls): self._vllm_sync_ready = False return False try: base_url = self._vllm_base_url(ctx.vllm_url) LOG.info( "vLLM client NCCL config | base_url=%s | vllm_url=%s | group_port=%s | NCCL_SOCKET_IFNAME=%s | MAXENT_VLLM_CLIENT_NCCL_SOCKET_IFNAME=%s | NCCL_P2P_DISABLE=%s | NCCL_IB_DISABLE=%s", base_url, ctx.vllm_url, _resolve_vllm_group_port(), os.getenv("NCCL_SOCKET_IFNAME", ""), os.getenv("MAXENT_VLLM_CLIENT_NCCL_SOCKET_IFNAME", ""), os.getenv("NCCL_P2P_DISABLE", ""), os.getenv("NCCL_IB_DISABLE", ""), ) try: group_port = _resolve_vllm_group_port() client_kwargs = {"base_url": base_url} if group_port is not None: client_kwargs["group_port"] = group_port client = client_cls(**client_kwargs) except TypeError: try: client = client_cls(base_url=base_url) except TypeError: client = client_cls() init = getattr(client, "init_communicator", None) if callable(init): overrides = _vllm_client_nccl_overrides(base_url) if overrides: LOG.info( "vLLM client NCCL overrides applied | %s", ", ".join(f"{k}={v}" for k, v in overrides.items()), ) else: LOG.info("vLLM client NCCL overrides applied | none") with _temporary_env(overrides): LOG.info( "vLLM client NCCL env effective | NCCL_SOCKET_IFNAME=%s | NCCL_P2P_DISABLE=%s | NCCL_IB_DISABLE=%s", os.getenv("NCCL_SOCKET_IFNAME", ""), os.getenv("NCCL_P2P_DISABLE", ""), os.getenv("NCCL_IB_DISABLE", ""), ) _shared_init_vllm_client_communicator(client, log=LOG.info) self._vllm_client = client self._vllm_sync_ready = True return True except ( OSError, RuntimeError, ValueError, TypeError, ) as exc: # pragma: no cover - defensive LOG.warning("Failed to initialize vLLMClient for weight sync: %s", exc) self._vllm_client = None self._vllm_sync_ready = False helper = getattr(self, "_vllm_helper", None) if helper is not None: try: setattr(helper, "_vllm_last_sync_ok", False) except Exception: pass return False def _maybe_sync_vllm_weights(self) -> None: """Push current model weights to the vLLM server.""" accelerator = self.ctx.accelerator progress_log = _progress_log_enabled() rank = getattr(accelerator, "process_index", None) world = getattr(accelerator, "num_processes", None) sync_start = time.monotonic() if progress_log: LOG.info("vLLM weight sync start | rank=%s/%s", rank, world) try: try: self._vllm_helper.maybe_sync_weights( ensure_client=self._ensure_vllm_client, sync_model=lambda model: self._sync_model_params_to_vllm( model, accelerator ), ) except TypeError: # Allow lightweight stubs without keyword support. self._vllm_helper.maybe_sync_weights() except Exception as exc: if progress_log: LOG.warning( "vLLM weight sync failed | rank=%s/%s | seconds=%.2f | error=%s", rank, world, time.monotonic() - sync_start, exc, ) raise if progress_log: LOG.info( "vLLM weight sync done | rank=%s/%s | seconds=%.2f", rank, world, time.monotonic() - sync_start, ) def _invoke_helper(self, attr: str, *args: Any, **kwargs: Any) -> Any: """Call a helper attribute if present, preferring public names when available.""" helper = getattr(self, "_vllm_helper", None) if helper is None: return None fn = getattr(helper, attr, None) if not callable(fn) and attr.startswith("_"): fn = getattr(helper, attr.lstrip("_"), None) if callable(fn): return fn(*args, **kwargs) return None def _sync_model_params_to_vllm( self, model: Any, accelerator: AcceleratorLike, ) -> None: """Best-effort parameter broadcast mirroring HF GRPO's vLLM path.""" del accelerator # handled internally by the shared helper result = self._invoke_helper("sync_model_params_to_vllm", model) if result is None: self._invoke_helper("_sync_model_params_to_vllm", model) def _push_param_to_vllm(self, name: str, param: Any) -> None: """Send a single parameter tensor to the vLLM client.""" self._invoke_helper("_push_param_to_vllm", name, param) def _reset_vllm_cache(self) -> None: """Reset prefix caches when the vLLM client exposes the helper.""" self._invoke_helper("_reset_vllm_cache") def _sync_fsdp_params(self, model: Any) -> None: """Iterate FSDP shards and push full parameters to vLLM.""" self._invoke_helper("_sync_fsdp_params", model) def _sync_peft_params( self, model: Any, gather_factory: Callable[[Sequence[Any]], AbstractContextManager[Any]], ) -> None: """Push merged PEFT adapter weights to vLLM.""" self._invoke_helper("_sync_peft_params", model, gather_factory) def _sync_standard_params( self, model: Any, gather_factory: Callable[[Sequence[Any]], AbstractContextManager[Any]], ) -> None: """Push standard (non-PEFT/FSDP) parameters to vLLM.""" self._invoke_helper("_sync_standard_params", model, gather_factory) def _resolve_vllm_round_limit(self, requested_n: int) -> int: """Decide how many vLLM rounds to run for the current request.""" result = self._invoke_helper("_resolve_vllm_round_limit", requested_n) if isinstance(result, numbers.Real): try: return int(float(result)) except (TypeError, ValueError): return requested_n return requested_n @staticmethod def _seed_generation_groups( prompt_count: int, grouped_comps: Optional[List[List[str]]], grouped_meta: Optional[List[List[Optional[Any]]]], ) -> Tuple[List[List[str]], Optional[List[List[Optional[Any]]]]]: """Compatibility wrapper for older tests expecting this helper.""" return _seed_generation_groups_impl(prompt_count, grouped_comps, grouped_meta) @staticmethod def _retry_incomplete_prompts( helper: "VLLMGenerationMixin", prompts: List[str], generator: Callable[ [List[str], int, Optional[List[int]]], Tuple[List[List[str]], Optional[List[List[Optional[Any]]]]], ], expected_generations: int, aggregated_comps: List[List[str]], aggregated_meta: Optional[List[List[Optional[Any]]]], max_retry_rounds: Optional[int], ) -> Tuple[List[List[str]], Optional[List[List[Optional[Any]]]]]: """Retry helpers retained for backwards compatibility with older tests.""" del helper # helper is unused but kept for signature compatibility. state = _AggregatedGenerationState(aggregated_comps, aggregated_meta) updated = _retry_incomplete_prompts_impl( prompts, generator, expected_generations, state, max_retry_rounds, ) return updated.completions, updated.metadata @staticmethod def _summarize_grouped(groups: List[List[str]], limit: int = 8) -> str: """Return a compact preview of grouped completions.""" summary_fn = getattr(VLLMGenerationHelper, "_summarize_grouped", None) if callable(summary_fn): return str(summary_fn(groups, limit)) truncated = groups[:limit] parts = [f"{idx}:{len(group)}" for idx, group in enumerate(truncated)] if len(groups) > limit: parts.append(f"+{len(groups) - limit} more") return " | ".join(parts) def _request_vllm_batch( self, pending_prompts: List[str], request_count: int, ) -> Tuple[ Optional[List[List[str]]], Optional[List[List[Optional[VLLMLogprobResult]]]], ]: """Request completions from vLLM for a subset of prompts.""" char_limit = self._prompt_char_limit() tokenizer = getattr(self.ctx, "tokenizer", None) max_tokens = getattr(self.ctx, "max_prompt_len", None) truncated = [ _truncate_prompt( prompt, char_limit, tokenizer=tokenizer, max_tokens=max_tokens, ) for prompt in pending_prompts ] response = self._invoke_vllm_requests(truncated, request_count) if response is None: return None, None grouped, grouped_meta, latency_ms = response self._record_vllm_latency(latency_ms) pending_count = len(pending_prompts) raw_group_count = len(grouped) if raw_group_count != pending_count: LOG.warning( "vLLM raw groups=%d for %d prompts (req_n=%d) | per-group preview: %s", raw_group_count, pending_count, request_count, self._summarize_grouped(grouped), ) coalesce_fn = getattr( self._vllm_helper, "_coalesce_grouped_outputs", self._coalesce_grouped_outputs, ) grouped, grouped_meta = coalesce_fn( grouped, pending_count, request_count, meta=grouped_meta, ) if len(grouped) == pending_count: LOG.warning( ( "vLLM grouped outputs normalized to %d prompts " "(req_n=%d) | per-prompt lengths=%s" ), len(grouped), request_count, [len(entry) for entry in grouped], ) return grouped, grouped_meta LOG.warning( "vLLM grouped outputs len=%d vs pending=%d | per-prompt lengths=%s", len(grouped), pending_count, [len(entry) for entry in grouped], ) return None, None def _record_vllm_latency(self, latency_ms: float) -> None: """Track latency metrics for successful vLLM invocations.""" self._invoke_helper("_record_vllm_latency", latency_ms) def _build_vllm_request_kwargs( self, prompts: List[str], request_count: int, ) -> Dict[str, Any]: """Assemble keyword arguments for ``safe_generate`` requests.""" kwargs = self._invoke_helper( "_build_vllm_request_kwargs", prompts, request_count ) return kwargs if isinstance(kwargs, dict) else {} def _invoke_vllm_requests( self, prompts: List[str], request_count: int, ) -> Optional[ Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]], float] ]: """Call vLLM with retries by splitting large prompt batches.""" try: helpers_mod = sys.modules.get( type(self).__module__ ) or importlib.import_module("maxent_grpo.training.rollout.helpers") except ImportError: helpers_mod = None safe_gen = getattr(helpers_mod, "safe_generate", safe_generate) set_safe = getattr(self._vllm_helper, "set_safe_generate", None) if callable(set_safe): set_safe(safe_gen) else: setattr(self._vllm_helper, "_safe_generate", safe_gen) set_time = getattr(self._vllm_helper, "set_time_provider", None) if callable(set_time): set_time(getattr(helpers_mod, "time", time)) else: setattr( self._vllm_helper, "_time", getattr(helpers_mod, "time", time), ) result = self._invoke_helper("_invoke_vllm_requests", prompts, request_count) return cast( Optional[ Tuple[ List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]], float, ] ], result, ) def _merge_vllm_results( self, state: _VLLMGenerationState, grouped: List[List[str]], grouped_meta: Optional[List[List[Optional[VLLMLogprobResult]]]], pending_indices: List[int], ) -> None: """Append vLLM outputs into the shared state aggregates.""" self._vllm_helper.merge_vllm_results( state, grouped, grouped_meta, pending_indices, ) def _backfill_missing( self, state: _VLLMGenerationState, missing_indices: List[int], ) -> None: """Generate missing completions locally when vLLM under-delivers.""" self._vllm_helper.set_fallback_generate(self._generate_local) self._vllm_helper.backfill_missing(state, missing_indices) def _record_vllm_failure( self, state: _VLLMGenerationState, missing_indices: List[int], ) -> None: """Log a warning when vLLM fails to deliver even after retries/backfill.""" self._vllm_helper.record_vllm_failure(state, missing_indices) @staticmethod def _coalesce_grouped_outputs( groups: List[List[str]], prompt_count: int, requested_n: int, meta: Optional[List[List[Optional[VLLMLogprobResult]]]] = None, ) -> Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]]]: """Normalize grouped outputs when vLLM returns per-sample lists.""" return VLLMGenerationHelper.coalesce_grouped_outputs( groups, prompt_count, requested_n, meta ) @staticmethod def _merge_group_chunk( chunk: List[List[str]], meta_chunk: Optional[List[List[Optional[VLLMLogprobResult]]]], requested_n: int, ) -> Tuple[List[str], Optional[List[Optional[VLLMLogprobResult]]]]: """Merge consecutive micro-groups back into per-prompt lists.""" return VLLMGenerationHelper.merge_group_chunk(chunk, meta_chunk, requested_n) def _prepare_vllm_targets( self, prompts: List[str], num_samples: int, per_prompt_counts: Optional[List[int]], ) -> Tuple[List[str], List[int], Optional[List[int]]]: """Resolve target counts and optional dedup mapping for vLLM.""" return self._vllm_helper.prepare_vllm_targets( prompts, num_samples, per_prompt_counts ) def _run_vllm_rounds(self, state: _VLLMGenerationState) -> None: """Iteratively request completions until targets are satisfied.""" try: helpers_mod = sys.modules.get(type(self).__module__) if helpers_mod is None or not hasattr(helpers_mod, "time"): helpers_mod = importlib.import_module( "maxent_grpo.training.rollout.helpers" ) except ImportError: helpers_mod = helpers_mod if "helpers_mod" in locals() else None set_time = getattr(self._vllm_helper, "set_time_provider", None) if callable(set_time): set_time(getattr(helpers_mod, "time", time)) else: setattr(self._vllm_helper, "_time", getattr(helpers_mod, "time", time)) # Allow monkeypatched generator hooks to propagate into the helper. helper_exec = getattr(self._vllm_helper, "_execute_vllm_request", None) helper_exec_name = getattr( getattr(helper_exec, "__func__", helper_exec), "__name__", "" ) if not callable(helper_exec) or helper_exec_name == "_execute_vllm_request": set_exec = getattr(self._vllm_helper, "set_request_executor", None) if callable(set_exec): set_exec(self._execute_vllm_request) else: setattr( self._vllm_helper, "_execute_vllm_request", self._execute_vllm_request, ) helper_batch = getattr(self._vllm_helper, "_request_vllm_batch", None) helper_batch_name = getattr( getattr(helper_batch, "__func__", helper_batch), "__name__", "" ) if not callable(helper_batch) or helper_batch_name == "_request_vllm_batch": set_batcher = getattr(self._vllm_helper, "set_request_batcher", None) if callable(set_batcher): set_batcher(self._request_vllm_batch) else: setattr( self._vllm_helper, "_request_vllm_batch", self._request_vllm_batch ) set_fallback = getattr(self._vllm_helper, "set_fallback_generate", None) if callable(set_fallback): set_fallback(self._generate_local) else: setattr(self._vllm_helper, "_fallback_generate", self._generate_local) run_rounds = getattr(self._vllm_helper, "run_vllm_rounds", None) if callable(run_rounds): run_rounds(state) else: self._invoke_helper("_run_vllm_rounds", state) @staticmethod def _expand_dedup_results( grouped: List[List[str]], meta: Optional[List[List[Optional[VLLMLogprobResult]]]], mapping: Optional[List[int]], ) -> Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]]]: """Expand de-duplicated prompts back to the original ordering.""" return VLLMGenerationHelper.expand_dedup_results(grouped, meta, mapping) def _generate_with_vllm( self, prompts: List[str], num_samples: int, per_prompt_counts: Optional[List[int]] = None, *, skip_sync: bool = False, ) -> Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]]]: """Generate completions via vLLM, with dedupe/backoff handling.""" if not prompts: return [], None # Keep prompt truncation aligned with the legacy helper implementation. self.ctx.prompt_char_limit = self._prompt_char_limit() accelerator = self.ctx.accelerator set_fallback = getattr(self._vllm_helper, "set_fallback_generate", None) if callable(set_fallback): set_fallback(self._generate_local) else: setattr(self._vllm_helper, "_fallback_generate", self._generate_local) generate_fn = getattr(self._vllm_helper, "generate", None) if not callable(generate_fn): return [], None if skip_sync: helper = getattr(self, "_vllm_helper", None) sentinel = object() prev_sync = sentinel swapped = False if helper is not None: prev_sync = getattr(helper, "__dict__", {}).get( "maybe_sync_weights", sentinel ) if callable(getattr(helper, "maybe_sync_weights", None)): setattr(helper, "maybe_sync_weights", lambda *args, **kwargs: None) swapped = True try: result = generate_fn( prompts, num_samples, per_prompt_counts, ensure_client=None, sync_model=None, ) finally: if swapped and helper is not None: if prev_sync is sentinel: try: delattr(helper, "maybe_sync_weights") except AttributeError: pass else: setattr(helper, "maybe_sync_weights", prev_sync) else: result = generate_fn( prompts, num_samples, per_prompt_counts, ensure_client=self._ensure_vllm_client, sync_model=lambda model: self._sync_model_params_to_vllm( model, accelerator ), ) if isinstance(result, tuple) and len(result) == 2: grouped, meta = result if grouped is None: grouped = [] if isinstance(grouped, list): return cast(List[List[str]], grouped), cast( Optional[List[List[Optional[VLLMLogprobResult]]]], meta ) return [], None def _execute_vllm_request( self, state: _VLLMGenerationState, pending_indices: List[int], ) -> bool: """Request completions for specific prompts, grouped by need bucket.""" exec_fn = getattr(self._vllm_helper, "_execute_vllm_request", None) if callable(exec_fn): return bool(exec_fn(state, pending_indices)) return False def _flatten_prompts_for_broadcast( self, prompts: List[str], per_prompt_counts: Optional[List[int]] = None, ) -> Tuple[List[str], List[int], Optional[List[int]]]: result = self._invoke_helper( "_flatten_prompts_for_broadcast", prompts, per_prompt_counts ) if isinstance(result, tuple) and len(result) == 3: flat_prompts, offsets, flat_counts = result if isinstance(flat_prompts, list) and isinstance(offsets, list): if flat_counts is None or isinstance(flat_counts, list): return ( cast(List[str], flat_prompts), cast(List[int], offsets), cast(Optional[List[int]], flat_counts), ) return prompts, [], None def _broadcast_vllm_payload( self, flat_prompts: List[str], payload: List[Any], ) -> Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]]]: _broadcast_object_list(self.ctx.accelerator, payload, src=0) grouped_all, meta_all = payload if grouped_all is None: grouped_all = [[] for _ in flat_prompts] return grouped_all, meta_all def _scatter_vllm_payload( self, flat_prompts: List[str], offsets: List[int], grouped_all: Optional[List[List[str]]], meta_all: Optional[List[List[Optional[VLLMLogprobResult]]]], ) -> Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]]]: """Scatter per-rank slices instead of broadcasting full completions.""" result = self._invoke_helper( "_scatter_vllm_payload", flat_prompts, offsets, grouped_all, meta_all ) if isinstance(result, tuple) and len(result) == 2: grouped, meta = result if grouped is None: return [], cast(Optional[List[List[Optional[VLLMLogprobResult]]]], meta) if isinstance(grouped, list): return cast(List[List[str]], grouped), cast( Optional[List[List[Optional[VLLMLogprobResult]]]], meta ) return grouped, meta return [], None def _pluck_rank_outputs( self, grouped_all: List[List[str]], meta_all: Optional[List[List[Optional[VLLMLogprobResult]]]], offsets: List[int], prompts: List[str], ) -> Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]]]: result = self._invoke_helper( "_pluck_rank_outputs", grouped_all, meta_all, offsets, prompts ) if isinstance(result, tuple) and len(result) == 2: grouped, meta = result if grouped is None: return [], cast(Optional[List[List[Optional[VLLMLogprobResult]]]], meta) if isinstance(grouped, list): return cast(List[List[str]], grouped), cast( Optional[List[List[Optional[VLLMLogprobResult]]]], meta ) return grouped, meta return [], None def _generate_vllm_collective( self, prompts: List[str], num_samples: int, per_prompt_counts: Optional[List[int]] = None, ) -> Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]]]: """Run vLLM once on rank 0 and scatter results back to all ranks.""" self.ctx.prompt_char_limit = self._prompt_char_limit() accelerator = self.ctx.accelerator if getattr(accelerator, "num_processes", 1) <= 1: return self._generate_with_vllm(prompts, num_samples, per_prompt_counts) prev_disable_fallback = getattr(self.ctx, "vllm_disable_local_fallback", None) try: setattr(self.ctx, "vllm_disable_local_fallback", True) except Exception: prev_disable_fallback = None try: # Weight sync under ZeRO-3 uses collective gathers; run it on all ranks # before non-main processes block waiting for the scatter payload. vllm_helper = getattr(self, "_vllm_helper", None) maybe_sync = ( getattr(vllm_helper, "maybe_sync_weights", None) if vllm_helper else None ) sync_weights = bool(getattr(self.ctx, "vllm_sync_weights", False)) dist_mod = getattr(torch, "distributed", None) if ( getattr(accelerator, "num_processes", 1) > 1 and dist_mod is not None and callable(getattr(dist_mod, "is_available", None)) and callable(getattr(dist_mod, "is_initialized", None)) and dist_mod.is_available() and dist_mod.is_initialized() and callable(getattr(dist_mod, "broadcast_object_list", None)) ): payload = [sync_weights] if accelerator.is_main_process else [False] dist_mod.broadcast_object_list(payload, src=0) sync_weights = bool(payload[0]) if sync_weights != bool(getattr(self.ctx, "vllm_sync_weights", False)): try: setattr(self.ctx, "vllm_sync_weights", sync_weights) except Exception: pass if callable(maybe_sync) and sync_weights: try: maybe_sync( ensure_client=self._ensure_vllm_client, sync_model=lambda model: self._sync_model_params_to_vllm( model, accelerator ), ) except TypeError: maybe_sync() flat_prompts, offsets, flat_counts = self._flatten_prompts_for_broadcast( prompts, per_prompt_counts, ) grouped_all = None meta_all = None status_ok = True status_err = None if bool(getattr(accelerator, "is_main_process", False)): try: grouped_all, meta_all = self._generate_with_vllm( flat_prompts, num_samples, flat_counts, # Skip helper-side sync in collective mode; we already handle # weight sync (or intentionally skip it) above on all ranks. skip_sync=True, ) except Exception as exc: if isinstance(exc, TypeError) and "skip_sync" in str(exc): try: grouped_all, meta_all = self._generate_with_vllm( flat_prompts, num_samples, flat_counts, ) except Exception as retry_exc: status_ok = False status_err = str(retry_exc) else: status_ok = False status_err = str(exc) dist_mod = getattr(torch, "distributed", None) if ( getattr(accelerator, "num_processes", 1) > 1 and dist_mod is not None and callable(getattr(dist_mod, "is_available", None)) and callable(getattr(dist_mod, "is_initialized", None)) and dist_mod.is_available() and dist_mod.is_initialized() and callable(getattr(dist_mod, "broadcast_object_list", None)) ): payload = ( [{"ok": status_ok, "error": status_err}] if accelerator.is_main_process else [{"ok": False, "error": None}] ) dist_mod.broadcast_object_list(payload, src=0) status_ok = bool(payload[0].get("ok", False)) status_err = payload[0].get("error") if not status_ok: if bool(getattr(self.ctx, "vllm_disable_local_fallback", False)): msg = ( "vLLM collective generate failed and local fallback disabled: " f"{status_err}" ) if bool(getattr(accelerator, "is_main_process", False)): LOG.error(msg) raise RuntimeError(msg) if bool(getattr(accelerator, "is_main_process", False)): LOG.warning( "vLLM collective generate failed on rank 0; " "falling back to local generation on all ranks: %s", status_err, ) return self._generate_local(prompts, num_samples, per_prompt_counts) scatter_result = self._scatter_vllm_payload( flat_prompts, offsets, grouped_all, meta_all ) if isinstance(scatter_result, tuple): if len(scatter_result) != 2: grouped_res, meta_res = [], None else: grouped_res = scatter_result[0] meta_res = scatter_result[1] else: grouped_res, meta_res = scatter_result, None if grouped_res is None: grouped_res = [[] for _ in prompts] return grouped_res, meta_res finally: if prev_disable_fallback is None: try: delattr(self.ctx, "vllm_disable_local_fallback") except AttributeError: pass else: try: setattr( self.ctx, "vllm_disable_local_fallback", prev_disable_fallback ) except Exception: pass def _vllm_collective_enabled(self) -> bool: return _use_vllm_collective()
[docs] def generate( self, prompts: List[str], num_samples: int, per_prompt_counts: Optional[List[int]] = None, ) -> Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]]]: """Produce completions, preferring vLLM when configured.""" if per_prompt_counts is not None and len(per_prompt_counts) != len(prompts): raise ValueError( "per_prompt_counts length must match prompts length in generate()" ) if self.ctx.use_vllm: if not self._vllm_collective_enabled(): return self._generate_with_vllm(prompts, num_samples, per_prompt_counts) return self._generate_vllm_collective( prompts, num_samples, per_prompt_counts ) if not prompts: return [], None return self._generate_local(prompts, num_samples, per_prompt_counts)
def _gather_object_list( accelerator: AcceleratorLike, value: List[Any] ) -> List[List[Any]]: """Gather Python lists across ranks with graceful Accelerate fallbacks.""" gather_fn = getattr(accelerator, "gather_object", None) if callable(gather_fn): gathered_obj: Any = gather_fn(value) if isinstance(gathered_obj, list): return cast(List[List[Any]], gathered_obj) return [value] if dist is not None and dist.is_available() and dist.is_initialized(): world_size = dist.get_world_size() gathered: List[List[str]] = [[] for _ in range(world_size)] dist.all_gather_object(gathered, value) return gathered # Single-process fallback return [value] def _broadcast_object_list( accelerator: AcceleratorLike, payload: List[Any], *, src: int = 0 ) -> None: """Broadcast python objects even when Accelerate lacks the helper.""" broadcast_fn = getattr(accelerator, "broadcast_object_list", None) if callable(broadcast_fn): broadcast_fn(payload, src) return if dist is not None and dist.is_available() and dist.is_initialized(): broadcast = getattr(dist, "broadcast_object_list", None) if callable(broadcast): broadcast(payload, src) def _scatter_object( accelerator: AcceleratorLike, input_list: Optional[List[Any]], *, src: int = 0, ) -> Any: """Scatter python objects from src to all ranks.""" mode = os.getenv("MAXENT_SCATTER_MODE", "").strip().lower() prefer_broadcast = mode in {"broadcast", "bcast", "broadcast_object_list"} if accelerator.num_processes <= 1: if input_list is None: return None return input_list[0] idx = getattr(accelerator, "process_index", None) try: if input_list is not None and isinstance(idx, int) and idx >= len(input_list): return None except (TypeError, ValueError): return None if dist is not None and dist.is_available() and dist.is_initialized(): # Prefer broadcast-based scatter when possible. Some environments # intermittently hang inside scatter_object_list; broadcasting the full # payload is slower but tends to be more reliable. broadcast_fn = getattr(dist, "broadcast_object_list", None) if prefer_broadcast and callable(broadcast_fn): try: world_size = int(dist.get_world_size()) except (RuntimeError, TypeError, ValueError): world_size = int(getattr(accelerator, "num_processes", 1) or 1) list_ok = input_list is None or ( isinstance(input_list, list) and len(input_list) == world_size ) if world_size > 0 and list_ok: payload = ( input_list if accelerator.process_index == src and input_list is not None else [None for _ in range(world_size)] ) try: broadcast_fn(payload, src) return payload[int(getattr(accelerator, "process_index", 0))] except (RuntimeError, TypeError, ValueError): return None scatter_fn = getattr(accelerator, "scatter_object", None) if callable(scatter_fn): return scatter_fn( input_list if accelerator.process_index == src else None, src=src, ) if dist is not None and dist.is_available() and dist.is_initialized(): broadcast_fn = getattr(dist, "broadcast_object_list", None) if callable(broadcast_fn): try: world_size = int(dist.get_world_size()) except (RuntimeError, TypeError, ValueError): world_size = int(getattr(accelerator, "num_processes", 1) or 1) list_ok = input_list is None or ( isinstance(input_list, list) and len(input_list) == world_size ) if world_size > 0 and list_ok: payload = ( input_list if accelerator.process_index == src and input_list is not None else [None for _ in range(world_size)] ) try: broadcast_fn(payload, src) return payload[int(getattr(accelerator, "process_index", 0))] except (RuntimeError, TypeError, ValueError): return None scatter_fn = getattr(dist, "scatter_object_list", None) if callable(scatter_fn): output: List[Any] = [None] try: scatter_fn( output, input_list if accelerator.process_index == src else None, src, ) except (RuntimeError, ValueError, TypeError): return None return output[0] return None # Fallback to best-effort local selection if no distributed backend is initialized. if input_list is None: return None if idx is None: return None try: if idx >= len(input_list): return None except (TypeError, ValueError): return None try: return input_list[idx] except (IndexError, TypeError): return None
[docs] def gather_object_list( accelerator: AcceleratorLike, value: List[Any] ) -> List[List[Any]]: """Public alias for gathering Python objects across ranks.""" return _gather_object_list(accelerator, value)
[docs] def broadcast_object_list( accelerator: AcceleratorLike, payload: List[Any], *, src: int = 0 ) -> None: """Public alias for broadcasting Python objects across ranks.""" _broadcast_object_list(accelerator, payload, src=src) return None
[docs] def scatter_object( accelerator: AcceleratorLike, input_list: Optional[List[Any]], *, src: int = 0, ) -> Any: """Public alias for scattering Python objects across ranks.""" return _scatter_object(accelerator, input_list, src=src)
__all__ = [ "VLLMGenerationMixin", "_VLLMGenerationState", "_broadcast_object_list", "broadcast_object_list", "_gather_object_list", "gather_object_list", "_import_vllm_client_cls", "_is_peft_model_safe", "dist", "_optional_import", "_scatter_object", "scatter_object", "_zero3_gather_factory", ]