maxent_grpo.training.weighting¶
Copyright 2025 Liv d’Aliberti
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
- class maxent_grpo.training.weighting.ControllerMetaSettings(enabled=False, method='analytic', learning_rate=0.0, tau_learning_rate=0.0, beta_learning_rate=0.0, beta_grad_clip=0.0, update_interval=1, objective='potential', analytic_steps=1, optimizer='sgd', truncation_steps=1, use_hessian=False, last_tau_grad=0.0, last_beta_grad=0.0)[source]¶
Bases:
objectMeta-controller knobs governing tau/beta adaptation.
- Parameters:
- class maxent_grpo.training.weighting.ControllerStateSnapshot(beta, tau, tau_log, tau_entropy_ema, meta=<factory>)[source]¶
Bases:
objectSerializable controller state describing tau/beta parameters.
- classmethod from_weighting(weighting_cfg)[source]¶
Build a controller snapshot from the active weighting settings.
- Parameters:
weighting_cfg (WeightingConfigLike)
- Return type:
- classmethod from_dict(payload)[source]¶
Instantiate a snapshot from a serialized payload.
- Parameters:
- Return type:
- apply_to_weighting(weighting_cfg)[source]¶
Apply the snapshot contents to a weighting configuration.
- Parameters:
weighting_cfg (WeightingConfigLike)
- Return type:
None
- class maxent_grpo.training.weighting.KlControllerSettings(target, horizon, step_size)[source]¶
Bases:
objectController settings for KL regularization.
- class maxent_grpo.training.weighting.QDistributionSettings(temperature, epsilon)[source]¶
Bases:
objectSoftmax temperature and smoothing for weighting.
- class maxent_grpo.training.weighting.TauSchedule(target_entropy, learning_rate, minimum_value, maximum_value, warmup_steps, target_entropy_start=None, target_entropy_final=None, target_entropy_horizon=0)[source]¶
Bases:
objectHyperparameters controlling tau adaptation.
- Parameters:
- class maxent_grpo.training.weighting.WeightNormalizationSettings(denom, len_norm_ref)[source]¶
Bases:
objectLength-normalization flag and denominator scaling.
- maxent_grpo.training.weighting.apply_meta_controller_update(weighting_cfg, *, tau_grad=None, beta_grad=None, lr_scale=1.0)[source]¶
Apply a deterministic meta-controller update in analytic mode.
- Parameters:
weighting_cfg (WeightingSettings) – Weighting configuration mutated in-place.
tau_grad (float | None) – Gradient of the controller objective w.r.t. tau.
beta_grad (float | None) – Gradient of the controller objective w.r.t. beta.
lr_scale (float) – Optional multiplier applied to the meta learning rate.
- Returns:
Truewhen any parameter was updated.- Return type:
- maxent_grpo.training.weighting.broadcast_controller_state(accelerator, weighting_cfg)[source]¶
Sync controller scalars (tau, beta, entropy EMA/log) across ranks.
Prefer an all_gather-style sync via accelerator.gather (available on Accelerate 1.x), then fall back to broadcast_object_list when present. Returns
Trueon success.- Parameters:
accelerator (Any)
weighting_cfg (WeightingConfigLike)
- Return type:
- maxent_grpo.training.weighting.build_uniform_weight_stats(grouped_completions)[source]¶
Return uniform weights per prompt as a GRPO-style fallback.
- Parameters:
- Return type:
WeightStats | None
- maxent_grpo.training.weighting.collect_weight_entropy(weights_grouped)[source]¶
Summarize entropy statistics for grouped weights.
- maxent_grpo.training.weighting.compute_weight_stats(grouped_completions, reward_comp, ref_stats, weighting_cfg)[source]¶
Compute normalized weights using q-values and reference log-probs.
- Parameters:
grouped_completions (list[list[str]]) – Completion groups per prompt.
reward_comp (RewardComputation) – Reward computation outputs used for q-distributions.
ref_stats (ReferenceLogprobs) – Reference-model log-probability statistics.
weighting_cfg (WeightingSettings) – Weighting configuration (tau/beta/targets).
- Returns:
Weight stats dataclass or
Noneif inputs are empty.- Return type:
WeightStats | None
- maxent_grpo.training.weighting.controller_state_dict(weighting_cfg)[source]¶
Return a serializable snapshot of the controller state.
- Parameters:
weighting_cfg (WeightingConfigLike) – Weighting configuration containing tau/beta scalars.
- Returns:
Dictionary describing controller parameters.
- Return type:
- maxent_grpo.training.weighting.load_controller_state(path, weighting_cfg)[source]¶
Load controller parameters if a state file exists.
- Parameters:
path (str | None) – Filesystem path to a controller JSON file.
weighting_cfg (WeightingConfigLike) – Weighting configuration that will receive the values.
- Returns:
Truewhen the controller state was loaded successfully.- Return type:
- maxent_grpo.training.weighting.maybe_update_beta(weighting_cfg, measured_kl)[source]¶
Adjust beta with a simple KL controller when targets are configured.
- Parameters:
weighting_cfg (WeightingSettings) – Weighting configuration mutated in-place.
measured_kl (float) – Observed KL divergence used for feedback.
- Return type:
None
- maxent_grpo.training.weighting.maybe_update_tau(weighting_cfg, weight_stats, global_step, lr_scale=None)[source]¶
Adjust tau to hit a target weight entropy if configured.
- Parameters:
weighting_cfg (WeightingSettings) – Weighting configuration mutated in-place.
weight_stats (WeightStats | WeightLoggingView | None) – Current batch weight statistics providing entropy. Can be raw per-batch stats or aggregated logging views.
global_step (int) – Training step used for warmup/EMA logic.
lr_scale (float | None) – Optional multiplicative scale applied to
maxent_tau_lr(e.g., to follow the main LR scheduler).
- Return type:
None
- maxent_grpo.training.weighting.save_controller_state(path, weighting_cfg)[source]¶
Persist controller parameters to disk.
- Parameters:
path (str | None) – Destination path for the controller JSON file.
weighting_cfg (WeightingConfigLike) – Weighting configuration to serialize.
- Return type:
None
- maxent_grpo.training.weighting.split_reference_logprobs(grouped_completions, ref_stats, len_norm_ref)[source]¶
Slice the (optionally length-normalized) reference log-probs per prompt group.
- Parameters:
grouped_completions (list[list[str]]) – Completion groups per prompt.
ref_stats (ReferenceLogprobs) – Reference log-probability statistics.
len_norm_ref (bool) – Whether
ref_logp_sumis already length normalized.
- Returns:
Reference log-probability sums aligned with each group.
- Return type:
- maxent_grpo.training.weighting.split_reference_token_counts(grouped_completions, ref_stats)[source]¶
Slice reference token counts per prompt group.
- maxent_grpo.training.weighting.weight_matrix_from_q(q_values, logp_values, token_counts, weighting_cfg, *, include_reference_term=True, normalize_by_tokens=True)[source]¶
Vectorized listwise weights for
[prompts, generations]tensors.
- maxent_grpo.training.weighting.weight_vector_from_q(q_values, logp_values, token_counts, weighting_cfg, *, include_reference_term=True, normalize_by_tokens=True)[source]¶
Convert listwise q-values and reference log-probs into normalized weights.
q_valuesare already normalized probabilities. Any q-temperature should be applied upstream when those targets are constructed so the listwise posterior does not silently reapply the same temperature here.Optionally normalize by token counts so each token contributes equally, mitigating length bias when reference log-probabilities are length-sensitive.
- Parameters:
q_values (list[float]) – Listwise probabilities per completion.
logp_values (list[float]) – Reference log-probabilities (or log ratios).
token_counts (list[float] | None) – Optional completion token counts for normalization.
weighting_cfg (WeightingSettings) – Weighting configuration containing tau/beta.
include_reference_term (bool) – Whether to include the reference-model factor.
normalize_by_tokens (bool) – Whether to scale weights by token counts.
- Returns:
Normalized weights aligned with
q_values.- Return type:
Modules