maxent_grpo.training.types¶
Copyright 2025 Liv d’Aliberti
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
- class maxent_grpo.training.types.AdvantageStats(grouped, samples)[source]¶
Bases:
objectGrouped and flattened advantages.
- class maxent_grpo.training.types.BatchingSettings(logprob_chunk_size, score_slice, prompt_length_cache_get=None, score_tail_tokens=None, slice_prefetch=0, prompt_cache_size=0)[source]¶
Bases:
objectScoring batch/chunk hints.
- Parameters:
- prompt_length_cache_get: Callable[[str], PromptCacheEntry] | None = None¶
- class maxent_grpo.training.types.ControllerPaths(state_path, resume_from, overwrite_existing=False)[source]¶
Bases:
objectFilesystem locations for adaptive controller state.
- class maxent_grpo.training.types.GenerationBatch(prompts, answers, grouped_completions, grouped_ref_meta, grouped_completion_info=None)[source]¶
Bases:
objectCompletions grouped per prompt after filtering.
- Parameters:
- class maxent_grpo.training.types.LoggingConfigView(weighting, clipping, schedule)[source]¶
Bases:
objectPointers to configs referenced while logging.
- Parameters:
weighting (WeightingSettings)
clipping (ClipSettings)
schedule (OptimizationSchedule)
- weighting: WeightingSettings¶
- clipping: ClipSettings¶
- schedule: OptimizationSchedule¶
- class maxent_grpo.training.types.LogStepArtifacts(loss_outputs, diagnostics, grad_norm_scalar, epoch_progress)[source]¶
Bases:
objectHelper container for optimizer/loss diagnostics per step.
- Parameters:
loss_outputs (LossOutputs)
diagnostics (BatchDiagnostics)
grad_norm_scalar (float | None)
epoch_progress (float)
- loss_outputs: LossOutputs¶
- diagnostics: BatchDiagnostics¶
- class maxent_grpo.training.types.LoopSettings(generation, evaluation, optimization, scoring, controller, controller_objective=None, controller_meta_manager=None)[source]¶
Bases:
objectGrouped training configuration shared across the loop.
- Parameters:
generation (GenerationSettings)
evaluation (EvaluationSettings)
optimization (OptimizationSettings)
scoring (ScoringSettings)
controller (ControllerPaths)
controller_objective (Optional['ControllerObjective'])
controller_meta_manager (Optional[Any])
- generation: GenerationSettings¶
- evaluation: EvaluationSettings¶
- optimization: OptimizationSettings¶
- scoring: ScoringSettings¶
- controller: ControllerPaths¶
- class maxent_grpo.training.types.LossScalarBundle(total_loss, policy_loss, clip_loss, kl_loss, weighted_kl_loss)[source]¶
Bases:
objectScalar contributions tracked for logging.
- Parameters:
- class maxent_grpo.training.types.MetricWriter(*args, **kwargs)[source]¶
Bases:
ProtocolProtocol describing a metric writer used by the training loop.
- class maxent_grpo.training.types.OptimizationSettings(schedule, handles)[source]¶
Bases:
objectCombined optimization metadata.
- Parameters:
schedule (OptimizationSchedule)
handles (OptimizerHandles)
- schedule: OptimizationSchedule¶
- handles: OptimizerHandles¶
- class maxent_grpo.training.types.PromptCompletionBatch(prompts, completions, metadata=None)[source]¶
Bases:
objectFlattened prompt/completion pairs.
- class maxent_grpo.training.types.QDistribution(grouped, samples)[source]¶
Bases:
objectSequence-level q-distribution.
- class maxent_grpo.training.types.RewardComponentStats(mean, std)[source]¶
Bases:
objectMean/std summary for an individual reward component.
- class maxent_grpo.training.types.RewardMoments(mean, std)[source]¶
Bases:
objectSummary statistics for sequence rewards.
- class maxent_grpo.training.types.RuntimeHandles(accelerator, model, tokenizer, train_loader, train_sampler, device, get_ref_model, reference_model=None, prompt_cache_get=None)[source]¶
Bases:
objectPointers to objects that should live for the entire training job.
- Parameters:
accelerator (Accelerator)
model (PreTrainedModel)
tokenizer (PreTrainedTokenizer)
train_loader (maxent_grpo.training.types.runtime.DataLoader)
train_sampler (torch.utils.data.Sampler | None)
device (Any)
get_ref_model (Callable[[], PreTrainedModel])
reference_model (PreTrainedModel | None)
prompt_cache_get (Callable[[str], PromptCacheEntry] | None)
- accelerator: Accelerator¶
- model: PreTrainedModel¶
- tokenizer: PreTrainedTokenizer¶
- train_loader: DataLoader¶
- device: Device¶
- get_ref_model: Callable[[], PreTrainedModel]¶
- reference_model: PreTrainedModel | None = None¶
- prompt_cache_get: Callable[[str], PromptCacheEntry] | None = None¶
- class maxent_grpo.training.types.TokenUsageStats(avg_completion_tokens, num_completion_tokens, num_input_tokens)[source]¶
Bases:
objectAggregate completion/input token statistics.
- class maxent_grpo.training.types.TrainingScalarStats(ref_logp_mean, tokens, current_lr, grad_norm_scalar, epoch_progress, vllm_latency_ms, policy_entropy=None, entropy_bonus_coef=None, entropy_bonus_reward_std=None)[source]¶
Bases:
objectScalar values that vary every logging step.
- Parameters:
- tokens: TokenUsageStats¶
- property avg_completion_tokens: float¶
Return the average completion token length.
- Returns:
Running average of completion token counts.
- Return type:
- class maxent_grpo.training.types.SequenceScores(cur_logp_sum, behavior_logp_sum, log_ratio_train, denom_tok_tensor, pooled_hidden=None, policy_entropy_sum=None, token_logp=None, token_mask=None, old_token_logp=None)[source]¶
Bases:
objectBundle sequence-level log-prob statistics.
- Parameters:
Modules