scores.SoftmaxScore

Maximum softmax probability uncertainty score.

Usage

Source

scores.SoftmaxScore()

Supports multiclass, binary (single/two-logit), and multilabel tasks. Higher maximum softmax probability indicates higher uncertainty (higher score).

Parameters

temperature: float or None = None

Optional initial temperature. If None, temperature is fitted if labels are provided to fit.

task: ("multiclass", "binary", "multilabel") = "multiclass"
Task type for score computation.

Examples

import torch
from seapig.scores.logits import SoftmaxScore
logits = torch.randn(2, 4)
SoftmaxScore().score(logits)

Methods

Name Description
fit() Fit the score on reference logits.
score() Compute task-aware softmax-based uncertainty score.
select() Select samples for prediction based on their uncertainty score.
set_threshold() Set a threshold based on a specific quantile on the available scores.
get_threshold() Get the current threshold value.
plot() Plot densities for uncertainty scores.

fit()

Fit the score on reference logits.

Usage

Source

fit(
    X=None,
    Y=None,
    model=None,
    loader=None,
    outdir=None,
    prefix=None,
    *args,
    **kwargs
)

This method supports two usage modes:

  1. Precomputed logits: Supply logits directly via X, with optional labels via Y for temperature fitting.
  2. On-the-fly extraction: Supply a model with a .logits() method and a DataLoader to extract logits automatically.

You must use either logits OR model+loader, but not both.

Parameters
X: torch.Tensor or None = None

Reference logits. Shape depends on task (see class docstring). Required when not using model and loader.

Y: torch.Tensor or None = None

Optional labels for temperature fitting. Shape/type depends on task.

model: torch.nn.Module or None = None

Model with a .logits(x) method. Required when not using precomputed logits.

loader: DataLoader or None = None

DataLoader yielding batches for inference. Required when using model.

outdir: Path or str or None = None

Optional directory to save/load logits. Only used with model and loader.

prefix: str or None = None
Optional prefix for saved files. Only used with model and loader.
Notes

If labels are provided, temperature is fitted to minimize NLL for the task.


score()

Compute task-aware softmax-based uncertainty score.

Usage

Source

score(query_logits)

For multiclass: -max softmax probability. For binary single-logit: -sigmoid(|logit|). For binary two-logit: -max softmax probability. For multilabel: -min(max(p, 1-p)), where p = sigmoid(logit).

Parameters
query_logits: torch.Tensor
Logits for samples to score. Shape depends on task.
Returns
torch.Tensor
1-D tensor of shape (M,). Lower values indicate lower uncertainty.

select()

Select samples for prediction based on their uncertainty score.

Usage

Source

select(query_logits)

Samples with scores lower than the threshold are selected for prediction, while samples with scores higher than the threshold are excluded.

Parameters
query_logits: torch.Tensor
Logits for samples to select. Shape depends on task.
Returns
dict[str, torch.Tensor]
A dict with keys 'score' (uncertainty scores) and 'selected' (boolean mask where True means the sample is selected).

set_threshold()

Set a threshold based on a specific quantile on the available scores.

Usage

Source

set_threshold(q=0.99)

Samples with scores higher than this threshold are excluded from prediction.

Parameters
q: float = 0.99
Quantile in the interval (0, 1) used to compute the threshold from the stored calibration scores. Defaults to 0.99.
Raises
ValueError
If no calibration scores are available yet.

get_threshold()

Get the current threshold value.

Usage

Source

get_threshold()

plot()

Plot densities for uncertainty scores.

Usage

Source

plot(query_scores=None, bins=100)

By default, this method plots densities for the uncertainty scores. Optionally, it can also plot densities for query_scores.

Parameters
query_scores: torch.Tensor | None = None

A torch.Tensor representing query scores to include in the plot. Defaults to None.

bins: int = 100
An int indicating the number of bins to use for density estimation. Defaults to 100.