Source code for maxent_grpo.training.generation.vocab_guard

"""Shared helpers for masking model-only token IDs during generation."""

from __future__ import annotations

import logging
import os
from typing import Any, Dict, List, Optional, cast

from maxent_grpo.training.scoring_common import (
    _coerce_optional_int,
    _get_config_value,
    _get_embedding_vocab_size,
)

LOG = logging.getLogger(__name__)
_INVALID_TOKEN_BLOCK_BIAS = -1.0e9


def _resolve_served_model_id(ctx: Any) -> Optional[str]:
    """Best-effort resolution of the external vLLM-served model identifier."""

    env_model = os.getenv("MAXENT_VLLM_SERVER_MODEL_NAME")
    if isinstance(env_model, str) and env_model.strip():
        return env_model.strip()

    for key in (
        "vllm_model_id",
        "served_model_id",
        "model_name",
        "model_id",
        "hub_model_id",
        "model_name_or_path",
    ):
        value = getattr(ctx, key, None)
        if isinstance(value, str) and value.strip():
            return value.strip()

    training_args = getattr(ctx, "training_args", None)
    if training_args is not None:
        for key in ("model_name_or_path", "hub_model_id", "model_id"):
            value = getattr(training_args, key, None)
            if isinstance(value, str) and value.strip():
                return value.strip()
    return None


def _resolve_served_model_vocab_limit(ctx: Any) -> Optional[int]:
    """Return the output-vocab width exposed by the external vLLM model."""

    cached = getattr(ctx, "_served_model_vocab_limit", None)
    if isinstance(cached, int) and cached > 0:
        return int(cached)

    env_limit = _coerce_optional_int(os.getenv("MAXENT_VLLM_SERVER_MODEL_VOCAB_LIMIT"))
    if isinstance(env_limit, int) and env_limit > 0:
        setattr(ctx, "_served_model_vocab_limit", int(env_limit))
        return int(env_limit)

    model_id = _resolve_served_model_id(ctx)
    if not model_id:
        return None

    try:
        from transformers import AutoConfig

        config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
    except Exception as exc:  # pragma: no cover - defensive
        LOG.debug("Unable to resolve served model vocab limit for %s: %s", model_id, exc)
        return None

    vocab_limit = _coerce_optional_int(_get_config_value(config, "vocab_size", None))
    if not isinstance(vocab_limit, int) or vocab_limit <= 0:
        return None

    setattr(ctx, "_served_model_vocab_limit", int(vocab_limit))
    if not bool(getattr(ctx, "_served_model_vocab_limit_logged", False)):
        LOG.warning(
            "Resolved served-model vocab limit=%d for server-mode vLLM generation (model=%s).",
            int(vocab_limit),
            model_id,
        )
        setattr(ctx, "_served_model_vocab_limit_logged", True)
    return int(vocab_limit)


[docs] def resolve_tokenizer_vocab_limit(tokenizer: Any) -> Optional[int]: """Return the maximum token id addressable by the tokenizer plus one.""" if tokenizer is None: return None candidates: List[int] = [] for attr in ("vocab_size",): value = _coerce_optional_int(getattr(tokenizer, attr, None)) if value is not None and value > 0: candidates.append(int(value)) try: tokenizer_len = len(tokenizer) except Exception: tokenizer_len = None if isinstance(tokenizer_len, int) and tokenizer_len > 0: candidates.append(int(tokenizer_len)) if not candidates: return None return max(candidates)
[docs] def resolve_model_vocab_limit(ctx: Any) -> Optional[int]: """Return the model output-vocab width exposed to generation.""" model = getattr(ctx, "model", None) accelerator = getattr(ctx, "accelerator", None) unwrap_fn = getattr(accelerator, "unwrap_model", None) base_model = model if callable(unwrap_fn): try: base_model = unwrap_fn(model) except Exception: base_model = model if base_model is None: return None config = getattr(base_model, "config", None) embedding_vocab = _get_embedding_vocab_size(base_model, config) config_vocab = _coerce_optional_int(_get_config_value(config, "vocab_size", None)) candidates = [ int(value) for value in (embedding_vocab, config_vocab) if isinstance(value, int) and int(value) > 0 ] use_vllm = bool(getattr(ctx, "use_vllm", False)) vllm_mode = str(getattr(ctx, "vllm_mode", "server") or "server").strip().lower() if use_vllm and vllm_mode == "server": served_model_vocab = _resolve_served_model_vocab_limit(ctx) if isinstance(served_model_vocab, int) and served_model_vocab > 0: candidates.append(int(served_model_vocab)) if not candidates: return None return max(candidates)
[docs] def merge_invalid_token_block_logit_bias( ctx: Any, existing_bias: Any, ) -> Optional[Dict[str, float]]: """Block model-only token IDs that the tokenizer cannot represent.""" tokenizer = getattr(ctx, "tokenizer", None) tokenizer_limit = resolve_tokenizer_vocab_limit(tokenizer) model_limit = resolve_model_vocab_limit(ctx) if ( not isinstance(tokenizer_limit, int) or tokenizer_limit <= 0 or not isinstance(model_limit, int) or model_limit <= tokenizer_limit ): return cast(Optional[Dict[str, float]], existing_bias) merged: Dict[str, float] = {} if isinstance(existing_bias, dict): for key, value in existing_bias.items(): try: merged[str(int(key))] = float(value) except (TypeError, ValueError): continue blocked = 0 for token_id in range(int(tokenizer_limit), int(model_limit)): key = str(int(token_id)) prev = merged.get(key) if prev is None or prev > _INVALID_TOKEN_BLOCK_BIAS: merged[key] = _INVALID_TOKEN_BLOCK_BIAS blocked += 1 if not bool(getattr(ctx, "_vllm_invalid_token_block_logged", False)): LOG.warning( "Blocking %d tokenizer-inaccessible token IDs for vLLM generation (tokenizer_limit=%d, model_limit=%d).", blocked, tokenizer_limit, model_limit, ) setattr(ctx, "_vllm_invalid_token_block_logged", True) stats = getattr(ctx, "generation_stats", None) if isinstance(stats, dict): stats["vllm_invalid_token_block_count"] = blocked stats["vllm_invalid_token_block_min_id"] = int(tokenizer_limit) stats["vllm_invalid_token_block_max_id"] = int(model_limit - 1) return merged
[docs] def resolve_allowed_token_ids(ctx: Any) -> Optional[List[int]]: """Return a cached hard allowlist for tokenizer-addressable token IDs.""" tokenizer = getattr(ctx, "tokenizer", None) tokenizer_limit = resolve_tokenizer_vocab_limit(tokenizer) model_limit = resolve_model_vocab_limit(ctx) if ( not isinstance(tokenizer_limit, int) or tokenizer_limit <= 0 or not isinstance(model_limit, int) or model_limit <= tokenizer_limit ): return None cached = getattr(ctx, "_vllm_allowed_token_ids", None) cached_limit = getattr(ctx, "_vllm_allowed_token_ids_limit", None) if isinstance(cached, list) and cached_limit == int(tokenizer_limit): return cached allowed = list(range(int(tokenizer_limit))) setattr(ctx, "_vllm_allowed_token_ids", allowed) setattr(ctx, "_vllm_allowed_token_ids_limit", int(tokenizer_limit)) if not bool(getattr(ctx, "_vllm_allowed_token_ids_logged", False)): LOG.warning( "Allowing only %d tokenizer-addressable token IDs for vLLM generation (tokenizer_limit=%d, model_limit=%d).", tokenizer_limit, tokenizer_limit, model_limit, ) setattr(ctx, "_vllm_allowed_token_ids_logged", True) stats = getattr(ctx, "generation_stats", None) if isinstance(stats, dict): stats["vllm_allowed_token_ids_count"] = int(tokenizer_limit) return allowed
[docs] def resolve_blocked_token_ids(ctx: Any) -> List[int]: """Return tokenizer-inaccessible model token IDs for local generation guards.""" tokenizer = getattr(ctx, "tokenizer", None) tokenizer_limit = resolve_tokenizer_vocab_limit(tokenizer) model_limit = resolve_model_vocab_limit(ctx) if ( not isinstance(tokenizer_limit, int) or tokenizer_limit <= 0 or not isinstance(model_limit, int) or model_limit <= tokenizer_limit ): return [] cached = getattr(ctx, "_local_blocked_token_ids", None) cached_limits = getattr(ctx, "_local_blocked_token_ids_limits", None) if ( isinstance(cached, list) and isinstance(cached_limits, tuple) and cached_limits == (int(tokenizer_limit), int(model_limit)) ): return cached blocked = list(range(int(tokenizer_limit), int(model_limit))) setattr(ctx, "_local_blocked_token_ids", blocked) setattr( ctx, "_local_blocked_token_ids_limits", (int(tokenizer_limit), int(model_limit)), ) if not bool(getattr(ctx, "_local_invalid_token_block_logged", False)): LOG.warning( "Blocking %d tokenizer-inaccessible token IDs for local generation (tokenizer_limit=%d, model_limit=%d).", len(blocked), tokenizer_limit, model_limit, ) setattr(ctx, "_local_invalid_token_block_logged", True) stats = getattr(ctx, "generation_stats", None) if isinstance(stats, dict): stats["local_invalid_token_block_count"] = len(blocked) return blocked