Evaluate a metric on full, selected, and rejected subsets.
Wraps a torchmetrics.Metric or torchmetrics.MetricCollection and keeps three independent copies that are updated separately:
"full": all samples passed to update.
"selected": samples where the provided selection mask is true.
"rejected": samples where the selection mask is false.
The compute result is a flat dict[str, torch.Tensor] where each underlying metric name is prefixed with full/, selected/, or rejected/. If a submetric was never updated, its value is a zero torch.Tensor.
Parameters
base: torchmetrics.Metric | torchmetrics.MetricCollection
-
Metric (or collection) to wrap. Internally the object is deep- copied three times so each subset is tracked independently.
Notes
- The selection mask may be boolean or numeric; numeric values
> 0 are treated as selected.
- Calls that contain no selected (or no rejected) rows do not update the corresponding internal metric for that call.
Example
import torch
from torchmetrics import Accuracy
from seapig import SelectiveMetric
base = Accuracy(task="binary")
m = SelectiveMetric(base)
preds = torch.tensor([[0.9, 0.1], [0.2, 0.8]])
target = torch.tensor([[1.0, 1], [1, 0]])
selected = torch.tensor([1, 0], dtype=torch.bool)
m.update(preds, target, selected)
results = m.compute()
print(results)
{'full/BinaryAccuracy': tensor(0.2500), 'selected/BinaryAccuracy': tensor(0.5000), 'rejected/BinaryAccuracy': tensor(0.)}
Methods
|
Name
|
Description
|
|
compute()
|
Compute metrics and return a dict with prefixed keys.
|
|
reset()
|
Reset the internal metric instances.
|
|
update()
|
Update full, selected, and rejected metrics.
|
compute()
Compute metrics and return a dict with prefixed keys.
Returns a dictionary where each key is <scope>/<metric_name> (scope is one of full, selected, rejected). Values are torch.Tensors. If a metric instance was never updated, its value is a scalar zero tensor.
reset()
Reset the internal metric instances.
update()
Update full, selected, and rejected metrics.
update(preds, target, selected)
Parameters
preds: torch.Tensor
-
Model predictions of shape (B, …).
target: torch.Tensor
-
Target tensor of shape (B, …).
selected: torch.Tensor
-
Boolean or binary selection mask of shape (B,). Values > 0 are treated as selected.