"""
Copyright 2025 Liv d'Aliberti
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Tokenizer/model loading helpers for training scripts.
This module exposes two utilities:
- ``get_tokenizer``: Load a tokenizer with optional chat template override.
- ``get_model``: Load an ``AutoModelForCausalLM`` with optional quantization
and device map resolution via TRL helpers, respecting attention impl/dtype
choices and gradient checkpointing compatibility.
"""
from __future__ import annotations
import functools
import logging
from typing import Any, Dict, Optional, TypedDict, TYPE_CHECKING, Union, cast
try: # pragma: no cover - optional dependency
import torch
except (
ImportError,
AttributeError,
OSError,
RuntimeError,
ValueError,
): # pragma: no cover - allow stubbed environments
class _TorchStub:
float16 = "float16"
bfloat16 = "bfloat16"
float32 = "float32"
class dtype: # noqa: N801 - mimic torch.dtype
pass
torch = _TorchStub() # type: ignore[assignment]
try: # pragma: no cover - optional dependency
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.tokenization_utils import (
PreTrainedTokenizer as _PreTrainedTokenizer,
)
except (
ImportError,
AttributeError,
OSError,
RuntimeError,
ValueError,
): # pragma: no cover - allow stubbed environments
class AutoModelForCausalLM: # type: ignore[no-redef]
@classmethod
def from_pretrained(cls, *_args: Any, **_kwargs: Any) -> Any:
raise RuntimeError("transformers is required for AutoModelForCausalLM")
class AutoTokenizer: # type: ignore[no-redef]
@classmethod
def from_pretrained(cls, *_args: Any, **_kwargs: Any) -> Any:
raise RuntimeError("transformers is required for AutoTokenizer")
class _PreTrainedTokenizer: # type: ignore[no-redef]
pass
try: # pragma: no cover - optional dependency
from trl import (
ModelConfig,
get_kbit_device_map,
get_quantization_config,
)
except (
ImportError,
AttributeError,
OSError,
RuntimeError,
ValueError,
): # pragma: no cover - allow stubbed environments
class ModelConfig: # type: ignore[no-redef]
pass
def get_kbit_device_map(*_args: Any, **_kwargs: Any) -> Any:
return None
def get_quantization_config(*_args: Any, **_kwargs: Any) -> Any:
return None
from maxent_grpo.config import GRPOConfig
PreTrainedTokenizer = _PreTrainedTokenizer
if TYPE_CHECKING:
AutoModelForCausalLMType = AutoModelForCausalLM
PreTrainedTokenizerType = PreTrainedTokenizer
else:
AutoModelForCausalLMType = Any
PreTrainedTokenizerType = Any
if TYPE_CHECKING:
ModelConfigType = ModelConfig
else:
ModelConfigType = Any
if TYPE_CHECKING:
from torch import dtype as TorchDType # pragma: no cover
else: # pragma: no cover - runtime fallback when torch.dtype is missing
TorchDType = getattr(torch, "dtype", Any)
LOG = logging.getLogger(__name__)
def _force_nonreentrant_checkpointing(model: Any) -> bool:
"""Best-effort enforcement of non-reentrant gradient checkpointing."""
if model is None:
return False
checkpoint_mod = getattr(getattr(torch, "utils", None), "checkpoint", None)
checkpoint_fn = (
getattr(checkpoint_mod, "checkpoint", None) if checkpoint_mod else None
)
if checkpoint_fn is None:
return False
gc_func = functools.partial(checkpoint_fn, use_reentrant=False)
set_gc = getattr(model, "_set_gradient_checkpointing", None)
if callable(set_gc):
try:
set_gc(enable=True, gradient_checkpointing_func=gc_func)
return True
except TypeError:
try:
set_gc(value=True)
except (AttributeError, TypeError, ValueError):
pass
applied = False
modules = getattr(model, "modules", None)
if callable(modules):
for module in modules():
if hasattr(module, "gradient_checkpointing"):
try:
setattr(module, "_gradient_checkpointing_func", gc_func)
setattr(module, "gradient_checkpointing", True)
applied = True
except (AttributeError, TypeError, ValueError, RuntimeError):
continue
if hasattr(model, "gradient_checkpointing"):
try:
setattr(model, "_gradient_checkpointing_func", gc_func)
setattr(model, "gradient_checkpointing", True)
applied = True
except (AttributeError, TypeError, ValueError, RuntimeError):
pass
return applied
[docs]
class ChatMessage(TypedDict):
"""Type definition for chat message format."""
role: str
content: str
[docs]
def get_tokenizer(
model_args: ModelConfigType, training_args: GRPOConfig
) -> PreTrainedTokenizerType | Any:
"""Load and optionally customize the tokenizer.
The function downloads a tokenizer from the Hub using the provided model
identifiers. When a ``chat_template`` override is configured it is injected
into the tokenizer before returning.
:param model_args: Model configuration (name, revision, trust flags) used to
locate the tokenizer on the Hub.
:type model_args: ``trl.ModelConfig``
:param training_args: Training configuration, specifically the optional
``chat_template`` used to override the tokenizer template.
:type training_args: GRPOConfig
:returns: A pre-trained tokenizer instance.
:rtype: ``transformers.PreTrainedTokenizer``
"""
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
)
if training_args.chat_template is not None:
tokenizer.chat_template = training_args.chat_template
# Keep tokenizer setup identical across GRPO and MaxEnt paths.
pad_token = getattr(tokenizer, "pad_token", None)
eos_token = getattr(tokenizer, "eos_token", None)
if pad_token is None and eos_token is not None:
try:
setattr(tokenizer, "pad_token", eos_token)
except (AttributeError, TypeError, ValueError):
LOG.debug("Failed to set tokenizer.pad_token from eos_token.")
return tokenizer
[docs]
def get_model(
model_args: ModelConfigType, training_args: GRPOConfig
) -> AutoModelForCausalLMType:
"""Construct the causal LM with optional quantization and device map.
:param model_args: Model configuration (quantization, dtype, attention
implementation, revision, trust settings) forwarded to
``from_pretrained``.
:type model_args: ``trl.ModelConfig``
:param training_args: Training configuration (used for ``use_cache`` and
gradient checkpointing compatibility).
:type training_args: GRPOConfig
:returns: A loaded ``AutoModelForCausalLM`` instance, configured with a
device map and quantization settings when available.
:rtype: ``transformers.AutoModelForCausalLM``
:raises ValueError: Propagated from underlying model loading if identifiers
or revisions are invalid.
"""
# Accept strings ("float16"), special values ("auto"/None), or actual torch.dtype
torch_dtype: Union[str, TorchDType, None] = getattr(model_args, "torch_dtype", None)
if torch_dtype in ["auto", None]:
torch_dtype = model_args.torch_dtype
elif isinstance(model_args.torch_dtype, str):
torch_dtype = getattr(torch, model_args.torch_dtype, model_args.torch_dtype)
else:
torch_dtype = model_args.torch_dtype
quantization_config: Optional[Any] = get_quantization_config(
cast(ModelConfig, model_args)
)
device_map: Optional[Dict[str, Any]] = (
get_kbit_device_map() if quantization_config is not None else None
)
model_kwargs: Dict[str, Any] = {
"revision": model_args.model_revision,
"trust_remote_code": model_args.trust_remote_code,
"attn_implementation": model_args.attn_implementation,
"torch_dtype": torch_dtype,
"use_cache": not training_args.gradient_checkpointing,
"device_map": device_map,
"quantization_config": quantization_config,
}
model_name_or_path = getattr(model_args, "model_name_or_path", None)
if not model_name_or_path:
raise ValueError("model_name_or_path must be set in model_args")
model = AutoModelForCausalLM.from_pretrained( # nosec B615
model_name_or_path,
**model_kwargs,
)
try:
cfg = getattr(model, "config", None)
if cfg is not None and getattr(cfg, "pad_token_id", None) is None:
eos_token_id = getattr(cfg, "eos_token_id", None)
if isinstance(eos_token_id, int):
cfg.pad_token_id = eos_token_id
elif isinstance(eos_token_id, (list, tuple)) and eos_token_id:
first = eos_token_id[0]
if isinstance(first, int):
cfg.pad_token_id = first
gen_cfg = getattr(model, "generation_config", None)
if gen_cfg is not None and getattr(gen_cfg, "pad_token_id", None) is None:
cfg = getattr(model, "config", None)
pad_token_id = (
getattr(cfg, "pad_token_id", None) if cfg is not None else None
)
if isinstance(pad_token_id, int):
gen_cfg.pad_token_id = pad_token_id
except (AttributeError, TypeError, ValueError):
LOG.debug("Failed to align model pad_token_id settings.")
if getattr(training_args, "gradient_checkpointing", False):
enable_fn = getattr(model, "gradient_checkpointing_enable", None)
if callable(enable_fn):
gc_kwargs = getattr(training_args, "gradient_checkpointing_kwargs", None)
kwargs = dict(gc_kwargs) if isinstance(gc_kwargs, dict) else {}
if gc_kwargs is not None and not isinstance(gc_kwargs, dict):
LOG.warning(
"Ignoring non-dict gradient_checkpointing_kwargs=%r.",
gc_kwargs,
)
if kwargs:
try:
enable_fn(**kwargs)
except TypeError as exc:
LOG.warning(
"gradient_checkpointing_enable did not accept kwargs (%s); "
"retrying without kwargs.",
exc,
)
try:
enable_fn()
except TypeError:
if not _force_nonreentrant_checkpointing(model):
LOG.warning(
"Failed to enforce non-reentrant checkpointing; "
"model may still use reentrant mode."
)
else:
try:
enable_fn()
except TypeError as exc:
LOG.warning(
"gradient_checkpointing_enable() failed (%s); "
"forcing non-reentrant checkpointing manually.",
exc,
)
if not _force_nonreentrant_checkpointing(model):
LOG.warning(
"Failed to enforce non-reentrant checkpointing; model may still use reentrant mode."
)
if getattr(training_args, "torch_compile", False):
# torch.compile is fragile with DeepSpeed/ZeRO wrapping; skip when deepspeed config is present.
prev_suppress = None
dynamo_mod = None
dynamo_config = None
if getattr(training_args, "deepspeed", None):
LOG.warning(
"Skipping torch.compile because deepspeed is enabled; set torch_compile=false to silence."
)
else:
try:
import torch._dynamo as dynamo_mod
dynamo_config = getattr(dynamo_mod, "config", None)
prev_suppress = (
getattr(dynamo_config, "suppress_errors", None)
if dynamo_config is not None
else None
)
try:
if dynamo_config is not None:
dynamo_config.suppress_errors = (
True # fall back to eager on compile errors
)
except (AttributeError, TypeError):
LOG.debug("Failed to set torch._dynamo suppress_errors flag.")
except (ImportError, AttributeError, RuntimeError):
dynamo_mod = None
compile_fn = getattr(torch, "compile", None)
if callable(compile_fn):
try:
model = compile_fn(model, mode="max-autotune")
except TypeError:
try:
model = compile_fn(model)
except (RuntimeError, TypeError, ValueError):
LOG.warning("torch.compile failed; falling back to eager mode.")
except (RuntimeError, ValueError):
# Best-effort: ignore compilation failures and keep the original model.
LOG.warning("torch.compile failed; falling back to eager mode.")
finally:
if dynamo_config is not None and prev_suppress is not None:
try:
dynamo_config.suppress_errors = prev_suppress
except (AttributeError, TypeError):
LOG.debug(
"Failed to restore torch._dynamo suppress_errors flag."
)
return cast(AutoModelForCausalLMType, model)