Source code for maxent_grpo.training.runtime.deepspeed
"""DeepSpeed and Accelerate integration helpers."""
from __future__ import annotations
import logging
import os
from typing import Any, Dict, Optional, Tuple
from maxent_grpo.utils.imports import optional_import as _optional_dependency
from maxent_grpo.utils.imports import require_dependency as _require_dependency
try: # Optional dependency for reading accelerate config files
import yaml
except (ImportError, ModuleNotFoundError): # pragma: no cover - optional
yaml = None
LOG = logging.getLogger(__name__)
[docs]
def require_deepspeed(context: str, module: str = "deepspeed") -> Any:
"""Return a DeepSpeed module import or raise a contextual RuntimeError."""
hint = (
f"DeepSpeed is required for MaxEnt-GRPO {context}. "
"Install it with `pip install deepspeed`."
)
try:
return _require_dependency(module, hint)
except ImportError as exc: # pragma: no cover - import guard
raise RuntimeError(hint) from exc
[docs]
def get_trl_prepare_deepspeed() -> Optional[Any]:
"""Return TRL's prepare_deepspeed helper when available."""
for module_name in ("trl.models.utils", "trl.trainer.utils"):
utils_module = _optional_dependency(module_name)
if utils_module is None:
continue
prepare = getattr(utils_module, "prepare_deepspeed", None)
if callable(prepare):
return prepare
return None
def _maybe_create_deepspeed_plugin() -> Optional[Any]:
"""Construct a DeepSpeedPlugin from Accelerate env/config when available."""
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() != "true":
return None
ds_module = _require_dependency(
"accelerate.utils",
(
"DeepSpeed integration requires the Accelerate package. "
"Install it via `pip install accelerate[deepspeed]`."
),
)
try:
ds_class = getattr(ds_module, "DeepSpeedPlugin")
except AttributeError as exc: # pragma: no cover - defensive guard
raise ImportError(
"accelerate.utils does not expose DeepSpeedPlugin; update Accelerate."
) from exc
ds_cfg: Dict[str, Any] = {}
cfg_path = os.environ.get("ACCELERATE_CONFIG_FILE")
if cfg_path and yaml is not None and os.path.isfile(cfg_path):
handled_exceptions: Tuple[type[BaseException], ...] = (OSError, ValueError)
yaml_error = getattr(yaml, "YAMLError", None)
if isinstance(yaml_error, type) and issubclass(yaml_error, BaseException):
handled_exceptions = handled_exceptions + (yaml_error,)
try:
with open(cfg_path, "r", encoding="utf-8") as cfg_file:
raw = yaml.safe_load(cfg_file) or {}
ds_cfg = raw.get("deepspeed_config") or {}
except handled_exceptions:
ds_cfg = {}
zero_stage_raw = ds_cfg.get("zero_stage", 3)
zero_stage = int(zero_stage_raw) if zero_stage_raw is not None else None
offload_param = ds_cfg.get("offload_param_device")
offload_optim = ds_cfg.get("offload_optimizer_device")
zero3_init_flag = ds_cfg.get("zero3_init_flag")
zero3_save = ds_cfg.get("zero3_save_16bit_model")
kwargs = {
"zero_stage": zero_stage,
"offload_param_device": offload_param,
"offload_optimizer_device": offload_optim,
"zero3_init_flag": zero3_init_flag,
"zero3_save_16bit_model": zero3_save,
}
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if not kwargs:
return None
return ds_class(**kwargs)
__all__ = [
"_maybe_create_deepspeed_plugin",
"get_trl_prepare_deepspeed",
"require_deepspeed",
]