maxent_grpo.training.scoring_common

Scoring helpers extracted from the MaxEnt-GRPO training loop.

Functions

_as_context_manager(value)

Return value as a context manager when possible, otherwise a no-op.

_autocast_context(accelerator, device)

Return the right autocast context for the current accelerator/device.

_coerce_optional_int(value)

Return value coerced to int when possible, else None.

_coerce_shape(value)

_describe_embedding_module(module, name)

Return a human-friendly summary of an embedding module.

_dist_all(dist, flag)

Return True if flag is True on all ranks (best-effort).

_dist_any(dist, flag)

Return True if flag is True on any rank (best-effort).

_dist_collective_ready(torch_mod)

Return a dist module when initialized, otherwise None.

_get_config_value(config, key[, default])

Return a config value from either Mapping or object-style configs.

_get_embedding_vocab_size(model, config)

Return the vocab size exposed by the model's embedding weights.

_maybe_long_tensor(value, torch_mod)

Return a tensor cast to long when the stub lacks long.

_model_has_non2d_embeddings(model)

Return True when any known embedding weight is not 2-D.

_prefetch_iterator(iterator, buffer_size)

Yield from iterator while prefetching up to buffer_size slices.

_progress_log_enabled()

_refresh_torch()

Return the active torch module.

_resolve_dtype(dtype)

Normalize dtype objects coming from various stubs.

_score_slice_log_enabled()

_size_hint(tensor_obj, dim)

Return tensor.size(dim) with fallbacks for numpy-backed stubs.

_to_numpy_array(obj)

Return a numpy view of obj for stub compatibility.

_weight_is_stub_tensor(weight)

Return True for tensor-like stubs used in tests.

_weight_is_two_dimensional(weight)

Return True if the provided weight exposes a 2-D shape.

Classes

_DistModuleLike(*args, **kwargs)

Minimal distributed API needed for best-effort gathers.

_LongDTypeProxy(target)

Lightweight wrapper that compares equal to torch.long in stubs.

_PadTokenGuard(targets, value)

Context manager that temporarily clamps padding attributes.

_TorchModuleLike(*args, **kwargs)