Source code for maxent_grpo.training.runtime.prompts

"""Prompt-related helpers and sampling penalties."""

from __future__ import annotations
# pylint: disable=broad-exception-caught

import logging
import os
from dataclasses import dataclass
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Protocol,
    TYPE_CHECKING,
    Union,
    cast,
)

from maxent_grpo.prompt_templates import (
    normalize_prompt_template,
    render_prompt_template,
)

if TYPE_CHECKING:  # pragma: no cover - type checking only
    from transformers.tokenization_utils import PreTrainedTokenizer

LOG = logging.getLogger(__name__)
PROMPT_CHAR_LIMIT = int(
    os.environ.get("MAX_PROMPT_TOKENS", os.environ.get("MAX_PROMPT_CHARS", "2048"))
)
DEFAULT_PROMPT_SUFFIX = ""
DEFAULT_EVAL_PROMPT_SUFFIX = ""
_TRUNC_STATE = {"warned": False}


[docs] class ChatTokenizer(Protocol): """Protocol for tokenizers with chat template capabilities."""
[docs] def apply_chat_template( self, conversation: List[Dict[str, str]], tokenize: bool = True, add_generation_prompt: bool = True, ) -> Union[str, List[int]]: """Render a conversation into a model-ready prompt.""" raise NotImplementedError
@property def eos_token_id(self) -> Optional[int]: """Expose the EOS token id used by the tokenizer.""" raise NotImplementedError def __call__(self, *args: Any, **kwargs: Any) -> Any: """Allow chat tokenizers to be invoked like standard HF tokenizers.""" raise NotImplementedError
[docs] @dataclass class GenerationPenaltyConfig: """Shared penalty/stop sequence overrides for completion sampling.""" gen_top_k: Optional[int] = None gen_best_of: Optional[int] = None gen_frequency_penalty: float = 0.0 gen_presence_penalty: float = 0.0 gen_stop_sequences: Optional[List[str]] = None
[docs] class GenerationPenaltyPassthroughMixin: """Expose penalty overrides via legacy ``gen_*`` accessors.""" penalty: GenerationPenaltyConfig @property def gen_top_k(self) -> Optional[int]: """Backward-compatible alias for the top-k sampling limit.""" return self.penalty.gen_top_k @gen_top_k.setter def gen_top_k(self, value: Optional[int]) -> None: self.penalty.gen_top_k = value @property def gen_best_of(self) -> Optional[int]: """Backward-compatible alias for the best-of sampling count.""" return self.penalty.gen_best_of @gen_best_of.setter def gen_best_of(self, value: Optional[int]) -> None: self.penalty.gen_best_of = value @property def gen_frequency_penalty(self) -> float: """Backward-compatible alias for the frequency penalty strength.""" return self.penalty.gen_frequency_penalty @gen_frequency_penalty.setter def gen_frequency_penalty(self, value: float) -> None: self.penalty.gen_frequency_penalty = value @property def gen_presence_penalty(self) -> float: """Backward-compatible alias for the presence penalty strength.""" return self.penalty.gen_presence_penalty @gen_presence_penalty.setter def gen_presence_penalty(self, value: float) -> None: self.penalty.gen_presence_penalty = value @property def gen_stop_sequences(self) -> Optional[List[str]]: """Backward-compatible alias for stop sequences.""" return self.penalty.gen_stop_sequences @gen_stop_sequences.setter def gen_stop_sequences(self, value: Optional[List[str]]) -> None: self.penalty.gen_stop_sequences = value
[docs] def truncate_prompt( prompt: str, char_limit: Optional[int] = None, *, tokenizer: Optional[Any] = None, max_tokens: Optional[int] = None, ) -> str: """Clamp prompt strings to a safe token length when possible. :param prompt: Prompt string to clamp. :param char_limit: Optional character limit fallback. When ``None`` the module-level ``PROMPT_CHAR_LIMIT`` is used. :param tokenizer: Optional tokenizer used to enforce token limits. :param max_tokens: Optional token limit override (preferred when tokenizer is available). :returns: The original prompt when under the limit, otherwise a truncated prefix. :rtype: str """ token_limit = max_tokens if max_tokens is not None else char_limit if tokenizer is not None and token_limit is not None and token_limit > 0: try: encoded = tokenizer(prompt, add_special_tokens=False) ids = encoded.get("input_ids") if isinstance(encoded, dict) else encoded if hasattr(ids, "tolist"): try: ids = ids.tolist() except (TypeError, ValueError, AttributeError, RuntimeError): pass if isinstance(ids, list) and ids and isinstance(ids[0], list): ids = ids[0] if isinstance(ids, list) and len(ids) > token_limit: decode = getattr(tokenizer, "decode", None) if callable(decode): decode_fn = cast(Callable[..., str], decode) truncated = decode_fn( # pylint: disable=not-callable ids[:token_limit], skip_special_tokens=False ) if not _TRUNC_STATE.get("warned_tokens", False): LOG.warning( "Prompt length exceeded %d tokens; truncating. " "Override via MAX_PROMPT_TOKENS if needed.", token_limit, ) _TRUNC_STATE["warned_tokens"] = True return truncated except Exception as exc: if not _TRUNC_STATE.get("warned_tokens_error", False): LOG.debug("Token-based prompt truncation failed; falling back: %s", exc) _TRUNC_STATE["warned_tokens_error"] = True limit = char_limit if char_limit is not None else PROMPT_CHAR_LIMIT if limit <= 0 or len(prompt) <= limit: return prompt if not _TRUNC_STATE["warned"]: LOG.warning( "Prompt length exceeded %d characters; truncating. " "Override via MAX_PROMPT_TOKENS (or MAX_PROMPT_CHARS for legacy) if needed.", limit, ) _TRUNC_STATE["warned"] = True return prompt[:limit]
[docs] def sync_trunc_state(state: Dict[str, Any]) -> None: """Merge external truncation state into the shared warning cache. :param state: Dictionary of state keys to merge (e.g., ``{"warned": True}``). :returns: ``None``. """ if isinstance(state, dict): _TRUNC_STATE.update(state)
def _prompt_suffix_from_env(env_var: str, default: str) -> str: """Resolve a prompt suffix from environment variables.""" suffix = os.environ.get(env_var) if suffix is None: suffix = default return suffix
[docs] def append_prompt_suffix(prompt: str) -> str: """Append a format reminder to all prompts.""" return prompt
[docs] def append_eval_prompt_suffix(prompt: str) -> str: """Append a short eval-only format reminder to the prompt.""" return prompt
# Backwards compatibility for existing imports. _truncate_prompt = truncate_prompt def _prompt_char_limit_from_tokens(max_prompt_len: int) -> int: """Return the token cap used for prompt truncation. :param max_prompt_len: Maximum number of tokens allowed for prompts. :returns: Token limit used by ``truncate_prompt`` when tokenizers are available. :rtype: int """ if max_prompt_len and max_prompt_len > 0: return max(int(max_prompt_len), int(PROMPT_CHAR_LIMIT)) return PROMPT_CHAR_LIMIT def _require_prompt_column(example: Dict[str, Any], prompt_column: str) -> None: """Raise if the configured prompt column is missing from a dataset row.""" if prompt_column in example: return try: available = ", ".join(sorted(str(key) for key in example.keys())) except (AttributeError, TypeError): available = "<unknown>" raise KeyError( f"Missing prompt column '{prompt_column}' in dataset row. " f"Available columns: {available}" ) def _to_prompt( example: Dict[str, Any], tokenizer: Union["PreTrainedTokenizer", ChatTokenizer], prompt_column: str, system_prompt: Optional[str], char_limit: Optional[int] = None, *, return_messages: bool = False, prompt_template: Optional[str] = None, ) -> Dict[str, Any]: """Shared prompt/answer builder used across training pipelines. :param example: Dataset row containing a prompt and optional answer fields. :param tokenizer: Tokenizer or chat template adapter used to render prompts. :param prompt_column: Column name to read the user prompt from. :param system_prompt: Optional system prompt prepended to the conversation. :param char_limit: Optional character cap applied after formatting. :param return_messages: When True, return the raw conversation list instead of rendering a prompt string (TRL/open-r1 style). :param prompt_template: Optional SEED-style prompt template selector. When set, the prompt is rendered directly as a string and bypasses tokenizer chat templating. :returns: Mapping with ``prompt`` and ``answer`` fields. ``prompt`` is a string unless ``return_messages=True`` (then it is a list of messages). :rtype: dict[str, Any] :raises KeyError: If the prompt column is missing from the example. """ resolved_column = prompt_column if prompt_column not in example and prompt_column == "problem": for candidate in ("prompt", "question"): if candidate in example: LOG.info( "Prompt column '%s' missing; falling back to '%s'.", prompt_column, candidate, ) resolved_column = candidate break _require_prompt_column(example, resolved_column) user_val = example.get(resolved_column) user = "" if user_val is None else str(user_val) active_prompt_template = normalize_prompt_template(prompt_template, default=None) if active_prompt_template is not None: effective_limit = char_limit if char_limit is not None else PROMPT_CHAR_LIMIT prompt = render_prompt_template(user, active_prompt_template) prompt = truncate_prompt( prompt, effective_limit, tokenizer=tokenizer, max_tokens=effective_limit, ) return { "prompt": prompt, "answer": str(example.get("answer", example.get("solution", ""))), } messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": user}) if return_messages: return { "prompt": messages, "answer": str(example.get("answer", example.get("solution", ""))), } try: apply_fn = getattr(tokenizer, "apply_chat_template", None) if not callable(apply_fn): raise AttributeError("chat template missing or not callable") prompt = apply_fn(messages, tokenize=False, add_generation_prompt=True) if not isinstance(prompt, str): raise TypeError("chat template did not return a string prompt") except (AttributeError, TypeError, ValueError, RuntimeError): prompt = ( "\n".join(f"{m['role'].upper()}: {m['content']}" for m in messages) + "\nASSISTANT:" ) effective_limit = char_limit if char_limit is not None else PROMPT_CHAR_LIMIT min_required = len("USER: ") + len(user) + len("\nASSISTANT:") available_limit = effective_limit if available_limit and available_limit > 0 and available_limit < min_required: available_limit = min_required prompt = truncate_prompt( prompt, available_limit, tokenizer=tokenizer, max_tokens=available_limit, ) # Defensive: ensure the user message survives even if truncation or a template # removed it entirely. if user and user not in prompt: prompt = f"{prompt}\n{user}" return { "prompt": prompt, "answer": str(example.get("answer", example.get("solution", ""))), } __all__ = [ "ChatTokenizer", "GenerationPenaltyConfig", "GenerationPenaltyPassthroughMixin", "DEFAULT_PROMPT_SUFFIX", "DEFAULT_EVAL_PROMPT_SUFFIX", "PROMPT_CHAR_LIMIT", "_TRUNC_STATE", "_prompt_char_limit_from_tokens", "_to_prompt", "_truncate_prompt", "append_prompt_suffix", "append_eval_prompt_suffix", "sync_trunc_state", "truncate_prompt", ]