maxent_grpo.training.weighting

Copyright 2025 Liv d’Aliberti

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

class maxent_grpo.training.weighting.ControllerMetaSettings(enabled=False, method='analytic', learning_rate=0.0, tau_learning_rate=0.0, beta_learning_rate=0.0, beta_grad_clip=0.0, update_interval=1, objective='potential', analytic_steps=1, optimizer='sgd', truncation_steps=1, use_hessian=False, last_tau_grad=0.0, last_beta_grad=0.0)[source]

Bases: object

Meta-controller knobs governing tau/beta adaptation.

Parameters:
  • enabled (bool)

  • method (str)

  • learning_rate (float)

  • tau_learning_rate (float)

  • beta_learning_rate (float)

  • beta_grad_clip (float)

  • update_interval (int)

  • objective (str)

  • analytic_steps (int)

  • optimizer (str)

  • truncation_steps (int)

  • use_hessian (bool)

  • last_tau_grad (float)

  • last_beta_grad (float)

enabled: bool = False
method: str = 'analytic'
learning_rate: float = 0.0
tau_learning_rate: float = 0.0
beta_learning_rate: float = 0.0
beta_grad_clip: float = 0.0
update_interval: int = 1
objective: str = 'potential'
analytic_steps: int = 1
optimizer: str = 'sgd'
truncation_steps: int = 1
use_hessian: bool = False
last_tau_grad: float = 0.0
last_beta_grad: float = 0.0
to_state()[source]

Return a serializable snapshot of the meta-controller settings.

Return type:

Dict[str, Any]

apply_state(payload)[source]

Update the meta-controller settings from a serialized payload.

Parameters:

payload (Mapping[str, Any])

Return type:

None

class maxent_grpo.training.weighting.ControllerStateSnapshot(beta, tau, tau_log, tau_entropy_ema, meta=<factory>)[source]

Bases: object

Serializable controller state describing tau/beta parameters.

Parameters:
beta: float
tau: float
tau_log: float
tau_entropy_ema: float
meta: Dict[str, Any]
STATE_VERSION: ClassVar[int] = 1
to_dict()[source]

Serialize the snapshot to a JSON-friendly mapping.

Return type:

Dict[str, Any]

classmethod from_weighting(weighting_cfg)[source]

Build a controller snapshot from the active weighting settings.

Parameters:

weighting_cfg (WeightingConfigLike)

Return type:

ControllerStateSnapshot

classmethod from_dict(payload)[source]

Instantiate a snapshot from a serialized payload.

Parameters:

payload (Mapping[str, Any])

Return type:

ControllerStateSnapshot

apply_to_weighting(weighting_cfg)[source]

Apply the snapshot contents to a weighting configuration.

Parameters:

weighting_cfg (WeightingConfigLike)

Return type:

None

class maxent_grpo.training.weighting.KlControllerSettings(target, horizon, step_size)[source]

Bases: object

Controller settings for KL regularization.

Parameters:
target: float
horizon: int
step_size: float
class maxent_grpo.training.weighting.QDistributionSettings(temperature, epsilon)[source]

Bases: object

Softmax temperature and smoothing for weighting.

Parameters:
temperature: float
epsilon: float
class maxent_grpo.training.weighting.TauSchedule(target_entropy, learning_rate, minimum_value, maximum_value, warmup_steps, target_entropy_start=None, target_entropy_final=None, target_entropy_horizon=0)[source]

Bases: object

Hyperparameters controlling tau adaptation.

Parameters:
  • target_entropy (float | None)

  • learning_rate (float)

  • minimum_value (float)

  • maximum_value (float)

  • warmup_steps (int)

  • target_entropy_start (float | None)

  • target_entropy_final (float | None)

  • target_entropy_horizon (int)

target_entropy: float | None
learning_rate: float
minimum_value: float
maximum_value: float
warmup_steps: int
target_entropy_start: float | None = None
target_entropy_final: float | None = None
target_entropy_horizon: int = 0
class maxent_grpo.training.weighting.WeightNormalizationSettings(denom, len_norm_ref)[source]

Bases: object

Length-normalization flag and denominator scaling.

Parameters:
denom: float
len_norm_ref: bool
maxent_grpo.training.weighting.apply_meta_controller_update(weighting_cfg, *, tau_grad=None, beta_grad=None, lr_scale=1.0)[source]

Apply a deterministic meta-controller update in analytic mode.

Parameters:
  • weighting_cfg (WeightingSettings) – Weighting configuration mutated in-place.

  • tau_grad (float | None) – Gradient of the controller objective w.r.t. tau.

  • beta_grad (float | None) – Gradient of the controller objective w.r.t. beta.

  • lr_scale (float) – Optional multiplier applied to the meta learning rate.

Returns:

True when any parameter was updated.

Return type:

bool

maxent_grpo.training.weighting.broadcast_controller_state(accelerator, weighting_cfg)[source]

Sync controller scalars (tau, beta, entropy EMA/log) across ranks.

Prefer an all_gather-style sync via accelerator.gather (available on Accelerate 1.x), then fall back to broadcast_object_list when present. Returns True on success.

Parameters:
Return type:

bool

maxent_grpo.training.weighting.build_uniform_weight_stats(grouped_completions)[source]

Return uniform weights per prompt as a GRPO-style fallback.

Parameters:

grouped_completions (List[List[str]])

Return type:

WeightStats | None

maxent_grpo.training.weighting.collect_weight_entropy(weights_grouped)[source]

Summarize entropy statistics for grouped weights.

Parameters:

weights_grouped (list[list[float]]) – Weight samples grouped per prompt.

Returns:

Tuple containing (mean entropy, min entropy, max entropy, advantage samples).

Return type:

tuple[float, float, float, list[float]]

maxent_grpo.training.weighting.compute_weight_stats(grouped_completions, reward_comp, ref_stats, weighting_cfg)[source]

Compute normalized weights using q-values and reference log-probs.

Parameters:
  • grouped_completions (list[list[str]]) – Completion groups per prompt.

  • reward_comp (RewardComputation) – Reward computation outputs used for q-distributions.

  • ref_stats (ReferenceLogprobs) – Reference-model log-probability statistics.

  • weighting_cfg (WeightingSettings) – Weighting configuration (tau/beta/targets).

Returns:

Weight stats dataclass or None if inputs are empty.

Return type:

WeightStats | None

maxent_grpo.training.weighting.controller_state_dict(weighting_cfg)[source]

Return a serializable snapshot of the controller state.

Parameters:

weighting_cfg (WeightingConfigLike) – Weighting configuration containing tau/beta scalars.

Returns:

Dictionary describing controller parameters.

Return type:

dict[str, float]

maxent_grpo.training.weighting.load_controller_state(path, weighting_cfg)[source]

Load controller parameters if a state file exists.

Parameters:
  • path (str | None) – Filesystem path to a controller JSON file.

  • weighting_cfg (WeightingConfigLike) – Weighting configuration that will receive the values.

Returns:

True when the controller state was loaded successfully.

Return type:

bool

maxent_grpo.training.weighting.maybe_update_beta(weighting_cfg, measured_kl)[source]

Adjust beta with a simple KL controller when targets are configured.

Parameters:
  • weighting_cfg (WeightingSettings) – Weighting configuration mutated in-place.

  • measured_kl (float) – Observed KL divergence used for feedback.

Return type:

None

maxent_grpo.training.weighting.maybe_update_tau(weighting_cfg, weight_stats, global_step, lr_scale=None)[source]

Adjust tau to hit a target weight entropy if configured.

Parameters:
  • weighting_cfg (WeightingSettings) – Weighting configuration mutated in-place.

  • weight_stats (WeightStats | WeightLoggingView | None) – Current batch weight statistics providing entropy. Can be raw per-batch stats or aggregated logging views.

  • global_step (int) – Training step used for warmup/EMA logic.

  • lr_scale (float | None) – Optional multiplicative scale applied to maxent_tau_lr (e.g., to follow the main LR scheduler).

Return type:

None

maxent_grpo.training.weighting.save_controller_state(path, weighting_cfg)[source]

Persist controller parameters to disk.

Parameters:
  • path (str | None) – Destination path for the controller JSON file.

  • weighting_cfg (WeightingConfigLike) – Weighting configuration to serialize.

Return type:

None

maxent_grpo.training.weighting.split_reference_logprobs(grouped_completions, ref_stats, len_norm_ref)[source]

Slice the (optionally length-normalized) reference log-probs per prompt group.

Parameters:
  • grouped_completions (list[list[str]]) – Completion groups per prompt.

  • ref_stats (ReferenceLogprobs) – Reference log-probability statistics.

  • len_norm_ref (bool) – Whether ref_logp_sum is already length normalized.

Returns:

Reference log-probability sums aligned with each group.

Return type:

list[list[float]]

maxent_grpo.training.weighting.split_reference_token_counts(grouped_completions, ref_stats)[source]

Slice reference token counts per prompt group.

Parameters:
  • grouped_completions (list[list[str]]) – Completion groups per prompt.

  • ref_stats (ReferenceLogprobs) – Reference log-probability statistics.

Returns:

Reference token counts grouped by prompt.

Return type:

list[list[float]]

maxent_grpo.training.weighting.weight_matrix_from_q(q_values, logp_values, token_counts, weighting_cfg, *, include_reference_term=True, normalize_by_tokens=True)[source]

Vectorized listwise weights for [prompts, generations] tensors.

Parameters:
Return type:

Any

maxent_grpo.training.weighting.weight_vector_from_q(q_values, logp_values, token_counts, weighting_cfg, *, include_reference_term=True, normalize_by_tokens=True)[source]

Convert listwise q-values and reference log-probs into normalized weights.

q_values are already normalized probabilities. Any q-temperature should be applied upstream when those targets are constructed so the listwise posterior does not silently reapply the same temperature here.

Optionally normalize by token counts so each token contributes equally, mitigating length bias when reference log-probabilities are length-sensitive.

Parameters:
  • q_values (list[float]) – Listwise probabilities per completion.

  • logp_values (list[float]) – Reference log-probabilities (or log ratios).

  • token_counts (list[float] | None) – Optional completion token counts for normalization.

  • weighting_cfg (WeightingSettings) – Weighting configuration containing tau/beta.

  • include_reference_term (bool) – Whether to include the reference-model factor.

  • normalize_by_tokens (bool) – Whether to scale weights by token counts.

Returns:

Normalized weights aligned with q_values.

Return type:

list[float]

Modules

logic

Weighting helpers extracted from the MaxEnt-GRPO training loop.

types

Weighting-related dataclasses shared across the MaxEnt training loop.