maxent_grpo.training.rollout.context¶
Shared generation context dataclass used by local and vLLM paths.
Classes
|
Configuration required to produce completions for each training batch. |
- class maxent_grpo.training.rollout.context.GenerationContext(max_prompt_len, max_completion_len, gen_temperature, gen_top_p, use_vllm, vllm, accelerator, model, tokenizer, generation_stats, device, penalty=<factory>, prompt_char_limit=None, *, vllm_mode='server')[source]¶
Bases:
GenerationPenaltyPassthroughMixin,GenerationSamplingConfigConfiguration required to produce completions for each training batch.
- Parameters:
max_prompt_len (int)
max_completion_len (int)
gen_temperature (float)
gen_top_p (float)
use_vllm (bool)
vllm (VLLMClientConfig)
accelerator (Accelerator)
model (PreTrainedModel)
tokenizer (PreTrainedTokenizer)
device (Any)
penalty (GenerationPenaltyConfig)
prompt_char_limit (int | None)
vllm_mode (str)
- accelerator: TypesAccelerator¶
- model: TypesPreTrainedModel¶
- tokenizer: TypesPreTrainedTokenizer¶
- device: Any¶
- penalty: GenerationPenaltyConfig¶
- vllm: VLLMClientConfig¶