maxent_grpo.rewards.maxent¶
Copyright 2025 Liv d’Aliberti
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
Re-export MaxEnt reward helpers under a dedicated namespace.
- maxent_grpo.rewards.maxent.compute_reward_statistics(gen_batch, reward_spec, device, q_temperature, q_epsilon, controller_beta=None, controller_tau=None, scale_rewards=True, zero_truncated_completion_rewards=False, max_completion_len=None, seed_grpo_enabled=False, seed_grpo_alpha=0.0417, seed_grpo_alpha_normalize_by_max_entropy=True, seed_grpo_length_normalize_logprobs=True, seed_grpo_num_generations=None)[source]¶
Compute utilities, q-distributions, and flattened prompt/completion pairs.
- Parameters:
gen_batch (
GenerationBatch) – Generation batch containing grouped completions/meta.reward_spec (RewardSpec) – Reward configuration (functions + weights).
device (
torch.device) – Torch device used for reward moment computations.q_temperature (float) – Temperature used when forming q-distributions.
q_epsilon (float) – Epsilon floor ensuring full support in q-distribution.
controller_beta (float | None) – Optional KL controller beta logged with stats.
controller_tau (float | None) – Optional controller tau logged alongside q temp.
scale_rewards (bool)
zero_truncated_completion_rewards (bool)
max_completion_len (int | None)
seed_grpo_enabled (bool)
seed_grpo_alpha (float)
seed_grpo_alpha_normalize_by_max_entropy (bool)
seed_grpo_length_normalize_logprobs (bool)
seed_grpo_num_generations (int | None)
- Returns:
Populated
RewardComputationorNonewhen inputs are empty.- Return type:
RewardComputation| None
- maxent_grpo.rewards.maxent.compute_reward_totals(reward_spec, completion_batch, flat_answers)[source]¶
Evaluate reward functions and aggregate per-sequence utilities.
- Parameters:
reward_spec (RewardSpec) – Reward configuration specifying callables/weights.
flat_answers (list[str]) – Flattened answer strings aligned with completions.
- Returns:
Tuple of total utilities and per-reward raw values.
- Return type:
- maxent_grpo.rewards.maxent.group_advantages(grouped_comps, total_utils, *, scale_rewards=True)[source]¶
Return normalized advantages per prompt group and flattened samples.
- Parameters:
- Returns:
Tuple of grouped advantages and flattened advantage samples.
- Return type:
- maxent_grpo.rewards.maxent.prepare_generation_batch(batch, generator, generation_stats, expected_generations, max_retry_rounds=None)[source]¶
Generate completions and retry prompts that initially returned nothing.
- Parameters:
batch (dict[str, list[str]]) – Mini-batch containing
prompt/answerlists.generator (
GenerationFn) – Callable that produces grouped completions and metadata.generation_stats (dict[str, int]) – Mutable statistics dictionary updated in-place.
expected_generations (int) – Desired completions per prompt.
max_retry_rounds (int | None) – Optional cap overriding the default retry limit.
- Returns:
Populated
GenerationBatchorNoneif generation fails after retries.- Return type:
GenerationBatch| None
- maxent_grpo.rewards.maxent.reward_moments(total_utils, device)[source]¶
Compute reward mean/std on CPU or current accelerator device.
- class maxent_grpo.rewards.maxent.WeightLoggingView(entropy=0.0, entropy_norm=0.0, entropy_min=0.0, entropy_max=0.0, advantage_entropy_mean=0.0, advantage_entropy_std=0.0)[source]¶
Bases:
objectAggregated entropy statistics for logging.
- Parameters:
- class maxent_grpo.rewards.maxent.WeightStats(weights_grouped, flat_weights, weight_entropy, weight_entropy_min, weight_entropy_max, advantage_entropy)[source]¶
Bases:
objectWeights per completion and entropy diagnostics.
- Parameters:
- class maxent_grpo.rewards.maxent.WeightingConfigLike(*args, **kwargs)[source]¶
Bases:
ProtocolProtocol for objects that carry controller weighting scalars.
betaandtauare required. Optional attributes such asdenomortrain_grpo_objectivemay be present and are accessed viagetattr.
- class maxent_grpo.rewards.maxent.WeightingSettings(tau, beta, normalization, q_distribution, tau_schedule, kl_controller, train_grpo_objective, scale_rewards=True, controller_meta=<factory>, controller_state=None, allow_empty_weight_fallback=False)[source]¶
Bases:
objectSequence weighting hyperparameters with convenience accessors.
- Parameters:
tau (float)
beta (float)
normalization (WeightNormalizationSettings)
q_distribution (QDistributionSettings)
tau_schedule (TauSchedule)
kl_controller (KlControllerSettings)
train_grpo_objective (bool)
scale_rewards (bool)
controller_meta (ControllerMetaSettings)
controller_state (TorchControllerState | None)
allow_empty_weight_fallback (bool)
- normalization: WeightNormalizationSettings¶
- q_distribution: QDistributionSettings¶
- tau_schedule: TauSchedule¶
- kl_controller: KlControllerSettings¶
- controller_meta: ControllerMetaSettings¶
- controller_state: TorchControllerState | None = None¶
- property denom: float¶
Return the denominator used for weight normalization.
- Returns:
Normalization denominator applied to weights.
- Return type:
- property len_norm_ref: bool¶
Return whether reference log-probs are length-normalized.
- Returns:
Truewhen reference stats are length-normalized.- Return type:
- property q_temperature: float¶
Return the q-distribution temperature.
- Returns:
Temperature applied to the q-distribution softmax.
- Return type:
- property q_epsilon: float¶
Return the epsilon smoothing factor.
- Returns:
Epsilon smoothing applied to the q-distribution.
- Return type:
- property tau_target_entropy: float | None¶
Return the target weight entropy.
- Returns:
Desired entropy target (
Noneto disable adaptation).- Return type:
float | None
- property tau_lr: float¶
Return the learning rate for tau adaptation.
- Returns:
Scalar learning rate for tau updates.
- Return type:
- property tau_min: float¶
Return the minimum tau value.
- Returns:
Lower bound applied to tau.
- Return type:
- property tau_max: float¶
Return the maximum tau value.
- Returns:
Upper bound applied to tau.
- Return type:
- property tau_warmup_steps: int¶
Return the tau warmup horizon.
- Returns:
Number of steps used to warm up tau updates.
- Return type:
- property kl_target: float¶
Return the KL target.
- Returns:
Desired KL divergence target.
- Return type:
- maxent_grpo.rewards.maxent.apply_meta_controller_update(weighting_cfg, *, tau_grad=None, beta_grad=None, lr_scale=1.0)[source]¶
Apply a deterministic meta-controller update in analytic mode.
- Parameters:
weighting_cfg (WeightingSettings) – Weighting configuration mutated in-place.
tau_grad (float | None) – Gradient of the controller objective w.r.t. tau.
beta_grad (float | None) – Gradient of the controller objective w.r.t. beta.
lr_scale (float) – Optional multiplier applied to the meta learning rate.
- Returns:
Truewhen any parameter was updated.- Return type:
- maxent_grpo.rewards.maxent.broadcast_controller_state(accelerator, weighting_cfg)[source]¶
Sync controller scalars (tau, beta, entropy EMA/log) across ranks.
Prefer an all_gather-style sync via accelerator.gather (available on Accelerate 1.x), then fall back to broadcast_object_list when present. Returns
Trueon success.- Parameters:
accelerator (Any)
weighting_cfg (WeightingConfigLike)
- Return type:
- maxent_grpo.rewards.maxent.build_uniform_weight_stats(grouped_completions)[source]¶
Return uniform weights per prompt as a GRPO-style fallback.
- Parameters:
- Return type:
WeightStats | None
- maxent_grpo.rewards.maxent.collect_weight_entropy(weights_grouped)[source]¶
Summarize entropy statistics for grouped weights.
- maxent_grpo.rewards.maxent.compute_weight_stats(grouped_completions, reward_comp, ref_stats, weighting_cfg)[source]¶
Compute normalized weights using q-values and reference log-probs.
- Parameters:
grouped_completions (list[list[str]]) – Completion groups per prompt.
reward_comp (RewardComputation) – Reward computation outputs used for q-distributions.
ref_stats (ReferenceLogprobs) – Reference-model log-probability statistics.
weighting_cfg (WeightingSettings) – Weighting configuration (tau/beta/targets).
- Returns:
Weight stats dataclass or
Noneif inputs are empty.- Return type:
WeightStats | None
- maxent_grpo.rewards.maxent.controller_state_dict(weighting_cfg)[source]¶
Return a serializable snapshot of the controller state.
- Parameters:
weighting_cfg (WeightingConfigLike) – Weighting configuration containing tau/beta scalars.
- Returns:
Dictionary describing controller parameters.
- Return type:
- maxent_grpo.rewards.maxent.load_controller_state(path, weighting_cfg)[source]¶
Load controller parameters if a state file exists.
- Parameters:
path (str | None) – Filesystem path to a controller JSON file.
weighting_cfg (WeightingConfigLike) – Weighting configuration that will receive the values.
- Returns:
Truewhen the controller state was loaded successfully.- Return type:
- maxent_grpo.rewards.maxent.maybe_update_beta(weighting_cfg, measured_kl)[source]¶
Adjust beta with a simple KL controller when targets are configured.
- Parameters:
weighting_cfg (WeightingSettings) – Weighting configuration mutated in-place.
measured_kl (float) – Observed KL divergence used for feedback.
- Return type:
None
- maxent_grpo.rewards.maxent.maybe_update_tau(weighting_cfg, weight_stats, global_step, lr_scale=None)[source]¶
Adjust tau to hit a target weight entropy if configured.
- Parameters:
weighting_cfg (WeightingSettings) – Weighting configuration mutated in-place.
weight_stats (WeightStats | WeightLoggingView | None) – Current batch weight statistics providing entropy. Can be raw per-batch stats or aggregated logging views.
global_step (int) – Training step used for warmup/EMA logic.
lr_scale (float | None) – Optional multiplicative scale applied to
maxent_tau_lr(e.g., to follow the main LR scheduler).
- Return type:
None
- maxent_grpo.rewards.maxent.save_controller_state(path, weighting_cfg)[source]¶
Persist controller parameters to disk.
- Parameters:
path (str | None) – Destination path for the controller JSON file.
weighting_cfg (WeightingConfigLike) – Weighting configuration to serialize.
- Return type:
None
- maxent_grpo.rewards.maxent.split_reference_logprobs(grouped_completions, ref_stats, len_norm_ref)[source]¶
Slice the (optionally length-normalized) reference log-probs per prompt group.
- Parameters:
grouped_completions (list[list[str]]) – Completion groups per prompt.
ref_stats (ReferenceLogprobs) – Reference log-probability statistics.
len_norm_ref (bool) – Whether
ref_logp_sumis already length normalized.
- Returns:
Reference log-probability sums aligned with each group.
- Return type:
- maxent_grpo.rewards.maxent.split_reference_token_counts(grouped_completions, ref_stats)[source]¶
Slice reference token counts per prompt group.
- maxent_grpo.rewards.maxent.weight_matrix_from_q(q_values, logp_values, token_counts, weighting_cfg, *, include_reference_term=True, normalize_by_tokens=True)[source]¶
Vectorized listwise weights for
[prompts, generations]tensors.
- maxent_grpo.rewards.maxent.weight_vector_from_q(q_values, logp_values, token_counts, weighting_cfg, *, include_reference_term=True, normalize_by_tokens=True)[source]¶
Convert listwise q-values and reference log-probs into normalized weights.
q_valuesare already normalized probabilities. Any q-temperature should be applied upstream when those targets are constructed so the listwise posterior does not silently reapply the same temperature here.Optionally normalize by token counts so each token contributes equally, mitigating length bias when reference log-probabilities are length-sensitive.
- Parameters:
q_values (list[float]) – Listwise probabilities per completion.
logp_values (list[float]) – Reference log-probabilities (or log ratios).
token_counts (list[float] | None) – Optional completion token counts for normalization.
weighting_cfg (WeightingSettings) – Weighting configuration containing tau/beta.
include_reference_term (bool) – Whether to include the reference-model factor.
normalize_by_tokens (bool) – Whether to scale weights by token counts.
- Returns:
Normalized weights aligned with
q_values.- Return type: