SelectiveMetric

Evaluate a metric on full, selected, and rejected subsets.

Usage

Source

SelectiveMetric()

Wraps a torchmetrics.Metric or torchmetrics.MetricCollection and keeps three independent copies that are updated separately:

  • "full": all samples passed to update.
  • "selected": samples where the provided selection mask is true.
  • "rejected": samples where the selection mask is false.

The compute result is a flat dict[str, torch.Tensor] where each underlying metric name is prefixed with full/, selected/, or rejected/. If a submetric was never updated, its value is a zero torch.Tensor.

Parameters

base: torchmetrics.Metric | torchmetrics.MetricCollection
Metric (or collection) to wrap. Internally the object is deep- copied three times so each subset is tracked independently.

Notes

  • The selection mask may be boolean or numeric; numeric values > 0 are treated as selected.
  • Calls that contain no selected (or no rejected) rows do not update the corresponding internal metric for that call.

Example

import torch
from torchmetrics import Accuracy
from seapig import SelectiveMetric

base = Accuracy(task="binary")
m = SelectiveMetric(base)
preds = torch.tensor([[0.9, 0.1], [0.2, 0.8]])
target = torch.tensor([[1.0, 1], [1, 0]])
selected = torch.tensor([1, 0], dtype=torch.bool)
m.update(preds, target, selected)
results = m.compute()
print(results)
{'full/BinaryAccuracy': tensor(0.2500), 'selected/BinaryAccuracy': tensor(0.5000), 'rejected/BinaryAccuracy': tensor(0.)}

Methods

Name Description
compute() Compute metrics and return a dict with prefixed keys.
reset() Reset the internal metric instances.
update() Update full, selected, and rejected metrics.

compute()

Compute metrics and return a dict with prefixed keys.

Usage

Source

compute()

Returns a dictionary where each key is <scope>/<metric_name> (scope is one of full, selected, rejected). Values are torch.Tensors. If a metric instance was never updated, its value is a scalar zero tensor.


reset()

Reset the internal metric instances.

Usage

Source

reset()

update()

Update full, selected, and rejected metrics.

Usage

Source

update(preds, target, selected)
Parameters
preds: torch.Tensor

Model predictions of shape (B, …).

target: torch.Tensor

Target tensor of shape (B, …).

selected: torch.Tensor
Boolean or binary selection mask of shape (B,). Values > 0 are treated as selected.