"""Distributed helpers used by the vLLM generation helper."""
from __future__ import annotations
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, cast
import os
import sys
from maxent_grpo.training.runtime import require_accelerator, require_torch
torch = require_torch("generation_vllm_dist")
Accelerator = require_accelerator("generation_vllm_dist")
if TYPE_CHECKING: # pragma: no cover - hints only
from accelerate import Accelerator as AcceleratorType
else: # pragma: no cover - runtime fallback
AcceleratorType = Any
def _current_torch() -> Any:
"""Return torch, preferring the vLLM module shim when patched in tests.
:returns: The torch module (possibly a shim injected by tests).
:rtype: Any
"""
vllm_mod = sys.modules.get("maxent_grpo.training.generation.vllm")
if vllm_mod is not None and getattr(vllm_mod, "torch", None) is not None:
return vllm_mod.torch
return torch
def _gather_object_list(
accelerator: AcceleratorType, value: List[Any]
) -> List[List[Any]]:
"""Gather python lists across ranks with Accelerate/torch fallbacks.
:param accelerator: Accelerate instance providing distributed utilities.
:type accelerator: accelerate.Accelerator
:param value: Python list to broadcast to every process.
:type value: list[Any]
:returns: List of lists containing gathered values per rank.
:rtype: list[list[Any]]
"""
gather_fn = getattr(accelerator, "gather_object", None)
if callable(gather_fn):
gathered = gather_fn(value)
if isinstance(gathered, list):
return cast(List[List[Any]], gathered)
return [value]
dist = getattr(_current_torch(), "distributed", None)
if (
dist is not None
and hasattr(dist, "is_available")
and hasattr(dist, "is_initialized")
and dist.is_available()
and dist.is_initialized()
):
world_size = dist.get_world_size()
gathered_lists: List[List[Any]] = [[] for _ in range(world_size)]
dist.all_gather_object(gathered_lists, value)
return gathered_lists
return [value]
def _scatter_object(
accelerator: AcceleratorType,
input_list: Optional[List[Any]],
*,
src: int = 0,
) -> Any:
"""Scatter python objects from ``src`` rank to every other process.
:param accelerator: Accelerate instance providing distributed utilities.
:type accelerator: accelerate.Accelerator
:param input_list: Sequence of objects to scatter; only required on the
source rank.
:type input_list: list[Any] | None
:param src: Source rank that owns ``input_list``.
:type src: int
:returns: Object slice corresponding to the current rank, or ``None`` when
``input_list`` is missing.
:rtype: Any
"""
mode = os.getenv("MAXENT_SCATTER_MODE", "").strip().lower()
prefer_broadcast = mode in {"broadcast", "bcast", "broadcast_object_list"}
if accelerator.num_processes <= 1:
if input_list is None:
return None
return input_list[0]
dist = getattr(_current_torch(), "distributed", None)
if dist is not None and dist.is_available() and dist.is_initialized():
# Prefer a broadcast-based implementation when possible. Some torch/
# backend combinations have flaky support for scatter_object_list,
# whereas broadcast_object_list tends to be more reliable.
if prefer_broadcast and callable(getattr(dist, "broadcast_object_list", None)):
try:
world_size = int(dist.get_world_size())
except (AttributeError, TypeError, ValueError, RuntimeError):
world_size = int(getattr(accelerator, "num_processes", 1) or 1)
list_ok = input_list is None or (
isinstance(input_list, list) and len(input_list) == world_size
)
if world_size > 0 and list_ok:
payload = (
input_list
if accelerator.process_index == src and input_list is not None
else [None for _ in range(world_size)]
)
dist.broadcast_object_list(payload, src=src)
try:
return payload[int(accelerator.process_index)]
except (IndexError, TypeError, ValueError):
return None
scatter_fn = getattr(accelerator, "scatter_object", None)
if callable(scatter_fn):
return scatter_fn(
input_list if accelerator.process_index == src else None,
src=src,
)
if dist is not None and dist.is_available() and dist.is_initialized():
if callable(getattr(dist, "broadcast_object_list", None)):
try:
world_size = int(dist.get_world_size())
except (AttributeError, TypeError, ValueError, RuntimeError):
world_size = int(getattr(accelerator, "num_processes", 1) or 1)
list_ok = input_list is None or (
isinstance(input_list, list) and len(input_list) == world_size
)
if world_size > 0 and list_ok:
payload = (
input_list
if accelerator.process_index == src and input_list is not None
else [None for _ in range(world_size)]
)
dist.broadcast_object_list(payload, src=src)
try:
return payload[int(accelerator.process_index)]
except (IndexError, TypeError, ValueError):
return None
output: List[Any] = [None]
dist.scatter_object_list(
output,
input_list if accelerator.process_index == src else None,
src=src,
)
return output[0]
if input_list is None:
return None
return input_list[accelerator.process_index]
[docs]
class VLLMDistributedMixin:
"""Split out scatter/gather helpers from the vLLM helper."""
ctx: Any
def _flatten_prompts_for_broadcast(
self,
prompts: List[str],
per_prompt_counts: Optional[List[int]] = None,
) -> Tuple[List[str], List[int], Optional[List[int]]]:
"""Gather prompts and counts from all ranks and flatten them.
:param prompts: Local prompt list for the current rank.
:type prompts: list[str]
:param per_prompt_counts: Optional completion counts aligned to
``prompts``.
:type per_prompt_counts: list[int] | None
:returns: Tuple of flattened prompts, offsets indicating each rank's
slice start, and flattened counts if provided.
:rtype: tuple[list[str], list[int], list[int] | None]
"""
accelerator = self.ctx.accelerator
gathered = _gather_object_list(accelerator, prompts)
flat_prompts: List[str] = []
offsets: List[int] = []
running = 0
for group in gathered:
offsets.append(running)
running += len(group)
flat_prompts.extend(group)
flat_counts: Optional[List[int]] = None
if per_prompt_counts is not None:
gathered_counts = _gather_object_list(accelerator, per_prompt_counts)
flat_counts = []
for group in gathered_counts:
flat_counts.extend(int(val) for val in group)
return flat_prompts, offsets, flat_counts
def _build_scatter_payload(
self,
offsets: List[int],
world_size: int,
flat_prompts: List[str],
grouped_all: Optional[List[List[str]]],
meta_all: Optional[List[List[Optional[Any]]]],
) -> List[Tuple[List[List[str]], Optional[List[List[Optional[Any]]]]]]:
"""Build payload slices for scatter, trimming to each rank's prompt slice.
:param offsets: Offsets computed by ``_flatten_prompts_for_broadcast``.
:type offsets: list[int]
:param world_size: Total number of ranks.
:type world_size: int
:param flat_prompts: Flattened prompt list across all ranks.
:type flat_prompts: list[str]
:param grouped_all: Grouped completions aligned to ``flat_prompts``.
:type grouped_all: list[list[str]] | None
:param meta_all: Grouped metadata aligned to ``flat_prompts``.
:type meta_all: list[list[object | None]] | None
:returns: List of tuples containing grouped completions and metadata for
each rank.
:rtype: list[tuple[list[list[str]], list[list[object | None]] | None]]
"""
payload: List[Tuple[List[List[str]], Optional[List[List[Optional[Any]]]]]] = []
total = len(flat_prompts)
for rank in range(world_size):
start = offsets[rank]
end = offsets[rank + 1] if rank + 1 < len(offsets) else total
slice_grouped = [] if grouped_all is None else grouped_all[start:end]
slice_meta = None if meta_all is None else meta_all[start:end]
payload.append((slice_grouped, slice_meta))
return payload
def _scatter_vllm_payload(
self,
flat_prompts: List[str],
offsets: List[int],
grouped_all: Optional[List[List[str]]],
meta_all: Optional[List[List[Optional[Any]]]],
) -> Tuple[List[List[str]], Optional[List[List[Optional[Any]]]]]:
"""Scatter aggregated outputs from main process to all ranks.
:param flat_prompts: Flattened prompts gathered across ranks.
:type flat_prompts: list[str]
:param offsets: Per-rank offsets into ``flat_prompts``.
:type offsets: list[int]
:param grouped_all: Grouped completions generated on main process.
:type grouped_all: list[list[str]] | None
:param meta_all: Grouped metadata generated on main process.
:type meta_all: list[list[object | None]] | None
:returns: Local grouped completions and metadata for the current rank.
:rtype: tuple[list[list[str]], list[list[object | None]] | None]
"""
accelerator = self.ctx.accelerator
world_size = accelerator.num_processes
if world_size <= 1:
return self._pluck_rank_outputs(
grouped_all or [],
meta_all,
offsets,
flat_prompts,
)
scatter_payload: Optional[
List[Tuple[List[List[str]], Optional[List[List[Optional[Any]]]]]]
] = None
if accelerator.is_main_process:
scatter_payload = self._build_scatter_payload(
offsets,
world_size,
flat_prompts,
grouped_all,
meta_all,
)
scatter_fn = getattr(self, "_scatter_object", _scatter_object)
scatter_result = scatter_fn(accelerator, scatter_payload, src=0)
if scatter_result is None:
return [], None
grouped_local, meta_local = scatter_result
if grouped_local is None:
return [], None
filled_local: List[List[str]] = []
for group in grouped_local or []:
filled_local.append(group if group is not None else [])
return filled_local, meta_local
def _pluck_rank_outputs(
self,
grouped_all: List[List[str]],
meta_all: Optional[List[List[Optional[Any]]]],
offsets: List[int],
prompts: List[str],
) -> Tuple[List[List[str]], Optional[List[List[Optional[Any]]]]]:
"""Return this rank's slice from globally grouped outputs.
:param grouped_all: Grouped completions for every prompt across ranks.
:type grouped_all: list[list[str]]
:param meta_all: Grouped metadata for every prompt across ranks.
:type meta_all: list[list[object | None]] | None
:param offsets: Offsets produced by ``_flatten_prompts_for_broadcast``.
:type offsets: list[int]
:param prompts: Prompts owned by the current rank.
:type prompts: list[str]
:returns: Grouped completions and metadata for the current rank.
:rtype: tuple[list[list[str]], list[list[object | None]] | None]
"""
accelerator = self.ctx.accelerator
rank = accelerator.process_index
start = offsets[rank]
end = start + len(prompts)
grouped_local = grouped_all[start:end]
meta_local = None if meta_all is None else meta_all[start:end]
filled_local: List[List[str]] = []
for group in grouped_local:
filled_local.append(group if group is not None else [])
return filled_local, meta_local
__all__ = ["VLLMDistributedMixin", "_gather_object_list", "_scatter_object"]