Source code for maxent_grpo.training.metrics

# Copyright 2025 Liv d'Aliberti
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Metrics and logging helpers for the MaxEnt-GRPO training loop.

Key entry points
----------------

``log_local_step``
    Emits per-rank metrics for debugging and updates the accumulator used for
    windowed averages.
``log_training_step``
    Aggregates metrics across processes, forwards them to ``wandb`` and/or the
    ``accelerate`` logger, and dumps a structured log line.
``LogStepArtifacts``
    Lightweight container that bundles loss outputs, diagnostics, gradient
    norms, and epoch progress.

The module also exposes helpers for building W&B sample tables,
gathering statistics across ranks, and summarizing reward/weighting diagnostics.
Docstrings follow Sphinx conventions so the documentation clearly describes the
available metrics and their shapes.
"""

from __future__ import annotations

import logging
import sys
import math
import json
import os
from typing import (
    Any,
    Dict,
    List,
    Mapping,
    Optional,
    Sequence,
    Tuple,
    Callable,
    TYPE_CHECKING,
)

from maxent_grpo.training.runtime.logging import _log_wandb
from maxent_grpo.training.telemetry.trl_logging import _normalize_prefixes
from .runtime import resolve_run_metadata
from .types import (
    Accelerator,
    BatchDiagnostics,
    LengthStats,
    LogStepArtifacts,
    LoggingHandles,
    LoggingConfigView,
    MetricState,
    OptimizationSchedule,
    RewardComponentStats,
    RewardComputation,
    RewardLoggingView,
    TrainingLoopContext,
    TrainingMetricsPayload,
    TrainingScalarStats,
    TokenUsageStats,
)
from .weighting import WeightLoggingView, WeightStats

if TYPE_CHECKING:  # Avoid importing heavy pipeline/scoring deps at runtime
    from .pipeline import PreparedBatch
    from .types import LossOutputs

LOG = logging.getLogger(__name__)
_WANDB_SAMPLE_ROWS = 4
_LOG_STRATEGY_WARNED = {"epoch": False}
_DEBUG_METRIC_FIELDS = (
    ("train/loss", "loss"),
    ("train/reward", "reward"),
    ("train/reward_std", "reward_std"),
    ("train/q_entropy_mean", "q_entropy"),
    ("train/weight_entropy", "w_entropy"),
    ("train/kl", "kl"),
    ("train/tau", "tau"),
    ("train/beta", "beta"),
)


def _rich_completion_wandb_enabled(training_args: Any) -> bool:
    """Return whether enriched completion tables should also go to W&B."""
    if training_args is None:
        return False
    return bool(getattr(training_args, "rich_log_completions_to_wandb", False))


def _rich_completion_sync_enabled(training_args: Any) -> bool:
    """Return whether ranks should synchronize after rich completion logging."""
    if training_args is None:
        return True
    return bool(
        getattr(training_args, "rich_log_completions_synchronize_ranks", True)
    )


def _wait_after_rich_completion_logging(
    accelerator: Any,
    training_args: Any,
) -> None:
    """Synchronize ranks after rich completion logging when configured."""
    if not _rich_completion_sync_enabled(training_args):
        return
    wait_for_all = getattr(accelerator, "wait_for_everyone", None)
    if callable(wait_for_all):
        wait_for_all()


def _as_float(value: Any) -> Optional[float]:
    """Return a finite float or ``None`` when conversion fails."""
    if isinstance(value, bool):
        return None
    try:
        candidate = float(value)
    except (TypeError, ValueError):
        item_fn = getattr(value, "item", None)
        if callable(item_fn):
            try:
                candidate = float(item_fn())
            except (TypeError, ValueError):
                return None
        else:
            return None
    return candidate if math.isfinite(candidate) else None


def _metrics_mode() -> str:
    """Return the logging mode for metrics filtering."""
    raw = os.environ.get("MAXENT_WANDB_METRICS_MODE") or os.environ.get(
        "MAXENT_METRICS_MODE", ""
    )
    return raw.strip().lower()


def _drop_prefix(metrics: Dict[str, Any], prefix: str) -> None:
    """Remove all keys that start with ``prefix`` from ``metrics``."""
    for key in [k for k in metrics if k.startswith(prefix)]:
        metrics.pop(key, None)


def _slim_metrics(
    metrics: Dict[str, Any],
    _ctx: TrainingLoopContext,
) -> Dict[str, Any]:
    """Return a compact metrics dict for W&B/console logging."""
    slim = dict(metrics)
    maxent_objective = _as_float(metrics.get("train/maxent_objective"))
    entropy_bonus_coef = _as_float(metrics.get("train/entropy_bonus_coef"))
    meta_enabled = _as_float(metrics.get("train/meta/enabled"))

    _drop_prefix(slim, "train/weighting/")
    _drop_prefix(slim, "train/kl_controller/")
    _drop_prefix(slim, "train/kl_per_token_bucket/")
    _drop_prefix(slim, "train/kl_per_token_bucket_tokens/")

    for key in (
        "train/loss/total",
        "train/objective/minimize",
        "train/objective/maximize",
        "train/kl_coeff",
        "train/grpo_objective",
        "train/maxent_objective",
        "train/len_norm_ref",
        "train/weight_norm_denom",
        "train/tau_log",
        "train/delta_tau_abs",
        "train/delta_beta_abs",
    ):
        slim.pop(key, None)

    for key in [k for k in slim if k.startswith("train/clip_ratio/")]:
        slim.pop(key, None)

    keep_completions = {
        "train/completions/mean_length_sampled",
        "train/completions/mean_length_terminated",
        "train/completions/clipped_frac",
        "train/completions/diversity/jaccard",
        "train/completions/diversity/jaccard_micro",
        "train/completions/diversity/distinct_1",
        "train/completions/diversity/distinct_2",
        "train/completions/diversity/distinct_1_micro",
        "train/completions/diversity/distinct_2_micro",
    }
    for key in [k for k in slim if k.startswith("train/completions/")]:
        if key not in keep_completions:
            slim.pop(key, None)

    keep_reward_quantiles = {"train/reward_p05", "train/reward_p50", "train/reward_p95"}
    for key in [k for k in slim if k.startswith("train/reward_p")]:
        if key not in keep_reward_quantiles:
            slim.pop(key, None)
    for key in [k for k in slim if k.startswith("train/rewards/") and "/p" in k]:
        if not key.endswith(("/p05", "/p50", "/p95")):
            slim.pop(key, None)

    comp_means = [
        k for k in slim if k.startswith("train/rewards/") and k.endswith("/mean")
    ]
    if len(comp_means) == 1:
        comp_mean = comp_means[0]
        if _as_float(slim.get(comp_mean)) == _as_float(slim.get("train/reward")):
            reward_key = comp_mean.split("/")[2]
            _drop_prefix(slim, f"train/rewards/{reward_key}/")

    if meta_enabled == 0.0:
        _drop_prefix(slim, "train/meta/")

    if entropy_bonus_coef in (None, 0.0):
        for key in list(slim):
            if key.startswith("train/entropy_bonus"):
                slim.pop(key, None)
        slim.pop("train/reward_without_entropy_bonus", None)
        slim.pop("train/reward_with_entropy_bonus", None)
        _drop_prefix(slim, "train/rewards/entropy_bonus/")

    if maxent_objective == 0.0:
        for key in list(slim):
            if key.startswith("train/weight_entropy") or key.startswith(
                "train/advantage_entropy"
            ):
                slim.pop(key, None)
        tau_val = _as_float(slim.get("train/tau"))
        if tau_val in (None, 0.0):
            slim.pop("train/tau", None)
            for key in [k for k in slim if k.startswith("train/tau_")]:
                slim.pop(key, None)
    else:
        for key in (
            "train/weight_entropy_min",
            "train/weight_entropy_max",
            "train/weight_entropy_ema",
        ):
            slim.pop(key, None)
        for key in ("train/q_entropy_min", "train/q_entropy_max"):
            slim.pop(key, None)

    return slim


def _filter_metrics(
    metrics: Dict[str, Any],
    ctx: TrainingLoopContext,
) -> Dict[str, Any]:
    """Return metrics filtered according to the configured mode."""
    mode = _metrics_mode()
    if not mode or mode in {"full", "all", "default"}:
        return metrics
    if mode in {"slim", "compact", "minimal", "lite"}:
        return _slim_metrics(metrics, ctx)
    return metrics


def _log_like_grpo_enabled(training_args: Any) -> bool:
    """Return ``True`` when GRPO-style per-rank logging is requested."""

    flag_val = (
        getattr(training_args, "log_like_grpo", False) if training_args else False
    )
    if isinstance(flag_val, bool):
        return flag_val
    try:
        return bool(flag_val)
    except (TypeError, ValueError):
        return False


def _logging_controls(ctx: TrainingLoopContext) -> tuple[str, int, bool]:
    """Return logging cadence (strategy, steps, first-step flag)."""
    training_args = getattr(ctx, "training_args", None)
    strategy = str(
        getattr(training_args, "logging_strategy", "steps") or "steps"
    ).lower()
    steps = int(getattr(training_args, "logging_steps", 1) or 1)
    first_step = bool(getattr(training_args, "logging_first_step", True))
    if steps <= 0:
        steps = 1
    return strategy, steps, first_step


def _should_log(ctx: TrainingLoopContext, step: int) -> bool:
    """Return True when metrics should be emitted for this step."""
    strategy, steps, first_step = _logging_controls(ctx)
    if strategy in {"no", "none", "off"}:
        return False
    if strategy in {"epoch", "epochs"}:
        if not _LOG_STRATEGY_WARNED["epoch"]:
            LOG.warning(
                "logging_strategy=epoch is not supported in the custom loop; disabling step logs."
            )
            _LOG_STRATEGY_WARNED["epoch"] = True
        return False
    if step == 0:
        return first_step
    return (step % steps) == 0


def _log_debug_metrics(step: int, metrics: Dict[str, Any]) -> None:
    """Emit a concise debug line with key metrics for the current step."""

    if not LOG.isEnabledFor(logging.DEBUG):
        return
    parts: List[str] = []
    for key, label in _DEBUG_METRIC_FIELDS:
        value = metrics.get(key)
        if isinstance(value, (int, float)) and math.isfinite(float(value)):
            parts.append(f"{label}={float(value):.6f}")
    reward_components: List[str] = []
    for key in sorted(metrics):
        if not key.startswith("train/rewards/"):
            continue
        if not key.endswith("/mean"):
            continue
        short_name = key.split("/")[-2]
        value = metrics.get(key)
        if isinstance(value, (int, float)) and math.isfinite(float(value)):
            reward_components.append(f"{short_name}={float(value):.4f}")
    if reward_components:
        parts.append("rewards[" + ", ".join(reward_components) + "]")
    if not parts:
        parts.append("no-metrics")
    LOG.debug("debug metrics step %d | %s", step, " ".join(parts))


def _log_entropy_bonus_impact(
    metrics: Dict[str, Any],
    step: int,
    *,
    tag: str,
) -> None:
    """Emit a concise log line showing entropy bonus impact when present."""
    bonus_mean = metrics.get("train/entropy_bonus_mean")
    bonus_std = metrics.get("train/rewards/entropy_bonus/std")
    reward_no_bonus = metrics.get("train/reward_without_entropy_bonus")
    reward_with_bonus = metrics.get("train/reward_with_entropy_bonus")
    objective_loss = metrics.get("train/objective/minimize", metrics.get("train/loss"))
    if not isinstance(bonus_mean, (int, float)):
        return
    if not isinstance(reward_no_bonus, (int, float)):
        return
    if not isinstance(reward_with_bonus, (int, float)):
        reward_with_bonus = metrics.get("train/reward")
    if not isinstance(reward_with_bonus, (int, float)):
        return
    if not isinstance(objective_loss, (int, float)):
        objective_loss = None
    bonus_std_str = ""
    if isinstance(bonus_std, (int, float)) and math.isfinite(float(bonus_std)):
        bonus_std_str = f" | bonus_std={float(bonus_std):.6f}"
    if objective_loss is None:
        LOG.info(
            "%s entropy bonus step %d | reward_no_bonus=%.6f | bonus_mean=%.6f | reward_with_bonus=%.6f%s",
            tag,
            step,
            float(reward_no_bonus),
            float(bonus_mean),
            float(reward_with_bonus),
            bonus_std_str,
        )
        return
    LOG.info(
        "%s entropy bonus step %d | reward_no_bonus=%.6f | bonus_mean=%.6f | reward_with_bonus=%.6f | objective_loss=%.6f%s",
        tag,
        step,
        float(reward_no_bonus),
        float(bonus_mean),
        float(reward_with_bonus),
        float(objective_loss),
        bonus_std_str,
    )


try:  # Optional dependency
    import wandb
except ImportError:  # pragma: no cover - optional logging backend
    wandb = None


class _FallbackWandbError(RuntimeError):
    """Fallback error used when wandb is unavailable."""


WandbError: type[BaseException]
if wandb is not None:
    WandbError = getattr(getattr(wandb, "errors", None), "Error", _FallbackWandbError)
else:
    WandbError = _FallbackWandbError


def _get_wandb() -> Optional[Any]:
    """Return the wandb module when available (facilitates testing)."""
    return wandb


def _mean_std(values: Sequence[float]) -> Tuple[float, float]:
    """Compute mean/std for a list of values.

    :param values: Sequence of numeric samples.
    :type values: Sequence[float]
    :returns: Tuple containing ``(mean, std)``.
    :rtype: tuple[float, float]
    """
    if not values:
        return 0.0, 0.0
    mean_val = float(sum(values) / len(values))
    if len(values) > 1:
        variance = sum((val - mean_val) ** 2 for val in values) / len(values)
        std_val = float(math.sqrt(max(variance, 0.0)))
    else:
        std_val = 0.0
    return mean_val, std_val


def _quantile_stats(
    values: Sequence[float],
    quantiles: Sequence[float],
) -> Dict[str, float]:
    """Compute simple linear-interpolated quantiles for logging."""
    if not values:
        return {}
    sorted_vals = sorted(float(v) for v in values)
    n = len(sorted_vals)
    stats: Dict[str, float] = {}
    for q in quantiles:
        q = float(q)
        if q <= 0.0:
            stats[f"p{int(q * 100):02d}"] = sorted_vals[0]
            continue
        if q >= 1.0:
            stats[f"p{int(q * 100):02d}"] = sorted_vals[-1]
            continue
        pos = q * (n - 1)
        lo = int(math.floor(pos))
        hi = int(math.ceil(pos))
        if lo == hi:
            val = sorted_vals[lo]
        else:
            frac = pos - lo
            val = sorted_vals[lo] * (1.0 - frac) + sorted_vals[hi] * frac
        stats[f"p{int(q * 100):02d}"] = float(val)
    return stats


def _gather_list_for_metrics(
    accelerator: Accelerator,
    values: Sequence[float],
    *,
    skip_global: bool = False,
) -> List[float]:
    """Gather a sequence of floats across processes.

    :param accelerator: Accelerate handle used for distributed comms.
    :type accelerator: Accelerator
    :param values: Local float values to gather.
    :type values: Sequence[float]
    :returns: Flattened list containing values from all ranks.
    :rtype: list[float]
    """
    local = [float(v) for v in values]
    if skip_global or getattr(accelerator, "num_processes", 1) <= 1:
        return local
    gather_fn = getattr(accelerator, "gather_object", None)
    if not callable(gather_fn):
        return local
    gathered = gather_fn(local)
    if not isinstance(gathered, list):
        return local
    merged: List[float] = []
    for chunk in gathered:
        merged.extend(float(v) for v in chunk)
    return merged


def _gather_dict_of_lists_for_metrics(
    accelerator: Accelerator,
    values: Mapping[str, Sequence[float]],
    *,
    skip_global: bool = False,
) -> Dict[str, List[float]]:
    """Gather dict-of-list structures across processes.

    :param accelerator: Accelerate handle used for distributed comms.
    :type accelerator: Accelerator
    :param values: Mapping of metric name to local float sequence.
    :type values: Mapping[str, Sequence[float]]
    :returns: Mapping where each metric key contains concatenated lists.
    :rtype: dict[str, list[float]]
    """
    if skip_global or getattr(accelerator, "num_processes", 1) <= 1:
        return {key: [float(v) for v in seq] for key, seq in values.items()}
    gather_fn = getattr(accelerator, "gather_object", None)
    if not callable(gather_fn):
        return {key: [float(v) for v in seq] for key, seq in values.items()}
    payload = {key: [float(v) for v in seq] for key, seq in values.items()}
    gathered = gather_fn(payload)
    if not isinstance(gathered, list):
        return {key: [float(v) for v in seq] for key, seq in values.items()}
    merged: Dict[str, List[float]] = {}
    for shard in gathered:
        if not isinstance(shard, dict):
            continue
        for key, seq in shard.items():
            merged.setdefault(key, []).extend(float(v) for v in seq)
    return merged


def _sum_scalar_for_metrics(
    accelerator: Accelerator,
    value: float,
    *,
    skip_global: bool = False,
) -> float:
    """Sum a scalar across all processes.

    :param accelerator: Accelerate handle used for reductions.
    :type accelerator: Accelerator
    :param value: Scalar value contributed by the local rank.
    :type value: float
    :returns: Sum of the scalar across all processes.
    :rtype: float
    """
    return float(
        sum(_gather_list_for_metrics(accelerator, [value], skip_global=skip_global))
    )


def _policy_entropy_from_scores(scores: Any) -> Optional[float]:
    """Return token-weighted policy entropy from a SequenceScores-like object."""
    entropy_sum = getattr(scores, "policy_entropy_sum", None)
    token_counts = getattr(scores, "denom_tok_tensor", None)
    if entropy_sum is None or token_counts is None:
        return None
    try:
        entropy_total = float(entropy_sum.detach().float().sum().cpu().item())
    except (AttributeError, RuntimeError, TypeError, ValueError):
        try:
            entropy_total = float(entropy_sum.sum())
        except (AttributeError, RuntimeError, TypeError, ValueError):
            return None
    try:
        token_total = float(token_counts.detach().float().sum().cpu().item())
    except (AttributeError, RuntimeError, TypeError, ValueError):
        try:
            token_total = float(token_counts.sum())
        except (AttributeError, RuntimeError, TypeError, ValueError):
            return None
    if token_total <= 0:
        return None
    return entropy_total / token_total


def _base_metric_block(
    payload: TrainingMetricsPayload, global_step: int
) -> Dict[str, Any]:
    """Return loss/optimizer scalars that mirror the TRL trainer."""
    scalars = payload.scalars
    total_loss = payload.loss_outputs.total_loss_scalar
    metrics: Dict[str, Any] = {
        "train/loss": total_loss,
        "train/loss/total": total_loss,
        "train/objective/minimize": total_loss,
        "train/objective/maximize": -float(total_loss),
        "train/learning_rate": scalars.current_lr,
        "train/epoch": scalars.epoch_progress,
        "train/global_step": float(global_step),
        "train/num_tokens": scalars.num_input_tokens,
        "train/avg_completion_tokens": scalars.avg_completion_tokens,
        "train/ref_logp_mean": scalars.ref_logp_mean,
        "train/beta": payload.config.weighting.beta,
        "train/tau": payload.config.weighting.tau,
        "train/kl_coeff": payload.config.weighting.beta,
        "train/grpo_objective": (
            1.0
            if getattr(payload.config.weighting, "train_grpo_objective", False)
            else 0.0
        ),
    }
    if scalars.num_completion_tokens > 0:
        kl_scalar = getattr(payload.loss_outputs, "kl_loss_scalar", None)
        if kl_scalar is None:
            kl_scalar = getattr(
                getattr(payload.loss_outputs, "scalars", None), "kl_loss", None
            )
        if kl_scalar is not None:
            kl_per_token = float(kl_scalar) / float(scalars.num_completion_tokens)
            metrics["train/kl_per_completion_token"] = max(0.0, kl_per_token)
        loss_per_token = float(payload.loss_outputs.total_loss_scalar) / float(
            scalars.num_completion_tokens
        )
        metrics["train/loss_per_completion_token"] = max(0.0, loss_per_token)
    if scalars.grad_norm_scalar is not None:
        metrics["train/grad_norm"] = scalars.grad_norm_scalar
    if scalars.vllm_latency_ms is not None:
        metrics["train/vllm_latency_ms"] = scalars.vllm_latency_ms
    if scalars.policy_entropy is not None:
        metrics["train/policy_entropy"] = scalars.policy_entropy
    if scalars.entropy_bonus_coef is not None:
        metrics["train/entropy_bonus_coef"] = scalars.entropy_bonus_coef
    if scalars.entropy_bonus_reward_std is not None:
        metrics["train/entropy_bonus_reward_std"] = scalars.entropy_bonus_reward_std
    return metrics


def _loss_component_block(loss_outputs: "LossOutputs") -> Dict[str, float]:
    """Break down the loss into individual components."""
    metrics: Dict[str, float] = {
        "train/loss/policy": loss_outputs.policy_loss_scalar,
        "train/loss/kl": loss_outputs.kl_loss_scalar,
        "train/loss/weighted_kl": loss_outputs.weighted_kl_loss_scalar,
    }
    clip_loss = loss_outputs.clip_loss_scalar
    if clip_loss is not None:
        metrics["train/loss/clip"] = clip_loss
    return metrics


def _length_metric_block(length_stats: LengthStats) -> Dict[str, float]:
    """Metrics summarizing completion lengths."""
    # Clamp clipped_ratio to the valid [0, 1] range to avoid noisy negatives.
    clipped_ratio = max(0.0, min(1.0, float(length_stats.clipped_ratio)))
    return {
        "train/completions/mean_length_sampled": length_stats.mean_length,
        "train/completions/min_length_sampled": length_stats.min_length,
        "train/completions/max_length_sampled": length_stats.max_length,
        "train/completions/clipped_frac": clipped_ratio,
        "train/completions/mean_length_terminated": length_stats.mean_terminated,
        "train/completions/min_length_terminated": length_stats.min_terminated,
        "train/completions/max_length_terminated": length_stats.max_terminated,
    }


def _entropy_bonus_impact(
    reward_stats: RewardLoggingView,
) -> Optional[Tuple[float, float, float, float, float]]:
    """Return reward/bonus summary values when an entropy bonus is present."""
    bonus_stats = reward_stats.per_reward.get("entropy_bonus")
    if bonus_stats is None:
        return None
    bonus_mean = float(bonus_stats.mean)
    reward_with_bonus = float(reward_stats.reward_mean)
    reward_without_bonus = reward_with_bonus - bonus_mean
    base_denom = max(abs(reward_without_bonus), 1e-8)
    total_denom = max(abs(reward_with_bonus), 1e-8)
    return (
        reward_without_bonus,
        reward_with_bonus,
        bonus_mean,
        bonus_mean / base_denom,
        bonus_mean / total_denom,
    )


def _reward_metric_block(payload: TrainingMetricsPayload) -> Dict[str, float]:
    reward_stats = payload.reward_stats
    metrics: Dict[str, float] = {
        "train/reward": reward_stats.reward_mean,
        "train/reward_std": reward_stats.reward_std,
        "train/frac_reward_zero_std": reward_stats.frac_zero_std,
        "train/q_entropy_mean": reward_stats.q_entropy_mean,
        "train/q_entropy_std": reward_stats.q_entropy_std,
        "train/q_entropy_min": reward_stats.q_entropy_min,
        "train/q_entropy_max": reward_stats.q_entropy_max,
        "train/seed_grpo/semantic_entropy_mean": reward_stats.semantic_entropy_mean,
        "train/seed_grpo/semantic_entropy_std": reward_stats.semantic_entropy_std,
        "train/seed_grpo/semantic_entropy_min": reward_stats.semantic_entropy_min,
        "train/seed_grpo/semantic_entropy_max": reward_stats.semantic_entropy_max,
        "train/seed_grpo/advantage_scale_mean": reward_stats.advantage_scale_mean,
        "train/seed_grpo/advantage_scale_min": reward_stats.advantage_scale_min,
        "train/seed_grpo/advantage_scale_max": reward_stats.advantage_scale_max,
        "train/seed_grpo/alpha_effective": reward_stats.seed_alpha_effective,
        "train/seed_grpo/max_possible_entropy": reward_stats.seed_max_possible_entropy,
    }
    for quantile_key, value in reward_stats.reward_quantiles.items():
        metrics[f"train/reward_{quantile_key}"] = value
    for reward_key, stats in reward_stats.per_reward.items():
        metrics[f"train/rewards/{reward_key}/mean"] = stats.mean
        metrics[f"train/rewards/{reward_key}/std"] = stats.std
        for quantile_key, value in reward_stats.per_reward_quantiles.get(
            reward_key, {}
        ).items():
            metrics[f"train/rewards/{reward_key}/{quantile_key}"] = value
    bonus_summary = _entropy_bonus_impact(reward_stats)
    if bonus_summary is not None:
        (
            reward_without_bonus,
            reward_with_bonus,
            bonus_mean,
            bonus_frac_base,
            bonus_frac_total,
        ) = bonus_summary
        metrics["train/reward_without_entropy_bonus"] = reward_without_bonus
        metrics["train/reward_with_entropy_bonus"] = reward_with_bonus
        metrics["train/entropy_bonus_mean"] = bonus_mean
        metrics["train/entropy_bonus_frac_of_base"] = bonus_frac_base
        metrics["train/entropy_bonus_frac_of_total"] = bonus_frac_total
    return metrics


def _clip_metric_block(diagnostics: "BatchDiagnostics") -> Dict[str, float]:
    """Return PPO-style clipping diagnostics."""
    metrics: Dict[str, float] = {
        "train/clip_ratio": diagnostics.clip_ratio,
        "train/clip_ratio/low_mean": diagnostics.clip_ratio_low_mean,
        "train/clip_ratio/low_min": diagnostics.clip_ratio_low_min,
        "train/clip_ratio/high_mean": diagnostics.clip_ratio_high_mean,
        "train/clip_ratio/high_max": diagnostics.clip_ratio_high_max,
        "train/clip_ratio/region_mean": diagnostics.clip_ratio_region_mean,
    }
    if diagnostics.kl_value is not None:
        metrics["train/kl"] = diagnostics.kl_value
    bucket_means = getattr(diagnostics, "kl_per_token_by_len_bucket", {}) or {}
    bucket_token_counts = getattr(diagnostics, "kl_token_count_by_len_bucket", {}) or {}
    for bucket in sorted(bucket_means.keys()):
        metrics[f"train/kl_per_token_bucket/{bucket}"] = bucket_means[bucket]
        metrics[f"train/kl_per_token_bucket_tokens/{bucket}"] = bucket_token_counts.get(
            bucket, 0.0
        )
    return metrics


def _weight_metric_block(payload: TrainingMetricsPayload) -> Dict[str, float]:
    """Entropy diagnostics for the MaxEnt weighting distribution."""
    weight_stats = payload.weight_stats
    metrics = {
        "train/weight_entropy": weight_stats.entropy,
        "train/weight_entropy_norm": weight_stats.entropy_norm,
        "train/weight_entropy_min": weight_stats.entropy_min,
        "train/weight_entropy_max": weight_stats.entropy_max,
        "train/advantage_entropy_mean": weight_stats.advantage_entropy_mean,
        "train/advantage_entropy_std": weight_stats.advantage_entropy_std,
    }
    entropy_ema = getattr(payload.config.weighting, "_tau_entropy_ema", None)
    if isinstance(entropy_ema, (int, float)):
        metrics["train/weight_entropy_ema"] = float(entropy_ema)
    return metrics


def _weighting_config_block(
    payload: TrainingMetricsPayload, global_step: int
) -> Dict[str, float]:
    """Log controller hyperparameters for both GRPO and MaxEnt-GRPO."""
    weighting = payload.config.weighting
    prev_tau = getattr(weighting, "_prev_tau", None)
    prev_beta = getattr(weighting, "_prev_beta", None)
    delta_tau = float(weighting.tau) - float(prev_tau) if prev_tau is not None else 0.0
    delta_beta = (
        float(weighting.beta) - float(prev_beta) if prev_beta is not None else 0.0
    )
    meta_cfg = getattr(weighting, "controller_meta", None)
    meta_enabled = bool(getattr(meta_cfg, "enabled", False))
    tau_lr_effective = getattr(weighting, "_tau_lr_effective", weighting.tau_lr)
    metrics: Dict[str, float] = {
        "train/weight_norm_denom": weighting.denom,
        "train/weighting/tau": float(weighting.tau),
        "train/weighting/beta": float(weighting.beta),
        "train/tau_log": float(
            getattr(weighting, "_tau_log", math.log(max(weighting.tau, 1e-8)))
        ),
        "train/q_temperature": weighting.q_temperature,
        "train/q_epsilon": weighting.q_epsilon,
        "train/tau_lr": float(tau_lr_effective),
        "train/tau_min": weighting.tau_min,
        "train/tau_max": weighting.tau_max,
        "train/tau_warmup_steps": float(weighting.tau_warmup_steps),
        "train/tau_target_entropy": float(
            weighting.tau_target_entropy
            if weighting.tau_target_entropy is not None
            else 0.0
        ),
        "train/tau_target_enabled": (
            1.0 if weighting.tau_target_entropy is not None else 0.0
        ),
        "train/tau_schedule_active": (
            1.0
            if (
                (not meta_enabled)
                and weighting.tau_target_entropy is not None
                and global_step > max(0, weighting.tau_warmup_steps)
            )
            else 0.0
        ),
        "train/kl_controller_target": weighting.kl_target,
        "train/kl_controller_horizon": float(weighting.kl_horizon),
        "train/kl_controller_step_size": weighting.kl_ctl_step_size,
        "train/kl_controller_enabled": (
            1.0
            if (not meta_enabled)
            and weighting.kl_target > 0.0
            and weighting.kl_horizon > 0
            and weighting.kl_ctl_step_size > 0.0
            else 0.0
        ),
        "train/len_norm_ref": 1.0 if weighting.len_norm_ref else 0.0,
        "train/maxent_objective": 0.0 if weighting.train_grpo_objective else 1.0,
        "train/delta_tau": delta_tau,
        "train/delta_tau_abs": abs(delta_tau),
        "train/delta_beta": delta_beta,
        "train/delta_beta_abs": abs(delta_beta),
    }
    metrics["train/weighting/weight_norm_denom"] = metrics["train/weight_norm_denom"]
    metrics["train/weighting/tau_log"] = metrics["train/tau_log"]
    metrics["train/weighting/q_temperature"] = metrics["train/q_temperature"]
    metrics["train/weighting/q_epsilon"] = metrics["train/q_epsilon"]
    metrics["train/weighting/tau_lr"] = metrics["train/tau_lr"]
    metrics["train/weighting/tau_min"] = metrics["train/tau_min"]
    metrics["train/weighting/tau_max"] = metrics["train/tau_max"]
    metrics["train/weighting/tau_warmup_steps"] = metrics["train/tau_warmup_steps"]
    metrics["train/weighting/tau_target_entropy"] = metrics["train/tau_target_entropy"]
    metrics["train/weighting/tau_schedule_active"] = metrics[
        "train/tau_schedule_active"
    ]
    metrics["train/weighting/delta_tau"] = metrics["train/delta_tau"]
    metrics["train/weighting/delta_tau_abs"] = metrics["train/delta_tau_abs"]
    metrics["train/weighting/delta_beta"] = metrics["train/delta_beta"]
    metrics["train/weighting/delta_beta_abs"] = metrics["train/delta_beta_abs"]
    # Error-to-target signals for KL and weight entropy controllers.
    kl_measured = payload.diagnostics.kl_value
    if kl_measured is None:
        kl_measured = getattr(payload.loss_outputs, "kl_loss_scalar", None)
    if (
        isinstance(kl_measured, (int, float))
        and weighting.kl_target
        and weighting.kl_target > 0.0
    ):
        metrics["train/kl_error_to_target"] = float(kl_measured) - weighting.kl_target
        metrics["train/kl_ratio_to_target"] = float(kl_measured) / max(
            weighting.kl_target, 1e-8
        )
    target_entropy = weighting.tau_target_entropy
    if target_entropy is not None:
        entropy_error = payload.weight_stats.entropy - float(target_entropy)
        metrics["train/weight_entropy_error"] = entropy_error
        metrics["train/weight_entropy_abs_error"] = abs(entropy_error)
        # Treat the squared error as a simple controller "loss" so it shows up alongside
        # the main model loss in dashboards (e.g., W&B).
        metrics["train/tau_loss"] = 0.5 * entropy_error * entropy_error
    meta_cfg = getattr(weighting, "controller_meta", None)
    meta_enabled = bool(getattr(meta_cfg, "enabled", False))
    metrics["train/meta/enabled"] = 1.0 if meta_enabled else 0.0
    metrics["train/meta/lr"] = (
        float(getattr(meta_cfg, "learning_rate", 0.0)) if meta_cfg else 0.0
    )
    metrics["train/meta/update_interval"] = float(
        getattr(meta_cfg, "update_interval", 0.0) if meta_cfg else 0.0
    )
    metrics["train/meta/truncation_steps"] = float(
        getattr(meta_cfg, "truncation_steps", getattr(meta_cfg, "analytic_steps", 0))
        if meta_cfg
        else 0.0
    )
    metrics["train/meta/use_hessian"] = (
        1.0 if meta_cfg and getattr(meta_cfg, "use_hessian", False) else 0.0
    )
    tau_grad = float(getattr(weighting, "_meta_last_tau_grad", 0.0))
    beta_grad = float(getattr(weighting, "_meta_last_beta_grad", 0.0))
    metrics["train/meta/tau_grad"] = tau_grad
    metrics["train/meta/beta_grad"] = beta_grad
    metrics["train/meta/grad_norm"] = math.sqrt(
        tau_grad * tau_grad + beta_grad * beta_grad
    )
    metrics["train/meta/loss"] = float(getattr(weighting, "_meta_last_loss", 0.0))
    metrics["train/meta/tau_projected"] = (
        1.0 if getattr(weighting, "_meta_tau_projected", False) else 0.0
    )
    metrics["train/meta/beta_projected"] = (
        1.0 if getattr(weighting, "_meta_beta_projected", False) else 0.0
    )
    metrics.setdefault("train/weighting/tau_loss", metrics.get("train/tau_loss", 0.0))
    metrics["train/kl_controller/target"] = metrics["train/kl_controller_target"]
    metrics["train/kl_controller/horizon"] = metrics["train/kl_controller_horizon"]
    metrics["train/kl_controller/step_size"] = metrics["train/kl_controller_step_size"]
    metrics["train/kl_controller/enabled"] = metrics["train/kl_controller_enabled"]
    return metrics


[docs] def build_training_metrics_dict( payload: TrainingMetricsPayload, global_step: int, ) -> Dict[str, Any]: """Return the flattened metrics dictionary for logging. :param payload: Structured metrics payload produced by the training loop. :type payload: TrainingMetricsPayload :param global_step: Current optimizer step used for logging context. :type global_step: int :returns: Flat mapping of scalar metrics keyed by name. :rtype: dict[str, Any] """ metrics: Dict[str, Any] = {} metrics.update(resolve_run_metadata()) metrics.update(_base_metric_block(payload, global_step)) metrics.update(_loss_component_block(payload.loss_outputs)) metrics.update(_length_metric_block(payload.length_stats)) metrics.update(_reward_metric_block(payload)) metrics.update(_weight_metric_block(payload)) metrics.update(_weighting_config_block(payload, global_step)) metrics.update(_clip_metric_block(payload.diagnostics)) if "train/kl" not in metrics: kl_fallback = getattr(payload.loss_outputs, "kl_loss_scalar", None) if isinstance(kl_fallback, (int, float)): metrics["train/kl"] = float(kl_fallback) if payload.diversity_metrics: metrics.update( { f"train/completions/diversity/{k}": v for k, v in payload.diversity_metrics.items() } ) return metrics
[docs] def log_training_metrics( logging_cfg: LoggingHandles, global_step: int, payload: TrainingMetricsPayload, ) -> Dict[str, Any]: """Emit scalar metrics to logging callbacks and return them. :param logging_cfg: Logging handles (W&B, tensorboard, stdout, etc.). :type logging_cfg: LoggingHandles :param global_step: Current optimizer step. :type global_step: int :param payload: Structured metrics payload to log. :type payload: TrainingMetricsPayload :returns: Flattened metrics dictionary emitted to loggers. :rtype: dict[str, Any] """ metrics = build_training_metrics_dict(payload, global_step) logging_cfg.log_metrics(metrics, global_step) writer = getattr(logging_cfg, "metric_writer", None) flush = getattr(writer, "flush", None) if callable(flush): flush() return metrics
def _reward_component_stats( per_reward_values: Mapping[str, Sequence[float]], ) -> Dict[str, RewardComponentStats]: """Convert raw reward samples into summary statistics. :param per_reward_values: Mapping of reward key to local samples. :type per_reward_values: Mapping[str, Sequence[float]] :returns: Mapping of reward key to mean/std dataclasses. :rtype: dict[str, RewardComponentStats] """ stats: Dict[str, RewardComponentStats] = {} for key, values in per_reward_values.items(): mean_val, std_val = _mean_std([float(v) for v in values]) stats[key] = RewardComponentStats(mean=mean_val, std=std_val) return stats def _fraction_zero_std_groups( accelerator: Accelerator, advantage_groups: Sequence[Sequence[float]], *, skip_global: bool = False, ) -> float: """Return the global fraction of zero-variance advantage groups. :param accelerator: Accelerate handle used for reductions. :type accelerator: Accelerator :param advantage_groups: Advantage samples grouped per prompt. :type advantage_groups: Sequence[Sequence[float]] :returns: Fraction of groups whose advantages are (near) zero variance. :rtype: float """ zero_std_local = 0.0 total_groups_local = 0.0 for adv_group in advantage_groups: if not adv_group: continue total_groups_local += 1.0 if all(abs(val) < 1e-8 for val in adv_group): zero_std_local += 1.0 zero_std_total = _sum_scalar_for_metrics( accelerator, zero_std_local, skip_global=skip_global ) group_total = _sum_scalar_for_metrics( accelerator, total_groups_local, skip_global=skip_global ) return zero_std_total / group_total if group_total > 0 else 0.0 def _summarize_reward_stats( accelerator: Accelerator, reward_comp: RewardComputation, *, skip_global: bool = False, ) -> RewardLoggingView: """Aggregate reward/advantage stats into a lightweight view. :param accelerator: Accelerate handle used for reductions. :type accelerator: Accelerator :param reward_comp: Reward computation outputs from the batch. :type reward_comp: RewardComputation :returns: Lightweight logging view containing aggregated stats. :rtype: RewardLoggingView """ all_rewards = _gather_list_for_metrics( accelerator, reward_comp.total_utils, skip_global=skip_global ) reward_mean, reward_std = _mean_std(all_rewards) adv_samples = _gather_list_for_metrics( accelerator, reward_comp.advantage_samples, skip_global=skip_global ) adv_mean, adv_std = _mean_std(adv_samples) per_reward_values = _gather_dict_of_lists_for_metrics( accelerator, reward_comp.per_reward_values, skip_global=skip_global ) # Q-distribution entropy captures how sharp the ranking is per prompt. q_grouped = reward_comp.q_grouped q_entropies = [] for q_vals in q_grouped: if not q_vals: continue # Clamp for numerical stability before log. entropy = 0.0 for q in q_vals: q_clamped = max(float(q), 1e-12) entropy -= q_clamped * math.log(q_clamped) q_entropies.append(entropy) q_entropies = _gather_list_for_metrics( accelerator, q_entropies, skip_global=skip_global ) q_entropy_mean, q_entropy_std = _mean_std(q_entropies) q_entropy_min = min(q_entropies) if q_entropies else 0.0 q_entropy_max = max(q_entropies) if q_entropies else 0.0 seed_entropies = _gather_list_for_metrics( accelerator, list(getattr(reward_comp, "seed_semantic_entropies", []) or []), skip_global=skip_global, ) semantic_entropy_mean, semantic_entropy_std = _mean_std(seed_entropies) semantic_entropy_min = min(seed_entropies) if seed_entropies else 0.0 semantic_entropy_max = max(seed_entropies) if seed_entropies else 0.0 seed_scales = _gather_list_for_metrics( accelerator, list(getattr(reward_comp, "seed_advantage_scales", []) or []), skip_global=skip_global, ) advantage_scale_mean, _ = _mean_std(seed_scales) advantage_scale_min = min(seed_scales) if seed_scales else 1.0 advantage_scale_max = max(seed_scales) if seed_scales else 1.0 alpha_effective_vals = _gather_list_for_metrics( accelerator, ( [float(getattr(reward_comp, "seed_alpha_effective", 0.0) or 0.0)] if getattr(reward_comp, "seed_alpha_effective", None) is not None else [] ), skip_global=skip_global, ) seed_alpha_effective, _ = _mean_std(alpha_effective_vals) max_entropy_vals = _gather_list_for_metrics( accelerator, ( [float(getattr(reward_comp, "seed_max_possible_entropy", 0.0) or 0.0)] if getattr(reward_comp, "seed_max_possible_entropy", None) is not None else [] ), skip_global=skip_global, ) seed_max_possible_entropy, _ = _mean_std(max_entropy_vals) reward_quantiles = _quantile_stats( all_rewards, (0.0, 0.05, 0.25, 0.5, 0.75, 0.95, 1.0) ) per_reward_quantiles: Dict[str, Dict[str, float]] = {} for reward_key, values in per_reward_values.items(): if not values: continue per_reward_quantiles[reward_key] = _quantile_stats( values, (0.0, 0.05, 0.25, 0.5, 0.75, 0.95, 1.0) ) return RewardLoggingView( reward_mean=reward_mean, reward_std=reward_std, frac_zero_std=_fraction_zero_std_groups( accelerator, reward_comp.advantage.grouped, skip_global=skip_global ), advantage_mean=adv_mean, advantage_std=adv_std, advantage_count=len(adv_samples), per_reward=_reward_component_stats(per_reward_values), q_entropy_mean=q_entropy_mean, q_entropy_std=q_entropy_std, q_entropy_min=q_entropy_min, q_entropy_max=q_entropy_max, semantic_entropy_mean=semantic_entropy_mean, semantic_entropy_std=semantic_entropy_std, semantic_entropy_min=semantic_entropy_min, semantic_entropy_max=semantic_entropy_max, advantage_scale_mean=advantage_scale_mean, advantage_scale_min=advantage_scale_min, advantage_scale_max=advantage_scale_max, seed_alpha_effective=seed_alpha_effective, seed_max_possible_entropy=seed_max_possible_entropy, reward_quantiles=reward_quantiles, per_reward_quantiles=per_reward_quantiles, )
[docs] def summarize_reward_stats( accelerator: Accelerator, reward_comp: Optional[RewardComputation], *, log_like_grpo: bool = False, ) -> RewardLoggingView: """Aggregate reward statistics across all ranks. Exposes the internal helper so that training code can gather reward diagnostics even on non-main ranks before metrics are logged. :param accelerator: Accelerate handle used for reductions. :type accelerator: Accelerator :param reward_comp: Reward computation outputs for the current batch. :type reward_comp: RewardComputation | None :param log_like_grpo: When ``True``, skip global reductions and keep local statistics for GRPO-style logging. :type log_like_grpo: bool :returns: Aggregated reward statistics for logging. :rtype: RewardLoggingView """ if reward_comp is None: return RewardLoggingView( reward_mean=0.0, reward_std=0.0, frac_zero_std=0.0, advantage_mean=0.0, advantage_std=0.0, advantage_count=0, per_reward={}, q_entropy_mean=0.0, q_entropy_std=0.0, q_entropy_min=0.0, q_entropy_max=0.0, reward_quantiles={}, per_reward_quantiles={}, ) return _summarize_reward_stats(accelerator, reward_comp, skip_global=log_like_grpo)
def _summarize_weight_stats( accelerator: Accelerator, weight_stats: WeightStats, *, skip_global: bool = False, ) -> WeightLoggingView: """Summarize entropy statistics for logging. :param accelerator: Accelerate handle used for distributed reductions. :type accelerator: accelerate.Accelerator :param weight_stats: Per-batch weight diagnostics. :type weight_stats: training.types.WeightStats :returns: Aggregated entropy metrics per batch. :rtype: WeightLoggingView """ weights_grouped = getattr(weight_stats, "weights_grouped", []) or [] prompt_count = len(weights_grouped) entropy_val = float(getattr(weight_stats, "weight_entropy", 0.0)) entropy_norm_vals: List[float] = [] for weight_group in weights_grouped: if not weight_group: continue denom = math.log(max(len(weight_group), 1)) if denom <= 0.0: entropy_norm_vals.append(0.0) continue filtered = [ max(float(w), 1e-12) for w in weight_group if isinstance(w, (int, float)) ] if not filtered: entropy_norm_vals.append(0.0) continue total = sum(filtered) if total <= 0.0: entropy_norm_vals.append(0.0) continue normalized = [val / total for val in filtered] entropy = -sum(val * math.log(val) for val in normalized) entropy_norm_vals.append(float(entropy / denom)) entropy_norm_sum = _sum_scalar_for_metrics( accelerator, float(sum(entropy_norm_vals)), skip_global=skip_global ) entropy_sum = _sum_scalar_for_metrics( accelerator, float(entropy_val * max(prompt_count, 0)), skip_global=skip_global ) prompt_total = _sum_scalar_for_metrics( accelerator, float(prompt_count), skip_global=skip_global ) entropy_mean = entropy_sum / prompt_total if prompt_total > 0 else entropy_val entropy_norm_mean = entropy_norm_sum / prompt_total if prompt_total > 0 else 0.0 entropy_min_vals = _gather_list_for_metrics( accelerator, [getattr(weight_stats, "weight_entropy_min", 0.0)], skip_global=skip_global, ) entropy_max_vals = _gather_list_for_metrics( accelerator, [getattr(weight_stats, "weight_entropy_max", 0.0)], skip_global=skip_global, ) ent_adv_values = _gather_list_for_metrics( accelerator, getattr(weight_stats, "advantage_entropy", []), skip_global=skip_global, ) ent_adv_mean, ent_adv_std = _mean_std(ent_adv_values) return WeightLoggingView( entropy=entropy_mean, entropy_norm=entropy_norm_mean, entropy_min=min(entropy_min_vals) if entropy_min_vals else 0.0, entropy_max=max(entropy_max_vals) if entropy_max_vals else 0.0, advantage_entropy_mean=ent_adv_mean, advantage_entropy_std=ent_adv_std, )
[docs] def summarize_weight_stats( accelerator: Accelerator, weight_stats: WeightStats, *, log_like_grpo: bool = False, ) -> WeightLoggingView: """Aggregate per-batch weight statistics across all processes. Exposes the internal summarization helper so controller logic can rely on the same cross-rank entropy measurement used for logging. :param accelerator: Accelerate handle used for reductions. :type accelerator: Accelerator :param weight_stats: Weight statistics for the current batch. :type weight_stats: WeightStats :param log_like_grpo: When ``True``, skip global reductions and keep local statistics for GRPO-style logging. :type log_like_grpo: bool :returns: Aggregated weight statistics for logging. :rtype: WeightLoggingView """ return _summarize_weight_stats(accelerator, weight_stats, skip_global=log_like_grpo)
def _build_metrics_payload( ctx: TrainingLoopContext, state: MetricState, prepared: PreparedBatch, log_artifacts: LogStepArtifacts, current_lr: float, *, reward_view: Optional[RewardLoggingView] = None, weight_view: Optional[WeightLoggingView] = None, ) -> TrainingMetricsPayload: """Return a structured payload describing the current step. :param ctx: Full training loop context. :type ctx: training.types.TrainingLoopContext :param state: Metric accumulator providing token counts and step numbers. :type state: MetricState :param prepared: Batch artifacts containing reward/weight stats. :type prepared: PreparedBatch :param log_artifacts: Loss/diagnostic bundle produced by :func:`log_local_step`. :type log_artifacts: LogStepArtifacts :param current_lr: Learning rate applied for the step (logged for reference). :type current_lr: float :param reward_view: Optional pre-aggregated reward statistics. When ``None`` the helper gathers them across all ranks. :type reward_view: RewardLoggingView | None :param weight_view: Optional pre-aggregated weight statistics. :type weight_view: WeightLoggingView | None :returns: Aggregated metrics suitable for logging. :rtype: training.types.TrainingMetricsPayload """ accelerator = ctx.runtime.accelerator training_args = getattr(ctx, "training_args", None) log_like_grpo = _log_like_grpo_enabled(training_args) config_view = LoggingConfigView( weighting=ctx.scoring.weighting, clipping=ctx.scoring.clipping, schedule=ctx.optimization.schedule, ) policy_entropy = _policy_entropy_from_scores(getattr(prepared, "scores", None)) entropy_bonus_coef = None entropy_bonus_reward_std = None scoring_cfg = getattr(ctx, "scoring", None) if scoring_cfg is not None: try: entropy_bonus_coef = float( getattr(scoring_cfg, "policy_entropy_bonus_coef", 0.0) ) except (TypeError, ValueError): entropy_bonus_coef = None try: entropy_bonus_reward_std = float( getattr(prepared.reward_comp, "entropy_bonus_scale", None) ) except (TypeError, ValueError): entropy_bonus_reward_std = None scalar_stats = TrainingScalarStats( ref_logp_mean=prepared.ref_stats.ref_logp_mean, tokens=TokenUsageStats( avg_completion_tokens=prepared.ref_stats.avg_completion_tokens, num_completion_tokens=prepared.num_completion_tokens, num_input_tokens=state.num_input_tokens_seen, ), current_lr=current_lr, grad_norm_scalar=log_artifacts.grad_norm_scalar, epoch_progress=log_artifacts.epoch_progress, vllm_latency_ms=( float(ctx.generation.generation_stats.get("vllm_last_latency_ms", 0.0)) if ctx.generation.use_vllm else None ), policy_entropy=policy_entropy, entropy_bonus_coef=entropy_bonus_coef, entropy_bonus_reward_std=entropy_bonus_reward_std, ) if log_like_grpo: reward_stats_payload = _summarize_reward_stats( accelerator, prepared.reward_comp, skip_global=True ) weight_stats_payload = _summarize_weight_stats( accelerator, prepared.weight_stats, skip_global=True ) else: reward_stats_payload = ( reward_view if reward_view is not None else _summarize_reward_stats(accelerator, prepared.reward_comp) ) weight_stats_payload = ( weight_view if weight_view is not None else _summarize_weight_stats(accelerator, prepared.weight_stats) ) return TrainingMetricsPayload( reward_stats=reward_stats_payload, weight_stats=weight_stats_payload, loss_outputs=log_artifacts.loss_outputs, diagnostics=log_artifacts.diagnostics, length_stats=prepared.length_stats, config=config_view, scalars=scalar_stats, diversity_metrics=prepared.diversity_metrics, ) def _epoch_from_global_step(schedule: OptimizationSchedule, global_step: int) -> float: """Return the current epoch progress given the training schedule. :param schedule: Optimization schedule containing step/epoch metadata. :type schedule: OptimizationSchedule :param global_step: Current optimizer step. :type global_step: int :returns: Fractional epoch progress. :rtype: float """ steps_per_epoch = getattr(schedule, "steps_per_epoch", None) if steps_per_epoch and steps_per_epoch > 0: return float(global_step) / float(steps_per_epoch) num_generations = getattr(schedule, "num_generations", 0) num_epochs = getattr(schedule, "num_epochs", 0) total_steps = getattr(schedule, "total_training_steps", 0) if num_generations and num_generations > 0: return float(global_step) / float(num_generations) if total_steps > 0 and num_epochs > 0: return float(global_step) * float(num_epochs) / float(total_steps) return float(global_step) def _emit_metrics( ctx: TrainingLoopContext, metrics: Dict[str, Any], global_step: int, *, log_to_wandb: bool, tag: str, metric_logger: Optional[Callable[[Dict[str, Any]], None]] = None, ) -> Dict[str, Any]: """Emit metrics to stdout, Accelerate, and optionally W&B. :param ctx: Training context containing logging handles. :type ctx: training.types.TrainingLoopContext :param metrics: Dictionary of scalar metrics to log. :type metrics: dict[str, Any] :param global_step: Current optimizer step. :type global_step: int :param log_to_wandb: Whether to forward the metrics to W&B via ``logging_handles``. :type log_to_wandb: bool :param tag: Human-readable prefix for the info log line. :type tag: str :returns: The metrics dictionary (unchanged) for convenience. :rtype: dict[str, Any] """ accelerator = ctx.runtime.accelerator logging_handles = ctx.logging metrics_to_emit = _filter_metrics(metrics, ctx) if log_to_wandb: if metric_logger is not None: metric_logger(metrics_to_emit) else: logging_handles.log_metrics(metrics_to_emit, global_step) _log_wandb( getattr(logging_handles, "wandb_run", None), metrics_to_emit, global_step, ) accelerator_log = getattr(accelerator, "log", None) if callable(accelerator_log): try: accelerator_log(metrics_to_emit, step=global_step) except TypeError: accelerator_log(metrics_to_emit) elif metric_logger is not None: metric_logger(metrics_to_emit) try: kv_pairs = " | ".join( f"{key}={metrics_to_emit[key]}" for key in sorted(metrics_to_emit.keys()) ) except (TypeError, ValueError, KeyError): kv_pairs = str(metrics_to_emit) LOG.info("%s metrics step %d | %s", tag, global_step, kv_pairs) return metrics_to_emit def _pretty_print_metrics(metrics: Dict[str, Any]) -> str: """Return a deterministic, pretty JSON string for human-readable logs.""" try: return json.dumps(metrics, indent=2, sort_keys=True, default=str) except (TypeError, ValueError): return str(metrics) def _update_weighting_history(weighting: Any, global_step: int) -> None: """Cache the last-seen tau/beta for delta logging.""" try: setattr(weighting, "_prev_tau", float(weighting.tau)) setattr(weighting, "_prev_beta", float(weighting.beta)) setattr(weighting, "_prev_step", int(global_step)) except (AttributeError, TypeError, ValueError): return
[docs] def accumulate_metrics(state: MetricState, metrics: Dict[str, Any]) -> None: """Accumulate per-batch metrics so the global log can show running averages. :param state: Mutable metric accumulator storing sums/counts. :type state: MetricState :param metrics: Scalar metrics emitted for the current step. :type metrics: dict[str, Any] """ for key, value in metrics.items(): if key in {"train/global_step", "train/epoch"}: continue try: numeric = float(value) except (TypeError, ValueError): continue state.metric_sums[key] = state.metric_sums.get(key, 0.0) + numeric state.metric_counts[key] = state.metric_counts.get(key, 0) + 1
[docs] def flush_metric_averages(state: MetricState) -> Dict[str, float]: """Return averaged metrics and clear the accumulator. :param state: Metric accumulator to flush. :type state: MetricState :returns: Mapping of metric name to averaged value. :rtype: dict[str, float] """ averaged: Dict[str, float] = {} for key, total in state.metric_sums.items(): count = max(state.metric_counts.get(key, 1), 1) averaged[key] = total / float(count) state.metric_sums.clear() state.metric_counts.clear() return averaged
[docs] def log_local_step( ctx: TrainingLoopContext, state: MetricState, prepared: PreparedBatch, log_artifacts: LogStepArtifacts, current_lr: float, *, reward_view: Optional[RewardLoggingView] = None, weight_view: Optional[WeightLoggingView] = None, emit: bool = True, ) -> None: """Log metrics for the current step on the main process only. :param ctx: Full training loop context containing runtime/logging handles. :type ctx: training.types.TrainingLoopContext :param state: Metric accumulator tracking sums and counts. :type state: MetricState :param prepared: Prepared batch with reward and weighting statistics. :type prepared: PreparedBatch :param log_artifacts: Loss outputs and diagnostics emitted by the optimizer step. :type log_artifacts: LogStepArtifacts :param current_lr: Learning rate applied for the current step. :type current_lr: float :param reward_view: Optional reward statistics aggregated across ranks. :type reward_view: RewardLoggingView | None :param weight_view: Optional weight statistics aggregated across ranks. :type weight_view: WeightLoggingView | None :param emit: When ``False``, skip emitting logs and only accumulate averages. :type emit: bool """ accelerator = ctx.runtime.accelerator if not accelerator.is_main_process: return training_args = getattr(ctx, "training_args", None) log_like_grpo = _log_like_grpo_enabled(training_args) payload = _build_metrics_payload( ctx, state, prepared, log_artifacts, current_lr, reward_view=reward_view, weight_view=weight_view, ) metrics = build_training_metrics_dict(payload, state.global_step) metrics["train/global_step"] = float(state.global_step) accumulate_metrics(state, metrics) if not emit: return if ( log_like_grpo and _should_log(ctx, state.global_step) and accelerator.is_main_process ): LOG.info( "step %d | epoch %.2f | loss=%.4f | tau=%.3f beta=%.3f", state.global_step, log_artifacts.epoch_progress, log_artifacts.loss_outputs.total_loss_scalar, ctx.scoring.weighting.tau, ctx.scoring.weighting.beta, ) if log_like_grpo: if not _should_log(ctx, state.global_step): return averaged_metrics = flush_metric_averages(state) if averaged_metrics: metrics_to_emit = dict(metrics) metrics_to_emit.update(averaged_metrics) else: metrics_to_emit = dict(metrics) if "train/epoch" not in metrics_to_emit: metrics_to_emit["train/epoch"] = _epoch_from_global_step( ctx.optimization.schedule, state.global_step, ) metrics_to_emit.setdefault("train/global_step", float(state.global_step)) _log_entropy_bonus_impact(metrics_to_emit, state.global_step, tag="Global") _log_debug_metrics(state.global_step, metrics_to_emit) normalized_metrics = _normalize_prefixes(dict(metrics_to_emit), is_eval=False) with ctx.logging.step_logger(state.global_step, enabled=True) as step_logger: _emit_metrics( ctx, normalized_metrics, state.global_step, log_to_wandb=True, tag="Global", metric_logger=getattr(step_logger, "log", None), ) _update_weighting_history(ctx.scoring.weighting, state.global_step) if training_args is None: log_completions = True else: log_completions = bool( getattr( training_args, "rich_log_completions", getattr(training_args, "log_completions", False), ) ) if log_completions: _log_sample_table(ctx, state, prepared) return _log_debug_metrics(state.global_step, metrics) if not _should_log(ctx, state.global_step): return _log_entropy_bonus_impact(metrics, state.global_step, tag="Local") with ctx.logging.step_logger(state.global_step, enabled=True) as step_logger: _emit_metrics( ctx, metrics, state.global_step, log_to_wandb=False, tag="Local", metric_logger=getattr(step_logger, "log", None), )
def _build_sample_table( prepared: PreparedBatch, step: int, max_rows: int, ) -> Tuple[List[str], List[List[Any]]]: """Return W&B table columns/rows for sample completions. :param prepared: Batch artifacts used to extract prompts/completions. :type prepared: PreparedBatch :param step: Global step used in the W&B table rows. :type step: int :param max_rows: Maximum number of rows to include in the table. :type max_rows: int :returns: Tuple containing table columns and row data. :rtype: tuple[list[str], list[list[Any]]] """ pairs = prepared.reward_comp.pairs prompts = pairs.prompts completions = pairs.completions reward_values = prepared.reward_comp.per_reward_values reward_keys = sorted(reward_values.keys()) advantages = prepared.reward_comp.advantage_samples total_utils = list(getattr(prepared.reward_comp, "total_utils", []) or []) q_grouped = list(getattr(prepared.reward_comp, "q_grouped", []) or []) weight_groups = list(getattr(getattr(prepared, "weight_stats", None), "weights_grouped", []) or []) def _flatten_groups(groups: Any, *, fill: float = float("nan")) -> List[float]: flat: List[float] = [] for group in groups or []: if not isinstance(group, list): continue for value in group: try: flat.append(float(value)) except (TypeError, ValueError): flat.append(fill) return flat def _weight_mass_proxy(group: Any) -> List[float]: if not isinstance(group, list) or not group: return [] weights: List[float] = [] for value in group: try: weights.append(float(value)) except (TypeError, ValueError): weights.append(0.0) if any(val < 0.0 for val in weights): positives = [max(val, 0.0) for val in weights] pos_total = sum(positives) if pos_total > 0.0: return [val / pos_total for val in positives] nonneg_total = sum(max(val, 0.0) for val in weights) if nonneg_total > 0.0 and all(val >= 0.0 for val in weights): return [max(val, 0.0) / nonneg_total for val in weights] abs_total = sum(abs(val) for val in weights) if abs_total > 0.0: return [abs(val) / abs_total for val in weights] return [float("nan")] * len(weights) q_samples = _flatten_groups(q_grouped) weight_raw_samples = _flatten_groups(weight_groups) weight_mass_samples = _flatten_groups( [_weight_mass_proxy(group) for group in weight_groups] ) prompt_index_samples: List[int] = [] completion_index_samples: List[int] = [] group_size_samples: List[int] = [] reward_rank_samples: List[int] = [] group_offset = 0 for prompt_idx, completion_group in enumerate( getattr(prepared, "grouped_completions", []) or [] ): group_size = len(completion_group) if group_size <= 0: continue reward_slice = total_utils[group_offset : group_offset + group_size] reward_order = sorted( range(group_size), key=lambda idx: (-float(reward_slice[idx]), idx), ) reward_rank = {local_idx: rank + 1 for rank, local_idx in enumerate(reward_order)} for local_idx in range(group_size): prompt_index_samples.append(prompt_idx) completion_index_samples.append(local_idx) group_size_samples.append(group_size) reward_rank_samples.append(reward_rank.get(local_idx, local_idx + 1)) group_offset += group_size columns = [ "step", "prompt_index", "completion_index", "group_size", "reward_rank_desc", "prompt", "completion", "reward_total", "advantage", "q_mass", "update_weight_raw", "update_mass_proxy", ] + [ f"reward/{key}" for key in reward_keys ] rows: List[List[Any]] = [] for idx in range(max_rows): prompt = prompts[idx] completion = completions[idx] prompt_index = ( int(prompt_index_samples[idx]) if idx < len(prompt_index_samples) else -1 ) completion_index = ( int(completion_index_samples[idx]) if idx < len(completion_index_samples) else -1 ) group_size = ( int(group_size_samples[idx]) if idx < len(group_size_samples) else 0 ) reward_rank_desc = ( int(reward_rank_samples[idx]) if idx < len(reward_rank_samples) else -1 ) reward_total = ( float(total_utils[idx]) if idx < len(total_utils) else float("nan") ) advantage_val = ( float(advantages[idx]) if idx < len(advantages) else float("nan") ) q_mass = float(q_samples[idx]) if idx < len(q_samples) else float("nan") update_weight_raw = ( float(weight_raw_samples[idx]) if idx < len(weight_raw_samples) else float("nan") ) update_mass_proxy = ( float(weight_mass_samples[idx]) if idx < len(weight_mass_samples) else float("nan") ) row: List[Any] = [ step, prompt_index, completion_index, group_size, reward_rank_desc, prompt, completion, reward_total, advantage_val, q_mass, update_weight_raw, update_mass_proxy, ] for key in reward_keys: values = reward_values.get(key, []) reward_val = float(values[idx]) if idx < len(values) else float("nan") row.append(reward_val) rows.append(row) return columns, rows def _write_sample_table_sidecar( *, output_dir: str, table_key: str, step: int, columns: Sequence[str], rows: Sequence[Sequence[Any]], ) -> Optional[str]: """Persist the full completion table locally for deterministic downstream analysis.""" if not output_dir: return None try: sidecar_dir = os.path.join(output_dir, "rich_completions") os.makedirs(sidecar_dir, exist_ok=True) payload = {"columns": list(columns), "data": [list(row) for row in rows]} path = os.path.join( sidecar_dir, f"{table_key}_step_{int(step):06d}.json", ) with open(path, "w", encoding="utf-8") as handle: json.dump(payload, handle) return path except OSError: return None def _log_sample_table( ctx: TrainingLoopContext, state: MetricState, prepared: PreparedBatch, ) -> None: """Log a W&B table with prompt/completion samples when enabled. :param ctx: Training context providing logging handles and accelerator state. :type ctx: training.types.TrainingLoopContext :param state: Metric state containing the global step for table rows. :type state: MetricState :param prepared: Batch artifacts whose ``RewardComputation`` holds the prompts, completions, and reward components to display. :type prepared: PreparedBatch """ training_args = getattr(ctx, "training_args", None) wandb_run = ctx.logging.wandb_run accelerator = ctx.runtime.accelerator wandb_mod = _get_wandb() if wandb_mod is None and "wandb" in sys.modules: wandb_mod = sys.modules["wandb"] if wandb_mod is None: class _FallbackWandb: def Table( self, columns: Any = None, rows: Any = None, **_kwargs: Any, ) -> Dict[str, Any]: return {"columns": columns, "rows": rows} wandb_mod = _FallbackWandb() if not accelerator.is_main_process: _wait_after_rich_completion_logging(accelerator, training_args) return pairs = getattr(prepared.reward_comp, "pairs", None) if pairs is None: _wait_after_rich_completion_logging(accelerator, training_args) return if not pairs.prompts or not pairs.completions: _wait_after_rich_completion_logging(accelerator, training_args) return total_rows = min(len(pairs.prompts), len(pairs.completions)) if total_rows <= 0: _wait_after_rich_completion_logging(accelerator, training_args) return columns, rows = _build_sample_table(prepared, state.global_step, total_rows) if not rows: _wait_after_rich_completion_logging(accelerator, training_args) return table_key = "rich_completions" if training_args is not None: key_value = getattr(training_args, "rich_log_completions_key", table_key) if isinstance(key_value, str) and key_value.strip(): table_key = key_value.strip() sidecar_path = None if training_args is not None: output_dir = getattr(training_args, "output_dir", None) if isinstance(output_dir, str) and output_dir.strip(): sidecar_path = _write_sample_table_sidecar( output_dir=output_dir.strip(), table_key=table_key, step=state.global_step, columns=columns, rows=rows, ) LOG.info( "Logging enriched completion table | key=%s step=%d columns=%s rows=%d sidecar=%s", table_key, state.global_step, columns, len(rows), sidecar_path or "<none>", ) if wandb_run is None or not _rich_completion_wandb_enabled(training_args): _wait_after_rich_completion_logging(accelerator, training_args) return try: wandb_run.log( {table_key: wandb_mod.Table(columns=columns, rows=rows[:_WANDB_SAMPLE_ROWS])}, step=state.global_step, ) except WandbError: _wait_after_rich_completion_logging(accelerator, training_args) return _wait_after_rich_completion_logging(accelerator, training_args)
[docs] def log_training_step( ctx: TrainingLoopContext, state: MetricState, prepared: PreparedBatch, log_artifacts: LogStepArtifacts, current_lr: float, *, reward_view: Optional[RewardLoggingView] = None, weight_view: Optional[WeightLoggingView] = None, ) -> None: """Emit global metrics (including optional W&B logging). :param ctx: Training context containing runtime/logging handles. :type ctx: training.types.TrainingLoopContext :param state: Metric accumulator tracking running averages. :type state: MetricState :param prepared: Batch artifacts with reward/weight statistics. :type prepared: PreparedBatch :param log_artifacts: Loss outputs and diagnostics for the step. :type log_artifacts: LogStepArtifacts :param current_lr: Learning rate applied for the current step. :type current_lr: float :param reward_view: Optional reward statistics aggregated across ranks. :type reward_view: RewardLoggingView | None :param weight_view: Optional weight statistics aggregated across ranks. :type weight_view: WeightLoggingView | None """ training_args = getattr(ctx, "training_args", None) if _log_like_grpo_enabled(training_args): return if not _should_log(ctx, state.global_step): return accelerator = ctx.runtime.accelerator if accelerator.is_main_process: LOG.info( "step %d | epoch %.2f | loss=%.4f | tau=%.3f beta=%.3f", state.global_step, log_artifacts.epoch_progress, log_artifacts.loss_outputs.total_loss_scalar, ctx.scoring.weighting.tau, ctx.scoring.weighting.beta, ) averaged_metrics = flush_metric_averages(state) if averaged_metrics: averaged_metrics["train/global_step"] = float(state.global_step) averaged_metrics["train/epoch"] = _epoch_from_global_step( ctx.optimization.schedule, state.global_step, ) metrics = averaged_metrics else: payload = _build_metrics_payload( ctx, state, prepared, log_artifacts, current_lr, reward_view=reward_view, weight_view=weight_view, ) metrics = build_training_metrics_dict(payload, state.global_step) metrics["train/global_step"] = float(state.global_step) with ctx.logging.step_logger(state.global_step, enabled=True) as step_logger: _emit_metrics( ctx, metrics, state.global_step, log_to_wandb=True, tag="Global", metric_logger=getattr(step_logger, "log", None), ) if accelerator.is_main_process: pretty = _pretty_print_metrics(metrics) LOG.info("Global metrics (pretty) step %d\n%s", state.global_step, pretty) _update_weighting_history(ctx.scoring.weighting, state.global_step) if training_args is None: log_completions = True else: log_completions = bool( getattr( training_args, "rich_log_completions", getattr(training_args, "log_completions", False), ) ) if log_completions: _log_sample_table(ctx, state, prepared)
__all__ = [ "LogStepArtifacts", "accumulate_metrics", "build_training_metrics_dict", "flush_metric_averages", "log_local_step", "log_training_metrics", "log_training_step", "summarize_reward_stats", "summarize_weight_stats", ]