Source code for maxent_grpo.training.rollout.local

"""Local HF generation helpers split from the vLLM adapter."""

from __future__ import annotations
# pylint: disable=broad-exception-caught

from contextlib import nullcontext
import logging
import os
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, cast

from maxent_grpo.training.generation.vocab_guard import resolve_blocked_token_ids
from maxent_grpo.training.runtime import require_torch, require_transformer_base_classes
from maxent_grpo.training.runtime.prompts import PROMPT_CHAR_LIMIT, _truncate_prompt

from .context import GenerationContext

LOG = logging.getLogger(__name__)

torch = require_torch("generation")
try:
    PreTrainedModel, PreTrainedTokenizer = require_transformer_base_classes(
        "generation"
    )
except (
    ImportError,
    RuntimeError,
    ModuleNotFoundError,
):  # pragma: no cover - stub fallback
    PreTrainedModel = Any
    PreTrainedTokenizer = Any

if TYPE_CHECKING:
    import torch as torch_types
    from transformers.tokenization_utils import (
        PreTrainedTokenizer as PreTrainedTokenizerType,
    )

    Tensor = torch_types.Tensor
else:
    PreTrainedTokenizerType = Any
    Tensor = Any


def _env_flag(name: str, default: bool = False) -> bool:
    value = os.getenv(name)
    if value is None:
        return default
    return value.strip().lower() in {"1", "true", "yes", "y", "on"}


[docs] class LocalGenerationMixin: """Handle prompt expansion, tokenization, and local HF sampling.""" ctx: GenerationContext def __init__(self, ctx: GenerationContext) -> None: self.ctx = ctx
[docs] def describe(self) -> dict[str, Any]: """Expose the underlying generation configuration for logging.""" return self.ctx.as_dict()
def _prompt_char_limit(self) -> int: """Return the token limit applied to prompts for vLLM/local calls.""" try: helpers_mod = __import__( "maxent_grpo.training.rollout.helpers", fromlist=["PROMPT_CHAR_LIMIT"], ) limit_base = getattr(helpers_mod, "PROMPT_CHAR_LIMIT", PROMPT_CHAR_LIMIT) except ImportError: limit_base = PROMPT_CHAR_LIMIT limit_env = int(limit_base) if isinstance(limit_base, int) else 0 approx_limit = ( int(self.ctx.max_prompt_len) * 4 if self.ctx.max_prompt_len and self.ctx.max_prompt_len > 0 else 0 ) if limit_env <= 0 and approx_limit <= 0: return 0 if limit_env <= 0: return approx_limit if approx_limit <= 0: return limit_env return max(limit_env, approx_limit) def _build_local_prompt_requests( self, prompts: List[str], target_counts: List[int], ) -> Tuple[List[str], List[int]]: """Expand prompts by their requested counts for local sampling.""" expanded_prompts: List[str] = [] prompt_indices: List[int] = [] for idx, (prompt, target_count) in enumerate(zip(prompts, target_counts)): adjusted_target = max(0, int(target_count)) if adjusted_target <= 0: continue expanded_prompts.extend([prompt] * adjusted_target) prompt_indices.extend([idx] * adjusted_target) return expanded_prompts, prompt_indices def _tokenize_expanded_prompts( self, expanded_prompts: List[str], ) -> Tuple[Any, List[int]]: """Tokenize prompts for local generation and track prompt lengths.""" tokenizer = self.ctx.tokenizer if callable(tokenizer): try: encoder_inputs = tokenizer( expanded_prompts, return_tensors="pt", padding=True, truncation=True, max_length=self.ctx.max_prompt_len, ) except TypeError: encoder_inputs = tokenizer(expanded_prompts) if hasattr(encoder_inputs, "to"): encoder_inputs = encoder_inputs.to(self.ctx.device) mask = cast(Any, encoder_inputs["attention_mask"]) prompt_lengths = mask.sum(dim=1).detach().cpu().tolist() return encoder_inputs, prompt_lengths # Fallback for lightweight stubs that only provide ``decode``. lengths = [len(p) for p in expanded_prompts] class _Mask: def __init__(self, vals: List[int]) -> None: self._vals = vals def sum(self, _dim: int = 1) -> "_Mask": return self def detach(self) -> "_Mask": return self def cpu(self) -> "_Mask": return self def tolist(self) -> List[int]: return list(self._vals) class _Inputs(dict): def __init__(self, lens: List[int]) -> None: super().__init__(attention_mask=_Mask(lens)) def to(self, _device: Any) -> "_Inputs": return self return _Inputs(lengths), lengths def _run_local_model( self, encoder_inputs: Any, prompt_lengths: List[int], ) -> List[str]: """Run the HF model locally and decode completions.""" unwrap = getattr(self.ctx.accelerator, "unwrap_model", None) gen_model = unwrap(self.ctx.model) if callable(unwrap) else self.ctx.model no_grad = getattr(torch, "no_grad", None) or nullcontext dist = getattr(torch, "distributed", None) dist_initialized = bool( dist and hasattr(dist, "is_available") and hasattr(dist, "is_initialized") and dist.is_available() and dist.is_initialized() ) synced_gpus = _env_flag("MAXENT_LOCAL_SYNCED_GPUS", False) disable_dynamo = _env_flag("MAXENT_LOCAL_DISABLE_DYNAMO", dist_initialized) max_new_tokens = self.ctx.max_completion_len max_time: Optional[float] = None env_max_new = os.getenv("MAXENT_LOCAL_MAX_NEW_TOKENS") if env_max_new is not None: try: cap_val = int(env_max_new) if cap_val > 0: max_new_tokens = min(int(max_new_tokens), cap_val) except (TypeError, ValueError): LOG.warning( "Invalid MAXENT_LOCAL_MAX_NEW_TOKENS=%r; using max_new_tokens=%s", env_max_new, max_new_tokens, ) env_max_time = os.getenv("MAXENT_LOCAL_MAX_TIME_S") if env_max_time is not None: try: max_time_val = float(env_max_time) if max_time_val > 0: max_time = max_time_val except (TypeError, ValueError): LOG.warning( "Invalid MAXENT_LOCAL_MAX_TIME_S=%r; ignoring max_time override.", env_max_time, ) empty_cache = _env_flag("MAXENT_LOCAL_EMPTY_CACHE", False) if empty_cache and hasattr(torch, "cuda") and torch.cuda.is_available(): try: torch.cuda.empty_cache() except Exception: pass dynamo_ctx = nullcontext() if disable_dynamo: dynamo = getattr(torch, "_dynamo", None) disable_fn = ( getattr(dynamo, "disable", None) if dynamo is not None else None ) if callable(disable_fn): dynamo_ctx = disable_fn() LOG.debug( "HF generate start | model=%s | max_new_tokens=%s | max_time=%s | temp=%.3f | top_p=%.3f | top_k=%s | synced_gpus=%s | disable_dynamo=%s | empty_cache=%s", gen_model.__class__.__name__ if gen_model is not None else "None", max_new_tokens, max_time, self.ctx.gen_temperature, self.ctx.gen_top_p, self.ctx.gen_top_k, synced_gpus, disable_dynamo, empty_cache, ) with no_grad(), dynamo_ctx: generate_fn = getattr(gen_model, "generate", None) if callable(generate_fn): gen_cfg = getattr(gen_model, "generation_config", None) if gen_cfg is not None and hasattr(gen_cfg, "synced_gpus"): try: setattr(gen_cfg, "synced_gpus", bool(synced_gpus)) except Exception: pass generate_kwargs = dict( do_sample=True, temperature=self.ctx.gen_temperature, top_p=self.ctx.gen_top_p, top_k=( self.ctx.gen_top_k if self.ctx.gen_top_k is not None else None ), max_new_tokens=max_new_tokens, num_return_sequences=1, synced_gpus=synced_gpus, ) blocked_token_ids = resolve_blocked_token_ids(self.ctx) if blocked_token_ids: generate_kwargs["bad_words_ids"] = [ [int(token_id)] for token_id in blocked_token_ids ] if max_time is not None: generate_kwargs["max_time"] = max_time try: gen_out = generate_fn(**encoder_inputs, **generate_kwargs) except TypeError as exc: msg = str(exc) retry = False if "synced_gpus" in msg: generate_kwargs.pop("synced_gpus", None) retry = True if "max_time" in msg: generate_kwargs.pop("max_time", None) retry = True if retry: gen_out = generate_fn(**encoder_inputs, **generate_kwargs) else: raise else: # Fallback for lightweight stubs without generation support. gen_out = encoder_inputs gen_out_any = cast(Any, gen_out) return self._decode_sequences(gen_out_any, prompt_lengths, self.ctx.tokenizer) def _generate_local( self, prompts: List[str], num_samples: int, per_prompt_counts: Optional[List[int]] = None, ) -> Tuple[List[List[str]], Optional[List[List[Optional[Any]]]]]: """Generate completions using the local HF model.""" try: helpers_mod = __import__( "maxent_grpo.training.rollout.helpers", fromlist=["_truncate_prompt"] ) trunc_fn = getattr(helpers_mod, "_truncate_prompt", _truncate_prompt) except ImportError: trunc_fn = _truncate_prompt grouped: List[List[str]] = [[] for _ in prompts] if not prompts: return grouped, None char_limit = self._prompt_char_limit() prompts = [trunc_fn(prompt, char_limit) for prompt in prompts] target_counts = self._resolve_local_counts( prompts, num_samples, per_prompt_counts ) LOG.debug( "Local generation | prompts=%d | num_samples=%d | char_limit=%d | per_prompt_counts=%s", len(prompts), num_samples, char_limit, f"len={len(target_counts)}" if target_counts is not None else "none", ) expanded_prompts, prompt_indices = self._build_local_prompt_requests( prompts, target_counts, ) if not expanded_prompts: return grouped, None enc_inputs, prompt_lengths = self._tokenize_expanded_prompts(expanded_prompts) LOG.debug( "Local generation tokenize | expanded_prompts=%d | prompt_indices=%d | prompt_lengths_sample=%s", len(expanded_prompts), len(prompt_indices), prompt_lengths[: min(3, len(prompt_lengths))], ) decoded = self._run_local_model(enc_inputs, prompt_lengths) LOG.debug( "Local generation decode done | decoded=%d | first_prompt_count=%d", len(decoded), len(grouped[0]) if grouped else 0, ) for text, prompt_idx in zip(decoded, prompt_indices): grouped[prompt_idx].append(text) return grouped, None @staticmethod def _resolve_local_counts( prompts: List[str], default_count: int, overrides: Optional[List[int]], ) -> List[int]: """Resolve per-prompt generation counts for local sampling.""" if overrides is None: return [default_count] * len(prompts) if len(overrides) != len(prompts): raise ValueError("per_prompt_counts length must match prompts length") return overrides @staticmethod def _decode_sequences( sequences: Any, prompt_lengths: List[int], tokenizer: PreTrainedTokenizerType, ) -> List[str]: """Decode model outputs into completion strings.""" outputs: List[str] = [] for row, prompt_len in zip(sequences, prompt_lengths): completion_ids = row[int(prompt_len) :] try: outputs.append( tokenizer.decode(completion_ids, skip_special_tokens=True) ) except AttributeError: # Minimal tokenizer fallback: stringify the ids. try: outputs.append(" ".join(str(int(tok)) for tok in completion_ids)) except (TypeError, ValueError): outputs.append(str(completion_ids)) return outputs
__all__ = ["LocalGenerationMixin"]