diff --git a/tests/unit/linalg/test_gramian.py b/tests/unit/linalg/test_gramian.py index 4a02a9f4..6fc22512 100644 --- a/tests/unit/linalg/test_gramian.py +++ b/tests/unit/linalg/test_gramian.py @@ -1,6 +1,7 @@ from pytest import mark +from torch.testing import assert_close from utils.asserts import assert_is_psd_matrix -from utils.tensors import randn_ +from utils.tensors import randn_, tensor_ from torchjd._linalg import compute_gramian, is_matrix, normalize, regularize @@ -25,6 +26,59 @@ def test_gramian_is_psd(shape: list[int]): assert_is_psd_matrix(gramian) +def test_compute_gramian_scalar_input_0(): + t = tensor_(5.0) + gramian = compute_gramian(t, contracted_dims=0) + expected = tensor_(25.0) + + assert_close(gramian, expected) + + +def test_compute_gramian_vector_input_0(): + t = tensor_([2.0, 3.0]) + gramian = compute_gramian(t, contracted_dims=0) + expected = tensor_([[4.0, 6.0], [6.0, 9.0]]) + + assert_close(gramian, expected) + + +def test_compute_gramian_vector_input_1(): + t = tensor_([2.0, 3.0]) + gramian = compute_gramian(t, contracted_dims=1) + expected = tensor_(13.0) + + assert_close(gramian, expected) + + +def test_compute_gramian_matrix_input_0(): + t = tensor_([[1.0, 2.0], [3.0, 4.0]]) + gramian = compute_gramian(t, contracted_dims=0) + expected = tensor_( + [ + [[[1.0, 3.0], [2.0, 4.0]], [[2.0, 6.0], [4.0, 8.0]]], + [[[3.0, 9.0], [6.0, 12.0]], [[4.0, 12.0], [8.0, 16.0]]], + ] + ) + + assert_close(gramian, expected) + + +def test_compute_gramian_matrix_input_1(): + t = tensor_([[1.0, 2.0], [3.0, 4.0]]) + gramian = compute_gramian(t, contracted_dims=1) + expected = tensor_([[5.0, 11.0], [11.0, 25.0]]) + + assert_close(gramian, expected) + + +def test_compute_gramian_matrix_input_2(): + t = tensor_([[1.0, 2.0], [3.0, 4.0]]) + gramian = compute_gramian(t, contracted_dims=2) + expected = tensor_(30.0) + + assert_close(gramian, expected) + + @mark.parametrize( "shape", [