diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 7eb6bde863..635a3e75ce 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -19,11 +19,18 @@ from monai.utils import optional_import from monai.utils.enums import StrEnum +# Valid model name to download from the repository +HF_MONAI_MODELS = frozenset( + ("medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets", "radimagenet_resnet50") +) + LPIPS, _ = optional_import("lpips", name="LPIPS") torchvision, _ = optional_import("torchvision") -class PercetualNetworkType(StrEnum): +class PerceptualNetworkType(StrEnum): + """Types of neural networks that are supported by perceptual loss.""" + alex = "alex" vgg = "vgg" squeeze = "squeeze" @@ -81,7 +88,7 @@ class PerceptualLoss(nn.Module): def __init__( self, spatial_dims: int, - network_type: str = PercetualNetworkType.alex, + network_type: str = PerceptualNetworkType.alex, is_fake_3d: bool = True, fake_3d_ratio: float = 0.5, cache_dir: str | None = None, @@ -95,18 +102,26 @@ def __init__( if spatial_dims not in [2, 3]: raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") - if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type: - raise ValueError( - "MedicalNet networks are only compatible with ``spatial_dims=3``." - "Argument is_fake_3d must be set to False." - ) - - if channel_wise and "medicalnet_" not in network_type: + network_type = network_type.lower() + + # Strict validation for MedicalNet + if "medicalnet_" in network_type: + if spatial_dims == 2 or is_fake_3d: + raise ValueError( + "MedicalNet networks are only compatible with ``spatial_dims=3``. Argument is_fake_3d must be set to False." + ) + if not channel_wise: + warnings.warn( + "MedicalNet networks supp, ort channel-wise loss. Consider setting channel_wise=True.", stacklevel=2 + ) + + # Channel-wise only for MedicalNet + elif channel_wise: raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.") - if network_type.lower() not in list(PercetualNetworkType): + if network_type.lower() not in list(PerceptualNetworkType): raise ValueError( - f"Unrecognised criterion entered for Perceptual Loss. Must be one in: {', '.join(PercetualNetworkType)}" + f"Unrecognised criterion entered for Perceptual Loss. Must be one in: {', '.join(PerceptualNetworkType)}" ) if cache_dir: torch.hub.set_dir(cache_dir) @@ -117,12 +132,16 @@ def __init__( self.spatial_dims = spatial_dims self.perceptual_function: nn.Module + + # If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used. if spatial_dims == 3 and is_fake_3d is False: self.perceptual_function = MedicalNetPerceptualSimilarity( - net=network_type, verbose=False, channel_wise=channel_wise + net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir ) elif "radimagenet_" in network_type: - self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) + self.perceptual_function = RadImageNetPerceptualSimilarity( + net=network_type, verbose=False, cache_dir=cache_dir + ) elif network_type == "resnet50": self.perceptual_function = TorchvisionModelPerceptualSimilarity( net=network_type, @@ -131,7 +150,9 @@ def __init__( pretrained_state_dict_key=pretrained_state_dict_key, ) else: + # VGG, AlexNet and SqueezeNet are independently handled by LPIPS. self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False) + self.is_fake_3d = is_fake_3d self.fake_3d_ratio = fake_3d_ratio self.channel_wise = channel_wise @@ -203,22 +224,31 @@ class MedicalNetPerceptualSimilarity(nn.Module): """ Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from - "Warvito/MedicalNet-models". + "Project-MONAI/perceptual-models". Args: net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``. verbose: if false, mute messages from torch Hub load function. channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels. - Defaults to ``False``. + Defaults to ``False``. + cache_dir: path to cache directory to save the pretrained network weights. """ def __init__( - self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False + self, + net: str = "medicalnet_resnet10_23datasets", + verbose: bool = False, + channel_wise: bool = False, + cache_dir: str | None = None, ) -> None: super().__init__() - torch.hub._validate_not_a_forked_repo = lambda a, b, c: True - self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose, trust_repo=True) + if net not in HF_MONAI_MODELS: + raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.") + + self.model = torch.hub.load( + "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, trust_repo=True + ) self.eval() self.channel_wise = channel_wise @@ -267,7 +297,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: for i in range(input.shape[1]): l_idx = i * feats_per_ch r_idx = (i + 1) * feats_per_ch - results[:, i, ...] = feats_diff[:, l_idx : i + r_idx, ...].sum(dim=1) + results[:, i, ...] = feats_diff[:, l_idx:r_idx, ...].sum(dim=1) else: results = feats_diff.sum(dim=1, keepdim=True) @@ -296,17 +326,22 @@ class RadImageNetPerceptualSimilarity(nn.Module): """ Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class - uses torch Hub to download the networks from "Warvito/radimagenet-models". + uses torch Hub to download the networks from "Project-MONAI/perceptual-models". Args: net: {``"radimagenet_resnet50"``} Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``. verbose: if false, mute messages from torch Hub load function. + cache_dir: path to cache directory to save the pretrained network weights. """ - def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None: + def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None: super().__init__() - self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose, trust_repo=True) + if net not in HF_MONAI_MODELS: + raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.") + self.model = torch.hub.load( + "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, trust_repo=True + ) self.eval() for param in self.parameters(): diff --git a/setup.py b/setup.py index 530b869ea2..2569bbe03a 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ import re import sys import warnings +from typing import Any, cast from packaging import version from setuptools import find_packages, setup @@ -146,6 +147,6 @@ def get_cmds(): cmdclass=get_cmds(), packages=find_packages(exclude=("docs", "examples", "tests", "tests.*")), zip_safe=False, - package_data={"monai": ["py.typed", *jit_extension_source]}, # type: ignore[arg-type] + package_data=cast(Any, {"monai": ["py.typed", *jit_extension_source]}), ext_modules=get_extensions(), ) diff --git a/tests/losses/test_perceptual_loss.py b/tests/losses/test_perceptual_loss.py index b406bd3c69..8d94fdc1ae 100644 --- a/tests/losses/test_perceptual_loss.py +++ b/tests/losses/test_perceptual_loss.py @@ -116,6 +116,16 @@ def test_medicalnet_on_2d_data(self, network_type): with self.assertRaises(ValueError): PerceptualLoss(spatial_dims=2, network_type=network_type) + @parameterized.expand(["squeeze", "alex", "vgg", "radimagenet_resnet50", "resnet50"]) + def test_channel_wise_with_non_medicalnet(self, network_type): + with self.assertRaises(ValueError): + PerceptualLoss(spatial_dims=2, network_type=network_type, channel_wise=True) + + @parameterized.expand(["squeeze", "alex", "vgg", "radimagenet_resnet50", "resnet50"]) + def test_non_medicalnet_3d_without_fake_3d(self, network_type): + with self.assertRaises(ValueError): + PerceptualLoss(spatial_dims=3, network_type=network_type, is_fake_3d=False) + if __name__ == "__main__": unittest.main()