Skip to content

Fix/symmetric contraction groupmm bindings#198

Draft
asglover wants to merge 8 commits into
mainfrom
fix/symmetric-contraction-groupmm-bindings
Draft

Fix/symmetric contraction groupmm bindings#198
asglover wants to merge 8 commits into
mainfrom
fix/symmetric-contraction-groupmm-bindings

Conversation

@asglover
Copy link
Copy Markdown
Collaborator

Re-enable symmetric contraction / groupmm bindings and added basic pytest testing.

Copy link
Copy Markdown
Member

@vbharadwaj-bk vbharadwaj-bk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks great. Couple of things:

  1. RocM, I'm assuming hasn't been tested. Are you basically adding back in code that was taken out? In which case, no problem. Otherwise, I'm a bit wary.
  2. Have you run the test suite / installation on google collab with an NVIDIA GPU, or on a platform of your choice? I would do that, since this requires changes to the low-level extension.
  3. No problem adding cuBLAS back to the CMakeLists.txt for the stable extension, since someone wants this. I was really hoping to not do this; without it, we could package up a prebuilt wheel and avoid compilation on the user side entirely. Now, we can't do that, and worse, the user's cuBLAS should be ABI compatible with PyTorch's installed version. At some point, we should address this.
  4. Future work (not in this commit): elimination of the groupMM custom class. You can switch to a caching scheme similar to the regular OEQ kernels for the cuBLASLt state. This will enable you to compose this with torch.compile (and, incidentally, you should tell the user that this is not compatible with torch.compile (yet). Then you can add a symmetric contraction test in the export_test.py file.

extra_cflags.append("-DCUDA_BACKEND")
elif torch.version.hip:
extra_link_args.extend(["-lhiprtc"])
extra_link_args.extend(["-lhiprtc", "-lrocblas"])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. have we tested this? This looks reasonable but I'm wary about changing this without a platform to test on.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be fine if this code was in our main branch at some point in the past - I excised a lot of this symmetric contraction at some point.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hasn't been tested. I'm looking though and I'm not sure that we actually linked it in the past somehow? We just referred to "rocblas/rocblas.h" and maybe hiprtc found this header. But I think we'll have to link here for the stable ABI. I'll setup a test on AMD's cloud.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still waiting for AMD cloud


extra_include_dirs.append(pybind11.get_include())
except Exception as e:
getLogger().info(f"Could not locate pybind11 include path: {e}")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that PyBind11 is required for JIT extension compilation, you should set BUILT_EXTENSION_ERROR here rather than log the error.


find_package(CUDAToolkit QUIET)
find_package(hip QUIET)
find_package(rocblas QUIET)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, checking here that we had this in our codebase in the past, removed it, and we are putting it back. If so, no problem.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs to be checked, it was never done through CMAKE only through the python extension.

Comment thread openequivariance/CMakeLists.txt
Comment thread openequivariance/CMakeLists.txt
Comment thread tests/symmetric_contraction_test.py Outdated

@pytest.fixture
def device():
if not torch.cuda.is_available():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this to be a fixture. Tests should fail if there is no torch.device("cuda"), not skip. I would use torch.device("cuda") inline everywhere.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just declare a global

DEVICE=torch.device("cuda")

Comment thread tests/symmetric_contraction_test.py Outdated


@pytest.fixture
def labels(device):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, don't need this as a fixture.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of like using a fixture so it makes a new tensor for each unique tests so nothing strange happens with lifetimes or gradients.

)


def random_like(tensor, seed):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. I would set the generator state outside this function. Otherwise every tensor of the same size gets the same random values (?)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function returns a tensor with values that are determined by the seed, which is passed into this function, so you would only get tensors with the same values in the same places if you called this function with the same seed. The generator should be the pseudorandom function, which can be the same, and is often the same. I would only switch the generator if it was a very expensive method or something.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way that it hardcodes the random seeds through does bother me. I'll look into if there are good practices about getting the random seed from pytest or something. So it can be repeatable when it needs to be, but otherwise is random.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to add hypothesis for property testing eventually, so I'll delay getting fancy with the random seeds here.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think you can achieve this just by avoiding passing the seed at all as a fixture and initializing the generator once as a global variable. That way, it's repeatable and you get fresh randomness for every tensor that calls the generator.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would make it test order / number of tests dependent. Running more tests would change the seed. I feel like the feature set provided by hypothesis, where they randomly fuzz the seed every time, except when it fails, then they store that and add it to a list which is always replayed is a good design. But probably should be applied evenly. For this PR, I can change it for this to whatever you'd like. I can remove the seeds and just make it a global.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants