HipKittens MXFP8 GEMM Support#566
Conversation
| ) | ||
| if use_bias: | ||
| pytest.skip("hipblaslt GEMM does not yet support MXFP8 with bias.") | ||
| hipkittens_eligible = (m % 256 == 0) and (n % 256 == 0) and (k >= 256) |
| key = (device, ub, grouped_gemm) | ||
| ws = _workspace_cache.get(key) | ||
| if ws is None: | ||
| ws = torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device) | ||
| _workspace_cache[key] = ws | ||
| return ws | ||
|
|
||
|
|
||
| def check_mxfp8_workspace(device: int, needed: int) -> None: | ||
| """Grow the workspace to required size""" | ||
| key = (device, False, False) | ||
| ws = _workspace_cache.get(key) | ||
| if ws is not None and ws.shape[0] >= needed: | ||
| return | ||
| _workspace_cache[key] = torch.empty(needed, dtype=torch.uint8, device=device) |
There was a problem hiding this comment.
I have concerns for the proposed workspace cache system:
1). In non-moe runs, it will try to allocate the largest size kitten_gemm needs, replace previous allocated smaller buffers, relying on pytorch garbage collection to deallocate. Then the biggest single buffer will stay in the process starting from the second iteration.
2). For the MOE run, sizes are dynamic, so probably the cache system can still change after the warm up runs
If we can force TE upstream to always provide you TN layout, then we can remove this dynamic workspace entirely?
There was a problem hiding this comment.
I understand your concern, but think we are ok for current models.
1.) This is correct, we only keep the largest workspace, relying on pytorch GC to delete the old workspace. This only affects iteration 1.
2.) Since the workspace is shared for all GEMMs in the model, I think this is unlikely. For example, with DeepSeek 671B with BS=2, the largest non-MoE workspace needed is for the dense layers FFN, where wgrad GEMM will need 200 MB compared to the theoretically maximum MoE GEMM size of 72 MB so this wouldn't occur. For a full MoE Model like Qwen 235B, we still don't run into this issue as the largest non-MoE GEMM would use 96 MB vs 44 MB worst case for MoE.
It is possible that there is a model that exists or could exists where the MoE GEMM is the largest, but convergence theory would imply that we hit the maximum allocation threshold fairly quickly with a many-layer model, and it almost certainly wouldn't affect the performance of a full training run.
There was a problem hiding this comment.
Emm, in addition to my another comment on the possibility to remove this dynamic workspace directly, if we really need this dynamic buffer:
1). let's try to allocate buffer without cache to see if it really hurts the e2e training before working on this delicate buffer cache?
2). Convergence theory usually works in theory papers with input distributional assumptions. I agree for our qwen or ds it works fine. Our library may run into strange corner cases when used by customers.
There was a problem hiding this comment.
We can do this, but I believe this doesn't change us from needing a workspace that changes dynamically with the largest needed space.
The convergence I was referring to was that we have an upper bound on our largest expert in a model. In the scenario where the MoE layer is the largest size, every time we see a new largest expert, we are less likely to see an even larger expert. This means that that we are very unlikely to be spending time on allocating memory for the workspace later on. I think memory allocation here is also a negligible overhead, given that the same workspace is reused.
There was a problem hiding this comment.
Right. this does not change the need of a dynamic workspace. If pytorch native buffer allocation does not hurt us much, our codes will be cleaner and easier to maintain
There was a problem hiding this comment.
Looks like it doesn't hurt much -- maybe 5 tflops or so lost on average.
There was a problem hiding this comment.
Emm, 5 tflops drop vs cleaner/easier to maintain code, I'm okay with both options. @ipanfilo do you have comments on this?
| if (!use_mxfp8 && params.force_hipblaslt) { | ||
| GTEST_SKIP() << "force_hipblaslt only relevant for MXFP8"; | ||
| } | ||
| if (use_mxfp8) { |
There was a problem hiding this comment.
Add new const bool use_hipblaslt_fp8 = (!use_mxfp8 || param.force_hipblaslt) - this combination is used below for many skips. And all this should be below, under ifdef HIP_PLATFORM_AMD under has_fp8
There was a problem hiding this comment.
I wanted to avoid the skips completely, so split up the test instantiation into non-mxfp8 and mxfp8.
There was a problem hiding this comment.
Nevertheless, the same condition is used multiple times below. May be you can rather have use_hipkittens_mxfp8 = (use_mxfp8 && !params.force_hiplaslt) for better clarity
| [](const testing::TestParamInfo<DqGEMMTestSuite::ParamType>& info) { | ||
| return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param)); | ||
| return MKN(std::get<0>(info.param)) + "x" + | ||
| std::to_string(std::get<1>(info.param)) + "x" + |
There was a problem hiding this comment.
What is a point, they are set to false only
| GTEST_SKIP() << "MXFP8 is not supported in current config"; | ||
| } | ||
| if (params.use_bias || params.use_gelu) { | ||
| if (params.force_hipblaslt) { |
There was a problem hiding this comment.
It is skipped below anyway, if add it for future, move it after more generic one
There was a problem hiding this comment.
Sorry, this and the Dq test name changes are artifacts from my attempt to enable bias and gelu for this test. I a ran into issues with gelu for the non-fp8 GEMM in hipBLASlt, and decided to just focus on the non-Dq tests. I have reverted things.
| #include <hip/hip_runtime.h> | ||
| #include <cstddef> | ||
|
|
||
| enum KittensDType { |
There was a problem hiding this comment.
Is it copied from some hipKittent enum? Put comment then
There was a problem hiding this comment.
These values come from the NVTE values -- I have added a comment to that extent.
There was a problem hiding this comment.
And where are they used?
|
|
||
| return torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device) | ||
| key = (device, ub, grouped_gemm) | ||
| ws = _workspace_cache.get(key) |
There was a problem hiding this comment.
Why we don't rely on torch memory caching?
There was a problem hiding this comment.
I have made this change. I will need to run an E2E run to make sure that performance isn't affected, but should be ok given my understanding of torch.empty()
| size_t sa_tr_bytes = align_up((size_t)M * scale_K, 256); | ||
| size_t sb_tr_bytes = align_up((size_t)N * scale_K, 256); | ||
| size_t sa_pk_bytes = align_up((size_t)k_iters * M * sizeof(uint32_t), 256); | ||
| size_t sb_pk_bytes = (size_t)k_iters * N * sizeof(uint32_t); |
There was a problem hiding this comment.
For my own understanding, can you explain why sb_pk_bytes does not require 256-alignment like the others?
There was a problem hiding this comment.
Here, we are aligning the end of each variable so that the next address is 256 aligned, not the current one. Since sb_pk_bytes is the last address, we don't need to pad.
|
|
||
| namespace transformer_engine { | ||
| namespace jax { |
There was a problem hiding this comment.
Nit: there are a few whitespace-only changes in these files, not sure if they are necessary.
There was a problem hiding this comment.
I have removed this, thanks
| Path(__file__).resolve().parent.parent | ||
| / "3rdparty" / "hipkittens" / "include" / "kittens.cuh" | ||
| ) | ||
| if "gfx950" in rocm_archs and hipkittens_header.exists(): |
There was a problem hiding this comment.
Pytorch/JAX extensions do not bear any GPU code but delegate all this to TE core. And kittens are added to TE common too.
Why is this build time setting needed?
There was a problem hiding this comment.
This is an artifact from when I was running into issues with CI not finding pybinded functions from hipKittens. The issue was elsewhere, and I forgot to remove this. I will remove it, thanks!
| NVTE_CHECK((k % 128) == 0, "GEMM K dimension must be multiple of 128 for MXFP8 scaling (got K=", k, ")"); | ||
| NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")"); | ||
| NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")"); | ||
| NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")"); |
There was a problem hiding this comment.
It looks like just spacing change. Please revert if it is the case
| transb, grad, workspace, workspaceSize, alpha, beta, use_split_accumulator, | ||
| math_sm_count, use_service_stream ? ss_ctl.stream : stream, handle); | ||
| #ifdef USE_HIPKITTENS_GEMM | ||
| bool is_mxfp8 = inputA->scaling_mode == NVTE_MXFP8_1D_SCALING |
There was a problem hiding this comment.
Move it out of ifdef and use in ifs that currently check the same conditon
| NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")"); | ||
| NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")"); | ||
| NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")"); | ||
| #ifndef USE_HIPKITTENS_GEMM |
There was a problem hiding this comment.
It is checked below in else branch of hipkittens conditoon
| if (use_hipkittens) { | ||
| auto param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); | ||
|
|
||
| hipStream_t s = use_service_stream ? ss_ctl.stream : stream; |
There was a problem hiding this comment.
the same like with is_mxfp8, no point of having it defined for one branch only
| } | ||
|
|
||
| auto [atol, rtol] = getTestTolerances(dtype, has_fp8, use_mxfp8); | ||
| size_t mismatch_limit = use_mxfp8 ? std::max((size_t)1, params.m * params.n / 1'000'000) : 0; |
| @@ -743,12 +786,15 @@ MAKE_DQ_GEMM_TEST(Testfp8xfp8xfp16, fp8, fp8, fp16) | |||
|
|
|||
| INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite, | |||
There was a problem hiding this comment.
If you end up with having separate prefix for MXFP8, it has to be use for this suite for consistency
| @@ -30,7 +30,9 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = { | |||
|
|
|||
| std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = { | |||
There was a problem hiding this comment.
test_case_sizes_mxfp8 is only used for DqGEMMTest, is it intention to add sizes there?
| if (!use_mxfp8 && params.force_hipblaslt) { | ||
| GTEST_SKIP() << "force_hipblaslt only relevant for MXFP8"; | ||
| } | ||
| if (use_mxfp8) { |
There was a problem hiding this comment.
Nevertheless, the same condition is used multiple times below. May be you can rather have use_hipkittens_mxfp8 = (use_mxfp8 && !params.force_hiplaslt) for better clarity
| #include <hip/hip_runtime.h> | ||
| #include <cstddef> | ||
|
|
||
| enum KittensDType { |
There was a problem hiding this comment.
And where are they used?
| num_cublas_streams = get_num_compute_streams() | ||
|
|
||
|
|
||
| def _hipkittens_workspace_bytes(m: int, n: int, k: int, layout: str) -> int: |
There was a problem hiding this comment.
Should it check for env to figure out if hipKittens is enabled?
Creates an MXFP8 GEMM with HipKittens that outperforms hipBLASlt, and offers additional epilogues such as BIAS and GELU AUX
Requires a workspace sized relative to the model. Often larger than hipBLASlt, but with significant performance improvements. Only builds for gfx950, and requires M / 256 and N / 256.
Adds hipKittens header library as a submodule.