scores.EntropyScore
Entropy-based uncertainty score.
Usage
scores.EntropyScore()Computes the predictive entropy of the output distribution. Lower entropy indicates lower uncertainty. Supports multiclass, binary, and multilabel tasks.
Parameters
temperature: float or None = None-
Optional initial temperature. If
None, temperature is fitted if labels are provided to fit. task: ("multiclass", "binary", "multilabel") = "multiclass"- Task type for score computation.
Examples
import torch
from seapig.scores.logits import EntropyScore
logits = torch.randn(2, 3)
EntropyScore().score(logits)Methods
| Name | Description |
|---|---|
| fit() | Fit the score on reference logits. |
| score() | Compute predictive entropy for each sample (task-aware). |
| select() | Select samples for prediction based on their uncertainty score. |
| set_threshold() | Set a threshold based on a specific quantile on the available scores. |
| get_threshold() | Get the current threshold value. |
| plot() | Plot densities for uncertainty scores. |
fit()
Fit the score on reference logits.
Usage
fit(
X=None,
Y=None,
model=None,
loader=None,
outdir=None,
prefix=None,
*args,
**kwargs
)This method supports two usage modes:
- Precomputed logits: Supply logits directly via
X, with optional labels viaYfor temperature fitting. - On-the-fly extraction: Supply a
modelwith a.logits()method and aDataLoaderto extract logits automatically.
You must use either logits OR model+loader, but not both.
Parameters
X: torch.Tensor or None = None-
Reference logits. Shape depends on task (see class docstring). Required when not using
modelandloader. Y: torch.Tensor or None = None-
Optional labels for temperature fitting. Shape/type depends on task.
model: torch.nn.Module or None = None-
Model with a
.logits(x)method. Required when not using precomputed logits. loader: DataLoader or None = None-
DataLoader yielding batches for inference. Required when using
model. outdir: Path or str or None = None-
Optional directory to save/load logits. Only used with
modelandloader. prefix: str or None = None-
Optional prefix for saved files. Only used with
modelandloader.
Notes
If labels are provided, temperature is fitted to minimize NLL for the task.
score()
Compute predictive entropy for each sample (task-aware).
Usage
score(query_logits)Parameters
query_logits: torch.Tensor- Logits for samples to score. Shape depends on task.
Returns
torch.Tensor-
1-D tensor of shape
(M,). Lower values indicate lower uncertainty.
select()
Select samples for prediction based on their uncertainty score.
Usage
select(query_logits)Samples with scores lower than the threshold are selected for prediction, while samples with scores higher than the threshold are excluded.
Parameters
query_logits: torch.Tensor- Logits for samples to select. Shape depends on task.
Returns
dict[str, torch.Tensor]-
A dict with keys
'score'(uncertainty scores) and'selected'(boolean mask whereTruemeans the sample is selected).
set_threshold()
Set a threshold based on a specific quantile on the available scores.
Usage
set_threshold(q=0.99)Samples with scores higher than this threshold are excluded from prediction.
Parameters
q: float = 0.99-
Quantile in the interval
(0, 1)used to compute the threshold from the stored calibration scores. Defaults to0.99.
Raises
ValueError- If no calibration scores are available yet.
get_threshold()
Get the current threshold value.
Usage
get_threshold()plot()
Plot densities for uncertainty scores.
Usage
plot(query_scores=None, bins=100)By default, this method plots densities for the uncertainty scores. Optionally, it can also plot densities for query_scores.
Parameters
query_scores: torch.Tensor | None = None-
A
torch.Tensorrepresenting query scores to include in the plot. Defaults toNone. bins: int = 100-
An
intindicating the number of bins to use for density estimation. Defaults to100.