maxent_grpo.training.runtime.setup

Setup utilities for loading runtime dependencies and accelerator plugins.

Functions

_maybe_create_deepspeed_plugin()

Construct a DeepSpeedPlugin from Accelerate env/config when available.

get_trl_prepare_deepspeed()

Return TRL's prepare_deepspeed helper when available.

require_accelerator(context)

Return accelerate.Accelerator or raise a helpful RuntimeError.

require_dataloader(context)

Return torch.utils.data.DataLoader with a descriptive error on failure.

require_deepspeed(context[, module])

Return a DeepSpeed module import or raise a contextual RuntimeError.

require_torch(context)

Return the torch module or raise a helpful RuntimeError.

require_transformer_base_classes(context)

Return (PreTrainedModel, PreTrainedTokenizer) with clear failure messages.

class maxent_grpo.training.runtime.setup.GenerationSamplingConfig(max_prompt_len, max_completion_len, gen_temperature, gen_top_p, use_vllm, vllm, *, vllm_mode='server')[source]

Bases: object

Shared completion sampling knobs (HF + vLLM).

Parameters:
max_prompt_len: int
max_completion_len: int
gen_temperature: float
gen_top_p: float
use_vllm: bool
vllm: VLLMClientConfig
vllm_mode: str = 'server'
property vllm_url: str

Backward-compatible accessor for the vLLM endpoint URL.

property vllm_rounds_cfg: int

Backward-compatible accessor for the maximum vLLM retry rounds.

property vllm_retry_sleep: float

Backward-compatible accessor for the per-round retry sleep.

property vllm_backfill_local: bool

Backward-compatible accessor for local fallback behavior.

property vllm_request_logprobs: bool

Backward-compatible accessor for whether to request logprobs.

property vllm_best_of: int | None

Backward-compatible accessor for the best-of sampling count.

property vllm_frequency_penalty: float

Backward-compatible accessor for the frequency penalty value.

property vllm_presence_penalty: float

Backward-compatible accessor for the presence penalty value.

property vllm_top_k: int | None

Backward-compatible accessor for the top-k sampling limit.

property vllm_stop_sequences: List[str] | None

Backward-compatible accessor for stop sequences.

property vllm_include_stop_str_in_output: bool

Whether vLLM should preserve matched stop strings in output text.

property vllm_timeout: float

Backward-compatible accessor for request timeout.

property vllm_max_retries: int

Backward-compatible accessor for maximum request retries.

property vllm_backoff: float

Backward-compatible accessor for exponential backoff factor.

property vllm_backoff_multiplier: float

Multiplier applied to the backoff delay after each attempt.

property vllm_guided_json: str | None

Backward-compatible accessor for JSON schema-guided decoding.

property vllm_guided_regex: str | None

Backward-compatible accessor for regex-guided decoding.

property vllm_logit_bias: Dict[str, float] | None

Backward-compatible accessor for logit bias configuration.

property vllm_request_id_prefix: str | None

Backward-compatible accessor for request-id prefixes.

property vllm_sync_weights: bool

Whether to push model weights to the vLLM server before generation.

class maxent_grpo.training.runtime.setup.MaxEntOptions(tau=<factory>, q_temperature=<factory>, q_epsilon=<factory>, length_normalize_ref=<factory>)[source]

Bases: object

Lightweight knobs specific to MaxEnt sequence-level updates.

Parameters:
tau: float
q_temperature: float
q_epsilon: float
length_normalize_ref: bool
maxent_grpo.training.runtime.setup.get_trl_prepare_deepspeed()[source]

Return TRL’s prepare_deepspeed helper when available.

Return type:

Any | None

maxent_grpo.training.runtime.setup.require_accelerator(context)[source]

Return accelerate.Accelerator or raise a helpful RuntimeError.

Parameters:

context (str)

Return type:

Any

maxent_grpo.training.runtime.setup.require_dataloader(context)[source]

Return torch.utils.data.DataLoader with a descriptive error on failure.

Parameters:

context (str)

Return type:

Any

maxent_grpo.training.runtime.setup.require_deepspeed(context, module='deepspeed')[source]

Return a DeepSpeed module import or raise a contextual RuntimeError.

Parameters:
Return type:

Any

maxent_grpo.training.runtime.setup.require_torch(context)[source]

Return the torch module or raise a helpful RuntimeError.

Parameters:

context (str)

Return type:

Any

maxent_grpo.training.runtime.setup.require_transformer_base_classes(context)[source]

Return (PreTrainedModel, PreTrainedTokenizer) with clear failure messages.

Parameters:

context (str)

Return type:

Tuple[Any, Any]