maxent_grpo.training.rollout.context

Shared generation context dataclass used by local and vLLM paths.

Classes

GenerationContext(max_prompt_len, ...[, ...])

Configuration required to produce completions for each training batch.

class maxent_grpo.training.rollout.context.GenerationContext(max_prompt_len, max_completion_len, gen_temperature, gen_top_p, use_vllm, vllm, accelerator, model, tokenizer, generation_stats, device, penalty=<factory>, prompt_char_limit=None, *, vllm_mode='server')[source]

Bases: GenerationPenaltyPassthroughMixin, GenerationSamplingConfig

Configuration required to produce completions for each training batch.

Parameters:
accelerator: TypesAccelerator
model: TypesPreTrainedModel
tokenizer: TypesPreTrainedTokenizer
generation_stats: Dict[str, int]
device: Any
penalty: GenerationPenaltyConfig
prompt_char_limit: int | None = None
as_dict()[source]

Return a lightweight representation useful for logging/debugging.

Return type:

Dict[str, Any]

max_prompt_len: int
max_completion_len: int
gen_temperature: float
gen_top_p: float
use_vllm: bool
vllm: VLLMClientConfig