maxent_grpo.training.state

Training loop state helpers for controller and checkpoint management.

Functions

_callable_accepts_kwargs(fn)

_callable_accepts_param(fn, name)

_checkpoint_has_accelerate_state(checkpoint_dir)

Return True when a checkpoint directory looks loadable via accelerator.load_state.

_checkpoint_has_deepspeed_engine_state(...)

Return True when a checkpoint looks like a DeepSpeed engine checkpoint.

_checkpoint_has_hf_weights(checkpoint_dir)

_checkpoint_has_valid_hf_weights(checkpoint_dir)

Return True when a checkpoint directory contains loadable, non-empty HF weights.

_checkpoint_has_zero_shards(checkpoint_dir)

Return True when a checkpoint directory contains ZeRO shard files.

_get_last_checkpoint(output_dir)

Best-effort discovery of the latest checkpoint under output_dir.

_is_safetensors_available()

_load_controller_file(path, _accelerator, ...)

Load controller parameters from path when available.

_maybe_convert_zero_checkpoint_to_hf(...[, ...])

Attempt to convert ZeRO shards into a consolidated HF weight file.

_normalize_checkpoint_dir(path)

Promote DeepSpeed tag subfolders (e.g., global_step100/pytorch_model) to their parent.

_parse_checkpoint_step(path)

Return the numeric suffix from a checkpoint-<n> directory.

_parse_save_total_limit(value)

Normalize save_total_limit configuration values.

_prune_old_checkpoints(output_dir, limit)

Delete checkpoints to respect save_total_limit.

_read_checkpoint_latest_tag(checkpoint_dir)

Return the DeepSpeed/Accelerate checkpoint tag stored in latest if present.

_remove_hf_weight_files(checkpoint_dir)

_safetensors_header_has_valid_tensors(path)

Return True when a safetensors file declares non-empty tensors.

_save_consolidated_hf_weights(*, ...[, ...])

_state_dict_has_zero_sized_tensors(state_dict)

_write_trainer_state_json(checkpoint_dir, ...)

Persist a minimal trainer_state.json so future resumes find the step.

build_checkpoint_saver(training_args, ...[, ...])

Return a save_checkpoint callable compatible with LoggingHandles.

build_training_state(training_args)

Construct minimal logging handles for the custom runner.

check_stop_condition(schedule, loop_state)

Set stop flag when the configured number of steps is reached.

load_controller_state_chain(controller_cfg, ...)

Attempt to load controller state from resume directory or the current state.

load_trainer_state_metadata(checkpoint_path)

Load trainer_state.json if available for resume bookkeeping.

maybe_checkpoint(logging_cfg, accelerator, ...)

Checkpoint periodically while on the main process.

maybe_clear_stale_controller_state(...)

Delete a stale controller state file when overwriting the output dir.

maybe_load_accelerator_state(...)

Load an accelerator state directory when resuming if available.

resolve_resume_checkpoint(training_args)

Resolve the checkpoint path to resume from, if any.

Classes

AcceleratorLike(*args, **kwargs)

Subset of Accelerator API used by training state utilities.

ControllerPathsLike(*args, **kwargs)

Minimal controller path settings used by checkpoint helpers.

maxent_grpo.training.state.maybe_clear_stale_controller_state(accelerator, controller_cfg)[source]

Delete a stale controller state file when overwriting the output dir.

Parameters:
  • accelerator (AcceleratorLike) – Accelerate handle used to determine the main process and trigger wait_for_everyone guards.

  • controller_cfg (ControllerPathsLike) – Paths describing the active controller checkpoint/restore locations.

Return type:

None

maxent_grpo.training.state.load_controller_state_chain(controller_cfg, accelerator, weighting_cfg)[source]

Attempt to load controller state from resume directory or the current state.

Parameters:
  • controller_cfg (ControllerPathsLike) – Filesystem paths for controller checkpoints.

  • accelerator (AcceleratorLike) – Accelerate handle performing logging/synchronization.

  • weighting_cfg (WeightingConfigLike) – Mutable weighting settings that receive the loaded parameters.

Returns:

True when controller resume was requested or a controller checkpoint was successfully loaded.

Return type:

bool

maxent_grpo.training.state.resolve_resume_checkpoint(training_args)[source]

Resolve the checkpoint path to resume from, if any.

Parameters:

training_args (Any) – Trainer configuration with resume flags and output_dir.

Returns:

Tuple of (checkpoint path or None, whether resume was requested).

Return type:

tuple[str | None, bool]

maxent_grpo.training.state.load_trainer_state_metadata(checkpoint_path)[source]

Load trainer_state.json if available for resume bookkeeping.

Parameters:

checkpoint_path (str | None) – Path to a checkpoint directory.

Returns:

Parsed metadata fields (global_step, best metrics, etc.).

Return type:

dict[str, Any]

maxent_grpo.training.state.maybe_load_accelerator_state(resume_state_path, accelerator)[source]

Load an accelerator state directory when resuming if available.

Parameters:
  • resume_state_path (str | None) – Filesystem path to an accelerator state directory (e.g., saved by accelerator.save_state).

  • accelerator (AcceleratorLike) – Accelerate handle whose load_state method will be invoked.

Returns:

None.

Return type:

None

maxent_grpo.training.state.maybe_checkpoint(logging_cfg, accelerator, global_step)[source]

Checkpoint periodically while on the main process.

Parameters:
  • logging_cfg (LoggingHandles) – Logging handles containing checkpoint callbacks and scheduling knobs.

  • accelerator (AcceleratorLike) – Accelerate handle used for synchronization and main-process checks.

  • global_step (int) – Current optimizer step; used to decide whether save_steps divides the step index evenly.

Returns:

None.

Return type:

None

maxent_grpo.training.state.check_stop_condition(schedule, loop_state)[source]

Set stop flag when the configured number of steps is reached.

Parameters:
  • schedule (training.types.OptimizationSchedule) – Optimization schedule describing total_training_steps.

  • loop_state (training.types.TrainingLoopState) – Mutable training loop state whose stop_training flag should be updated when the threshold is crossed.

Returns:

None.

Return type:

None

maxent_grpo.training.state.build_checkpoint_saver(training_args, runtime_handles, optim_handles, tokenizer, *, state_ref=None, base_trainer_state=None, controller_cfg=None)[source]

Return a save_checkpoint callable compatible with LoggingHandles.

The returned callable snapshots accelerator state, model/optimizer weights, trainer state metadata, and optional controller state into a checkpoint directory under output_dir.

Parameters:
  • training_args (object) – Training configuration containing output/checkpoint options.

  • runtime_handles (object) – Runtime bundle providing model/accelerator references.

  • optim_handles (object) – Optimizer bundle used for saving optimizer state.

  • tokenizer (object) – Tokenizer to serialize alongside checkpoints.

  • state_ref (Dict[str, object] | None) – Mutable state dict used for cross-callback coordination.

  • base_trainer_state (Dict[str, object] | None) – Optional base trainer state JSON to merge into saves.

  • controller_cfg (ControllerPathsLike | None) – Optional controller state paths for MaxEnt.

Returns:

Callable save_checkpoint(name: str) -> None.

Return type:

Callable[[str], None]

maxent_grpo.training.state.build_training_state(training_args)[source]

Construct minimal logging handles for the custom runner.

Parameters:

training_args (object) – Training configuration providing save strategy/steps.

Returns:

LoggingHandles instance with a no-op checkpoint saver.

Return type:

LoggingHandles