Source code for maxent_grpo.training.runtime.config
"""Configuration dataclasses for the training runtime."""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional
[docs]
@dataclass
class MaxEntOptions:
"""Lightweight knobs specific to MaxEnt sequence-level updates."""
tau: float = field(default_factory=lambda: float(os.environ.get("MAXENT_TAU", 0.2)))
q_temperature: float = field(
default_factory=lambda: float(os.environ.get("MAXENT_Q_TEMPERATURE", 1.0))
)
q_epsilon: float = field(
default_factory=lambda: float(os.environ.get("MAXENT_Q_EPS", 1e-6))
)
length_normalize_ref: bool = field(
default_factory=lambda: os.environ.get("MAXENT_LENGTH_NORM_REF", "1")
not in {"0", "false", "False"}
)
[docs]
@dataclass
class VLLMClientConfig:
"""Configuration for vLLM-backed completion generation with all exposed knobs."""
url: str
rounds_cfg: int
retry_sleep: float
backfill_local: bool
request_logprobs: bool
best_of: Optional[int] = None
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None
include_stop_str_in_output: bool = False
timeout: float = 120.0
max_retries: int = 3
backoff: float = 1.0
backoff_multiplier: float = 2.0
guided_json: Optional[str] = None
guided_regex: Optional[str] = None
logit_bias: Optional[Dict[str, float]] = None
request_id_prefix: Optional[str] = None
sync_weights: bool = False
[docs]
@dataclass
class GenerationSamplingConfig:
"""Shared completion sampling knobs (HF + vLLM)."""
max_prompt_len: int
max_completion_len: int
gen_temperature: float
gen_top_p: float
use_vllm: bool
vllm: VLLMClientConfig
vllm_mode: str = field(default="server", kw_only=True)
@property
def vllm_url(self) -> str:
"""Backward-compatible accessor for the vLLM endpoint URL."""
return self.vllm.url
@property
def vllm_rounds_cfg(self) -> int:
"""Backward-compatible accessor for the maximum vLLM retry rounds."""
return self.vllm.rounds_cfg
@property
def vllm_retry_sleep(self) -> float:
"""Backward-compatible accessor for the per-round retry sleep."""
return self.vllm.retry_sleep
@property
def vllm_backfill_local(self) -> bool:
"""Backward-compatible accessor for local fallback behavior."""
return self.vllm.backfill_local
@property
def vllm_request_logprobs(self) -> bool:
"""Backward-compatible accessor for whether to request logprobs."""
return self.vllm.request_logprobs
@property
def vllm_best_of(self) -> Optional[int]:
"""Backward-compatible accessor for the best-of sampling count."""
return self.vllm.best_of
@property
def vllm_frequency_penalty(self) -> float:
"""Backward-compatible accessor for the frequency penalty value."""
return self.vllm.frequency_penalty
@property
def vllm_presence_penalty(self) -> float:
"""Backward-compatible accessor for the presence penalty value."""
return self.vllm.presence_penalty
@property
def vllm_top_k(self) -> Optional[int]:
"""Backward-compatible accessor for the top-k sampling limit."""
return self.vllm.top_k
@property
def vllm_stop_sequences(self) -> Optional[List[str]]:
"""Backward-compatible accessor for stop sequences."""
return self.vllm.stop_sequences
@property
def vllm_include_stop_str_in_output(self) -> bool:
"""Whether vLLM should preserve matched stop strings in output text."""
return bool(getattr(self.vllm, "include_stop_str_in_output", False))
@property
def vllm_timeout(self) -> float:
"""Backward-compatible accessor for request timeout."""
return self.vllm.timeout
@property
def vllm_max_retries(self) -> int:
"""Backward-compatible accessor for maximum request retries."""
return self.vllm.max_retries
@property
def vllm_backoff(self) -> float:
"""Backward-compatible accessor for exponential backoff factor."""
return self.vllm.backoff
@property
def vllm_backoff_multiplier(self) -> float:
"""Multiplier applied to the backoff delay after each attempt."""
return getattr(self.vllm, "backoff_multiplier", 2.0)
@property
def vllm_guided_json(self) -> Optional[str]:
"""Backward-compatible accessor for JSON schema-guided decoding."""
return self.vllm.guided_json
@property
def vllm_guided_regex(self) -> Optional[str]:
"""Backward-compatible accessor for regex-guided decoding."""
return self.vllm.guided_regex
@property
def vllm_logit_bias(self) -> Optional[Dict[str, float]]:
"""Backward-compatible accessor for logit bias configuration."""
return self.vllm.logit_bias
@property
def vllm_request_id_prefix(self) -> Optional[str]:
"""Backward-compatible accessor for request-id prefixes."""
return self.vllm.request_id_prefix
@property
def vllm_sync_weights(self) -> bool:
"""Whether to push model weights to the vLLM server before generation."""
return bool(getattr(self.vllm, "sync_weights", False))
__all__ = [
"GenerationSamplingConfig",
"MaxEntOptions",
"VLLMClientConfig",
]