"""Weight synchronization helpers split out from the main vLLM helper."""
from __future__ import annotations
# pylint: disable=broad-exception-caught
import inspect
import logging
import os
import socket
import time
from contextlib import contextmanager, nullcontext
from types import SimpleNamespace as _SimpleNamespace
from typing import Any, Callable, Iterable, Optional, Sequence, TYPE_CHECKING, cast
from maxent_grpo.training.generation import vllm_utils as _vllm_utils
from maxent_grpo.utils.imports import optional_import
from maxent_grpo.training.runtime import require_accelerator, require_torch
torch = require_torch("generation_vllm")
Accelerator = require_accelerator("generation_vllm")
if TYPE_CHECKING: # pragma: no cover - hints only
from accelerate import Accelerator as AcceleratorType
else: # pragma: no cover - runtime fallback
AcceleratorType = Any
LOG = logging.getLogger(__name__)
def _env_flag(name: str, default: bool) -> bool:
raw = os.getenv(name)
if raw is None:
return default
return raw.strip().lower() not in {"0", "false", "no", "off"}
def _env_int(name: str, default: int) -> int:
raw = os.getenv(name)
if raw is None:
return default
try:
return int(raw)
except (TypeError, ValueError):
return default
SimpleNamespace = _SimpleNamespace # Exposed for tests that monkeypatch this module
def _mirror_log(message: str) -> None:
path = os.environ.get("MAXENT_VLLM_LOG_MIRROR_FILE")
if not path:
return
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
try:
with open(path, "a", encoding="utf-8") as handle:
handle.write(f"[mirror {timestamp}] {message}\n")
except OSError:
LOG.debug("Failed to mirror vLLM sync logs to %s", path)
def _log_sync_info(template: str, *args: Any) -> None:
LOG.info(template, *args)
try:
rendered = template % args
except Exception:
rendered = template
_mirror_log(rendered)
def _log_sync_warning(template: str, *args: Any) -> None:
LOG.warning(template, *args)
try:
rendered = template % args
except Exception:
rendered = template
_mirror_log(rendered)
@contextmanager
def _temporary_env(overrides: dict[str, str]) -> Iterable[None]:
if not overrides:
yield
return
previous: dict[str, Optional[str]] = {}
for key, value in overrides.items():
previous[key] = os.environ.get(key)
os.environ[key] = value
try:
yield
finally:
for key, prior in previous.items():
if prior is None:
os.environ.pop(key, None)
else:
os.environ[key] = prior
def _loopback_host(base_url: str) -> bool:
try:
from urllib.parse import urlparse
parsed = urlparse(base_url)
host = parsed.hostname or ""
except Exception:
host = ""
if not host:
host = base_url
host = host.strip().lower()
return host in {"localhost", "127.0.0.1", "::1"}
def _vllm_client_nccl_overrides(base_url: str) -> dict[str, str]:
overrides: dict[str, str] = {}
enable_overrides = str(
os.getenv("MAXENT_VLLM_CLIENT_NCCL_OVERRIDES", "0")
).strip().lower() in {"1", "true", "yes", "on"}
if not enable_overrides:
return overrides
if not _loopback_host(base_url):
explicit = os.getenv("MAXENT_VLLM_CLIENT_NCCL_SOCKET_IFNAME")
if explicit and "NCCL_SOCKET_IFNAME" not in os.environ:
overrides["NCCL_SOCKET_IFNAME"] = explicit
explicit = os.getenv("MAXENT_VLLM_CLIENT_NCCL_P2P_DISABLE")
if explicit and "NCCL_P2P_DISABLE" not in os.environ:
overrides["NCCL_P2P_DISABLE"] = explicit
explicit = os.getenv("MAXENT_VLLM_CLIENT_NCCL_IB_DISABLE")
if explicit and "NCCL_IB_DISABLE" not in os.environ:
overrides["NCCL_IB_DISABLE"] = explicit
return overrides
if "NCCL_SOCKET_IFNAME" not in os.environ:
overrides["NCCL_SOCKET_IFNAME"] = os.getenv(
"MAXENT_VLLM_CLIENT_NCCL_SOCKET_IFNAME", "lo"
)
if "NCCL_P2P_DISABLE" not in os.environ:
overrides["NCCL_P2P_DISABLE"] = os.getenv(
"MAXENT_VLLM_CLIENT_NCCL_P2P_DISABLE", "1"
)
if "NCCL_IB_DISABLE" not in os.environ:
overrides["NCCL_IB_DISABLE"] = os.getenv(
"MAXENT_VLLM_CLIENT_NCCL_IB_DISABLE", "1"
)
return overrides
def _resolve_vllm_group_port() -> Optional[int]:
for key in ("VLLM_GROUP_PORT", "PORT_FOR_COMMUNICATION"):
value = os.environ.get(key)
if not value:
continue
try:
return int(value)
except ValueError:
LOG.warning("Invalid %s=%r; expected an integer.", key, value)
return None
def _optional_import(module_name: str) -> Any:
"""Import a module using the shared optional import helper.
:param module_name: Dotted module path to import.
:type module_name: str
:returns: Imported module or ``None`` when unavailable.
:rtype: Any
"""
return optional_import(module_name)
def _import_vllm_client_cls(
import_fn: Optional[Callable[[str], Any]] = None,
) -> Optional[type]:
"""Return TRL's VLLMClient using the provided import helper.
:param import_fn: Optional import helper; defaults to ``_optional_import``.
:type import_fn: Callable[[str], Any] | None
:returns: VLLMClient class when import succeeds, otherwise ``None``.
:rtype: type | None
"""
return _vllm_utils.import_vllm_client_cls(import_fn or _optional_import)
def _zero3_gather_factory(
accelerator: AcceleratorType,
) -> Callable[[Sequence[Any]], Any]:
"""Return a callable that gathers parameters when ZeRO-3 is active.
:param accelerator: Accelerate instance exposing ``state.deepspeed_plugin``.
:type accelerator: accelerate.Accelerator
:returns: Callable producing a gather context manager for ZeRO-3, or a
no-op when ZeRO-3 is not enabled.
:rtype: Callable[[Sequence[Any]], Any]
"""
return _vllm_utils.zero3_gather_factory(accelerator, import_fn=_optional_import)
def _is_peft_model_safe(target: Any) -> bool:
"""Return ``True`` if accelerate.utils reports that the model uses PEFT adapters.
:param target: Model instance to inspect.
:type target: Any
:returns: Whether the model appears to be PEFT-wrapped.
:rtype: bool
"""
accelerate_utils = _optional_import("accelerate.utils")
if accelerate_utils is None:
return False
is_peft_model = getattr(accelerate_utils, "is_peft_model", None)
if not callable(is_peft_model):
return False
try:
return bool(is_peft_model(target))
except (TypeError, AttributeError, ValueError):
return False
class _ClientCallable:
"""Lightweight callable wrapper to keep static analyzers satisfied."""
def __init__(self, func: Callable[..., Any]) -> None:
"""Wrap a callable to guard attribute access in static typing.
:param func: Callable to wrap.
:type func: Callable[..., Any]
"""
self._func = func
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Forward the call to the wrapped function.
:param args: Positional arguments forwarded to ``func``.
:type args: tuple
:param kwargs: Keyword arguments forwarded to ``func``.
:type kwargs: dict
:returns: Result of the wrapped function.
:rtype: Any
"""
return self._func(*args, **kwargs)
[docs]
class VLLMWeightSyncMixin:
"""Group weight sync helpers separately from retry/resilience logic."""
_vllm_client: Any
_vllm_sync_ready: bool
_last_vllm_synced_step: Optional[int]
_last_vllm_param_version: Optional[int]
_fsdp_cls: Any
_gather_factory: Any
ctx: Any
@staticmethod
def _zero3_status_name(param: Any) -> Optional[str]:
"""Best-effort extraction of DeepSpeed ZeRO-3 status for a parameter."""
status = getattr(param, "ds_status", None)
if status is None:
return None
name = getattr(status, "name", None)
if isinstance(name, str) and name:
return name
try:
text = str(status)
except (TypeError, ValueError, RuntimeError):
return None
if "." in text:
text = text.rsplit(".", 1)[-1]
return text or None
def _zero3_param_ready_without_gather(self, param: Any) -> bool:
"""Return True if ZeRO-3 param is already available (or actively held)."""
status_name = self._zero3_status_name(param)
if status_name == "AVAILABLE":
return True
active = getattr(param, "ds_active_sub_modules", None)
try:
return bool(active)
except (TypeError, ValueError, RuntimeError):
return False
def _zero3_params_to_gather(self, params: Sequence[Any]) -> list[Any]:
"""Filter ZeRO-3 params that actually require a GatheredParameters context."""
to_gather: list[Any] = []
for param in params:
if param is None or not hasattr(param, "ds_id"):
continue
status_name = self._zero3_status_name(param)
if status_name == "INFLIGHT":
continue
if self._zero3_param_ready_without_gather(param):
continue
to_gather.append(param)
return to_gather
def _vllm_base_url(self, url: str) -> str:
"""Strip common ``/generate`` suffixes from the vLLM endpoint.
:param url: Full URL configured for the vLLM server.
:type url: str
:returns: Base URL without trailing ``/generate`` paths.
:rtype: str
"""
from urllib.parse import urlparse
try:
parsed = urlparse(url)
except ValueError:
parsed = None
if parsed is not None and parsed.scheme and parsed.netloc:
base = f"{parsed.scheme}://{parsed.netloc}"
return base.rstrip("/")
if "/generate" in url:
return url.split("/generate", 1)[0].rstrip("/")
return url.rstrip("/")
def _ensure_vllm_client(
self, import_vllm_client_cls: Optional[Callable[[], Any]] = None
) -> bool:
"""Return True when the TRL VLLMClient is ready for weight sync.
:param import_vllm_client_cls: Optional callable that imports and
returns the TRL ``VLLMClient`` class.
:type import_vllm_client_cls: Callable[[], Any] | None
:returns: Whether weight sync is ready to proceed on this rank.
:rtype: bool
"""
ctx = self.ctx
if getattr(self, "_vllm_sync_disabled", False):
return False
if not ctx.vllm_sync_weights or not ctx.accelerator.is_main_process:
return False
if self._vllm_client is not None and self._vllm_sync_ready:
return True
import_fn = import_vllm_client_cls or getattr(
self, "_import_vllm_client_cls", _import_vllm_client_cls
)
client_cls = import_fn()
if client_cls is None:
LOG.warning(
"vLLM weight sync requested but TRL VLLMClient is unavailable; skipping."
)
self._vllm_sync_ready = False
return False
if not callable(client_cls):
self._vllm_client = None
self._vllm_sync_ready = False
return False
try:
base_url = self._vllm_base_url(ctx.vllm_url)
LOG.info(
"vLLM client NCCL config | base_url=%s | vllm_url=%s | group_port=%s | NCCL_SOCKET_IFNAME=%s | MAXENT_VLLM_CLIENT_NCCL_SOCKET_IFNAME=%s | NCCL_P2P_DISABLE=%s | NCCL_IB_DISABLE=%s",
base_url,
ctx.vllm_url,
_resolve_vllm_group_port(),
os.getenv("NCCL_SOCKET_IFNAME", ""),
os.getenv("MAXENT_VLLM_CLIENT_NCCL_SOCKET_IFNAME", ""),
os.getenv("NCCL_P2P_DISABLE", ""),
os.getenv("NCCL_IB_DISABLE", ""),
)
group_port = _resolve_vllm_group_port()
try:
client_kwargs = {"base_url": base_url}
if group_port is not None:
client_kwargs["group_port"] = group_port
self._vllm_client = client_cls(**client_kwargs)
except TypeError:
try:
self._vllm_client = client_cls(base_url=base_url)
except TypeError:
self._vllm_client = client_cls()
init_attr = getattr(self._vllm_client, "init_communicator", None)
if not callable(init_attr):
LOG.warning(
"vLLM weight sync requested but TRL VLLMClient is unavailable; skipping."
)
self._vllm_client = None
self._vllm_sync_ready = False
return False
override_host = os.getenv("MAXENT_VLLM_CLIENT_HOST", "").strip()
if override_host:
try:
resolved_override = socket.gethostbyname(override_host)
except Exception:
resolved_override = override_host
current_host = getattr(self._vllm_client, "host", None)
if resolved_override != current_host:
_log_sync_info(
"Overriding vLLM client host for init_communicator | current=%s | override=%s",
current_host,
resolved_override,
)
try:
setattr(self._vllm_client, "host", resolved_override)
except Exception:
LOG.debug(
"Failed to override vLLM client host; proceeding with %s",
current_host,
)
overrides = _vllm_client_nccl_overrides(base_url)
if overrides:
_log_sync_info(
"vLLM client NCCL overrides applied | %s",
", ".join(f"{k}={v}" for k, v in overrides.items()),
)
else:
_log_sync_info("vLLM client NCCL overrides applied | none")
with _temporary_env(overrides):
_log_sync_info(
"vLLM client NCCL env effective | NCCL_SOCKET_IFNAME=%s | NCCL_P2P_DISABLE=%s | NCCL_IB_DISABLE=%s",
os.getenv("NCCL_SOCKET_IFNAME", ""),
os.getenv("NCCL_P2P_DISABLE", ""),
os.getenv("NCCL_IB_DISABLE", ""),
)
_vllm_utils.init_vllm_client_communicator(
self._vllm_client, log=_log_sync_info
)
self._vllm_sync_ready = True
setattr(self, "_vllm_last_sync_ok", True)
setattr(self, "_vllm_last_sync_error", None)
setattr(self, "_vllm_sync_failures", 0)
return True
except (
OSError,
RuntimeError,
ValueError,
) as exc: # pragma: no cover - network dependent
LOG.warning("Failed to initialize vLLMClient for weight sync: %s", exc)
setattr(self, "_vllm_last_sync_ok", False)
setattr(self, "_vllm_last_sync_error", str(exc))
failures = int(getattr(self, "_vllm_sync_failures", 0) or 0) + 1
setattr(self, "_vllm_sync_failures", failures)
if _env_flag("MAXENT_VLLM_SYNC_DISABLE_ON_FAILURE", True):
max_fails = max(1, _env_int("MAXENT_VLLM_SYNC_MAX_FAILURES", 2))
if failures >= max_fails:
if not getattr(self, "_vllm_sync_disabled", False):
LOG.warning(
"Disabling vLLM weight sync after %d init failures.",
failures,
)
setattr(self, "_vllm_sync_disabled", True)
try:
ctx.vllm_sync_weights = False
except Exception:
LOG.debug("Failed to update ctx.vllm_sync_weights; proceeding.")
self._vllm_client = None
self._vllm_sync_ready = False
return False
[docs]
def maybe_sync_weights(
self,
ensure_client: Optional[Callable[[], bool]] = None,
sync_model: Optional[Callable[..., None]] = None,
) -> None:
"""Synchronize weights to the vLLM server if configured.
:param ensure_client: Optional callable that prepares the vLLM client.
:type ensure_client: Callable[[], bool] | None
:param sync_model: Optional callable invoked to push model weights.
:type sync_model: Callable[[Any], None] | None
"""
ctx = self.ctx
sync_enabled = bool(getattr(ctx, "vllm_sync_weights", False))
setattr(self, "_vllm_sync_attempted", False)
model_obj = getattr(ctx, "model", None)
params_fn = getattr(model_obj, "parameters", None)
if callable(params_fn):
try:
params_iter = params_fn()
if isinstance(params_iter, Iterable):
params = [param for param in params_iter if param is not None]
if params and not any(
bool(getattr(param, "requires_grad", True)) for param in params
):
if not getattr(self, "_vllm_warn_no_trainable", False):
LOG.warning(
"Skipping vLLM weight sync: no trainable parameters detected."
)
setattr(self, "_vllm_warn_no_trainable", True)
else:
LOG.debug(
"Skipping vLLM weight sync: no trainable parameters detected."
)
return
else:
LOG.debug(
"Skipping vLLM trainable-parameter check: model.parameters() returned non-iterable."
)
except (AttributeError, RuntimeError, TypeError, ValueError) as exc:
LOG.debug(
"Failed to inspect trainable parameters for vLLM sync: %s", exc
)
accelerator = ctx.accelerator
dist = getattr(torch, "distributed", None)
rank = getattr(accelerator, "process_index", None)
local_rank = getattr(accelerator, "local_process_index", None)
if (
rank is None
and dist is not None
and dist.is_available()
and dist.is_initialized()
):
try:
rank = dist.get_rank()
except Exception:
rank = None
world_size = getattr(accelerator, "num_processes", None)
if (
not world_size
and dist is not None
and dist.is_available()
and dist.is_initialized()
):
try:
world_size = dist.get_world_size()
except Exception:
world_size = None
if not world_size:
world_size = 1
is_main = getattr(accelerator, "is_main_process", True)
# Keep vLLM sync enabled/disabled consistent across ranks to avoid
# mismatched collective calls (e.g., during eval).
if (
getattr(accelerator, "num_processes", 1) > 1
and dist is not None
and callable(getattr(dist, "is_available", None))
and callable(getattr(dist, "is_initialized", None))
and dist.is_available()
and dist.is_initialized()
and callable(getattr(dist, "broadcast_object_list", None))
):
payload = [bool(sync_enabled)] if is_main else [False]
dist.broadcast_object_list(payload, src=0)
sync_enabled = bool(payload[0])
if not sync_enabled:
return
disable_sync = bool(getattr(self, "_vllm_sync_disabled", False))
if (
getattr(accelerator, "num_processes", 1) > 1
and dist is not None
and callable(getattr(dist, "is_available", None))
and callable(getattr(dist, "is_initialized", None))
and dist.is_available()
and dist.is_initialized()
and callable(getattr(dist, "broadcast_object_list", None))
):
payload = [bool(disable_sync)] if is_main else [False]
dist.broadcast_object_list(payload, src=0)
disable_sync = bool(payload[0])
if disable_sync and not getattr(self, "_vllm_sync_disabled", False):
setattr(self, "_vllm_sync_disabled", True)
if disable_sync:
return
current_step = ctx.generation_stats.get("current_step")
sync_interval = getattr(ctx, "vllm_sync_interval_steps", None)
if sync_interval is None:
training_args = getattr(ctx, "training_args", None)
sync_interval = getattr(training_args, "vllm_sync_interval_steps", None)
if sync_interval is not None:
try:
sync_interval = int(sync_interval)
except (TypeError, ValueError):
LOG.warning(
"Invalid vllm_sync_interval_steps=%s; ignoring.",
sync_interval,
)
sync_interval = None
if sync_interval is not None and sync_interval <= 0:
if not getattr(self, "_vllm_warn_sync_interval", False):
LOG.warning(
"Skipping vLLM weight sync: vllm_sync_interval_steps=%s.",
sync_interval,
)
setattr(self, "_vllm_warn_sync_interval", True)
else:
LOG.debug(
"Skipping vLLM weight sync: vllm_sync_interval_steps=%s.",
sync_interval,
)
return
ensure_fn = ensure_client or self._ensure_vllm_client
# Decide once (on rank 0) whether to run the ZeRO-3 gather + sync path,
# then broadcast to all ranks. This avoids deadlocks where non-main
# ranks enter the gather path while rank 0 returns early.
should_sync = True
current_version_sig: Optional[int] = None
if is_main and current_step is not None:
try:
last_synced = self._last_vllm_synced_step
current_step_int = int(current_step)
if (
sync_interval is not None
and last_synced is not None
and current_step_int - int(last_synced) < sync_interval
):
should_sync = False
else:
should_sync = last_synced != current_step_int
except (TypeError, ValueError):
should_sync = True
if is_main and should_sync:
current_version_sig = self._param_version_signature(model_obj)
last_sig = getattr(self, "_last_vllm_param_version", None)
if (
current_version_sig is not None
and last_sig is not None
and current_version_sig == last_sig
):
LOG.debug(
"Skipping vLLM weight sync: parameter version signature unchanged."
)
should_sync = False
if (
getattr(accelerator, "num_processes", 1) > 1
and dist is not None
and callable(getattr(dist, "is_available", None))
and callable(getattr(dist, "is_initialized", None))
and dist.is_available()
and dist.is_initialized()
and callable(getattr(dist, "broadcast_object_list", None))
):
payload = [bool(should_sync)]
dist.broadcast_object_list(payload, src=0)
should_sync = bool(payload[0])
if not should_sync:
return
_log_sync_info(
"vLLM weight sync check | step=%s | rank=%s/%s | local_rank=%s | is_main=%s | should_sync=%s",
current_step,
rank,
world_size,
local_rank,
is_main,
should_sync,
)
def _is_zero3(accel: Any) -> bool:
ds_plugin = getattr(getattr(accel, "state", None), "deepspeed_plugin", None)
try:
return int(getattr(ds_plugin, "zero_stage", 0) or 0) == 3
except (TypeError, ValueError):
return False
# Only the main process should talk to the vLLM HTTP client. Other ranks
# still participate in ZeRO gathers so parameter states stay aligned.
ready = False
if is_main:
setattr(self, "_vllm_sync_attempted", True)
ensure_start = time.monotonic()
ready = bool(ensure_fn())
_log_sync_info(
"vLLM weight sync ensure_client done | step=%s | rank=%s/%s | ready=%s | seconds=%.2f",
current_step,
rank,
world_size,
ready,
time.monotonic() - ensure_start,
)
if not ready:
# In collective vLLM generation, only rank 0 issues the vLLM request,
# but ZeRO-3 parameter gathering is a collective op that requires
# participation from every rank. Allow non-main ranks to run the
# gather-only path (no client updates) to avoid deadlocks.
if _is_zero3(accelerator) and callable(sync_model):
try:
try:
model = accelerator.unwrap_model(ctx.model)
except (AttributeError, TypeError):
model = ctx.model
start = time.monotonic()
_log_sync_info(
"vLLM weight sync gather-only start | step=%s | rank=%s/%s | local_rank=%s",
current_step,
rank,
world_size,
local_rank,
)
sync_model(model)
_log_sync_info(
"vLLM weight sync gather-only done | step=%s | rank=%s/%s | seconds=%.2f",
current_step,
rank,
world_size,
time.monotonic() - start,
)
except (RuntimeError, ValueError, TypeError) as exc:
_log_sync_warning("vLLM weight sync (gather-only) failed: %s", exc)
wait_for_all = getattr(accelerator, "wait_for_everyone", None)
if callable(wait_for_all):
wait_start = time.monotonic()
_log_sync_info(
"vLLM weight sync gather-only wait_for_everyone start | step=%s | rank=%s/%s",
current_step,
rank,
world_size,
)
wait_for_all()
_log_sync_info(
"vLLM weight sync gather-only wait_for_everyone done | step=%s | rank=%s/%s | seconds=%.2f",
current_step,
rank,
world_size,
time.monotonic() - wait_start,
)
if is_main:
_log_sync_info(
"vLLM weight sync barrier complete | step=%s | rank=%s/%s | next=vLLM request",
current_step,
rank,
world_size,
)
else:
wait_for_all = getattr(accelerator, "wait_for_everyone", None)
if callable(wait_for_all):
wait_start = time.monotonic()
_log_sync_info(
"vLLM weight sync wait_for_everyone start | step=%s | rank=%s/%s",
current_step,
rank,
world_size,
)
wait_for_all()
_log_sync_info(
"vLLM weight sync wait_for_everyone done | step=%s | rank=%s/%s | seconds=%.2f",
current_step,
rank,
world_size,
time.monotonic() - wait_start,
)
return
if current_step is not None and self._last_vllm_synced_step == int(
current_step
):
return
start = time.monotonic()
_log_sync_info(
"vLLM weight sync push start | step=%s | rank=%s/%s | local_rank=%s",
current_step,
rank,
world_size,
local_rank,
)
try:
model = accelerator.unwrap_model(ctx.model)
except (AttributeError, TypeError):
model = ctx.model
sync_fn = sync_model or self._sync_model_params_to_vllm
visited: set[str] = set()
try:
try:
sig = inspect.signature(sync_fn)
accepts_visited = "visited" in sig.parameters
except (TypeError, ValueError):
accepts_visited = False
if accepts_visited:
sync_fn(model, visited=visited)
else:
sync_fn(model)
stats = ctx.generation_stats
stats["vllm_weight_syncs"] = int(stats.get("vllm_weight_syncs", 0)) + 1
if current_step is not None:
self._last_vllm_synced_step = int(current_step)
if is_main and current_version_sig is not None:
self._last_vllm_param_version = current_version_sig
except (
RuntimeError,
ValueError,
) as exc: # pragma: no cover - runtime dependent
_log_sync_warning("Skipping vLLM weight sync due to error: %s", exc)
else:
elapsed = time.monotonic() - start
_log_sync_info(
"vLLM weight sync push done | step=%s | rank=%s/%s | seconds=%.2f",
current_step,
rank,
world_size,
elapsed,
)
wait_for_all = getattr(accelerator, "wait_for_everyone", None)
if callable(wait_for_all):
wait_start = time.monotonic()
_log_sync_info(
"vLLM weight sync wait_for_everyone start | step=%s | rank=%s/%s",
current_step,
rank,
world_size,
)
wait_for_all()
_log_sync_info(
"vLLM weight sync wait_for_everyone done | step=%s | rank=%s/%s | seconds=%.2f",
current_step,
rank,
world_size,
time.monotonic() - wait_start,
)
if is_main:
_log_sync_info(
"vLLM weight sync barrier complete | step=%s | rank=%s/%s | next=vLLM request",
current_step,
rank,
world_size,
)
def _param_version_signature(self, model: Any) -> Optional[int]:
"""Return a cheap signature based on torch parameter version counters."""
params_fn = getattr(model, "parameters", None)
if not callable(params_fn):
return None
total = 0
count = 0
try:
for param in cast(Iterable[Any], params_fn()):
if param is None:
continue
version = getattr(param, "_version", None)
if isinstance(version, int):
total += version
count += 1
except (AttributeError, RuntimeError, TypeError, ValueError):
return None
if count <= 0:
return None
return total
def _sync_log_param(self, name: str, param: Any) -> bool:
state = getattr(self, "_vllm_sync_log_state", None)
if not isinstance(state, dict):
return False
state["push_count"] = int(state.get("push_count", 0)) + 1
idx = int(state["push_count"])
log_every = int(state.get("log_every", 50))
if log_every > 0 and idx % log_every == 0:
_log_sync_info(
"vLLM weight sync param push | idx=%d | name=%s | shape=%s | dtype=%s",
idx,
name,
getattr(param, "shape", None),
getattr(param, "dtype", None),
)
max_params = state.get("max_params")
if isinstance(max_params, int) and max_params > 0 and idx >= max_params:
if not state.get("stop"):
_log_sync_warning(
"vLLM weight sync early stop | max_params=%d reached",
max_params,
)
state["stop"] = True
return True
return bool(state.get("stop"))
def _sync_log_should_stop(self) -> bool:
state = getattr(self, "_vllm_sync_log_state", None)
if not isinstance(state, dict):
return False
return bool(state.get("stop"))
def _client_callable(self, attr_name: str) -> Optional[_ClientCallable]:
"""Return a callable attribute from the vLLM client if available.
:param attr_name: Attribute name to fetch from the client.
:type attr_name: str
:returns: Wrapped callable or ``None`` when missing.
:rtype: _ClientCallable | None
"""
client = self._vllm_client
if client is None:
return None
candidate = getattr(client, attr_name, None)
if not callable(candidate):
return None
return _ClientCallable(candidate)
def _flush_vllm_client(self) -> None:
"""Flush buffered vLLM updates when the client supports it."""
client = self._vllm_client
if client is None:
return
flush = getattr(client, "flush", None)
if not callable(flush):
return
try:
flush_fn = cast(Callable[[], None], flush)
flush_fn() # pylint: disable=not-callable
except (RuntimeError, ValueError, AttributeError, TypeError) as exc:
LOG.debug("Failed to flush vLLM client updates: %s", exc)
def _sync_model_params_to_vllm(
self,
model: Any,
visited: Optional[set[str]] = None,
) -> None:
"""Push model parameters to the vLLM side, handling FSDP/PEFT cases.
:param model: Model instance whose parameters should be synchronized.
:type model: Any
:param visited: Optional set of parameter names that have already been
synchronized (used for recursion).
:type visited: set[str] | None
"""
fsdp_cls = self._fsdp_cls
visited = visited if visited is not None else set()
log_every_raw = os.getenv("MAXENT_VLLM_SYNC_LOG_EVERY", "50")
try:
log_every = max(1, int(log_every_raw))
except (TypeError, ValueError):
log_every = 50
max_params_raw = os.getenv("MAXENT_VLLM_SYNC_MAX_PARAMS", "")
max_params: Optional[int]
if max_params_raw:
try:
max_params = max(1, int(max_params_raw))
except (TypeError, ValueError):
max_params = None
else:
max_params = None
self._vllm_sync_log_state = {
"log_every": log_every,
"max_params": max_params,
"push_count": 0,
"stop": False,
}
try:
def _has_summon_full_params(target: Any) -> bool:
try:
return callable(getattr(target, "summon_full_params"))
except (AttributeError, RuntimeError, TypeError, ValueError):
return False
class_summon = getattr(type(model), "summon_full_params", None)
if fsdp_cls is None:
fsdp_mod = getattr(getattr(torch, "distributed", None), "fsdp", None)
fsdp_cls = (
getattr(fsdp_mod, "FullyShardedDataParallel", None)
if fsdp_mod
else None
)
has_summon = _has_summon_full_params(model)
if has_summon and callable(class_summon):
if fsdp_cls is None or not isinstance(model, fsdp_cls):
fsdp_cls = type(model)
if fsdp_cls is not None and (
self._fsdp_cls is None or not isinstance(model, self._fsdp_cls)
):
self._fsdp_cls = fsdp_cls
if fsdp_cls is not None and isinstance(model, fsdp_cls):
named_children = getattr(model, "named_children", None)
children = (
list(cast(Iterable[tuple[str, Any]], named_children()))
if callable(named_children)
else []
)
modules_to_sync = children or [("", model)]
for base_name, base_module in modules_to_sync:
named_params = getattr(base_module, "named_parameters", None)
if not callable(named_params):
continue
for pname, param in cast(Iterable[tuple[str, Any]], named_params()):
full_name = f"{base_name}.{pname}" if base_name else pname
for extra in (
"_fsdp_wrapped_module.",
"_checkpoint_wrapped_module.",
):
full_name = full_name.replace(extra, "")
if full_name in visited:
continue
visited.add(full_name)
self._sync_log_param(full_name, param)
self._push_param_to_vllm(full_name, param)
if self._sync_log_should_stop():
self._reset_vllm_cache()
return
self._reset_vllm_cache()
return
if not has_summon:
has_summon = _has_summon_full_params(model)
if callable(class_summon) and (
fsdp_cls is None or not isinstance(model, fsdp_cls)
):
fsdp_cls = type(model)
if self._fsdp_cls is None or not isinstance(model, self._fsdp_cls):
self._fsdp_cls = fsdp_cls
if fsdp_cls is not None and isinstance(model, fsdp_cls):
named_children = getattr(model, "named_children", None)
children = (
list(cast(Iterable[tuple[str, Any]], named_children()))
if callable(named_children)
else []
)
modules_to_sync = children or [("", model)]
for base_name, base_module in modules_to_sync:
named_params = getattr(base_module, "named_parameters", None)
if not callable(named_params):
continue
for pname, param in cast(Iterable[tuple[str, Any]], named_params()):
full_name = f"{base_name}.{pname}" if base_name else pname
for extra in (
"_fsdp_wrapped_module.",
"_checkpoint_wrapped_module.",
):
full_name = full_name.replace(extra, "")
if full_name in visited:
continue
visited.add(full_name)
self._sync_log_param(full_name, param)
self._push_param_to_vllm(full_name, param)
if self._sync_log_should_stop():
self._reset_vllm_cache()
return
self._reset_vllm_cache()
return
if has_summon:
def _walk(module: Any, prefix: str = "") -> None:
named_children = getattr(module, "named_children", None)
children = (
list(cast(Iterable[tuple[str, Any]], named_children()))
if callable(named_children)
else []
)
named_params = getattr(module, "named_parameters", None)
if callable(named_params):
for raw_name, param in cast(
Iterable[tuple[str, Any]], named_params()
):
if param is None:
continue
clean = raw_name
for extra in (
"_fsdp_wrapped_module.",
"_checkpoint_wrapped_module.",
):
clean = clean.replace(extra, "")
full_name = f"{prefix}.{clean}" if prefix else clean
if children and any(
raw_name.startswith(extra)
for extra in (
"_fsdp_wrapped_module.",
"_checkpoint_wrapped_module.",
)
):
continue
if full_name in visited:
continue
visited.add(full_name)
self._sync_log_param(full_name, param)
self._push_param_to_vllm(full_name, param)
if self._sync_log_should_stop():
return
for child_name, child in children:
child_prefix = (
f"{prefix}.{child_name}" if prefix else child_name
)
_walk(child, child_prefix)
_walk(model)
root_params = getattr(model, "named_parameters", None)
if callable(root_params):
for raw_name, param in cast(
Iterable[tuple[str, Any]], root_params()
):
clean = raw_name
for extra in (
"_fsdp_wrapped_module.",
"_checkpoint_wrapped_module.",
):
clean = clean.replace(extra, "")
if any(
raw_name.startswith(extra)
for extra in (
"_fsdp_wrapped_module.",
"_checkpoint_wrapped_module.",
)
):
continue
if clean in visited:
continue
visited.add(clean)
self._sync_log_param(clean, param)
self._push_param_to_vllm(clean, param)
if self._sync_log_should_stop():
self._reset_vllm_cache()
return
self._reset_vllm_cache()
return
is_peft_fn = getattr(self, "_is_peft_model_safe", _is_peft_model_safe)
if is_peft_fn(model):
self._sync_peft_params(model)
self._reset_vllm_cache()
return
self._sync_standard_params(model)
self._reset_vllm_cache()
finally:
try:
delattr(self, "_vllm_sync_log_state")
except AttributeError:
self._vllm_sync_log_state = None
def _push_param_to_vllm(self, name: str, param: Any) -> None:
"""Send a single parameter tensor to the vLLM client if available.
:param name: Fully qualified parameter name.
:type name: str
:param param: Tensor to push to the vLLM server.
:type param: Any
"""
if param is None:
return
update_fn = self._client_callable("update_named_param")
if update_fn is None:
return
try:
update_fn(name, param.data)
except (
RuntimeError,
ValueError,
) as exc: # pragma: no cover - network dependent
LOG.warning("Failed to push param %s to vLLM: %s", name, exc)
[docs]
def push_param_to_vllm(self, name: str, param: Any) -> None:
"""Public wrapper forwarding to the protected vLLM param push."""
self._push_param_to_vllm(name, param)
def _reset_vllm_cache(self) -> None:
"""Reset the vLLM prefix cache if the client exposes the hook."""
self._flush_vllm_client()
reset_fn = self._client_callable("reset_prefix_cache")
if reset_fn is None:
return
try:
reset_fn()
except (RuntimeError, ValueError, AttributeError):
return
[docs]
def reset_vllm_cache(self) -> None:
"""Public wrapper that resets the vLLM prefix cache."""
self._reset_vllm_cache()
def _sync_standard_params(
self,
model: Any,
gather_factory: Optional[Callable[[Sequence[Any]], Any]] = None,
prefix: str = "",
visited: Optional[set[str]] = None,
) -> None:
"""Synchronize standard (non-FSDP/PEFT) model parameters.
:param model: Model instance whose parameters are being pushed.
:type model: Any
:param gather_factory: Optional context manager factory for ZeRO-3.
:type gather_factory: Callable[[Sequence[Any]], Any] | None
:param prefix: Name prefix to prepend to parameters.
:type prefix: str
:param visited: Parameter names already synced to avoid duplicates.
:type visited: set[str] | None
"""
factory = gather_factory or self._gather_factory or (lambda _p: nullcontext())
visited = visited if visited is not None else set()
named_params = getattr(model, "named_parameters", None)
iterator: list[tuple[str, Any]] = []
if callable(named_params):
try:
iterator = list(
cast(Iterable[tuple[str, Any]], named_params(recurse=False))
)
except TypeError:
iterator = list(cast(Iterable[tuple[str, Any]], named_params()))
params = [param for _, param in iterator if param is not None]
params_to_gather = self._zero3_params_to_gather(params)
if not params_to_gather and params:
saw_zero3 = any(self._zero3_status_name(p) is not None for p in params)
if not saw_zero3:
params_to_gather = params
try:
ctx = factory(params_to_gather)
except NameError:
ctx = nullcontext()
gather_start = time.monotonic()
_log_sync_info(
"vLLM weight sync gather ctx start | params=%d | prefix=%s",
len(params_to_gather),
prefix,
)
with ctx:
for name, param in iterator:
if param is None:
continue
if self._zero3_status_name(param) == "INFLIGHT":
continue
clean = f"{prefix}{name}" if prefix else name
for extra in (
"_fsdp_wrapped_module.",
"_checkpoint_wrapped_module.",
):
clean = clean.replace(extra, "")
if clean in visited:
continue
visited.add(clean)
self._sync_log_param(clean, param)
self._push_param_to_vllm(clean, param)
if self._sync_log_should_stop():
return
_log_sync_info(
"vLLM weight sync gather ctx done | params=%d | prefix=%s | seconds=%.2f",
len(params_to_gather),
prefix,
time.monotonic() - gather_start,
)
named_children = getattr(model, "named_children", None)
if callable(named_children):
for child_name, child in cast(Iterable[tuple[str, Any]], named_children()):
child_prefix = f"{prefix}{child_name}."
self._sync_standard_params(
child, gather_factory, child_prefix, visited=visited
)
def _sync_peft_params(
self,
model: Any,
gather_factory: Optional[Callable[[Sequence[Any]], Any]] = None,
) -> None:
"""Synchronize PEFT adapter parameters to the vLLM server.
:param model: PEFT model instance.
:type model: Any
:param gather_factory: Optional context manager factory for ZeRO-3.
:type gather_factory: Callable[[Sequence[Any]], Any] | None
"""
merge_fn = getattr(model, "merge_adapter", None)
unmerge_fn = getattr(model, "unmerge_adapter", None)
params = list(model.parameters())
factory = gather_factory or self._gather_factory or (lambda _p: nullcontext())
params_to_gather = self._zero3_params_to_gather(params)
if not params_to_gather and params:
saw_zero3 = any(self._zero3_status_name(p) is not None for p in params)
if not saw_zero3:
params_to_gather = params
gather_start = time.monotonic()
_log_sync_info(
"vLLM weight sync gather ctx start | params=%d | peft=true",
len(params_to_gather),
)
with factory(params_to_gather):
if callable(merge_fn):
merge_fn()
for name, param in model.named_parameters():
clean = (
name.replace("modules_to_save.default.", "")
.replace("base_model.model.", "")
.replace(".base_layer", "")
)
if getattr(model, "prefix", None) and str(model.prefix) in clean:
continue
if "original_module" in clean:
continue
self._sync_log_param(clean, param)
self._push_param_to_vllm(clean, param)
if self._sync_log_should_stop():
return
if callable(unmerge_fn):
unmerge_fn()
_log_sync_info(
"vLLM weight sync gather ctx done | params=%d | peft=true | seconds=%.2f",
len(params_to_gather),
time.monotonic() - gather_start,
)
def _sync_fsdp_params(
self,
module: Any,
gather_factory: Optional[Callable[[Sequence[Any]], Any]] = None,
prefix: str = "",
fsdp_cls: Any = None,
visited: Optional[set[str]] = None,
) -> None:
"""Synchronize parameters for FSDP-wrapped modules.
:param module: Module wrapped by FullyShardedDataParallel.
:type module: Any
:param gather_factory: Optional context manager factory for ZeRO-3.
:type gather_factory: Callable[[Sequence[Any]], Any] | None
:param prefix: Prefix to prepend to parameter names.
:type prefix: str
:param fsdp_cls: FSDP class used to detect wrapped modules.
:type fsdp_cls: Any
:param visited: Parameter names already synced to avoid duplicates.
:type visited: set[str] | None
"""
fsdp_cls = fsdp_cls or self._fsdp_cls
if fsdp_cls is None:
return
try:
params = list(module.parameters()) if hasattr(module, "parameters") else []
except AttributeError:
params = []
visited = visited or set()
factory = gather_factory or self._gather_factory or (lambda _p: nullcontext())
params_to_gather = self._zero3_params_to_gather(params)
if not params_to_gather and params:
saw_zero3 = any(self._zero3_status_name(p) is not None for p in params)
if not saw_zero3:
params_to_gather = params
gather_start = time.monotonic()
_log_sync_info(
"vLLM weight sync gather ctx start | params=%d | fsdp=true | prefix=%s",
len(params_to_gather),
prefix,
)
with factory(params_to_gather):
named_params = getattr(module, "named_parameters", None)
if callable(named_params):
for name, param in cast(Iterable[tuple[str, Any]], named_params()):
full_name = f"{prefix}{name}" if prefix else name
for extra in (
"_fsdp_wrapped_module.",
"_checkpoint_wrapped_module.",
):
full_name = full_name.replace(extra, "")
if full_name in visited:
continue
visited.add(full_name)
self._sync_log_param(full_name, param)
self._push_param_to_vllm(full_name, param)
if self._sync_log_should_stop():
return
named_children = getattr(module, "named_children", None)
if callable(named_children):
for child_name, child in cast(
Iterable[tuple[str, Any]], named_children()
):
child_prefix = f"{prefix}{child_name}."
self._sync_fsdp_params(
child,
gather_factory=gather_factory,
prefix=child_prefix,
fsdp_cls=fsdp_cls,
visited=visited,
)
_log_sync_info(
"vLLM weight sync gather ctx done | params=%d | fsdp=true | prefix=%s | seconds=%.2f",
len(params_to_gather),
prefix,
time.monotonic() - gather_start,
)
if isinstance(module, fsdp_cls) and callable(
getattr(module, "named_parameters", None)
):
summon_start = time.monotonic()
_log_sync_info(
"vLLM weight sync summon_full_params start | prefix=%s",
prefix,
)
with fsdp_cls.summon_full_params(module, recurse=False, writeback=False):
for pname, param in module.named_parameters():
full_name = f"{prefix}{pname}" if prefix else pname
for extra in (
"_fsdp_wrapped_module.",
"_checkpoint_wrapped_module.",
):
full_name = full_name.replace(extra, "")
if full_name in visited:
continue
visited.add(full_name)
self._sync_log_param(full_name, param)
self._push_param_to_vllm(full_name, param)
if self._sync_log_should_stop():
return
_log_sync_info(
"vLLM weight sync summon_full_params done | prefix=%s | seconds=%.2f",
prefix,
time.monotonic() - summon_start,
)
[docs]
def sync_fsdp_params(self, module: Any) -> None:
"""Public wrapper to synchronize FSDP parameters to vLLM."""
self._sync_fsdp_params(module)
__all__ = [
"Accelerator",
"VLLMWeightSyncMixin",
"_ClientCallable",
"_import_vllm_client_cls",
"_is_peft_model_safe",
"_optional_import",
"_zero3_gather_factory",
]