Source code for maxent_grpo.training.rollout.helpers

# Copyright 2025 Liv d'Aliberti
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Completion generation helpers for the MaxEnt-GRPO runner."""

from __future__ import annotations

import importlib
import sys
import time
from typing import Any, Callable

import maxent_grpo.training.rollout.vllm_adapter as _vllm_adapter
from maxent_grpo.training.patches.vllm import safe_generate
from maxent_grpo.training.generation.vllm import VLLMGenerationHelper
from maxent_grpo.training.runtime import require_torch
from maxent_grpo.training.runtime.prompts import PROMPT_CHAR_LIMIT, _truncate_prompt
from maxent_grpo.training.generation.common import (
    AggregatedGenerationState as _AggregatedGenerationState,
    append_completion_group as _append_completion_group,
    determine_retry_limit as _determine_retry_limit,
    pending_generation_indices as _pending_generation_indices,
    retry_incomplete_prompts as _retry_incomplete_prompts,
    seed_generation_groups as _seed_generation_groups_impl,
)

from .context import GenerationContext
from .local import LocalGenerationMixin
from .vllm_adapter import (
    VLLMGenerationMixin,
    _VLLMGenerationState,
    _is_peft_model_safe,
    _optional_import,
    _zero3_gather_factory,
    _import_vllm_client_cls as _adapter_import_vllm_client_cls,
)

torch = require_torch("generation")
_retry_incomplete_prompts_impl = _retry_incomplete_prompts


class _DistFallback:
    """Minimal torch.distributed stand-in for single-process tests."""

    def is_available(self) -> bool:
        return False

    def is_initialized(self) -> bool:
        return False

    def get_world_size(self) -> int:
        return 1

    def all_gather_object(self, output: Any, value: Any) -> None:
        if isinstance(output, list):
            if output:
                output[0] = value
            else:
                output.append(value)

    def broadcast_object_list(self, _payload: Any, _src: int = 0) -> None:
        return None

    def scatter_object_list(
        self, output: Any, input_list: Any = None, src: int = 0
    ) -> None:
        if isinstance(output, list):
            if output:
                if isinstance(input_list, list) and 0 <= src < len(input_list):
                    output[0] = input_list[src]
                else:
                    output[0] = None
            else:
                output.append(None)


def _ensure_dist(dist_obj: Any) -> Any:
    if dist_obj is None:
        return _DistFallback()
    required = (
        "is_available",
        "is_initialized",
        "get_world_size",
        "all_gather_object",
        "broadcast_object_list",
    )
    if any(not hasattr(dist_obj, name) for name in required):
        return _DistFallback()
    return dist_obj


dist = _ensure_dist(getattr(torch, "distributed", None))


def _refresh_vllm_globals() -> None:
    """Keep vLLM adapter globals in sync with test monkeypatches."""

    _vllm_adapter.dist = dist
    _vllm_adapter.safe_generate = safe_generate
    _vllm_adapter.time = importlib.import_module("time")
    globals()["_retry_incomplete_prompts_impl"] = _retry_incomplete_prompts


def _gather_object_list_wrapper(accelerator: Any, value: Any) -> Any:
    _refresh_vllm_globals()
    return _vllm_adapter.gather_object_list(accelerator, value)


def _broadcast_object_list_wrapper(
    accelerator: Any, payload: Any, *, src: int = 0
) -> Any:
    _refresh_vllm_globals()
    return _vllm_adapter.broadcast_object_list(accelerator, payload, src=src)


def _scatter_object_wrapper(accelerator: Any, input_list: Any, *, src: int = 0) -> Any:
    _refresh_vllm_globals()
    return _vllm_adapter.scatter_object(accelerator, input_list, src=src)


# Expose wrapper functions that honor patched globals.
_broadcast_object_list = _broadcast_object_list_wrapper
_gather_object_list = _gather_object_list_wrapper
_scatter_object = _scatter_object_wrapper


def _import_vllm_client_cls(import_fn: Callable[[str], Any] | None = None) -> Any:
    """Import the TRL VLLMClient using the caller-provided optional import hook."""
    if import_fn is None:
        vllm_mod = sys.modules.get("trl.extras.vllm_client")
        if vllm_mod is None:
            return None
        return getattr(vllm_mod, "VLLMClient", None)
    resolved_import = import_fn or globals().get("_optional_import") or _optional_import
    return _adapter_import_vllm_client_cls(resolved_import)


[docs] class CompletionGenerator(LocalGenerationMixin, VLLMGenerationMixin): """Stateful helper that handles both local HF and vLLM completions.""" def __init__(self, ctx: GenerationContext) -> None: LocalGenerationMixin.__init__(self, ctx) VLLMGenerationMixin.__init__(self, ctx)
__all__ = [ "VLLMGenerationHelper", "CompletionGenerator", "GenerationContext", "PROMPT_CHAR_LIMIT", "_truncate_prompt", "require_torch", "safe_generate", "torch", "time", "dist", "_AggregatedGenerationState", "_append_completion_group", "_determine_retry_limit", "_pending_generation_indices", "_retry_incomplete_prompts", "_retry_incomplete_prompts_impl", "_seed_generation_groups_impl", "_VLLMGenerationState", "_broadcast_object_list", "_gather_object_list", "_import_vllm_client_cls", "_is_peft_model_safe", "_optional_import", "_scatter_object", "_zero3_gather_factory", ]