maxent_grpo.training.trl_trainer

Custom TRL GRPOTrainer wrapper used by the MaxEnt-GRPO pipelines.

This module is the single place where GRPO-vs-MaxEnt objective behavior should diverge at runtime. The surrounding training pipeline (dataset mapping, reward loading, model/tokenizer setup, trainer wiring, launch entrypoints) is kept shared so objective comparisons stay fair and easy to audit.

Functions

_adapter_disabled_context(model)

Disable adapters when the model exposes a supported API.

_apply_eos_completion_mask(completion_ids, ...)

Mask completion tokens after the first EOS token (TRL-style).

_build_ema_alias_index(params)

Index tensors by canonicalized names for alias-aware EMA matching.

_build_prompt_text(example, tokenizer)

Render one trainer example into the exact text sent to generation.

_build_rich_rollout_rows(*, step, ...[, ...])

Build prompt-major rollout rows for within-group distribution analysis.

_build_seed_worker(num_workers, rank)

Return a worker_init_fn compatible with the active transformers seed_worker signature.

_canonical_metric_key(key)

Normalize metric aliases to one canonical key namespace.

_clamp_log_delta(delta)

Clamp log-probability deltas before exponentiating.

_coerce_bool(value, *, default)

Convert flexible config values to bool without surprising string truthiness.

_coerce_non_negative_float(value, *[, default])

Convert config values to a finite non-negative float.

_completion_diversity_metrics(...[, ...])

Return coarse diversity metrics for grouped completions.

_empty_dataset_like(dataset)

Return an empty dataset preserving the input dataset type when possible.

_entropy_normalization_scale(valid_vocab_size)

Return the log-vocab normalization constant for exact entropy metrics.

_find_token_prefix_len_for_text(tokenizer, ...)

Best-effort map a decoded text prefix back onto token prefix length.

_flatten_prompt_major_tensor(tensor)

Convert a prompt-major [prompts, generations, ...] tensor to flat order.

_gather_eval_benchmark_ids_for_prompts(...)

Return gathered prompt-major benchmark ids when present.

_is_main_process(trainer)

Return whether the active trainer rank should emit shared metrics.

_legacy_metric_aliases(key)

Return compatibility aliases for a canonical metric key.

_local_metric_tensor(value)

Return a detached local metric tensor without any distributed gather.

_mask_invalid_logit_columns(logits, *, ...)

Mask logit columns that correspond to tokenizer-inaccessible token IDs.

_mean(values)

Return the arithmetic mean for a non-empty list, else 0.0.

_metric_suffix_from_benchmark(name)

Return a metric-safe benchmark suffix (e.g., AIME24).

_metric_tensor_for_logging(trainer, value, ...)

Return a metric tensor for logging, avoiding DDP gathers in local-only eval.

_nanmax_tensor(tensor)

Return the max value while ignoring NaNs.

_nanmin_tensor(tensor)

Return the min value while ignoring NaNs.

_normalize_group_mass_proxy(values)

Convert a per-group signal into a non-negative mass proxy.

_normalize_listwise_q_targets(q_grouped, *, ...)

Validate listwise q targets and project them onto the simplex.

_normalize_text_for_prefix_match(text)

Normalize text for lightweight decode-prefix comparisons.

_numeric_or_none(value)

Best-effort numeric conversion used for logging filters.

_pad_completion_rows(rows, *, pad_token_id, ...)

Pad variable-length completion token rows and return ids + mask tensors.

_pad_logprob_rows(rows, *, device, dtype)

Pad per-token log-prob rows with zeros to a dense tensor.

_reshape_prompt_major_tensor(tensor, group_size)

Reshape prompt-major flat rollouts into [prompts, generations, ...].

_resolve_ema_source_param(ref_name, ...)

Return matching policy tensor for ref_name and whether aliasing was used.

_resolve_prompt_group_sizes(tensor_dict, ...)

Infer flat row count and prompt count for listwise prompt groups.

_resolve_token_id_upper_bound(model[, tokenizer])

Return a conservative upper bound for valid token IDs.

_resolve_tokenizer_vocab_limit(tokenizer)

Return the full positive vocab-size limit exposed by the tokenizer.

_resolve_vocab_size_limit(model)

Return the smallest positive vocab-size limit exposed by the model.

_selected_logps_and_entropy(logits, ...)

Return selected token log-probs and a differentiable entropy term.

_shuffle_listwise_tensor_dict(tensor_dict, ...)

Shuffle prompt groups while preserving candidate order within each group.

_split_listwise_tensor_dict(tensor_dict, ...)

Split buffered listwise tensors by whole prompt groups.

_strip_ema_param_prefixes(name)

Remove known wrapper prefixes used in policy/reference param names.

_strip_mode_prefix(key, mode)

Remove a train/eval prefix from metric keys when applicable.

_supports_adapter_disabled_reference(model)

Return whether the model exposes an adapter-disable reference path.

_token_prefix_search_order(target_len, max_len)

Return a small symmetric search window around a candidate prefix length.

_tokenize_for_diversity(text[, tokenizer])

Tokenize a completion for diversity metrics.

_use_lightweight_greedy_eval(trainer, mode)

Return whether training-time eval is using the lightweight greedy path.

_use_local_only_eval_diversity_metrics(...)

Return whether eval diversity logging should stay local to main rank.

_use_local_only_lightweight_eval_metrics(...)

Return whether greedy-only eval should stay main-rank-only for metrics.

_use_sharded_prompt_major_greedy_eval(...)

Return whether greedy-only eval should shard prompt-major batches across ranks.

_weighted_mean(values, weights)

Return the weighted mean or 0.0 when weights are empty.

_write_rich_rollout_sidecar(*, output_dir, ...)

Persist prompt-major rollout rows for downstream figure generation.

apply_chat_template(example, _tokenizer)

build_custom_grpo_trainer(parent_cls)

Return a GRPOTrainer subclass with MaxEnt hooks enabled.

gather(value)

is_conversational(example)

maybe_apply_chat_template(example, _tokenizer)

wrap_trl_trainer(trainer_cls)

Ensure a trainer class emits TRL-style logs and metrics.

maxent_grpo.training.trl_trainer.build_custom_grpo_trainer(parent_cls)[source]

Return a GRPOTrainer subclass with MaxEnt hooks enabled.

Parameters:

parent_cls (Type[Any]) – Base TRL GRPOTrainer class.

Returns:

Wrapped GRPOTrainer subclass.

Return type:

Type[Any]

maxent_grpo.training.trl_trainer.wrap_trl_trainer(trainer_cls)[source]

Ensure a trainer class emits TRL-style logs and metrics.

Parameters:

trainer_cls (Type[Any])

Return type:

Type[Any]