SelectiveMetric

Evaluate a metric on full, selected, and rejected subsets.

Usage

Source

SelectiveMetric()

Wraps a torchmetrics.Metric or torchmetrics.MetricCollection and keeps three independent copies that are updated separately:

  • "full": all samples passed to update.
  • "selected": samples where the provided selection mask is true.
  • "rejected": samples where the selection mask is false.

The compute result is a flat dict[str, torch.Tensor] where each underlying metric name is prefixed with full/, selected/, or rejected/. If a submetric was never updated, its value is a zero torch.Tensor.

Parameters

base: torchmetrics.Metric | torchmetrics.MetricCollection
Metric (or collection) to wrap. Internally the object is deep- copied three times so each subset is tracked independently.

Notes

  • The selection mask may be boolean or numeric; numeric values > 0 are treated as selected.
  • Calls that contain no selected (or no rejected) rows do not update the corresponding internal metric for that call.

Example

from torchmetrics import Accuracy
base = Accuracy(task="binary")
m = SelectiveMetric(base)
preds = torch.tensor([[0.9, 0.1], [0.2, 0.8]])
target = torch.tensor([0, 1])
mask = torch.tensor([1, 0], dtype=torch.bool)
m.update(preds, target, mask)
results = m.compute()
# results contains keys like 'full/accuracy', 'selected/accuracy', ...

Methods

Name Description
compute() Compute metrics and return a dict with prefixed keys.
reset() Reset the internal metric instances.
update() Update full, selected, and rejected metrics.

compute()

Compute metrics and return a dict with prefixed keys.

Usage

Source

compute()

Returns a dictionary where each key is <scope>/<metric_name> (scope is one of full, selected, rejected). Values are torch.Tensors. If a metric instance was never updated, its value is a scalar zero tensor.


reset()

Reset the internal metric instances.

Usage

Source

reset()

update()

Update full, selected, and rejected metrics.

Usage

Source

update(preds, target, selected)
Parameters
preds: torch.Tensor

Model predictions of shape (B, …).

target: torch.Tensor

Target tensor of shape (B, …).

selected: torch.Tensor
Boolean or binary selection mask of shape (B,). Values > 0 are treated as selected.