RiskCoverageMetric
Build a risk-coverage curve from scores and per-sample errors.
Usage
RiskCoverageMetric()Collects per-sample scores and per-sample residuals across multiple update calls and computes summary area-under-curve values using seapig.risk_coverage.risk_coverage.
Parameters
risk: (generalized, selective) = "generalized"-
Which risk definition to use when computing the curve. Must be either
'generalized'or'selective'. n_bins: int = 100-
Number of bins used to downsample the curve when computing AUC summaries.
error_fn: callable or None = None-
Function
(preds, target) -> residualsthat reduces model predictions and targets to a 1-D tensor of per-sample residuals. IfNone, the default is per-sample mean absolute error.
Notes
The compute method returns three tensors: rc/auc_empirical, rc/auc_reference, and rc/auc_excess. The last computed complete curve object (RiskCoverage) is available via get_curve.
Examples
import torch
from seapig.metric import RiskCoverageMetric
metric = RiskCoverageMetric(risk="generalized")
preds = torch.rand(50, 1)
target = torch.rand(50, 1)
scores = torch.rand(50)
metric.update(preds, target, scores)
result = metric.compute()
# result contains keys: 'rc/auc_empirical', 'rc/auc_reference', 'rc/auc_excess'Attributes
| Name | Description |
|---|---|
| full_state_update | bool(x) -> bool |
full_state_update
bool(x) -> bool
full_state_update: bool = False
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
Methods
| Name | Description |
|---|---|
| compute() | Compute AUC summaries for the accumulated risk–coverage curve. |
| get_curve() | Return the last computed RiskCoverage object (or None if not computed). |
| reset() | Reset the accumulated scores and residuals. |
| update() | Store scores and residuals for later curve computation. |
compute()
Compute AUC summaries for the accumulated risk–coverage curve.
Usage
compute()Returns a dict with keys rc/auc_empirical, rc/auc_reference, and rc/auc_excess. If no data has been accumulated an all-zero mapping is returned on the correct device.
get_curve()
Return the last computed RiskCoverage object (or None if not computed).
Usage
get_curve()reset()
Reset the accumulated scores and residuals.
Usage
reset()update()
Store scores and residuals for later curve computation.
Usage
update(preds, target, scores)Parameters
preds: torch.Tensor-
Model outputs and targets. These are passed to
error_fnto compute per-sample residuals. target: torch.Tensor-
Model outputs and targets. These are passed to
error_fnto compute per-sample residuals. scores: torch.Tensor- Per-sample confidence scores (lower values indicate higher confidence).
See Also
- seapig.risk_coverage.risk_coverage: The underlying curve computation.
- seapig.risk_coverage.RiskCoverage: Container for curve results.