maxent_grpo.training.metrics

Metrics and logging helpers for the MaxEnt-GRPO training loop.

Key entry points

log_local_step

Emits per-rank metrics for debugging and updates the accumulator used for windowed averages.

log_training_step

Aggregates metrics across processes, forwards them to wandb and/or the accelerate logger, and dumps a structured log line.

LogStepArtifacts

Lightweight container that bundles loss outputs, diagnostics, gradient norms, and epoch progress.

The module also exposes helpers for building W&B sample tables, gathering statistics across ranks, and summarizing reward/weighting diagnostics. Docstrings follow Sphinx conventions so the documentation clearly describes the available metrics and their shapes.

Functions

_as_float(value)

Return a finite float or None when conversion fails.

_base_metric_block(payload, global_step)

Return loss/optimizer scalars that mirror the TRL trainer.

_build_metrics_payload(ctx, state, prepared, ...)

Return a structured payload describing the current step.

_build_sample_table(prepared, step, max_rows)

Return W&B table columns/rows for sample completions.

_clip_metric_block(diagnostics)

Return PPO-style clipping diagnostics.

_drop_prefix(metrics, prefix)

Remove all keys that start with prefix from metrics.

_emit_metrics(ctx, metrics, global_step, *, ...)

Emit metrics to stdout, Accelerate, and optionally W&B.

_entropy_bonus_impact(reward_stats)

Return reward/bonus summary values when an entropy bonus is present.

_epoch_from_global_step(schedule, global_step)

Return the current epoch progress given the training schedule.

_filter_metrics(metrics, ctx)

Return metrics filtered according to the configured mode.

_fraction_zero_std_groups(accelerator, ...)

Return the global fraction of zero-variance advantage groups.

_gather_dict_of_lists_for_metrics(...[, ...])

Gather dict-of-list structures across processes.

_gather_list_for_metrics(accelerator, values, *)

Gather a sequence of floats across processes.

_get_wandb()

Return the wandb module when available (facilitates testing).

_length_metric_block(length_stats)

Metrics summarizing completion lengths.

_log_debug_metrics(step, metrics)

Emit a concise debug line with key metrics for the current step.

_log_entropy_bonus_impact(metrics, step, *, tag)

Emit a concise log line showing entropy bonus impact when present.

_log_like_grpo_enabled(training_args)

Return True when GRPO-style per-rank logging is requested.

_log_sample_table(ctx, state, prepared)

Log a W&B table with prompt/completion samples when enabled.

_logging_controls(ctx)

Return logging cadence (strategy, steps, first-step flag).

_loss_component_block(loss_outputs)

Break down the loss into individual components.

_mean_std(values)

Compute mean/std for a list of values.

_metrics_mode()

Return the logging mode for metrics filtering.

_policy_entropy_from_scores(scores)

Return token-weighted policy entropy from a SequenceScores-like object.

_pretty_print_metrics(metrics)

Return a deterministic, pretty JSON string for human-readable logs.

_quantile_stats(values, quantiles)

Compute simple linear-interpolated quantiles for logging.

_reward_component_stats(per_reward_values)

Convert raw reward samples into summary statistics.

_reward_metric_block(payload)

_rich_completion_sync_enabled(training_args)

Return whether ranks should synchronize after rich completion logging.

_rich_completion_wandb_enabled(training_args)

Return whether enriched completion tables should also go to W&B.

_should_log(ctx, step)

Return True when metrics should be emitted for this step.

_slim_metrics(metrics, _ctx)

Return a compact metrics dict for W&B/console logging.

_sum_scalar_for_metrics(accelerator, value, *)

Sum a scalar across all processes.

_summarize_reward_stats(accelerator, ...[, ...])

Aggregate reward/advantage stats into a lightweight view.

_summarize_weight_stats(accelerator, ...[, ...])

Summarize entropy statistics for logging.

_update_weighting_history(weighting, global_step)

Cache the last-seen tau/beta for delta logging.

_wait_after_rich_completion_logging(...)

Synchronize ranks after rich completion logging when configured.

_weight_metric_block(payload)

Entropy diagnostics for the MaxEnt weighting distribution.

_weighting_config_block(payload, global_step)

Log controller hyperparameters for both GRPO and MaxEnt-GRPO.

_write_sample_table_sidecar(*, output_dir, ...)

Persist the full completion table locally for deterministic downstream analysis.

accumulate_metrics(state, metrics)

Accumulate per-batch metrics so the global log can show running averages.

build_training_metrics_dict(payload, global_step)

Return the flattened metrics dictionary for logging.

flush_metric_averages(state)

Return averaged metrics and clear the accumulator.

log_local_step(ctx, state, prepared, ...[, ...])

Log metrics for the current step on the main process only.

log_training_metrics(logging_cfg, ...)

Emit scalar metrics to logging callbacks and return them.

log_training_step(ctx, state, prepared, ...)

Emit global metrics (including optional W&B logging).

summarize_reward_stats(accelerator, ...[, ...])

Aggregate reward statistics across all ranks.

summarize_weight_stats(accelerator, ...[, ...])

Aggregate per-batch weight statistics across all processes.

Exceptions

_FallbackWandbError

Fallback error used when wandb is unavailable.

maxent_grpo.training.metrics.accumulate_metrics(state, metrics)[source]

Accumulate per-batch metrics so the global log can show running averages.

Parameters:
  • state (MetricState) – Mutable metric accumulator storing sums/counts.

  • metrics (dict[str, Any]) – Scalar metrics emitted for the current step.

Return type:

None

maxent_grpo.training.metrics.build_training_metrics_dict(payload, global_step)[source]

Return the flattened metrics dictionary for logging.

Parameters:
  • payload (TrainingMetricsPayload) – Structured metrics payload produced by the training loop.

  • global_step (int) – Current optimizer step used for logging context.

Returns:

Flat mapping of scalar metrics keyed by name.

Return type:

dict[str, Any]

maxent_grpo.training.metrics.flush_metric_averages(state)[source]

Return averaged metrics and clear the accumulator.

Parameters:

state (MetricState) – Metric accumulator to flush.

Returns:

Mapping of metric name to averaged value.

Return type:

dict[str, float]

maxent_grpo.training.metrics.log_local_step(ctx, state, prepared, log_artifacts, current_lr, *, reward_view=None, weight_view=None, emit=True)[source]

Log metrics for the current step on the main process only.

Parameters:
  • ctx (training.types.TrainingLoopContext) – Full training loop context containing runtime/logging handles.

  • state (MetricState) – Metric accumulator tracking sums and counts.

  • prepared (PreparedBatch) – Prepared batch with reward and weighting statistics.

  • log_artifacts (LogStepArtifacts) – Loss outputs and diagnostics emitted by the optimizer step.

  • current_lr (float) – Learning rate applied for the current step.

  • reward_view (RewardLoggingView | None) – Optional reward statistics aggregated across ranks.

  • weight_view (WeightLoggingView | None) – Optional weight statistics aggregated across ranks.

  • emit (bool) – When False, skip emitting logs and only accumulate averages.

Return type:

None

maxent_grpo.training.metrics.log_training_metrics(logging_cfg, global_step, payload)[source]

Emit scalar metrics to logging callbacks and return them.

Parameters:
  • logging_cfg (LoggingHandles) – Logging handles (W&B, tensorboard, stdout, etc.).

  • global_step (int) – Current optimizer step.

  • payload (TrainingMetricsPayload) – Structured metrics payload to log.

Returns:

Flattened metrics dictionary emitted to loggers.

Return type:

dict[str, Any]

maxent_grpo.training.metrics.log_training_step(ctx, state, prepared, log_artifacts, current_lr, *, reward_view=None, weight_view=None)[source]

Emit global metrics (including optional W&B logging).

Parameters:
  • ctx (training.types.TrainingLoopContext) – Training context containing runtime/logging handles.

  • state (MetricState) – Metric accumulator tracking running averages.

  • prepared (PreparedBatch) – Batch artifacts with reward/weight statistics.

  • log_artifacts (LogStepArtifacts) – Loss outputs and diagnostics for the step.

  • current_lr (float) – Learning rate applied for the current step.

  • reward_view (RewardLoggingView | None) – Optional reward statistics aggregated across ranks.

  • weight_view (WeightLoggingView | None) – Optional weight statistics aggregated across ranks.

Return type:

None

maxent_grpo.training.metrics.summarize_reward_stats(accelerator, reward_comp, *, log_like_grpo=False)[source]

Aggregate reward statistics across all ranks.

Exposes the internal helper so that training code can gather reward diagnostics even on non-main ranks before metrics are logged.

Parameters:
  • accelerator (Accelerator) – Accelerate handle used for reductions.

  • reward_comp (RewardComputation | None) – Reward computation outputs for the current batch.

  • log_like_grpo (bool) – When True, skip global reductions and keep local statistics for GRPO-style logging.

Returns:

Aggregated reward statistics for logging.

Return type:

RewardLoggingView

maxent_grpo.training.metrics.summarize_weight_stats(accelerator, weight_stats, *, log_like_grpo=False)[source]

Aggregate per-batch weight statistics across all processes.

Exposes the internal summarization helper so controller logic can rely on the same cross-rank entropy measurement used for logging.

Parameters:
  • accelerator (Accelerator) – Accelerate handle used for reductions.

  • weight_stats (WeightStats) – Weight statistics for the current batch.

  • log_like_grpo (bool) – When True, skip global reductions and keep local statistics for GRPO-style logging.

Returns:

Aggregated weight statistics for logging.

Return type:

WeightLoggingView