# Copyright 2025 Liv d'Aliberti
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Completion generation helpers for the MaxEnt-GRPO runner."""
from __future__ import annotations
import importlib
import sys
import time
from typing import Any, Callable
import maxent_grpo.training.rollout.vllm_adapter as _vllm_adapter
from maxent_grpo.training.patches.vllm import safe_generate
from maxent_grpo.training.generation.vllm import VLLMGenerationHelper
from maxent_grpo.training.runtime import require_torch
from maxent_grpo.training.runtime.prompts import PROMPT_CHAR_LIMIT, _truncate_prompt
from maxent_grpo.training.generation.common import (
AggregatedGenerationState as _AggregatedGenerationState,
append_completion_group as _append_completion_group,
determine_retry_limit as _determine_retry_limit,
pending_generation_indices as _pending_generation_indices,
retry_incomplete_prompts as _retry_incomplete_prompts,
seed_generation_groups as _seed_generation_groups_impl,
)
from .context import GenerationContext
from .local import LocalGenerationMixin
from .vllm_adapter import (
VLLMGenerationMixin,
_VLLMGenerationState,
_is_peft_model_safe,
_optional_import,
_zero3_gather_factory,
_import_vllm_client_cls as _adapter_import_vllm_client_cls,
)
torch = require_torch("generation")
_retry_incomplete_prompts_impl = _retry_incomplete_prompts
class _DistFallback:
"""Minimal torch.distributed stand-in for single-process tests."""
def is_available(self) -> bool:
return False
def is_initialized(self) -> bool:
return False
def get_world_size(self) -> int:
return 1
def all_gather_object(self, output: Any, value: Any) -> None:
if isinstance(output, list):
if output:
output[0] = value
else:
output.append(value)
def broadcast_object_list(self, _payload: Any, _src: int = 0) -> None:
return None
def scatter_object_list(
self, output: Any, input_list: Any = None, src: int = 0
) -> None:
if isinstance(output, list):
if output:
if isinstance(input_list, list) and 0 <= src < len(input_list):
output[0] = input_list[src]
else:
output[0] = None
else:
output.append(None)
def _ensure_dist(dist_obj: Any) -> Any:
if dist_obj is None:
return _DistFallback()
required = (
"is_available",
"is_initialized",
"get_world_size",
"all_gather_object",
"broadcast_object_list",
)
if any(not hasattr(dist_obj, name) for name in required):
return _DistFallback()
return dist_obj
dist = _ensure_dist(getattr(torch, "distributed", None))
def _refresh_vllm_globals() -> None:
"""Keep vLLM adapter globals in sync with test monkeypatches."""
_vllm_adapter.dist = dist
_vllm_adapter.safe_generate = safe_generate
_vllm_adapter.time = importlib.import_module("time")
globals()["_retry_incomplete_prompts_impl"] = _retry_incomplete_prompts
def _gather_object_list_wrapper(accelerator: Any, value: Any) -> Any:
_refresh_vllm_globals()
return _vllm_adapter.gather_object_list(accelerator, value)
def _broadcast_object_list_wrapper(
accelerator: Any, payload: Any, *, src: int = 0
) -> Any:
_refresh_vllm_globals()
return _vllm_adapter.broadcast_object_list(accelerator, payload, src=src)
def _scatter_object_wrapper(accelerator: Any, input_list: Any, *, src: int = 0) -> Any:
_refresh_vllm_globals()
return _vllm_adapter.scatter_object(accelerator, input_list, src=src)
# Expose wrapper functions that honor patched globals.
_broadcast_object_list = _broadcast_object_list_wrapper
_gather_object_list = _gather_object_list_wrapper
_scatter_object = _scatter_object_wrapper
def _import_vllm_client_cls(import_fn: Callable[[str], Any] | None = None) -> Any:
"""Import the TRL VLLMClient using the caller-provided optional import hook."""
if import_fn is None:
vllm_mod = sys.modules.get("trl.extras.vllm_client")
if vllm_mod is None:
return None
return getattr(vllm_mod, "VLLMClient", None)
resolved_import = import_fn or globals().get("_optional_import") or _optional_import
return _adapter_import_vllm_client_cls(resolved_import)
[docs]
class CompletionGenerator(LocalGenerationMixin, VLLMGenerationMixin):
"""Stateful helper that handles both local HF and vLLM completions."""
def __init__(self, ctx: GenerationContext) -> None:
LocalGenerationMixin.__init__(self, ctx)
VLLMGenerationMixin.__init__(self, ctx)
__all__ = [
"VLLMGenerationHelper",
"CompletionGenerator",
"GenerationContext",
"PROMPT_CHAR_LIMIT",
"_truncate_prompt",
"require_torch",
"safe_generate",
"torch",
"time",
"dist",
"_AggregatedGenerationState",
"_append_completion_group",
"_determine_retry_limit",
"_pending_generation_indices",
"_retry_incomplete_prompts",
"_retry_incomplete_prompts_impl",
"_seed_generation_groups_impl",
"_VLLMGenerationState",
"_broadcast_object_list",
"_gather_object_list",
"_import_vllm_client_cls",
"_is_peft_model_safe",
"_optional_import",
"_scatter_object",
"_zero3_gather_factory",
]