Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion openequivariance/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ endfunction()

find_package(CUDAToolkit QUIET)
find_package(hip QUIET)
find_package(rocblas QUIET)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, checking here that we had this in our codebase in the past, removed it, and we are putting it back. If so, no problem.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs to be checked, it was never done through CMAKE only through the python extension.


if(CUDAToolkit_FOUND)
message(STATUS "Building stable extension with CUDA backend.")
Expand All @@ -137,6 +138,7 @@ if(CUDAToolkit_FOUND)
CUDA::cudart
CUDA::cuda_driver
CUDA::nvrtc
CUDA::cublas
Comment thread
vbharadwaj-bk marked this conversation as resolved.
cuda_stub_lib
)
add_stable_extension(oeq_stable_cuda CUDA_BACKEND "${CUDA_LINK_LIBS}")
Expand All @@ -157,13 +159,20 @@ if(hip_FOUND)
CXX_STANDARD 17
)

if(TARGET roc::rocblas)
Comment thread
vbharadwaj-bk marked this conversation as resolved.
set(HIP_BLAS_LIB roc::rocblas)
else()
set(HIP_BLAS_LIB rocblas)
endif()

set(HIP_LINK_LIBS
hiprtc
${HIP_BLAS_LIB}
hip_stub_lib
)
add_stable_extension(torch_stable_hip HIP_BACKEND "${HIP_LINK_LIBS}")
endif()

if(NOT CUDAToolkit_FOUND AND NOT hip_FOUND)
message(WARNING "Neither CUDAToolkit nor HIP was found. The stable extension will not be built.")
endif()
endif()
28 changes: 21 additions & 7 deletions openequivariance/openequivariance/_torch/extlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ def load_jit_extension():
torch_sources = ["libtorch_tp_jit.cpp", "json11/json11.cpp"]

include_dirs, extra_link_args = (["backend"], ["-Wl,--no-as-needed"])
extra_include_dirs = []

try:
import pybind11

extra_include_dirs.append(pybind11.get_include())
except Exception as e:
BUILT_EXTENSION_ERROR = (
"Could not locate pybind11 include path required for JIT "
f"OpenEquivariance extension compilation: {e}"
)
return

if LINKED_LIBPYTHON:
extra_link_args.pop()
Expand All @@ -76,7 +88,7 @@ def load_jit_extension():
],
)
if torch.version.cuda:
extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"])
extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc", "-lcublas"])

try:
torch_libs, cuda_libs = library_paths("cuda")
Expand All @@ -89,15 +101,17 @@ def load_jit_extension():

extra_cflags.append("-DCUDA_BACKEND")
elif torch.version.hip:
extra_link_args.extend(["-lhiprtc"])
extra_link_args.extend(["-lhiprtc", "-lrocblas"])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. have we tested this? This looks reasonable but I'm wary about changing this without a platform to test on.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be fine if this code was in our main branch at some point in the past - I excised a lot of this symmetric contraction at some point.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hasn't been tested. I'm looking though and I'm not sure that we actually linked it in the past somehow? We just referred to "rocblas/rocblas.h" and maybe hiprtc found this header. But I think we'll have to link here for the stable ABI. I'll setup a test on AMD's cloud.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still waiting for AMD cloud

torch_libs = library_paths("cuda")[0]
extra_link_args.append("-Wl,-rpath," + torch_libs)
extra_cflags.append("-DHIP_BACKEND")

torch_sources = [oeq_root + "/extension/" + src for src in torch_sources]
include_dirs = [
oeq_root + "/extension/" + d for d in include_dirs
] + include_paths("cuda")
include_dirs = (
[oeq_root + "/extension/" + d for d in include_dirs]
+ extra_include_dirs
+ include_paths("cuda")
)

with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand Down Expand Up @@ -184,8 +198,8 @@ def torch_ext_so_path():

if BUILT_EXTENSION:
from oeq_utilities import (
# GroupMM_F32,
# GroupMM_F64,
GroupMM_F32,
GroupMM_F64,
DeviceProp,
GPUTimer,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,17 @@ def __init__(
internal_weights: bool = True,
num_elements: Optional[int] = None,
weights: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()

self.num_elements = num_elements
self.num_features = irreps_in.count((0, 1))
self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in])
self.correlation = correlation
dtype = torch.get_default_dtype()
if dtype is None:
dtype = torch.get_default_dtype()

for nu in range(1, correlation + 1):
U_matrix = U_matrix_real(
irreps_in=self.coupling_irreps,
Expand All @@ -262,9 +265,7 @@ def __init__(

# Create weight for product basis
self.weights = torch.nn.ParameterList([])
self.groupMM = GroupMM(
torch.get_default_dtype(), num_elements, self.num_features
)
self.groupMM = GroupMM(dtype, num_elements, self.num_features)
self.num_equivariance = 2 * irrep_out.lmax + 1

for i in range(correlation, 0, -1):
Expand All @@ -274,14 +275,18 @@ def __init__(
if i == correlation:
# Parameters for the product basis
w = torch.nn.Parameter(
torch.randn((num_elements, num_params, self.num_features))
torch.randn(
(num_elements, num_params, self.num_features), dtype=dtype
)
/ num_params
)
self.weights_max = w
else:
# Parameters for the product basis
w = torch.nn.Parameter(
torch.randn((num_elements, num_params, self.num_features))
torch.randn(
(num_elements, num_params, self.num_features), dtype=dtype
)
/ num_params
)
self.weights.append(w)
Expand Down Expand Up @@ -352,10 +357,12 @@ def __init__(
internal_weights: Optional[bool] = None,
shared_weights: Optional[bool] = None,
num_elements: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()

self.num_elements = num_elements
self.dtype = dtype if dtype is not None else torch.get_default_dtype()
if irrep_normalization is None:
irrep_normalization = "component"

Expand Down Expand Up @@ -396,6 +403,7 @@ def __init__(
internal_weights=self.internal_weights,
num_elements=num_elements,
weights=self.shared_weights,
dtype=self.dtype,
)
)

Expand All @@ -412,136 +420,3 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
]
outs_cat = torch.cat(outs, dim=-1)[inverse_perm]
return outs_cat


# --------------------------------------------------------------------------


def test_group_matmul():
torch.manual_seed(0)
num_elements = 10
vpe = 30 # Vectors per element, uniform just for testing
num_features = 20

M = 64
K = 123
ragged_counts = torch.zeros(num_elements, dtype=torch.int64, device="cpu")

for i in range(num_elements):
ragged_counts[i] = vpe

def test_backward_0():
group_mm = GroupMM(torch.float32, num_elements, num_features)
A = torch.randn(num_elements, num_features, M, K).to("cuda")
B = torch.randn(num_elements * vpe, num_features, K).to("cuda")

A.requires_grad = True
B.requires_grad = True

ground_truth = torch.zeros(num_elements * vpe, num_features, M, device="cuda")

# Test the forward pass
for i in range(num_elements):
B_slice = B[vpe * i : vpe * (i + 1)]
ground_truth[vpe * i : vpe * (i + 1)] = (
A[i] @ B_slice.permute(1, 2, 0)
).permute(2, 0, 1)

C_g = torch.randn(num_elements * vpe, num_features, M).to("cuda")
C_g.requires_grad = True

ground_truth.backward(C_g, inputs=[A, B])

A_grad_gt = A.grad.detach().clone()
B_grad_gt = B.grad.detach().clone()

A.grad[:] = 0.0
B.grad[:] = 0.0

C = group_mm.group_gemm(A, B, ragged_counts, M, K, 0)

print(torch.norm(ground_truth - C))

C.backward(C_g, inputs=[A, B])
print(torch.norm(A_grad_gt - A.grad))
print(torch.norm(B_grad_gt - B.grad))

def test_backward_1():
print("TESTING BACKWARD_1!")
group_mm = GroupMM(torch.float32, num_elements, num_features)

A = torch.zeros(num_elements * vpe, num_features, M, device="cuda")
B = torch.randn(num_elements * vpe, num_features, K).to("cuda")
A.requires_grad = True
B.requires_grad = True

ground_truth = torch.zeros(num_elements, num_features, M, K).to("cuda")

for i in range(num_elements):
A_slice = A[vpe * i : vpe * (i + 1)]
B_slice = B[vpe * i : vpe * (i + 1)]

ground_truth[i] = A_slice.permute(1, 2, 0) @ B_slice.permute(1, 0, 2)

C = group_mm.group_gemm(A, B, ragged_counts, M, K, 1)

print(torch.norm(C - ground_truth))

C_g = torch.randn(num_elements, num_features, M, K).to("cuda")
C_g.requires_grad = True

ground_truth.backward(C_g, inputs=[A, B])

A_grad_gt = A.grad.detach().clone()
B_grad_gt = B.grad.detach().clone()

A.grad[:] = 0.0
B.grad[:] = 0.0

C.backward(C_g, inputs=[A, B])

print(torch.norm(A.grad - A_grad_gt))
print(torch.norm(B.grad - B_grad_gt))

def test_double_backward():
torch.autograd.set_detect_anomaly(True)
GroupMM(torch.float32, num_elements, num_features)
A = torch.randn(num_elements, num_features, M, K).to("cuda")
B = torch.randn(num_elements * vpe, num_features, K).to("cuda")

A.requires_grad = True
B.requires_grad = True

ground_truth = torch.zeros(num_elements * vpe, num_features, M, device="cuda")

# Test the forward pass
for i in range(num_elements):
B_slice = B[vpe * i : vpe * (i + 1)]
ground_truth[vpe * i : vpe * (i + 1)] = (
A[i] @ B_slice.permute(1, 2, 0)
).permute(2, 0, 1)

C_g = torch.randn(num_elements * vpe, num_features, M).to("cuda")
C_g.requires_grad = True

ground_truth.backward(C_g, inputs=[A, B], create_graph=True, retain_graph=True)
dummy = torch.norm(A.grad) + torch.norm(B.grad)
dummy_grad = torch.randn_like(dummy)

dummy.backward(gradient=dummy_grad, inputs=[C_g, A, B])

A_grad_gt = A.grad
B_grad_gt = B.grad
C_grad_gt = C_g.grad

print(torch.norm(A_grad_gt))
print(torch.norm(B_grad_gt))
print(torch.norm(C_grad_gt))

test_backward_0()
test_backward_1()
test_double_backward()


if __name__ == "__main__":
test_group_matmul()
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ Stream get_current_stream() {

namespace py=pybind11;
PYBIND11_MODULE(libtorch_tp_jit, m) {
py::class_<GroupMM<float>>(m, "GroupMM_F32")
.def(py::init<int, int>())
.def("group_gemm", &GroupMM<float>::group_gemm_intptr);
py::class_<GroupMM<double>>(m, "GroupMM_F64")
.def(py::init<int, int>())
.def("group_gemm", &GroupMM<double>::group_gemm_intptr);

py::class_<DeviceProp>(m, "DeviceProp")
.def(py::init<int>())
.def_readonly("name", &DeviceProp::name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ Stream get_current_stream() {
#include "nanobind/stl/string.h"
namespace nb = nanobind;
NB_MODULE(EXTENSION_NAME, m) {
nb::class_<GroupMM<float>>(m, "GroupMM_F32")
.def(nb::init<int, int>())
.def("group_gemm", &GroupMM<float>::group_gemm_intptr);
nb::class_<GroupMM<double>>(m, "GroupMM_F64")
.def(nb::init<int, int>())
.def("group_gemm", &GroupMM<double>::group_gemm_intptr);

nb::class_<DeviceProp>(m, "DeviceProp")
.def(nb::init<int>())
.def_ro("name", &DeviceProp::name)
Expand All @@ -107,4 +114,4 @@ Stream get_current_stream() {
.def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed)
.def("clear_L2_cache", &GPUTimer::clear_L2_cache);
}
#endif
#endif
4 changes: 3 additions & 1 deletion openequivariance/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ dependencies = [
"setuptools",
"ninja",
"jinja2",
"numpy"
"numpy",
"pybind11"
]
readme = "README.md"

Expand Down Expand Up @@ -55,6 +56,7 @@ dev = [
"pytest",
"pytest-check",
"pytest-subtests",
"mace-torch",
"torch_geometric",
"cmake"
]
Expand Down
Loading
Loading