"""Public CompletionGenerator that wires local and vLLM helpers together."""
from __future__ import annotations
# pylint: disable=broad-exception-caught
import time
import logging
from typing import Any, Dict, List, Optional, Tuple
from maxent_grpo.training.patches.vllm import VLLMLogprobResult, safe_generate
from .context import GenerationContext
from .distributed import _scatter_object
from .local import LocalGenerationMixin
from .vllm_adapter import (
VLLMGenerationHelper,
VLLMGenerationMixin,
_is_peft_model_safe,
)
LOG = logging.getLogger(__name__)
[docs]
class CompletionGenerator(LocalGenerationMixin, VLLMGenerationMixin):
"""Stateful helper that handles both local HF and vLLM completions."""
def __init__(self, ctx: GenerationContext) -> None:
LocalGenerationMixin.__init__(self, ctx)
if hasattr(ctx, "accelerator"):
try:
VLLMGenerationMixin.__init__(self, ctx)
except (ImportError, RuntimeError, AttributeError, ValueError):
self._vllm_helper = None
else:
self._vllm_helper = None
helper = getattr(self, "_vllm_helper", None)
if helper is None:
helper_cls = globals().get("VLLMGenerationHelper", VLLMGenerationHelper)
self._vllm_helper = helper_cls(ctx, self._generate_local)
helper = self._vllm_helper
# Surface patchable hooks for tests so monkeypatched helpers.* propagate.
helper._safe_generate = safe_generate
helper._scatter_object = _scatter_object
helper._time = time
helper._is_peft_model_safe = _is_peft_model_safe
helper._fallback_generate = self._generate_local
if hasattr(self, "_configure_vllm_mode"):
try:
self._configure_vllm_mode()
except Exception as exc: # pragma: no cover - defensive
LOG.debug("vLLM mode configuration failed: %s", exc)
[docs]
def describe(self) -> Dict[str, Any]:
"""Expose the underlying generation configuration for logging."""
return self.ctx.as_dict()
[docs]
def generate(
self,
prompts: List[str],
num_samples: int,
per_prompt_counts: Optional[List[int]] = None,
) -> Tuple[List[List[str]], Optional[List[List[Optional[VLLMLogprobResult]]]]]:
"""Produce completions, preferring vLLM when configured."""
if per_prompt_counts is not None and len(per_prompt_counts) != len(prompts):
raise ValueError(
"per_prompt_counts length must match prompts length in generate()"
)
LOG.debug(
"CompletionGenerator.generate | prompts=%d | num_samples=%d | use_vllm=%s | per_prompt_counts=%s",
len(prompts),
num_samples,
getattr(self.ctx, "use_vllm", False),
f"len={len(per_prompt_counts)}"
if per_prompt_counts is not None
else "none",
)
if self.ctx.use_vllm:
return self._generate_vllm_collective(
prompts, num_samples, per_prompt_counts
)
if not prompts:
return [], None
LOG.debug("CompletionGenerator.generate using local HF path")
return self._generate_local(prompts, num_samples, per_prompt_counts)
__all__ = [
"CompletionGenerator",
"GenerationContext",
"safe_generate",
"_scatter_object",
"_is_peft_model_safe",
"time",
]