Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions tests/data/utils/test_compute_shape_offset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from monai.data.utils import compute_shape_offset


class TestComputeShapeOffsetRegression(unittest.TestCase):
"""Regression tests for `compute_shape_offset` input-shape handling."""
class TestComputeShapeOffset(unittest.TestCase):
"""Unit tests for :func:`monai.data.utils.compute_shape_offset`."""

def test_pytorch_size_input(self):
"""Validate `torch.Size` input produces expected shape and offset.
Expand All @@ -42,6 +42,39 @@ def test_pytorch_size_input(self):
# 3. Prove it successfully processed the shape by checking its length
self.assertEqual(len(shape), 3)

def setUp(self):
"""Set up a 4x4 identity affine used across all test cases."""
self.affine = np.eye(4)

def test_numpy_array_input(self):
"""Verify compute_shape_offset accepts a numpy array as spatial_shape."""
shape = np.array([64, 64, 64])
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
self.assertEqual(len(out_shape), 3)

def test_list_input(self):
"""Verify compute_shape_offset accepts a plain list as spatial_shape."""
shape = [64, 64, 64]
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
self.assertEqual(len(out_shape), 3)

def test_torch_tensor_input(self):
"""Verify compute_shape_offset accepts a torch.Tensor as spatial_shape.

This path broke in PyTorch >= 2.9 because np.array() relied on the
non-tuple sequence indexing protocol that PyTorch removed. Wrapping with
tuple() fixes it.
"""
shape = torch.tensor([64, 64, 64])
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
self.assertEqual(len(out_shape), 3)

def test_identity_affines_preserve_shape(self):
"""Verify that identity in/out affines produce an output shape matching the input."""
shape = torch.tensor([32, 48, 16])
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
np.testing.assert_allclose(np.array(out_shape, dtype=float), shape.numpy().astype(float), atol=1e-5)


if __name__ == "__main__":
unittest.main()
Loading