Source code for maxent_grpo.training.controller_objective

"""Meta-controller objectives for tau/beta adaptation."""

from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Any, Callable, Optional

from maxent_grpo.config import GRPOConfig
from .weighting.types import WeightingSettings

LOG = logging.getLogger(__name__)


[docs] @dataclass class ControllerGradients: """Gradient bundle returned by controller objectives.""" tau_grad: Optional[float] = None beta_grad: Optional[float] = None
[docs] def has_updates(self) -> bool: return any( isinstance(val, (int, float)) for val in (self.tau_grad, self.beta_grad) )
[docs] @dataclass class ControllerMetaContext: """Inputs made available to controller objectives.""" weighting: WeightingSettings weight_stats: Any loss_outputs: Any global_step: int lr_scale: float = 1.0 prepared_batch: Any = None kl_value: Optional[float] = None backprop_fn: Optional[Callable[[int], Optional[ControllerGradients]]] = None
[docs] def entropy_value(self) -> Optional[float]: """Return the batch entropy used for tau updates (handles logging views).""" if not getattr(self.weighting, "train_grpo_objective", True): scores = getattr(self.prepared_batch, "scores", None) entropy_sum = getattr(scores, "policy_entropy_sum", None) token_counts = getattr(scores, "denom_tok_tensor", None) if entropy_sum is not None and token_counts is not None: try: entropy_total = float( entropy_sum.detach().float().sum().cpu().item() ) token_total = float( token_counts.detach().float().sum().cpu().item() ) except (AttributeError, RuntimeError, TypeError, ValueError): try: entropy_total = float(entropy_sum.sum()) token_total = float(token_counts.sum()) except (AttributeError, RuntimeError, TypeError, ValueError): entropy_total = token_total = None if ( isinstance(entropy_total, (int, float)) and isinstance(token_total, (int, float)) and token_total > 0 ): return float(entropy_total) / float(token_total) for attr in ("weight_entropy", "entropy"): value = getattr(self.weight_stats, attr, None) if isinstance(value, (int, float)): return float(value) return None
[docs] def kl_metric(self) -> Optional[float]: """Return the KL metric supplied by the loss or fallback to cached value.""" if isinstance(self.kl_value, (int, float)): return float(self.kl_value) val = getattr(self.loss_outputs, "kl_loss_scalar", None) if isinstance(val, (int, float)): return float(val) return None
[docs] class ControllerObjective: """Base class for controller objectives.""" name = "base"
[docs] def compute(self, meta_ctx: ControllerMetaContext) -> Optional[ControllerGradients]: raise NotImplementedError
[docs] class AnalyticControllerObjective(ControllerObjective): """Closed-form gradients based on entropy/KL targets.""" name = "analytic"
[docs] def compute(self, meta_ctx: ControllerMetaContext) -> Optional[ControllerGradients]: gradients = ControllerGradients() entropy_val = meta_ctx.entropy_value() target_entropy = meta_ctx.weighting.tau_target_entropy if entropy_val is not None and target_entropy is not None: gradients.tau_grad = entropy_val - float(target_entropy) kl_val = meta_ctx.kl_metric() target_kl = meta_ctx.weighting.kl_target if kl_val is not None and target_kl > 0: gradients.beta_grad = kl_val - float(target_kl) return gradients if gradients.has_updates() else None
[docs] class TruncatedBackpropControllerObjective(ControllerObjective): """Truncated meta-gradient objective relying on a user-supplied callback.""" name = "truncated_backprop" def __init__(self, steps: int = 1) -> None: self.steps = max(1, int(steps))
[docs] def compute(self, meta_ctx: ControllerMetaContext) -> Optional[ControllerGradients]: backprop_fn = meta_ctx.backprop_fn if callable(backprop_fn): try: result = backprop_fn(self.steps) except RuntimeError as exc: LOG.warning("Controller backprop callback failed: %s", exc) result = None if result and result.has_updates(): return result # Fallback to analytic gradients so the controller still makes progress. return AnalyticControllerObjective().compute(meta_ctx)
[docs] def build_controller_objective( cfg: GRPOConfig, weighting: WeightingSettings ) -> Optional[ControllerObjective]: """Return the configured controller objective for the current run. :param cfg: Training configuration (retained for compatibility; not used). :param weighting: Weighting settings containing controller meta config. :returns: Controller objective instance or ``None`` when disabled. :rtype: ControllerObjective | None """ del cfg # legacy argument retained for compatibility meta_cfg = getattr(weighting, "controller_meta", None) if meta_cfg is None or not getattr(meta_cfg, "enabled", False): return None method = str(getattr(meta_cfg, "method", "analytic") or "analytic").lower() if method in ("analytic", "analytic_grad", "potential"): return AnalyticControllerObjective() if method in ("first_order", "truncated", "truncated_backprop", "backprop"): steps = getattr(meta_cfg, "truncation_steps", None) if steps is None or steps <= 0: steps = getattr(meta_cfg, "analytic_steps", 1) return TruncatedBackpropControllerObjective(steps=steps) LOG.warning( "Unknown controller_meta_method=%s; falling back to analytic gradients.", method, ) return AnalyticControllerObjective()
__all__ = [ "AnalyticControllerObjective", "ControllerGradients", "ControllerMetaContext", "ControllerObjective", "TruncatedBackpropControllerObjective", "build_controller_objective", ]