scores.EnergyScore
Energy-based uncertainty score.
Usage
scores.EnergyScore()Computes the free energy of the logit distribution. Lower energy indicates lower uncertainty. Supports multiclass, binary, and multilabel tasks.
Parameters
temperature: float or None = None-
Optional initial temperature. If
None, temperature is fitted if labels are provided to fit. task: ("multiclass", "binary", "multilabel") = "multiclass"- Task type for score computation.
Examples
import torch
from seapig.scores.logits import EnergyScore
logits = torch.randn(2, 3)
EnergyScore().score(logits)Methods
| Name | Description |
|---|---|
| fit() | Fit the score on reference logits. |
| score() | Compute energy for query logits (task-aware). |
| select() | Select samples for prediction based on their uncertainty score. |
| set_threshold() | Set a threshold based on a specific quantile on the available scores. |
| get_threshold() | Get the current threshold value. |
| plot() | Plot densities for uncertainty scores. |
fit()
Fit the score on reference logits.
Usage
fit(
X=None,
Y=None,
model=None,
loader=None,
outdir=None,
prefix=None,
*args,
**kwargs
)This method supports two usage modes:
- Precomputed logits: Supply logits directly via
X, with optional labels viaYfor temperature fitting. - On-the-fly extraction: Supply a
modelwith a.logits()method and aDataLoaderto extract logits automatically.
You must use either logits OR model+loader, but not both.
Parameters
X: torch.Tensor or None = None-
Reference logits. Shape depends on task (see class docstring). Required when not using
modelandloader. Y: torch.Tensor or None = None-
Optional labels for temperature fitting. Shape/type depends on task.
model: torch.nn.Module or None = None-
Model with a
.logits(x)method. Required when not using precomputed logits. loader: DataLoader or None = None-
DataLoader yielding batches for inference. Required when using
model. outdir: Path or str or None = None-
Optional directory to save/load logits. Only used with
modelandloader. prefix: str or None = None-
Optional prefix for saved files. Only used with
modelandloader.
Notes
If labels are provided, temperature is fitted to minimize NLL for the task.
score()
Compute energy for query logits (task-aware).
Usage
score(query_logits)Parameters
query_logits: torch.Tensor- Logits for samples to score. Shape depends on task.
Returns
torch.Tensor-
1-D tensor of shape
(M,). Lower values indicate lower uncertainty.
select()
Select samples for prediction based on their uncertainty score.
Usage
select(query_logits)Samples with scores lower than the threshold are selected for prediction, while samples with scores higher than the threshold are excluded.
Parameters
query_logits: torch.Tensor- Logits for samples to select. Shape depends on task.
Returns
dict[str, torch.Tensor]-
A dict with keys
'score'(uncertainty scores) and'selected'(boolean mask whereTruemeans the sample is selected).
set_threshold()
Set a threshold based on a specific quantile on the available scores.
Usage
set_threshold(q=0.99)Samples with scores higher than this threshold are excluded from prediction.
Parameters
q: float = 0.99-
Quantile in the interval
(0, 1)used to compute the threshold from the stored calibration scores. Defaults to0.99.
Raises
ValueError- If no calibration scores are available yet.
get_threshold()
Get the current threshold value.
Usage
get_threshold()plot()
Plot densities for uncertainty scores.
Usage
plot(query_scores=None, bins=100)By default, this method plots densities for the uncertainty scores. Optionally, it can also plot densities for query_scores.
Parameters
query_scores: torch.Tensor | None = None-
A
torch.Tensorrepresenting query scores to include in the plot. Defaults toNone. bins: int = 100-
An
intindicating the number of bins to use for density estimation. Defaults to100.