maxent_grpo.training.optim

Copyright 2025 Liv d’Aliberti

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Optimizer and gradient utilities shared across the training loop.

Functions

_filter_optimizer_kwargs(optimizer_cls, kwargs)

Drop optimizer kwargs unsupported by lightweight stubs or callables.

apply_learning_rate(handles, learning_rate)

Set the provided learning rate on all optimizer parameter groups.

build_optimization_handles(model, cfg)

Construct an optimizer/scheduler bundle that mirrors GRPO defaults.

clip_grad_norm_local(model, accelerator, ...)

Clip gradients via Accelerate when possible and return the norm.

configure_accumulation_steps(accelerator, ...)

Pass gradient accumulation steps to Accelerate when supported.

detect_deepspeed_state(accelerator)

Return DeepSpeed usage flags derived from the accelerator state.

epoch_progress(schedule, epoch, step_in_epoch)

Return floating-point epoch progress for logging.

optimizer_step(ctx, state, current_lr)

Perform an optimizer step and advance state.global_step.

require_accumulation_context(accelerator, model)

Return an accumulation context compatible with the current strategy.

scheduled_learning_rate(schedule, handles, step)

Return the learning rate for the given optimizer step.

sync_gradients_enabled(accelerator, global_step)

Return the sync_gradients flag and log it for debugging.

Classes

DeepspeedState(use_deepspeed, zero_stage)

Describe whether the current accelerator session uses DeepSpeed.

DistributedType()

class maxent_grpo.training.optim.DeepspeedState(use_deepspeed, zero_stage)[source]

Bases: object

Describe whether the current accelerator session uses DeepSpeed.

Parameters:
  • use_deepspeed (bool)

  • zero_stage (int)

use_deepspeed: bool
zero_stage: int
maxent_grpo.training.optim.apply_learning_rate(handles, learning_rate)[source]

Set the provided learning rate on all optimizer parameter groups.

Parameters:
  • handles (training.types.OptimizerHandles) – Wrapper containing the primary/base optimizers.

  • learning_rate (float) – Learning rate to apply across all parameter groups.

Return type:

None

maxent_grpo.training.optim.clip_grad_norm_local(model, accelerator, max_grad_norm)[source]

Clip gradients via Accelerate when possible and return the norm.

Parameters:
  • model (torch.nn.Module) – Model whose gradients should be clipped.

  • accelerator (Accelerator) – Accelerate handle providing clip_grad_norm_.

  • max_grad_norm (float) – Maximum norm applied during clipping.

Returns:

Gradient norm when clipping occurs, otherwise None.

Return type:

float | None

maxent_grpo.training.optim.configure_accumulation_steps(accelerator, grad_accum_steps)[source]

Pass gradient accumulation steps to Accelerate when supported.

Parameters:
  • accelerator (Accelerator) – Accelerate handle used to configure accumulation.

  • grad_accum_steps (int) – Desired gradient accumulation steps.

Return type:

None

maxent_grpo.training.optim.detect_deepspeed_state(accelerator)[source]

Return DeepSpeed usage flags derived from the accelerator state.

Parameters:

accelerator (Accelerator) – Accelerator instance whose state is inspected.

Returns:

DeepspeedState describing DeepSpeed usage and ZeRO stage.

Return type:

DeepspeedState

maxent_grpo.training.optim.epoch_progress(schedule, epoch, step_in_epoch)[source]

Return floating-point epoch progress for logging.

Parameters:
  • schedule (OptimizationSchedule) – Optimization schedule describing steps per epoch.

  • epoch (int) – Current epoch index (zero-based).

  • step_in_epoch (int) – Step index inside the current epoch.

Returns:

Floating-point epoch progress suitable for logs.

Return type:

float

maxent_grpo.training.optim.optimizer_step(ctx, state, current_lr)[source]

Perform an optimizer step and advance state.global_step.

Parameters:
  • ctx (training.types.TrainingLoopContext) – Training context containing optimizer handles.

  • state (TrainingLoopState) – Mutable training state tracking global steps.

  • current_lr (float) – Learning rate to apply before stepping.

Returns:

Gradient norm (if available) for metrics/logging.

Return type:

float | None

maxent_grpo.training.optim.require_accumulation_context(accelerator, model)[source]

Return an accumulation context compatible with the current strategy.

Parameters:
  • accelerator (Accelerator) – Accelerator instance providing accumulate.

  • model (Any) – Model passed to accelerator.accumulate when available.

Returns:

Context manager used to guard gradient accumulation.

Raises:

RuntimeError – If accumulation is required but unavailable.

Return type:

Any

maxent_grpo.training.optim.scheduled_learning_rate(schedule, handles, step)[source]

Return the learning rate for the given optimizer step.

Parameters:
  • schedule (training.types.OptimizationSchedule) – Optimization schedule describing warmup/total steps.

  • handles (training.types.OptimizerHandles) – Optimizer handles (used to read base LR).

  • step (int) – Current optimizer step index.

Returns:

Learning rate for this step.

Return type:

float

maxent_grpo.training.optim.sync_gradients_enabled(accelerator, global_step)[source]

Return the sync_gradients flag and log it for debugging.

Parameters:
  • accelerator (Accelerator) – Accelerator instance exposing sync_gradients.

  • global_step (int) – Current optimizer step used for debug logging.

Returns:

True if gradients should be synchronized this step.

Return type:

bool

maxent_grpo.training.optim.build_optimization_handles(model, cfg)[source]

Construct an optimizer/scheduler bundle that mirrors GRPO defaults.

The implementation follows the same AdamW parameter‑group semantics used by Hugging Face Trainer/TRL GRPO:

  • Parameters whose names contain "bias" or "LayerNorm.weight" are placed in a no‑decay group (weight_decay=0.0).

  • All other trainable parameters share a decay group with weight_decay=cfg.weight_decay.

  • Optimizer hyperparameters (learning rate, betas, epsilon) are taken from the GRPO/TrainingArguments instance so that MaxEnt runs stay aligned with the baseline GRPO trainer.

Parameters:
  • model (Any) – Model whose parameters will be optimized.

  • cfg (Any) – Training config carrying optimizer hyperparameters.

Returns:

OptimizerHandles with optimizer and metadata.

Return type:

OptimizerHandles

Raises:

ImportError – If torch is unavailable.