From 6fd348d41c96642c44fec4209e9b128aa88a1f62 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Tue, 19 May 2026 04:58:30 +0000 Subject: [PATCH 1/8] missing pybind11 --- openequivariance/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index e58669c..6824635 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -17,7 +17,8 @@ dependencies = [ "setuptools", "ninja", "jinja2", - "numpy" + "numpy", + "pybind11" ] readme = "README.md" From 5ddccf6c1ea8389e056b8f58359d804ba36039d5 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Tue, 19 May 2026 04:59:51 +0000 Subject: [PATCH 2/8] add bindings --- .../openequivariance/extension/libtorch_tp_jit.cpp | 7 +++++++ .../extension/libtorch_tp_jit_stable.cpp | 9 ++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index 698a142..173c0a2 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -84,6 +84,13 @@ Stream get_current_stream() { namespace py=pybind11; PYBIND11_MODULE(libtorch_tp_jit, m) { + py::class_>(m, "GroupMM_F32") + .def(py::init()) + .def("group_gemm", &GroupMM::group_gemm_intptr); + py::class_>(m, "GroupMM_F64") + .def(py::init()) + .def("group_gemm", &GroupMM::group_gemm_intptr); + py::class_(m, "DeviceProp") .def(py::init()) .def_readonly("name", &DeviceProp::name) diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp index 440b764..1a26de1 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -92,6 +92,13 @@ Stream get_current_stream() { #include "nanobind/stl/string.h" namespace nb = nanobind; NB_MODULE(EXTENSION_NAME, m) { + nb::class_>(m, "GroupMM_F32") + .def(nb::init()) + .def("group_gemm", &GroupMM::group_gemm_intptr); + nb::class_>(m, "GroupMM_F64") + .def(nb::init()) + .def("group_gemm", &GroupMM::group_gemm_intptr); + nb::class_(m, "DeviceProp") .def(nb::init()) .def_ro("name", &DeviceProp::name) @@ -107,4 +114,4 @@ Stream get_current_stream() { .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) .def("clear_L2_cache", &GPUTimer::clear_L2_cache); } -#endif \ No newline at end of file +#endif From 931ec8cc825ce2a0b4b3f017d33457dfdd1dba2f Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Tue, 19 May 2026 05:00:19 +0000 Subject: [PATCH 3/8] find and add blas during cmake --- openequivariance/CMakeLists.txt | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/openequivariance/CMakeLists.txt b/openequivariance/CMakeLists.txt index 232233f..5230dc0 100644 --- a/openequivariance/CMakeLists.txt +++ b/openequivariance/CMakeLists.txt @@ -117,6 +117,7 @@ endfunction() find_package(CUDAToolkit QUIET) find_package(hip QUIET) +find_package(rocblas QUIET) if(CUDAToolkit_FOUND) message(STATUS "Building stable extension with CUDA backend.") @@ -137,6 +138,7 @@ if(CUDAToolkit_FOUND) CUDA::cudart CUDA::cuda_driver CUDA::nvrtc + CUDA::cublas cuda_stub_lib ) add_stable_extension(oeq_stable_cuda CUDA_BACKEND "${CUDA_LINK_LIBS}") @@ -157,8 +159,15 @@ if(hip_FOUND) CXX_STANDARD 17 ) + if(TARGET roc::rocblas) + set(HIP_BLAS_LIB roc::rocblas) + else() + set(HIP_BLAS_LIB rocblas) + endif() + set(HIP_LINK_LIBS hiprtc + ${HIP_BLAS_LIB} hip_stub_lib ) add_stable_extension(torch_stable_hip HIP_BACKEND "${HIP_LINK_LIBS}") @@ -166,4 +175,4 @@ endif() if(NOT CUDAToolkit_FOUND AND NOT hip_FOUND) message(WARNING "Neither CUDAToolkit nor HIP was found. The stable extension will not be built.") -endif() \ No newline at end of file +endif() From ec961d47e3289f84548d013b2ed9f157393a7416 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Tue, 19 May 2026 05:00:55 +0000 Subject: [PATCH 4/8] find pybind11 and give useful error if not --- .../_torch/extlib/__init__.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index ef17724..306738d 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -65,6 +65,14 @@ def load_jit_extension(): torch_sources = ["libtorch_tp_jit.cpp", "json11/json11.cpp"] include_dirs, extra_link_args = (["backend"], ["-Wl,--no-as-needed"]) + extra_include_dirs = [] + + try: + import pybind11 + + extra_include_dirs.append(pybind11.get_include()) + except Exception as e: + getLogger().info(f"Could not locate pybind11 include path: {e}") if LINKED_LIBPYTHON: extra_link_args.pop() @@ -76,7 +84,7 @@ def load_jit_extension(): ], ) if torch.version.cuda: - extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"]) + extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc", "-lcublas"]) try: torch_libs, cuda_libs = library_paths("cuda") @@ -89,15 +97,17 @@ def load_jit_extension(): extra_cflags.append("-DCUDA_BACKEND") elif torch.version.hip: - extra_link_args.extend(["-lhiprtc"]) + extra_link_args.extend(["-lhiprtc", "-lrocblas"]) torch_libs = library_paths("cuda")[0] extra_link_args.append("-Wl,-rpath," + torch_libs) extra_cflags.append("-DHIP_BACKEND") torch_sources = [oeq_root + "/extension/" + src for src in torch_sources] - include_dirs = [ - oeq_root + "/extension/" + d for d in include_dirs - ] + include_paths("cuda") + include_dirs = ( + [oeq_root + "/extension/" + d for d in include_dirs] + + extra_include_dirs + + include_paths("cuda") + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -184,8 +194,8 @@ def torch_ext_so_path(): if BUILT_EXTENSION: from oeq_utilities import ( - # GroupMM_F32, - # GroupMM_F64, + GroupMM_F32, + GroupMM_F64, DeviceProp, GPUTimer, ) From 9dde7c1a484a1093ee41e63e4100db177c348e86 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Fri, 22 May 2026 05:10:40 +0000 Subject: [PATCH 5/8] add mace-torch for symmetric contraction testing --- openequivariance/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index 6824635..1028c8c 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -56,6 +56,7 @@ dev = [ "pytest", "pytest-check", "pytest-subtests", + "mace-torch", "torch_geometric", "cmake" ] From 13a4514e3199209d8132b45ab23ad9f4c83c8f23 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Fri, 22 May 2026 05:10:57 +0000 Subject: [PATCH 6/8] move tests to pytest --- .../symmetric_contraction.py | 153 ++----------- tests/symmetric_contraction_test.py | 208 ++++++++++++++++++ 2 files changed, 222 insertions(+), 139 deletions(-) create mode 100644 tests/symmetric_contraction_test.py diff --git a/openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py b/openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py index 504e788..655c13b 100644 --- a/openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py +++ b/openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py @@ -239,6 +239,7 @@ def __init__( internal_weights: bool = True, num_elements: Optional[int] = None, weights: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, ) -> None: super().__init__() @@ -246,7 +247,9 @@ def __init__( self.num_features = irreps_in.count((0, 1)) self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in]) self.correlation = correlation - dtype = torch.get_default_dtype() + if dtype is None: + dtype = torch.get_default_dtype() + for nu in range(1, correlation + 1): U_matrix = U_matrix_real( irreps_in=self.coupling_irreps, @@ -262,9 +265,7 @@ def __init__( # Create weight for product basis self.weights = torch.nn.ParameterList([]) - self.groupMM = GroupMM( - torch.get_default_dtype(), num_elements, self.num_features - ) + self.groupMM = GroupMM(dtype, num_elements, self.num_features) self.num_equivariance = 2 * irrep_out.lmax + 1 for i in range(correlation, 0, -1): @@ -274,14 +275,18 @@ def __init__( if i == correlation: # Parameters for the product basis w = torch.nn.Parameter( - torch.randn((num_elements, num_params, self.num_features)) + torch.randn( + (num_elements, num_params, self.num_features), dtype=dtype + ) / num_params ) self.weights_max = w else: # Parameters for the product basis w = torch.nn.Parameter( - torch.randn((num_elements, num_params, self.num_features)) + torch.randn( + (num_elements, num_params, self.num_features), dtype=dtype + ) / num_params ) self.weights.append(w) @@ -352,10 +357,12 @@ def __init__( internal_weights: Optional[bool] = None, shared_weights: Optional[bool] = None, num_elements: Optional[int] = None, + dtype: Optional[torch.dtype] = None, ) -> None: super().__init__() self.num_elements = num_elements + self.dtype = dtype if dtype is not None else torch.get_default_dtype() if irrep_normalization is None: irrep_normalization = "component" @@ -396,6 +403,7 @@ def __init__( internal_weights=self.internal_weights, num_elements=num_elements, weights=self.shared_weights, + dtype=self.dtype, ) ) @@ -412,136 +420,3 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): ] outs_cat = torch.cat(outs, dim=-1)[inverse_perm] return outs_cat - - -# -------------------------------------------------------------------------- - - -def test_group_matmul(): - torch.manual_seed(0) - num_elements = 10 - vpe = 30 # Vectors per element, uniform just for testing - num_features = 20 - - M = 64 - K = 123 - ragged_counts = torch.zeros(num_elements, dtype=torch.int64, device="cpu") - - for i in range(num_elements): - ragged_counts[i] = vpe - - def test_backward_0(): - group_mm = GroupMM(torch.float32, num_elements, num_features) - A = torch.randn(num_elements, num_features, M, K).to("cuda") - B = torch.randn(num_elements * vpe, num_features, K).to("cuda") - - A.requires_grad = True - B.requires_grad = True - - ground_truth = torch.zeros(num_elements * vpe, num_features, M, device="cuda") - - # Test the forward pass - for i in range(num_elements): - B_slice = B[vpe * i : vpe * (i + 1)] - ground_truth[vpe * i : vpe * (i + 1)] = ( - A[i] @ B_slice.permute(1, 2, 0) - ).permute(2, 0, 1) - - C_g = torch.randn(num_elements * vpe, num_features, M).to("cuda") - C_g.requires_grad = True - - ground_truth.backward(C_g, inputs=[A, B]) - - A_grad_gt = A.grad.detach().clone() - B_grad_gt = B.grad.detach().clone() - - A.grad[:] = 0.0 - B.grad[:] = 0.0 - - C = group_mm.group_gemm(A, B, ragged_counts, M, K, 0) - - print(torch.norm(ground_truth - C)) - - C.backward(C_g, inputs=[A, B]) - print(torch.norm(A_grad_gt - A.grad)) - print(torch.norm(B_grad_gt - B.grad)) - - def test_backward_1(): - print("TESTING BACKWARD_1!") - group_mm = GroupMM(torch.float32, num_elements, num_features) - - A = torch.zeros(num_elements * vpe, num_features, M, device="cuda") - B = torch.randn(num_elements * vpe, num_features, K).to("cuda") - A.requires_grad = True - B.requires_grad = True - - ground_truth = torch.zeros(num_elements, num_features, M, K).to("cuda") - - for i in range(num_elements): - A_slice = A[vpe * i : vpe * (i + 1)] - B_slice = B[vpe * i : vpe * (i + 1)] - - ground_truth[i] = A_slice.permute(1, 2, 0) @ B_slice.permute(1, 0, 2) - - C = group_mm.group_gemm(A, B, ragged_counts, M, K, 1) - - print(torch.norm(C - ground_truth)) - - C_g = torch.randn(num_elements, num_features, M, K).to("cuda") - C_g.requires_grad = True - - ground_truth.backward(C_g, inputs=[A, B]) - - A_grad_gt = A.grad.detach().clone() - B_grad_gt = B.grad.detach().clone() - - A.grad[:] = 0.0 - B.grad[:] = 0.0 - - C.backward(C_g, inputs=[A, B]) - - print(torch.norm(A.grad - A_grad_gt)) - print(torch.norm(B.grad - B_grad_gt)) - - def test_double_backward(): - torch.autograd.set_detect_anomaly(True) - GroupMM(torch.float32, num_elements, num_features) - A = torch.randn(num_elements, num_features, M, K).to("cuda") - B = torch.randn(num_elements * vpe, num_features, K).to("cuda") - - A.requires_grad = True - B.requires_grad = True - - ground_truth = torch.zeros(num_elements * vpe, num_features, M, device="cuda") - - # Test the forward pass - for i in range(num_elements): - B_slice = B[vpe * i : vpe * (i + 1)] - ground_truth[vpe * i : vpe * (i + 1)] = ( - A[i] @ B_slice.permute(1, 2, 0) - ).permute(2, 0, 1) - - C_g = torch.randn(num_elements * vpe, num_features, M).to("cuda") - C_g.requires_grad = True - - ground_truth.backward(C_g, inputs=[A, B], create_graph=True, retain_graph=True) - dummy = torch.norm(A.grad) + torch.norm(B.grad) - dummy_grad = torch.randn_like(dummy) - - dummy.backward(gradient=dummy_grad, inputs=[C_g, A, B]) - - A_grad_gt = A.grad - B_grad_gt = B.grad - C_grad_gt = C_g.grad - - print(torch.norm(A_grad_gt)) - print(torch.norm(B_grad_gt)) - print(torch.norm(C_grad_gt)) - - test_backward_0() - test_backward_1() - test_double_backward() - - -if __name__ == "__main__": - test_group_matmul() diff --git a/tests/symmetric_contraction_test.py b/tests/symmetric_contraction_test.py new file mode 100644 index 0000000..1ec9d2a --- /dev/null +++ b/tests/symmetric_contraction_test.py @@ -0,0 +1,208 @@ +from unittest.mock import patch + +import pytest +import torch +import torch.nn.functional as F + +from e3nn import o3 + +import openequivariance as oeq +from openequivariance._torch.symmetric_contraction import SymmetricContraction + +mace_symmetric_contraction = pytest.importorskip("mace.modules.symmetric_contraction") +MaceSymmetricContraction = mace_symmetric_contraction.SymmetricContraction + + +IRREPS_IN = o3.Irreps("2x0e + 2x1o") +IRREPS_OUT = o3.Irreps("2x0e + 2x1o") +CORRELATION = 2 +NUM_ELEMENTS = 4 +LABEL_VALUES = [0, 2, 3, 2, 0, 0, 2, 3, 2, 2] + + +@pytest.fixture(params=[torch.float32, torch.float64], ids=["F32", "F64"]) +def dtype(request): + return request.param + + +@pytest.fixture +def device(): + if not torch.cuda.is_available(): + pytest.skip( + "SymmetricContraction requires a CUDA/HIP device exposed through torch.cuda" + ) + if not oeq.BUILT_EXTENSION: + pytest.skip( + f"OpenEquivariance extension is not built: {oeq.BUILT_EXTENSION_ERROR}" + ) + return torch.device("cuda") + + +@pytest.fixture +def labels(device): + return torch.tensor(LABEL_VALUES, device=device, dtype=torch.long) + + +@pytest.fixture +def node_attrs(labels, dtype): + return F.one_hot(labels, num_classes=NUM_ELEMENTS).to(dtype=dtype) + + +@pytest.fixture +def node_feats(device, dtype): + gen = torch.Generator(device=device) + gen.manual_seed(2468) + return torch.randn( + len(LABEL_VALUES), + IRREPS_IN.count((0, 1)), + IRREPS_IN.dim // IRREPS_IN.count((0, 1)), + device=device, + dtype=dtype, + generator=gen, + requires_grad=True, + ) + + +@pytest.fixture +def modules(device, dtype): + torch.manual_seed(12345) + oeq_module = SymmetricContraction( + IRREPS_IN, + IRREPS_OUT, + correlation=CORRELATION, + num_elements=NUM_ELEMENTS, + dtype=dtype, + ).to(device) + + # MACE's original e3nn implementation reads torch.get_default_dtype() + # during construction, so patch that lookup instead of mutating global state. + with patch( + "mace.modules.symmetric_contraction.torch.get_default_dtype", + return_value=dtype, + ): + mace_module = MaceSymmetricContraction( + IRREPS_IN, + IRREPS_OUT, + correlation=CORRELATION, + num_elements=NUM_ELEMENTS, + ).to(device=device, dtype=dtype) + + copy_matching_state(oeq_module, mace_module) + return oeq_module, mace_module + + +def tolerance(dtype): + if dtype == torch.float64: + return {"rtol": 1e-10, "atol": 1e-10} + return {"rtol": 1e-4, "atol": 1e-4} + + +def copy_matching_state(source, target): + source_state = source.state_dict() + target_state = target.state_dict() + for name, value in source_state.items(): + if name in target_state and target_state[name].shape == value.shape: + target_state[name] = value.detach().clone().to(target_state[name]) + target.load_state_dict(target_state) + + +def matching_trainable_parameters(source, target): + source_params = dict(source.named_parameters()) + target_params = dict(target.named_parameters()) + names = [ + name + for name, param in source_params.items() + if param.requires_grad + and name in target_params + and target_params[name].requires_grad + and target_params[name].shape == param.shape + ] + assert names, "No matching trainable parameters found" + return tuple(source_params[name] for name in names), tuple( + target_params[name] for name in names + ) + + +def random_like(tensor, seed): + gen = torch.Generator(device=tensor.device) + gen.manual_seed(seed) + return torch.randn( + tensor.shape, device=tensor.device, dtype=tensor.dtype, generator=gen + ) + + +class TestSymmetricContraction: + def test_matches_mace_forward_backward( + self, modules, node_feats, node_attrs, dtype + ): + oeq_module, mace_module = modules + mace_node_feats = node_feats.detach().clone().requires_grad_() + + oeq_output = oeq_module(node_feats, node_attrs) + mace_output = mace_module(mace_node_feats, node_attrs) + + assert oeq_output.shape == (len(LABEL_VALUES), IRREPS_OUT.dim) + torch.testing.assert_close(oeq_output, mace_output, **tolerance(dtype)) + + output_grad = random_like(oeq_output, seed=4321) + oeq_params, mace_params = matching_trainable_parameters(oeq_module, mace_module) + + oeq_grads = torch.autograd.grad( + oeq_output, (node_feats, *oeq_params), grad_outputs=output_grad + ) + mace_grads = torch.autograd.grad( + mace_output, (mace_node_feats, *mace_params), grad_outputs=output_grad + ) + + for oeq_grad, mace_grad in zip(oeq_grads, mace_grads): + torch.testing.assert_close(oeq_grad, mace_grad, **tolerance(dtype)) + + def test_matches_mace_double_backward(self, modules, node_feats, node_attrs, dtype): + oeq_module, mace_module = modules + mace_node_feats = node_feats.detach().clone().requires_grad_() + + oeq_output = oeq_module(node_feats, node_attrs) + mace_output = mace_module(mace_node_feats, node_attrs) + oeq_output_grad = random_like(oeq_output, seed=9876).requires_grad_() + mace_output_grad = oeq_output_grad.detach().clone().requires_grad_() + + oeq_params, mace_params = matching_trainable_parameters(oeq_module, mace_module) + oeq_tensors = (node_feats, *oeq_params) + mace_tensors = (mace_node_feats, *mace_params) + + oeq_first_grads = torch.autograd.grad( + oeq_output, + oeq_tensors, + grad_outputs=oeq_output_grad, + create_graph=True, + ) + mace_first_grads = torch.autograd.grad( + mace_output, + mace_tensors, + grad_outputs=mace_output_grad, + create_graph=True, + ) + + for oeq_grad, mace_grad in zip(oeq_first_grads, mace_first_grads): + torch.testing.assert_close(oeq_grad, mace_grad, **tolerance(dtype)) + + probes = [ + random_like(grad, seed=1357 + index) + for index, grad in enumerate(oeq_first_grads) + ] + oeq_target = sum( + (grad * probe).sum() for grad, probe in zip(oeq_first_grads, probes) + ) + mace_target = sum( + (grad * probe).sum() for grad, probe in zip(mace_first_grads, probes) + ) + + oeq_second_grads = torch.autograd.grad( + oeq_target, oeq_tensors + (oeq_output_grad,) + ) + mace_second_grads = torch.autograd.grad( + mace_target, mace_tensors + (mace_output_grad,) + ) + + for oeq_grad, mace_grad in zip(oeq_second_grads, mace_second_grads): + torch.testing.assert_close(oeq_grad, mace_grad, **tolerance(dtype)) From f8a0c3963dac92609e0ff25ff78218e24385a88f Mon Sep 17 00:00:00 2001 From: asglover <140220574+asglover@users.noreply.github.com> Date: Tue, 26 May 2026 21:08:25 -0700 Subject: [PATCH 7/8] BUILT_EXTENSION_ERROR --- openequivariance/openequivariance/_torch/extlib/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 306738d..393d337 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -72,7 +72,11 @@ def load_jit_extension(): extra_include_dirs.append(pybind11.get_include()) except Exception as e: - getLogger().info(f"Could not locate pybind11 include path: {e}") + BUILT_EXTENSION_ERROR = ( + "Could not locate pybind11 include path required for JIT " + f"OpenEquivariance extension compilation: {e}" + ) + return if LINKED_LIBPYTHON: extra_link_args.pop() From 155899f014adcf14a2fa30ba52d9b1d7fad2c7f7 Mon Sep 17 00:00:00 2001 From: asglover <140220574+asglover@users.noreply.github.com> Date: Tue, 26 May 2026 21:26:31 -0700 Subject: [PATCH 8/8] DEVICE=cuda --- tests/symmetric_contraction_test.py | 31 +++++++++-------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/tests/symmetric_contraction_test.py b/tests/symmetric_contraction_test.py index 1ec9d2a..fdf0861 100644 --- a/tests/symmetric_contraction_test.py +++ b/tests/symmetric_contraction_test.py @@ -6,7 +6,6 @@ from e3nn import o3 -import openequivariance as oeq from openequivariance._torch.symmetric_contraction import SymmetricContraction mace_symmetric_contraction = pytest.importorskip("mace.modules.symmetric_contraction") @@ -18,6 +17,7 @@ CORRELATION = 2 NUM_ELEMENTS = 4 LABEL_VALUES = [0, 2, 3, 2, 0, 0, 2, 3, 2, 2] +DEVICE = torch.device("cuda") @pytest.fixture(params=[torch.float32, torch.float64], ids=["F32", "F64"]) @@ -26,21 +26,8 @@ def dtype(request): @pytest.fixture -def device(): - if not torch.cuda.is_available(): - pytest.skip( - "SymmetricContraction requires a CUDA/HIP device exposed through torch.cuda" - ) - if not oeq.BUILT_EXTENSION: - pytest.skip( - f"OpenEquivariance extension is not built: {oeq.BUILT_EXTENSION_ERROR}" - ) - return torch.device("cuda") - - -@pytest.fixture -def labels(device): - return torch.tensor(LABEL_VALUES, device=device, dtype=torch.long) +def labels(): + return torch.tensor(LABEL_VALUES, device=DEVICE, dtype=torch.long) @pytest.fixture @@ -49,14 +36,14 @@ def node_attrs(labels, dtype): @pytest.fixture -def node_feats(device, dtype): - gen = torch.Generator(device=device) +def node_feats(dtype): + gen = torch.Generator(device=DEVICE) gen.manual_seed(2468) return torch.randn( len(LABEL_VALUES), IRREPS_IN.count((0, 1)), IRREPS_IN.dim // IRREPS_IN.count((0, 1)), - device=device, + device=DEVICE, dtype=dtype, generator=gen, requires_grad=True, @@ -64,7 +51,7 @@ def node_feats(device, dtype): @pytest.fixture -def modules(device, dtype): +def modules(dtype): torch.manual_seed(12345) oeq_module = SymmetricContraction( IRREPS_IN, @@ -72,7 +59,7 @@ def modules(device, dtype): correlation=CORRELATION, num_elements=NUM_ELEMENTS, dtype=dtype, - ).to(device) + ).to(DEVICE) # MACE's original e3nn implementation reads torch.get_default_dtype() # during construction, so patch that lookup instead of mutating global state. @@ -85,7 +72,7 @@ def modules(device, dtype): IRREPS_OUT, correlation=CORRELATION, num_elements=NUM_ELEMENTS, - ).to(device=device, dtype=dtype) + ).to(device=DEVICE, dtype=dtype) copy_matching_state(oeq_module, mace_module) return oeq_module, mace_module