Source code for maxent_grpo.training.controller_optimizer

"""Meta-optimizer orchestration for controller updates."""

from __future__ import annotations

import math
import logging
from contextlib import nullcontext
from typing import Any, Callable, ContextManager, Optional, cast

from maxent_grpo.config import GRPOConfig
from maxent_grpo.training.runtime import require_torch
from .weighting.logic import _ensure_tau_history
from .weighting.types import WeightingSettings
from .controller_objective import ControllerGradients

LOG = logging.getLogger(__name__)


def _no_grad_context(torch_mod: Any) -> ContextManager[None]:
    """Return a torch.no_grad context when available, else a nullcontext."""

    ctx = getattr(torch_mod, "no_grad", None)
    if callable(ctx):
        return cast(ContextManager[None], ctx())
    if ctx is not None and hasattr(ctx, "__enter__") and hasattr(ctx, "__exit__"):
        return cast(ContextManager[None], ctx)
    return nullcontext()


[docs] class ControllerMetaManager: """Manage meta-controller optimizer state and cadence.""" def __init__(self, cfg: GRPOConfig, weighting: WeightingSettings) -> None: meta_cfg = getattr(weighting, "controller_meta", None) disable_reason = None self.enabled = bool(meta_cfg and meta_cfg.enabled) if not self.enabled: disable_reason = "controller_meta_enabled flag is false" self.update_interval = max( 1, int( getattr( meta_cfg, "update_interval", getattr(cfg, "controller_meta_update_interval", 1), ) ), ) legacy_lr = float(getattr(meta_cfg, "learning_rate", cfg.controller_meta_lr)) tau_lr = float( getattr( meta_cfg, "tau_learning_rate", getattr(cfg, "controller_meta_tau_lr", 0.0), ) ) beta_lr = float( getattr( meta_cfg, "beta_learning_rate", getattr(cfg, "controller_meta_beta_lr", 0.0), ) ) self.tau_learning_rate = tau_lr if tau_lr > 0.0 else legacy_lr self.beta_learning_rate = beta_lr if beta_lr > 0.0 else legacy_lr self.learning_rate = max(self.tau_learning_rate, self.beta_learning_rate, 0.0) self.beta_grad_clip = float( getattr( meta_cfg, "beta_grad_clip", getattr(cfg, "controller_meta_beta_grad_clip", 0.0), ) ) self.method = str(getattr(meta_cfg, "method", "analytic") or "analytic").lower() self.optimizer = str(getattr(meta_cfg, "optimizer", "sgd") or "sgd").lower() self.objective_name = str( getattr(meta_cfg, "objective", cfg.controller_meta_objective) ) self.analytic_steps = max( 1, int( getattr( meta_cfg, "analytic_steps", cfg.controller_meta_analytic_steps or 1 ) ), ) self.truncation_steps = max( 1, int( getattr( meta_cfg, "truncation_steps", getattr( cfg, "controller_meta_truncation_steps", cfg.controller_meta_analytic_steps or 1, ), ) ), ) self.use_hessian = bool( getattr( meta_cfg, "use_hessian", getattr(cfg, "controller_meta_use_hessian", False), ) ) self._torch = None self._controller_state = getattr(weighting, "controller_state", None) self._meta_optimizer = None self._weighting = weighting # NOTE: These updates are not "true" meta-gradients by default. The # controller objective returns tau/beta update signals, and (when an # optimizer is enabled) we map those signals onto Parameter.grads so we # can use standard torch.optim state (e.g., Adam momentum). self._requires_optimizer = self.method in ( "analytic", "analytic_grad", "potential", "first_order", "truncated", "truncated_backprop", "backprop", ) if self.optimizer not in ("sgd", "adam", "adamw"): LOG.warning( "Unsupported controller_meta_optimizer=%s; falling back to analytic updates.", self.optimizer, ) self._requires_optimizer = False if self._controller_state is not None: if self.enabled: self._controller_state.enable_grad() else: self._controller_state.disable_grad() if self.enabled and self._requires_optimizer: try: self._torch = require_torch("controller_meta") if self._controller_state is None: raise RuntimeError("controller_state required for meta optimizer") params = self._controller_state.parameters() if not params: raise RuntimeError("controller_state missing parameters") self._meta_optimizer = self._build_optimizer() except ( ImportError, ModuleNotFoundError, AttributeError, RuntimeError, ) as exc: # pragma: no cover LOG.warning( "Controller meta-optimizer falling back to analytic updates: %s", exc, ) self._requires_optimizer = False self._torch = None self._meta_optimizer = None if self.tau_learning_rate <= 0.0 and self.beta_learning_rate <= 0.0: self.enabled = False disable_reason = "controller_meta learning rates <= 0" proc_index = getattr( cfg, "process_index", getattr(cfg, "local_process_index", 0) ) if proc_index in (0, None): if self.enabled: update_mode = self.optimizer if self._requires_optimizer else "analytic" LOG.info( "Controller meta enabled | method=%s | objective=%s | tau_lr=%.4g | beta_lr=%.4g | beta_grad_clip=%s | update_interval=%d | update_mode=%s", self.method, self.objective_name, self.tau_learning_rate, self.beta_learning_rate, f"{self.beta_grad_clip:.4g}" if self.beta_grad_clip and self.beta_grad_clip > 0 else "off", self.update_interval, update_mode, ) else: LOG.info( "Controller meta disabled | reason=%s", disable_reason or "flag disabled", )
[docs] def should_run(self, global_step: int) -> bool: if not self.enabled: return False return (global_step + 1) % self.update_interval == 0
[docs] def make_backprop_fn( self, ) -> Optional[Callable[[int], Optional[ControllerGradients]]]: """Return a callback that computes gradients via autograd.""" if not ( self.enabled and self._requires_optimizer and self._controller_state is not None and self._torch is not None ): return None state = self._controller_state def _backprop_fn(_inner_steps: int) -> Optional[ControllerGradients]: tau_param = state.tau_param beta_param = state.beta_param tau_grad = tau_param.grad beta_grad = beta_param.grad if tau_grad is None and beta_grad is None: return None tau_grad_val = None beta_grad_val = None if tau_grad is not None: val = tau_grad.detach() if hasattr(tau_grad, "detach") else tau_grad try: tau_grad_val = float(val.item()) except ( AttributeError, TypeError, ValueError, ): # pragma: no cover - numeric fallback tau_grad_val = float(val) if beta_grad is not None: val = beta_grad.detach() if hasattr(beta_grad, "detach") else beta_grad try: beta_grad_val = float(val.item()) except ( AttributeError, TypeError, ValueError, ): # pragma: no cover - numeric fallback beta_grad_val = float(val) if tau_grad_val is None and beta_grad_val is None: return None return ControllerGradients( tau_grad=tau_grad_val, beta_grad=beta_grad_val, ) return _backprop_fn
[docs] def apply_gradients( self, gradients: Optional[ControllerGradients], *, lr_scale: float, ) -> None: """Apply controller updates based on the configured method.""" if not gradients: return if self._requires_optimizer and self._meta_optimizer is not None: # Populate parameter grads from the controller objective signals when # running in analytic mode. Optimizer-based modes expect the grads # to already be set on the controller parameters. if self.method in ("analytic", "analytic_grad", "potential"): self._set_optimizer_grads(gradients) self._apply_optimizer_step(lr_scale) setattr( self._weighting, "_meta_last_tau_grad", float(gradients.tau_grad or 0.0) ) setattr( self._weighting, "_meta_last_beta_grad", float(gradients.beta_grad or 0.0), ) meta_cfg = getattr(self._weighting, "controller_meta", None) if meta_cfg: meta_cfg.last_tau_grad = float(gradients.tau_grad or 0.0) meta_cfg.last_beta_grad = float(gradients.beta_grad or 0.0) return self._manual_update(gradients, lr_scale=lr_scale)
def _set_optimizer_grads(self, gradients: ControllerGradients) -> None: """Map controller update signals onto controller Parameter.grads.""" if self._controller_state is None or self._torch is None: return state = self._controller_state torch_mod = self._torch try: state.zero_grad() except (AttributeError, RuntimeError, TypeError) as exc: LOG.debug("Failed to zero controller grads: %s", exc) def _as_grad_tensor(param: Any, value: float) -> Any: dtype = getattr(param, "dtype", None) device = getattr(param, "device", None) try: return torch_mod.tensor(float(value), dtype=dtype, device=device) except TypeError: return torch_mod.tensor(float(value), dtype=dtype) # tau: gradient-descent style update (tau -= lr * tau_grad) if isinstance(gradients.tau_grad, (int, float)) and math.isfinite( float(gradients.tau_grad) ): state.tau_param.grad = _as_grad_tensor( state.tau_param, float(gradients.tau_grad) ) # beta: "tighten KL when kl > target" update (beta += lr * beta_grad), # so we flip sign to match optimizer descent convention. if isinstance(gradients.beta_grad, (int, float)) and math.isfinite( float(gradients.beta_grad) ): grad_update_val = float(gradients.beta_grad) clip = float(self.beta_grad_clip or 0.0) if clip > 0.0 and math.isfinite(clip): grad_update_val = max(min(grad_update_val, clip), -clip) state.beta_param.grad = _as_grad_tensor(state.beta_param, -grad_update_val) def _manual_update( self, gradients: ControllerGradients, *, lr_scale: float ) -> None: meta_cfg = getattr(self._weighting, "controller_meta", None) base_lr = 0.0 if meta_cfg is not None: try: base_lr = float(getattr(meta_cfg, "learning_rate", 0.0) or 0.0) except (TypeError, ValueError): base_lr = 0.0 lr_scale_val = float(lr_scale) if isinstance(lr_scale, (int, float)) else 1.0 if not math.isfinite(lr_scale_val): lr_scale_val = 1.0 lr_scale_val = max(lr_scale_val, 0.0) lr_tau = base_lr if base_lr > 0.0 else float(self.tau_learning_rate) lr_beta = base_lr if base_lr > 0.0 else float(self.beta_learning_rate) lr_tau *= lr_scale_val lr_beta *= lr_scale_val if lr_tau <= 0.0 and lr_beta <= 0.0: return updated = False tau_projected = False if isinstance(gradients.tau_grad, (int, float)) and math.isfinite( float(gradients.tau_grad) ): grad_val = float(gradients.tau_grad) raw_tau = self._weighting.tau - lr_tau * grad_val new_tau = max(self._weighting.tau_min, raw_tau) tau_max = float(self._weighting.tau_max) if tau_max > 0.0: clipped = min(new_tau, tau_max) tau_projected = tau_projected or clipped != new_tau new_tau = clipped if self._weighting.tau_min > 0.0 and new_tau <= self._weighting.tau_min: tau_projected = True self._weighting.tau = float(new_tau) try: setattr( self._weighting, "_tau_log", math.log(max(self._weighting.tau, 1e-8)), ) except (AttributeError, TypeError, ValueError) as exc: LOG.debug("Failed to update _tau_log after manual tau update: %s", exc) updated = True setattr(self._weighting, "_meta_last_tau_grad", float(grad_val)) if meta_cfg is not None: try: meta_cfg.last_tau_grad = float(grad_val) except (AttributeError, TypeError, ValueError) as exc: LOG.debug("Failed to record controller meta tau_grad: %s", exc) beta_projected = False if isinstance(gradients.beta_grad, (int, float)) and math.isfinite( float(gradients.beta_grad) ): raw_grad_val = float(gradients.beta_grad) if lr_beta > 0.0: clip = float(self.beta_grad_clip or 0.0) grad_update_val = raw_grad_val if clip > 0.0 and math.isfinite(clip): grad_update_val = max(min(grad_update_val, clip), -clip) # Beta tightens KL: when kl > target (grad > 0) beta must increase. raw_beta = self._weighting.beta + lr_beta * grad_update_val if raw_beta < 0.0: beta_projected = True self._weighting.beta = max(float(raw_beta), 0.0) updated = True setattr(self._weighting, "_meta_last_beta_grad", float(raw_grad_val)) if meta_cfg is not None: try: meta_cfg.last_beta_grad = float(raw_grad_val) except (AttributeError, TypeError, ValueError) as exc: LOG.debug("Failed to record controller meta beta_grad: %s", exc) if not updated: return _ensure_tau_history(self._weighting) if self._weighting.train_grpo_objective: self._weighting.denom = 1.0 else: denom_sum = float(self._weighting.tau) + float(self._weighting.beta) self._weighting.denom = denom_sum if denom_sum > 0 else 1.0 setattr(self._weighting, "_meta_tau_projected", bool(tau_projected)) setattr(self._weighting, "_meta_beta_projected", bool(beta_projected)) state = getattr(self._weighting, "controller_state", None) if state is not None: try: state.sync_from_scalars(self._weighting.tau, self._weighting.beta) except (AttributeError, TypeError, ValueError) as exc: LOG.debug("Failed to sync controller state from scalars: %s", exc) try: state.zero_grad() except (AttributeError, RuntimeError, TypeError) as exc: LOG.debug("Failed to zero controller state grads: %s", exc) def _build_optimizer(self) -> Any: torch_mod = self._torch if torch_mod is None: raise RuntimeError("torch is not available for meta optimizer") state = self._controller_state if state is None: raise RuntimeError("controller_state required for meta optimizer") # Use param groups so tau/beta can have independent learning rates and # optimizer state (e.g., Adam moments). param_groups = [ {"params": [state.tau_param], "base_lr": float(self.tau_learning_rate)}, {"params": [state.beta_param], "base_lr": float(self.beta_learning_rate)}, ] if self.optimizer == "sgd": return torch_mod.optim.SGD(param_groups, lr=self.learning_rate) if self.optimizer == "adam": return torch_mod.optim.Adam(param_groups, lr=self.learning_rate) if self.optimizer == "adamw": return torch_mod.optim.AdamW(param_groups, lr=self.learning_rate) raise RuntimeError(f"Unsupported controller_meta_optimizer={self.optimizer}") def _apply_optimizer_step(self, lr_scale: float) -> None: if ( self._controller_state is None or self._meta_optimizer is None or self._torch is None ): return for group in self._meta_optimizer.param_groups: base_lr = group.get("base_lr", self.learning_rate) try: base_lr_val = float(base_lr) except (TypeError, ValueError): base_lr_val = float(self.learning_rate) group["lr"] = float(base_lr_val * lr_scale) self._meta_optimizer.step() self._meta_optimizer.zero_grad(set_to_none=True) torch_mod = self._torch state = self._controller_state ctx = _no_grad_context(torch_mod) with ctx: tau_min = float(self._weighting.tau_min) tau_max = float(self._weighting.tau_max) if tau_min <= 0.0 and tau_max <= 0.0: pass else: min_val = tau_min if tau_min > 0.0 else None max_val = tau_max if tau_max > 0.0 else None clamp_fn = getattr(state.tau_param, "clamp_", None) if callable(clamp_fn): if min_val is not None or max_val is not None: clamp_fn(min=min_val, max=max_val) else: clamp_res = state.tau_param.clamp(min=min_val, max=max_val) state.tau_param.copy_(clamp_res) beta_clamp = getattr(state.beta_param, "clamp_", None) if callable(beta_clamp): beta_clamp(min=0.0) else: state.beta_param.copy_(state.beta_param.clamp(min=0.0)) tau_val = float(state.tau_param.detach().item()) beta_val = float(state.beta_param.detach().item()) if tau_max > 0.0: tau_val = min(tau_val, tau_max) if tau_min > 0.0: tau_val = max(tau_val, tau_min) self._weighting.tau = tau_val self._weighting.beta = max(0.0, beta_val) try: setattr( self._weighting, "_tau_log", math.log(max(self._weighting.tau, 1e-8)) ) except (AttributeError, TypeError, ValueError) as exc: LOG.debug("Failed to update _tau_log after optimizer step: %s", exc) state.zero_grad() _ensure_tau_history(self._weighting) setattr(self._weighting, "_meta_tau_projected", False) setattr(self._weighting, "_meta_beta_projected", beta_val <= 0.0) if self._weighting.train_grpo_objective: self._weighting.denom = 1.0 else: denom_sum = self._weighting.tau + self._weighting.beta self._weighting.denom = denom_sum if denom_sum > 0 else 1.0 try: state.sync_from_scalars(self._weighting.tau, self._weighting.beta) except (AttributeError, TypeError, ValueError) as exc: LOG.debug("Failed to sync controller state after optimizer step: %s", exc)