maxent_grpo.core.model

Copyright 2025 Liv d’Aliberti

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Tokenizer/model loading helpers for training scripts.

This module exposes two utilities:

  • get_tokenizer: Load a tokenizer with optional chat template override.

  • get_model: Load an AutoModelForCausalLM with optional quantization and device map resolution via TRL helpers, respecting attention impl/dtype choices and gradient checkpointing compatibility.

Functions

_force_nonreentrant_checkpointing(model)

Best-effort enforcement of non-reentrant gradient checkpointing.

get_model(model_args, training_args)

Construct the causal LM with optional quantization and device map.

get_tokenizer(model_args, training_args)

Load and optionally customize the tokenizer.

Classes

ChatMessage

Type definition for chat message format.

maxent_grpo.core.model.get_kbit_device_map(*args, **kwargs)
maxent_grpo.core.model.get_quantization_config(*args, **kwargs)
class maxent_grpo.core.model.ChatMessage[source]

Bases: TypedDict

Type definition for chat message format.

role: str
content: str
maxent_grpo.core.model.get_tokenizer(model_args, training_args)[source]

Load and optionally customize the tokenizer.

The function downloads a tokenizer from the Hub using the provided model identifiers. When a chat_template override is configured it is injected into the tokenizer before returning.

Parameters:
  • model_args (trl.ModelConfig) – Model configuration (name, revision, trust flags) used to locate the tokenizer on the Hub.

  • training_args (GRPOConfig) – Training configuration, specifically the optional chat_template used to override the tokenizer template.

Returns:

A pre-trained tokenizer instance.

Return type:

transformers.PreTrainedTokenizer

maxent_grpo.core.model.get_model(model_args, training_args)[source]

Construct the causal LM with optional quantization and device map.

Parameters:
  • model_args (trl.ModelConfig) – Model configuration (quantization, dtype, attention implementation, revision, trust settings) forwarded to from_pretrained.

  • training_args (GRPOConfig) – Training configuration (used for use_cache and gradient checkpointing compatibility).

Returns:

A loaded AutoModelForCausalLM instance, configured with a device map and quantization settings when available.

Return type:

transformers.AutoModelForCausalLM

Raises:

ValueError – Propagated from underlying model loading if identifiers or revisions are invalid.