SelectiveMetric
Evaluate a metric on full, selected, and rejected subsets.
Usage
SelectiveMetric()Wraps a torchmetrics.Metric or torchmetrics.MetricCollection and keeps three independent copies that are updated separately:
"full": all samples passed to update."selected": samples where the provided selection mask is true."rejected": samples where the selection mask is false.
The compute result is a flat dict[str, torch.Tensor] where each underlying metric name is prefixed with full/, selected/, or rejected/. If a submetric was never updated, its value is a zero torch.Tensor.
Parameters
base: torchmetrics.Metric | torchmetrics.MetricCollection- Metric (or collection) to wrap. Internally the object is deep- copied three times so each subset is tracked independently.
Notes
- The selection mask may be boolean or numeric; numeric values
> 0are treated as selected. - Calls that contain no selected (or no rejected) rows do not update the corresponding internal metric for that call.
Example
from torchmetrics import Accuracy
base = Accuracy(task="binary")
m = SelectiveMetric(base)
preds = torch.tensor([[0.9, 0.1], [0.2, 0.8]])
target = torch.tensor([0, 1])
mask = torch.tensor([1, 0], dtype=torch.bool)
m.update(preds, target, mask)
results = m.compute()
# results contains keys like 'full/accuracy', 'selected/accuracy', ...Methods
| Name | Description |
|---|---|
| compute() | Compute metrics and return a dict with prefixed keys. |
| reset() | Reset the internal metric instances. |
| update() | Update full, selected, and rejected metrics. |
compute()
Compute metrics and return a dict with prefixed keys.
Usage
compute()Returns a dictionary where each key is <scope>/<metric_name> (scope is one of full, selected, rejected). Values are torch.Tensors. If a metric instance was never updated, its value is a scalar zero tensor.
reset()
Reset the internal metric instances.
Usage
reset()update()
Update full, selected, and rejected metrics.
Usage
update(preds, target, selected)Parameters
preds: torch.Tensor-
Model predictions of shape (B, …).
target: torch.Tensor-
Target tensor of shape (B, …).
selected: torch.Tensor- Boolean or binary selection mask of shape (B,). Values > 0 are treated as selected.