RiskCoverageMetric

Build a risk-coverage curve from scores and per-sample errors.

Usage

Source

RiskCoverageMetric()

Collects per-sample scores and per-sample residuals across multiple update calls and computes summary area-under-curve values using seapig.risk_coverage.risk_coverage.

Parameters

risk: (generalized, selective) = "generalized"

Which risk definition to use when computing the curve. Must be either 'generalized' or 'selective'.

n_bins: int = 100

Number of bins used to downsample the curve when computing AUC summaries.

error_metric: torchmetrics.Metric | torchmetrics.MetricCollection | None = None

Metric or collection that computes per-sample residuals. It must return a 1-D tensor of shape (batch,) when compute is called.

error_fn: callable or None = None
Deprecated legacy function (preds, target) -> residuals. Use error_metric instead.

Notes

The compute method returns three tensors: rc/auc_empirical, rc/auc_reference, and rc/auc_excess. The last computed complete curve object (RiskCoverage) is available via get_curve.

Examples

import torch
from seapig.metric import RiskCoverageMetric

metric = RiskCoverageMetric(risk="generalized")
preds = torch.rand(50, 1)
target = torch.rand(50, 1)
scores = torch.rand(50)
metric.update(preds, target, scores)
results = metric.compute()
print(results)
_ = metric.get_curve().plot()
{'rc/auc_empirical': tensor(0.2052), 'rc/auc_reference': tensor(0.1323), 'rc/auc_excess': tensor(0.0729)}

Methods

Name Description
compute() Compute AUC summaries for the accumulated risk–coverage curve.
reset() Reset all accumulated state.
update() Store scores and residuals for later curve computation.
get_curve() Return the last computed curve(s), or None if not computed yet.

compute()

Compute AUC summaries for the accumulated risk–coverage curve.

Usage

Source

compute()

Returns a dict with keys rc/auc_empirical, rc/auc_reference, and rc/auc_excess. If multiple residual streams are present, results are prefixed with the corresponding metric name.


reset()

Reset all accumulated state.

Usage

Source

reset()

The generic scores and residuals buffers are cleared, as are any per-metric residual buffers created when error_metric is a MetricCollection. After a reset, accessing a per-metric residual attribute before the next update call will yield an empty tensor. A UserWarning is emitted to make this behaviour explicit.


update()

Store scores and residuals for later curve computation.

Usage

Source

update(preds, target, scores)
Parameters
preds: torch.Tensor

Model outputs and targets. These are passed to error_fn to compute per-sample residuals.

target: torch.Tensor

Model outputs and targets. These are passed to error_fn to compute per-sample residuals.

scores: torch.Tensor
Per-sample uncertainty scores (lower values indicate lower uncertainty).

get_curve()

Return the last computed curve(s), or None if not computed yet.

Usage

Source

get_curve(metric_name=None)

See Also

RiskCoverage