Source code for maxent_grpo.core.hub

#!/usr/bin/env python
# Copyright 2025 Liv d'Aliberti
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helpers for working with the Hugging Face Hub.

This module provides:

- Upload utilities to push a training output directory to a dedicated branch
  (revision) with basic safety checks.
- Small metadata helpers such as parameter count inference from a repo ID
  (via naming conventions or safetensors metadata) and choosing a valid GPU
  count for vLLM tensor parallelism.

"""

from __future__ import annotations

import logging
import re
from concurrent.futures import Future
from typing import Any, List, Optional, TYPE_CHECKING

try:  # pragma: no cover - optional dependency or incomplete install
    from transformers import AutoConfig
except ModuleNotFoundError as exc:  # pragma: no cover - special-case explicit test signal
    if "transformers missing" in str(exc):
        raise

    class AutoConfig:  # type: ignore[no-redef]
        """Fallback stub when transformers is missing or incomplete."""

        @staticmethod
        def from_pretrained(*_args: Any, **_kwargs: Any) -> Any:
            raise RuntimeError("transformers is not installed or lacks AutoConfig")
except (ImportError, RuntimeError, AttributeError):  # pragma: no cover

    class AutoConfig:  # type: ignore[no-redef]
        """Fallback stub when transformers is missing or incomplete."""

        @staticmethod
        def from_pretrained(*_args: Any, **_kwargs: Any) -> Any:
            raise RuntimeError("transformers is not installed or lacks AutoConfig")


if TYPE_CHECKING:  # pragma: no cover - import types without runtime dependency
    from huggingface_hub import (
        create_branch,
        create_repo,
        get_safetensors_metadata,
        list_repo_commits,
        list_repo_files,
        list_repo_refs,
        repo_exists,
        upload_folder,
        CommitInfo,
    )
    from huggingface_hub.errors import HfHubHTTPError, NotASafetensorsRepoError
else:
    try:  # pragma: no cover - optional dependency
        from huggingface_hub import (
            create_branch,
            create_repo,
            get_safetensors_metadata,
            list_repo_commits,
            list_repo_files,
            list_repo_refs,
            repo_exists,
            upload_folder,
            CommitInfo,
        )
        from huggingface_hub.errors import HfHubHTTPError, NotASafetensorsRepoError
    except ModuleNotFoundError:  # pragma: no cover - provide safe fallbacks for tests

        def create_branch(*_args: Any, **_kwargs: Any) -> None:
            raise RuntimeError("huggingface_hub is not installed")

        def create_repo(*_args: Any, **_kwargs: Any) -> str:
            raise RuntimeError("huggingface_hub is not installed")

        def get_safetensors_metadata(*_args: Any, **_kwargs: Any) -> Any:
            raise RuntimeError("huggingface_hub is not installed")

        def list_repo_commits(*_args: Any, **_kwargs: Any) -> List[Any]:
            return []

        def list_repo_files(*_args: Any, **_kwargs: Any) -> List[str]:
            return []

        class _EmptyBranches:
            branches: List[Any] = []

        def list_repo_refs(*_args: Any, **_kwargs: Any) -> _EmptyBranches:
            return _EmptyBranches()

        def repo_exists(*_args: Any, **_kwargs: Any) -> bool:
            return False

        def upload_folder(*_args: Any, **_kwargs: Any) -> Future["CommitInfo"]:
            raise RuntimeError("huggingface_hub is not installed")

        class CommitInfo:
            commit_id = ""

        class HfHubHTTPError(Exception):
            pass

        class NotASafetensorsRepoError(Exception):
            pass


logger = logging.getLogger(__name__)


if TYPE_CHECKING:  # only for type checking; avoids runtime dependency
    from maxent_grpo.config import GRPOConfig


[docs] def push_to_hub_revision( training_args: "GRPOConfig", extra_ignore_patterns: Optional[List[str]] = None, *, include_checkpoints: bool = False, ) -> Future[CommitInfo]: """Push a checkpoint directory to a branch on the Hub. The helper will create the repository if missing, ensure the target branch exists (forked from the latest commit when possible), and upload the ``output_dir`` contents while ignoring common checkpoint artefacts. Uploads are executed asynchronously via ``run_as_future=True`` to avoid blocking training scripts. :param training_args: Training config with Hub identifiers (``hub_model_id`` and ``hub_model_revision``) and the local ``output_dir`` to upload. :type training_args: GRPOConfig :param include_checkpoints: When True, do not ignore checkpoint-* folders. :type include_checkpoints: bool :param extra_ignore_patterns: Additional filename patterns to ignore during upload; appended to the default ``checkpoint-*`` and ``*.pth`` filters. :type extra_ignore_patterns: list[str] | None :returns: Future that completes when the upload finishes, resolving to the Hub commit metadata. :rtype: concurrent.futures.Future[huggingface_hub.CommitInfo] :raises ValueError: If ``hub_model_id`` is not set in ``training_args``. """ if not training_args.hub_model_id: raise ValueError("hub_model_id must be set in training_args") revision = training_args.hub_model_revision or "main" output_dir = training_args.output_dir if not output_dir: raise ValueError("output_dir must be set in training_args") # Create a repo if it doesn't exist yet repo_url: str = create_repo( repo_id=training_args.hub_model_id, private=True, exist_ok=True ) # Get initial commit to branch from (repo may be empty on first push) try: commits = list_repo_commits(training_args.hub_model_id) initial_commit = commits[-1] if commits else None base_rev: Optional[str] = ( getattr(initial_commit, "commit_id", None) if initial_commit is not None else None ) except HfHubHTTPError: # Fall back to default branch tip base_rev = None # Now create the branch we'll be pushing to create_branch( repo_id=training_args.hub_model_id, branch=revision, revision=base_rev, exist_ok=True, ) logger.info("Created target repo at %s", repo_url) logger.info("Pushing to the Hub revision %s...", revision) ignore_patterns: List[str] = [] if not include_checkpoints: ignore_patterns.extend(["checkpoint-*", "*.pth"]) if extra_ignore_patterns: ignore_patterns.extend(extra_ignore_patterns) future: Future[CommitInfo] = upload_folder( repo_id=training_args.hub_model_id, folder_path=output_dir, revision=revision, commit_message=f"Add {revision} checkpoint", ignore_patterns=ignore_patterns, run_as_future=True, ) logger.info( "Pushed to %s revision %s successfully!", repo_url, revision, ) return future
[docs] def ensure_hf_repo_ready(training_args: "GRPOConfig") -> None: """Verify Hub credentials and provision the target repo/branch upfront. The helper is a best-effort preflight. When Hub access is not configured (or push is disabled), it returns early. Errors in network calls are surfaced as ``RuntimeError`` to avoid silent misconfiguration. :param training_args: Training config with Hub identifiers and push flags. :type training_args: GRPOConfig :returns: ``None``. The function exits early when Hub pushes are disabled. :rtype: None :raises RuntimeError: If the Hub preflight fails due to network or auth errors. """ push_requested = bool( getattr(training_args, "push_to_hub", False) or getattr(training_args, "push_to_hub_revision", False) ) if not push_requested: return repo_id = getattr(training_args, "hub_model_id", None) if not repo_id: logger.warning( "push_to_hub requested but hub_model_id is unset; skipping preflight" ) return revision = getattr(training_args, "hub_model_revision", None) or "main" try: repo_url = create_repo(repo_id=repo_id, private=True, exist_ok=True) base_rev: Optional[str] try: base_rev = list_repo_commits(repo_id)[-1].commit_id except (IndexError, HfHubHTTPError): base_rev = None create_branch( repo_id=repo_id, branch=revision, revision=base_rev, exist_ok=True ) logger.info( "Verified Hugging Face repo %s (revision %s) is ready", repo_url, revision ) except RuntimeError as exc: logger.warning("Skipping Hub preflight: %s", exc) return except ( HfHubHTTPError, OSError, ValueError, ) as exc: # pragma: no cover - network dependent raise RuntimeError( "Failed to preflight Hugging Face Hub access; check credentials/network" ) from exc
[docs] def check_hub_revision_exists(training_args: "GRPOConfig") -> None: """Validate whether a target Hub revision exists and is safe to write. The check avoids clobbering populated branches unless explicitly permitted via ``overwrite_hub_revision``. A README in the branch is treated as a signal that the branch has content. :param training_args: Training config with Hub identifiers and safety flags such as ``push_to_hub_revision`` and ``overwrite_hub_revision``. :type training_args: GRPOConfig :returns: ``None``. Raises if the target revision appears non-empty and overwriting is disallowed. :rtype: None :raises ValueError: If the revision exists and appears non-empty without setting ``overwrite_hub_revision``. """ repo_id = getattr(training_args, "hub_model_id", None) if not repo_id: logger.warning( "push_to_hub_revision requested but hub_model_id is unset; skipping revision check" ) return if repo_exists(repo_id): if training_args.push_to_hub_revision is True: # First check if the revision exists revisions = [rev.name for rev in list_repo_refs(repo_id).branches] # If the revision exists, we next check it has a README file if training_args.hub_model_revision in revisions: repo_files = list_repo_files( repo_id=repo_id, revision=training_args.hub_model_revision, ) if ( "README.md" in repo_files and training_args.overwrite_hub_revision is False ): raise ValueError( f"Revision {training_args.hub_model_revision} already exists. " "Use --overwrite_hub_revision to overwrite it." )
[docs] def get_param_count_from_repo_id(repo_id: Optional[str]) -> int: """Infer parameter count from naming conventions or Hub metadata. Prefers parsing strings like ``42m``, ``1.5b`` or products like ``8x7b`` from the repo ID. Falls back to safetensors metadata when no pattern is found. :param repo_id: Hub repository ID. :type repo_id: str | None :returns: Best guess of total parameter count, or ``-1`` if unknown after attempting both pattern extraction and safetensors metadata lookup, or if ``repo_id`` is missing. :rtype: int """ if not repo_id: return -1 # Pattern to match products (like 8x7b) and single values (like 42m) pattern = r"((\d+(\.\d+)?)(x(\d+(\.\d+)?))?)([bm])" matches = re.findall(pattern, repo_id.lower()) param_counts = [] for _full_match, number1, _, _, number2, _, unit in matches: if number2: # If there's a second number, it's a product number = float(number1) * float(number2) else: # Otherwise, it's a single value number = float(number1) if unit == "b": number *= 1_000_000_000 # Convert to billion elif unit == "m": number *= 1_000_000 # Convert to million param_counts.append(number) if len(param_counts) > 0: # Return the largest number from the string pattern return int(max(param_counts)) # Fallback: try to read from Hub metadata try: # pragma: no cover - behavior depends on environment metadata = get_safetensors_metadata(repo_id) return int(list(metadata.parameter_count.values())[0]) except (HfHubHTTPError, NotASafetensorsRepoError, ValueError, KeyError, TypeError): return -1
[docs] def get_gpu_count_for_vllm( model_name: Optional[str], revision: Optional[str] = "main", num_gpus: int = 8, ) -> int: """Choose a valid GPU count for vLLM tensor parallelism. vLLM requires that the number of attention heads and 64 are divisible by the tensor parallel size. This function decrements ``num_gpus`` until the constraints are satisfied. :param model_name: Model repository ID used to fetch the ``AutoConfig``. :type model_name: str | None :param revision: Repo revision/branch to inspect. :type revision: str | None :param num_gpus: Starting number of GPUs available; decremented until the constraints are satisfied. :type num_gpus: int :returns: A compatible number of GPUs for vLLM tensor parallelism. :rtype: int """ if num_gpus <= 0: logger.warning("Invalid num_gpus=%d supplied; defaulting to 1", num_gpus) num_gpus = 1 if not model_name: logger.warning("Missing model_name; using num_gpus=%d", num_gpus) return num_gpus safe_revision = revision or "main" try: config = AutoConfig.from_pretrained( model_name, revision=safe_revision, trust_remote_code=True ) except ( OSError, RuntimeError, ValueError, ) as exc: # pragma: no cover - network dependent logger.warning( "Unable to load config for %s (revision %s): %s; using num_gpus=%d", model_name, safe_revision, exc, num_gpus, ) return num_gpus # Get number of attention heads num_heads = getattr(config, "num_attention_heads", None) if not isinstance(num_heads, int) or num_heads <= 0: logger.warning( "Unable to infer num_attention_heads for %s; using num_gpus=%d", model_name, num_gpus, ) return max(1, num_gpus) # Reduce num_gpus so that num_heads is divisible by num_gpus and 64 is divisible by num_gpus while num_heads % num_gpus != 0 or 64 % num_gpus != 0: logger.info( "Reducing num_gpus from %d to %d to make num_heads divisible by num_gpus", num_gpus, num_gpus - 1, ) num_gpus -= 1 return num_gpus