maxent_grpo.training.metrics¶
Metrics and logging helpers for the MaxEnt-GRPO training loop.
Key entry points¶
log_local_stepEmits per-rank metrics for debugging and updates the accumulator used for windowed averages.
log_training_stepAggregates metrics across processes, forwards them to
wandband/or theacceleratelogger, and dumps a structured log line.LogStepArtifactsLightweight container that bundles loss outputs, diagnostics, gradient norms, and epoch progress.
The module also exposes helpers for building W&B sample tables, gathering statistics across ranks, and summarizing reward/weighting diagnostics. Docstrings follow Sphinx conventions so the documentation clearly describes the available metrics and their shapes.
Functions
|
Return a finite float or |
|
Return loss/optimizer scalars that mirror the TRL trainer. |
|
Return a structured payload describing the current step. |
|
Return W&B table columns/rows for sample completions. |
|
Return PPO-style clipping diagnostics. |
|
Remove all keys that start with |
|
Emit metrics to stdout, Accelerate, and optionally W&B. |
|
Return reward/bonus summary values when an entropy bonus is present. |
|
Return the current epoch progress given the training schedule. |
|
Return metrics filtered according to the configured mode. |
|
Return the global fraction of zero-variance advantage groups. |
|
Gather dict-of-list structures across processes. |
|
Gather a sequence of floats across processes. |
|
Return the wandb module when available (facilitates testing). |
|
Metrics summarizing completion lengths. |
|
Emit a concise debug line with key metrics for the current step. |
|
Emit a concise log line showing entropy bonus impact when present. |
|
Return |
|
Log a W&B table with prompt/completion samples when enabled. |
|
Return logging cadence (strategy, steps, first-step flag). |
|
Break down the loss into individual components. |
|
Compute mean/std for a list of values. |
|
Return the logging mode for metrics filtering. |
|
Return token-weighted policy entropy from a SequenceScores-like object. |
|
Return a deterministic, pretty JSON string for human-readable logs. |
|
Compute simple linear-interpolated quantiles for logging. |
|
Convert raw reward samples into summary statistics. |
|
|
|
Return whether ranks should synchronize after rich completion logging. |
|
Return whether enriched completion tables should also go to W&B. |
|
Return True when metrics should be emitted for this step. |
|
Return a compact metrics dict for W&B/console logging. |
|
Sum a scalar across all processes. |
|
Aggregate reward/advantage stats into a lightweight view. |
|
Summarize entropy statistics for logging. |
|
Cache the last-seen tau/beta for delta logging. |
|
Synchronize ranks after rich completion logging when configured. |
|
Entropy diagnostics for the MaxEnt weighting distribution. |
|
Log controller hyperparameters for both GRPO and MaxEnt-GRPO. |
|
Persist the full completion table locally for deterministic downstream analysis. |
|
Accumulate per-batch metrics so the global log can show running averages. |
|
Return the flattened metrics dictionary for logging. |
|
Return averaged metrics and clear the accumulator. |
|
Log metrics for the current step on the main process only. |
|
Emit scalar metrics to logging callbacks and return them. |
|
Emit global metrics (including optional W&B logging). |
|
Aggregate reward statistics across all ranks. |
|
Aggregate per-batch weight statistics across all processes. |
Exceptions
|
Fallback error used when wandb is unavailable. |
- maxent_grpo.training.metrics.accumulate_metrics(state, metrics)[source]¶
Accumulate per-batch metrics so the global log can show running averages.
- Parameters:
state (MetricState) – Mutable metric accumulator storing sums/counts.
metrics (dict[str, Any]) – Scalar metrics emitted for the current step.
- Return type:
None
- maxent_grpo.training.metrics.build_training_metrics_dict(payload, global_step)[source]¶
Return the flattened metrics dictionary for logging.
- Parameters:
payload (TrainingMetricsPayload) – Structured metrics payload produced by the training loop.
global_step (int) – Current optimizer step used for logging context.
- Returns:
Flat mapping of scalar metrics keyed by name.
- Return type:
- maxent_grpo.training.metrics.flush_metric_averages(state)[source]¶
Return averaged metrics and clear the accumulator.
- Parameters:
state (MetricState) – Metric accumulator to flush.
- Returns:
Mapping of metric name to averaged value.
- Return type:
- maxent_grpo.training.metrics.log_local_step(ctx, state, prepared, log_artifacts, current_lr, *, reward_view=None, weight_view=None, emit=True)[source]¶
Log metrics for the current step on the main process only.
- Parameters:
ctx (training.types.TrainingLoopContext) – Full training loop context containing runtime/logging handles.
state (MetricState) – Metric accumulator tracking sums and counts.
prepared (PreparedBatch) – Prepared batch with reward and weighting statistics.
log_artifacts (LogStepArtifacts) – Loss outputs and diagnostics emitted by the optimizer step.
current_lr (float) – Learning rate applied for the current step.
reward_view (RewardLoggingView | None) – Optional reward statistics aggregated across ranks.
weight_view (WeightLoggingView | None) – Optional weight statistics aggregated across ranks.
emit (bool) – When
False, skip emitting logs and only accumulate averages.
- Return type:
None
- maxent_grpo.training.metrics.log_training_metrics(logging_cfg, global_step, payload)[source]¶
Emit scalar metrics to logging callbacks and return them.
- Parameters:
logging_cfg (LoggingHandles) – Logging handles (W&B, tensorboard, stdout, etc.).
global_step (int) – Current optimizer step.
payload (TrainingMetricsPayload) – Structured metrics payload to log.
- Returns:
Flattened metrics dictionary emitted to loggers.
- Return type:
- maxent_grpo.training.metrics.log_training_step(ctx, state, prepared, log_artifacts, current_lr, *, reward_view=None, weight_view=None)[source]¶
Emit global metrics (including optional W&B logging).
- Parameters:
ctx (training.types.TrainingLoopContext) – Training context containing runtime/logging handles.
state (MetricState) – Metric accumulator tracking running averages.
prepared (PreparedBatch) – Batch artifacts with reward/weight statistics.
log_artifacts (LogStepArtifacts) – Loss outputs and diagnostics for the step.
current_lr (float) – Learning rate applied for the current step.
reward_view (RewardLoggingView | None) – Optional reward statistics aggregated across ranks.
weight_view (WeightLoggingView | None) – Optional weight statistics aggregated across ranks.
- Return type:
None
- maxent_grpo.training.metrics.summarize_reward_stats(accelerator, reward_comp, *, log_like_grpo=False)[source]¶
Aggregate reward statistics across all ranks.
Exposes the internal helper so that training code can gather reward diagnostics even on non-main ranks before metrics are logged.
- Parameters:
accelerator (Accelerator) – Accelerate handle used for reductions.
reward_comp (RewardComputation | None) – Reward computation outputs for the current batch.
log_like_grpo (bool) – When
True, skip global reductions and keep local statistics for GRPO-style logging.
- Returns:
Aggregated reward statistics for logging.
- Return type:
- maxent_grpo.training.metrics.summarize_weight_stats(accelerator, weight_stats, *, log_like_grpo=False)[source]¶
Aggregate per-batch weight statistics across all processes.
Exposes the internal summarization helper so controller logic can rely on the same cross-rank entropy measurement used for logging.
- Parameters:
accelerator (Accelerator) – Accelerate handle used for reductions.
weight_stats (WeightStats) – Weight statistics for the current batch.
log_like_grpo (bool) – When
True, skip global reductions and keep local statistics for GRPO-style logging.
- Returns:
Aggregated weight statistics for logging.
- Return type: