SelectiveInferenceTask
Wrap a trained LightningModule to attach selection results during inference.
Usage
SelectiveInferenceTask()The wrapper calls the wrapped model in inference mode and combines its predictions with selection outputs produced by a provided ConfidenceScore.
Key behavior:
- The wrapped task must provide an
.embed(x)method. The wrapper callstask.embed(x)to produce embeddings used by the score. - The wrapped task is copied and set to
eval()during initialization to avoid accidental training side effects. - If the wrapped task defines
test_metrics(aMetricorMetricCollection), it will be wrapped by SelectiveMetric so metrics are computed only on selected examples. - If
rc_metric(a RiskCoverageMetric) is provided, the wrapper will update it during test steps; the final risk-coverage values are available via get_risk_coverage_curve().
Parameters
task: LightningModule-
A trained
LightningModulewhoseforward(x)returns predictions. The module must implementembed(x)to produce embeddings for scoring. score: ConfidenceScore-
A seapig ConfidenceScore instance providing the ConfidenceScore.select method.
input_key: INPUT_KEYS | None = None-
Key used to extract inputs from an incoming batch. If
None(default), the first element of the batch is used (positional index 0). When a string is given it must be one of:'image','input','images','inputs','x'. target_key: TARGET_KEYS | None = None-
Key used to extract targets from an incoming batch. If
None(default), the second element of the batch is used (positional index 1). When a string is given it must be one of:'mask','label','masks','labels','targets','target','y','y_true'. acc_test_outputs: bool = False-
If
True, per-batch outputs (predictions merged with selection results) are accumulated in thetest_outputslist for later inspection. IfFalse(default), outputs are not accumulated and metrics are logged as usual. rc_metric: RiskCoverageMetric | None = None- Optional RiskCoverageMetric that will be updated during testing.
Examples
from seapig import SelectiveInferenceTask
from seapig.scores import EuclideanScore
score = EuclideanScore()
# score.fit(X=train_embeddings) # fit before wrapping
selective_task = SelectiveInferenceTask(task=model, score=score)Methods
| Name | Description |
|---|---|
| forward() | Run the wrapped model and attach selection results. |
| get_risk_coverage_curve() | Return the latest computed risk-coverage curve, or None if not available. |
| on_test_epoch_end() | Log final computed test metrics once at the end of testing. |
| predict_step() | Perform prediction and return predictions with selection outputs. |
| test_step() | Perform a test step and include selection outputs. |
forward()
Run the wrapped model and attach selection results.
Usage
forward(x)Steps performed:
- Calls the wrapped model. If a
torch.Tensoris returned it is placed under the key'predictions'. - Computes embeddings with
task.embed(x)and callscore.select(embs). - Merges prediction mapping and selection mapping and return the result.
Returns
dict[str, torch.Tensor]-
A
dictcontaining the model predictions and the selection outputs returned by the score ('score'and'selected').
get_risk_coverage_curve()
Return the latest computed risk-coverage curve, or None if not available.
Usage
get_risk_coverage_curve()on_test_epoch_end()
Log final computed test metrics once at the end of testing.
Usage
on_test_epoch_end()predict_step()
Perform prediction and return predictions with selection outputs.
Usage
predict_step(batch, batch_idx, dataloader_idx=0)The wrapper calls forward(x) and returns the combined mapping produced by the wrapped model and the score. This mapping typically contains the model’s predictions and the selection outputs (e.g. score and selected).
test_step()
Perform a test step and include selection outputs.
Usage
test_step(batch, batch_idx, dataloader_idx=0)Behavior:
- Extracts inputs and targets from the batch using
input_key/target_key. - Calls
forward(x)to get predictions augmented with selection results. - If a SelectiveMetric was created (from the wrapped task’s
test_metrics), it is updated with (predictions,targets,selected_mask) and logged. - If
rc_metricis provided, it is updated with (predictions,targets, score) and its values are logged; final rc values are logged at on_test_epoch_end step. - If
test_outputswas enabled at construction, the per-batch outputs are appended to thetest_outputslist for later inspection.
Notes
This method does not return a value; metrics are updated and logged via Lightning’s logging utilities.