Source code for maxent_grpo.training.rollout.context
"""Shared generation context dataclass used by local and vLLM paths."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict
from maxent_grpo.training.runtime import GenerationSamplingConfig
from maxent_grpo.training.runtime.prompts import (
GenerationPenaltyConfig,
GenerationPenaltyPassthroughMixin,
)
from ..types import (
Accelerator as TypesAccelerator,
PreTrainedModel as TypesPreTrainedModel,
PreTrainedTokenizer as TypesPreTrainedTokenizer,
)
[docs]
@dataclass
class GenerationContext(GenerationPenaltyPassthroughMixin, GenerationSamplingConfig):
"""Configuration required to produce completions for each training batch."""
accelerator: TypesAccelerator
model: TypesPreTrainedModel
tokenizer: TypesPreTrainedTokenizer
generation_stats: Dict[str, int]
device: Any
penalty: GenerationPenaltyConfig = field(default_factory=GenerationPenaltyConfig)
prompt_char_limit: int | None = None
[docs]
def as_dict(self) -> Dict[str, Any]:
"""Return a lightweight representation useful for logging/debugging."""
return {
"device": str(self.device),
"max_prompt_len": self.max_prompt_len,
"max_completion_len": self.max_completion_len,
"top_k": self.gen_top_k,
"best_of": self.gen_best_of,
"use_vllm": self.use_vllm,
"vllm_mode": getattr(self, "vllm_mode", "server"),
"vllm_url": self.vllm_url,
}
__all__ = ["GenerationContext"]