maxent_grpo.training.weighting.logic

Weighting helpers extracted from the MaxEnt-GRPO training loop.

Functions

_ensure_tau_history(weighting_cfg[, ...])

Ensure tau controller history fields have finite defaults.

_maybe_init_controller_state(weighting_cfg)

Attach a Torch-backed controller state if torch is available.

_resolve_target_entropy(weighting_cfg, ...)

Compute the active target entropy, honoring optional annealing settings.

_split_ref_logprobs_per_token(...)

Return per-token reference log-probs sliced per prompt group.

_sync_controller_state(weighting_cfg)

Ensure the TorchControllerState mirrors the scalar tau/beta.

_to_float_list(values)

Return a best-effort list of floats extracted from values.

apply_meta_controller_update(weighting_cfg, *)

Apply a deterministic meta-controller update in analytic mode.

broadcast_controller_state(accelerator, ...)

Sync controller scalars (tau, beta, entropy EMA/log) across ranks.

build_uniform_weight_stats(grouped_completions)

Return uniform weights per prompt as a GRPO-style fallback.

build_weighting_settings(cfg)

Convenience builder for WeightingSettings from GRPOConfig.

collect_weight_entropy(weights_grouped)

Summarize entropy statistics for grouped weights.

compute_weight_stats(grouped_completions, ...)

Compute normalized weights using q-values and reference log-probs.

controller_state_dict(weighting_cfg)

Return a serializable snapshot of the controller state.

load_controller_state(path, weighting_cfg)

Load controller parameters if a state file exists.

maybe_update_beta(weighting_cfg, measured_kl)

Adjust beta with a simple KL controller when targets are configured.

maybe_update_tau(weighting_cfg, ...[, lr_scale])

Adjust tau to hit a target weight entropy if configured.

save_controller_state(path, weighting_cfg)

Persist controller parameters to disk.

split_reference_logprobs(...)

Slice the (optionally length-normalized) reference log-probs per prompt group.

split_reference_token_counts(...)

Slice reference token counts per prompt group.

weight_matrix_from_q(q_values, logp_values, ...)

Vectorized listwise weights for [prompts, generations] tensors.

weight_vector_from_q(q_values, logp_values, ...)

Convert listwise q-values and reference log-probs into normalized weights.

maxent_grpo.training.weighting.logic.build_weighting_settings(cfg)[source]

Convenience builder for WeightingSettings from GRPOConfig.

Parameters:

cfg (GRPOConfig)

Return type:

WeightingSettings

maxent_grpo.training.weighting.logic.split_reference_logprobs(grouped_completions, ref_stats, len_norm_ref)[source]

Slice the (optionally length-normalized) reference log-probs per prompt group.

Parameters:
  • grouped_completions (list[list[str]]) – Completion groups per prompt.

  • ref_stats (ReferenceLogprobs) – Reference log-probability statistics.

  • len_norm_ref (bool) – Whether ref_logp_sum is already length normalized.

Returns:

Reference log-probability sums aligned with each group.

Return type:

list[list[float]]

maxent_grpo.training.weighting.logic.split_reference_token_counts(grouped_completions, ref_stats)[source]

Slice reference token counts per prompt group.

Parameters:
  • grouped_completions (list[list[str]]) – Completion groups per prompt.

  • ref_stats (ReferenceLogprobs) – Reference log-probability statistics.

Returns:

Reference token counts grouped by prompt.

Return type:

list[list[float]]

maxent_grpo.training.weighting.logic.weight_vector_from_q(q_values, logp_values, token_counts, weighting_cfg, *, include_reference_term=True, normalize_by_tokens=True)[source]

Convert listwise q-values and reference log-probs into normalized weights.

q_values are already normalized probabilities. Any q-temperature should be applied upstream when those targets are constructed so the listwise posterior does not silently reapply the same temperature here.

Optionally normalize by token counts so each token contributes equally, mitigating length bias when reference log-probabilities are length-sensitive.

Parameters:
  • q_values (list[float]) – Listwise probabilities per completion.

  • logp_values (list[float]) – Reference log-probabilities (or log ratios).

  • token_counts (list[float] | None) – Optional completion token counts for normalization.

  • weighting_cfg (WeightingSettings) – Weighting configuration containing tau/beta.

  • include_reference_term (bool) – Whether to include the reference-model factor.

  • normalize_by_tokens (bool) – Whether to scale weights by token counts.

Returns:

Normalized weights aligned with q_values.

Return type:

list[float]

maxent_grpo.training.weighting.logic.weight_matrix_from_q(q_values, logp_values, token_counts, weighting_cfg, *, include_reference_term=True, normalize_by_tokens=True)[source]

Vectorized listwise weights for [prompts, generations] tensors.

Parameters:
Return type:

Any

maxent_grpo.training.weighting.logic.maybe_update_beta(weighting_cfg, measured_kl)[source]

Adjust beta with a simple KL controller when targets are configured.

Parameters:
  • weighting_cfg (WeightingSettings) – Weighting configuration mutated in-place.

  • measured_kl (float) – Observed KL divergence used for feedback.

Return type:

None

maxent_grpo.training.weighting.logic.apply_meta_controller_update(weighting_cfg, *, tau_grad=None, beta_grad=None, lr_scale=1.0)[source]

Apply a deterministic meta-controller update in analytic mode.

Parameters:
  • weighting_cfg (WeightingSettings) – Weighting configuration mutated in-place.

  • tau_grad (float | None) – Gradient of the controller objective w.r.t. tau.

  • beta_grad (float | None) – Gradient of the controller objective w.r.t. beta.

  • lr_scale (float) – Optional multiplier applied to the meta learning rate.

Returns:

True when any parameter was updated.

Return type:

bool

maxent_grpo.training.weighting.logic.maybe_update_tau(weighting_cfg, weight_stats, global_step, lr_scale=None)[source]

Adjust tau to hit a target weight entropy if configured.

Parameters:
  • weighting_cfg (WeightingSettings) – Weighting configuration mutated in-place.

  • weight_stats (WeightStats | WeightLoggingView | None) – Current batch weight statistics providing entropy. Can be raw per-batch stats or aggregated logging views.

  • global_step (int) – Training step used for warmup/EMA logic.

  • lr_scale (float | None) – Optional multiplicative scale applied to maxent_tau_lr (e.g., to follow the main LR scheduler).

Return type:

None

maxent_grpo.training.weighting.logic.broadcast_controller_state(accelerator, weighting_cfg)[source]

Sync controller scalars (tau, beta, entropy EMA/log) across ranks.

Prefer an all_gather-style sync via accelerator.gather (available on Accelerate 1.x), then fall back to broadcast_object_list when present. Returns True on success.

Parameters:
Return type:

bool

maxent_grpo.training.weighting.logic.controller_state_dict(weighting_cfg)[source]

Return a serializable snapshot of the controller state.

Parameters:

weighting_cfg (WeightingConfigLike) – Weighting configuration containing tau/beta scalars.

Returns:

Dictionary describing controller parameters.

Return type:

dict[str, float]

maxent_grpo.training.weighting.logic.save_controller_state(path, weighting_cfg)[source]

Persist controller parameters to disk.

Parameters:
  • path (str | None) – Destination path for the controller JSON file.

  • weighting_cfg (WeightingConfigLike) – Weighting configuration to serialize.

Return type:

None

maxent_grpo.training.weighting.logic.load_controller_state(path, weighting_cfg)[source]

Load controller parameters if a state file exists.

Parameters:
  • path (str | None) – Filesystem path to a controller JSON file.

  • weighting_cfg (WeightingConfigLike) – Weighting configuration that will receive the values.

Returns:

True when the controller state was loaded successfully.

Return type:

bool

maxent_grpo.training.weighting.logic.collect_weight_entropy(weights_grouped)[source]

Summarize entropy statistics for grouped weights.

Parameters:

weights_grouped (list[list[float]]) – Weight samples grouped per prompt.

Returns:

Tuple containing (mean entropy, min entropy, max entropy, advantage samples).

Return type:

tuple[float, float, float, list[float]]

maxent_grpo.training.weighting.logic.compute_weight_stats(grouped_completions, reward_comp, ref_stats, weighting_cfg)[source]

Compute normalized weights using q-values and reference log-probs.

Parameters:
  • grouped_completions (list[list[str]]) – Completion groups per prompt.

  • reward_comp (RewardComputation) – Reward computation outputs used for q-distributions.

  • ref_stats (ReferenceLogprobs) – Reference-model log-probability statistics.

  • weighting_cfg (WeightingSettings) – Weighting configuration (tau/beta/targets).

Returns:

Weight stats dataclass or None if inputs are empty.

Return type:

WeightStats | None

maxent_grpo.training.weighting.logic.build_uniform_weight_stats(grouped_completions)[source]

Return uniform weights per prompt as a GRPO-style fallback.

Parameters:

grouped_completions (List[List[str]])

Return type:

WeightStats | None