maxent_grpo.core.model¶
Copyright 2025 Liv d’Aliberti
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
Tokenizer/model loading helpers for training scripts.
This module exposes two utilities:
get_tokenizer: Load a tokenizer with optional chat template override.get_model: Load anAutoModelForCausalLMwith optional quantization and device map resolution via TRL helpers, respecting attention impl/dtype choices and gradient checkpointing compatibility.
Functions
|
Best-effort enforcement of non-reentrant gradient checkpointing. |
|
Construct the causal LM with optional quantization and device map. |
|
Load and optionally customize the tokenizer. |
Classes
Type definition for chat message format. |
- maxent_grpo.core.model.get_kbit_device_map(*args, **kwargs)¶
- maxent_grpo.core.model.get_quantization_config(*args, **kwargs)¶
- class maxent_grpo.core.model.ChatMessage[source]¶
Bases:
TypedDictType definition for chat message format.
- maxent_grpo.core.model.get_tokenizer(model_args, training_args)[source]¶
Load and optionally customize the tokenizer.
The function downloads a tokenizer from the Hub using the provided model identifiers. When a
chat_templateoverride is configured it is injected into the tokenizer before returning.- Parameters:
model_args (
trl.ModelConfig) – Model configuration (name, revision, trust flags) used to locate the tokenizer on the Hub.training_args (GRPOConfig) – Training configuration, specifically the optional
chat_templateused to override the tokenizer template.
- Returns:
A pre-trained tokenizer instance.
- Return type:
transformers.PreTrainedTokenizer
- maxent_grpo.core.model.get_model(model_args, training_args)[source]¶
Construct the causal LM with optional quantization and device map.
- Parameters:
model_args (
trl.ModelConfig) – Model configuration (quantization, dtype, attention implementation, revision, trust settings) forwarded tofrom_pretrained.training_args (GRPOConfig) – Training configuration (used for
use_cacheand gradient checkpointing compatibility).
- Returns:
A loaded
AutoModelForCausalLMinstance, configured with a device map and quantization settings when available.- Return type:
transformers.AutoModelForCausalLM- Raises:
ValueError – Propagated from underlying model loading if identifiers or revisions are invalid.