maxent_grpo.training.baseline

Minimal GRPO training entrypoint built on TRL.

This script wires up a standard trl.GRPOTrainer with:

  • Dataset loading via core.data.get_dataset.

  • Simple chat‑templated prompts built from a dataset column.

  • A small registry of reward functions from maxent_grpo.rewards.basic.

It aims to be a clean baseline without experimental features (e.g., replay buffers, schedulers, or custom trainers). Use together with maxent_grpo.config.ScriptArguments/maxent_grpo.config.GRPOConfig and TRL’s TrlParser.

Key functions

  • _to_prompt: Convert a dataset row to a chat prompt + gold answer.

  • main: Load data/model, construct GRPOTrainer, train/eval, and handle Hub push and model card creation.

License Copyright 2025 Liv d’Aliberti

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Functions

_builtin_weight_transfer_trainer_init(...)

Initialize trainer-side vLLM weight transfer across version variants.

_canonical_eval_benchmark_label(spec)

Return stable benchmark labels used in eval metric suffixes.

_clear_vllm_client_buffer(client)

Reset any buffered trainer-side weight updates.

_collect_dataset_columns(dataset)

Return per-split column names when discoverable.

_encode_vllm_batched_update(names, dtypes, ...)

Encode batched vLLM weight metadata through TRL's legacy request model.

_ensure_split_mapping(dataset)

Coerce dataset-like objects into a split->dataset mapping.

_force_vllm_dtype(training_args[, tokenizer])

Ensure TRL vLLM init respects local dtype and colocate engine overrides.

_get_column_names(dataset)

Return a best-effort list of column names for a dataset split.

_guided_decoding_kwargs(guided_decoding)

Extract vLLM guided-decoding fields across version variants.

_import_builtin_vllm_weight_transfer()

Return vLLM's built-in NCCL transfer engine when available.

_loopback_host(base_url)

_main_process_first(training_args, desc)

Return a process-ordering context when TrainingArguments provides one.

_maybe_align_model_tokenizer_vocab(model, ...)

Resize model embeddings when tokenizer exposes additional addressable ids.

_normalize_vllm_generate_url(base_url)

Return the canonical /generate endpoint for a vLLM server base URL.

_patch_trl_vllm_client_init()

Patch TRL VLLMClient init handshake to avoid POST-first deadlocks.

_patch_vllm_guided_decoding_compat()

Bridge TRL 0.18 guided decoding onto vLLM 0.16 structured outputs.

_resolve_eval_dataset_preset(spec)

Resolve built-in benchmark aliases used by training eval configs.

_resolve_eval_dataset_spec(spec, *, ...)

Resolve one evaluation dataset spec (preset alias or HF dataset id).

_resolve_prompt_column(dataset, prompt_column)

Return an inferred prompt column when the default is missing.

_resolve_vllm_client_generate_boundary(client)

Resolve tokenizer/model boundary metadata for live server-mode rollouts.

_resolve_vllm_group_port()

Resolve the vLLM communicator port from launcher environment.

_split_eval_dataset_specs(raw_name)

Return normalized evaluation dataset spec entries from config.

_temporary_env(overrides)

Temporarily set environment variables while preserving prior values.

_tensor_nbytes(tensor)

Best-effort tensor size in bytes for batching decisions.

_validate_dataset_columns(dataset, *, ...)

Fail fast if required dataset columns are missing.

_validate_vllm_completion_ids(...)

Fail fast when live rollouts contain tokenizer-inaccessible token IDs.

_vllm_client_nccl_overrides(base_url)

Return conservative NCCL settings for loopback vLLM sync.

_vllm_sync_chunk_bytes()

Return the max weight-sync batch size for server-mode vLLM updates.

run_baseline_training(script_args, ...)

Entrypoint that loads data/model, builds trainer, and runs GRPO.

Classes

ChatTemplate(*args, **kwargs)

Protocol for objects with chat templating capabilities.

_LazyModuleProxy(module_name)

Proxy that lazily imports a module on first attribute access.

maxent_grpo.training.baseline.get_reward_funcs(script_args, _ref_model=None, _tokenizer=None)[source]

Resolve reward function callables from names.

Parameters:
  • script_args (RewardConfig)

  • _ref_model (Optional['PreTrainedModel'])

  • _tokenizer (Optional['PreTrainedTokenizerBase'])

Return type:

List[’RewardFunction’]

maxent_grpo.training.baseline.run_baseline_training(script_args, training_args, model_args)[source]

Entrypoint that loads data/model, builds trainer, and runs GRPO.

The function also performs a small eval subsample for speed if training_args.do_eval is enabled and an eval split exists.

Parameters:
  • script_args (GRPOScriptArguments) – Script configuration including dataset and rewards.

  • training_args (GRPOConfig) – GRPO trainer arguments from TRL.

  • model_args (trl.ModelConfig) – Model configuration for TRL/transformers.

Returns:

None. Side effects include training, evaluation, and checkpointing.

Return type:

None