scores.LogitScore

Base class for logit-based confidence scores.

Usage

Source

scores.LogitScore()

Supports multiclass, binary (single/two-logit), and multilabel tasks. Handles temperature fitting and input normalization for all cases.

Parameters

temperature: float or None = None

Optional temperature to apply to logits. If None, no temperature scaling is applied until fit() is called.

task: ("multiclass", "binary", "multilabel") = "multiclass"
Type of classification task. Determines score computation and temperature fitting loss.

Notes

Input shapes and label formats by task:

  • multiclass: logits (N, C), labels (N,) long
  • binary single-logit: logits (N,) or (N, 1), labels (N,) float/long
  • binary two-logit: logits (N, 2), labels (N,) long
  • multilabel: logits (N, C), labels (N, C) float

Examples

import torch
from seapig.scores.logits import SoftmaxScore
logits = torch.randn(4, 3)
score = SoftmaxScore()
score.score(logits)

Methods

Name Description
fit() Fit the score on reference logits.
score() Compute confidence scores for query logits.
select() Select samples for prediction based on their confidence score.

fit()

Fit the score on reference logits.

Usage

Source

fit(
    X=None,
    Y=None,
    model=None,
    loader=None,
    outdir=None,
    prefix=None,
    *args,
    **kwargs
)

This method supports two usage modes:

  1. Precomputed logits: Supply logits directly via X, with optional labels via Y for temperature fitting.
  2. On-the-fly extraction: Supply a model with a .logits() method and a DataLoader to extract logits automatically.

You must use either logits OR model+loader, but not both.

Parameters
X: torch.Tensor or None = None

Reference logits. Shape depends on task (see class docstring). Required when not using model and loader.

Y: torch.Tensor or None = None

Optional labels for temperature fitting. Shape/type depends on task.

model: torch.nn.Module or None = None

Model with a .logits(x) method. Required when not using precomputed logits.

loader: DataLoader or None = None

DataLoader yielding batches for inference. Required when using model.

outdir: Path or str or None = None

Optional directory to save/load logits. Only used with model and loader.

prefix: str or None = None
Optional prefix for saved files. Only used with model and loader.
Notes

If labels are provided, temperature is fitted to minimize NLL for the task.


score()

Compute confidence scores for query logits.

Usage

Source

score(query_logits)
Parameters
query_logits: torch.Tensor
Logits for samples to score. Shape depends on task.
Returns
torch.Tensor
1-D tensor of shape (M,). Lower values indicate higher confidence.

select()

Select samples for prediction based on their confidence score.

Usage

Source

select(query_logits)

Samples with scores lower than the threshold are selected for prediction, while samples with scores higher than the threshold are excluded.

Parameters
query_logits: torch.Tensor
Logits for samples to select. Shape depends on task.
Returns
dict[str, torch.Tensor]
A dict with keys 'score' (confidence scores) and 'selected' (boolean mask where True means the sample is selected).

See Also