scores.LogitScore
Base class for logit-based uncertainty scores.
Usage
scores.LogitScore()Supports multiclass, binary (single/two-logit), and multilabel tasks. Handles temperature fitting and input normalization for all cases.
Parameters
temperature: float or None = None-
Optional temperature to apply to logits. If
None, no temperature scaling is applied until fit() is called. task: ("multiclass", "binary", "multilabel") = "multiclass"- Type of classification task. Determines score computation and temperature fitting loss.
Notes
Input shapes and label formats by task:
multiclass: logits(N, C), labels(N,)longbinarysingle-logit: logits(N,)or(N, 1), labels(N,)float/longbinarytwo-logit: logits(N, 2), labels(N,)longmultilabel: logits(N, C), labels(N, C)float
Examples
import torch
from seapig.scores.logits import SoftmaxScore
logits = torch.randn(4, 3)
score = SoftmaxScore()
score.score(logits)Methods
| Name | Description |
|---|---|
| fit() | Fit the score on reference logits. |
| score() | Compute uncertainty scores for query logits. |
| select() | Select samples for prediction based on their uncertainty score. |
fit()
Fit the score on reference logits.
Usage
fit(
X=None,
Y=None,
model=None,
loader=None,
outdir=None,
prefix=None,
*args,
**kwargs
)This method supports two usage modes:
- Precomputed logits: Supply logits directly via
X, with optional labels viaYfor temperature fitting. - On-the-fly extraction: Supply a
modelwith a.logits()method and aDataLoaderto extract logits automatically.
You must use either logits OR model+loader, but not both.
Parameters
X: torch.Tensor or None = None-
Reference logits. Shape depends on task (see class docstring). Required when not using
modelandloader. Y: torch.Tensor or None = None-
Optional labels for temperature fitting. Shape/type depends on task.
model: torch.nn.Module or None = None-
Model with a
.logits(x)method. Required when not using precomputed logits. loader: DataLoader or None = None-
DataLoader yielding batches for inference. Required when using
model. outdir: Path or str or None = None-
Optional directory to save/load logits. Only used with
modelandloader. prefix: str or None = None-
Optional prefix for saved files. Only used with
modelandloader.
Notes
If labels are provided, temperature is fitted to minimize NLL for the task.
score()
Compute uncertainty scores for query logits.
Usage
score(query_logits)Parameters
query_logits: torch.Tensor- Logits for samples to score. Shape depends on task.
Returns
torch.Tensor-
1-D tensor of shape
(M,). Lower values indicate lower uncertainty.
select()
Select samples for prediction based on their uncertainty score.
Usage
select(query_logits)Samples with scores lower than the threshold are selected for prediction, while samples with scores higher than the threshold are excluded.
Parameters
query_logits: torch.Tensor- Logits for samples to select. Shape depends on task.
Returns
dict[str, torch.Tensor]-
A dict with keys
'score'(uncertainty scores) and'selected'(boolean mask whereTruemeans the sample is selected).