Add ConformalRiskCalibrator and ConformalRiskPredictor for conformal risk control prediction sets#8939
Add ConformalRiskCalibrator and ConformalRiskPredictor for conformal risk control prediction sets#8939txmed82 wants to merge 2 commits into
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (4)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (3)
📝 WalkthroughWalkthroughAdds Estimated code review effort🎯 4 (Complex) | ⏱️ ~70 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
85b8a37 to
2fd9a20
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (4)
monai/metrics/conformal_risk.py (4)
253-253: ⚡ Quick winAdd
strict=Truetozip().Ensures
_scoresand_labelshave matching lengths; helps catch logic errors early.- for scores_i, labels_i in zip(self._scores, self._labels): + for scores_i, labels_i in zip(self._scores, self._labels, strict=True):🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/metrics/conformal_risk.py` at line 253, Add the `strict=True` parameter to the zip() call that iterates over self._scores and self._labels in the for loop. This will ensure that both iterables have matching lengths and raise a ValueError if they don't, helping catch logic errors early.Source: Linters/SAST tools
50-50: ⚡ Quick winDead import:
tqdmis never used.Neither
tqdmnorhas_tqdmappear elsewhere in the file.🧹 Remove unused import
-tqdm, has_tqdm = optional_import("tqdm", name="tqdm")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/metrics/conformal_risk.py` at line 50, The optional_import call for tqdm on line 50 imports both tqdm and has_tqdm variables, but neither of these variables is used anywhere else in the conformal_risk.py file. Remove the entire line containing the tqdm optional_import statement since it introduces dead code that serves no purpose.
277-279: ⚡ Quick winMissing docstring for
reset().Per coding guidelines, all definitions should have docstrings.
def reset(self) -> None: + """Clear accumulated calibration data and reset internal state.""" self._scores, self._labels = [], [] self._num_classes = None🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/metrics/conformal_risk.py` around lines 277 - 279, The reset() method is missing a docstring as required by coding guidelines. Add a docstring to the reset() method that describes its purpose, which is to clear the internal state by resetting the scores list, labels list, and num_classes attribute back to their initial values. Follow the project's docstring format conventions.Source: Coding guidelines
76-76: ⚖️ Poor tradeoffSilent clamping may mask upstream label bugs.
Out-of-range labels are silently clamped to
[0, C-1]. Consider logging a warning when clamping occurs, so upstream data issues surface during debugging.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/metrics/conformal_risk.py` at line 76, The labels_flat.clamp operation silently adjusts out-of-range labels without any logging, which can hide upstream data issues. Before applying the clamp operation on labels_flat, detect if any values fall outside the valid range [0, c-1] by checking the minimum and maximum values. If out-of-range values are detected, log a warning message that indicates clamping is occurring and optionally includes details about how many or what proportion of labels were affected. Then proceed with the clamp operation as currently written.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@monai/metrics/conformal_risk.py`:
- Around line 59-65: The function `_set_from_threshold` in the file is dead code
that is never called, not exported, and has no tests associated with it. Remove
the entire function definition including the docstring to clean up the codebase
and reduce maintenance burden.
---
Nitpick comments:
In `@monai/metrics/conformal_risk.py`:
- Line 253: Add the `strict=True` parameter to the zip() call that iterates over
self._scores and self._labels in the for loop. This will ensure that both
iterables have matching lengths and raise a ValueError if they don't, helping
catch logic errors early.
- Line 50: The optional_import call for tqdm on line 50 imports both tqdm and
has_tqdm variables, but neither of these variables is used anywhere else in the
conformal_risk.py file. Remove the entire line containing the tqdm
optional_import statement since it introduces dead code that serves no purpose.
- Around line 277-279: The reset() method is missing a docstring as required by
coding guidelines. Add a docstring to the reset() method that describes its
purpose, which is to clear the internal state by resetting the scores list,
labels list, and num_classes attribute back to their initial values. Follow the
project's docstring format conventions.
- Line 76: The labels_flat.clamp operation silently adjusts out-of-range labels
without any logging, which can hide upstream data issues. Before applying the
clamp operation on labels_flat, detect if any values fall outside the valid
range [0, c-1] by checking the minimum and maximum values. If out-of-range
values are detected, log a warning message that indicates clamping is occurring
and optionally includes details about how many or what proportion of labels were
affected. Then proceed with the clamp operation as currently written.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: e3510c75-6712-4d38-bb82-842dd7c3fb24
📒 Files selected for processing (4)
docs/source/metrics.rstmonai/metrics/__init__.pymonai/metrics/conformal_risk.pytests/metrics/test_conformal_risk.py
| def _set_from_threshold(scores: torch.Tensor, lam: float) -> torch.Tensor: | ||
| """Boolean prediction set ``{ y : score(y) <= lam }``, shape (..., C). | ||
|
|
||
| ``scores`` is the non-conformity score tensor with class as the last dim; | ||
| ``lam`` is a scalar threshold. Returns a bool tensor of the same shape. | ||
| """ | ||
| return scores <= lam |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if _set_from_threshold is used anywhere in the codebase
rg -n '_set_from_threshold' --type pyRepository: Project-MONAI/MONAI
Length of output: 173
🏁 Script executed:
cat -n monai/metrics/conformal_risk.py | head -80 | tail -25Repository: Project-MONAI/MONAI
Length of output: 1375
🏁 Script executed:
rg '_set_from_threshold' --type py -g '**test**'Repository: Project-MONAI/MONAI
Length of output: 45
🏁 Script executed:
rg '__all__' monai/metrics/conformal_risk.py -A 20Repository: Project-MONAI/MONAI
Length of output: 806
Remove unused _set_from_threshold function.
The function is never called, not exported, and untested. Dead code should be removed.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/metrics/conformal_risk.py` around lines 59 - 65, The function
`_set_from_threshold` in the file is dead code that is never called, not
exported, and has no tests associated with it. Remove the entire function
definition including the docstring to clean up the codebase and reduce
maintenance burden.
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (2)
monai/metrics/conformal_risk.py (1)
277-279: ⚡ Quick winAdd a Google-style docstring for
reset.This public method is missing a docstring.
Proposed fix
def reset(self) -> None: + """Clear accumulated calibration data. + + Returns: + None. + """ self._scores, self._labels = [], [] self._num_classes = NoneAs per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/metrics/conformal_risk.py` around lines 277 - 279, The reset method in the ConformalRisk class is missing a Google-style docstring. Add a docstring to the reset method that describes its purpose of resetting internal state variables (_scores, _labels, and _num_classes to their initial values). Include a brief one-line summary followed by a more detailed description if needed. Since the method takes no arguments and returns None, focus on describing what state is being reset and why this operation is performed.Source: Coding guidelines
tests/metrics/test_conformal_risk.py (1)
187-187: ⚡ Quick winAssert the returned
probscontract.
probs_outis unpacked but unused. Assert it instead of leaving Ruff RUF059.Proposed fix
sets, mask, probs_out = predictor(probs) + assert_allclose(probs_out, probs, atol=0) self.assertEqual(sets.shape, (1, 3, 3))🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/metrics/test_conformal_risk.py` at line 187, The variable `probs_out` is unpacked from the predictor call but remains unused, which triggers the RUF059 linting rule. Replace the unused variable with an assertion that validates the returned `probs_out` meets the expected contract, such as checking its shape, type, or value ranges that align with the test's expectations. This both eliminates the unused variable warning and improves test coverage by verifying the predictor's output contract.Source: Linters/SAST tools
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@monai/metrics/conformal_risk.py`:
- Around line 75-76: Instead of silently clamping invalid labels to valid class
indices using the clamp operation, add validation to reject or raise an error
when labels contain invalid values outside the expected range of 0 to c-1. This
needs to be fixed in two locations: the labels_flat.clamp call around line 76
and the similar operation around line 226. Replace the clamping logic with a
check that validates all labels fall within the valid range and either raises an
exception or skips invalid samples, ensuring that corrupted labels do not
silently propagate through the loss, coverage, and lambda_hat calculations.
- Around line 259-262: The code materializes the entire lambda grid evaluation
at once in the line with `sets = scores_i.unsqueeze(0) <= lam_grid.view(-1, 1,
1)`, which creates a tensor of shape (n_lam, P_i, C) that can cause
out-of-memory errors for large 3D volumes. Refactor this block to chunk the
lambda grid into smaller batches and iterate through them in a loop, processing
each chunk separately and accumulating the risk contribution from each chunk
into risk_sum. This prevents materializing the full tensor at once while still
computing the correct cumulative risk.
- Around line 41-48: The `__all__` list in the conformal_risk.py file has items
that are not in alphabetical order, which violates Ruff's sorting requirements.
Specifically, swap the positions of "compute_set_size" and "compute_coverage" in
the `__all__` list so that "compute_coverage" appears before "compute_set_size",
matching alphabetical order.
- Around line 193-197: The validation of lam_grid is incomplete and allows two
problematic cases: empty grids and unsorted grids. In the validation condition
that checks if lam_grid is a 1-D tensor with values in [0, 1], add two
additional checks: verify that lam_grid is not empty (check the length or size),
and verify that lam_grid is sorted in ascending order (you can use
torch.is_nonincreasing or check if the differences between consecutive elements
are non-negative). Update the error message in the ValueError to reflect all
validation requirements. This will prevent crashes at line 271 and ensure
correct behavior at line 273 when finding the infimum.
- Line 262: After calling self.loss_fn with sets_shaped and labels_rep on the
line where risk_sum is updated, add validation to ensure the output is valid.
Validate that the returned loss tensor has the expected shape matching
labels_rep, contains no NaN values, and all values are within the valid range of
0 to 1. If any validation fails, raise a descriptive error to prevent silent
failures in the CRC bound calculation. This validation should occur immediately
after the loss_fn call and before accumulating the result into risk_sum.
- Around line 319-323: The set_threshold method currently accepts tensors of
arbitrary shape, but it should enforce that the lam parameter is a scalar tensor
with a value in the range [0, 1]. After the existing isinstance check for
torch.Tensor, add validation to ensure the tensor has a scalar shape (empty
dimensions) and that its value falls within [0, 1], raising a ValueError if
either condition is violated. This prevents invalid broadcasts over spatial
dimensions that could produce incorrect conformal sets.
- Line 253: In the for loop iterating over self._scores and self._labels using
zip(), add the parameter strict=True to the zip() call to enforce that both
sequences have the same length. This will catch any synchronization issues
between these two buffers if they become out of sync in the future due to
unequal append operations.
---
Nitpick comments:
In `@monai/metrics/conformal_risk.py`:
- Around line 277-279: The reset method in the ConformalRisk class is missing a
Google-style docstring. Add a docstring to the reset method that describes its
purpose of resetting internal state variables (_scores, _labels, and
_num_classes to their initial values). Include a brief one-line summary followed
by a more detailed description if needed. Since the method takes no arguments
and returns None, focus on describing what state is being reset and why this
operation is performed.
In `@tests/metrics/test_conformal_risk.py`:
- Line 187: The variable `probs_out` is unpacked from the predictor call but
remains unused, which triggers the RUF059 linting rule. Replace the unused
variable with an assertion that validates the returned `probs_out` meets the
expected contract, such as checking its shape, type, or value ranges that align
with the test's expectations. This both eliminates the unused variable warning
and improves test coverage by verifying the predictor's output contract.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 7ae55399-50d4-479b-8bf4-a56352f5e0cf
📒 Files selected for processing (4)
docs/source/metrics.rstmonai/metrics/__init__.pymonai/metrics/conformal_risk.pytests/metrics/test_conformal_risk.py
✅ Files skipped from review due to trivial changes (2)
- monai/metrics/init.py
- docs/source/metrics.rst
…risk control prediction sets Signed-off-by: Colin Son <txmed82@users.noreply.github.com>
… lambda loop; sort __all__; scalar/range check set_threshold; zip(strict=True); docstrings - conformal_risk.py: reject out-of-range labels instead of silent clamp (lines ~76, ~226) - conformal_risk.py: chunk the lambda grid in calibrate() to avoid materializing (n_lam, P_i, C) at once - conformal_risk.py: validate lam_grid is non-empty and sorted ascending (prevents IndexError and wrong infimum) - conformal_risk.py: validate loss_fn output shape and NaN after each call - conformal_risk.py: enforce set_threshold lam is scalar in [0, 1] - conformal_risk.py: zip(strict=True) over _scores/_labels - conformal_risk.py: alphabetical __all__ (RUF022), reset() docstring - test_conformal_risk.py: assert predictor returns input probs unchanged (RUF059) Signed-off-by: Colin Son <txmed82@users.noreply.github.com>
|
Reviewed this alongside #8938 — the risk-control math is correct, including the parts that are easy to get subtly wrong:
Controlling an image-level loss (rather than per-voxel) is also the principled choice here: image exchangeability is exactly what the CRC theorem needs, so within-image voxel correlation doesn't break the guarantee (unlike a per-voxel marginal treatment). One robustness item and minor nits: 1. A custom non-monotone 2. Nits.
Solid PR — glad to see CRC going into MONAI. |
Fixes #8935 (part 2 of 2).
Description
Adds image-level conformal risk control to
monai/metrics/, the loss-boundedcounterpart to the marginal-coverage
ConformalPredictorin #8938.monai/metrics/conformal_risk.py:ConformalRiskCalibrator: calibrate one global thresholdlambda_hatthatbounds an image-level loss (
miscoverageorfalse_negative) on a held-outsplit via the finite-sample selection of Angelopoulos et al. 2022, giving
E[L] <= alpha. Handles classification and variable-size segmentation;include_background=Falsedrops background voxels.ConformalRiskPredictor: applylambda_hatand return the prediction-setmask plus a per-voxel uncertainty mask (set size > 1).
Coverage/SetSize:CumulativeIterationMetrics for the coverage vs.set-size trade-off.
API docs added to
docs/source/metrics.rst.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.