RiskCoverageMetric

Build a risk-coverage curve from scores and per-sample errors.

Usage

Source

RiskCoverageMetric()

Collects per-sample scores and per-sample residuals across multiple update calls and computes summary area-under-curve values using seapig.risk_coverage.risk_coverage.

Parameters

risk: (generalized, selective) = "generalized"

Which risk definition to use when computing the curve. Must be either 'generalized' or 'selective'.

n_bins: int = 100

Number of bins used to downsample the curve when computing AUC summaries.

error_fn: callable or None = None
Function (preds, target) -> residuals that reduces model predictions and targets to a 1-D tensor of per-sample residuals. If None, the default is per-sample mean absolute error.

Notes

The compute method returns three tensors: rc/auc_empirical, rc/auc_reference, and rc/auc_excess. The last computed complete curve object (RiskCoverage) is available via get_curve.

Examples

import torch
from seapig.metric import RiskCoverageMetric
metric = RiskCoverageMetric(risk="generalized")
preds = torch.rand(50, 1)
target = torch.rand(50, 1)
scores = torch.rand(50)
metric.update(preds, target, scores)
result = metric.compute()
# result contains keys: 'rc/auc_empirical', 'rc/auc_reference', 'rc/auc_excess'

Attributes

Name Description
full_state_update bool(x) -> bool

full_state_update

bool(x) -> bool

full_state_update: bool = False

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

Methods

Name Description
compute() Compute AUC summaries for the accumulated risk–coverage curve.
get_curve() Return the last computed RiskCoverage object (or None if not computed).
reset() Reset the accumulated scores and residuals.
update() Store scores and residuals for later curve computation.

compute()

Compute AUC summaries for the accumulated risk–coverage curve.

Usage

Source

compute()

Returns a dict with keys rc/auc_empirical, rc/auc_reference, and rc/auc_excess. If no data has been accumulated an all-zero mapping is returned on the correct device.


get_curve()

Return the last computed RiskCoverage object (or None if not computed).

Usage

Source

get_curve()

reset()

Reset the accumulated scores and residuals.

Usage

Source

reset()

update()

Store scores and residuals for later curve computation.

Usage

Source

update(preds, target, scores)
Parameters
preds: torch.Tensor

Model outputs and targets. These are passed to error_fn to compute per-sample residuals.

target: torch.Tensor

Model outputs and targets. These are passed to error_fn to compute per-sample residuals.

scores: torch.Tensor
Per-sample confidence scores (lower values indicate higher confidence).

See Also