"""Dependency loading utilities used by the training runtime."""
from __future__ import annotations
# pylint: disable=broad-exception-caught
import logging
import os
import sys
from importlib.machinery import ModuleSpec
from types import ModuleType, SimpleNamespace
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
from maxent_grpo.utils.imports import (
cached_import as _import_module,
optional_import as _optional_dependency,
require_dependency as _require_dependency,
)
LOG = logging.getLogger(__name__)
if TYPE_CHECKING:
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin
else: # pragma: no cover - typing fallbacks
Accelerator = Any
DeepSpeedPlugin = Any
try: # Optional dependency for reading accelerate config files
import yaml
except (ImportError, ModuleNotFoundError): # pragma: no cover - optional
yaml = None
def _allow_stubbed_deps() -> bool:
"""Return True when lightweight dependency stubs are allowed."""
return "pytest" in sys.modules or os.environ.get("PYTEST_CURRENT_TEST") is not None
def _install_torch_stub(hint: str) -> ModuleType:
"""Install a lightweight torch stub so imports can succeed in tests."""
existing = sys.modules.get("torch")
if isinstance(existing, ModuleType) and getattr(existing, "__MAXENT_STUB__", False):
return existing
try:
import numpy as _np
except (
ImportError,
ModuleNotFoundError,
OSError,
RuntimeError,
): # pragma: no cover - be defensive
_np = None
class _StubDType:
def __init__(self, name: str, np_dtype: Any = None) -> None:
self.name = f"torch.{name}"
self.np_dtype = np_dtype
def __repr__(self) -> str: # pragma: no cover - representational only
return self.name
class _StubTensor:
def __init__(self, data: Any, dtype: Any = None) -> None:
self.arr = _np.array(data) if _np is not None else data
self.dtype = dtype
@property
def shape(self) -> Any:
if _np is None:
try:
return (len(self.arr),)
except TypeError:
return ()
return self.arr.shape
@property
def ndim(self) -> int:
return len(self.shape)
def dim(self) -> int:
return self.ndim
def detach(self) -> "_StubTensor":
return self
def float(self) -> "_StubTensor":
return self
def cpu(self) -> "_StubTensor":
return self
def to(self, *_args: Any, **_kwargs: Any) -> "_StubTensor":
return self
def clone(self) -> "_StubTensor":
if _np is None:
if isinstance(self.arr, list):
return _StubTensor(list(self.arr), dtype=self.dtype)
return _StubTensor(self.arr, dtype=self.dtype)
return _StubTensor(self.arr.copy(), dtype=self.dtype)
def view(self, *shape: int) -> "_StubTensor":
if _np is None:
return self
return _StubTensor(self.arr.reshape(*shape), dtype=self.dtype)
def reshape(self, *shape: int) -> "_StubTensor":
return self.view(*shape)
def unsqueeze(self, dim: int) -> "_StubTensor":
if _np is None:
return _StubTensor([self.arr], dtype=self.dtype)
return _StubTensor(_np.expand_dims(self.arr, axis=dim), dtype=self.dtype)
def repeat(self, *sizes: int) -> "_StubTensor":
if len(sizes) == 1 and isinstance(sizes[0], (list, tuple)):
sizes = tuple(int(size) for size in sizes[0])
if _np is None:
data = self.arr
for size in reversed(tuple(int(size) for size in sizes)):
data = [data for _ in range(size)]
return _StubTensor(data, dtype=self.dtype)
return _StubTensor(
_np.tile(self.arr, tuple(int(size) for size in sizes)),
dtype=self.dtype,
)
def clamp(
self,
min_value: Optional[float] = None,
max_value: Optional[float] = None,
**kwargs: Any,
) -> "_StubTensor":
if "min" in kwargs and min_value is None:
min_value = kwargs["min"]
if "max" in kwargs and max_value is None:
max_value = kwargs["max"]
if _np is None:
return self
data = self.arr
if min_value is not None:
data = _np.maximum(data, min_value)
if max_value is not None:
data = _np.minimum(data, max_value)
return _StubTensor(data, dtype=self.dtype)
def numel(self) -> int:
if _np is None:
try:
return len(self.arr)
except TypeError:
return 1
return int(self.arr.size)
def size(self, dim: Optional[int] = None) -> Any:
if _np is None:
return len(self.arr) if dim is not None else ()
if dim is None:
return self.arr.shape
return self.arr.shape[dim]
def tolist(self) -> Any:
if _np is None:
return list(self.arr) if isinstance(self.arr, list) else self.arr
return self.arr.tolist()
def item(self) -> Any:
if _np is None:
return self.arr
return self.arr.item()
def sum(self, dim: Optional[int] = None) -> "_StubTensor":
if _np is None:
return self
return _StubTensor(self.arr.sum(axis=dim), dtype=self.dtype)
def mean(self, dim: Optional[int] = None) -> "_StubTensor":
if _np is None:
return self
return _StubTensor(self.arr.mean(axis=dim), dtype=self.dtype)
def log(self) -> "_StubTensor":
if _np is None:
return self
return _StubTensor(_np.log(self.arr), dtype=self.dtype)
def exp(self) -> "_StubTensor":
if _np is None:
return self
return _StubTensor(_np.exp(self.arr), dtype=self.dtype)
def max(self) -> Any:
if _np is None:
return self.arr
return self.arr.max()
def __len__(self) -> int:
try:
return len(self.arr)
except TypeError:
return 1
def __iter__(self):
return iter(self.arr)
def __getitem__(self, idx: Any) -> Any:
if _np is None:
return self.arr[idx]
value = self.arr[idx]
if isinstance(value, (_np.ndarray, _np.generic)):
return _StubTensor(value, dtype=self.dtype)
return value
def new_full(self, shape: Any, fill_value: float) -> "_StubTensor":
if _np is None:
return _StubTensor([fill_value], dtype=self.dtype)
return _StubTensor(_np.full(shape, fill_value), dtype=self.dtype)
def __float__(self) -> float:
if _np is None:
try:
return float(self.arr)
except Exception:
return 0.0
return float(self.arr.item())
def _binary_op(self, other: Any, op) -> "_StubTensor":
if isinstance(other, _StubTensor):
other = other.arr
if _np is None:
try:
return _StubTensor(op(self.arr, other), dtype=self.dtype)
except (TypeError, ValueError, OverflowError):
return self
return _StubTensor(op(self.arr, other), dtype=self.dtype)
def __add__(self, other: Any) -> "_StubTensor":
return self._binary_op(other, lambda a, b: a + b)
def __sub__(self, other: Any) -> "_StubTensor":
return self._binary_op(other, lambda a, b: a - b)
def __mul__(self, other: Any) -> "_StubTensor":
return self._binary_op(other, lambda a, b: a * b)
def __truediv__(self, other: Any) -> "_StubTensor":
return self._binary_op(other, lambda a, b: a / b)
def __lt__(self, other: Any) -> "_StubTensor":
return self._binary_op(other, lambda a, b: a < b)
def __le__(self, other: Any) -> "_StubTensor":
return self._binary_op(other, lambda a, b: a <= b)
def __gt__(self, other: Any) -> "_StubTensor":
return self._binary_op(other, lambda a, b: a > b)
def __ge__(self, other: Any) -> "_StubTensor":
return self._binary_op(other, lambda a, b: a >= b)
def lt(self, other: Any) -> "_StubTensor":
return self.__lt__(other)
def le(self, other: Any) -> "_StubTensor":
return self.__le__(other)
def gt(self, other: Any) -> "_StubTensor":
return self.__gt__(other)
def ge(self, other: Any) -> "_StubTensor":
return self.__ge__(other)
def __array__(
self, dtype: Any = None
) -> Any: # pragma: no cover - numpy interop
if _np is None:
return self.arr
return self.arr.astype(dtype) if dtype is not None else self.arr
def _wrap(data: Any, dtype: Any = None) -> _StubTensor:
return _StubTensor(data, dtype=dtype)
def _zeros(shape: Any, dtype: Any = None) -> _StubTensor:
if _np is None:
return _wrap(
[
0
for _ in range(
int(shape[0]) if isinstance(shape, (list, tuple)) else 1
)
],
dtype,
)
return _wrap(_np.zeros(shape), dtype=dtype)
def _ones(shape: Any, dtype: Any = None) -> _StubTensor:
if _np is None:
return _wrap(
[
1
for _ in range(
int(shape[0]) if isinstance(shape, (list, tuple)) else 1
)
],
dtype,
)
return _wrap(_np.ones(shape), dtype=dtype)
def _full(shape: Any, fill_value: float, dtype: Any = None) -> _StubTensor:
if _np is None:
return _wrap(
[
fill_value
for _ in range(
int(shape[0]) if isinstance(shape, (list, tuple)) else 1
)
],
dtype,
)
return _wrap(_np.full(shape, fill_value), dtype=dtype)
def _arange(*args: Any, **_kwargs: Any) -> _StubTensor:
if _np is None:
return _wrap(list(range(*args)))
return _wrap(_np.arange(*args))
def _cat(tensors: Any, dim: int = 0) -> _StubTensor:
if _np is None:
data = []
for t in tensors:
data.extend(getattr(t, "arr", t))
return _wrap(data)
arrs = [_np.array(getattr(t, "arr", t)) for t in tensors]
return _wrap(_np.concatenate(arrs, axis=dim))
def _stack(tensors: Any, dim: int = 0) -> _StubTensor:
if _np is None:
return _wrap([getattr(t, "arr", t) for t in tensors])
arrs = [_np.array(getattr(t, "arr", t)) for t in tensors]
return _wrap(_np.stack(arrs, axis=dim))
def _sum(tensor: Any, dim: int | None = None) -> _StubTensor:
data = getattr(tensor, "arr", tensor)
if _np is None:
if isinstance(data, (list, tuple)):
return _wrap(sum(float(v) for v in data))
try:
return _wrap(float(data))
except (TypeError, ValueError):
return _wrap(0.0)
return _wrap(_np.sum(_np.array(data), axis=dim))
def _nansum(tensor: Any, dim: int | None = None) -> _StubTensor:
data = getattr(tensor, "arr", tensor)
if _np is None:
return _sum(data, dim=dim)
return _wrap(_np.nansum(_np.array(data), axis=dim))
def _clamp(
tensor: Any,
min_value: Optional[float] = None,
max_value: Optional[float] = None,
**kwargs: Any,
) -> _StubTensor:
value = tensor if isinstance(tensor, _StubTensor) else _wrap(tensor)
return value.clamp(min_value=min_value, max_value=max_value, **kwargs)
def _log(tensor: Any) -> _StubTensor:
value = tensor if isinstance(tensor, _StubTensor) else _wrap(tensor)
return value.log()
def _exp(tensor: Any) -> _StubTensor:
value = tensor if isinstance(tensor, _StubTensor) else _wrap(tensor)
return value.exp()
def _softmax(tensor: Any, dim: int = -1) -> _StubTensor:
value = getattr(tensor, "arr", tensor)
if _np is None:
return tensor if isinstance(tensor, _StubTensor) else _wrap(tensor)
arr = _np.array(value, dtype=float)
shifted = arr - _np.max(arr, axis=dim, keepdims=True)
exp_vals = _np.exp(shifted)
denom = _np.sum(exp_vals, axis=dim, keepdims=True)
return _wrap(exp_vals / denom)
def _randn(*shape: int) -> _StubTensor:
if _np is None:
size = int(shape[0]) if shape else 1
return _wrap([0.0 for _ in range(size)])
normalized_shape = shape or (1,)
return _wrap(_np.random.randn(*normalized_shape))
class _Module:
"""Minimal torch.nn.Module compatible with test doubles."""
def __init__(self, *_args: Any, **_kwargs: Any) -> None:
self.training = True
def forward(self, *_args: Any, **_kwargs: Any) -> Any:
raise TypeError("Module stub is not callable without a forward method")
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.forward(*args, **kwargs)
def eval(self) -> "_Module":
self.training = False
return self
def train(self, mode: bool = True) -> "_Module":
self.training = bool(mode)
return self
def parameters(self) -> list[Any]:
return []
class _Linear(_Module):
"""Simple Linear stub that preserves output shape semantics."""
def __init__(
self, in_features: int, out_features: int, *_args: Any, **_kwargs: Any
) -> None:
super().__init__()
self.in_features = int(in_features)
self.out_features = int(out_features)
if _np is None:
self.weight = _wrap(
[[0.0 for _ in range(self.in_features)] for _ in range(self.out_features)]
)
self.bias = _wrap([0.0 for _ in range(self.out_features)])
else:
self.weight = _wrap(_np.zeros((self.out_features, self.in_features)))
self.bias = _wrap(_np.zeros((self.out_features,)))
def forward(self, *args: Any, **_kwargs: Any) -> _StubTensor:
inputs = args[0] if args else None
data = getattr(inputs, "arr", inputs)
if _np is None:
return _wrap([0.0 for _ in range(self.out_features)])
array = _np.array(data)
if array.ndim == 0:
out_shape = (self.out_features,)
else:
out_shape = tuple(array.shape[:-1]) + (self.out_features,)
return _wrap(_np.zeros(out_shape))
stub = ModuleType("torch")
stub.__MAXENT_STUB__ = True
stub.__spec__ = getattr(stub, "__spec__", None) or ModuleSpec("torch", loader=None)
stub.__path__ = getattr(stub, "__path__", [])
stub.Tensor = _StubTensor
stub.tensor = _wrap
stub.as_tensor = _wrap
stub.zeros = _zeros
stub.ones = _ones
def _shape_from(value: Any) -> Any:
shape = getattr(value, "shape", None)
if shape is not None:
return shape
try:
return (len(value),)
except TypeError:
return (1,)
stub.ones_like = lambda tensor, *_a, **_k: _ones(
_shape_from(getattr(tensor, "arr", tensor))
)
stub.zeros_like = lambda tensor, *_a, **_k: _zeros(
_shape_from(getattr(tensor, "arr", tensor))
)
stub.full = _full
stub.arange = _arange
stub.cat = _cat
stub.stack = _stack
stub.sum = _sum
stub.nansum = _nansum
stub.clamp = _clamp
stub.log = _log
stub.exp = _exp
stub.softmax = _softmax
stub.randn = _randn
stub.float32 = _StubDType("float32", _np.float32 if _np is not None else None)
stub.float16 = _StubDType("float16", _np.float16 if _np is not None else None)
stub.bfloat16 = _StubDType("bfloat16")
long_dtype = _StubDType("int64", _np.int64 if _np is not None else None)
stub.long = long_dtype
stub.int64 = long_dtype
stub.bool = _StubDType("bool", _np.bool_ if _np is not None else None)
try:
from contextlib import nullcontext as _nullcontext
except ImportError: # pragma: no cover - stdlib fallback
_nullcontext = None
def _no_grad() -> Any:
if _nullcontext is None:
return SimpleNamespace(
__enter__=lambda *_a, **_k: None,
__exit__=lambda *_a, **_k: False,
)
return _nullcontext()
stub.no_grad = _no_grad
stub.device = lambda *_a, **_k: "cpu"
stub.cuda = SimpleNamespace(is_available=lambda: False)
stub.distributed = None
stub.__version__ = "0.0.0-stub"
stub.nn = SimpleNamespace(
Module=_Module,
Linear=_Linear,
utils=SimpleNamespace(clip_grad_norm_=lambda *_a, **_k: 0.0),
functional=SimpleNamespace(),
)
stub.optim = SimpleNamespace(Optimizer=object, AdamW=None)
sys.modules["torch"] = stub
nn_mod = sys.modules.setdefault("torch.nn", ModuleType("torch.nn"))
if not hasattr(nn_mod, "Module"):
nn_mod.Module = _Module
stub_nn = getattr(stub, "nn", None)
stub_linear = getattr(stub_nn, "Linear", None)
if not hasattr(nn_mod, "Linear"):
if callable(stub_linear):
nn_mod.Linear = stub_linear
else:
nn_mod.Linear = lambda *_a, **_k: (_ for _ in ()).throw(RuntimeError(hint))
nn_fn_mod = sys.modules.setdefault(
"torch.nn.functional", ModuleType("torch.nn.functional")
)
setattr(nn_mod, "functional", nn_fn_mod)
optim_mod = sys.modules.setdefault("torch.optim", ModuleType("torch.optim"))
if not hasattr(optim_mod, "Optimizer"):
optim_mod.Optimizer = object
if not hasattr(optim_mod, "AdamW"):
optim_mod.AdamW = None
utils_mod = sys.modules.setdefault("torch.utils", ModuleType("torch.utils"))
data_mod = sys.modules.setdefault(
"torch.utils.data", ModuleType("torch.utils.data")
)
if not hasattr(data_mod, "DataLoader"):
class _DataLoader: # pragma: no cover - stubbed fallback
def __init__(self, *_args: Any, **_kwargs: Any) -> None:
raise RuntimeError(hint)
data_mod.DataLoader = _DataLoader
if not hasattr(data_mod, "Sampler"):
class _Sampler: # pragma: no cover - stubbed fallback
pass
data_mod.Sampler = _Sampler
setattr(utils_mod, "data", data_mod)
return stub
def _install_accelerate_stub(_hint: str) -> ModuleType:
"""Install a lightweight accelerate stub for tests."""
existing = sys.modules.get("accelerate")
if isinstance(existing, ModuleType) and getattr(existing, "__MAXENT_STUB__", False):
return existing
class _Accelerator:
def __init__(self, *args: Any, **kwargs: Any) -> None:
_ = (args, kwargs)
self.is_main_process = True
self.num_processes = 1
self.process_index = 0
self.device = "cpu"
def wait_for_everyone(self) -> None:
return None
accel_mod = ModuleType("accelerate")
accel_mod.__MAXENT_STUB__ = True
accel_mod.__spec__ = getattr(accel_mod, "__spec__", None) or ModuleSpec(
"accelerate", loader=None
)
accel_mod.__path__ = getattr(accel_mod, "__path__", [])
setattr(accel_mod, "Accelerator", _Accelerator)
state_mod = ModuleType("accelerate.state")
state_mod.__spec__ = getattr(state_mod, "__spec__", None) or ModuleSpec(
"accelerate.state", loader=None
)
class _DistributedType:
DEEPSPEED = "deepspeed"
setattr(state_mod, "DistributedType", _DistributedType)
sys.modules["accelerate"] = accel_mod
sys.modules["accelerate.state"] = state_mod
return accel_mod
[docs]
def require_torch(context: str) -> ModuleType:
"""Return the torch module or raise a helpful RuntimeError."""
hint = (
f"PyTorch is required for MaxEnt-GRPO {context}. "
"Install it via `pip install torch`."
)
try:
return _require_dependency("torch", hint)
except ImportError as exc:
if _allow_stubbed_deps():
return _install_torch_stub(hint)
raise RuntimeError(hint) from exc
[docs]
def require_dataloader(context: str) -> Any:
"""Return ``torch.utils.data.DataLoader`` with a descriptive error on failure."""
hint = (
f"Torch's DataLoader is required for MaxEnt-GRPO {context}. "
"Install torch first."
)
class DataLoader: # pragma: no cover - import-time fallback
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise RuntimeError(hint)
try:
torch_data = _require_dependency("torch.utils.data", hint)
except ImportError as exc:
torch_mod = _optional_dependency("torch")
if torch_mod is None and _allow_stubbed_deps():
torch_mod = _install_torch_stub(hint)
if torch_mod is None:
raise RuntimeError(hint) from exc
torch_data = ModuleType("torch.utils.data")
setattr(torch_data, "DataLoader", DataLoader)
sys.modules["torch.utils.data"] = torch_data
torch_utils = getattr(torch_mod, "utils", None)
if torch_utils is None:
torch_utils = ModuleType("torch.utils")
sys.modules["torch.utils"] = torch_utils
setattr(torch_mod, "utils", torch_utils)
setattr(torch_utils, "data", torch_data)
dataloader_cls = getattr(torch_data, "DataLoader", None)
if dataloader_cls is None:
setattr(torch_data, "DataLoader", DataLoader)
dataloader_cls = DataLoader
return dataloader_cls
[docs]
def require_accelerator(context: str) -> Any:
"""Return accelerate.Accelerator or raise a helpful RuntimeError."""
hint = (
f"Accelerate is required for MaxEnt-GRPO {context}. "
"Install it via `pip install accelerate`."
)
try:
accelerate_mod = _require_dependency("accelerate", hint)
except ImportError as exc:
if _allow_stubbed_deps():
accelerate_mod = _install_accelerate_stub(hint)
else:
raise RuntimeError(hint) from exc
accelerator_cls = getattr(accelerate_mod, "Accelerator", None)
if accelerator_cls is None:
if _allow_stubbed_deps():
accelerate_mod = _install_accelerate_stub(hint)
accelerator_cls = getattr(accelerate_mod, "Accelerator", None)
if accelerator_cls is None:
raise RuntimeError(hint)
return accelerator_cls
[docs]
def require_deepspeed(context: str, module: str = "deepspeed") -> ModuleType:
"""Return a DeepSpeed module import or raise a contextual RuntimeError."""
hint = (
f"DeepSpeed is required for MaxEnt-GRPO {context}. "
"Install it with `pip install deepspeed`."
)
try:
return _require_dependency(module, hint)
except ImportError as exc: # pragma: no cover - import guard
raise RuntimeError(hint) from exc
[docs]
def get_trl_prepare_deepspeed() -> Optional[Any]:
"""Return TRL's prepare_deepspeed helper when available."""
for module_name in ("trl.models.utils", "trl.trainer.utils"):
utils_module = _optional_dependency(module_name)
if utils_module is None:
continue
prepare = getattr(utils_module, "prepare_deepspeed", None)
if callable(prepare):
return prepare
return None
def _maybe_create_deepspeed_plugin() -> Optional[DeepSpeedPlugin]:
"""Construct a DeepSpeedPlugin from Accelerate env/config when available."""
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() != "true":
return None
ds_module = _require_dependency(
"accelerate.utils",
(
"DeepSpeed integration requires the Accelerate package. "
"Install it via `pip install accelerate[deepspeed]`."
),
)
ds_class = getattr(ds_module, "DeepSpeedPlugin", None)
if ds_class is None:
raise ImportError(
"accelerate.utils does not expose DeepSpeedPlugin; update Accelerate."
)
ds_cfg: Dict[str, Any] = {}
cfg_path = os.environ.get("ACCELERATE_CONFIG_FILE")
if cfg_path and yaml is not None and os.path.isfile(cfg_path):
handled_exceptions: Tuple[type[BaseException], ...] = (OSError, ValueError)
yaml_error = getattr(yaml, "YAMLError", None)
if isinstance(yaml_error, type) and issubclass(yaml_error, BaseException):
handled_exceptions = handled_exceptions + (yaml_error,)
try:
with open(cfg_path, "r", encoding="utf-8") as cfg_file:
raw = yaml.safe_load(cfg_file) or {}
ds_cfg = raw.get("deepspeed_config") or {}
except handled_exceptions:
ds_cfg = {}
zero_stage_raw = ds_cfg.get("zero_stage", 3)
zero_stage = int(zero_stage_raw) if zero_stage_raw is not None else None
offload_param = ds_cfg.get("offload_param_device")
offload_optim = ds_cfg.get("offload_optimizer_device")
zero3_init_flag = ds_cfg.get("zero3_init_flag")
zero3_save = ds_cfg.get("zero3_save_16bit_model")
kwargs = {
"zero_stage": zero_stage,
"offload_param_device": offload_param,
"offload_optimizer_device": offload_optim,
"zero3_init_flag": zero3_init_flag,
"zero3_save_16bit_model": zero3_save,
}
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if not kwargs:
return None
return ds_class(**kwargs)
[docs]
def maybe_create_deepspeed_plugin() -> Optional[DeepSpeedPlugin]:
"""Public wrapper for DeepSpeed plugin creation."""
return _maybe_create_deepspeed_plugin()
__all__ = [
"Accelerator",
"DeepSpeedPlugin",
"_import_module",
"_optional_dependency",
"_require_dependency",
"_maybe_create_deepspeed_plugin",
"maybe_create_deepspeed_plugin",
"get_trl_prepare_deepspeed",
"require_accelerator",
"require_dataloader",
"require_deepspeed",
"require_torch",
"require_transformer_base_classes",
]