"""Pydantic-powered validation for Hydra training configs.
This module inspects the resolved training arguments before a pipeline is
launched so accidental MaxEnt toggles are caught early. The validator is kept
lightweight and only depends on :mod:`pydantic`, which is already part of the
runtime toolchain for several other components. Future guardrails can extend
this module by adding additional schema checks (including GRPO + entropy-bonus
overrides under ``train-maxent``).
"""
from __future__ import annotations
import warnings
from dataclasses import MISSING, Field as DataclassField, fields
from typing import Any, Mapping, MutableMapping, Sequence
from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator
from pydantic.warnings import UnsupportedFieldAttributeWarning
from maxent_grpo.config import GRPOConfig
from maxent_grpo.objectives import normalize_maxent_objective_variant, resolve_objective_routing
__all__ = [
"validate_training_config",
]
warnings.filterwarnings("ignore", category=UnsupportedFieldAttributeWarning)
def _field_default(field: DataclassField[object]) -> object | None:
if field.default is not MISSING:
return field.default
if field.default_factory is not MISSING:
return field.default_factory()
return None
_MAXENT_DEFAULTS = {
field.name: _field_default(field)
for field in fields(GRPOConfig)
if field.name.startswith("maxent_")
}
_DEFAULT_OBJECTIVE_BY_COMMAND = {
"train-baseline": "grpo",
"train-maxent": "maxent_entropy",
}
_GRPO_SAFE_MAXENT_KNOBS = {
"maxent_allow_empty_weight_fallback",
"maxent_allow_stale_reference_logprobs",
"maxent_length_normalize_ref",
"maxent_logprob_chunk_size",
"maxent_policy_entropy",
"maxent_policy_entropy_mode",
"maxent_prompt_cache_size",
"maxent_reference_logprobs_source",
"maxent_score_tail_tokens",
"maxent_trl_reference_scoring",
}
_REMOVED_TRAINING_KEYS = {
"maxent_reward_signal_gate",
"maxent_reward_signal_min_max",
"maxent_reward_signal_std_threshold",
"maxent_bonus_positive_only",
"maxent_bonus_min_reward",
"maxent_cusp_gate",
"maxent_cusp_reward_threshold",
"controller_meta_objective",
"controller_meta_analytic_steps",
"controller_meta_optimizer",
"controller_meta_truncation_steps",
"controller_meta_use_hessian",
}
class _TrainingSchema(BaseModel):
"""Minimal schema capturing the knobs that need cross-field validation."""
model_config = ConfigDict(extra="forbid")
objective: str | None = None
train_grpo_objective: bool | None = None
maxent_objective_variant: str | None = None
policy_entropy_bonus_coef: float | None = None
default_objective: str | None = Field(default=None)
maxent_overrides: dict[str, Any] = Field(default_factory=dict)
@model_validator(mode="after")
def _check_maxent_conflicts(self) -> "_TrainingSchema":
effective = resolve_objective_routing(
objective=self.objective,
train_grpo_objective=self.train_grpo_objective,
maxent_objective_variant=self.maxent_objective_variant,
policy_entropy_bonus_coef=self.policy_entropy_bonus_coef,
default_objective=self.default_objective or "maxent_entropy",
)
if effective.objective in {"grpo", "grpo_entropy_bonus"} and self.maxent_overrides:
knobs = ", ".join(sorted(self.maxent_overrides))
raise ValueError(
"MaxEnt overrides (%s) require objective=maxent_entropy "
"or objective=maxent_listwise" % knobs
)
return self
def _training_values(payload: Any) -> MutableMapping[str, Any]:
"""Return a mapping containing the knobs relevant to validation."""
if isinstance(payload, Mapping):
return {key: payload[key] for key in payload}
values: MutableMapping[str, Any] = {}
attr_names = set(_MAXENT_DEFAULTS)
attr_names |= {
"objective",
"train_grpo_objective",
"maxent_objective_variant",
"policy_entropy_bonus_coef",
"seed_grpo_enabled",
"seed_grpo_alpha",
"seed_grpo_alpha_normalize_by_max_entropy",
"seed_grpo_length_normalize_logprobs",
}
for name in attr_names:
if hasattr(payload, name):
values[name] = getattr(payload, name)
return values
def _numeric_or_none(value: Any) -> float | None:
try:
return float(value)
except (TypeError, ValueError):
return None
def _integer_or_none(value: Any) -> int | None:
try:
return int(value)
except (TypeError, ValueError):
return None
def _normalize_entropy_mode(value: Any) -> str:
"""Return the canonical entropy-mode label used by config validation."""
candidate = str(value or "exact").strip().lower()
if candidate in {"", "none", "exact", "full", "distribution"}:
return "exact"
if candidate in {
"sample",
"estimate",
"estimated",
"approx",
"approximate",
"token",
"token_logp",
"nll",
"logp",
}:
return "sample"
raise ValueError("maxent_policy_entropy_mode must be one of: exact, sample")
def _is_safe_grpo_maxent_override(name: str, value: Any) -> bool:
if name in _GRPO_SAFE_MAXENT_KNOBS:
return True
if name == "maxent_alpha":
numeric = _numeric_or_none(value)
return numeric is not None and numeric <= 0.0
if name == "policy_entropy_bonus_coef":
numeric = _numeric_or_none(value)
return numeric is not None and numeric > 0.0
if name == "maxent_objective_variant":
return normalize_maxent_objective_variant(value, default="entropy") == "entropy"
return False
def _maxent_overrides(values: Mapping[str, Any]) -> dict[str, Any]:
"""Return MaxEnt fields whose values differ from their defaults."""
overrides: dict[str, Any] = {}
for name, default in _MAXENT_DEFAULTS.items():
if name not in values:
continue
value = values[name]
if value is None and default is None:
continue
if value == default:
continue
if _is_safe_grpo_maxent_override(name, value):
continue
overrides[name] = value
return overrides
def _validate_listwise_microbatch_shape(values: Mapping[str, Any]) -> None:
routing = resolve_objective_routing(
objective=values.get("objective"),
train_grpo_objective=values.get("train_grpo_objective"),
maxent_objective_variant=values.get("maxent_objective_variant"),
maxent_alpha=values.get("maxent_alpha"),
policy_entropy_bonus_coef=values.get("policy_entropy_bonus_coef"),
)
if routing.train_grpo_objective:
return
if not routing.uses_listwise_loss:
return
tau = _numeric_or_none(values.get("maxent_tau"))
if tau is None or tau <= 0.0:
raise ValueError("listwise MaxEnt requires maxent_tau > 0")
if routing.maxent_alpha > 0.0:
raise ValueError("listwise MaxEnt does not use maxent_alpha; set it to 0")
num_generations = _integer_or_none(values.get("num_generations"))
if num_generations is None or num_generations <= 0:
return
for batch_name in ("per_device_train_batch_size", "per_device_eval_batch_size"):
batch_size = _integer_or_none(values.get(batch_name))
if batch_size is None or batch_size <= 0:
continue
if batch_size % num_generations != 0:
raise ValueError(
f"listwise MaxEnt requires {batch_name}={batch_size} to be divisible "
f"by num_generations={num_generations} so each trainer microbatch "
"contains whole prompt groups"
)
def _validate_entropy_objective_settings(values: Mapping[str, Any]) -> None:
"""Reject entropy-loss settings that do not match the implemented math."""
routing = resolve_objective_routing(
objective=values.get("objective"),
train_grpo_objective=values.get("train_grpo_objective"),
maxent_objective_variant=values.get("maxent_objective_variant"),
maxent_alpha=values.get("maxent_alpha"),
policy_entropy_bonus_coef=values.get("policy_entropy_bonus_coef"),
)
if routing.uses_entropy_regularized_loss:
entropy_mode = _normalize_entropy_mode(
values.get("maxent_policy_entropy_mode", "exact")
)
if entropy_mode != "exact":
raise ValueError(
"Entropy-regularized MaxEnt requires maxent_policy_entropy_mode='exact'; "
"sample mode is only valid for logging or GRPO reward bonuses."
)
def _validate_seed_grpo_settings(values: Mapping[str, Any]) -> None:
"""Reject SEED-GRPO knobs that are incompatible with the selected objective."""
seed_enabled = bool(values.get("seed_grpo_enabled", False))
alpha = _numeric_or_none(values.get("seed_grpo_alpha"))
if alpha is not None and alpha < 0.0:
raise ValueError("seed_grpo_alpha must be non-negative")
if not seed_enabled:
return
routing = resolve_objective_routing(
objective=values.get("objective"),
train_grpo_objective=values.get("train_grpo_objective"),
maxent_objective_variant=values.get("maxent_objective_variant"),
maxent_alpha=values.get("maxent_alpha"),
policy_entropy_bonus_coef=values.get("policy_entropy_bonus_coef"),
)
if not routing.train_grpo_objective:
raise ValueError(
"seed_grpo_enabled requires objective=grpo or objective=grpo_entropy_bonus"
)
def _source_hint(command: str, *, recipe: str | None, training_args: Any) -> str:
"""Return a short string pointing at the config origin for error messages."""
hints: list[str] = [command]
recipe_path = recipe or getattr(training_args, "recipe_path", None)
if recipe_path:
hints.append(str(recipe_path))
return " | ".join(hints)
def _format_validation_errors(errors: Sequence[Mapping[str, Any]]) -> str:
parts = []
for error in errors:
loc = error.get("loc") or ()
if isinstance(loc, tuple):
path = ".".join(str(item) for item in loc if item is not None)
else:
path = str(loc)
if path:
parts.append(f"{path}: {error.get('msg', '')}")
else:
parts.append(error.get("msg", ""))
return "; ".join(parts)
def _validate_removed_training_keys(values: Mapping[str, Any]) -> None:
removed = sorted(
name for name in _REMOVED_TRAINING_KEYS if values.get(name) is not None
)
if removed:
raise ValueError("Removed training keys are no longer supported: " + ", ".join(removed))
[docs]
def validate_training_config(
training_args: GRPOConfig | Mapping[str, Any],
*,
command: str,
source: str | None = None,
) -> None:
"""Validate Hydrated training knobs before dispatching to a pipeline.
The validator ensures that the canonical ``objective`` matches the presence of
MaxEnt-specific options. When MaxEnt knobs are supplied while the effective
objective stays on the native GRPO path, a :class:`ValueError` is raised so
the job fails fast.
:param training_args: Training dataclass or mapping derived from Hydra.
:param command: CLI command being executed (e.g., ``train-baseline``).
:param source: Optional user-facing hint (recipe path, override description).
:returns: ``None``. Raises on invalid or incompatible configurations.
:raises ValueError: If incompatible knob combinations are detected.
"""
values = _training_values(training_args)
effective_values = dict(values)
if (
effective_values.get("objective") is None
and effective_values.get("train_grpo_objective") is None
and effective_values.get("maxent_objective_variant") is None
and effective_values.get("policy_entropy_bonus_coef") is None
):
effective_default = _DEFAULT_OBJECTIVE_BY_COMMAND.get(command)
if effective_default is not None:
effective_values["objective"] = effective_default
try:
schema_payload = {
"objective": values.get("objective"),
"train_grpo_objective": values.get("train_grpo_objective"),
"maxent_objective_variant": values.get("maxent_objective_variant"),
"policy_entropy_bonus_coef": values.get("policy_entropy_bonus_coef"),
"default_objective": _DEFAULT_OBJECTIVE_BY_COMMAND.get(command),
"maxent_overrides": _maxent_overrides(values),
}
_TrainingSchema(**schema_payload)
_validate_removed_training_keys(effective_values)
_validate_listwise_microbatch_shape(effective_values)
_validate_entropy_objective_settings(effective_values)
_validate_seed_grpo_settings(effective_values)
except ValidationError as exc:
message = _format_validation_errors(exc.errors())
hint = _source_hint(command, recipe=source, training_args=training_args)
raise ValueError(f"{hint}: {message}") from exc
except ValueError as exc:
message = str(exc)
hint = _source_hint(command, recipe=source, training_args=training_args)
raise ValueError(f"{hint}: {message}") from exc