maxent_grpo.training.trl_trainer¶
Custom TRL GRPOTrainer wrapper used by the MaxEnt-GRPO pipelines.
This module is the single place where GRPO-vs-MaxEnt objective behavior should diverge at runtime. The surrounding training pipeline (dataset mapping, reward loading, model/tokenizer setup, trainer wiring, launch entrypoints) is kept shared so objective comparisons stay fair and easy to audit.
Functions
|
Disable adapters when the model exposes a supported API. |
|
Mask completion tokens after the first EOS token (TRL-style). |
|
Index tensors by canonicalized names for alias-aware EMA matching. |
|
Render one trainer example into the exact text sent to generation. |
|
Build prompt-major rollout rows for within-group distribution analysis. |
|
Return a worker_init_fn compatible with the active transformers seed_worker signature. |
|
Normalize metric aliases to one canonical key namespace. |
|
Clamp log-probability deltas before exponentiating. |
|
Convert flexible config values to bool without surprising string truthiness. |
|
Convert config values to a finite non-negative float. |
|
Return coarse diversity metrics for grouped completions. |
|
Return an empty dataset preserving the input dataset type when possible. |
|
Return the log-vocab normalization constant for exact entropy metrics. |
|
Best-effort map a decoded text prefix back onto token prefix length. |
|
Convert a prompt-major |
|
Return gathered prompt-major benchmark ids when present. |
|
Return whether the active trainer rank should emit shared metrics. |
|
Return compatibility aliases for a canonical metric key. |
|
Return a detached local metric tensor without any distributed gather. |
|
Mask logit columns that correspond to tokenizer-inaccessible token IDs. |
|
Return the arithmetic mean for a non-empty list, else 0.0. |
|
Return a metric-safe benchmark suffix (e.g., |
|
Return a metric tensor for logging, avoiding DDP gathers in local-only eval. |
|
Return the max value while ignoring NaNs. |
|
Return the min value while ignoring NaNs. |
|
Convert a per-group signal into a non-negative mass proxy. |
|
Validate listwise q targets and project them onto the simplex. |
|
Normalize text for lightweight decode-prefix comparisons. |
|
Best-effort numeric conversion used for logging filters. |
|
Pad variable-length completion token rows and return ids + mask tensors. |
|
Pad per-token log-prob rows with zeros to a dense tensor. |
|
Reshape prompt-major flat rollouts into |
|
Return matching policy tensor for |
|
Infer flat row count and prompt count for listwise prompt groups. |
|
Return a conservative upper bound for valid token IDs. |
|
Return the full positive vocab-size limit exposed by the tokenizer. |
|
Return the smallest positive vocab-size limit exposed by the model. |
|
Return selected token log-probs and a differentiable entropy term. |
|
Shuffle prompt groups while preserving candidate order within each group. |
|
Split buffered listwise tensors by whole prompt groups. |
|
Remove known wrapper prefixes used in policy/reference param names. |
|
Remove a train/eval prefix from metric keys when applicable. |
|
Return whether the model exposes an adapter-disable reference path. |
|
Return a small symmetric search window around a candidate prefix length. |
|
Tokenize a completion for diversity metrics. |
|
Return whether training-time eval is using the lightweight greedy path. |
|
Return whether eval diversity logging should stay local to main rank. |
|
Return whether greedy-only eval should stay main-rank-only for metrics. |
|
Return whether greedy-only eval should shard prompt-major batches across ranks. |
|
Return the weighted mean or 0.0 when weights are empty. |
|
Persist prompt-major rollout rows for downstream figure generation. |
|
|
|
Return a GRPOTrainer subclass with MaxEnt hooks enabled. |
|
|
|
|
|
|
|
Ensure a trainer class emits TRL-style logs and metrics. |