diff --git a/monai/losses/cldice.py b/monai/losses/cldice.py index 7d7e447c54..ab580d0160 100644 --- a/monai/losses/cldice.py +++ b/monai/losses/cldice.py @@ -20,9 +20,11 @@ from monai.losses.dice import DiceLoss from monai.networks import one_hot -from monai.utils import LossReduction +from monai.utils import LossReduction, optional_import from monai.utils.deprecate_utils import deprecated_arg +binary_thinning_3d, _has_thinning = optional_import("binary_thinning_3d") + def soft_erode(img: torch.Tensor) -> torch.Tensor: # type: ignore """ @@ -129,6 +131,7 @@ def __init__( softmax: bool = False, other_act: Callable | None = None, reduction: LossReduction | str = LossReduction.MEAN, + use_hard_target: bool = False, ) -> None: """ Args: @@ -151,6 +154,8 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + use_hard_target: if True, use the exact CUDA 3D binary thinning for the target skeleton instead of soft skeletonization. + Requires binary_thinning_3d_cuda package and a CUDA 3D target. Defaults to False. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -181,6 +186,7 @@ def __init__( self.sigmoid = sigmoid self.softmax = softmax self.other_act = other_act + self.use_hard_target = use_hard_target @deprecated_arg("y_pred", since="1.5", removed="1.8", new_name="input", msg_suffix="please use `input` instead.") @deprecated_arg("y_true", since="1.5", removed="1.8", new_name="target", msg_suffix="please use `target` instead.") @@ -226,7 +232,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") skel_pred = soft_skel(input, self.iter) - skel_true = soft_skel(target, self.iter) + if self.use_hard_target: + if not (target.dim() == 5 and _has_thinning and target.is_cuda): + raise ValueError( + "use_hard_target=True but conditions not met. " + "Requires 5D CUDA tensor and binary_thinning_3d_cuda package." + ) + skel_true = (target > 0).to(torch.uint8).contiguous() + for b in range(target.shape[0]): + for c in range(target.shape[1]): + binary_thinning_3d.binary_thinning(skel_true[b, c], 0) + skel_true = skel_true.to(target.dtype) + else: + skel_true = soft_skel(target, self.iter) # Compute per-batch clDice by reducing over channel and spatial dimensions # reduce_axis includes all dimensions except batch (dim 0) @@ -279,6 +297,7 @@ def __init__( softmax: bool = False, other_act: Callable | None = None, reduction: LossReduction | str = LossReduction.MEAN, + use_hard_target: bool = False, ) -> None: """ Args: @@ -304,6 +323,8 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + use_hard_target: if True, use the exact CUDA 3D binary thinning for the target skeleton instead of soft skeletonization. + Requires MONAI C++ extensions and a 3D target. Defaults to False. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -336,6 +357,7 @@ def __init__( softmax=softmax, other_act=other_act, reduction=reduction, + use_hard_target=use_hard_target, ) self.alpha = alpha self.to_onehot_y = to_onehot_y diff --git a/setup.cfg b/setup.cfg index d987141d0b..db4bd22fd4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -90,6 +90,7 @@ all = nvidia-ml-py huggingface_hub pyamg>=5.0.0, <5.3.0 + binary_thinning_3d_cuda nibabel = nibabel ninja = @@ -179,6 +180,8 @@ huggingface_hub = huggingface_hub pyamg = pyamg>=5.0.0, <5.3.0 +binary_thinning = + binary_thinning_3d_cuda # segment-anything = # segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything diff --git a/tests/losses/test_cldice_loss.py b/tests/losses/test_cldice_loss.py index cb17cb81ad..71a036109b 100644 --- a/tests/losses/test_cldice_loss.py +++ b/tests/losses/test_cldice_loss.py @@ -85,6 +85,19 @@ def test_cuda(self): result = loss(ONES_2D["input"].cuda(), ONES_2D["target"].cuda()) np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4) + @skip_if_no_cuda + def test_hard_target(self): + """Test SoftclDiceLoss with use_hard_target=True using binary thinning on 3D CUDA tensors.""" + # Skip if binary_thinning not available + from monai.losses.cldice import _has_thinning + if not _has_thinning: + self.skipTest("binary_thinning_3d_cuda not available") + + loss = SoftclDiceLoss(use_hard_target=True) + # MUST BE 3D for hard target logic to trigger! (shape: B, N, H, W, D) + result = loss(ONES_3D["input"].cuda(), ONES_3D["target"].cuda()) + np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4) + def test_reduction_shapes(self): input_tensor = torch.ones((4, 2, 8, 8)) target = torch.ones((4, 2, 8, 8))