maxent_grpo.training.zero_utils

Utilities to safely integrate DeepSpeed ZeRO with optional dependencies.

Functions

_call_gather_fn(gather_fn, params, modifier_rank)

Invoke GatheredParameters handling pre/post modifier_rank support.

_deepspeed_engine_cls()

Return the DeepSpeedEngine class when available.

_disable_hf_deepspeed_zero3_init()

Temporarily disable HF DeepSpeed ZeRO-3 init for model loading.

_embedding_weight_needing_gather(model)

Return the embedding weight tensor when ZeRO gathering is required.

_embedding_weights_needing_gather(model)

Return all embedding-like weights requiring ZeRO gathering.

_ensure_cuda_fallback()

Return a cuda namespace exposing is_available and empty_cache.

_ensure_deepspeed_ready()

Best-effort initialization of DeepSpeed helpers when installed.

_gather_callable()

Return the callable GatheredParameters helper when available.

_is_deepspeed_engine(model)

Return True when the provided model is a DeepSpeed engine.

_maybe_patch_zero_no_sync(model)

Patch DeepSpeedEngine.no_sync to a no-op when gradients are partitioned.

_maybe_zero_gather_embedding(model)

Gather ZeRO-sharded embedding weights before a forward pass.

_maybe_zero_gather_params(model, enabled[, ...])

Gather ZeRO-partitioned params only when needed.

_reserve_zero_gather_params(params)

Reserve parameter ids for a ZeRO gather region.

_zero_param_list(model)

Return a parameter list for ZeRO-gather contexts, unwrapping engines.

_zero_param_ready_without_gather(param)

Return True when a ZeRO parameter is already materialized.

_zero_partitioning_gradients(model)

Return whether the model partitions gradients (ZeRO-3).

_zero_stage(model)

Return the DeepSpeed ZeRO stage for a model when available.

_zero_status_name(param)

Best-effort extraction of the DeepSpeed ZeRO status name.

Classes

GatherCallable(*args, **kwargs)

Callable signature exposed by DeepSpeed GatheredParameters.