Build a risk-coverage curve from scores and per-sample errors.
Collects per-sample scores and per-sample residuals across multiple update calls and computes summary area-under-curve values using seapig.risk_coverage.risk_coverage.
Parameters
risk: (generalized, selective) = "generalized"
-
Which risk definition to use when computing the curve. Must be either 'generalized' or 'selective'.
n_bins: int = 100
-
Number of bins used to downsample the curve when computing AUC summaries.
error_metric: torchmetrics.Metric | torchmetrics.MetricCollection | None = None
-
Metric or collection that computes per-sample residuals. It must return a 1-D tensor of shape (batch,) when compute is called.
error_fn: callable or None = None
-
Deprecated legacy function
(preds, target) -> residuals. Use error_metric instead.
Notes
The compute method returns three tensors: rc/auc_empirical, rc/auc_reference, and rc/auc_excess. The last computed complete curve object (RiskCoverage) is available via get_curve.
Examples
import torch
from seapig.metric import RiskCoverageMetric
metric = RiskCoverageMetric(risk="generalized")
preds = torch.rand(50, 1)
target = torch.rand(50, 1)
scores = torch.rand(50)
metric.update(preds, target, scores)
results = metric.compute()
print(results)
_ = metric.get_curve().plot()
{'rc/auc_empirical': tensor(0.2052), 'rc/auc_reference': tensor(0.1323), 'rc/auc_excess': tensor(0.0729)}
Methods
|
Name
|
Description
|
|
compute()
|
Compute AUC summaries for the accumulated risk–coverage curve.
|
|
reset()
|
Reset all accumulated state.
|
|
update()
|
Store scores and residuals for later curve computation.
|
|
get_curve()
|
Return the last computed curve(s), or None if not computed yet.
|
compute()
Compute AUC summaries for the accumulated risk–coverage curve.
Returns a dict with keys rc/auc_empirical, rc/auc_reference, and rc/auc_excess. If multiple residual streams are present, results are prefixed with the corresponding metric name.
reset()
Reset all accumulated state.
The generic scores and residuals buffers are cleared, as are any per-metric residual buffers created when error_metric is a MetricCollection. After a reset, accessing a per-metric residual attribute before the next update call will yield an empty tensor. A UserWarning is emitted to make this behaviour explicit.
update()
Store scores and residuals for later curve computation.
update(preds, target, scores)
Parameters
preds: torch.Tensor
-
Model outputs and targets. These are passed to error_fn to compute per-sample residuals.
target: torch.Tensor
-
Model outputs and targets. These are passed to error_fn to compute per-sample residuals.
scores: torch.Tensor
-
Per-sample uncertainty scores (lower values indicate lower uncertainty).
get_curve()
Return the last computed curve(s), or None if not computed yet.
get_curve(metric_name=None)