scores.EnergyScore

Energy-based uncertainty score.

Usage

Source

scores.EnergyScore()

Computes the free energy of the logit distribution. Lower energy indicates lower uncertainty. Supports multiclass, binary, and multilabel tasks.

Parameters

temperature: float or None = None

Optional initial temperature. If None, temperature is fitted if labels are provided to fit.

task: ("multiclass", "binary", "multilabel") = "multiclass"
Task type for score computation.

Examples

import torch
from seapig.scores.logits import EnergyScore
logits = torch.randn(2, 3)
EnergyScore().score(logits)

Methods

Name Description
fit() Fit the score on reference logits.
score() Compute energy for query logits (task-aware).
select() Select samples for prediction based on their uncertainty score.
set_threshold() Set a threshold based on a specific quantile on the available scores.
get_threshold() Get the current threshold value.
plot() Plot densities for uncertainty scores.

fit()

Fit the score on reference logits.

Usage

Source

fit(
    X=None,
    Y=None,
    model=None,
    loader=None,
    outdir=None,
    prefix=None,
    *args,
    **kwargs
)

This method supports two usage modes:

  1. Precomputed logits: Supply logits directly via X, with optional labels via Y for temperature fitting.
  2. On-the-fly extraction: Supply a model with a .logits() method and a DataLoader to extract logits automatically.

You must use either logits OR model+loader, but not both.

Parameters
X: torch.Tensor or None = None

Reference logits. Shape depends on task (see class docstring). Required when not using model and loader.

Y: torch.Tensor or None = None

Optional labels for temperature fitting. Shape/type depends on task.

model: torch.nn.Module or None = None

Model with a .logits(x) method. Required when not using precomputed logits.

loader: DataLoader or None = None

DataLoader yielding batches for inference. Required when using model.

outdir: Path or str or None = None

Optional directory to save/load logits. Only used with model and loader.

prefix: str or None = None
Optional prefix for saved files. Only used with model and loader.
Notes

If labels are provided, temperature is fitted to minimize NLL for the task.


score()

Compute energy for query logits (task-aware).

Usage

Source

score(query_logits)
Parameters
query_logits: torch.Tensor
Logits for samples to score. Shape depends on task.
Returns
torch.Tensor
1-D tensor of shape (M,). Lower values indicate lower uncertainty.

select()

Select samples for prediction based on their uncertainty score.

Usage

Source

select(query_logits)

Samples with scores lower than the threshold are selected for prediction, while samples with scores higher than the threshold are excluded.

Parameters
query_logits: torch.Tensor
Logits for samples to select. Shape depends on task.
Returns
dict[str, torch.Tensor]
A dict with keys 'score' (uncertainty scores) and 'selected' (boolean mask where True means the sample is selected).

set_threshold()

Set a threshold based on a specific quantile on the available scores.

Usage

Source

set_threshold(q=0.99)

Samples with scores higher than this threshold are excluded from prediction.

Parameters
q: float = 0.99
Quantile in the interval (0, 1) used to compute the threshold from the stored calibration scores. Defaults to 0.99.
Raises
ValueError
If no calibration scores are available yet.

get_threshold()

Get the current threshold value.

Usage

Source

get_threshold()

plot()

Plot densities for uncertainty scores.

Usage

Source

plot(query_scores=None, bins=100)

By default, this method plots densities for the uncertainty scores. Optionally, it can also plot densities for query_scores.

Parameters
query_scores: torch.Tensor | None = None

A torch.Tensor representing query scores to include in the plot. Defaults to None.

bins: int = 100
An int indicating the number of bins to use for density estimation. Defaults to 100.