Source code for maxent_grpo.training.scoring_logprob

# Copyright 2025 Liv d'Aliberti
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

"""Model logprob computation and sequence-score assembly helpers."""

from __future__ import annotations

from contextlib import ExitStack, nullcontext
import inspect
import os
import time
from typing import (
    Any,
    Callable,
    ContextManager,
    Iterable,
    List,
    Optional,
    Sequence,
    Tuple,
    cast,
)

import numpy as np

from .scoring_batching import iter_batch_slices
from .scoring_common import (
    LOG,
    TorchDevice,
    _PadTokenGuard,
    _SCORING_EXCEPTIONS,
    _TorchModuleLike,
    _autocast_context,
    _coerce_optional_int,
    _describe_embedding_module,
    _dist_all,
    _dist_any,
    _dist_collective_ready,
    _get_config_value,
    _get_embedding_vocab_size,
    _prefetch_iterator,
    _progress_log_enabled,
    _refresh_torch,
    _score_slice_log_enabled,
    _to_numpy_array,
    _weight_is_stub_tensor,
    _weight_is_two_dimensional,
)
from .types import (
    BatchingSettings,
    PreTrainedModel,
    ReferenceLogprobs,
    RuntimeHandles,
    ScoreBatch,
    SequenceScores,
    Tensor,
)
from .zero_utils import _maybe_zero_gather_params

torch = _refresh_torch()


def _summon_fsdp_full_param_context(model: PreTrainedModel) -> ContextManager[object]:
    """Return a context manager that gathers FSDP parameters when available."""
    summon_fn = getattr(model, "summon_full_params", None)
    summon_callable = cast(Optional[Callable[..., ContextManager[object]]], summon_fn)
    if not callable(summon_callable):
        return nullcontext()
    try:
        return summon_callable()
    except TypeError:
        try:
            return summon_callable(recurse=True)
        except TypeError:
            return nullcontext()


def _chunked_sequence_logprobs(
    model: PreTrainedModel,
    *,
    input_ids: Tensor,
    attention_mask: Tensor,
    labels: Tensor,
    chunk_size: int,
    gather_full_params: bool = False,  # retained for parity
    zero_gather_all_ranks: bool = False,
    return_hidden: bool = False,
    pooling: str = "mean",
    return_entropy: bool = False,
    entropy_mode: str = "exact",
    return_token_logp: bool = False,
) -> Optional[
    Tuple[
        Tensor,
        Tensor,
        Optional[Tensor],
        Optional[Tensor],
        Optional[Tensor],
        Optional[Tensor],
    ]
]:
    """Compute summed log-probabilities per sequence with optional chunking/pooled states/entropy."""

    torch_mod = _refresh_torch()
    slice_log = _score_slice_log_enabled()
    entropy_mode_norm = str(entropy_mode or "exact").strip().lower()
    if entropy_mode_norm in {"", "none"}:
        entropy_mode_norm = "exact"
    if entropy_mode_norm in {"exact", "full", "distribution"}:
        entropy_mode_norm = "exact"
    elif entropy_mode_norm in {
        "sample",
        "estimate",
        "estimated",
        "approx",
        "approximate",
        "token",
        "token_logp",
        "nll",
        "logp",
    }:
        entropy_mode_norm = "sample"
    else:
        if return_entropy:
            warned = getattr(_chunked_sequence_logprobs, "_entropy_mode_warned", False)
            if not warned:
                LOG.warning(
                    "Unknown entropy_mode=%s; falling back to 'exact'.",
                    entropy_mode,
                )
                setattr(_chunked_sequence_logprobs, "_entropy_mode_warned", True)
        entropy_mode_norm = "exact"
    use_exact_entropy = return_entropy and entropy_mode_norm == "exact"
    use_sample_entropy = return_entropy and entropy_mode_norm == "sample"
    _ = (chunk_size,)  # parity with distributed APIs; currently unused

    def _compress_completion_token_logp(
        token_logp: Tensor, token_mask: Tensor
    ) -> Tuple[Tensor, Tensor]:
        """Return completion-only token logprobs and mask (pad to max completion length)."""
        try:
            torch_tensor = getattr(torch_mod, "Tensor", None)
            if torch_tensor is None:
                return token_logp, token_mask
            if not isinstance(token_logp, torch_tensor) or not isinstance(
                token_mask, torch_tensor
            ):
                return token_logp, token_mask
        except (TypeError, AttributeError):
            return token_logp, token_mask
        if token_logp.ndim < 2 or token_mask.ndim < 2:
            return token_logp, token_mask
        if token_logp.shape != token_mask.shape:
            return token_logp, token_mask
        mask_bool = token_mask != 0
        try:
            mask_int = mask_bool.to(dtype=getattr(torch_mod, "long", None))
        except _SCORING_EXCEPTIONS:
            try:
                mask_int = mask_bool.long()
            except _SCORING_EXCEPTIONS:
                return token_logp, token_mask
        counts = mask_int.sum(dim=1)
        if getattr(counts, "numel", lambda: 0)() == 0:
            return token_logp[:, :0], token_mask[:, :0]
        try:
            max_len = int(counts.max().item())
        except _SCORING_EXCEPTIONS:
            return token_logp, token_mask
        if max_len <= 0:
            try:
                empty = torch_mod.zeros(
                    (token_logp.shape[0], 0),
                    dtype=getattr(token_logp, "dtype", None),
                    device=getattr(token_logp, "device", None),
                )
            except _SCORING_EXCEPTIONS:
                empty = torch_mod.zeros((token_logp.shape[0], 0))
            return empty, empty
        try:
            comp_logp = torch_mod.zeros(
                (token_logp.shape[0], max_len),
                dtype=getattr(token_logp, "dtype", None),
                device=getattr(token_logp, "device", None),
            )
            comp_mask = torch_mod.zeros(
                (token_logp.shape[0], max_len),
                dtype=getattr(token_logp, "dtype", None),
                device=getattr(token_logp, "device", None),
            )
        except _SCORING_EXCEPTIONS:
            comp_logp = torch_mod.zeros((token_logp.shape[0], max_len))
            comp_mask = torch_mod.zeros((token_logp.shape[0], max_len))
        pos = mask_int.cumsum(dim=1) - 1
        try:
            rows, cols = mask_bool.nonzero(as_tuple=True)
        except (TypeError, AttributeError):
            try:
                nz = mask_bool.nonzero()
                rows, cols = nz[:, 0], nz[:, 1]
            except _SCORING_EXCEPTIONS:
                return token_logp, token_mask
        if getattr(rows, "numel", lambda: 0)() > 0:
            comp_cols = pos[rows, cols]
            try:
                comp_logp[rows, comp_cols] = token_logp[rows, cols]
                comp_mask[rows, comp_cols] = 1
            except _SCORING_EXCEPTIONS:
                return token_logp, token_mask
        return comp_logp, comp_mask

    # Fallback for lightweight stubs
    if not hasattr(model, "forward"):
        label_arr = np.asarray(getattr(labels, "arr", labels))
        valid_counts = (label_arr != -100).sum(axis=1)
        logp = torch_mod.tensor(
            np.zeros(len(valid_counts), dtype=float),
            dtype=getattr(torch_mod, "float32", None),
        )
        tok_tensor = torch_mod.tensor(
            valid_counts, dtype=getattr(torch_mod, "float32", None)
        )
        entropy_sum = (
            torch_mod.tensor(
                np.zeros(len(valid_counts), dtype=float),
                dtype=getattr(torch_mod, "float32", None),
            )
            if return_entropy
            else None
        )
        if return_token_logp:
            token_logp = torch_mod.zeros(
                (len(valid_counts), 0),
                dtype=getattr(torch_mod, "float32", None),
            )
            return logp, tok_tensor, None, entropy_sum, token_logp, token_logp
        return logp, tok_tensor, None, entropy_sum
    shape = getattr(input_ids, "shape", None)
    if shape and len(shape) > 1 and shape[1] == 0:
        batch_size = shape[0]
        zero = torch_mod.tensor(
            np.zeros(batch_size, dtype=float), dtype=getattr(torch_mod, "float32", None)
        )
        tok_tensor = torch_mod.tensor(
            np.zeros(batch_size, dtype=float), dtype=getattr(torch_mod, "float32", None)
        )
        entropy_sum = (
            torch_mod.tensor(
                np.zeros(batch_size, dtype=float),
                dtype=getattr(torch_mod, "float32", None),
            )
            if return_entropy
            else None
        )
        if return_token_logp:
            token_logp = torch_mod.zeros(
                (batch_size, 0),
                dtype=getattr(torch_mod, "float32", None),
            )
            return zero, tok_tensor, None, entropy_sum, token_logp, token_logp
        return zero, tok_tensor, None, entropy_sum
    # DeepSpeed ZeRO-3 shards parameters to 1-D partitions; gather embeddings
    # so the reference forward sees 2-D weights without full-parameter all-gather.
    gather_ctx = nullcontext()
    zero_gather_strategy = "none"
    dist = _dist_collective_ready(torch_mod)
    if gather_full_params:
        try:  # pragma: no cover - exercised in distributed runs
            import deepspeed

            params: list[Any] = []
            param_iter = getattr(model, "parameters", None)
            if callable(param_iter):
                try:
                    params = list(cast(Iterable[Any], param_iter()))
                except _SCORING_EXCEPTIONS:
                    params = []
            # Always invoke GatheredParameters when available, even with an empty
            # list, so stubbed environments can verify the context is entered.
            try:
                gather_ctx = deepspeed.zero.GatheredParameters(
                    params or [], modifier_rank=None
                )
            except TypeError:
                gather_ctx = deepspeed.zero.GatheredParameters(params or [])
            zero_gather_strategy = "manual_full"
        except ImportError:
            gather_ctx = nullcontext()
    else:
        try:  # pragma: no cover - exercised in distributed runs
            import deepspeed

            zero_mod = getattr(deepspeed, "zero", None)
            is_enabled_fn = (
                getattr(zero_mod, "is_enabled", None) if zero_mod is not None else None
            )
            if callable(is_enabled_fn) and is_enabled_fn():
                to_gather: list[Any] = []
                if os.environ.get("MAXENT_DISABLE_SCORING_ZERO_GATHER", "").strip():
                    gather_ctx = nullcontext()
                inp_emb = (
                    model.get_input_embeddings()
                    if hasattr(model, "get_input_embeddings")
                    else None
                )
                input_weight = (
                    getattr(inp_emb, "weight", None)
                    if inp_emb is not None and hasattr(inp_emb, "weight")
                    else None
                )
                out_emb = (
                    model.get_output_embeddings()
                    if hasattr(model, "get_output_embeddings")
                    else None
                )
                if out_emb is None and hasattr(model, "lm_head"):
                    out_emb = model.lm_head
                output_weight = (
                    getattr(out_emb, "weight", None)
                    if out_emb is not None and hasattr(out_emb, "weight")
                    else None
                )

                input_present_local = input_weight is not None
                output_present_local = output_weight is not None
                input_present_all = (
                    _dist_all(dist, input_present_local)
                    if dist is not None
                    else input_present_local
                )
                output_present_all = (
                    _dist_all(dist, output_present_local)
                    if dist is not None
                    else output_present_local
                )

                needs_input_local = bool(
                    input_weight is not None
                    and not _weight_is_two_dimensional(input_weight)
                )
                needs_output_local = bool(
                    output_weight is not None
                    and not _weight_is_two_dimensional(output_weight)
                )
                needs_input_any = (
                    _dist_any(dist, needs_input_local)
                    if dist is not None
                    else needs_input_local
                )
                needs_output_any = (
                    _dist_any(dist, needs_output_local)
                    if dist is not None
                    else needs_output_local
                )

                if needs_input_any and input_present_all and input_weight is not None:
                    to_gather.append(input_weight)
                if (
                    needs_output_any
                    and output_present_all
                    and output_weight is not None
                ):
                    to_gather.append(output_weight)

                if (needs_input_any and not input_present_all) or (
                    needs_output_any and not output_present_all
                ):
                    LOG.warning(
                        "DeepSpeed ZeRO gather decision mismatch across ranks; skipping embed gather to avoid deadlock | "
                        "input_present_local=%s input_present_all=%s needs_input_local=%s needs_input_any=%s | "
                        "output_present_local=%s output_present_all=%s needs_output_local=%s needs_output_any=%s",
                        input_present_local,
                        input_present_all,
                        needs_input_local,
                        needs_input_any,
                        output_present_local,
                        output_present_all,
                        needs_output_local,
                        needs_output_any,
                    )
                    to_gather = []

                if to_gather:
                    # De-duplicate parameters to avoid repeated GatheredParameters calls
                    # on shared/tied embedding weights.
                    seen: set[int] = set()
                    unique: list[Any] = []
                    for param in to_gather:
                        param_id = id(param)
                        if param_id in seen:
                            continue
                        seen.add(param_id)
                        unique.append(param)
                    LOG.debug(
                        "DeepSpeed ZeRO gather for scoring | tensors=%d | shapes=%s",
                        len(unique),
                        [getattr(param, "shape", None) for param in unique],
                    )
                    try:
                        gather_ctx = deepspeed.zero.GatheredParameters(
                            unique, modifier_rank=None
                        )
                    except TypeError:
                        gather_ctx = deepspeed.zero.GatheredParameters(unique)
                    zero_gather_strategy = "manual_embed"
        except ImportError:
            gather_ctx = nullcontext()
        except _SCORING_EXCEPTIONS:
            gather_ctx = nullcontext()

    use_helper_zero_gather = zero_gather_strategy == "none"
    helper_gather_all_ranks = bool(zero_gather_all_ranks or gather_full_params)
    fsdp_ctx = _summon_fsdp_full_param_context(model)
    stack = ExitStack()
    if slice_log:
        LOG.info(
            "chunked_sequence_logprobs enter gather_ctx | gather_full_params=%s strategy=%s helper_zero_gather=%s",
            gather_full_params,
            zero_gather_strategy,
            use_helper_zero_gather,
        )
    stack.enter_context(gather_ctx)
    if slice_log:
        LOG.info("chunked_sequence_logprobs entered gather_ctx")
    if use_helper_zero_gather:
        # Enforce a single ZeRO gather strategy per scoring pass. If we already
        # entered an explicit DeepSpeed gather context above, skip helper gathers
        # to avoid nested GatheredParameters over overlapping tensors.
        if slice_log:
            LOG.info("chunked_sequence_logprobs enter zero_gather_params")
        stack.enter_context(
            _maybe_zero_gather_params(
                model, enabled=True, gather_all_ranks=helper_gather_all_ranks
            )
        )
        if slice_log:
            LOG.info("chunked_sequence_logprobs entered zero_gather_params")
    else:
        LOG.debug(
            "Using manual ZeRO gather strategy=%s; skipping helper gather contexts.",
            zero_gather_strategy,
        )
    if slice_log:
        LOG.info("chunked_sequence_logprobs enter fsdp_ctx")
    stack.enter_context(fsdp_ctx)
    if slice_log:
        LOG.info("chunked_sequence_logprobs entered fsdp_ctx")
    config = getattr(model, "config", None)
    padding_idx = _coerce_optional_int(_get_config_value(config, "pad_token_id", None))
    embedding_vocab_size = _get_embedding_vocab_size(model, config)
    vocab_size = _coerce_optional_int(_get_config_value(config, "vocab_size", None))
    pad_targets: list[tuple[Any, str]] = []
    if config is not None:
        pad_targets.append((config, "pad_token_id"))
    seen_modules: set[int] = set()
    embed_token_module = getattr(model, "embed_tokens", None)
    if embed_token_module is not None:
        seen_modules.add(id(embed_token_module))
        if hasattr(embed_token_module, "padding_idx"):
            pad_targets.append((embed_token_module, "padding_idx"))
    try:
        input_embed_module = model.get_input_embeddings()
    except _SCORING_EXCEPTIONS:
        input_embed_module = None
    if (
        input_embed_module is not None
        and id(input_embed_module) not in seen_modules
        and hasattr(input_embed_module, "padding_idx")
    ):
        pad_targets.append((input_embed_module, "padding_idx"))
    final_padding_idx = padding_idx
    if padding_idx is not None:
        limit: Optional[int] = None
        if embedding_vocab_size is not None:
            limit = embedding_vocab_size - 1
        if vocab_size is not None:
            vocab_limit = vocab_size - 1
            if limit is None or vocab_limit < limit:
                limit = vocab_limit
        if limit is not None:
            limit = max(limit, 0)
            if padding_idx > limit:
                final_padding_idx = limit
    pad_ctx = nullcontext()
    if (
        padding_idx is not None
        and final_padding_idx is not None
        and final_padding_idx != padding_idx
        and pad_targets
    ):
        LOG.debug(
            "Clamping padding idx for scoring | original=%s final=%s",
            padding_idx,
            final_padding_idx,
        )
        pad_ctx = _PadTokenGuard(pad_targets, final_padding_idx)
        padding_idx = final_padding_idx
    stack.enter_context(pad_ctx)
    with stack:
        if slice_log:
            LOG.info(
                "chunked_sequence_logprobs start | input_ids_shape=%s attention_mask_shape=%s labels_shape=%s",
                getattr(input_ids, "shape", None),
                getattr(attention_mask, "shape", None),
                getattr(labels, "shape", None),
            )
        LOG.debug(
            "chunked_sequence_logprobs start | gather_full_params=%s return_hidden=%s pooling=%s | "
            "input_ids_shape=%s dtype=%s device=%s | attention_mask_shape=%s | labels_shape=%s dtype=%s device=%s",
            gather_full_params,
            return_hidden,
            pooling,
            getattr(input_ids, "shape", None),
            getattr(input_ids, "dtype", None),
            getattr(input_ids, "device", None),
            getattr(attention_mask, "shape", None),
            getattr(labels, "shape", None),
            getattr(labels, "dtype", None),
            getattr(labels, "device", None),
        )
        # High-level shape logging for reference scoring.
        LOG.debug(
            "reference scoring inputs | input_ids_shape=%s attention_mask_shape=%s labels_shape=%s",
            getattr(input_ids, "shape", None),
            getattr(attention_mask, "shape", None),
            getattr(labels, "shape", None),
        )
        LOG.debug(
            "reference scoring pad metadata | model.config.pad_token_id=%s embedding_vocab_size=%s",
            padding_idx,
            embedding_vocab_size,
        )

        embed_descs: list[str] = []
        embed_tokens = getattr(model, "embed_tokens", None)
        embed_descs.append(_describe_embedding_module(embed_tokens, "embed_tokens"))
        try:
            input_embeddings = model.get_input_embeddings()
        except _SCORING_EXCEPTIONS:
            input_embeddings = None
        if input_embeddings is not None and input_embeddings is not embed_tokens:
            embed_descs.append(
                _describe_embedding_module(input_embeddings, "input_embeddings")
            )
        # Conservative guard: if the reference model's embedding weights are
        # not 2-D under the gathered parameter contexts, skip reference
        # scoring to avoid noisy runtime errors from torch.embedding.
        for module in (embed_tokens, input_embeddings):
            if module is None:
                continue
            weight = getattr(module, "weight", None)
            if weight is not None and not _weight_is_two_dimensional(weight):
                if _weight_is_stub_tensor(weight):
                    LOG.debug(
                        "Non-2D stub embedding weight; continuing reference scoring | %s",
                        " | ".join(embed_descs),
                    )
                    continue
                LOG.warning(
                    "Skipping reference scoring due to non-2D embedding weight | %s",
                    " | ".join(embed_descs),
                )
                return None
        LOG.debug(
            "reference scoring embeddings | %s | pad_token_id=%s vocab=%s",
            " | ".join(embed_descs),
            padding_idx,
            embedding_vocab_size,
        )
        batch = getattr(input_ids, "shape", None)
        batch = batch[0] if batch else 0
        chunk_limit = int(chunk_size) if chunk_size is not None else 0
        if (
            chunk_limit <= 0
            and batch > 1
            and not os.environ.get("MAXENT_DISABLE_LOGPROB_AUTOBATCH", "").strip()
        ):
            try:
                vocab_size_guess = int(
                    getattr(getattr(model, "config", None), "vocab_size", 0) or 0
                )
            except (TypeError, ValueError):
                vocab_size_guess = 0
            seq_len = getattr(input_ids, "shape", None)
            seq_len = int(seq_len[1]) if seq_len and len(seq_len) > 1 else 0
            device_str = str(getattr(input_ids, "device", "")).lower()
            if vocab_size_guess > 0 and seq_len > 0 and "cuda" in device_str:
                try:
                    model_dtype = getattr(model, "dtype", None)
                    dtype_str = str(
                        getattr(model_dtype, "name", model_dtype) or ""
                    ).lower()
                except _SCORING_EXCEPTIONS:
                    dtype_str = ""
                bytes_per_elem = 2
                if "float32" in dtype_str or "fp32" in dtype_str:
                    bytes_per_elem = 4
                target_mb_raw = os.environ.get("MAXENT_LOGPROB_TARGET_LOGITS_MB", "256")
                try:
                    target_bytes = int(float(target_mb_raw) * 1024 * 1024)
                except (TypeError, ValueError):
                    target_bytes = 256 * 1024 * 1024
                bytes_per_seq = max(1, seq_len * vocab_size_guess * bytes_per_elem)
                auto_limit = max(1, min(batch, target_bytes // bytes_per_seq))
                if auto_limit < batch:
                    warned = getattr(
                        _chunked_sequence_logprobs, "_autobatch_warned", False
                    )
                    if not warned:
                        LOG.warning(
                            "Auto-tuning reference scoring batch chunk size to avoid large logits tensors | "
                            "requested_chunk_size=%s auto_chunk_size=%s batch=%s seq_len=%s vocab=%s target_logits_mb=%s "
                            "(set MAXENT_DISABLE_LOGPROB_AUTOBATCH=1 to disable)",
                            chunk_size,
                            auto_limit,
                            batch,
                            seq_len,
                            vocab_size_guess,
                            target_mb_raw,
                        )
                        setattr(_chunked_sequence_logprobs, "_autobatch_warned", True)
                    chunk_limit = auto_limit
        if chunk_limit <= 0 or chunk_limit >= batch:
            chunk_indices = [(0, batch)]
        else:
            chunk_indices = [
                (start, min(start + chunk_limit, batch))
                for start in range(0, batch, chunk_limit)
            ]
        LOG.debug(
            "reference scoring chunk plan | total_batch=%s chunk_size=%s chunks=%s",
            batch,
            chunk_limit,
            len(chunk_indices),
        )
    logp_chunks: list[Tensor] = []
    tok_chunks: list[Tensor] = []
    pooled_chunks: list[Tensor] = [] if return_hidden else []
    entropy_chunks: list[Tensor] = [] if return_entropy else []
    token_logp_chunks: list[Tensor] = [] if return_token_logp else []
    token_mask_chunks: list[Tensor] = [] if return_token_logp else []
    for idx, (start, end) in enumerate(chunk_indices):
        LOG.debug(
            "reference scoring chunk begin | chunk=%s | slice=[%s:%s] | rows=%s",
            idx,
            start,
            end,
            end - start,
        )
        if slice_log:
            LOG.info(
                "chunked_sequence_logprobs forward start | chunk=%s slice=[%s:%s]",
                idx,
                start,
                end,
            )
            forward_start = time.monotonic()
        ids_chunk = input_ids[start:end]
        mask_chunk = attention_mask[start:end] if attention_mask is not None else None
        label_chunk = labels[start:end]
        wants_labels = False
        for callable_name in ("forward", "__call__"):
            candidate = getattr(model, callable_name, None)
            if not callable(candidate):
                continue
            try:
                sig = inspect.signature(candidate)
                param = sig.parameters.get("labels")
                if param is not None and param.default is inspect.Signature.empty:
                    wants_labels = True
                    break
            except (TypeError, ValueError):
                continue
        call_kwargs: dict[str, Any] = {
            "input_ids": ids_chunk,
            "attention_mask": mask_chunk,
            "output_hidden_states": return_hidden,
        }
        if wants_labels:
            call_kwargs["labels"] = label_chunk
        call_target = model if callable(model) else getattr(model, "forward", None)
        if not callable(call_target):
            raise TypeError("Model is not callable and lacks a forward method")
        try:
            outputs = call_target(**call_kwargs)
        except TypeError:
            call_kwargs.pop("output_hidden_states", None)
            try:
                outputs = call_target(**call_kwargs)
            except TypeError:
                if not wants_labels:
                    call_kwargs.pop("labels", None)
                outputs = call_target(**call_kwargs)
        if slice_log:
            LOG.info(
                "chunked_sequence_logprobs forward done | chunk=%s seconds=%.2f",
                idx,
                time.monotonic() - forward_start,
            )
        outputs_any = cast(Any, outputs)
        logits = getattr(outputs_any, "logits", None)
        if logits is None:
            raise AttributeError("Model outputs missing logits")
        LOG.debug(
            "reference scoring logits metadata | chunk=%s | shape=%s dtype=%s device=%s",
            idx,
            getattr(logits, "shape", None),
            getattr(logits, "dtype", None),
            getattr(logits, "device", None),
        )
        # Causal LM logits at position t predict token t+1. Align labels by
        # shifting so we score next-token log-probs for all non-masked targets.
        if logits.size(1) <= 1:
            batch_rows = logits.size(0)
            try:
                seq_logp_chunk = torch_mod.zeros(
                    (batch_rows,),
                    dtype=getattr(torch_mod, "float32", None),
                    device=getattr(logits, "device", None),
                )
            except _SCORING_EXCEPTIONS:
                seq_logp_chunk = torch_mod.tensor(
                    np.zeros(batch_rows, dtype=float),
                    dtype=getattr(torch_mod, "float32", None),
                    device=getattr(logits, "device", None),
                )
            tok_tensor_chunk = torch_mod.ones(
                (batch_rows,),
                dtype=getattr(torch_mod, "long", None),
                device=getattr(logits, "device", None),
            )
            logp_chunks.append(seq_logp_chunk)
            tok_chunks.append(tok_tensor_chunk)
            if return_token_logp:
                try:
                    token_logp_chunk = torch_mod.zeros(
                        (batch_rows, 0),
                        dtype=getattr(torch_mod, "float32", None),
                        device=getattr(seq_logp_chunk, "device", None),
                    )
                except _SCORING_EXCEPTIONS:
                    token_logp_chunk = torch_mod.tensor(
                        np.zeros((batch_rows, 0), dtype=float),
                        dtype=getattr(torch_mod, "float32", None),
                        device=getattr(seq_logp_chunk, "device", None),
                    )
                token_mask_chunk = token_logp_chunk
                token_logp_chunks.append(token_logp_chunk)
                token_mask_chunks.append(token_mask_chunk)
            if return_entropy:
                try:
                    entropy_chunk = torch_mod.zeros(
                        (batch_rows,),
                        dtype=getattr(torch_mod, "float32", None),
                        device=getattr(seq_logp_chunk, "device", None),
                    )
                except _SCORING_EXCEPTIONS:
                    entropy_chunk = torch_mod.tensor(
                        np.zeros(batch_rows, dtype=float),
                        dtype=getattr(torch_mod, "float32", None),
                        device=getattr(seq_logp_chunk, "device", None),
                    )
                entropy_chunks.append(entropy_chunk)
            preview_vals = None
            preview_source = seq_logp_chunk
            detach_fn = getattr(seq_logp_chunk, "detach", None)
            if callable(detach_fn):
                try:
                    preview_source = detach_fn()
                except _SCORING_EXCEPTIONS:
                    preview_source = seq_logp_chunk
            preview_source_any = cast(Any, preview_source)
            if hasattr(preview_source_any, "cpu"):
                try:
                    preview_vals = preview_source_any.cpu().reshape(-1)[:3].tolist()
                except _SCORING_EXCEPTIONS:
                    preview_vals = None
            LOG.debug(
                "reference scoring chunk stats | chunk=%s | ids_shape=%s | mask_shape=%s | "
                "seq_logp_shape=%s dtype=%s device=%s | tok_shape=%s dtype=%s device=%s | "
                "tok_sum=%s | valid_token_mask_sum=%s | seq_logp_preview=%s",
                idx,
                getattr(ids_chunk, "shape", None),
                getattr(mask_chunk, "shape", None),
                getattr(seq_logp_chunk, "shape", None),
                getattr(seq_logp_chunk, "dtype", None),
                getattr(seq_logp_chunk, "device", None),
                getattr(tok_tensor_chunk, "shape", None),
                getattr(tok_tensor_chunk, "dtype", None),
                getattr(tok_tensor_chunk, "device", None),
                float(batch_rows),
                0,
                preview_vals,
            )
            continue

        shifted_logits = logits[:, :-1, :]
        shifted_labels = label_chunk[:, 1:]
        label_mask = shifted_labels != -100
        safe_labels = shifted_labels.masked_fill(~label_mask, 0)
        # Memory-lean path for MaxEnt/reference scoring when we only need
        # sequence-level log-prob sums (no per-token outputs/entropy).
        if not return_token_logp and not return_entropy:
            try:
                nonzero_kwargs = {"as_tuple": True}
                valid_rows, valid_cols = label_mask.nonzero(**nonzero_kwargs)
                valid_count = int(getattr(valid_rows, "numel", lambda: 0)())
                batch_rows = shifted_logits.size(0)
                seq_logp_chunk = torch_mod.zeros(
                    (batch_rows,),
                    dtype=getattr(shifted_logits, "dtype", None),
                    device=getattr(shifted_logits, "device", None),
                )
                if valid_count > 0:
                    raw_chunk = os.getenv("MAXENT_LOGPROB_TOKEN_CHUNK", "256")
                    try:
                        token_chunk = max(1, int(raw_chunk))
                    except (TypeError, ValueError):
                        token_chunk = 256
                    target_ids = safe_labels[valid_rows, valid_cols]
                    logsumexp_fn = getattr(torch_mod, "logsumexp", None)
                    for start_pos in range(0, valid_count, token_chunk):
                        end_pos = min(start_pos + token_chunk, valid_count)
                        row_chunk = valid_rows[start_pos:end_pos]
                        col_chunk = valid_cols[start_pos:end_pos]
                        tgt_chunk = target_ids[start_pos:end_pos]
                        selected_logits = shifted_logits[row_chunk, col_chunk, :]
                        target_logits = selected_logits.gather(
                            dim=-1, index=tgt_chunk.unsqueeze(-1)
                        ).squeeze(-1)
                        if callable(logsumexp_fn):
                            log_denom = logsumexp_fn(selected_logits, dim=-1)
                            to_fn = getattr(target_logits, "to", None)
                            if callable(to_fn):
                                target_logits = to_fn(
                                    dtype=getattr(log_denom, "dtype", None)
                                )
                            token_logp_chunk = cast(Any, target_logits) - log_denom
                        else:
                            chunk_log_probs = torch_mod.nn.functional.log_softmax(
                                selected_logits, dim=-1
                            )
                            token_logp_chunk = chunk_log_probs.gather(
                                dim=-1, index=tgt_chunk.unsqueeze(-1)
                            ).squeeze(-1)
                        seq_logp_chunk = cast(
                            Tensor,
                            seq_logp_chunk.index_add(
                                0,
                                row_chunk,
                                cast(
                                    Any,
                                    token_logp_chunk.to(
                                        dtype=getattr(seq_logp_chunk, "dtype", None)
                                    ),
                                ),
                            ),
                        )
                tok_tensor_chunk = label_mask.sum(dim=1).clamp(min=1)
                logp_chunks.append(seq_logp_chunk)
                tok_chunks.append(tok_tensor_chunk)
                continue
            except _SCORING_EXCEPTIONS as exc:
                LOG.debug(
                    "Memory-lean sequence logprob path failed; falling back to dense path: %s",
                    exc,
                )
        gather_labels = safe_labels.unsqueeze(-1)
        log_probs = None
        token_logp = None
        if use_exact_entropy:
            try:
                log_probs = torch_mod.nn.functional.log_softmax(shifted_logits, dim=-1)
                token_logp = log_probs.gather(dim=-1, index=gather_labels).squeeze(-1)
            except _SCORING_EXCEPTIONS:
                log_probs = None
                token_logp = None
        if token_logp is None:
            try:
                logsumexp_fn = getattr(torch_mod, "logsumexp", None)
                if not callable(logsumexp_fn):
                    raise AttributeError("torch.logsumexp unavailable")
                log_denom = logsumexp_fn(shifted_logits, dim=-1)
                target_logits = shifted_logits.gather(
                    dim=-1, index=gather_labels
                ).squeeze(-1)
                to_fn = getattr(target_logits, "to", None)
                if callable(to_fn):
                    target_logits = to_fn(dtype=getattr(log_denom, "dtype", None))
                token_logp = cast(Any, target_logits) - log_denom
            except _SCORING_EXCEPTIONS:
                log_probs = torch_mod.nn.functional.log_softmax(shifted_logits, dim=-1)
                token_logp = log_probs.gather(dim=-1, index=gather_labels).squeeze(-1)
        # If mask tensors are real torch tensors but token_logp is a stub tensor,
        # coerce token_logp into the active torch module to avoid type mismatches.
        torch_tensor = getattr(torch_mod, "Tensor", None)
        if (
            torch_tensor is not None
            and isinstance(label_mask, torch_tensor)
            and not isinstance(token_logp, torch_tensor)
        ):
            try:
                token_logp = torch_mod.tensor(_to_numpy_array(token_logp))
            except _SCORING_EXCEPTIONS as exc:
                LOG.debug("Failed to coerce token_logp into torch tensor: %s", exc)
        mask_float = label_mask
        type_as_fn = getattr(mask_float, "type_as", None)
        if callable(type_as_fn):
            try:
                is_tensor_fn = getattr(torch_mod, "is_tensor", None)
                if callable(is_tensor_fn) and not is_tensor_fn(token_logp):
                    raise TypeError("type_as requires a torch tensor")
                mask_mod = getattr(type(mask_float), "__module__", "")
                token_mod = getattr(type(token_logp), "__module__", "")
                if mask_mod.startswith("torch") != token_mod.startswith("torch"):
                    raise TypeError("type_as requires matching torch tensors")
                if not isinstance(token_logp, type(mask_float)):
                    raise TypeError("type_as requires same tensor types")
                mask_float = type_as_fn(token_logp)
            except _SCORING_EXCEPTIONS:
                type_as_fn = None
        if not callable(type_as_fn):
            to_fn = getattr(mask_float, "to", None)
            if callable(to_fn):
                try:
                    mask_float = to_fn(dtype=getattr(token_logp, "dtype", None))
                except _SCORING_EXCEPTIONS:
                    float_fn = getattr(mask_float, "float", None)
                    if callable(float_fn):
                        mask_float = float_fn()
        seq_logp_chunk = cast(
            Tensor,
            (cast(Any, token_logp) * cast(Any, mask_float)).sum(dim=1),
        )
        tok_tensor_chunk = label_mask.sum(dim=1).clamp(min=1)
        logp_chunks.append(seq_logp_chunk)
        tok_chunks.append(tok_tensor_chunk)
        if return_token_logp:
            try:
                comp_logp, comp_mask = _compress_completion_token_logp(
                    cast(Tensor, token_logp), cast(Tensor, label_mask)
                )
            except _SCORING_EXCEPTIONS:
                comp_logp, comp_mask = token_logp, label_mask
            try:
                token_logp_chunks.append(cast(Tensor, comp_logp))
            except _SCORING_EXCEPTIONS:
                token_logp_chunks.append(torch_mod.tensor(_to_numpy_array(comp_logp)))
            token_mask_chunks.append(cast(Tensor, comp_mask))
        if return_entropy:
            entropy_chunk = None
            if use_sample_entropy:
                try:
                    entropy_chunk = (-seq_logp_chunk).to(
                        dtype=getattr(torch_mod, "float32", None)
                        or getattr(seq_logp_chunk, "dtype", None)
                    )
                except _SCORING_EXCEPTIONS:
                    entropy_chunk = None
            if entropy_chunk is None:
                try:
                    if log_probs is None:
                        log_probs = torch_mod.nn.functional.log_softmax(
                            shifted_logits, dim=-1
                        )
                    ent = -(log_probs.exp() * log_probs).sum(dim=-1)
                    entropy_chunk = (ent * cast(Any, mask_float)).sum(dim=1)
                except _SCORING_EXCEPTIONS as exc:
                    LOG.debug("Failed to compute policy entropy: %s", exc)
            if entropy_chunk is None:
                try:
                    entropy_chunk = torch_mod.zeros(
                        (seq_logp_chunk.shape[0],),
                        dtype=getattr(torch_mod, "float32", None),
                        device=getattr(seq_logp_chunk, "device", None),
                    )
                except _SCORING_EXCEPTIONS:
                    entropy_chunk = torch_mod.tensor(
                        np.zeros(seq_logp_chunk.shape[0], dtype=float),
                        dtype=getattr(torch_mod, "float32", None),
                        device=getattr(seq_logp_chunk, "device", None),
                    )
            entropy_chunks.append(entropy_chunk)
        try:
            tok_sum = tok_tensor_chunk.detach().cpu().sum().item()
        except _SCORING_EXCEPTIONS:
            tok_sum = None
        try:
            logp_preview = seq_logp_chunk.detach().cpu().reshape(-1)[:3].tolist()
        except _SCORING_EXCEPTIONS:
            logp_preview = None
        try:
            valid_tokens = int(label_mask.sum().detach().cpu().item())
        except _SCORING_EXCEPTIONS:
            valid_tokens = None
        LOG.debug(
            "reference scoring chunk stats | chunk=%s | ids_shape=%s | mask_shape=%s | "
            "seq_logp_shape=%s dtype=%s device=%s | tok_shape=%s dtype=%s device=%s | "
            "tok_sum=%s | valid_token_mask_sum=%s | seq_logp_preview=%s",
            idx,
            getattr(ids_chunk, "shape", None),
            getattr(mask_chunk, "shape", None),
            getattr(seq_logp_chunk, "shape", None),
            getattr(seq_logp_chunk, "dtype", None),
            getattr(seq_logp_chunk, "device", None),
            getattr(tok_tensor_chunk, "shape", None),
            getattr(tok_tensor_chunk, "dtype", None),
            getattr(tok_tensor_chunk, "device", None),
            tok_sum,
            valid_tokens,
            logp_preview,
        )
        hidden_states = getattr(outputs_any, "hidden_states", None)
        if return_hidden and hidden_states is not None:
            hidden = hidden_states[-1]
            mask = mask_chunk
            if pooling == "last":
                pooled = hidden[:, -1, :]
            else:
                if mask is None:
                    pooled = hidden.mean(dim=1)
                else:
                    mask = mask.unsqueeze(-1)
                    type_as_fn = getattr(mask, "type_as", None)
                    if callable(type_as_fn):
                        try:
                            is_tensor_fn = getattr(torch_mod, "is_tensor", None)
                            if callable(is_tensor_fn) and not is_tensor_fn(hidden):
                                raise TypeError("type_as requires a torch tensor")
                            mask_mod = getattr(type(mask), "__module__", "")
                            hidden_mod = getattr(type(hidden), "__module__", "")
                            if mask_mod.startswith("torch") != hidden_mod.startswith(
                                "torch"
                            ):
                                raise TypeError(
                                    "type_as requires matching torch tensors"
                                )
                            if not isinstance(hidden, type(mask)):
                                raise TypeError("type_as requires same tensor types")
                            mask = type_as_fn(hidden)
                        except _SCORING_EXCEPTIONS:
                            type_as_fn = None
                    if not callable(type_as_fn):
                        to_fn = getattr(mask, "to", None)
                        if callable(to_fn):
                            try:
                                mask = to_fn(dtype=hidden.dtype)
                            except _SCORING_EXCEPTIONS:
                                float_fn = getattr(mask, "float", None)
                                if callable(float_fn):
                                    mask = float_fn()
                        else:
                            float_fn = getattr(mask, "float", None)
                            if callable(float_fn):
                                mask = float_fn()
                    mask_any = cast(Any, mask)
                    pooled = (hidden * mask_any).sum(dim=1) / mask_any.sum(dim=1).clamp(
                        min=1.0
                    )
            pooled_chunks.append(pooled)
    seq_logp = (
        logp_chunks[0] if len(logp_chunks) == 1 else torch_mod.cat(logp_chunks, dim=0)
    )
    tok_tensor = (
        tok_chunks[0] if len(tok_chunks) == 1 else torch_mod.cat(tok_chunks, dim=0)
    )
    pooled_hidden: Optional[Tensor] = None
    if pooled_chunks:
        pooled_hidden = (
            pooled_chunks[0]
            if len(pooled_chunks) == 1
            else torch_mod.cat(pooled_chunks, dim=0)
        )
    entropy_sum: Optional[Tensor] = None
    if return_entropy:
        if entropy_chunks:
            entropy_sum = (
                entropy_chunks[0]
                if len(entropy_chunks) == 1
                else torch_mod.cat(entropy_chunks, dim=0)
            )
        else:
            entropy_sum = torch_mod.zeros_like(tok_tensor)
    try:
        logp_sample = seq_logp.detach().cpu().reshape(-1)[:4].tolist()
    except _SCORING_EXCEPTIONS:
        logp_sample = None
    LOG.debug(
        "chunked_sequence_logprobs finish | seq_logp_shape=%s tok_shape=%s pooled_shape=%s | "
        "seq_logp_dtype=%s tok_dtype=%s pooled_dtype=%s | seq_logp_device=%s tok_device=%s pooled_device=%s | "
        "seq_logp_numel=%s tok_numel=%s | seq_logp_preview=%s",
        getattr(seq_logp, "shape", None),
        getattr(tok_tensor, "shape", None),
        getattr(pooled_hidden, "shape", None) if pooled_hidden is not None else None,
        getattr(seq_logp, "dtype", None),
        getattr(tok_tensor, "dtype", None),
        getattr(pooled_hidden, "dtype", None) if pooled_hidden is not None else None,
        getattr(seq_logp, "device", None),
        getattr(tok_tensor, "device", None),
        getattr(pooled_hidden, "device", None) if pooled_hidden is not None else None,
        getattr(seq_logp, "numel", lambda: None)(),
        getattr(tok_tensor, "numel", lambda: None)(),
        logp_sample,
    )
    if return_token_logp:
        token_logp: Optional[Tensor] = None
        token_mask: Optional[Tensor] = None
        if token_logp_chunks:
            token_logp = (
                token_logp_chunks[0]
                if len(token_logp_chunks) == 1
                else torch_mod.cat(token_logp_chunks, dim=0)
            )
            token_mask = (
                token_mask_chunks[0]
                if len(token_mask_chunks) == 1
                else torch_mod.cat(token_mask_chunks, dim=0)
            )
        return seq_logp, tok_tensor, pooled_hidden, entropy_sum, token_logp, token_mask
    return seq_logp, tok_tensor, pooled_hidden, entropy_sum


[docs] def selective_log_softmax(logits: Tensor, index: Tensor) -> Tensor: """Memory-efficient log_softmax + gather (TRL-style).""" torch_mod = _refresh_torch() float32 = getattr(torch_mod, "float32", None) float64 = getattr(torch_mod, "float64", None) dtype = getattr(logits, "dtype", None) if dtype in {float32, float64}: try: selected_logits = torch_mod.gather( logits, dim=-1, index=index.unsqueeze(-1) ).squeeze(-1) logsumexp_values = torch_mod.stack( [torch_mod.logsumexp(lg, dim=-1) for lg in logits] ) return cast(Tensor, selected_logits - logsumexp_values) except _SCORING_EXCEPTIONS: pass per_token_logps: List[Tensor] = [] log_softmax_fn = getattr(getattr(torch_mod, "nn", None), "functional", None) log_softmax = ( getattr(log_softmax_fn, "log_softmax", None) if log_softmax_fn else None ) for row_logits, row_labels in zip(logits, index): if callable(log_softmax): row_logps = log_softmax(row_logits, dim=-1) else: logsumexp_fn = getattr(torch_mod, "logsumexp", None) if callable(logsumexp_fn): row_logps = row_logits - logsumexp_fn(row_logits, dim=-1, keepdim=True) else: # pragma: no cover - best effort fallback row_logps = row_logits row_per_token_logps = row_logps.gather( dim=-1, index=row_labels.unsqueeze(-1) ).squeeze(-1) per_token_logps.append(cast(Tensor, row_per_token_logps)) return cast(Tensor, torch_mod.stack(per_token_logps))
def _trl_get_per_token_logps( model: PreTrainedModel, input_ids: Tensor, attention_mask: Tensor, logits_to_keep: int, *, temperature: Optional[float] = None, batch_size: Optional[int] = None, ) -> Tensor: """TRL-style per-token log-probabilities for completion tokens.""" torch_mod = _refresh_torch() if logits_to_keep <= 0: return cast( Tensor, torch_mod.zeros( (int(getattr(input_ids, "shape", [0])[0] or 0), 0), dtype=getattr(torch_mod, "float32", None), device=getattr(input_ids, "device", None), ), ) temp = float(temperature if temperature is not None else 1.0) step = int(batch_size or 0) if step <= 0: step = int(getattr(input_ids, "shape", [0])[0] or 1) all_logps: List[Tensor] = [] for i in range(0, int(getattr(input_ids, "shape", [0])[0] or 0), step): input_ids_batch = input_ids[i : i + step] attention_mask_batch = attention_mask[i : i + step] logits = None try: outputs = model( input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1, ) logits = getattr(outputs, "logits", outputs) except TypeError: outputs = model( input_ids=input_ids_batch, attention_mask=attention_mask_batch ) logits = getattr(outputs, "logits", outputs) if logits is None: raise ValueError("Model forward returned no logits for TRL scoring.") logits = logits[:, :-1, :] input_ids_batch = input_ids_batch[:, -logits_to_keep:] logits = logits[:, -logits_to_keep:] if temp != 1.0: logits = logits / temp logps = selective_log_softmax(logits, input_ids_batch) all_logps.append(cast(Tensor, logps)) return cast(Tensor, torch_mod.cat(all_logps, dim=0))
[docs] def score_model_outputs( model: PreTrainedModel, score_batch: ScoreBatch, batching_cfg: BatchingSettings, runtime: RuntimeHandles, *, return_hidden: bool = False, pooling: str = "mean", return_entropy: bool = False, entropy_mode: str = "exact", return_token_logp: bool = False, ) -> Optional[ Tuple[ Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor] ] ]: """Compute current model log-probs for the batch and optional pooled states. :param model: Current policy model used for scoring. :param score_batch: Prepared scoring batch. :param batching_cfg: Batching config controlling logprob chunking. :param runtime: Runtime handles providing device and accelerator state. :param return_hidden: When ``True``, also return pooled hidden states. :param pooling: Pooling strategy applied to hidden states. :returns: Tuple of ``(cur_logp_sum, pooled_hidden[, policy_entropy_sum][, token_logp, token_mask])`` or ``None`` if empty. :rtype: tuple[Tensor, Tensor | None] | tuple[Tensor, Tensor | None, Tensor | None] | tuple[Tensor, Tensor | None, Tensor | None, Tensor | None, Tensor | None] | None """ cur_logp_slices: List[Tensor] = [] pooled_slices: List[Tensor] = [] entropy_slices: List[Tensor] = [] token_logp_slices: List[Tensor] = [] token_mask_slices: List[Tensor] = [] slice_log = _score_slice_log_enabled() progress_log = _progress_log_enabled() eos_token_id = getattr(getattr(runtime, "tokenizer", None), "eos_token_id", None) slice_iter = iter_batch_slices( score_batch, runtime.device, eos_token_id=eos_token_id, apply_eos_mask=True, ) slice_iter = _prefetch_iterator( slice_iter, getattr(batching_cfg, "slice_prefetch", 0) ) detach_token_logp = not torch.is_grad_enabled() if progress_log: LOG.info( "score_model_outputs start | total_sequences=%s slice_size=%s device=%s logprob_chunk_size=%s", score_batch.total_sequences, score_batch.slice_size, getattr(runtime.device, "type", runtime.device), getattr(batching_cfg, "logprob_chunk_size", None), ) LOG.debug( "current scoring batch metadata | total_sequences=%s slice_size=%s device=%s", score_batch.total_sequences, score_batch.slice_size, getattr(runtime.device, "type", runtime.device), ) with _autocast_context(runtime.accelerator, runtime.device): slice_idx = 0 for slice_inputs, slice_mask, slice_labels in slice_iter: if slice_log: LOG.info( "score_model_outputs slice start | idx=%d input_ids_shape=%s attention_mask_shape=%s labels_shape=%s", slice_idx, getattr(slice_inputs, "shape", None), getattr(slice_mask, "shape", None), getattr(slice_labels, "shape", None), ) slice_start = time.monotonic() LOG.debug( "current scoring slice inputs | input_ids_shape=%s attention_mask_shape=%s labels_shape=%s", getattr(slice_inputs, "shape", None), getattr(slice_mask, "shape", None), getattr(slice_labels, "shape", None), ) result = _chunked_sequence_logprobs( model, input_ids=slice_inputs, attention_mask=slice_mask, labels=slice_labels, chunk_size=batching_cfg.logprob_chunk_size, return_hidden=return_hidden, pooling=pooling, return_entropy=return_entropy, entropy_mode=entropy_mode, return_token_logp=return_token_logp, ) if slice_log: if result is None: LOG.info( "score_model_outputs slice done | idx=%d seconds=%.2f result=None", slice_idx, time.monotonic() - slice_start, ) else: seq_result = cast(Sequence[Any], result) logp_slice = seq_result[0] if len(seq_result) >= 1 else None tok_counts = seq_result[1] if len(seq_result) >= 2 else None pooled = seq_result[2] if len(seq_result) >= 3 else None entropy_sum = seq_result[3] if len(seq_result) >= 4 else None LOG.info( "score_model_outputs slice done | idx=%d seconds=%.2f logp_shape=%s tok_shape=%s pooled=%s entropy=%s", slice_idx, time.monotonic() - slice_start, getattr(logp_slice, "shape", None), getattr(tok_counts, "shape", None), getattr(pooled, "shape", None), getattr(entropy_sum, "shape", None), ) if result is None: return None cur_logp_slice, _tok_counts, pooled, entropy_sum = result[:4] token_logp = None token_mask = None if return_token_logp and len(result) >= 6: token_logp = result[4] token_mask = result[5] cur_logp_slices.append(cur_logp_slice) if pooled is not None: pooled_slices.append(pooled.detach()) if entropy_sum is not None: entropy_slices.append(entropy_sum.detach()) if return_token_logp: if token_logp is not None: token_logp_slices.append( token_logp.detach() if detach_token_logp else token_logp ) if token_mask is not None: token_mask_slices.append(token_mask.detach()) slice_idx += 1 if not cur_logp_slices: return None pooled_hidden = torch.cat(pooled_slices, dim=0) if pooled_slices else None token_logp = None token_mask = None if return_token_logp and token_logp_slices: max_len = max(getattr(t, "shape", [0, 0])[1] for t in token_logp_slices) if max_len < 0: max_len = 0 padded_logps: List[Tensor] = [] padded_masks: List[Tensor] = [] for logp_slice, mask_slice in zip(token_logp_slices, token_mask_slices): cur_len = getattr(logp_slice, "shape", [0, 0])[1] if cur_len == max_len: padded_logps.append(logp_slice) padded_masks.append(mask_slice) continue pad_len = max_len - cur_len if pad_len <= 0: padded_logps.append(logp_slice[:, :max_len]) padded_masks.append(mask_slice[:, :max_len]) continue pad_device = getattr(logp_slice, "device", None) pad_dtype = getattr(logp_slice, "dtype", None) try: pad_logp = torch.zeros( (logp_slice.shape[0], pad_len), device=pad_device, dtype=pad_dtype, ) except TypeError: pad_logp = torch.zeros((logp_slice.shape[0], pad_len)) try: pad_mask = torch.zeros( (mask_slice.shape[0], pad_len), device=getattr(mask_slice, "device", None), dtype=getattr(mask_slice, "dtype", None), ) except TypeError: pad_mask = torch.zeros((mask_slice.shape[0], pad_len)) padded_logps.append(torch.cat([logp_slice, pad_logp], dim=1)) padded_masks.append(torch.cat([mask_slice, pad_mask], dim=1)) token_logp = torch.cat(padded_logps, dim=0) if padded_logps else None token_mask = torch.cat(padded_masks, dim=0) if padded_masks else None if not return_entropy: output = ( torch.cat(cur_logp_slices, dim=0), pooled_hidden, token_logp, token_mask, ) if progress_log: LOG.info( "score_model_outputs done | slices=%d logp_shape=%s pooled=%s token_logp=%s", len(cur_logp_slices), getattr(output[0], "shape", None), getattr(pooled_hidden, "shape", None), getattr(token_logp, "shape", None), ) if return_token_logp: return output return output[:2] entropy_sum = torch.cat(entropy_slices, dim=0) if entropy_slices else None output = ( torch.cat(cur_logp_slices, dim=0), pooled_hidden, entropy_sum, token_logp, token_mask, ) if progress_log: LOG.info( "score_model_outputs done | slices=%d logp_shape=%s pooled=%s entropy=%s token_logp=%s", len(cur_logp_slices), getattr(output[0], "shape", None), getattr(pooled_hidden, "shape", None), getattr(entropy_sum, "shape", None), getattr(token_logp, "shape", None), ) if return_token_logp: return output return output[:3]
def _as_torch_tensor( torch_mod: _TorchModuleLike, value: object, *, device: Optional[TorchDevice], dtype: Optional[object], ) -> Tensor: """Best-effort conversion of ``value`` into a torch tensor on ``device``.""" ctor = getattr(torch_mod, "as_tensor", getattr(torch_mod, "tensor", None)) if ctor is None: raise RuntimeError("Torch tensor constructor unavailable") if isinstance(value, torch_mod.Tensor): tensor = value else: payload = getattr(value, "arr", None) if payload is None: payload = getattr(value, "data", value) try: tensor = ctor(payload) except _SCORING_EXCEPTIONS: tensor = ctor([]) if dtype is not None: to_fn = getattr(tensor, "to", None) if callable(to_fn): try: tensor = to_fn(dtype=dtype) except _SCORING_EXCEPTIONS: clone_fn = getattr(tensor, "clone", None) if callable(clone_fn): tensor = clone_fn() to_fn = getattr(tensor, "to", None) if callable(to_fn): tensor = to_fn(dtype=dtype) if device is not None and getattr(tensor, "device", None) != device: to_fn = getattr(tensor, "to", None) if callable(to_fn): tensor = to_fn(device=device) return cast(Tensor, tensor) def _match_tensor_length( torch_mod: _TorchModuleLike, tensor: Tensor, target_len: int, *, device: Optional[TorchDevice], dtype: Optional[object], fill_value: float = 0.0, ) -> Tensor: """Return ``tensor`` reshaped/padded to ``target_len`` elements.""" def _full(shape: Tuple[int, ...], value: float) -> Tensor: kwargs = {} if device is not None: kwargs["device"] = device if dtype is not None: kwargs["dtype"] = dtype return torch_mod.full(shape, value, **kwargs) def _zeros(shape: Tuple[int, ...]) -> Tensor: kwargs = {} if device is not None: kwargs["device"] = device if dtype is not None: kwargs["dtype"] = dtype try: return torch_mod.zeros(shape, **kwargs) except TypeError: return torch_mod.zeros(shape) tensor = tensor.view(-1) cur_len = int(getattr(tensor, "numel", lambda: 0)()) if target_len <= 0: return _zeros((0,)) if cur_len == target_len: return tensor if cur_len == 0: return _full((target_len,), fill_value) if cur_len == 1: scalar_val = float(getattr(tensor[0], "item", lambda: tensor[0])()) return _full((target_len,), scalar_val) min_len = min(cur_len, target_len) tensor = tensor[:min_len] if min_len == target_len: return tensor pad_val = float(getattr(tensor[-1], "item", lambda: tensor[-1])()) pad = _full((target_len - min_len,), pad_val) return torch_mod.cat([tensor, pad], dim=0)
[docs] def build_sequence_scores( cur_logp_sum: Tensor, ref_stats: ReferenceLogprobs, pooled_hidden: Optional[Tensor] = None, *, behavior_logp_sum: Optional[Tensor] = None, policy_entropy_sum: Optional[Tensor] = None, token_logp: Optional[Tensor] = None, token_mask: Optional[Tensor] = None, old_token_logp: Optional[Tensor] = None, ) -> SequenceScores: """Return ``SequenceScores`` built from current and reference log-probs. :param cur_logp_sum: Current policy log-prob sums per sequence. :param ref_stats: Reference log-prob stats used for KL and weighting. :param pooled_hidden: Optional pooled hidden states for auxiliary losses. :param behavior_logp_sum: Optional behavior-policy log-probs for off-policy scoring. :returns: ``SequenceScores`` dataclass with normalized log-probs and KL terms. :rtype: SequenceScores """ torch_mod = _refresh_torch() base_dtype = getattr(torch_mod, "float32", None) cur_tensor = _as_torch_tensor( torch_mod, cur_logp_sum, device=getattr(cur_logp_sum, "device", None), dtype=base_dtype, ).view(-1) device = getattr(cur_tensor, "device", None) cur_len = int(getattr(cur_tensor, "numel", lambda: 0)()) ref_source = getattr(ref_stats, "ref_logp_sum_raw", None) if ref_source is None: ref_source = getattr(ref_stats, "ref_logp_sum", None) ref_tensor = _as_torch_tensor( torch_mod, ref_source if ref_source is not None else [], device=device, dtype=base_dtype, ) ref_tensor = _match_tensor_length( torch_mod, ref_tensor, cur_len, device=device, dtype=base_dtype, fill_value=0.0, ) denom_source = getattr(ref_stats, "ref_tok_counts", None) denom_tensor = _as_torch_tensor( torch_mod, denom_source if denom_source is not None else [], device=device, dtype=base_dtype, ) denom_tensor = _match_tensor_length( torch_mod, denom_tensor, cur_len, device=device, dtype=base_dtype, fill_value=1.0, ).clamp(min=1.0) if behavior_logp_sum is None: behavior_tensor = cur_tensor.detach() else: behavior_tensor = _as_torch_tensor( torch_mod, behavior_logp_sum, device=device, dtype=base_dtype, ) behavior_tensor = _match_tensor_length( torch_mod, behavior_tensor, cur_len, device=device, dtype=base_dtype, fill_value=0.0, ) policy_entropy_tensor: Optional[Tensor] = None if policy_entropy_sum is not None: policy_entropy_tensor = _as_torch_tensor( torch_mod, policy_entropy_sum, device=device, dtype=base_dtype, ) policy_entropy_tensor = _match_tensor_length( torch_mod, policy_entropy_tensor, cur_len, device=device, dtype=base_dtype, fill_value=0.0, ) token_logp_tensor: Optional[Tensor] = None token_mask_tensor: Optional[Tensor] = None old_token_logp_tensor: Optional[Tensor] = None if token_logp is not None: token_logp_tensor = _as_torch_tensor( torch_mod, token_logp, device=device, dtype=base_dtype, ) if token_mask is not None: token_mask_tensor = _as_torch_tensor( torch_mod, token_mask, device=device, dtype=base_dtype, ) if old_token_logp is not None: old_token_logp_tensor = _as_torch_tensor( torch_mod, old_token_logp, device=device, dtype=base_dtype, ) if token_logp_tensor is not None and old_token_logp_tensor is None: try: old_token_logp_tensor = token_logp_tensor.detach() except _SCORING_EXCEPTIONS: old_token_logp_tensor = _as_torch_tensor( torch_mod, token_logp_tensor, device=device, dtype=base_dtype, ) log_ratio_train = cur_tensor - ref_tensor if getattr(log_ratio_train, "numel", lambda: 0)() == 0 and cur_len > 0: log_ratio_train = torch_mod.zeros((cur_len,), device=device, dtype=base_dtype) return SequenceScores( cur_logp_sum=cur_tensor, behavior_logp_sum=behavior_tensor, log_ratio_train=log_ratio_train, denom_tok_tensor=denom_tensor, pooled_hidden=pooled_hidden, policy_entropy_sum=policy_entropy_tensor, token_logp=token_logp_tensor, token_mask=token_mask_tensor, old_token_logp=old_token_logp_tensor, )