diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index de10e4b2d4..410ee1e0bc 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -233,9 +233,10 @@ def __init__( self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"]) self.num_bins = num_bins self.kernel_type = kernel_type + self.bin_centers: torch.Tensor if self.kernel_type == "gaussian": self.preterm = 1 / (2 * sigma**2) - self.bin_centers = bin_centers[None, None, ...] + self.register_buffer("bin_centers", bin_centers[None, None, ...], persistent=False) self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) diff --git a/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py b/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py index ff7851ed1c..41816a4410 100644 --- a/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py +++ b/tests/losses/image_dissimilarity/test_global_mutual_information_loss.py @@ -116,6 +116,15 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0. class TestGlobalMutualInformationLossIll(unittest.TestCase): + def test_gaussian_bin_centers_registered_buffer(self): + loss = GlobalMutualInformationLoss(kernel_type="gaussian", num_bins=16) + + self.assertIn("bin_centers", dict(loss.named_buffers())) + self.assertFalse(loss.bin_centers.requires_grad) + + loss = loss.to(dtype=torch.float64) + self.assertEqual(loss.bin_centers.dtype, torch.float64) + @parameterized.expand( [ (torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)), # mismatched_simple_dims