diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 37e4468d4b..80da8bafcc 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -44,9 +44,58 @@ def spatial_gradient(x: torch.Tensor, dim: int) -> torch.Tensor: return (x[slicing_s] - x[slicing_e]) / 2.0 +def spatial_gradient_squared(x: torch.Tensor, dim_1: int, dim_2: int) -> torch.Tensor: + """ + Calculate the second-order partial derivative of ``x`` with respect to spatial dims + ``dim_1`` and ``dim_2`` using compact central finite differences. + + For ``dim_1 == dim_2`` the pure second derivative uses the ``[1, -2, 1]`` stencil: + ``d2x[i] = x[i+1] - 2 * x[i] + x[i-1]``. + + For ``dim_1 != dim_2`` the mixed partial uses the compact 4-point stencil: + ``d2x[i, j] = (x[i+1, j+1] - x[i+1, j-1] - x[i-1, j+1] + x[i-1, j-1]) / 4``. + + Every spatial dimension is sliced to ``[1:-1]`` so the output shape is independent of + ``(dim_1, dim_2)``; this lets terms be summed together. Requires ``x.shape[d] > 2`` + for every spatial dim ``d``. + + Args: + x: the shape should be BCH(WD). + dim_1: first spatial dimension index. + dim_2: second spatial dimension index. + + Returns: + Tensor with batch and channel axes preserved and every spatial axis sliced to + ``[1:-1]``. + """ + slice_inner = slice(1, -1) + slice_plus = slice(2, None) + slice_minus = slice(None, -2) + slice_all = slice(None) + + def _idx(overrides: dict) -> list: + out: list = [slice_all, slice_all] + for d in range(2, x.ndim): + out.append(overrides.get(d, slice_inner)) + return out + + if dim_1 == dim_2: + return x[_idx({dim_1: slice_plus})] - 2 * x[_idx({})] + x[_idx({dim_1: slice_minus})] + return ( + x[_idx({dim_1: slice_plus, dim_2: slice_plus})] + - x[_idx({dim_1: slice_plus, dim_2: slice_minus})] + - x[_idx({dim_1: slice_minus, dim_2: slice_plus})] + + x[_idx({dim_1: slice_minus, dim_2: slice_minus})] + ) / 4.0 + + class BendingEnergyLoss(_Loss): """ - Calculate the bending energy based on second-order differentiation of ``pred`` using central finite difference. + Calculate the bending energy based on second-order differentiation of ``pred``. + + Pure second derivatives use the compact ``[1, -2, 1]`` stencil; mixed partials use a + compact 4-point central scheme. Both span three voxels per axis, so each spatial + dimension of ``pred`` only needs to be greater than 2. For more information, see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb. @@ -79,41 +128,41 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. ValueError: When ``pred`` is not 3-d, 4-d or 5-d. - ValueError: When any spatial dimension of ``pred`` has size less than or equal to 4. + ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2. ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions. """ if pred.ndim not in [3, 4, 5]: raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") for i in range(pred.ndim - 2): - if pred.shape[-i - 1] <= 4: - raise ValueError(f"All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}") + if pred.shape[-i - 1] <= 2: + raise ValueError(f"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}") if pred.shape[1] != pred.ndim - 2: raise ValueError( f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, " f"does not match number of spatial dimensions, {pred.ndim - 2}" ) - # first order gradient - first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] - # spatial dimensions in a shape suited for broadcasting below if self.normalize: spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,)) - energy = torch.tensor(0) - for dim_1, g in enumerate(first_order_gradient): - dim_1 += 2 + # Initialize on pred.device so a GPU `pred` does not get added to a CPU + # accumulator, and as a float so an integer-dtype `pred` still produces a + # floating-point energy (the compact pure-derivative stencil has no + # division, so a Long input would otherwise propagate as Long and fail + # `torch.mean` at the reduction step). + energy = torch.tensor(0.0, device=pred.device) + for dim_1 in range(2, pred.ndim): + d2 = spatial_gradient_squared(pred, dim_1, dim_1) if self.normalize: - g *= pred.shape[dim_1] / spatial_dims - energy = energy + (spatial_gradient(g, dim_1) * pred.shape[dim_1]) ** 2 - else: - energy = energy + spatial_gradient(g, dim_1) ** 2 + d2 = d2 * (pred.shape[dim_1] ** 2 / spatial_dims) + energy = energy + d2**2 for dim_2 in range(dim_1 + 1, pred.ndim): + d2_mixed = spatial_gradient_squared(pred, dim_1, dim_2) if self.normalize: - energy = energy + 2 * (spatial_gradient(g, dim_2) * pred.shape[dim_2]) ** 2 - else: - energy = energy + 2 * spatial_gradient(g, dim_2) ** 2 + d2_mixed = d2_mixed * (pred.shape[dim_1] * pred.shape[dim_2] / spatial_dims) + energy = energy + 2 * d2_mixed**2 if self.reduction == LossReduction.MEAN.value: energy = torch.mean(energy) # the batch and channel average diff --git a/tests/losses/deform/test_bending_energy.py b/tests/losses/deform/test_bending_energy.py index 2e8ab32dbd..5e713b3e47 100644 --- a/tests/losses/deform/test_bending_energy.py +++ b/tests/losses/deform/test_bending_energy.py @@ -23,6 +23,7 @@ TEST_CASES = [ [{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0], + [{}, {"pred": torch.ones((1, 3, 3, 3, 3), device=device)}, 0.0], [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 0.0], [ {"normalize": False}, @@ -64,11 +65,11 @@ def test_ill_shape(self): with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device)) with self.assertRaisesRegex(ValueError, "All spatial dimensions"): - loss.forward(torch.ones((1, 3, 4, 5, 5), device=device)) + loss.forward(torch.ones((1, 3, 2, 5, 5), device=device)) with self.assertRaisesRegex(ValueError, "All spatial dimensions"): - loss.forward(torch.ones((1, 3, 5, 4, 5))) + loss.forward(torch.ones((1, 3, 5, 2, 5))) with self.assertRaisesRegex(ValueError, "All spatial dimensions"): - loss.forward(torch.ones((1, 3, 5, 5, 4))) + loss.forward(torch.ones((1, 3, 5, 5, 2))) # number of vector components unequal to number of spatial dims with self.assertRaisesRegex(ValueError, "Number of vector components"):