-
Notifications
You must be signed in to change notification settings - Fork 9
Fix/symmetric contraction groupmm bindings #198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6fd348d
5ddccf6
931ec8c
ec961d4
9dde7c1
13a4514
f8a0c39
155899f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,6 +65,18 @@ def load_jit_extension(): | |
| torch_sources = ["libtorch_tp_jit.cpp", "json11/json11.cpp"] | ||
|
|
||
| include_dirs, extra_link_args = (["backend"], ["-Wl,--no-as-needed"]) | ||
| extra_include_dirs = [] | ||
|
|
||
| try: | ||
| import pybind11 | ||
|
|
||
| extra_include_dirs.append(pybind11.get_include()) | ||
| except Exception as e: | ||
| BUILT_EXTENSION_ERROR = ( | ||
| "Could not locate pybind11 include path required for JIT " | ||
| f"OpenEquivariance extension compilation: {e}" | ||
| ) | ||
| return | ||
|
|
||
| if LINKED_LIBPYTHON: | ||
| extra_link_args.pop() | ||
|
|
@@ -76,7 +88,7 @@ def load_jit_extension(): | |
| ], | ||
| ) | ||
| if torch.version.cuda: | ||
| extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"]) | ||
| extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc", "-lcublas"]) | ||
|
|
||
| try: | ||
| torch_libs, cuda_libs = library_paths("cuda") | ||
|
|
@@ -89,15 +101,17 @@ def load_jit_extension(): | |
|
|
||
| extra_cflags.append("-DCUDA_BACKEND") | ||
| elif torch.version.hip: | ||
| extra_link_args.extend(["-lhiprtc"]) | ||
| extra_link_args.extend(["-lhiprtc", "-lrocblas"]) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm.. have we tested this? This looks reasonable but I'm wary about changing this without a platform to test on.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be fine if this code was in our main branch at some point in the past - I excised a lot of this symmetric contraction at some point.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This hasn't been tested. I'm looking though and I'm not sure that we actually linked it in the past somehow? We just referred to "rocblas/rocblas.h" and maybe hiprtc found this header. But I think we'll have to link here for the stable ABI. I'll setup a test on AMD's cloud.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. still waiting for AMD cloud |
||
| torch_libs = library_paths("cuda")[0] | ||
| extra_link_args.append("-Wl,-rpath," + torch_libs) | ||
| extra_cflags.append("-DHIP_BACKEND") | ||
|
|
||
| torch_sources = [oeq_root + "/extension/" + src for src in torch_sources] | ||
| include_dirs = [ | ||
| oeq_root + "/extension/" + d for d in include_dirs | ||
| ] + include_paths("cuda") | ||
| include_dirs = ( | ||
| [oeq_root + "/extension/" + d for d in include_dirs] | ||
| + extra_include_dirs | ||
| + include_paths("cuda") | ||
| ) | ||
|
|
||
| with warnings.catch_warnings(): | ||
| warnings.simplefilter("ignore") | ||
|
|
@@ -184,8 +198,8 @@ def torch_ext_so_path(): | |
|
|
||
| if BUILT_EXTENSION: | ||
| from oeq_utilities import ( | ||
| # GroupMM_F32, | ||
| # GroupMM_F64, | ||
| GroupMM_F32, | ||
| GroupMM_F64, | ||
| DeviceProp, | ||
| GPUTimer, | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto, checking here that we had this in our codebase in the past, removed it, and we are putting it back. If so, no problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It needs to be checked, it was never done through CMAKE only through the python extension.