Source code for maxent_grpo.training.telemetry.trl_logging

"""Lightweight logging helpers to mirror MaxEnt metrics inside TRL trainers.

These utilities attach a small mixin to the GRPOTrainer so per-step logs also
include the tau/beta and controller diagnostics used by the custom MaxEnt loop.
The helpers are dependency-light and tolerate missing transformer/TRL pieces so
unit tests can exercise them with SimpleNamespace stubs.
"""

from __future__ import annotations

import logging
import math
from typing import Any, Dict, Optional, cast, TYPE_CHECKING

try:  # Optional dependency for callback-based logging patch
    from transformers.trainer_callback import TrainerCallback
except (ImportError, ModuleNotFoundError):  # pragma: no cover - optional dependency
    TrainerCallback = None

LOG = logging.getLogger(__name__)


def _numeric_or_none(value: Any) -> Optional[float]:
    """Return a finite float or ``None`` when conversion fails."""

    if isinstance(value, bool):
        return None
    try:  # Handle plain numbers, numpy scalars, and torch scalars
        candidate = float(value)
    except (TypeError, ValueError):
        item_fn = getattr(value, "item", None)
        if callable(item_fn):
            try:
                candidate = float(cast(Any, item_fn)())
            except (TypeError, ValueError):
                return None
        else:
            return None
    return candidate if math.isfinite(candidate) else None


def _with_prefix(prefix: str, key: str) -> str:
    """Helper to attach a prefix if not already present."""

    return key if key.startswith(prefix) else f"{prefix}{key}"


_CANONICAL_COMPLETION_KEYS = {
    "completions/mean_length": "completions/mean_length_sampled",
    "completions/min_length": "completions/min_length_sampled",
    "completions/max_length": "completions/max_length_sampled",
    "completions/clipped_ratio": "completions/clipped_frac",
    "completions/mean_terminated_length": "completions/mean_length_terminated",
    "completions/min_terminated_length": "completions/min_length_terminated",
    "completions/max_terminated_length": "completions/max_length_terminated",
}


def _canonicalize_rollout_metric_keys(metrics: Dict[str, Any]) -> Dict[str, Any]:
    """Add canonical metric aliases so GRPO/MaxEnt share one key schema."""
    normalized = dict(metrics)
    for mode_prefix in ("train/", "eval/"):
        for legacy_suffix, canonical_suffix in _CANONICAL_COMPLETION_KEYS.items():
            legacy_key = f"{mode_prefix}{legacy_suffix}"
            canonical_key = f"{mode_prefix}{canonical_suffix}"
            if legacy_key in normalized and canonical_key not in normalized:
                normalized[canonical_key] = normalized[legacy_key]
        for key in list(normalized.keys()):
            if not key.startswith(f"{mode_prefix}diversity/"):
                continue
            suffix = key[len(f"{mode_prefix}diversity/") :]
            canonical_key = f"{mode_prefix}completions/diversity/{suffix}"
            if canonical_key not in normalized:
                normalized[canonical_key] = normalized[key]
    return normalized


def _fix_clipped_ratio(metrics: Dict[str, Any], args: Any) -> None:
    """Clamp and normalize TRL's negative clipped_ratio counts into a [0, 1] ratio."""

    # Best-effort world size for correcting cross-rank aggregation mistakes.
    world_size = _numeric_or_none(getattr(args, "world_size", None))
    if world_size in (None, 0.0):
        world_size = _numeric_or_none(getattr(args, "num_processes", None))
    if world_size in (None, 0.0):
        world_size = _numeric_or_none(getattr(args, "process_count", None))

    for key in list(metrics.keys()):
        if "completions/clipped_ratio" not in key:
            continue
        val = _numeric_or_none(metrics.pop(key))
        if val is None:
            continue
        val = float(val)
        if val < 0.0:
            if world_size and world_size > 1.0:
                # TRL gathers completions across ranks but divides by local batch size.
                val = 1.0 - ((1.0 - val) / world_size)
            else:
                # Fallback: negative count -> clamp to zero to avoid inflating ratios.
                val = 0.0
        metrics[key.replace("clipped_ratio", "clipped_frac")] = max(0.0, min(1.0, val))


if TYPE_CHECKING:  # pragma: no cover - typing only
    from transformers.trainer_callback import TrainerCallback as _TrainerCallback

    _WeightingLogCallbackBase = _TrainerCallback
else:
    _WeightingLogCallbackBase = (
        TrainerCallback if TrainerCallback is not None else object
    )


class _WeightingLogCallback(_WeightingLogCallbackBase):
    """Normalize/log metrics even if a trainer bypasses the log override."""

    def on_log(
        self,
        args: Any,
        state: Any,
        control: Any,
        logs: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        if logs is None:
            return control
        # Inputs are unused but kept for TrainerCallback signature compatibility.
        _ = state
        _ = kwargs
        merged = dict(logs)
        _augment_loss_metrics(merged)
        _fix_clipped_ratio(merged, args)
        normalized = _normalize_prefixes(merged, is_eval=False)
        logs.clear()
        logs.update(normalized)
        return control


def _fix_clipped_ratio_metrics(trainer: Any) -> None:
    """Sanitize in-memory `_metrics` before GRPOTrainer aggregates them."""

    metrics_map = getattr(trainer, "_metrics", None)
    if not isinstance(metrics_map, dict):
        return
    args = getattr(trainer, "args", None)
    num_generations = _numeric_or_none(getattr(args, "num_generations", None)) or 1.0
    denom = max(1.0, num_generations)

    def _normalize(val: Any) -> Optional[float]:
        v = _numeric_or_none(val)
        if v is None:
            return None
        v = float(v)
        if v < 0.0:
            v = -v / denom  # TRL emits a negative count; convert to fraction.
        return max(0.0, min(1.0, v))

    for mode_metrics in metrics_map.values():
        if not isinstance(mode_metrics, dict):
            continue
        for key in list(mode_metrics.keys()):
            if "completions/clipped_ratio" not in key:
                continue
            values = mode_metrics.get(key, [])
            if not isinstance(values, (list, tuple)):
                continue
            normalized: list = []
            for val in values:
                fixed = _normalize(val)
                if fixed is not None:
                    normalized.append(fixed)
            # Keep the original list if nothing was normalized.
            if normalized:
                mode_metrics[key] = normalized
                # Also inject a normalized copy for downstream consumers to avoid
                # surfacing raw negative counts.
                mode_metrics.setdefault(
                    key.replace("clipped_ratio", "clipped_frac"), normalized
                )


def _augment_loss_metrics(metrics: Dict[str, Any]) -> None:
    """Mirror base loss/KL logs under train-prefixed keys for consistency."""

    loss_val = _numeric_or_none(metrics.get("loss"))
    if loss_val is not None:
        metrics.setdefault("train/loss/total", float(loss_val))
    kl_val = _numeric_or_none(metrics.get("kl"))
    if kl_val is not None:
        metrics.setdefault("train/kl", float(kl_val))
        metrics.setdefault("train/loss/kl", float(kl_val))
    # Heuristically mirror policy/clip losses if TRL exposed them in metrics.
    key_aliases = {
        "train/loss/policy": ["policy_loss", "loss/policy", "train/policy_loss"],
        "train/loss/clip": ["clip_loss", "loss/clip", "train/clip_loss"],
        "train/loss/value": ["value_loss", "loss/value", "train/value_loss"],
    }
    for target, candidates in key_aliases.items():
        if target in metrics:
            continue
        for cand in candidates:
            cand_val = _numeric_or_none(metrics.get(cand))
            if cand_val is not None:
                metrics[target] = float(cand_val)
                break


def _merge_loss_components_from_trainer(metrics: Dict[str, Any], trainer: Any) -> None:
    """Inject loss sub-components captured from compute_loss into metrics."""

    comp = getattr(trainer, "_last_loss_components", None)
    if not isinstance(comp, dict):
        return
    for key, val in comp.items():
        val_num = _numeric_or_none(val)
        if val_num is None:
            continue
        metrics.setdefault(f"train/loss/{key}", float(val_num))


def _normalize_prefixes(
    metrics: Dict[str, Any], is_eval: bool = False
) -> Dict[str, Any]:
    """Return a copy of metrics with bare keys moved under train/ or eval/."""

    prefix = "eval/" if is_eval else "train/"
    out: Dict[str, Any] = {}
    for key, val in metrics.items():
        if key.startswith("eval/") or key.startswith("train/"):
            out[key] = val
            if key.endswith("weighting/tau"):
                out.setdefault("train/tau", val)
            if key.endswith("weighting/beta"):
                out.setdefault("train/beta", val)
            continue
        if key.startswith("train/weighting/"):
            if key.endswith("tau"):
                out.setdefault("train/tau", val)
            out[key] = val
            continue
        # Eval-prefixed keys from TRL (e.g., eval_loss, eval_reward)
        if key.startswith("eval_"):
            subkey = key[len("eval_") :]
            out[_with_prefix("eval/", subkey)] = val
            continue
        # Train-prefixed aliases (e.g., train_loss) — normalize to train/
        if key.startswith("train_"):
            subkey = key[len("train_") :]
            out[_with_prefix("train/", subkey)] = val
            continue
        if key == "loss":
            out[f"{prefix}loss/total"] = val
            continue
        if key == "kl":
            out[f"{prefix}kl"] = val
            out[f"{prefix}loss/kl"] = val
            continue
        if key in {"reward", "reward_std"}:
            out[f"{prefix}{key}"] = val
            continue
        if key in {"eval_reward", "eval_reward_std"}:
            out[f"eval/{key.replace('eval_', '')}"] = val
            continue
        if key.startswith("rewards/"):
            out[f"{prefix}{key}"] = val
            continue
        if key.startswith("eval_rewards/"):
            out[f"eval/{key[len('eval_') :]}"] = val
            continue
        if key.startswith("completions/"):
            out[f"{prefix}{key}"] = val
            continue
        if key.startswith("eval_completions/"):
            out[f"eval/{key[len('eval_') :]}"] = val
            continue
        if key == "frac_reward_zero_std":
            out[f"{prefix}reward/zero_fraction"] = val
            continue
        if key == "eval_frac_reward_zero_std":
            out["eval/reward/zero_fraction"] = val
            continue
        if key in {"beta", "kl_coeff", "kl_coef", "kl_coefficient"}:
            out[f"{prefix}weighting/beta"] = val
            continue
        if key == "tau":
            out[f"{prefix}weighting/tau"] = val
            continue
        if key.startswith("kl_controller_"):
            out[f"{prefix}kl_controller/{key[len('kl_controller_') :]}"] = val
            continue
        if key.startswith("train/weighting/"):
            if key.endswith("tau"):
                out.setdefault("train/tau", val)
            if key.endswith("beta"):
                out.setdefault("train/beta", val)
            out[key] = val
            continue
        out[_with_prefix(prefix, key)] = val
    out = _canonicalize_rollout_metric_keys(out)
    # Always materialize top-level aliases for weighting scalars so dashboards
    # can rely on `train/tau` and `train/beta` regardless of the upstream key
    # shape (e.g., `weighting/tau`, `train/weighting/tau`, or `tau`).
    alias_prefix = "eval" if is_eval else "train"
    tau_key = f"{alias_prefix}/weighting/tau"
    beta_key = f"{alias_prefix}/weighting/beta"
    if tau_key in out:
        out.setdefault(f"{alias_prefix}/tau", out[tau_key])
    if beta_key in out:
        out.setdefault(f"{alias_prefix}/beta", out[beta_key])
    return out


class _WeightingMetricHelper:
    """Helper that derives tau/beta metrics from a trainer + its args."""

    def __init__(self, args: Any) -> None:
        self._args = args
        self._prev_tau: Optional[float] = None
        self._prev_beta: Optional[float] = None

    def _current_tau(self, trainer: Any) -> float:
        for attr in ("tau", "maxent_tau"):
            tau_val = _numeric_or_none(getattr(trainer, attr, None))
            if tau_val is not None:
                return tau_val
        args = getattr(trainer, "args", self._args)
        tau_val = _numeric_or_none(getattr(args, "maxent_tau", None))
        return tau_val if tau_val is not None else 0.0

    def _current_beta(self, trainer: Any) -> float:
        kl_ctl = getattr(trainer, "kl_ctl", None)
        for candidate in (
            getattr(kl_ctl, "value", None),
            getattr(kl_ctl, "current_kl_coef", None),
            getattr(kl_ctl, "kl_coef", None),
            getattr(trainer, "kl_coef", None),
            getattr(trainer, "kl_coefficient", None),
            getattr(trainer, "beta", None),
        ):
            beta_val = _numeric_or_none(candidate)
            if beta_val is not None:
                return beta_val
        args = getattr(trainer, "args", self._args)
        beta_arg = _numeric_or_none(getattr(args, "beta", None))
        if beta_arg is not None:
            return beta_arg
        return 0.0

    def metrics_for_trainer(self, trainer: Any) -> Dict[str, float]:
        """Build the extra weighting metrics for the provided trainer."""

        args = getattr(trainer, "args", self._args)
        tau = float(self._current_tau(trainer))
        beta = float(self._current_beta(trainer))
        denom = _numeric_or_none(getattr(trainer, "weight_norm_denom", None))
        if denom is None:
            denom = max(tau + beta, 1.0)
        delta_tau = tau - self._prev_tau if self._prev_tau is not None else 0.0
        delta_beta = beta - self._prev_beta if self._prev_beta is not None else 0.0
        self._prev_tau = tau
        self._prev_beta = beta

        warmup_steps = getattr(args, "maxent_tau_warmup_steps", -1)
        warmup_steps = warmup_steps if isinstance(warmup_steps, int) else -1
        target_entropy = getattr(args, "maxent_target_weight_entropy", None)
        target_entropy_val = _numeric_or_none(target_entropy)
        state = getattr(trainer, "state", None)
        global_step = (
            getattr(state, "global_step", 0)
            if state is not None
            else getattr(args, "global_step", 0)
        )
        schedule_active = target_entropy_val is not None and global_step > max(
            0, warmup_steps
        )

        q_temperature = _numeric_or_none(getattr(args, "maxent_q_temperature", None))
        q_epsilon = _numeric_or_none(getattr(args, "maxent_q_epsilon", None))
        tau_lr = _numeric_or_none(getattr(args, "maxent_tau_lr", None))
        tau_min = _numeric_or_none(getattr(args, "maxent_tau_min", None))
        tau_max = _numeric_or_none(getattr(args, "maxent_tau_max", None))

        def _bool_flag(val: Any) -> Optional[bool]:
            """Return a bool for truthy/falsey flags, preserving ``None``."""

            if isinstance(val, bool):
                return val
            if val is None:
                return None
            try:
                return bool(val)
            except (TypeError, ValueError):
                return None

        train_grpo_flag = _bool_flag(getattr(args, "train_grpo_objective", None))
        # Back-compat aliases used by some recipes/TRL surfaces.
        if train_grpo_flag is None:
            train_grpo_flag = _bool_flag(getattr(args, "grpo_objective", None))
        maxent_flag = _bool_flag(getattr(args, "maxent_objective", None))
        if train_grpo_flag is None and maxent_flag is not None:
            train_grpo_flag = not maxent_flag
        if train_grpo_flag is None:
            train_grpo_flag = True  # default to GRPO when unspecified
        if maxent_flag is None:
            maxent_flag = not train_grpo_flag

        metrics: Dict[str, Optional[float]] = {
            "train/weighting/tau": tau,
            "train/weighting/beta": beta,
            "train/tau": tau,
            "train/beta": beta,
            "train/kl_coeff": beta,
            "train/weighting/weight_norm_denom": denom,
            "train/weight_norm_denom": denom,
            "train/weighting/tau_log": math.log(max(tau, 1e-8)),
            "train/weighting/q_temperature": q_temperature,
            "train/weighting/q_epsilon": q_epsilon,
            "train/weighting/tau_lr": tau_lr,
            "train/weighting/tau_min": tau_min,
            "train/weighting/tau_max": tau_max,
            "train/weighting/tau_warmup_steps": float(warmup_steps),
            "train/weighting/tau_target_entropy": (
                target_entropy_val if target_entropy_val is not None else None
            ),
            "train/weighting/tau_schedule_active": 1.0 if schedule_active else 0.0,
            "train/tau_target_enabled": 1.0 if target_entropy_val is not None else 0.0,
            "train/tau_schedule_active": 1.0 if schedule_active else 0.0,
            "train/weighting/delta_tau": delta_tau,
            "train/weighting/delta_tau_abs": abs(delta_tau),
            "train/weighting/delta_beta": delta_beta,
            "train/weighting/delta_beta_abs": abs(delta_beta),
            "train/delta_tau": delta_tau,
            "train/delta_beta": delta_beta,
            "train/kl_controller/target": _numeric_or_none(
                getattr(args, "kl_target", None)
            ),
            "train/kl_controller/horizon": _numeric_or_none(
                getattr(args, "kl_horizon", None)
            ),
            "train/kl_controller/step_size": _numeric_or_none(
                getattr(args, "kl_ctl_step_size", None)
            ),
            "train/grpo_objective": 1.0 if train_grpo_flag else 0.0,
            "train/maxent_objective": 1.0 if maxent_flag else 0.0,
            "train/kl_controller/enabled": (
                1.0
                if (
                    (
                        _bool_flag(getattr(args, "grpo_beta_controller_enabled", None))
                        if train_grpo_flag
                        else _bool_flag(
                            getattr(args, "maxent_beta_controller_enabled", None)
                        )
                    )
                    and _numeric_or_none(getattr(args, "kl_target", None))
                    not in {None, 0.0}
                    and _numeric_or_none(getattr(args, "kl_horizon", None))
                    not in {
                        None,
                        0.0,
                    }
                    and _numeric_or_none(getattr(args, "kl_ctl_step_size", None))
                    not in {None, 0.0}
                )
                else 0.0
            ),
            "train/kl_controller_enabled": (
                1.0
                if (
                    (
                        _bool_flag(getattr(args, "grpo_beta_controller_enabled", None))
                        if train_grpo_flag
                        else _bool_flag(
                            getattr(args, "maxent_beta_controller_enabled", None)
                        )
                    )
                    and _numeric_or_none(getattr(args, "kl_target", None))
                    not in {None, 0.0}
                    and _numeric_or_none(getattr(args, "kl_horizon", None))
                    not in {None, 0.0}
                    and _numeric_or_none(getattr(args, "kl_ctl_step_size", None))
                    not in {None, 0.0}
                )
                else 0.0
            ),
        }
        meta_enabled = _bool_flag(getattr(args, "controller_meta_enabled", None))
        meta_lr = _numeric_or_none(getattr(args, "controller_meta_lr", None)) or 0.0
        meta_interval = (
            _numeric_or_none(getattr(args, "controller_meta_update_interval", None))
            or 0.0
        )
        meta_trunc = (
            _numeric_or_none(
                getattr(args, "controller_meta_truncation_steps", None)
                or getattr(args, "controller_meta_analytic_steps", None)
            )
            or 0.0
        )
        meta_use_hessian = (
            1.0
            if _bool_flag(getattr(args, "controller_meta_use_hessian", None))
            else 0.0
        )
        metrics.update(
            {
                "train/meta/enabled": 1.0 if meta_enabled else 0.0,
                "train/meta/lr": meta_lr if meta_enabled else 0.0,
                "train/meta/update_interval": meta_interval if meta_enabled else 0.0,
                "train/meta/truncation_steps": meta_trunc if meta_enabled else 0.0,
                "train/meta/use_hessian": meta_use_hessian if meta_enabled else 0.0,
                "train/meta/tau_grad": 0.0,
                "train/meta/beta_grad": 0.0,
                "train/meta/grad_norm": 0.0,
                "train/meta/loss": 0.0,
                "train/meta/tau_projected": 0.0,
                "train/meta/beta_projected": 0.0,
            }
        )
        sanitized: Dict[str, float] = {}
        for key, val in metrics.items():
            val_num = _numeric_or_none(val)
            if val_num is None:
                continue
            sanitized[key] = float(val_num)
        return sanitized


class _WeightingLoggingMixin:
    """Mixin that injects weighting metrics into Trainer.log."""

    def _init_weighting_logger(self) -> None:
        if getattr(self, "_weighting_metric_helper", None) is None:
            self._weighting_metric_helper = _WeightingMetricHelper(
                getattr(self, "args", None)
            )

    def _cache_train_kl_for_alpha(self, metrics: Dict[str, Any]) -> None:
        """Persist latest train KL so adaptive MaxEnt alpha can read it."""

        if not isinstance(metrics, dict):
            return
        for key in ("train/kl", "train/loss/kl", "kl"):
            kl_val = _numeric_or_none(metrics.get(key))
            if kl_val is None:
                continue
            try:
                setattr(self, "_last_train_kl_for_alpha", float(kl_val))
            except (AttributeError, RuntimeError, TypeError, ValueError):
                return
            return

    def log(self, logs: Dict[str, Any], *args: Any, **kwargs: Any) -> None:
        self._init_weighting_logger()
        helper = getattr(self, "_weighting_metric_helper", None)
        merged = dict(logs or {})
        if helper is not None:
            try:
                extra = helper.metrics_for_trainer(self)
                for key, value in extra.items():
                    merged.setdefault(key, value)
            except (AttributeError, RuntimeError, TypeError, ValueError) as err:
                # Defensive: helper can rely on optional Trainer attributes
                # that may be missing in lightweight stubs.
                LOG.debug("Failed to compute weighting metrics: %s", err)
        # Prefer the precise loss captured during compute_loss to avoid the
        # 4-decimal rounding applied by the upstream Trainer logger.
        precise_loss = _numeric_or_none(getattr(self, "_last_loss_scalar", None))
        logged_loss = _numeric_or_none(merged.get("loss"))
        if precise_loss is not None:
            merged.setdefault("train/loss/total_raw", precise_loss)
            if logged_loss is None or logged_loss == 0.0:
                merged["loss"] = precise_loss
        _merge_loss_components_from_trainer(merged, self)
        _augment_loss_metrics(merged)
        _fix_clipped_ratio(merged, getattr(self, "args", None))
        # Sanitize TRL's internal metrics so the downstream GRPOTrainer log doesn't
        # reintroduce negative clipped_ratio values.
        _fix_clipped_ratio_metrics(self)
        normalized = _normalize_prefixes(merged, is_eval=False)
        self._cache_train_kl_for_alpha(normalized)
        return cast(Any, super()).log(normalized, *args, **kwargs)


[docs] def ensure_weighting_logging(trainer_cls: type) -> type: """Wrap a Trainer subclass to include weighting metric logging once. :param trainer_cls: Trainer class (or callable) to wrap. :type trainer_cls: type :returns: Wrapped trainer class emitting normalized weighting metrics. :rtype: type """ if not isinstance(trainer_cls, type): # Allow callables (e.g., stubs returning SimpleNamespace) to be used like classes. callable_trainer = trainer_cls class _CallableTrainer(_WeightingLoggingMixin): _MAXENT_WEIGHTING_LOGGING = True def __init__(self, *args: Any, **kwargs: Any) -> None: self._inner = callable_trainer(*args, **kwargs) self.logged_kwargs = kwargs def __getattr__(self, name: str) -> Any: return getattr(self._inner, name) def log(self, logs: Dict[str, Any], *args: Any, **kwargs: Any) -> None: merged = dict(logs or {}) _merge_loss_components_from_trainer(merged, self) _augment_loss_metrics(merged) _fix_clipped_ratio(merged, getattr(self, "args", None)) normalized = _normalize_prefixes(merged, is_eval=False) if hasattr(self._inner, "log"): return self._inner.log(normalized, *args, **kwargs) return None _CallableTrainer.__name__ = ( f"WeightingLogged{getattr(trainer_cls, '__name__', 'Callable')}" ) return _CallableTrainer if getattr(trainer_cls, "_MAXENT_WEIGHTING_LOGGING", False): return trainer_cls class _LossCaptureMixin: """Capture loss component dicts returned by compute_loss for logging.""" def compute_loss(self, *args: Any, **kwargs: Any) -> Any: super_obj = cast(Any, super()) loss = super_obj.compute_loss(*args, **kwargs) setattr(self, "_last_loss_components", None) setattr(self, "_last_loss_scalar", None) loss_value = loss if isinstance(loss, tuple) and len(loss) >= 2: maybe_components = loss[1] if isinstance(maybe_components, dict): try: self._last_loss_components = { str(k): v for k, v in maybe_components.items() } except ( RuntimeError, TypeError, ValueError, ) as err: # pragma: no cover - defensive LOG.debug( "Failed to capture loss components for logging: %s", err ) loss_value = loss[0] try: # Capture a precise scalar before upstream rounding. if hasattr(loss_value, "mean"): loss_value = loss_value.mean() self._last_loss_scalar = float(loss_value.item()) except (AttributeError, RuntimeError, TypeError, ValueError) as err: LOG.debug("Failed to capture precise loss scalar: %s", err) return loss class WeightingLoggedTrainer( _LossCaptureMixin, _WeightingLoggingMixin, trainer_cls ): _MAXENT_WEIGHTING_LOGGING = True def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) # Belt-and-suspenders: attach callback-based normalization in case # parent classes bypass the log override. try: cb_handler = getattr(self, "callback_handler", None) callbacks = getattr(cb_handler, "callbacks", []) if cb_handler else [] already_added = any( isinstance(cb, _WeightingLogCallback) for cb in callbacks ) if not already_added and hasattr(self, "add_callback"): self.add_callback(_WeightingLogCallback()) except (AttributeError, RuntimeError, TypeError) as err: LOG.debug("Failed to attach weighting log callback: %s", err) WeightingLoggedTrainer.__name__ = f"WeightingLogged{trainer_cls.__name__}" return WeightingLoggedTrainer
__all__ = ["ensure_weighting_logging"]