Source code for maxent_grpo.training.runtime.setup
"""Setup utilities for loading runtime dependencies and accelerator plugins."""
from __future__ import annotations
from typing import Any, Optional, Tuple
from maxent_grpo.utils.imports import (
cached_import as _import_module,
optional_import as _optional_dependency,
require_dependency as _require_dependency,
)
from . import deps as _deps
from .config import (
GenerationSamplingConfig,
MaxEntOptions,
VLLMClientConfig,
)
Accelerator = _deps.Accelerator
DeepSpeedPlugin = _deps.DeepSpeedPlugin
[docs]
def require_torch(context: str) -> Any:
"""Return the torch module or raise a helpful RuntimeError."""
return _deps.require_torch(context)
[docs]
def require_dataloader(context: str) -> Any:
"""Return torch.utils.data.DataLoader with a descriptive error on failure."""
return _deps.require_dataloader(context)
[docs]
def require_accelerator(context: str) -> Any:
"""Return accelerate.Accelerator or raise a helpful RuntimeError."""
return _deps.require_accelerator(context)
[docs]
def require_deepspeed(context: str, module: str = "deepspeed") -> Any:
"""Return a DeepSpeed module import or raise a contextual RuntimeError."""
return _deps.require_deepspeed(context, module)
[docs]
def get_trl_prepare_deepspeed() -> Optional[Any]:
"""Return TRL's prepare_deepspeed helper when available."""
return _deps.get_trl_prepare_deepspeed()
def _maybe_create_deepspeed_plugin() -> Optional[Any]:
"""Construct a DeepSpeedPlugin from Accelerate env/config when available."""
return _deps.maybe_create_deepspeed_plugin()
__all__ = [
"Accelerator",
"DeepSpeedPlugin",
"GenerationSamplingConfig",
"MaxEntOptions",
"VLLMClientConfig",
"_import_module",
"_optional_dependency",
"_require_dependency",
"_maybe_create_deepspeed_plugin",
"get_trl_prepare_deepspeed",
"require_accelerator",
"require_dataloader",
"require_deepspeed",
"require_torch",
"require_transformer_base_classes",
]