maxent_grpo.rewards.maxent

Copyright 2025 Liv d’Aliberti

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Re-export MaxEnt reward helpers under a dedicated namespace.

maxent_grpo.rewards.maxent.compute_reward_statistics(gen_batch, reward_spec, device, q_temperature, q_epsilon, controller_beta=None, controller_tau=None, scale_rewards=True, zero_truncated_completion_rewards=False, max_completion_len=None, seed_grpo_enabled=False, seed_grpo_alpha=0.0417, seed_grpo_alpha_normalize_by_max_entropy=True, seed_grpo_length_normalize_logprobs=True, seed_grpo_num_generations=None)[source]

Compute utilities, q-distributions, and flattened prompt/completion pairs.

Parameters:
  • gen_batch (GenerationBatch) – Generation batch containing grouped completions/meta.

  • reward_spec (RewardSpec) – Reward configuration (functions + weights).

  • device (torch.device) – Torch device used for reward moment computations.

  • q_temperature (float) – Temperature used when forming q-distributions.

  • q_epsilon (float) – Epsilon floor ensuring full support in q-distribution.

  • controller_beta (float | None) – Optional KL controller beta logged with stats.

  • controller_tau (float | None) – Optional controller tau logged alongside q temp.

  • scale_rewards (bool)

  • zero_truncated_completion_rewards (bool)

  • max_completion_len (int | None)

  • seed_grpo_enabled (bool)

  • seed_grpo_alpha (float)

  • seed_grpo_alpha_normalize_by_max_entropy (bool)

  • seed_grpo_length_normalize_logprobs (bool)

  • seed_grpo_num_generations (int | None)

Returns:

Populated RewardComputation or None when inputs are empty.

Return type:

RewardComputation | None

maxent_grpo.rewards.maxent.compute_reward_totals(reward_spec, completion_batch, flat_answers)[source]

Evaluate reward functions and aggregate per-sequence utilities.

Parameters:
  • reward_spec (RewardSpec) – Reward configuration specifying callables/weights.

  • completion_batch (list[str]) – Flattened completion texts.

  • flat_answers (list[str]) – Flattened answer strings aligned with completions.

Returns:

Tuple of total utilities and per-reward raw values.

Return type:

tuple[list[float], dict[str, list[float]]]

maxent_grpo.rewards.maxent.group_advantages(grouped_comps, total_utils, *, scale_rewards=True)[source]

Return normalized advantages per prompt group and flattened samples.

Parameters:
  • grouped_comps (list[list[str]]) – Completions grouped by prompt.

  • total_utils (list[float]) – Flattened utilities aligned with completions.

  • scale_rewards (bool) – Whether to divide by group std (TRL default).

Returns:

Tuple of grouped advantages and flattened advantage samples.

Return type:

tuple[list[list[float]], list[float]]

maxent_grpo.rewards.maxent.prepare_generation_batch(batch, generator, generation_stats, expected_generations, max_retry_rounds=None)[source]

Generate completions and retry prompts that initially returned nothing.

Parameters:
  • batch (dict[str, list[str]]) – Mini-batch containing prompt/answer lists.

  • generator (GenerationFn) – Callable that produces grouped completions and metadata.

  • generation_stats (dict[str, int]) – Mutable statistics dictionary updated in-place.

  • expected_generations (int) – Desired completions per prompt.

  • max_retry_rounds (int | None) – Optional cap overriding the default retry limit.

Returns:

Populated GenerationBatch or None if generation fails after retries.

Return type:

GenerationBatch | None

maxent_grpo.rewards.maxent.reward_moments(total_utils, device)[source]

Compute reward mean/std on CPU or current accelerator device.

Parameters:
  • total_utils (list[float]) – Flattened reward totals per completion.

  • device (torch.device) – Device used for tensor computations.

Returns:

Tuple containing (mean, std) rewards.

Return type:

tuple[float, float]

class maxent_grpo.rewards.maxent.WeightLoggingView(entropy=0.0, entropy_norm=0.0, entropy_min=0.0, entropy_max=0.0, advantage_entropy_mean=0.0, advantage_entropy_std=0.0)[source]

Bases: object

Aggregated entropy statistics for logging.

Parameters:
entropy: float = 0.0
entropy_norm: float = 0.0
entropy_min: float = 0.0
entropy_max: float = 0.0
advantage_entropy_mean: float = 0.0
advantage_entropy_std: float = 0.0
class maxent_grpo.rewards.maxent.WeightStats(weights_grouped, flat_weights, weight_entropy, weight_entropy_min, weight_entropy_max, advantage_entropy)[source]

Bases: object

Weights per completion and entropy diagnostics.

Parameters:
weights_grouped: List[List[float]]
flat_weights: List[float]
weight_entropy: float
weight_entropy_min: float
weight_entropy_max: float
advantage_entropy: List[float]
class maxent_grpo.rewards.maxent.WeightingConfigLike(*args, **kwargs)[source]

Bases: Protocol

Protocol for objects that carry controller weighting scalars.

beta and tau are required. Optional attributes such as denom or train_grpo_objective may be present and are accessed via getattr.

beta: float
tau: float
class maxent_grpo.rewards.maxent.WeightingSettings(tau, beta, normalization, q_distribution, tau_schedule, kl_controller, train_grpo_objective, scale_rewards=True, controller_meta=<factory>, controller_state=None, allow_empty_weight_fallback=False)[source]

Bases: object

Sequence weighting hyperparameters with convenience accessors.

Parameters:
tau: float
beta: float
normalization: WeightNormalizationSettings
q_distribution: QDistributionSettings
tau_schedule: TauSchedule
kl_controller: KlControllerSettings
train_grpo_objective: bool
scale_rewards: bool = True
controller_meta: ControllerMetaSettings
controller_state: TorchControllerState | None = None
allow_empty_weight_fallback: bool = False
property denom: float

Return the denominator used for weight normalization.

Returns:

Normalization denominator applied to weights.

Return type:

float

property len_norm_ref: bool

Return whether reference log-probs are length-normalized.

Returns:

True when reference stats are length-normalized.

Return type:

bool

property q_temperature: float

Return the q-distribution temperature.

Returns:

Temperature applied to the q-distribution softmax.

Return type:

float

property q_epsilon: float

Return the epsilon smoothing factor.

Returns:

Epsilon smoothing applied to the q-distribution.

Return type:

float

property tau_target_entropy: float | None

Return the target weight entropy.

Returns:

Desired entropy target (None to disable adaptation).

Return type:

float | None

property tau_lr: float

Return the learning rate for tau adaptation.

Returns:

Scalar learning rate for tau updates.

Return type:

float

property tau_min: float

Return the minimum tau value.

Returns:

Lower bound applied to tau.

Return type:

float

property tau_max: float

Return the maximum tau value.

Returns:

Upper bound applied to tau.

Return type:

float

property tau_warmup_steps: int

Return the tau warmup horizon.

Returns:

Number of steps used to warm up tau updates.

Return type:

int

property kl_target: float

Return the KL target.

Returns:

Desired KL divergence target.

Return type:

float

property kl_horizon: int

Return the KL controller horizon.

Returns:

Number of steps used for the KL controller horizon.

Return type:

int

property kl_ctl_step_size: float

Return the KL controller step size.

Returns:

Step size multiplier used by the KL controller.

Return type:

float

maxent_grpo.rewards.maxent.apply_meta_controller_update(weighting_cfg, *, tau_grad=None, beta_grad=None, lr_scale=1.0)[source]

Apply a deterministic meta-controller update in analytic mode.

Parameters:
  • weighting_cfg (WeightingSettings) – Weighting configuration mutated in-place.

  • tau_grad (float | None) – Gradient of the controller objective w.r.t. tau.

  • beta_grad (float | None) – Gradient of the controller objective w.r.t. beta.

  • lr_scale (float) – Optional multiplier applied to the meta learning rate.

Returns:

True when any parameter was updated.

Return type:

bool

maxent_grpo.rewards.maxent.broadcast_controller_state(accelerator, weighting_cfg)[source]

Sync controller scalars (tau, beta, entropy EMA/log) across ranks.

Prefer an all_gather-style sync via accelerator.gather (available on Accelerate 1.x), then fall back to broadcast_object_list when present. Returns True on success.

Parameters:
Return type:

bool

maxent_grpo.rewards.maxent.build_uniform_weight_stats(grouped_completions)[source]

Return uniform weights per prompt as a GRPO-style fallback.

Parameters:

grouped_completions (List[List[str]])

Return type:

WeightStats | None

maxent_grpo.rewards.maxent.collect_weight_entropy(weights_grouped)[source]

Summarize entropy statistics for grouped weights.

Parameters:

weights_grouped (list[list[float]]) – Weight samples grouped per prompt.

Returns:

Tuple containing (mean entropy, min entropy, max entropy, advantage samples).

Return type:

tuple[float, float, float, list[float]]

maxent_grpo.rewards.maxent.compute_weight_stats(grouped_completions, reward_comp, ref_stats, weighting_cfg)[source]

Compute normalized weights using q-values and reference log-probs.

Parameters:
  • grouped_completions (list[list[str]]) – Completion groups per prompt.

  • reward_comp (RewardComputation) – Reward computation outputs used for q-distributions.

  • ref_stats (ReferenceLogprobs) – Reference-model log-probability statistics.

  • weighting_cfg (WeightingSettings) – Weighting configuration (tau/beta/targets).

Returns:

Weight stats dataclass or None if inputs are empty.

Return type:

WeightStats | None

maxent_grpo.rewards.maxent.controller_state_dict(weighting_cfg)[source]

Return a serializable snapshot of the controller state.

Parameters:

weighting_cfg (WeightingConfigLike) – Weighting configuration containing tau/beta scalars.

Returns:

Dictionary describing controller parameters.

Return type:

dict[str, float]

maxent_grpo.rewards.maxent.load_controller_state(path, weighting_cfg)[source]

Load controller parameters if a state file exists.

Parameters:
  • path (str | None) – Filesystem path to a controller JSON file.

  • weighting_cfg (WeightingConfigLike) – Weighting configuration that will receive the values.

Returns:

True when the controller state was loaded successfully.

Return type:

bool

maxent_grpo.rewards.maxent.maybe_update_beta(weighting_cfg, measured_kl)[source]

Adjust beta with a simple KL controller when targets are configured.

Parameters:
  • weighting_cfg (WeightingSettings) – Weighting configuration mutated in-place.

  • measured_kl (float) – Observed KL divergence used for feedback.

Return type:

None

maxent_grpo.rewards.maxent.maybe_update_tau(weighting_cfg, weight_stats, global_step, lr_scale=None)[source]

Adjust tau to hit a target weight entropy if configured.

Parameters:
  • weighting_cfg (WeightingSettings) – Weighting configuration mutated in-place.

  • weight_stats (WeightStats | WeightLoggingView | None) – Current batch weight statistics providing entropy. Can be raw per-batch stats or aggregated logging views.

  • global_step (int) – Training step used for warmup/EMA logic.

  • lr_scale (float | None) – Optional multiplicative scale applied to maxent_tau_lr (e.g., to follow the main LR scheduler).

Return type:

None

maxent_grpo.rewards.maxent.save_controller_state(path, weighting_cfg)[source]

Persist controller parameters to disk.

Parameters:
  • path (str | None) – Destination path for the controller JSON file.

  • weighting_cfg (WeightingConfigLike) – Weighting configuration to serialize.

Return type:

None

maxent_grpo.rewards.maxent.split_reference_logprobs(grouped_completions, ref_stats, len_norm_ref)[source]

Slice the (optionally length-normalized) reference log-probs per prompt group.

Parameters:
  • grouped_completions (list[list[str]]) – Completion groups per prompt.

  • ref_stats (ReferenceLogprobs) – Reference log-probability statistics.

  • len_norm_ref (bool) – Whether ref_logp_sum is already length normalized.

Returns:

Reference log-probability sums aligned with each group.

Return type:

list[list[float]]

maxent_grpo.rewards.maxent.split_reference_token_counts(grouped_completions, ref_stats)[source]

Slice reference token counts per prompt group.

Parameters:
  • grouped_completions (list[list[str]]) – Completion groups per prompt.

  • ref_stats (ReferenceLogprobs) – Reference log-probability statistics.

Returns:

Reference token counts grouped by prompt.

Return type:

list[list[float]]

maxent_grpo.rewards.maxent.weight_matrix_from_q(q_values, logp_values, token_counts, weighting_cfg, *, include_reference_term=True, normalize_by_tokens=True)[source]

Vectorized listwise weights for [prompts, generations] tensors.

Parameters:
Return type:

Any

maxent_grpo.rewards.maxent.weight_vector_from_q(q_values, logp_values, token_counts, weighting_cfg, *, include_reference_term=True, normalize_by_tokens=True)[source]

Convert listwise q-values and reference log-probs into normalized weights.

q_values are already normalized probabilities. Any q-temperature should be applied upstream when those targets are constructed so the listwise posterior does not silently reapply the same temperature here.

Optionally normalize by token counts so each token contributes equally, mitigating length bias when reference log-probabilities are length-sensitive.

Parameters:
  • q_values (list[float]) – Listwise probabilities per completion.

  • logp_values (list[float]) – Reference log-probabilities (or log ratios).

  • token_counts (list[float] | None) – Optional completion token counts for normalization.

  • weighting_cfg (WeightingSettings) – Weighting configuration containing tau/beta.

  • include_reference_term (bool) – Whether to include the reference-model factor.

  • normalize_by_tokens (bool) – Whether to scale weights by token counts.

Returns:

Normalized weights aligned with q_values.

Return type:

list[float]