# Copyright 2025 Liv d'Aliberti
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training loop state helpers for controller and checkpoint management."""
from __future__ import annotations
import json
from collections.abc import Mapping
import logging
import os
import shutil
import sys
import inspect
from typing import Any, Callable, Dict, Optional, Tuple, Protocol, TYPE_CHECKING, cast
from types import SimpleNamespace
from .types import (
LoggingHandles,
OptimizationSchedule,
TrainingLoopState,
)
from .optim import detect_deepspeed_state
from .weighting.types import WeightingConfigLike
from .weighting.logic import (
CONTROLLER_STATE_FILENAME,
broadcast_controller_state,
load_controller_state,
)
LOG = logging.getLogger(__name__)
_checkpoint_log_once = {"config": False, "strategy": False, "steps": False}
if TYPE_CHECKING:
from maxent_grpo.config import GRPOConfig as GRPOConfigType
else:
GRPOConfigType = object
class ControllerPathsLike(Protocol):
"""Minimal controller path settings used by checkpoint helpers."""
state_path: Optional[str]
class AcceleratorLike(Protocol):
"""Subset of Accelerator API used by training state utilities."""
@property
def is_main_process(self) -> bool:
"""Return True when on the main process."""
raise NotImplementedError
def wait_for_everyone(self) -> None:
"""Synchronize all processes."""
raise NotImplementedError
def load_state(self, input_dir: str, **kwargs: object) -> object:
"""Load accelerator state from ``path``."""
raise NotImplementedError
def _is_safetensors_available() -> bool:
import importlib.util
return importlib.util.find_spec("safetensors") is not None
def _callable_accepts_kwargs(fn: object) -> bool:
if not callable(fn):
return True
try:
sig = inspect.signature(cast(Callable[..., object], fn))
except (TypeError, ValueError):
return True
return any(
param.kind == inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values()
)
def _callable_accepts_param(fn: object, name: str) -> bool:
if not callable(fn):
return True
try:
sig = inspect.signature(cast(Callable[..., object], fn))
except (TypeError, ValueError):
return True
if name in sig.parameters:
return True
return _callable_accepts_kwargs(fn)
def _checkpoint_has_hf_weights(checkpoint_dir: str) -> bool:
candidates = (
"model.safetensors",
"model.safetensors.index.json",
"pytorch_model.bin",
"pytorch_model.bin.index.json",
)
return any(
os.path.isfile(os.path.join(checkpoint_dir, name)) for name in candidates
)
def _safetensors_header_has_valid_tensors(path: str) -> bool:
"""Return True when a safetensors file declares non-empty tensors.
This is a lightweight validation that reads only the JSON header and checks that:
- at least one tensor entry exists
- no tensor has a zero-sized dimension
"""
try:
with open(path, "rb") as handle:
header_len_raw = handle.read(8)
if len(header_len_raw) != 8:
return False
header_len = int.from_bytes(header_len_raw, "little", signed=False)
if header_len <= 0 or header_len > 128 * 1024 * 1024:
return False
header = handle.read(header_len)
meta = json.loads(header.decode("utf-8"))
except (OSError, ValueError, UnicodeDecodeError):
return False
if not isinstance(meta, dict):
return False
saw_tensor = False
for key, value in meta.items():
if key == "__metadata__":
continue
if not isinstance(value, dict):
continue
shape = value.get("shape")
if not isinstance(shape, list):
continue
if not shape:
return False
dims: list[int] = []
for dim in shape:
if not isinstance(dim, int):
return False
dims.append(dim)
if any(dim <= 0 for dim in dims):
return False
saw_tensor = True
return saw_tensor
def _checkpoint_has_valid_hf_weights(checkpoint_dir: str) -> bool:
"""Return True when a checkpoint directory contains loadable, non-empty HF weights."""
safetensors_path = os.path.join(checkpoint_dir, "model.safetensors")
if os.path.isfile(safetensors_path):
return _safetensors_header_has_valid_tensors(safetensors_path)
torch_bin = os.path.join(checkpoint_dir, "pytorch_model.bin")
if os.path.isfile(torch_bin):
try:
return os.path.getsize(torch_bin) > 1_000_000
except OSError:
return False
# Sharded checkpoints: validate that the index exists and at least one shard file is present.
index_path = os.path.join(checkpoint_dir, "model.safetensors.index.json")
if os.path.isfile(index_path):
try:
with open(index_path, "r", encoding="utf-8") as handle:
idx = json.load(handle) or {}
weight_map = idx.get("weight_map") or {}
if not isinstance(weight_map, dict) or not weight_map:
return False
shard_files = sorted({str(v) for v in weight_map.values() if v})
except (OSError, ValueError, UnicodeDecodeError):
return False
for shard in shard_files:
shard_path = os.path.join(checkpoint_dir, shard)
if os.path.isfile(shard_path):
try:
if os.path.getsize(shard_path) > 1_000_000:
return True
except OSError:
continue
return False
index_path = os.path.join(checkpoint_dir, "pytorch_model.bin.index.json")
if os.path.isfile(index_path):
try:
with open(index_path, "r", encoding="utf-8") as handle:
idx = json.load(handle) or {}
weight_map = idx.get("weight_map") or {}
if not isinstance(weight_map, dict) or not weight_map:
return False
shard_files = sorted({str(v) for v in weight_map.values() if v})
except (OSError, ValueError, UnicodeDecodeError):
return False
for shard in shard_files:
shard_path = os.path.join(checkpoint_dir, shard)
if os.path.isfile(shard_path):
try:
if os.path.getsize(shard_path) > 1_000_000:
return True
except OSError:
continue
return False
return False
def _read_checkpoint_latest_tag(checkpoint_dir: str) -> Optional[str]:
"""Return the DeepSpeed/Accelerate checkpoint tag stored in ``latest`` if present."""
latest_path = os.path.join(checkpoint_dir, "latest")
if not os.path.isfile(latest_path):
return None
try:
with open(latest_path, "r", encoding="utf-8") as handle:
tag = (handle.read() or "").strip()
except OSError:
return None
if not tag:
return None
# Defensive: tags should be a single path component (avoid traversal).
if os.sep in tag or (os.altsep and os.altsep in tag):
return None
return tag
def _checkpoint_has_accelerate_state(checkpoint_dir: str) -> bool:
"""Return True when a checkpoint directory looks loadable via ``accelerator.load_state``.
Note: ``accelerator.load_state`` only supports checkpoints produced by
``accelerator.save_state`` (it expects specific filenames/structure, e.g.
``<checkpoint_dir>/pytorch_model/...`` for DeepSpeed). Hugging Face
``Trainer`` checkpoints (e.g. ``checkpoint-150/global_step150``) are *not*
compatible with ``accelerator.load_state`` even though they may contain a
DeepSpeed tag directory.
"""
# Accelerate (DeepSpeed) stores shards under a directory named after MODEL_NAME,
# e.g. <checkpoint_dir>/pytorch_model/{latest,global_step...}
# Accelerate (FSDP) uses a different model directory name.
for name in ("pytorch_model", "pytorch_model_fsdp"):
if os.path.isdir(os.path.join(checkpoint_dir, name)):
return True
try:
for entry in os.listdir(checkpoint_dir):
if entry.startswith("pytorch_model_") and os.path.isdir(
os.path.join(checkpoint_dir, entry)
):
return True
except OSError as exc:
LOG.debug("Failed to scan checkpoint dir for model shards: %s", exc)
# Accelerate's non-DeepSpeed checkpoints use `.bin` for optimizer/scheduler and
# `random_states_<rank>.pkl` (torch-saved payloads) for RNG tracking.
for name in (
"optimizer.bin",
"scheduler.bin",
"sampler.bin",
"dl_state_dict.bin",
"scaler.pt",
"random_states_0.pkl",
):
if os.path.isfile(os.path.join(checkpoint_dir, name)):
return True
try:
for entry in os.listdir(checkpoint_dir):
if entry.startswith("optimizer_") and entry.endswith(".bin"):
return True
if entry.startswith("scheduler_") and entry.endswith(".bin"):
return True
if entry.startswith("random_states_") and entry.endswith(".pkl"):
return True
except OSError as exc:
LOG.debug("Failed to scan checkpoint dir for optimizer/random state: %s", exc)
return False
def _checkpoint_has_deepspeed_engine_state(checkpoint_dir: str) -> bool:
"""Return True when a checkpoint looks like a DeepSpeed engine checkpoint.
This matches Hugging Face Trainer/TRL checkpoints that contain a ``latest`` file
pointing at a ``global_step*`` directory with ZeRO shards (for example
``zero_pp_rank_*_model_states.pt``).
"""
tag = _read_checkpoint_latest_tag(checkpoint_dir)
if not tag or not tag.startswith("global_step"):
return False
tag_dir = os.path.join(checkpoint_dir, tag)
if not os.path.isdir(tag_dir):
return False
try:
for name in os.listdir(tag_dir):
if name.endswith("_model_states.pt") or name.endswith("_optim_states.pt"):
return True
except OSError:
return False
return False
def _checkpoint_has_zero_shards(checkpoint_dir: str) -> bool:
"""Return True when a checkpoint directory contains ZeRO shard files."""
tag = _read_checkpoint_latest_tag(checkpoint_dir)
if not tag:
return False
tag_dir = os.path.join(checkpoint_dir, tag)
if not os.path.isdir(tag_dir):
return False
try:
for name in os.listdir(tag_dir):
if name.endswith("_model_states.pt") or name.endswith("_optim_states.pt"):
return True
except OSError:
return False
return False
def _maybe_convert_zero_checkpoint_to_hf(
checkpoint_dir: str,
*,
max_shard_size: str = "100GB",
) -> bool:
"""Attempt to convert ZeRO shards into a consolidated HF weight file."""
if not _checkpoint_has_zero_shards(checkpoint_dir):
return False
try:
from deepspeed.utils import zero_to_fp32 # type: ignore
except (
ImportError,
ModuleNotFoundError,
) as exc: # pragma: no cover - optional dependency
LOG.warning(
"DeepSpeed zero_to_fp32 unavailable; cannot consolidate ZeRO shards in %s: %s",
checkpoint_dir,
exc,
)
return False
safe_serialization = bool(_is_safetensors_available())
try:
zero_to_fp32.convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir,
checkpoint_dir,
max_shard_size=max_shard_size,
safe_serialization=safe_serialization,
)
except (OSError, RuntimeError, TypeError, ValueError) as exc:
LOG.warning(
"Failed to convert ZeRO shards into consolidated weights for %s: %s",
checkpoint_dir,
exc,
)
return False
return True
def _normalize_checkpoint_dir(path: str) -> str:
"""Promote DeepSpeed tag subfolders (e.g., ``global_step100``/``pytorch_model``) to their parent."""
if not isinstance(path, str) or not path:
return path
normalized = path.rstrip(os.sep)
if not os.path.isdir(normalized):
return path
base = os.path.basename(normalized)
if base == "pytorch_model" or base.startswith("global_step"):
parent = os.path.dirname(normalized)
parent_tag = _read_checkpoint_latest_tag(parent)
if parent_tag == base:
return parent
if os.path.basename(parent).startswith("checkpoint-"):
return parent
return path
def _state_dict_has_zero_sized_tensors(state_dict: Optional[Dict[str, object]]) -> bool:
if not isinstance(state_dict, dict) or not state_dict:
return True
try:
import torch
except (ImportError, ModuleNotFoundError): # pragma: no cover - optional runtime
return False
for value in state_dict.values():
try:
is_tensor = isinstance(value, torch.Tensor)
except TypeError:
is_tensor = False
if is_tensor:
try:
if cast(Any, value).numel() == 0:
return True
except (AttributeError, RuntimeError, TypeError):
continue
# If there are no tensors (e.g., stubbed accelerators in tests), treat the
# payload as non-indicative rather than invalid.
return False
def _remove_hf_weight_files(checkpoint_dir: str) -> None:
candidates = (
"model.safetensors",
"model.safetensors.index.json",
"pytorch_model.bin",
"pytorch_model.bin.index.json",
)
for name in candidates:
path = os.path.join(checkpoint_dir, name)
try:
if os.path.isfile(path):
os.remove(path)
except OSError:
continue
def _save_consolidated_hf_weights(
*,
model_to_save: object,
checkpoint_dir: str,
state_dict: Optional[Dict[str, object]],
max_shard_size: str = "100GB",
) -> None:
save_pretrained = getattr(model_to_save, "save_pretrained", None)
if not callable(save_pretrained):
raise TypeError("Model does not define save_pretrained()")
save_kwargs: Dict[str, object] = {}
if state_dict is not None:
save_kwargs["state_dict"] = state_dict
safetensors_ok = _is_safetensors_available()
save_kwargs["safe_serialization"] = bool(safetensors_ok)
save_kwargs["max_shard_size"] = max_shard_size
def _try_save(kwargs: Dict[str, object]) -> None:
save_pretrained(checkpoint_dir, **kwargs)
def _retry_without_kwargs(kwargs: Dict[str, object]) -> None:
retry_kwargs = dict(kwargs)
last_exc: Optional[TypeError] = None
for key in ("max_shard_size", "safe_serialization", "state_dict"):
retry_kwargs.pop(key, None)
try:
_try_save(retry_kwargs)
return
except TypeError as exc:
last_exc = exc
continue
if last_exc is not None:
raise last_exc
raise TypeError("save_pretrained rejected provided kwargs")
try:
_try_save(save_kwargs)
except TypeError:
_retry_without_kwargs(save_kwargs)
return
except (OSError, RuntimeError, ValueError):
if save_kwargs.get("safe_serialization") is True:
retry_kwargs = dict(save_kwargs)
retry_kwargs["safe_serialization"] = False
try:
_try_save(retry_kwargs)
return
except TypeError:
_retry_without_kwargs(retry_kwargs)
return
raise
def _parse_checkpoint_step(path: str) -> Optional[int]:
"""Return the numeric suffix from a ``checkpoint-<n>`` directory."""
tail = os.path.basename(path.rstrip(os.sep))
if tail.startswith("checkpoint-"):
try:
return int(tail.split("-")[-1])
except (TypeError, ValueError):
return None
return None
def _parse_save_total_limit(value: object) -> int:
"""Normalize ``save_total_limit`` configuration values."""
if value is None:
return 0
try:
limit = int(cast(Any, value))
except (TypeError, ValueError):
return 0
return max(limit, 0)
def _prune_old_checkpoints(output_dir: Optional[str], limit: int) -> None:
"""Delete checkpoints to respect ``save_total_limit``."""
if not output_dir or limit <= 0:
return
try:
entries: list[tuple[int, str]] = []
for name in os.listdir(output_dir):
if not name.startswith("checkpoint-"):
continue
path = os.path.join(output_dir, name)
if not os.path.isdir(path):
continue
step = _parse_checkpoint_step(name)
key = step if step is not None else -1
entries.append((key, name))
except OSError:
return
if len(entries) <= limit:
return
entries.sort(key=lambda pair: (pair[0], pair[1]))
to_remove = entries[: len(entries) - limit]
for _, name in to_remove:
path = os.path.join(output_dir, name)
try:
shutil.rmtree(path)
except OSError as exc:
LOG.warning("Failed to prune checkpoint %s: %s", path, exc)
def _get_last_checkpoint(output_dir: Optional[str]) -> Optional[str]:
"""Best-effort discovery of the latest checkpoint under ``output_dir``."""
if not output_dir or not os.path.isdir(output_dir):
return None
try:
from transformers.trainer_utils import get_last_checkpoint
except (ImportError, ModuleNotFoundError): # pragma: no cover - optional dependency
get_last_checkpoint = None
if callable(get_last_checkpoint):
try:
last = get_last_checkpoint(output_dir)
except (OSError, RuntimeError, ValueError): # pragma: no cover - defensive
last = None
if last:
return last
try:
entries = [
d
for d in os.listdir(output_dir)
if d.startswith("checkpoint-")
and os.path.isdir(os.path.join(output_dir, d))
]
except OSError:
return None
if not entries:
return None
entries.sort(key=lambda name: _parse_checkpoint_step(name) or -1)
return os.path.join(output_dir, entries[-1])
[docs]
def resolve_resume_checkpoint(
training_args: GRPOConfigType,
) -> Tuple[Optional[str], bool]:
"""Resolve the checkpoint path to resume from, if any.
:param training_args: Trainer configuration with resume flags and output_dir.
:type training_args: Any
:returns: Tuple of (checkpoint path or None, whether resume was requested).
:rtype: tuple[str | None, bool]
"""
resume_cfg = getattr(training_args, "resume_from_checkpoint", None)
init_path = getattr(training_args, "init_from_checkpoint", None)
output_dir = getattr(training_args, "output_dir", None)
requested = bool(resume_cfg) or bool(init_path)
def _validate(path: Optional[str]) -> Optional[str]:
if isinstance(path, str) and path:
candidate = _normalize_checkpoint_dir(path)
if os.path.isdir(candidate):
if _checkpoint_has_valid_hf_weights(candidate):
return candidate
if _checkpoint_has_accelerate_state(candidate):
if _checkpoint_has_hf_weights(candidate):
LOG.warning(
"Checkpoint %s contains HF weight files but they look invalid; "
"will attempt resume via accelerator state instead.",
candidate,
)
else:
LOG.info(
"Checkpoint %s does not contain consolidated HF weights; "
"will attempt resume via accelerator state.",
candidate,
)
return candidate
if _checkpoint_has_hf_weights(candidate):
LOG.warning(
"Checkpoint %s contains HF weight files but they look invalid "
"(e.g., zero-sized tensors); ignoring this checkpoint.",
candidate,
)
else:
LOG.warning(
"Checkpoint %s does not contain a loadable HF weight file "
"(expected model.safetensors or pytorch_model.bin) or accelerator state; "
"ignoring this checkpoint.",
candidate,
)
return None
if path:
LOG.warning(
"resume_from_checkpoint=%s was requested but the path does not exist; "
"starting from scratch.",
path,
)
return None
if isinstance(init_path, str) and init_path:
resolved = _validate(init_path)
if resolved:
return resolved, True
resolved = None
if isinstance(resume_cfg, str) and resume_cfg:
resolved = _validate(resume_cfg)
elif resume_cfg:
resolved = _validate(_get_last_checkpoint(output_dir))
if resolved is None:
LOG.warning(
"resume_from_checkpoint was requested but no checkpoint was found under %s; "
"starting from scratch.",
output_dir or "<unspecified>",
)
else:
resolved = None
return resolved, requested
def _write_trainer_state_json(
checkpoint_dir: str,
training_args: GRPOConfigType,
*,
global_step: Optional[int],
num_input_tokens_seen: Optional[float] = None,
base_state: Optional[Dict[str, object]] = None,
accelerator: Optional[object] = None,
) -> None:
"""Persist a minimal trainer_state.json so future resumes find the step."""
payload: Dict[str, object] = {
"global_step": int(global_step or 0),
"max_steps": getattr(training_args, "max_steps", None),
"num_train_epochs": getattr(training_args, "num_train_epochs", None),
"save_steps": getattr(training_args, "save_steps", None),
"logging_steps": getattr(training_args, "logging_steps", None),
"is_local_process_zero": True,
"is_world_process_zero": True,
"log_history": [],
}
if num_input_tokens_seen is not None:
payload["num_input_tokens_seen"] = float(num_input_tokens_seen)
if base_state:
for key in (
"best_model_checkpoint",
"best_metric",
"best_global_step",
"log_history",
):
if key in base_state:
payload[key] = base_state[key]
if accelerator is not None:
payload["is_local_process_zero"] = bool(
getattr(
accelerator,
"is_local_process_zero",
getattr(accelerator, "is_local_main_process", True),
)
)
payload["is_world_process_zero"] = bool(
getattr(
accelerator,
"is_world_process_zero",
getattr(accelerator, "is_main_process", True),
)
)
state_path = os.path.join(checkpoint_dir, "trainer_state.json")
try:
with open(state_path, "w", encoding="utf-8") as fh:
json.dump(payload, fh, indent=2)
except (
OSError,
TypeError,
ValueError,
) as exc: # pragma: no cover - filesystem errors
LOG.warning("Failed to write trainer_state.json to %s: %s", checkpoint_dir, exc)
[docs]
def build_checkpoint_saver(
training_args: GRPOConfigType,
runtime_handles: object,
optim_handles: object,
tokenizer: object,
*,
state_ref: Optional[Dict[str, object]] = None,
base_trainer_state: Optional[Dict[str, object]] = None,
controller_cfg: Optional[ControllerPathsLike] = None,
) -> Callable[[str], None]:
"""Return a save_checkpoint callable compatible with LoggingHandles.
The returned callable snapshots accelerator state, model/optimizer weights,
trainer state metadata, and optional controller state into a checkpoint
directory under ``output_dir``.
:param training_args: Training configuration containing output/checkpoint options.
:param runtime_handles: Runtime bundle providing model/accelerator references.
:param optim_handles: Optimizer bundle used for saving optimizer state.
:param tokenizer: Tokenizer to serialize alongside checkpoints.
:param state_ref: Mutable state dict used for cross-callback coordination.
:param base_trainer_state: Optional base trainer state JSON to merge into saves.
:param controller_cfg: Optional controller state paths for MaxEnt.
:returns: Callable ``save_checkpoint(name: str) -> None``.
:rtype: Callable[[str], None]
"""
output_dir = getattr(training_args, "output_dir", None)
accelerator = getattr(runtime_handles, "accelerator", None)
model = getattr(runtime_handles, "model", None)
optimizer = getattr(optim_handles, "optimizer", None)
save_total_limit = _parse_save_total_limit(
getattr(training_args, "save_total_limit", None)
)
state_ref = state_ref if isinstance(state_ref, dict) else {}
push_enabled = bool(
getattr(training_args, "push_to_hub", False)
or getattr(training_args, "push_to_hub_revision", False)
)
hub_strategy = str(getattr(training_args, "hub_strategy", "end") or "end").lower()
push_every_save = push_enabled and hub_strategy in {"every_save", "checkpoint"}
def _step_from_name(name: str) -> Optional[int]:
if not isinstance(name, str):
return None
return _parse_checkpoint_step(name)
def _save_checkpoint(checkpoint_name: str) -> None:
if not output_dir:
return
checkpoint_dir = os.path.join(output_dir, checkpoint_name)
try:
os.makedirs(checkpoint_dir, exist_ok=True)
except OSError as exc: # pragma: no cover - filesystem guard
LOG.warning(
"Failed to create checkpoint directory %s: %s", checkpoint_dir, exc
)
return
wait_for_all = getattr(accelerator, "wait_for_everyone", None)
if callable(wait_for_all):
wait_for_all()
save_state_fn = getattr(accelerator, "save_state", None)
if callable(save_state_fn):
try:
save_state_fn(checkpoint_dir)
except (
OSError,
RuntimeError,
TypeError,
ValueError,
) as exc: # pragma: no cover - accelerator dependent
LOG.warning(
"Failed to save accelerator state to %s: %s", checkpoint_dir, exc
)
state_dict: Optional[Dict[str, object]] = None
get_state_dict_fn = getattr(accelerator, "get_state_dict", None)
if callable(get_state_dict_fn) and model is not None:
# Needed for ZeRO-3/FSDP: gathers a full (saveable) state_dict.
# This may involve collective ops, so run it on all ranks.
ds_state = None
try:
ds_state = detect_deepspeed_state(accelerator)
except (AttributeError, RuntimeError, TypeError, ValueError):
ds_state = None
candidates = [model]
# Some Accelerate versions behave differently depending on whether the model is wrapped.
unwrap = getattr(accelerator, "unwrap_model", None)
if callable(unwrap):
try:
unwrapped = unwrap(model)
except (AttributeError, RuntimeError, TypeError, ValueError):
unwrapped = model
if unwrapped is not model:
if ds_state and (
ds_state.use_deepspeed or ds_state.zero_stage >= 2
):
LOG.warning(
"Skipping unwrapped model state_dict gather under DeepSpeed for %s; "
"use ZeRO shards or enable 16-bit gathering on save.",
checkpoint_dir,
)
else:
candidates.append(unwrapped)
for candidate in candidates:
try:
gathered = get_state_dict_fn(candidate)
except (
AttributeError,
OSError,
RuntimeError,
TypeError,
ValueError,
) as exc: # pragma: no cover - backend specific
LOG.warning(
"Failed to gather consolidated model state_dict for %s: %s",
checkpoint_dir,
exc,
)
continue
state_dict_candidate = (
gathered if isinstance(gathered, Mapping) else None
)
if state_dict_candidate is not None and not isinstance(
state_dict_candidate, dict
):
state_dict_candidate = dict(state_dict_candidate)
if _state_dict_has_zero_sized_tensors(state_dict_candidate):
if ds_state and (
ds_state.use_deepspeed or ds_state.zero_stage >= 2
):
LOG.warning(
"Accelerator returned an invalid consolidated state_dict for %s "
"(zero-sized tensors detected); skipping fallback under DeepSpeed.",
checkpoint_dir,
)
else:
LOG.warning(
"Accelerator returned an invalid consolidated state_dict for %s "
"(zero-sized tensors detected); trying fallback.",
checkpoint_dir,
)
continue
state_dict = state_dict_candidate
break
if getattr(accelerator, "is_main_process", True):
model_to_save = model
unwrap = getattr(accelerator, "unwrap_model", None)
if callable(unwrap):
try:
model_to_save = unwrap(model)
except (AttributeError, RuntimeError, TypeError, ValueError):
model_to_save = model
try:
_save_consolidated_hf_weights(
model_to_save=model_to_save,
checkpoint_dir=checkpoint_dir,
state_dict=state_dict,
)
except (
OSError,
RuntimeError,
TypeError,
ValueError,
) as exc: # pragma: no cover - model save guard
LOG.warning(
"Failed to save model weights to %s: %s", checkpoint_dir, exc
)
hf_has_weights = _checkpoint_has_hf_weights(checkpoint_dir)
hf_valid = hf_has_weights and _checkpoint_has_valid_hf_weights(
checkpoint_dir
)
if not hf_has_weights:
LOG.warning(
"Checkpoint %s does not contain a loadable HF weight file "
"(expected model.safetensors or pytorch_model.bin).",
checkpoint_dir,
)
elif not hf_valid:
LOG.error(
"Checkpoint %s contains invalid HF weight files (e.g., zero-sized tensors). "
"Removing the HF weight artifacts to avoid poisoning future resumes.",
checkpoint_dir,
)
_remove_hf_weight_files(checkpoint_dir)
hf_has_weights = False
hf_valid = False
if not hf_valid:
try:
ds_state = detect_deepspeed_state(accelerator)
except (AttributeError, RuntimeError, TypeError, ValueError):
ds_state = None
if ds_state and (ds_state.use_deepspeed or ds_state.zero_stage >= 2):
if _maybe_convert_zero_checkpoint_to_hf(checkpoint_dir):
hf_has_weights = _checkpoint_has_hf_weights(checkpoint_dir)
hf_valid = hf_has_weights and _checkpoint_has_valid_hf_weights(
checkpoint_dir
)
if hf_valid:
LOG.info(
"DeepSpeed ZeRO shards consolidated for checkpoint %s.",
checkpoint_dir,
)
else:
LOG.warning(
"DeepSpeed consolidation completed but HF weights are still missing/invalid in %s.",
checkpoint_dir,
)
save_tokenizer = getattr(tokenizer, "save_pretrained", None)
if callable(save_tokenizer):
try:
save_tokenizer(checkpoint_dir)
except (
AttributeError,
OSError,
RuntimeError,
TypeError,
ValueError,
) as exc: # pragma: no cover - tokenizer optional
LOG.warning(
"Failed to save tokenizer to %s: %s", checkpoint_dir, exc
)
else: # pragma: no cover - tokenizer optional
LOG.warning(
"Failed to save tokenizer to %s: %s",
checkpoint_dir,
"save_pretrained unavailable",
)
if optimizer is not None:
try:
import torch
state_dict_fn = getattr(optimizer, "state_dict", None)
if callable(state_dict_fn):
torch.save(
state_dict_fn(),
os.path.join(checkpoint_dir, "optimizer.pt"),
)
else:
LOG.warning(
"Optimizer state_dict unavailable; skipping optimizer.pt for %s",
checkpoint_dir,
)
except (
OSError,
RuntimeError,
TypeError,
ValueError,
) as exc: # pragma: no cover - optimizer optional
LOG.warning(
"Failed to save optimizer state to %s: %s",
checkpoint_dir,
exc,
)
if push_every_save:
try:
from maxent_grpo.core.hub import push_to_hub_revision
push_args = SimpleNamespace(
**getattr(training_args, "__dict__", {})
)
push_args.output_dir = checkpoint_dir
push_args.push_to_hub_revision = True
push_to_hub_revision(
cast(GRPOConfigType, push_args),
extra_ignore_patterns=[],
include_checkpoints=True,
)
except (
ImportError,
OSError,
RuntimeError,
TypeError,
ValueError,
) as exc: # pragma: no cover - optional hub deps
LOG.warning(
"Failed to push checkpoint %s to Hub: %s",
checkpoint_dir,
exc,
)
state_obj = state_ref.get("state")
global_step = (
int(getattr(state_obj, "global_step", 0))
if state_obj is not None
else (_step_from_name(checkpoint_name) or 0)
)
num_tokens = getattr(state_obj, "num_input_tokens_seen", None)
_write_trainer_state_json(
checkpoint_dir,
training_args,
global_step=global_step,
num_input_tokens_seen=num_tokens,
base_state=base_trainer_state,
accelerator=accelerator,
)
controller_state_src = (
getattr(controller_cfg, "state_path", None) if controller_cfg else None
)
if controller_state_src and os.path.isfile(controller_state_src):
dst_path = os.path.join(checkpoint_dir, CONTROLLER_STATE_FILENAME)
try:
shutil.copy2(controller_state_src, dst_path)
except OSError as exc:
LOG.warning(
"Failed to copy controller state to %s: %s",
dst_path,
exc,
)
if save_total_limit > 0:
_prune_old_checkpoints(output_dir, save_total_limit)
if callable(wait_for_all):
wait_for_all()
return _save_checkpoint
[docs]
def maybe_clear_stale_controller_state(
accelerator: AcceleratorLike, controller_cfg: ControllerPathsLike
) -> None:
"""Delete a stale controller state file when overwriting the output dir.
:param accelerator: Accelerate handle used to determine the main process
and trigger ``wait_for_everyone`` guards.
:type accelerator: AcceleratorLike
:param controller_cfg: Paths describing the active controller
checkpoint/restore locations.
:type controller_cfg: ControllerPathsLike
"""
resume_path = getattr(controller_cfg, "resume_from", None)
if resume_path:
return
if not getattr(controller_cfg, "overwrite_existing", False):
return
state_path = getattr(controller_cfg, "state_path", None)
if not state_path or not os.path.isfile(state_path):
return
if accelerator.is_main_process:
try:
os.remove(state_path)
LOG.info(
"Removed stale controller state at %s due to overwrite_output_dir.",
state_path,
)
except OSError as exc: # pragma: no cover - filesystem race
LOG.warning(
"Failed to remove stale controller state %s: %s", state_path, exc
)
wait_for_all = getattr(accelerator, "wait_for_everyone", None)
if callable(wait_for_all):
wait_for_all()
def _load_controller_file(
path: Optional[str],
_accelerator: Optional[AcceleratorLike],
weighting_cfg: WeightingConfigLike,
) -> bool:
"""Load controller parameters from ``path`` when available.
:param path: Filesystem path to a serialized controller state.
:type path: str | None
:param accelerator: Optional accelerator handle (unused, for signature parity/tests).
:type accelerator: AcceleratorLike | None
:param weighting_cfg: Mutable weighting configuration that will receive
the loaded parameters.
:type weighting_cfg: WeightingConfigLike
:returns: ``True`` when the controller state was loaded successfully.
:rtype: bool
"""
if not path:
return False
load_fn = globals().get("load_controller_state", load_controller_state)
success = False
if callable(load_fn):
try:
success = bool(load_fn(path, weighting_cfg))
except (OSError, RuntimeError, TypeError, ValueError) as exc:
LOG.warning("Failed to load controller state %s: %s", path, exc)
success = False
else:
success = bool(load_fn)
if success:
# Emit a simple success log for test visibility, then the detailed metrics
LOG.info("Loaded controller state from %s", path)
beta_val = getattr(weighting_cfg, "beta", None)
tau_val = getattr(weighting_cfg, "tau", None)
LOG.info(
"Loaded controller state from %s | beta=%s tau=%s",
path,
beta_val,
tau_val,
)
return success
[docs]
def load_controller_state_chain(
controller_cfg: ControllerPathsLike,
accelerator: AcceleratorLike,
weighting_cfg: WeightingConfigLike,
) -> bool:
"""Attempt to load controller state from resume directory or the current state.
:param controller_cfg: Filesystem paths for controller checkpoints.
:type controller_cfg: ControllerPathsLike
:param accelerator: Accelerate handle performing logging/synchronization.
:type accelerator: AcceleratorLike
:param weighting_cfg: Mutable weighting settings that receive the loaded parameters.
:type weighting_cfg: WeightingConfigLike
:returns: ``True`` when controller resume was requested or a controller
checkpoint was successfully loaded.
:rtype: bool
"""
maybe_clear_stale_controller_state(accelerator, controller_cfg)
resume_path = getattr(controller_cfg, "resume_from", None)
controller_loaded = False
resume_attempted = False
tried_paths: list[str] = []
if isinstance(resume_path, str) and resume_path:
resume_attempted = True
resume_state_file = os.path.join(resume_path, CONTROLLER_STATE_FILENAME)
tried_paths.append(resume_state_file)
controller_loaded = _load_controller_file(
resume_state_file, accelerator, weighting_cfg
)
if not controller_loaded and controller_cfg.state_path:
tried_paths.append(controller_cfg.state_path)
controller_loaded = _load_controller_file(
controller_cfg.state_path, accelerator, weighting_cfg
)
if not controller_loaded and tried_paths:
if resume_attempted:
LOG.warning(
"Controller resume was requested but no state was loaded | tried=%s",
", ".join(tried_paths),
)
else:
LOG.info(
"No controller state found; starting fresh | tried=%s",
", ".join(tried_paths),
)
broadcast_controller_state(accelerator, weighting_cfg)
return bool(controller_loaded or resume_attempted)
[docs]
def maybe_load_accelerator_state(
resume_state_path: Optional[str], accelerator: AcceleratorLike
) -> None:
"""Load an accelerator state directory when resuming if available.
:param resume_state_path: Filesystem path to an accelerator state directory
(e.g., saved by ``accelerator.save_state``).
:type resume_state_path: str | None
:param accelerator: Accelerate handle whose ``load_state`` method will be invoked.
:type accelerator: AcceleratorLike
:returns: ``None``.
"""
load_state_fn = getattr(accelerator, "load_state", None)
if isinstance(resume_state_path, str) and resume_state_path:
resume_state_path = _normalize_checkpoint_dir(resume_state_path)
if (
isinstance(resume_state_path, str)
and resume_state_path
and os.path.isdir(resume_state_path)
and callable(load_state_fn)
):
if not _checkpoint_has_accelerate_state(resume_state_path):
if _checkpoint_has_deepspeed_engine_state(resume_state_path):
loaded = False
for model in getattr(accelerator, "_models", []) or []:
load_checkpoint = getattr(model, "load_checkpoint", None)
if not callable(load_checkpoint):
continue
# Custom loops may use a different optimizer parameter-group
# layout than the original Trainer run; default to loading
# weights only (no optimizer/scheduler state).
try:
load_checkpoint(
resume_state_path,
tag=None,
load_optimizer_states=False,
load_lr_scheduler_states=False,
)
except TypeError:
load_checkpoint(resume_state_path)
loaded = True
if loaded:
accelerator.wait_for_everyone()
LOG.info(
"Loaded DeepSpeed checkpoint state from %s", resume_state_path
)
return
try:
load_state_fn(resume_state_path)
accelerator.wait_for_everyone()
LOG.info("Loaded accelerator state from %s", resume_state_path)
except (
AssertionError,
OSError,
RuntimeError,
ValueError,
) as exc: # pragma: no cover - environment dependent
LOG.warning(
"Failed to load accelerator state from %s: %s", resume_state_path, exc
)
[docs]
def maybe_checkpoint(
logging_cfg: LoggingHandles, accelerator: AcceleratorLike, global_step: int
) -> None:
"""Checkpoint periodically while on the main process.
:param logging_cfg: Logging handles containing checkpoint callbacks and
scheduling knobs.
:type logging_cfg: ~maxent_grpo.training.types.logging.LoggingHandles
:param accelerator: Accelerate handle used for synchronization and
main-process checks.
:type accelerator: AcceleratorLike
:param global_step: Current optimizer step; used to decide whether
``save_steps`` divides the step index evenly.
:type global_step: int
:returns: ``None``.
"""
if accelerator.is_main_process and not _checkpoint_log_once["config"]:
LOG.info(
"Checkpoint guard | strategy=%s | save_steps=%s",
getattr(logging_cfg, "save_strategy", None),
getattr(logging_cfg, "save_steps", None),
)
_checkpoint_log_once["config"] = True
strategy = str(getattr(logging_cfg, "save_strategy", "")).lower()
for prefix in ("savestrategy.", "intervalstrategy."):
if strategy.startswith(prefix):
strategy = strategy.split(".", 1)[1]
should_save = (
strategy == "steps"
and logging_cfg.save_steps > 0
and global_step % logging_cfg.save_steps == 0
)
if accelerator.is_main_process:
if strategy != "steps":
if not _checkpoint_log_once["strategy"]:
LOG.info(
"Skipping checkpoint: save_strategy=%s (global_step=%s)",
strategy,
global_step,
)
_checkpoint_log_once["strategy"] = True
elif logging_cfg.save_steps <= 0:
if not _checkpoint_log_once["steps"]:
LOG.info(
"Skipping checkpoint: save_steps<=0 (save_steps=%s | global_step=%s)",
logging_cfg.save_steps,
global_step,
)
_checkpoint_log_once["steps"] = True
wait_for_all = getattr(accelerator, "wait_for_everyone", None)
if callable(wait_for_all):
wait_for_all()
if should_save:
if accelerator.is_main_process:
LOG.info(
"Triggering checkpoint save at step %s (save_steps=%s)",
global_step,
logging_cfg.save_steps,
)
logging_cfg.save_checkpoint(f"checkpoint-{global_step}")
if callable(wait_for_all):
wait_for_all()
if callable(wait_for_all):
wait_for_all()
[docs]
def check_stop_condition(
schedule: OptimizationSchedule, loop_state: "TrainingLoopState"
) -> None:
"""Set stop flag when the configured number of steps is reached.
:param schedule: Optimization schedule describing ``total_training_steps``.
:type schedule: training.types.OptimizationSchedule
:param loop_state: Mutable training loop state whose ``stop_training`` flag
should be updated when the threshold is crossed.
:type loop_state: training.types.TrainingLoopState
:returns: ``None``.
"""
if (
schedule.total_training_steps > 0
and loop_state.global_step >= schedule.total_training_steps
):
loop_state.stop_training = True
[docs]
def build_training_state(training_args: GRPOConfigType) -> LoggingHandles:
"""Construct minimal logging handles for the custom runner.
:param training_args: Training configuration providing save strategy/steps.
:returns: ``LoggingHandles`` instance with a no-op checkpoint saver.
:rtype: LoggingHandles
"""
class _NoopWriter:
def __init__(self) -> None:
self.logged = []
def log(self, metrics: Dict[str, object], step: int) -> None:
self.logged.append((step, metrics))
def flush(self) -> None:
return
writer = _NoopWriter()
return LoggingHandles(
metric_writer=writer,
save_checkpoint=lambda *_a, **_k: None,
save_strategy=getattr(training_args, "save_strategy", "no"),
save_steps=int(getattr(training_args, "save_steps", 0) or 0),
wandb_run=None,
)
__all__ = [
"maybe_clear_stale_controller_state",
"load_controller_state_chain",
"resolve_resume_checkpoint",
"load_trainer_state_metadata",
"maybe_load_accelerator_state",
"maybe_checkpoint",
"check_stop_condition",
"build_checkpoint_saver",
"build_training_state",
]
# Preserve a self-reference so monkeypatch paths like ``training.state.state`` resolve
# even after test shuffling or aliasing.
state = sys.modules[__name__]