Source code for maxent_grpo.training.seed_paper_eval_callback

"""Trainer callback for official SEED paper-style eval against the live vLLM server."""

from __future__ import annotations

import json
import logging
import os
import re
import subprocess
import time
from pathlib import Path
from typing import Any, Optional

try:
    from transformers.trainer_callback import TrainerCallback
except (ImportError, ModuleNotFoundError):  # pragma: no cover - optional in tests
    TrainerCallback = object  # type: ignore[assignment]

from maxent_grpo.seed_paper_eval import (
    build_seed_paper_eval_payload,
    build_step0_wandb_payload,
    default_python_executable,
    default_workspace_dir,
)

LOG = logging.getLogger(__name__)


def _repo_root() -> Path:
    return Path(__file__).resolve().parents[3]


def _slugify(raw: str) -> str:
    text = re.sub(r"[^A-Za-z0-9._-]+", "-", raw.strip())
    text = text.strip("-")
    return text or "unnamed-run"


def _resolve_results_root(training_args: Any) -> Path:
    configured = getattr(training_args, "seed_paper_eval_results_dir", None)
    if configured:
        return Path(str(configured)).expanduser().absolute()
    return _repo_root() / "var" / "artifacts" / "seed_paper_eval" / "live"


def _resolve_workspace_dir(training_args: Any) -> Path:
    configured = getattr(training_args, "seed_paper_eval_workspace_dir", None)
    if configured:
        return Path(str(configured)).expanduser().absolute()
    return default_workspace_dir(_repo_root())


def _resolve_step0_results_dir(training_args: Any) -> Path | None:
    configured = getattr(training_args, "seed_paper_eval_step0_results_dir", None)
    if configured:
        return Path(str(configured)).expanduser().absolute()
    env_value = os.environ.get("MAXENT_STEP0_PAPER_EVAL_RESULTS_DIR")
    if env_value and env_value.strip():
        return Path(env_value.strip()).expanduser().absolute()
    return None


def _resolve_python(training_args: Any) -> Path:
    configured = getattr(training_args, "seed_paper_eval_python", None)
    if configured:
        return Path(str(configured)).expanduser().absolute()
    return default_python_executable(_repo_root())


def _resolve_vllm_url() -> str | None:
    for key in ("MAXENT_VLLM_URL", "VLLM_URL"):
        value = os.environ.get(key)
        if value and value.strip():
            return value.strip()
    return None


def _env_int(keys: tuple[str, ...]) -> int | None:
    for key in keys:
        raw = os.environ.get(key)
        if raw is None:
            continue
        text = str(raw).strip()
        if not text:
            continue
        try:
            return int(text)
        except (TypeError, ValueError):
            continue
    return None


def _resolve_model_name_from_recipe_env() -> str:
    for key in ("GRPO_RECIPE_USED", "GRPO_RECIPE"):
        recipe_path = os.environ.get(key)
        if not recipe_path:
            continue
        try:
            content = Path(recipe_path).read_text(encoding="utf-8")
        except OSError:
            continue
        match = re.search(r"^model_name_or_path:\s*(.+?)\s*$", content, re.MULTILINE)
        if not match:
            continue
        value = match.group(1).strip().strip("'").strip('"')
        if value:
            return value
    return ""


def _resolve_model_name(training_args: Any) -> str:
    for key in ("model_name_or_path", "hub_model_id", "model_id"):
        value = getattr(training_args, key, None)
        if isinstance(value, str) and value.strip():
            return value.strip()
    value = _resolve_model_name_from_recipe_env()
    if value:
        return value
    for key in (
        "SEED_PAPER_EVAL_MODEL_NAME",
        "MAXENT_MODEL",
        "GRPO_MODEL",
        "MODEL",
    ):
        value = os.environ.get(key)
        if value and value.strip():
            return value.strip()
    return ""


[docs] def build_live_seed_paper_eval_command( training_args: Any, *, step: int, ) -> tuple[list[str], Path]: repo_root = _repo_root() run_name = str( getattr(training_args, "run_name", None) or getattr(training_args, "output_dir", None) or "unnamed-run" ) results_dir = ( _resolve_results_root(training_args) / _slugify(run_name) / f"step-{int(step):06d}" ) command = [ str(_resolve_python(training_args)), str(repo_root / "tools" / "seed_paper_eval.py"), "--model-name", _resolve_model_name(training_args), "--workspace-dir", str(_resolve_workspace_dir(training_args)), "--results-dir", str(results_dir), "--vllm-url", str(_resolve_vllm_url() or ""), ] template = getattr(training_args, "seed_paper_eval_template", None) if template: command.extend(["--template", str(template)]) tasks = getattr(training_args, "seed_paper_eval_tasks", None) if tasks: command.extend(["--tasks", str(tasks)]) max_test = getattr(training_args, "seed_paper_eval_max_test", None) if max_test is not None: command.extend(["--max-test", str(int(max_test))]) batch_size = getattr(training_args, "seed_paper_eval_vllm_batch_size", None) if batch_size is not None: command.extend(["--vllm-batch-size", str(int(batch_size))]) if bool(getattr(training_args, "seed_paper_reward_fast", False)): command.append("--seed-paper-reward-fast") if bool(getattr(training_args, "seed_paper_eval_pass_at_8_enabled", False)): command.append("--pass-at-8") pass_at_8_samples = getattr( training_args, "seed_paper_eval_pass_at_8_samples", None, ) if pass_at_8_samples is not None: command.extend(["--pass-at-8-samples", str(int(pass_at_8_samples))]) pass_at_8_temperature = getattr( training_args, "seed_paper_eval_pass_at_8_temperature", None, ) if pass_at_8_temperature is not None: command.extend( ["--pass-at-8-temperature", str(float(pass_at_8_temperature))] ) pass_at_8_top_p = getattr( training_args, "seed_paper_eval_pass_at_8_top_p", None, ) if pass_at_8_top_p is not None: command.extend(["--pass-at-8-top-p", str(float(pass_at_8_top_p))]) return command, results_dir
def _resolve_process_rank(training_args: Any) -> int: value = _env_int(("RANK", "SLURM_PROCID")) if value is not None: return max(0, value) raw = getattr(training_args, "process_index", None) try: return max(0, int(raw)) except (TypeError, ValueError): return 0 def _resolve_world_size(training_args: Any) -> int: value = _env_int(("WORLD_SIZE", "SLURM_NTASKS")) if value is not None: return max(1, value) raw = getattr(training_args, "world_size", None) try: return max(1, int(raw)) except (TypeError, ValueError): return 1 def _coordination_dir(training_args: Any, *, step: int) -> Path: _, results_dir = build_live_seed_paper_eval_command(training_args, step=step) return results_dir / "_coord" def _write_json_atomic(path: Path, payload: dict[str, object]) -> None: path.parent.mkdir(parents=True, exist_ok=True) tmp_path = path.with_name(f"{path.name}.tmp-{os.getpid()}") tmp_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") tmp_path.replace(path) def _wait_for_rank_arrivals( coord_dir: Path, *, world_size: int, timeout_s: int, ) -> None: deadline = time.monotonic() + max(1, timeout_s) while time.monotonic() < deadline: arrivals = list(coord_dir.glob("arrived-rank*.json")) if len(arrivals) >= world_size: return time.sleep(1.0) raise TimeoutError( f"Timed out waiting for {world_size} ranks to enter live SEED eval under {coord_dir}" ) def _mark_rank_release( coord_dir: Path, *, rank: int, world_size: int, step: int, ) -> None: release_path = coord_dir / f"released-rank{rank:05d}.json" _write_json_atomic( release_path, {"rank": rank, "world_size": world_size, "pid": os.getpid(), "step": int(step)}, ) def _wait_for_rank_releases( coord_dir: Path, *, world_size: int, timeout_s: int, ) -> None: deadline = time.monotonic() + max(1, timeout_s) while time.monotonic() < deadline: releases = list(coord_dir.glob("released-rank*.json")) if len(releases) >= world_size: return time.sleep(1.0) raise TimeoutError( f"Timed out waiting for {world_size} ranks to leave live SEED eval under {coord_dir}" ) def _wait_for_result_payload( result_path: Path, *, timeout_s: int, ) -> dict[str, object]: deadline = time.monotonic() + max(1, timeout_s) last_error: Exception | None = None while time.monotonic() < deadline: if result_path.exists(): try: return json.loads(result_path.read_text(encoding="utf-8")) except json.JSONDecodeError as exc: last_error = exc time.sleep(1.0) if last_error is not None: raise TimeoutError( f"Timed out waiting for a valid live SEED eval result payload at {result_path}: {last_error}" ) from last_error raise TimeoutError(f"Timed out waiting for live SEED eval result payload at {result_path}") def _latest_summary_path(results_dir: Path) -> Path | None: summaries = sorted(results_dir.glob("*.summary.json")) return summaries[-1] if summaries else None def _wandb_run() -> Any | None: try: import wandb except ImportError: # pragma: no cover - optional dependency return None return getattr(wandb, "run", None) def _log_summary_to_wandb( *, summary: dict[str, object], summary_path: Path, step: int, ) -> None: run = _wandb_run() if run is None: return payload: dict[str, object] = {} for prefix in ("seed_paper_eval_live", "paper_eval"): _define_metric_axis(run, prefix=prefix) payload.update(build_seed_paper_eval_payload(summary, prefix=prefix)) payload[f"{prefix}/training_step"] = int(step) payload[f"{prefix}/ok"] = 1.0 if payload: run.log(payload, commit=True) for key, value in payload.items(): run.summary[key] = value for prefix in ("seed_paper_eval_live", "paper_eval"): run.summary[f"{prefix}/latest_summary_path"] = str(summary_path) run.summary[f"{prefix}/latest_step"] = int(step) run.summary[f"{prefix}/status"] = "ok" run.summary["paper_eval/source"] = "live" warning = summary.get("process_warning") if warning is not None: run.summary["seed_paper_eval_live/process_warning"] = str(warning) run.summary["paper_eval/process_warning"] = str(warning) def _log_step0_summary_to_wandb( *, summary: dict[str, object], summary_path: Path, ) -> bool: run = _wandb_run() if run is None: return False payload: dict[str, object] = {} _define_metric_axis(run, prefix="step0_paper_eval") _define_metric_axis(run, prefix="paper_eval") payload.update(build_step0_wandb_payload(summary)) payload.update(build_seed_paper_eval_payload(summary, prefix="paper_eval")) payload["step0_paper_eval/training_step"] = 0 payload["step0_paper_eval/ok"] = 1.0 payload["paper_eval/training_step"] = 0 payload["paper_eval/ok"] = 1.0 if payload: run.log(payload, commit=True) for key, value in payload.items(): run.summary[key] = value run.summary["step0_paper_eval/summary_path"] = str(summary_path) run.summary["step0_paper_eval/latest_step"] = 0 run.summary["step0_paper_eval/status"] = "ok" run.summary["paper_eval/summary_path"] = str(summary_path) run.summary["paper_eval/latest_step"] = 0 run.summary["paper_eval/status"] = "ok" run.summary["paper_eval/source"] = "step0" warning = summary.get("process_warning") if warning is not None: run.summary["step0_paper_eval/process_warning"] = str(warning) run.summary["paper_eval/process_warning"] = str(warning) comparison = summary.get("expected_comparison") if isinstance(comparison, dict) and not bool(comparison.get("ok")): run.summary["step0_paper_eval/status"] = "expected_mismatch" run.summary["paper_eval/status"] = "expected_mismatch" return True def _define_metric_axis(run: Any, *, prefix: str) -> None: define_metric = getattr(run, "define_metric", None) if not callable(define_metric): return step_key = f"{prefix}/training_step" define_metric(step_key) define_metric(f"{prefix}/*", step_metric=step_key) def _sync_step0_summary_to_current_run(training_args: Any) -> bool: results_dir = _resolve_step0_results_dir(training_args) if results_dir is None: return False summary_path = _latest_summary_path(results_dir) if summary_path is None: return False summary = json.loads(summary_path.read_text(encoding="utf-8")) return _log_step0_summary_to_wandb(summary=summary, summary_path=summary_path)
[docs] class SeedPaperEvalCallback(TrainerCallback): """Run the official SEED paper-style eval against the live vLLM server.""" def __init__(self, training_args: Any) -> None: self.training_args = training_args self._seen_steps: set[int] = set() def _built_in_eval_enabled(self) -> bool: if not bool(getattr(self.training_args, "do_eval", False)): return False eval_strategy = getattr(self.training_args, "eval_strategy", None) if eval_strategy is None: return False text = str(eval_strategy).strip().lower() return text not in {"", "no", "none"} def _eval_steps(self) -> int: raw = getattr(self.training_args, "eval_steps", 0) try: return max(0, int(raw or 0)) except (TypeError, ValueError): return 0 def _run_live_eval_once(self, step: int) -> tuple[bool, str | None]: if step in self._seen_steps: return True, None vllm_url = _resolve_vllm_url() if not vllm_url: LOG.warning( "Skipping live SEED paper eval at step %s because no vLLM URL is available.", step, ) return True, None command, results_dir = build_live_seed_paper_eval_command( self.training_args, step=step, ) env = os.environ.copy() env["PYTHONNOUSERSITE"] = "0" env["WANDB_DISABLED"] = "true" timeout_s = int( getattr(self.training_args, "seed_paper_eval_timeout_s", 14400) or 14400 ) LOG.info( "Running live SEED paper eval | step=%s | results_dir=%s | cmd=%s", step, results_dir, " ".join(command), ) try: subprocess.run( command, check=True, cwd=str(_repo_root()), env=env, text=True, timeout=timeout_s, ) summary_path = _latest_summary_path(results_dir) if summary_path is None: raise RuntimeError(f"No summary JSON was produced under {results_dir}") summary = json.loads(summary_path.read_text(encoding="utf-8")) _log_summary_to_wandb(summary=summary, summary_path=summary_path, step=step) self._seen_steps.add(step) return True, None except Exception as exc: LOG.warning("Live SEED paper eval failed at step %s: %s", step, exc) run = _wandb_run() if run is not None: for prefix in ("seed_paper_eval_live", "paper_eval"): _define_metric_axis(run, prefix=prefix) run.log( { "seed_paper_eval_live/training_step": int(step), "seed_paper_eval_live/ok": 0.0, "paper_eval/training_step": int(step), "paper_eval/ok": 0.0, }, commit=True, ) run.summary["seed_paper_eval_live/last_error"] = str(exc) run.summary["seed_paper_eval_live/latest_step"] = int(step) run.summary["seed_paper_eval_live/status"] = "failed" run.summary["paper_eval/last_error"] = str(exc) run.summary["paper_eval/latest_step"] = int(step) run.summary["paper_eval/status"] = "failed" run.summary["paper_eval/source"] = "live" return False, str(exc) def _run_live_eval(self, step: int) -> None: rank = _resolve_process_rank(self.training_args) world_size = _resolve_world_size(self.training_args) timeout_s = int( getattr(self.training_args, "seed_paper_eval_timeout_s", 14400) or 14400 ) fail_on_error = bool( getattr(self.training_args, "seed_paper_eval_fail_on_error", False) ) if world_size <= 1: ok, error = self._run_live_eval_once(step) if not ok and fail_on_error: raise RuntimeError( f"Live SEED paper eval failed at step {step}: {error}" ) return coord_dir = _coordination_dir(self.training_args, step=step) coord_dir.mkdir(parents=True, exist_ok=True) arrival_path = coord_dir / f"arrived-rank{rank:05d}.json" _write_json_atomic( arrival_path, {"rank": rank, "world_size": world_size, "pid": os.getpid(), "step": int(step)}, ) result_path = coord_dir / "result.json" ok = False error: str | None = None if rank == 0: _wait_for_rank_arrivals(coord_dir, world_size=world_size, timeout_s=timeout_s) ok, error = self._run_live_eval_once(step) _write_json_atomic( result_path, { "ok": bool(ok), "error": error, "rank": rank, "step": int(step), "world_size": world_size, }, ) else: payload = _wait_for_result_payload(result_path, timeout_s=timeout_s + 300) ok = bool(payload.get("ok")) error_obj = payload.get("error") error = str(error_obj) if error_obj is not None else None # Hold every rank here until all peers have fully observed the eval result # and are ready to return to the trainer. Without this exit rendezvous, # rank 0 can enter checkpoint save collectives while other ranks are still # unwinding the eval callback, which can deadlock NCCL. _mark_rank_release(coord_dir, rank=rank, world_size=world_size, step=step) _wait_for_rank_releases(coord_dir, world_size=world_size, timeout_s=timeout_s + 300) if ok: self._seen_steps.add(step) return if fail_on_error: raise RuntimeError( f"Live SEED paper eval failed at step {step}: {error}" )
[docs] def on_train_begin( self, args: Any, state: Any, control: Any, **kwargs: Any, ) -> Any: _ = args _ = kwargs if not bool(getattr(self.training_args, "seed_paper_eval_enabled", False)): return control if self._built_in_eval_enabled(): return control trigger_on_start = getattr( self.training_args, "seed_paper_eval_on_start", getattr(self.training_args, "eval_on_start", False), ) if bool(trigger_on_start): step = int(getattr(state, "global_step", 0) or 0) synced_step0 = False if step == 0: step0_results_dir = _resolve_step0_results_dir(self.training_args) if step0_results_dir is not None and _latest_summary_path(step0_results_dir) is not None: synced_step0 = True self._seen_steps.add(0) if bool(getattr(state, "is_world_process_zero", True)): _sync_step0_summary_to_current_run(self.training_args) if not synced_step0: self._run_live_eval(step) return control
[docs] def on_step_end( self, args: Any, state: Any, control: Any, **kwargs: Any, ) -> Any: _ = args _ = kwargs if not bool(getattr(self.training_args, "seed_paper_eval_enabled", False)): return control if self._built_in_eval_enabled(): return control step = int(getattr(state, "global_step", 0) or 0) eval_steps = self._eval_steps() if step > 0 and eval_steps > 0 and step % eval_steps == 0: self._run_live_eval(step) return control
[docs] def on_evaluate( self, args: Any, state: Any, control: Any, **kwargs: Any, ) -> Any: _ = args _ = kwargs if not bool(getattr(self.training_args, "seed_paper_eval_enabled", False)): return control if not self._built_in_eval_enabled(): return control step = int(getattr(state, "global_step", 0) or 0) self._run_live_eval(step) return control
[docs] def on_train_end( self, args: Any, state: Any, control: Any, **kwargs: Any, ) -> Any: _ = args _ = kwargs if not bool(getattr(self.training_args, "seed_paper_eval_enabled", False)): return control if self._built_in_eval_enabled(): return control step = int(getattr(state, "global_step", 0) or 0) if step > 0: self._run_live_eval(step) return control
__all__ = [ "SeedPaperEvalCallback", "build_live_seed_paper_eval_command", ]