maxent_grpo.training.types

Copyright 2025 Liv d’Aliberti

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

class maxent_grpo.training.types.AdvantageStats(grouped, samples)[source]

Bases: object

Grouped and flattened advantages.

Parameters:
grouped: List[List[float]]
samples: List[float]
class maxent_grpo.training.types.BatchingSettings(logprob_chunk_size, score_slice, prompt_length_cache_get=None, score_tail_tokens=None, slice_prefetch=0, prompt_cache_size=0)[source]

Bases: object

Scoring batch/chunk hints.

Parameters:
logprob_chunk_size: int
score_slice: int
prompt_length_cache_get: Callable[[str], PromptCacheEntry] | None = None
score_tail_tokens: int | None = None
slice_prefetch: int = 0
prompt_cache_size: int = 0
class maxent_grpo.training.types.ControllerPaths(state_path, resume_from, overwrite_existing=False)[source]

Bases: object

Filesystem locations for adaptive controller state.

Parameters:
  • state_path (str | None)

  • resume_from (str | None)

  • overwrite_existing (bool)

state_path: str | None
resume_from: str | None
overwrite_existing: bool = False
class maxent_grpo.training.types.GenerationBatch(prompts, answers, grouped_completions, grouped_ref_meta, grouped_completion_info=None)[source]

Bases: object

Completions grouped per prompt after filtering.

Parameters:
prompts: List[str]
answers: List[str]
grouped_completions: List[List[str]]
grouped_ref_meta: List[List[Any | None]] | None
grouped_completion_info: List[List[Dict[str, Any]]] | None = None
class maxent_grpo.training.types.LoggingConfigView(weighting, clipping, schedule)[source]

Bases: object

Pointers to configs referenced while logging.

Parameters:
weighting: WeightingSettings
clipping: ClipSettings
schedule: OptimizationSchedule
class maxent_grpo.training.types.LogStepArtifacts(loss_outputs, diagnostics, grad_norm_scalar, epoch_progress)[source]

Bases: object

Helper container for optimizer/loss diagnostics per step.

Parameters:
loss_outputs: LossOutputs
diagnostics: BatchDiagnostics
grad_norm_scalar: float | None
epoch_progress: float
as_dict()[source]

Return a dict view useful for debugging/log statements.

Return type:

Dict[str, Any]

class maxent_grpo.training.types.LoopSettings(generation, evaluation, optimization, scoring, controller, controller_objective=None, controller_meta_manager=None)[source]

Bases: object

Grouped training configuration shared across the loop.

Parameters:
generation: GenerationSettings
evaluation: EvaluationSettings
optimization: OptimizationSettings
scoring: ScoringSettings
controller: ControllerPaths
controller_objective: 'ControllerObjective' | None = None
controller_meta_manager: Any | None = None
class maxent_grpo.training.types.LossScalarBundle(total_loss, policy_loss, clip_loss, kl_loss, weighted_kl_loss)[source]

Bases: object

Scalar contributions tracked for logging.

Parameters:
total_loss: float
policy_loss: float
clip_loss: float | None
kl_loss: float
weighted_kl_loss: float
class maxent_grpo.training.types.MetricWriter(*args, **kwargs)[source]

Bases: Protocol

Protocol describing a metric writer used by the training loop.

log(metrics, step)[source]

Record metrics for a training step.

Parameters:
Return type:

None

flush()[source]

Flush buffered metrics to their storage backend.

Return type:

None

class maxent_grpo.training.types.OptimizationSettings(schedule, handles)[source]

Bases: object

Combined optimization metadata.

Parameters:
schedule: OptimizationSchedule
handles: OptimizerHandles
class maxent_grpo.training.types.PromptCompletionBatch(prompts, completions, metadata=None)[source]

Bases: object

Flattened prompt/completion pairs.

Parameters:
prompts: List[str]
completions: List[str]
metadata: List[Dict[str, Any]] | None = None
class maxent_grpo.training.types.QDistribution(grouped, samples)[source]

Bases: object

Sequence-level q-distribution.

Parameters:
grouped: List[List[float]]
samples: List[float]
class maxent_grpo.training.types.RewardComponentStats(mean, std)[source]

Bases: object

Mean/std summary for an individual reward component.

Parameters:
mean: float
std: float
class maxent_grpo.training.types.RewardMoments(mean, std)[source]

Bases: object

Summary statistics for sequence rewards.

Parameters:
mean: float
std: float
class maxent_grpo.training.types.RuntimeHandles(accelerator, model, tokenizer, train_loader, train_sampler, device, get_ref_model, reference_model=None, prompt_cache_get=None)[source]

Bases: object

Pointers to objects that should live for the entire training job.

Parameters:
accelerator: Accelerator
model: PreTrainedModel
tokenizer: PreTrainedTokenizer
train_loader: DataLoader
train_sampler: Sampler | None
device: Device
get_ref_model: Callable[[], PreTrainedModel]
reference_model: PreTrainedModel | None = None
prompt_cache_get: Callable[[str], PromptCacheEntry] | None = None
class maxent_grpo.training.types.TokenUsageStats(avg_completion_tokens, num_completion_tokens, num_input_tokens)[source]

Bases: object

Aggregate completion/input token statistics.

Parameters:
  • avg_completion_tokens (float)

  • num_completion_tokens (float)

  • num_input_tokens (float)

avg_completion_tokens: float
num_completion_tokens: float
num_input_tokens: float
class maxent_grpo.training.types.TrainingScalarStats(ref_logp_mean, tokens, current_lr, grad_norm_scalar, epoch_progress, vllm_latency_ms, policy_entropy=None, entropy_bonus_coef=None, entropy_bonus_reward_std=None)[source]

Bases: object

Scalar values that vary every logging step.

Parameters:
ref_logp_mean: float
tokens: TokenUsageStats
current_lr: float
grad_norm_scalar: float | None
epoch_progress: float
vllm_latency_ms: float | None
policy_entropy: float | None = None
entropy_bonus_coef: float | None = None
entropy_bonus_reward_std: float | None = None
property avg_completion_tokens: float

Return the average completion token length.

Returns:

Running average of completion token counts.

Return type:

float

property num_completion_tokens: float

Return the total completion token count processed.

Returns:

Total completion token count accumulated.

Return type:

float

property num_input_tokens: float

Return the total input token count processed.

Returns:

Total input token count accumulated.

Return type:

float

class maxent_grpo.training.types.SequenceScores(cur_logp_sum, behavior_logp_sum, log_ratio_train, denom_tok_tensor, pooled_hidden=None, policy_entropy_sum=None, token_logp=None, token_mask=None, old_token_logp=None)[source]

Bases: object

Bundle sequence-level log-prob statistics.

Parameters:
  • cur_logp_sum (Any)

  • behavior_logp_sum (Any)

  • log_ratio_train (Any)

  • denom_tok_tensor (Any)

  • pooled_hidden (Any | None)

  • policy_entropy_sum (Any | None)

  • token_logp (Any | None)

  • token_mask (Any | None)

  • old_token_logp (Any | None)

cur_logp_sum: Any
behavior_logp_sum: Any
log_ratio_train: Any
denom_tok_tensor: Any
pooled_hidden: Any | None = None
policy_entropy_sum: Any | None = None
token_logp: Any | None = None
token_mask: Any | None = None
old_token_logp: Any | None = None

Modules

logging

Logging protocols and dataclasses shared across the training stack.

rewards

Per-batch dataclasses shared across the pipeline, loss, and metrics code.

runtime

Runtime handles and configuration dataclasses for the training loop.