-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Fix DiceFocalLoss to apply activation before removing background #8947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -803,9 +803,9 @@ def __init__( | |||||
| """ | ||||||
| super().__init__() | ||||||
| self.dice = DiceLoss( | ||||||
| sigmoid=sigmoid, | ||||||
| softmax=softmax, | ||||||
| other_act=other_act, | ||||||
| sigmoid=False, | ||||||
| softmax=False, | ||||||
| other_act=None, | ||||||
| squared_pred=squared_pred, | ||||||
| jaccard=jaccard, | ||||||
| reduction=reduction, | ||||||
|
|
@@ -822,6 +822,9 @@ def __init__( | |||||
| self.lambda_focal = lambda_focal | ||||||
| self.to_onehot_y = to_onehot_y | ||||||
| self.include_background = include_background | ||||||
| self.sigmoid = sigmoid | ||||||
| self.softmax = softmax | ||||||
| self.other_act = other_act | ||||||
|
|
||||||
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||||||
| """ | ||||||
|
|
@@ -846,6 +849,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |||||
| else: | ||||||
| target = one_hot(target, num_classes=n_pred_ch) | ||||||
|
|
||||||
| # Apply activation before removing background to ensure softmax/sigmoid works correctly | ||||||
| if self.sigmoid: | ||||||
| input = torch.sigmoid(input) | ||||||
| elif self.softmax: | ||||||
| if n_pred_ch == 1: | ||||||
| warnings.warn("single channel prediction, `softmax=True` ignored.") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win Set Ruff flags this Proposed fix- warnings.warn("single channel prediction, `softmax=True` ignored.")
+ warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.15.18)[warning] 857-857: No explicit Set (B028) 🤖 Prompt for AI AgentsSource: Linters/SAST tools |
||||||
| else: | ||||||
| input = torch.softmax(input, 1) | ||||||
| elif self.other_act is not None: | ||||||
| input = self.other_act(input) | ||||||
|
|
||||||
| if not self.include_background: | ||||||
| if n_pred_ch == 1: | ||||||
| warnings.warn("single channel prediction, `include_background=False` ignored.") | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Preserve activation exclusivity validation.
With
DiceLossactivation disabled, configs likesigmoid=True, softmax=Trueare now silently accepted and resolved byif/elif.Proposed fix
self.to_onehot_y = to_onehot_y self.include_background = include_background + if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: + raise ValueError("Only one of sigmoid=True, softmax=True, or other_act may be specified.") self.sigmoid = sigmoid self.softmax = softmax self.other_act = other_actAs per path instructions, "Examine code for logical error or inconsistencies".
📝 Committable suggestion
🤖 Prompt for AI Agents
Source: Path instructions