RiskCoverage

Container for risk-coverage results.

Usage

Source

RiskCoverage()

Holds the coverage, score thresholds, empirical and reference risk curves, their difference (excess), and AUC metrics.

Attributes

coverage: torch.Tensor

Coverage values in [0, 1].

threshold: torch.Tensor

Sorted score thresholds used to compute coverage.

risk: torch.Tensor

Empirical risk at each coverage level.

reference: torch.Tensor

Reference (optimal) risk at each coverage level.

excess: torch.Tensor

Excess risk (empirical - reference).

risk_type: str

Either 'generalized' or 'selective'; see risk_coverage.

auc_empirical: torch.Tensor

Area under the empirical risk curve (trapezoidal rule).

auc_reference: torch.Tensor

Area under the reference risk curve (trapezoidal rule).

auc_excess: torch.Tensor
Area under the excess risk curve (trapezoidal rule).

Methods

Name Description
__init__() Create a RiskCoverage container.
__repr__() Short representation including AUCs and number of points.
plot() Return a matplotlib Figure with the requested curves.

__init__()

Create a RiskCoverage container.

Usage

Source

__init__(
    coverage,
    threshold,
    risk,
    reference,
    excess,
    risk_type,
    auc_empirical,
    auc_reference,
    auc_excess
)

All parameters correspond directly to the attributes of the same name. Typically constructed by risk_coverage rather than directly.


__repr__()

Short representation including AUCs and number of points.

Usage

Source

__repr__()

plot()

Return a matplotlib Figure with the requested curves.

Usage

Source

plot(empirical=True, reference=True, excess=True, digits=4)
Parameters
empirical: bool = True

Whether to include each curve in the plot.

reference: bool = True

Whether to include each curve in the plot.

excess: bool = True

Whether to include each curve in the plot.

digits: int = 4
Number of decimal places to show for AUC values in the legend.
Returns
matplotlib.figure.Figure
A figure containing the plotted curves.
Raises
ImportError

If matplotlib is not installed.

ValueError
If all curve flags are False.
Examples
fig = rc.plot(empirical=True, reference=False)

See Also