Source code for maxent_grpo.training.rollout.vllm_colocate

"""In-process (colocated) vLLM generation helpers for the custom loop."""

from __future__ import annotations
# pylint: disable=broad-exception-caught,protected-access

import atexit
import faulthandler
import gc
import inspect
import logging
import multiprocessing as mp
import os
import socket
import sys
import tempfile
import threading
import time
import traceback
from typing import Any, Callable, Dict, List, Optional, Tuple

from maxent_grpo.training.generation.vocab_guard import (
    merge_invalid_token_block_logit_bias,
    resolve_blocked_token_ids,
    resolve_allowed_token_ids,
)
from maxent_grpo.training.patches.vllm import VLLMLogprobResult
from maxent_grpo.training.runtime.prompts import _truncate_prompt, PROMPT_CHAR_LIMIT
from maxent_grpo.utils.imports import optional_import

LOG = logging.getLogger(__name__)


def _parse_log_level(raw: str) -> Optional[int]:
    value = raw.strip()
    if not value:
        return None
    try:
        return int(value)
    except ValueError:
        return logging._nameToLevel.get(value.upper())


def _configure_colocate_logging() -> None:
    raw = os.getenv("MAXENT_VLLM_COLOCATE_LOG_LEVEL")
    if raw:
        level = _parse_log_level(raw)
        if isinstance(level, int):
            LOG.setLevel(level)
        else:
            LOG.warning("Invalid MAXENT_VLLM_COLOCATE_LOG_LEVEL=%r; ignoring.", raw)
        return
    # Default to warnings-only to avoid excessive colocate logging once stable.
    LOG.setLevel(logging.WARNING)


_configure_colocate_logging()


def _env_float(name: str) -> Optional[float]:
    raw = os.getenv(name)
    if raw is None:
        return None
    try:
        return float(raw)
    except (TypeError, ValueError):
        LOG.debug("Invalid %s=%r; ignoring.", name, raw)
        return None


def _env_int(name: str) -> Optional[int]:
    raw = os.getenv(name)
    if raw is None:
        return None
    try:
        return int(raw)
    except (TypeError, ValueError):
        LOG.debug("Invalid %s=%r; ignoring.", name, raw)
        return None


def _sync_chunk_bytes() -> int:
    """Return the max payload size for colocate sync batches (bytes)."""
    raw = os.getenv("MAXENT_VLLM_COLOCATE_SYNC_CHUNK_MB", "64")
    try:
        mb = int(raw)
    except (TypeError, ValueError):
        mb = 64
    if mb <= 0:
        mb = 64
    return mb * 1024 * 1024


def _filter_kwargs(callable_obj: Any, kwargs: Dict[str, Any]) -> Dict[str, Any]:
    try:
        sig = inspect.signature(callable_obj)
    except (TypeError, ValueError):
        return kwargs
    if any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values()):
        return kwargs
    return {key: val for key, val in kwargs.items() if key in sig.parameters}


def _resolve_model_id(ctx: Any) -> Optional[str]:
    for key in (
        "vllm_model_id",
        "served_model_id",
        "model_name",
        "model_id",
        "hub_model_id",
        "model_name_or_path",
    ):
        value = getattr(ctx, key, None)
        if isinstance(value, str) and value:
            return value
    model = getattr(ctx, "model", None)
    if model is not None:
        name = getattr(model, "name_or_path", None)
        if isinstance(name, str) and name:
            return name
        cfg = getattr(model, "config", None)
        cfg_name = getattr(cfg, "name_or_path", None) or getattr(
            cfg, "_name_or_path", None
        )
        if isinstance(cfg_name, str) and cfg_name:
            return cfg_name
    training_args = getattr(ctx, "training_args", None)
    if training_args is not None:
        for key in ("model_name_or_path", "hub_model_id", "model_id"):
            value = getattr(training_args, key, None)
            if isinstance(value, str) and value:
                return value
    return None


def _env_bool(name: str) -> Optional[bool]:
    raw = os.getenv(name)
    if raw is None:
        return None
    value = raw.strip().lower()
    if value in {"1", "true", "yes", "y", "on"}:
        return True
    if value in {"0", "false", "no", "n", "off"}:
        return False
    LOG.debug("Invalid %s=%r; ignoring.", name, raw)
    return None


def _use_vllm_collective() -> bool:
    raw = os.getenv("MAXENT_VLLM_COLLECTIVE")
    if raw is None:
        return True
    return raw.strip().lower() not in {"0", "false", "no", "off"}


def _resolve_dtype(ctx: Any) -> Optional[str]:
    training_args = getattr(ctx, "training_args", None)
    if training_args is None:
        return None
    if getattr(training_args, "fp16", False):
        return "float16"
    if getattr(training_args, "bf16", False):
        return "bfloat16"
    return None


def _init_mode() -> str:
    raw = os.getenv("MAXENT_VLLM_COLOCATE_INIT_MODE", "").strip().lower()
    if raw in {"subprocess", "process", "proc"}:
        return "subprocess"
    if raw in {"thread", "async", "background"}:
        return "thread"
    if raw in {"blocking", "sync", "foreground"}:
        return "blocking"
    return "auto"


def _dist_initialized() -> bool:
    try:
        import torch

        dist = getattr(torch, "distributed", None)
        return bool(
            dist
            and hasattr(dist, "is_available")
            and hasattr(dist, "is_initialized")
            and dist.is_available()
            and dist.is_initialized()
        )
    except Exception:
        return False


def _log_env_snapshot(keys: List[str]) -> None:
    items: List[str] = []
    for key in keys:
        value = os.getenv(key)
        if value is None:
            continue
        items.append(f"{key}={value}")
    if items:
        LOG.info("vLLM colocate env snapshot | %s", " ".join(items))


def _log_torch_snapshot() -> None:
    try:
        import torch
    except Exception as exc:
        LOG.info("vLLM colocate torch snapshot skipped | error=%s", exc)
        return
    try:
        cuda_available = bool(torch.cuda.is_available())
    except Exception:
        cuda_available = False
    device_idx = None
    device_name = None
    total_mem = None
    free_mem = None
    try:
        if cuda_available:
            device_idx = torch.cuda.current_device()
            device_name = torch.cuda.get_device_name(device_idx)
            if hasattr(torch.cuda, "mem_get_info"):
                free_mem, total_mem = torch.cuda.mem_get_info(device_idx)
    except Exception:
        pass
    LOG.info(
        "vLLM colocate torch | cuda_available=%s device=%s name=%s free_mem=%s total_mem=%s",
        cuda_available,
        device_idx,
        device_name,
        free_mem,
        total_mem,
    )


def _log_process_snapshot() -> None:
    try:
        pid = os.getpid()
    except Exception:
        pid = None
    try:
        host = socket.gethostname()
    except Exception:
        host = None
    LOG.info(
        "vLLM colocate process | pid=%s thread=%s host=%s",
        pid,
        threading.current_thread().name,
        host,
    )


def _log_runtime_snapshot() -> None:
    try:
        python_version = sys.version.replace("\n", " ")
    except Exception:
        python_version = None
    LOG.info(
        "vLLM colocate runtime | python=%s executable=%s",
        python_version,
        getattr(sys, "executable", None),
    )


def _extract_logprob_sequence(raw: Any) -> Optional[List[float]]:
    if raw is None:
        return None
    if isinstance(raw, list):
        cleaned: List[float] = []
        for entry in raw:
            if entry is None:
                continue
            if isinstance(entry, (int, float)):
                cleaned.append(float(entry))
                continue
            if isinstance(entry, dict) and entry:
                val = next(iter(entry.values()))
            else:
                val = entry
            if isinstance(val, dict):
                val = val.get("logprob", val.get("log_prob"))
            elif hasattr(val, "logprob"):
                val = getattr(val, "logprob", None)
            if isinstance(val, (int, float)):
                cleaned.append(float(val))
        return cleaned if cleaned else None
    return None


def _sum_logprobs(values: Optional[List[float]]) -> Optional[float]:
    if not values:
        return None
    total = 0.0
    for val in values:
        total += float(val)
    return total


def _coerce_logprob_payload(
    payload: Optional[List[List[Optional[Dict[str, Any]]]]],
) -> Optional[List[List[Optional[VLLMLogprobResult]]]]:
    if payload is None:
        return None
    converted: List[List[Optional[VLLMLogprobResult]]] = []
    for group in payload:
        group_converted: List[Optional[VLLMLogprobResult]] = []
        for entry in group:
            if entry is None:
                group_converted.append(None)
                continue
            if isinstance(entry, VLLMLogprobResult):
                group_converted.append(entry)
                continue
            if isinstance(entry, dict):
                group_converted.append(
                    VLLMLogprobResult(
                        logprob_sum=entry.get("logprob_sum"),
                        token_count=entry.get("token_count"),
                        token_logprobs=entry.get("token_logprobs"),
                        raw_output=entry.get("raw_output"),
                    )
                )
                continue
            group_converted.append(None)
        converted.append(group_converted)
    return converted


def _outputs_to_payload(
    outputs: Any, want_logprobs: bool
) -> Tuple[List[List[str]], Optional[List[List[Optional[Dict[str, Any]]]]]]:
    grouped: List[List[str]] = []
    grouped_meta: List[List[Optional[Dict[str, Any]]]] = []
    for output in outputs:
        seqs = getattr(output, "outputs", None) or []
        group: List[str] = []
        meta_group: List[Optional[Dict[str, Any]]] = []
        for seq in seqs:
            text = getattr(seq, "text", None)
            if text is None:
                text = str(getattr(seq, "text", ""))
            group.append(text)
            logprob_sum = getattr(seq, "cumulative_logprob", None)
            token_ids = getattr(seq, "token_ids", None) or getattr(
                seq, "output_token_ids", None
            )
            if token_ids is not None:
                try:
                    token_ids = [int(token_id) for token_id in token_ids]
                except (TypeError, ValueError):
                    token_ids = None
            token_count = len(token_ids) if token_ids is not None else None
            token_logprobs = _extract_logprob_sequence(getattr(seq, "logprobs", None))
            finish_reason = getattr(seq, "finish_reason", None)
            stop_reason = getattr(seq, "stop_reason", None)
            if logprob_sum is None and want_logprobs:
                logprob_sum = _sum_logprobs(token_logprobs)
            raw_output: Optional[Dict[str, Any]] = None
            if token_ids is not None:
                raw_output = {
                    "token_ids": list(token_ids),
                    "token_count": int(token_count or 0),
                }
                if finish_reason is not None:
                    raw_output["finish_reason"] = finish_reason
                if stop_reason is not None:
                    raw_output["stop_reason"] = stop_reason
            if (
                logprob_sum is None
                and token_count is None
                and token_logprobs is None
                and raw_output is None
            ):
                meta_group.append(None)
            else:
                meta_group.append(
                    {
                        "logprob_sum": logprob_sum if want_logprobs else None,
                        "token_count": token_count,
                        "token_logprobs": token_logprobs if want_logprobs else None,
                        "raw_output": raw_output,
                    }
                )
        grouped.append(group)
        grouped_meta.append(meta_group)
    if not any(any(entry is not None for entry in group) for group in grouped_meta):
        return grouped, None
    return grouped, grouped_meta


_PARAM_NAME_ALIASES = (
    "model.",
    "module.",
    "base_model.model.",
)


def _candidate_children(obj: Any) -> List[Any]:
    if obj is None:
        return []
    children: List[Any] = []
    for attr in (
        "model",
        "llm_engine",
        "_engine",
        "engine",
        "model_executor",
        "executor",
        "driver_worker",
        "model_runner",
        "runner",
    ):
        if not hasattr(obj, attr):
            continue
        try:
            value = getattr(obj, attr)
        except Exception:
            continue
        if value is None:
            continue
        if isinstance(value, (list, tuple)):
            children.extend([item for item in value if item is not None])
        else:
            children.append(value)
    return children


def _resolve_llm_model(llm: Any) -> Optional[Any]:
    """Return the first object in the vLLM stack exposing named_parameters()."""
    if llm is None:
        return None
    seen: set[int] = set()
    stack: List[Any] = [llm]
    while stack:
        obj = stack.pop()
        if obj is None:
            continue
        obj_id = id(obj)
        if obj_id in seen:
            continue
        seen.add(obj_id)
        if callable(getattr(obj, "named_parameters", None)):
            return obj
        stack.extend(_candidate_children(obj))
    return None


def _build_param_index(model: Any) -> Dict[str, Any]:
    index: Dict[str, Any] = {}
    if model is None or not callable(getattr(model, "named_parameters", None)):
        return index
    try:
        for name, param in model.named_parameters():
            if name and param is not None:
                index[name] = param
    except Exception:
        return index
    return index


def _lookup_param(name: str, index: Dict[str, Any]) -> Optional[Any]:
    if name in index:
        return index[name]
    for prefix in _PARAM_NAME_ALIASES:
        if name.startswith(prefix):
            alt = name[len(prefix) :]
            if alt in index:
                return index[alt]
        alt = f"{prefix}{name}"
        if alt in index:
            return index[alt]
    return None


def _apply_param_updates(
    index: Dict[str, Any],
    updates: List[Tuple[str, Any]],
    missing: Optional[set[str]] = None,
    log_fn: Optional[Callable[[str], None]] = None,
) -> Tuple[int, int]:
    if missing is None:
        missing = set()
    applied = 0
    skipped = 0
    try:
        import torch
    except Exception:
        torch = None  # type: ignore[assignment]
    for name, tensor in updates:
        param = _lookup_param(name, index)
        if param is None:
            skipped += 1
            if name not in missing and log_fn is not None:
                missing.add(name)
                log_fn(f"worker missing param | name={name}")
            continue
        if torch is not None and not isinstance(tensor, torch.Tensor):
            skipped += 1
            continue
        try:
            if torch is not None:
                with torch.no_grad():
                    target = param.data
                    src = tensor.detach()
                    if target.shape != src.shape:
                        skipped += 1
                        if log_fn is not None:
                            log_fn(
                                f"worker param shape mismatch | name={name} "
                                f"target={tuple(target.shape)} src={tuple(src.shape)}"
                            )
                        continue
                    if target.dtype != src.dtype:
                        src = src.to(dtype=target.dtype)
                    if target.device != src.device:
                        src = src.to(device=target.device)
                    target.copy_(src)
            else:
                param.data.copy_(tensor)
            applied += 1
        except Exception:
            skipped += 1
            continue
    return applied, skipped


def _reset_prefix_cache_llm(llm: Any) -> None:
    if llm is None:
        return
    for attr in ("reset_prefix_cache",):
        fn = getattr(llm, attr, None)
        if callable(fn):
            try:
                fn()
                return
            except Exception:
                pass
    for attr in ("llm_engine", "_engine", "engine"):
        engine = getattr(llm, attr, None)
        if engine is None:
            continue
        fn = getattr(engine, "reset_prefix_cache", None)
        if callable(fn):
            try:
                fn()
                return
            except Exception:
                pass


def _vllm_colocate_worker(conn: Any) -> None:
    """Subprocess worker for vLLM colocate init/generate."""
    try:

        def _send_log(message: str) -> None:
            try:
                conn.send({"type": "log", "message": message})
            except Exception:
                pass

        try:
            _send_log(
                f"worker boot | pid={os.getpid()} python={sys.version.replace(os.linesep, ' ')}"
            )
        except Exception:
            pass
        init_msg = conn.recv()
        if not isinstance(init_msg, dict) or init_msg.get("type") != "init":
            conn.send({"ok": False, "error": "Invalid init payload"})
            return
        model_id = init_msg.get("model_id")
        llm_kwargs = init_msg.get("llm_kwargs") or {}
        request_logprobs_default = bool(init_msg.get("request_logprobs", False))
        _send_log(
            f"worker init payload | model_id={model_id} request_logprobs={request_logprobs_default} llm_kwargs={llm_kwargs}"
        )

        # Isolate the worker from the training process' distributed context.
        os.environ["RANK"] = "0"
        os.environ["WORLD_SIZE"] = "1"
        os.environ["LOCAL_RANK"] = "0"
        os.environ["SLURM_LOCALID"] = "0"
        os.environ["LOCAL_WORLD_SIZE"] = "1"
        os.environ["VLLM_DP_SIZE"] = "1"
        os.environ["VLLM_DP_RANK"] = "0"
        os.environ["VLLM_DP_RANK_LOCAL"] = "0"
        _send_log(
            "worker env override | RANK=0 WORLD_SIZE=1 LOCAL_RANK=0 SLURM_LOCALID=0"
        )

        # If a specific CUDA device was requested, remap visibility *before*
        # any torch import so the worker sees the requested GPU. By default we
        # append to existing CUDA_VISIBLE_DEVICES to preserve training ordinals
        # (needed for CUDA IPC during weight sync).
        device = llm_kwargs.get("device")
        if isinstance(device, str) and device.startswith("cuda:"):
            idx = device.split(":", 1)[1]
            if idx.isdigit():
                remap = _env_bool("MAXENT_VLLM_COLOCATE_REMAP_DEVICE")
                if remap is None:
                    remap = True
                if remap:
                    visible = os.getenv("CUDA_VISIBLE_DEVICES", "")
                    visible = visible.strip()
                    tokens = []
                    if visible and visible.lower() != "none":
                        tokens = [
                            tok.strip() for tok in visible.split(",") if tok.strip()
                        ]
                    if idx not in tokens:
                        tokens.append(idx)
                    if tokens:
                        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(tokens)
                        new_ord = tokens.index(idx)
                    else:
                        os.environ["CUDA_VISIBLE_DEVICES"] = idx
                        new_ord = 0
                    llm_kwargs = dict(llm_kwargs)
                    llm_kwargs["device"] = f"cuda:{new_ord}"
                    _send_log(
                        "worker CUDA remap | CUDA_VISIBLE_DEVICES=%s device=cuda:%s (orig=%s)"
                        % (os.getenv("CUDA_VISIBLE_DEVICES"), new_ord, idx)
                    )
                else:
                    _send_log(
                        "worker CUDA remap skipped | device=%s CUDA_VISIBLE_DEVICES=%s"
                        % (device, os.getenv("CUDA_VISIBLE_DEVICES"))
                    )
        master_addr = os.getenv("MAXENT_VLLM_COLOCATE_MASTER_ADDR") or "127.0.0.1"
        master_port = os.getenv("MAXENT_VLLM_COLOCATE_MASTER_PORT")
        if not master_port:
            try:
                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                sock.bind((master_addr, 0))
                master_port = str(sock.getsockname()[1])
            except Exception:
                master_port = "29501"
            finally:
                try:
                    sock.close()
                except Exception:
                    pass
        os.environ["MASTER_ADDR"] = master_addr
        os.environ["MASTER_PORT"] = str(master_port)
        os.environ["VLLM_DP_MASTER_IP"] = master_addr
        os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
        _send_log(
            f"worker dist env | MASTER_ADDR={master_addr} MASTER_PORT={master_port} LOCAL_WORLD_SIZE=1"
        )
        preinit = _env_bool("MAXENT_VLLM_COLOCATE_PREINIT_DIST")
        if preinit is None:
            preinit = True
        if preinit:
            try:
                import torch

                dist = getattr(torch, "distributed", None)
                if (
                    dist is not None
                    and dist.is_available()
                    and not dist.is_initialized()
                ):
                    init_method = os.getenv("MAXENT_VLLM_COLOCATE_INIT_METHOD")
                    if not init_method:
                        store_dir = (
                            os.getenv("MAXENT_VLLM_COLOCATE_STORE_DIR")
                            or tempfile.gettempdir()
                        )
                        fd, path = tempfile.mkstemp(prefix="vllm-init-", dir=store_dir)
                        os.close(fd)
                        init_method = f"file://{path}"
                    _send_log(
                        f"worker dist preinit | method={init_method} backend=nccl"
                    )
                    dist.init_process_group(
                        backend="nccl",
                        init_method=init_method,
                        world_size=1,
                        rank=0,
                    )
                    _send_log("worker dist preinit done | initialized=True")
            except Exception as exc:
                _send_log(f"worker dist preinit failed | error={exc}")

        _send_log("worker import vllm start")
        vllm_mod = optional_import("vllm")
        if vllm_mod is None:
            raise RuntimeError("vllm is not installed")
        llm_cls = getattr(vllm_mod, "LLM", None)
        if llm_cls is None:
            raise RuntimeError("vllm.LLM is unavailable")
        _send_log(
            f"worker vllm module | version={getattr(vllm_mod, '__version__', None)} path={getattr(vllm_mod, '__file__', None)}"
        )
        llm_kwargs = _filter_kwargs(llm_cls, llm_kwargs)
        device = llm_kwargs.get("device")
        if isinstance(device, str) and device.startswith("cuda:"):
            try:
                import torch

                torch.cuda.set_device(int(device.split(":")[1]))
                _send_log(f"worker torch cuda set_device | device={device}")
            except Exception:
                pass
        try:
            import torch

            torch_version = getattr(torch, "__version__", None)
            cuda_version = getattr(torch, "version", None)
            cuda_version = (
                getattr(cuda_version, "cuda", None)
                if cuda_version is not None
                else None
            )
            cudnn_version = None
            if hasattr(torch, "backends") and hasattr(torch.backends, "cudnn"):
                cudnn_version = getattr(torch.backends.cudnn, "version", None)
                if callable(cudnn_version):
                    cudnn_version = cudnn_version()
            _send_log(
                f"worker torch | version={torch_version} cuda_version={cuda_version} cudnn={cudnn_version}"
            )
        except Exception:
            pass
        try:
            import torch

            if torch.cuda.is_available() and hasattr(torch.cuda, "mem_get_info"):
                free_mem, total_mem = torch.cuda.mem_get_info()
                _send_log(
                    f"worker cuda mem pre-init | free_mem={free_mem} total_mem={total_mem}"
                )
        except Exception:
            pass
        _send_log(f"worker LLM init start | model={model_id} kwargs={llm_kwargs}")
        stack_interval = _env_float("MAXENT_VLLM_COLOCATE_INIT_STACK_S")
        watchdog_stop = threading.Event()

        def _stack_watchdog() -> None:
            if not stack_interval or stack_interval <= 0:
                return
            while not watchdog_stop.wait(stack_interval):
                _send_log("worker LLM init still running; dumping stack traces")
                try:
                    faulthandler.dump_traceback(file=sys.stderr, all_threads=True)
                except Exception:
                    pass

        watchdog = None
        if stack_interval and stack_interval > 0:
            watchdog = threading.Thread(
                target=_stack_watchdog, name="vllm-colocate-init-watchdog", daemon=True
            )
            watchdog.start()
        started = time.time()
        llm = llm_cls(model=model_id, **llm_kwargs)
        watchdog_stop.set()
        if watchdog is not None:
            try:
                watchdog.join(timeout=1.0)
            except Exception:
                pass
        elapsed = time.time() - started
        _send_log(f"worker LLM init done | elapsed_s={elapsed:.2f}")
        try:
            import torch

            if torch.cuda.is_available() and hasattr(torch.cuda, "mem_get_info"):
                free_mem, total_mem = torch.cuda.mem_get_info()
                _send_log(
                    f"worker cuda mem post-init | free_mem={free_mem} total_mem={total_mem}"
                )
        except Exception:
            pass
        params_cls = getattr(vllm_mod, "SamplingParams", None)
        if params_cls is None:
            raise RuntimeError("vllm.SamplingParams is unavailable")
        _send_log("worker SamplingParams available")
        conn.send({"ok": True})

        param_index: Optional[Dict[str, Any]] = None
        missing_params: set[str] = set()

        while True:
            msg = conn.recv()
            if not isinstance(msg, dict):
                continue
            if msg.get("type") == "shutdown":
                break
            if msg.get("type") == "reset_prefix_cache":
                _reset_prefix_cache_llm(llm)
                conn.send({"ok": True})
                continue
            if msg.get("type") == "update_params":
                updates = msg.get("params") or []
                if param_index is None:
                    model_for_params = _resolve_llm_model(llm)
                    param_index = _build_param_index(model_for_params)
                    if not param_index:
                        _send_log("worker param index empty; update may be skipped")
                applied, skipped = _apply_param_updates(
                    param_index or {}, updates, missing_params, _send_log
                )
                conn.send({"ok": True, "applied": applied, "skipped": skipped})
                continue
            if msg.get("type") != "generate":
                continue
            prompts = msg.get("prompts") or []
            params_kwargs = msg.get("params_kwargs") or {}
            request_logprobs = bool(
                msg.get("request_logprobs", request_logprobs_default)
            )
            params_kwargs = _filter_kwargs(params_cls, params_kwargs)
            params = params_cls(**params_kwargs)
            outputs = llm.generate(prompts, params)
            grouped, grouped_meta = _outputs_to_payload(outputs, request_logprobs)
            conn.send({"ok": True, "grouped": grouped, "meta": grouped_meta})
    except Exception as exc:
        try:
            conn.send(
                {
                    "ok": False,
                    "error": str(exc),
                    "traceback": traceback.format_exc(),
                }
            )
        except Exception:
            pass
    finally:
        try:
            conn.close()
        except Exception:
            pass


[docs] class ColocateVLLMEngine: """Lazy vLLM engine wrapper used for colocated generation.""" def __init__(self, ctx: Any, fallback_generate: Any) -> None: self.ctx = ctx self._fallback_generate = fallback_generate self._llm: Any = None self._init_failed = False self._worker_proc: Optional[mp.Process] = None self._worker_conn: Optional[Any] = None self._sync_client: Optional["ColocateVLLMClient"] = None self._param_index: Optional[Dict[str, Any]] = None self._missing_params: set[str] = set() def _local_fallback_allowed(self) -> bool: """Return True if local fallback generation is allowed.""" raw = os.getenv("MAXENT_VLLM_DISABLE_LOCAL_FALLBACK") if isinstance(raw, str) and raw.strip().lower() in {"1", "true", "yes", "on"}: return False if bool(getattr(self.ctx, "vllm_disable_local_fallback", False)): return False return True def _fallback_or_raise( self, truncated: List[str], request_count: int, reason: Exception | str, ) -> Tuple[ Optional[List[List[str]]], Optional[List[List[Optional[VLLMLogprobResult]]]], ]: if self._resolve_init_mode() == "subprocess": try: self._shutdown_worker() except Exception: pass if self._local_fallback_allowed(): LOG.warning( "vLLM colocate generate failed (%s); falling back to local generation.", reason, ) return self._fallback_generate(truncated, request_count, None) raise RuntimeError( f"vLLM colocate generate failed and local fallback disabled: {reason}" ) def _resolve_init_mode(self) -> str: mode = _init_mode() if mode == "auto": return "subprocess" if _dist_initialized() else "thread" return mode def _configure_vllm_env(self) -> None: backend_override = os.getenv("MAXENT_VLLM_COLOCATE_ATTENTION_BACKEND") if backend_override: os.environ["VLLM_ATTENTION_BACKEND"] = backend_override LOG.info( "vLLM colocate attention backend override | backend=%s", backend_override, ) if os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING") is None: mp_override = _env_bool("MAXENT_VLLM_COLOCATE_V1_MULTIPROCESSING") if mp_override is None or not mp_override: os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" LOG.info( "Disabling vLLM V1 multiprocessing for colocate. " "Set MAXENT_VLLM_COLOCATE_V1_MULTIPROCESSING=1 to re-enable." ) else: os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "1" if os.getenv("VLLM_USE_V1") is None: force_v0 = _env_bool("MAXENT_VLLM_COLOCATE_FORCE_V0") if force_v0: os.environ["VLLM_USE_V1"] = "0" LOG.info( "Forcing vLLM V0 engine for colocate " "(MAXENT_VLLM_COLOCATE_FORCE_V0=1)." ) LOG.info( "vLLM colocate env | VLLM_USE_V1=%s VLLM_ENABLE_V1_MULTIPROCESSING=%s " "CUDA_VISIBLE_DEVICES=%s LOCAL_RANK=%s SLURM_LOCALID=%s", os.getenv("VLLM_USE_V1"), os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING"), os.getenv("CUDA_VISIBLE_DEVICES"), os.getenv("LOCAL_RANK"), os.getenv("SLURM_LOCALID"), ) def _build_llm_kwargs(self) -> Dict[str, Any]: llm_kwargs: Dict[str, Any] = {} dtype_override = _resolve_dtype(self.ctx) if dtype_override: llm_kwargs["dtype"] = dtype_override prefix_caching = _env_bool("MAXENT_VLLM_COLOCATE_ENABLE_PREFIX_CACHING") if prefix_caching is not None: llm_kwargs["enable_prefix_caching"] = prefix_caching sleep_mode = _env_bool("MAXENT_VLLM_COLOCATE_ENABLE_SLEEP_MODE") if sleep_mode is not None: llm_kwargs["enable_sleep_mode"] = sleep_mode gpu_util = _env_float("MAXENT_VLLM_COLOCATE_GPU_UTIL") if gpu_util is None: training_args = getattr(self.ctx, "training_args", None) cfg_gpu_util = getattr(training_args, "vllm_gpu_memory_utilization", None) if cfg_gpu_util is not None: try: gpu_util = float(cfg_gpu_util) except (TypeError, ValueError): LOG.warning( "Invalid training_args.vllm_gpu_memory_utilization=%r; ignoring.", cfg_gpu_util, ) gpu_util = None if gpu_util is not None and not (0.0 < gpu_util <= 1.0): LOG.warning( "Invalid gpu_memory_utilization=%r; expected (0, 1]; ignoring.", gpu_util, ) gpu_util = None if gpu_util is not None: llm_kwargs["gpu_memory_utilization"] = gpu_util tp_size = _env_int("MAXENT_VLLM_COLOCATE_TP") if tp_size is not None: llm_kwargs["tensor_parallel_size"] = tp_size max_model_len = _env_int("MAXENT_VLLM_COLOCATE_MAX_MODEL_LEN") if max_model_len is not None: llm_kwargs["max_model_len"] = max_model_len enforce_eager = _env_bool("MAXENT_VLLM_COLOCATE_ENFORCE_EAGER") if enforce_eager is None: enforce_eager = True llm_kwargs["enforce_eager"] = enforce_eager device_override = os.getenv("MAXENT_VLLM_COLOCATE_DEVICE") if device_override and _use_vllm_collective(): llm_kwargs["device"] = device_override else: local_rank = os.getenv("LOCAL_RANK") or os.getenv("SLURM_LOCALID") if local_rank is not None and str(local_rank).isdigit(): llm_kwargs["device"] = f"cuda:{int(local_rank)}" else: try: import torch if torch.cuda.is_available(): llm_kwargs["device"] = f"cuda:{torch.cuda.current_device()}" except Exception: pass trust_remote = getattr(self.ctx, "trust_remote_code", None) if trust_remote is None: training_args = getattr(self.ctx, "training_args", None) trust_remote = getattr(training_args, "trust_remote_code", None) if trust_remote is not None: llm_kwargs["trust_remote_code"] = bool(trust_remote) return llm_kwargs def _build_llm_spec(self) -> Tuple[str, Dict[str, Any]]: LOG.info("vLLM colocate build spec start") _log_process_snapshot() _log_runtime_snapshot() _log_env_snapshot( [ "CUDA_VISIBLE_DEVICES", "CUDA_DEVICE_ORDER", "LOCAL_RANK", "RANK", "WORLD_SIZE", "SLURM_LOCALID", "SLURM_PROCID", "SLURM_NTASKS", "MASTER_ADDR", "MASTER_PORT", "NCCL_P2P_DISABLE", "NCCL_IB_DISABLE", "NCCL_SOCKET_IFNAME", "NCCL_DEBUG", "NCCL_DEBUG_SUBSYS", "VLLM_USE_V1", "VLLM_ENABLE_V1_MULTIPROCESSING", "VLLM_ATTENTION_BACKEND", "VLLM_LOGGING_LEVEL", "MAXENT_VLLM_COLOCATE_GPU_UTIL", "MAXENT_VLLM_COLOCATE_TP", "MAXENT_VLLM_COLOCATE_MAX_MODEL_LEN", "MAXENT_VLLM_COLOCATE_ENFORCE_EAGER", "MAXENT_VLLM_COLOCATE_ENABLE_PREFIX_CACHING", "MAXENT_VLLM_COLOCATE_ENABLE_SLEEP_MODE", "MAXENT_VLLM_COLOCATE_FORCE_V0", "MAXENT_VLLM_COLOCATE_DEVICE", "MAXENT_VLLM_COLOCATE_INIT_TIMEOUT_S", "MAXENT_VLLM_COLOCATE_INIT_MODE", "MAXENT_VLLM_COLOCATE_INIT_HEARTBEAT_S", "MAXENT_VLLM_COLOCATE_INIT_STACK_S", "MAXENT_VLLM_COLOCATE_MASTER_ADDR", "MAXENT_VLLM_COLOCATE_MASTER_PORT", "MAXENT_VLLM_COLOCATE_PREINIT_DIST", "MAXENT_VLLM_COLOCATE_INIT_METHOD", "MAXENT_VLLM_COLOCATE_STORE_DIR", "MAXENT_VLLM_COLOCATE_ATTENTION_BACKEND", "MAXENT_VLLM_COLOCATE_V1_MULTIPROCESSING", ] ) _log_torch_snapshot() self._configure_vllm_env() model_id = _resolve_model_id(self.ctx) if not model_id: model_id = "unknown-model" LOG.warning( "Unable to resolve model ID for vLLM colocate; using placeholder '%s'.", model_id, ) LOG.info("vLLM colocate resolved model_id=%s", model_id) llm_kwargs = self._build_llm_kwargs() LOG.info("vLLM colocate llm_kwargs pre-filter | %s", llm_kwargs) LOG.info( "vLLM colocate device selection | device=%s torch_device=%s", llm_kwargs.get("device", "auto"), getattr(getattr(self, "ctx", None), "device", None), ) LOG.info( "vLLM colocate compile mode | enforce_eager=%s", llm_kwargs.get("enforce_eager"), ) LOG.info( "vLLM colocate init | model=%s | dtype=%s | tp=%s | gpu_util=%s", model_id, llm_kwargs.get("dtype", "default"), llm_kwargs.get("tensor_parallel_size", "auto"), llm_kwargs.get("gpu_memory_utilization", "default"), ) return model_id, llm_kwargs def _build_llm(self) -> Any: model_id, llm_kwargs = self._build_llm_spec() LOG.info("vLLM colocate import vllm start") vllm_mod = optional_import("vllm") if vllm_mod is None: raise RuntimeError("vllm is not installed") LOG.info( "vLLM colocate vllm module | version=%s path=%s", getattr(vllm_mod, "__version__", None), getattr(vllm_mod, "__file__", None), ) llm_cls = getattr(vllm_mod, "LLM", None) if llm_cls is None: raise RuntimeError("vllm.LLM is unavailable") pre_filter = dict(llm_kwargs) llm_kwargs = _filter_kwargs(llm_cls, llm_kwargs) removed = sorted(set(pre_filter) - set(llm_kwargs)) if removed: LOG.info("vLLM colocate llm_kwargs filtered | removed=%s", removed) LOG.info( "vLLM colocate LLM init start | model=%s kwargs=%s", model_id, llm_kwargs ) llm = llm_cls(model=model_id, **llm_kwargs) LOG.info("vLLM colocate LLM init done") return llm def _shutdown_worker(self) -> None: conn = getattr(self, "_worker_conn", None) proc = getattr(self, "_worker_proc", None) self._worker_conn = None self._worker_proc = None if conn is not None: try: conn.send({"type": "shutdown"}) except Exception: pass try: conn.close() except Exception: pass if proc is not None: try: if proc.is_alive(): proc.terminate() except Exception: pass try: proc.join(timeout=5) except Exception: pass def _ensure_worker(self) -> None: if self._worker_conn is not None and self._worker_proc is not None: if self._worker_proc.is_alive(): return self._shutdown_worker() if self._init_failed: raise RuntimeError("vLLM colocate initialization previously failed") timeout_s = _env_float("MAXENT_VLLM_COLOCATE_INIT_TIMEOUT_S") if timeout_s is None: timeout_s = 0.0 LOG.info( "vLLM colocate init start | timeout_s=%.1f | mode=subprocess", float(timeout_s), ) model_id, llm_kwargs = self._build_llm_spec() ctx = mp.get_context("spawn") try: start_method = ctx.get_start_method() except Exception: start_method = "spawn" LOG.info("vLLM colocate subprocess | start_method=%s", start_method) parent_conn, child_conn = ctx.Pipe() proc = ctx.Process( target=_vllm_colocate_worker, args=(child_conn,), name="vllm-colocate-worker", daemon=True, ) proc.start() LOG.info("vLLM colocate subprocess started | pid=%s", proc.pid) child_conn.close() init_payload = { "type": "init", "model_id": model_id, "llm_kwargs": llm_kwargs, "request_logprobs": bool(getattr(self.ctx, "vllm_request_logprobs", False)), } parent_conn.send(init_payload) start = time.time() heartbeat_s = _env_float("MAXENT_VLLM_COLOCATE_INIT_HEARTBEAT_S") if heartbeat_s is None or heartbeat_s <= 0: heartbeat_s = 30.0 init_resp: Any = None while True: elapsed = time.time() - start if timeout_s and timeout_s > 0 and elapsed >= timeout_s: self._init_failed = True LOG.warning( "vLLM colocate subprocess init timed out | elapsed_s=%.1f alive=%s exitcode=%s", elapsed, proc.is_alive(), proc.exitcode, ) self._shutdown_worker() raise RuntimeError( f"vLLM colocate init timed out after {timeout_s:.1f}s" ) poll_timeout = heartbeat_s if timeout_s and timeout_s > 0: remaining = max(timeout_s - elapsed, 0.0) poll_timeout = min(heartbeat_s, remaining) if parent_conn.poll(poll_timeout): init_resp = parent_conn.recv() if isinstance(init_resp, dict) and init_resp.get("type") == "log": LOG.info("vLLM colocate worker | %s", init_resp.get("message")) init_resp = None continue break LOG.info( "vLLM colocate init waiting on subprocess | elapsed_s=%.1f alive=%s exitcode=%s", elapsed, proc.is_alive(), proc.exitcode, ) if not proc.is_alive() and not parent_conn.poll(0.0): init_resp = { "ok": False, "error": f"vLLM colocate subprocess exited (exitcode={proc.exitcode})", } break if init_resp is None: self._init_failed = True self._shutdown_worker() raise RuntimeError("vLLM colocate subprocess init produced no response") if not isinstance(init_resp, dict) or not init_resp.get("ok"): self._init_failed = True self._shutdown_worker() err = init_resp.get("error") if isinstance(init_resp, dict) else None tb = init_resp.get("traceback") if isinstance(init_resp, dict) else None if tb: LOG.warning("vLLM colocate subprocess init traceback:\n%s", tb) raise RuntimeError(err or "vLLM colocate subprocess init failed") self._worker_conn = parent_conn self._worker_proc = proc atexit.register(self._shutdown_worker) def _build_llm_with_timeout(self) -> Any: timeout_s = _env_float("MAXENT_VLLM_COLOCATE_INIT_TIMEOUT_S") if timeout_s is None or timeout_s <= 0: return self._build_llm() mode = self._resolve_init_mode() if mode == "subprocess": self._ensure_worker() return None LOG.info("vLLM colocate init start | timeout_s=%.1f | mode=%s", timeout_s, mode) if mode == "blocking": started = time.time() llm = self._build_llm() elapsed = time.time() - started if elapsed > timeout_s: LOG.warning( "vLLM colocate init exceeded timeout (%.1fs > %.1fs). " "Set MAXENT_VLLM_COLOCATE_INIT_MODE=thread to enforce a hard timeout.", elapsed, timeout_s, ) return llm result: Dict[str, Any] = {} error: Dict[str, Exception] = {} done = threading.Event() def _runner() -> None: try: result["llm"] = self._build_llm() except Exception as exc: error["exc"] = exc finally: done.set() thread = threading.Thread( target=_runner, name="vllm-colocate-init", daemon=True, ) thread.start() heartbeat_s = _env_float("MAXENT_VLLM_COLOCATE_INIT_HEARTBEAT_S") if heartbeat_s is None or heartbeat_s <= 0: heartbeat_s = 30.0 start = time.time() while True: elapsed = time.time() - start remaining = timeout_s - elapsed if remaining <= 0: LOG.warning( "vLLM colocate init timed out after %.1fs; disabling colocate engine.", timeout_s, ) LOG.warning( "vLLM colocate init thread still running | elapsed_s=%.1f alive=%s", elapsed, thread.is_alive(), ) raise RuntimeError( f"vLLM colocate init timed out after {timeout_s:.1f}s" ) wait_s = min(heartbeat_s, remaining) if done.wait(wait_s): break LOG.info( "vLLM colocate init still running | elapsed_s=%.1f thread_alive=%s", elapsed, thread.is_alive(), ) if "exc" in error: raise error["exc"] return result["llm"] def _get_llm(self) -> Any: if self._llm is not None: return self._llm if self._init_failed: raise RuntimeError("vLLM colocate initialization previously failed") if self._resolve_init_mode() == "subprocess": raise RuntimeError("vLLM colocate subprocess mode is active") try: self._llm = self._build_llm_with_timeout() except Exception as exc: self._init_failed = True raise RuntimeError(str(exc)) from exc return self._llm
[docs] def sync_client(self) -> "ColocateVLLMClient": if self._sync_client is None: self._sync_client = ColocateVLLMClient(self) return self._sync_client
def _stabilize_parent_cuda_state(self) -> None: """Best-effort parent-side cleanup before vLLM memory profiling starts.""" try: gc.collect() except Exception: pass try: import torch cuda_mod = getattr(torch, "cuda", None) if cuda_mod is None or not cuda_mod.is_available(): return synchronize = getattr(cuda_mod, "synchronize", None) empty_cache = getattr(cuda_mod, "empty_cache", None) ipc_collect = getattr(cuda_mod, "ipc_collect", None) for _ in range(2): if callable(synchronize): try: synchronize() except Exception: pass if callable(empty_cache): try: empty_cache() except Exception: pass if callable(ipc_collect): try: ipc_collect() except Exception: pass time.sleep(0.05) except Exception: pass
[docs] def ensure_ready(self) -> None: """Initialize the colocate worker/engine before parameter streaming begins.""" self._stabilize_parent_cuda_state() if self._resolve_init_mode() == "subprocess": self._ensure_worker() return self._get_llm()
def _apply_param_updates(self, updates: List[Tuple[str, Any]]) -> None: if not updates: return if self._resolve_init_mode() == "subprocess": self._ensure_worker() conn = self._worker_conn if conn is None: raise RuntimeError("vLLM colocate worker is unavailable") conn.send({"type": "update_params", "params": updates}) resp: Any = None while True: resp = conn.recv() if isinstance(resp, dict) and resp.get("type") == "log": LOG.info("vLLM colocate worker | %s", resp.get("message")) continue break if not isinstance(resp, dict) or not resp.get("ok"): err = resp.get("error") if isinstance(resp, dict) else None raise RuntimeError(err or "vLLM colocate param update failed") if isinstance(resp, dict): applied = resp.get("applied") skipped = resp.get("skipped") if applied is not None or skipped is not None: LOG.info( "vLLM colocate sync (subprocess) | applied=%s skipped=%s", applied, skipped, ) return llm = self._get_llm() if self._param_index is None: model_for_params = _resolve_llm_model(llm) self._param_index = _build_param_index(model_for_params) if not self._param_index: LOG.warning("vLLM colocate param index empty; updates may be skipped.") applied, skipped = _apply_param_updates( self._param_index or {}, updates, self._missing_params, lambda msg: LOG.warning("vLLM colocate | %s", msg), ) LOG.info( "vLLM colocate sync (in-process) | applied=%s skipped=%s", applied, skipped, ) def _reset_prefix_cache(self) -> None: if self._resolve_init_mode() == "subprocess": conn = self._worker_conn if conn is None: return try: conn.send({"type": "reset_prefix_cache"}) resp: Any = None while True: resp = conn.recv() if isinstance(resp, dict) and resp.get("type") == "log": LOG.info("vLLM colocate worker | %s", resp.get("message")) continue break if not isinstance(resp, dict) or not resp.get("ok"): LOG.debug("vLLM colocate reset_prefix_cache failed: %s", resp) except Exception: LOG.debug("vLLM colocate reset_prefix_cache failed") return _reset_prefix_cache_llm(self._get_llm()) def _build_sampling_params_kwargs(self, request_count: int) -> Dict[str, Any]: stop_sequences = ( self.ctx.gen_stop_sequences if getattr(self.ctx, "gen_stop_sequences", None) is not None else getattr(self.ctx, "vllm_stop_sequences", None) ) top_k = ( self.ctx.gen_top_k if getattr(self.ctx, "gen_top_k", None) is not None else getattr(self.ctx, "vllm_top_k", None) ) if top_k is None or top_k == 0: top_k = -1 best_of = ( self.ctx.gen_best_of if getattr(self.ctx, "gen_best_of", None) is not None else getattr(self.ctx, "vllm_best_of", None) ) logit_bias = merge_invalid_token_block_logit_bias( self.ctx, getattr(self.ctx, "vllm_logit_bias", None), ) allowed_token_ids = resolve_allowed_token_ids(self.ctx) blocked_token_ids = resolve_blocked_token_ids(self.ctx) params_kwargs: Dict[str, Any] = { "temperature": self.ctx.gen_temperature, "top_p": self.ctx.gen_top_p, "top_k": top_k, "min_p": getattr(self.ctx, "gen_min_p", 0.0), "max_tokens": self.ctx.max_completion_len, "n": int(request_count), "best_of": best_of, "repetition_penalty": getattr(self.ctx, "gen_repetition_penalty", 1.0), "frequency_penalty": getattr(self.ctx, "gen_frequency_penalty", 0.0), "presence_penalty": getattr(self.ctx, "gen_presence_penalty", 0.0), "stop": stop_sequences, "include_stop_str_in_output": bool( getattr(self.ctx, "vllm_include_stop_str_in_output", False) ), "logit_bias": logit_bias, "allowed_token_ids": allowed_token_ids, "blocked_token_ids": blocked_token_ids, "guided_json": getattr(self.ctx, "vllm_guided_json", None), "guided_regex": getattr(self.ctx, "vllm_guided_regex", None), } if bool(getattr(self.ctx, "vllm_request_logprobs", False)): params_kwargs["logprobs"] = 1 return params_kwargs def _build_sampling_params(self, request_count: int) -> Any: vllm_mod = optional_import("vllm") if vllm_mod is None: raise RuntimeError("vllm is not installed") params_cls = getattr(vllm_mod, "SamplingParams", None) if params_cls is None: raise RuntimeError("vllm.SamplingParams is unavailable") params_kwargs = self._build_sampling_params_kwargs(request_count) blocked_token_ids = list(params_kwargs.pop("blocked_token_ids", []) or []) params_kwargs = _filter_kwargs(params_cls, params_kwargs) params = params_cls(**params_kwargs) if blocked_token_ids: bad_words = [[int(token_id)] for token_id in blocked_token_ids] try: setattr(params, "_bad_words_token_ids", bad_words) except Exception: pass kwargs = getattr(params, "kwargs", None) if isinstance(kwargs, dict): kwargs["_bad_words_token_ids"] = bad_words return params def _record_latency(self, latency_ms: float) -> None: stats = getattr(self.ctx, "generation_stats", None) if not isinstance(stats, dict): return stats["vllm_last_latency_ms"] = float(latency_ms) stats["vllm_latency_total_ms"] = float( stats.get("vllm_latency_total_ms", 0.0) ) + float(latency_ms) stats["vllm_latency_calls"] = int(stats.get("vllm_latency_calls", 0)) + 1
[docs] def request_batch( self, prompts: List[str], request_count: int, ) -> Tuple[ Optional[List[List[str]]], Optional[List[List[Optional[VLLMLogprobResult]]]], ]: if not prompts: return [], None char_limit = getattr(self.ctx, "prompt_char_limit", None) if char_limit is None: char_limit = PROMPT_CHAR_LIMIT tokenizer = getattr(self.ctx, "tokenizer", None) max_tokens = getattr(self.ctx, "max_prompt_len", None) truncated = [ _truncate_prompt( prompt, char_limit, tokenizer=tokenizer, max_tokens=max_tokens, ) for prompt in prompts ] request_logprobs = bool(getattr(self.ctx, "vllm_request_logprobs", False)) if self._resolve_init_mode() == "subprocess": try: self._ensure_worker() params_kwargs = self._build_sampling_params_kwargs(request_count) except Exception as exc: return self._fallback_or_raise(truncated, request_count, exc) start = time.time() try: LOG.info( "vLLM colocate generate (subprocess) | prompts=%d request_count=%d", len(truncated), request_count, ) conn = self._worker_conn if conn is None: raise RuntimeError("vLLM colocate worker is unavailable") conn.send( { "type": "generate", "prompts": truncated, "params_kwargs": params_kwargs, "request_logprobs": request_logprobs, } ) resp = conn.recv() if not isinstance(resp, dict) or not resp.get("ok"): err = resp.get("error") if isinstance(resp, dict) else None raise RuntimeError(err or "vLLM colocate subprocess request failed") grouped = resp.get("grouped") or [] grouped_meta = _coerce_logprob_payload(resp.get("meta")) except Exception as exc: return self._fallback_or_raise(truncated, request_count, exc) latency_ms = (time.time() - start) * 1000.0 LOG.info("vLLM colocate generate done | latency_ms=%.2f", latency_ms) self._record_latency(latency_ms) else: try: llm = self._get_llm() params = self._build_sampling_params(request_count) except Exception as exc: return self._fallback_or_raise(truncated, request_count, exc) start = time.time() try: LOG.info( "vLLM colocate generate | prompts=%d request_count=%d", len(truncated), request_count, ) outputs = llm.generate(truncated, params) grouped, grouped_meta_payload = _outputs_to_payload( outputs, request_logprobs ) grouped_meta = _coerce_logprob_payload(grouped_meta_payload) except Exception as exc: return self._fallback_or_raise(truncated, request_count, exc) latency_ms = (time.time() - start) * 1000.0 LOG.info("vLLM colocate generate done | latency_ms=%.2f", latency_ms) self._record_latency(latency_ms) if len(grouped) != len(prompts): LOG.warning( "vLLM colocate returned %d groups for %d prompts; falling back to local generation.", len(grouped), len(prompts), ) return self._fallback_or_raise( truncated, request_count, f"group_count={len(grouped)} prompts={len(prompts)}", ) return grouped, grouped_meta
[docs] class ColocateVLLMClient: """Local client adapter that mimics TRL's VLLMClient interface.""" def __init__(self, engine: ColocateVLLMEngine) -> None: self._engine = engine self._buffer: List[Tuple[str, Any]] = [] self._buffer_bytes = 0 self._chunk_bytes = _sync_chunk_bytes() self._lock = threading.Lock() def _tensor_bytes(self, tensor: Any) -> int: try: return int(tensor.numel()) * int(tensor.element_size()) except Exception: return 0 def _flush_locked(self) -> None: if not self._buffer: return updates = self._buffer self._buffer = [] self._buffer_bytes = 0 self._engine._apply_param_updates(updates)
[docs] def update_named_param(self, name: str, param: Any) -> None: if param is None: return tensor = getattr(param, "detach", None) tensor = tensor() if callable(tensor) else param size = self._tensor_bytes(tensor) with self._lock: if self._buffer and self._buffer_bytes + size > self._chunk_bytes: self._flush_locked() self._buffer.append((name, tensor)) self._buffer_bytes += size if self._buffer_bytes >= self._chunk_bytes: self._flush_locked()
[docs] def ensure_ready(self) -> None: with self._lock: self._engine.ensure_ready()
[docs] def flush(self) -> None: with self._lock: self._flush_locked()
[docs] def reset_prefix_cache(self) -> None: self.flush() self._engine._reset_prefix_cache()
__all__ = ["ColocateVLLMEngine", "ColocateVLLMClient"]