diff --git a/tests/networks/layers/test_get_layers.py b/tests/networks/layers/test_get_layers.py index 5c020892ed..3114ed7b0b 100644 --- a/tests/networks/layers/test_get_layers.py +++ b/tests/networks/layers/test_get_layers.py @@ -11,12 +11,26 @@ from __future__ import annotations +import re import unittest from parameterized import parameterized from monai.networks.layers import get_act_layer, get_dropout_layer, get_norm_layer + +def _strip_bias_field(text: str) -> str: + """Strip the optional PyTorch >= 2.13 ``, bias=True|False`` repr fragment. + + Args: + text: Layer string representation to normalize. + + Returns: + The representation with any ``, bias=True|False`` removed. + """ + return re.sub(r",\s*bias=(?:True|False)", "", text) + + TEST_CASE_NORM = [ [{"name": ("group", {"num_groups": 1})}, "GroupNorm(1, 1, eps=1e-05, affine=True)"], [ @@ -41,7 +55,7 @@ class TestGetLayers(unittest.TestCase): @parameterized.expand(TEST_CASE_NORM) def test_norm_layer(self, input_param, expected): layer = get_norm_layer(**input_param) - self.assertEqual(f"{layer}", expected) + self.assertEqual(_strip_bias_field(f"{layer}"), _strip_bias_field(expected)) @parameterized.expand(TEST_CASE_ACT) def test_act_layer(self, input_param, expected):