Skip to content

[ROCm] Allow bf16/bf16/fp32 in nvte_multi_tensor_gemm dispatcher#573

Open
lizamd wants to merge 2 commits into
ROCm:devfrom
lizamd:fix/ck-grouped-gemm-bf16-fp32-output
Open

[ROCm] Allow bf16/bf16/fp32 in nvte_multi_tensor_gemm dispatcher#573
lizamd wants to merge 2 commits into
ROCm:devfrom
lizamd:fix/ck-grouped-gemm-bf16-fp32-output

Conversation

@lizamd
Copy link
Copy Markdown

@lizamd lizamd commented May 4, 2026

The is_supported_dtype check in nvte_multi_tensor_gemm previously required A==B==D for the fp16/bf16 path, which rejected the common bf16/bf16/fp32 case where the GEMM output is fp32 for gradient accumulation. This forced a fallback to multi_stream_cublas_gemm (a per-expert hipblaslt loop), bypassing the CK grouped GEMM kernel entirely on ROCm.

The CK FP16 dispatcher (ck_tile_grouped_gemm_fp16_dispatch) already supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY (fp32, fp16, bf16). The wrapper check is the only thing that prevents it from being reached.

Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}, which matches what the CK dispatcher actually accepts. Verified on Qwen3-30B-A3B MoE training on MI355X (gfx950): fallback warning rate drops from ~1040/step (every GEMM) to ~28/step (~3% of shapes that the CK kernel itself rejects via Kernel::IsSupportedArgument). Throughput is essentially unchanged in this workload because hipblaslt's per-shape autotuning happens to be competitive with the hardcoded CK tile configs for these MoE shapes; the gain will materialize once the CK dispatcher gains more tile configs (or shape-aware tile selection by aggregate M).

This is a CUDA path file; the same patch applies to the AMD path via hipify. No CUDA-side behavior change since cuBLAS/cutlass dispatch on NVIDIA still requires A==B==D in the cutlass fast-path pre-conditions.

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@lizamd lizamd force-pushed the fix/ck-grouped-gemm-bf16-fp32-output branch 2 times, most recently from 764cb65 to ff19241 Compare May 5, 2026 00:02
@matthiasdiener matthiasdiener added the ci-level 1 CI test level 1 label May 5, 2026
@wenchenvincent
Copy link
Copy Markdown
Collaborator

@matthiasdiener @aris134 Could you review this PR?

Copy link
Copy Markdown
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you edit an existing test or add a new test showing that with your change, bf16/fp16 input and fp32 outputs are going through the ck flow correctly now? Also paste some benchmarking data to this ticket for future reference

Comment on lines +1166 to +1171
// CK FP16/BF16 grouped GEMM dispatcher (ck_tile_grouped_gemm_fp16_dispatch)
// already supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
// (fp32, fp16, bf16). The previous check required A==B==D, which incorrectly
// rejected the common bf16/bf16/fp32 case (training with fp32 gradient
// accumulation), forcing a fallback to the per-expert hipblaslt loop.
// Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this explanation may be better suited for the PR description rather than an inline code comment.

Copy link
Copy Markdown
Contributor

@aris134 aris134 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that the CK dispatch logic supports bf16/f32 combination. I would remove the detailed history comment about the previous fallback behavior which is better suited to the PR itself.

@lizamd
Copy link
Copy Markdown
Author

lizamd commented May 6, 2026 via email

The is_supported_dtype check in nvte_multi_tensor_gemm previously required
A==B==D for the fp16/bf16 path, which rejected the common bf16/bf16/fp32
case where the GEMM output is fp32 for gradient accumulation. This forced
a fallback to multi_stream_cublas_gemm (a per-expert hipblaslt loop),
bypassing the CK grouped GEMM kernel entirely on ROCm.

The CK FP16 dispatcher (ck_tile_grouped_gemm_fp16_dispatch) already
supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(fp32, fp16, bf16). The wrapper check is the only thing that prevents it
from being reached.

Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}, which
matches what the CK dispatcher actually accepts. Verified on Qwen3-30B-A3B
MoE training on MI355X (gfx950): fallback warning rate drops from
~1040/step (every GEMM) to ~28/step (~3% of shapes that the CK kernel
itself rejects via Kernel::IsSupportedArgument). Throughput is essentially
unchanged in this workload because hipblaslt's per-shape autotuning
happens to be competitive with the hardcoded CK tile configs for these
MoE shapes; the gain will materialize once the CK dispatcher gains more
tile configs (or shape-aware tile selection by aggregate M).

This is a CUDA path file; the same patch applies to the AMD path via
hipify. No CUDA-side behavior change since cuBLAS/cutlass dispatch on
NVIDIA still requires A==B==D in the cutlass fast-path pre-conditions.

Follow-ups (out of scope for this PR):

- Add more CK tile configs (e.g. TileCfg_64x256x64, TileCfg_128x256x64)
  and shape-aware tile selection by aggregate M per call. Currently
  throughput is unchanged on this workload because the existing hipblaslt
  fallback is well-tuned and the 3 hardcoded CK tile configs
  (TileCfg_256x256x64, TileCfg_256x128x64, TileCfg_256x128x64_padding)
  don't fit MoE shapes (highly variable per-expert M) optimally. Real
  CK-grouped-GEMM perf wins will materialize once tile selection adapts
  to M.
- Investigate the ~3% of GEMMs that hit Kernel::IsSupportedArgument
  rejection (likely small per-expert M values that fail tile-size
  constraints in the current TileCfg_256x* instantiations).
@lizamd lizamd force-pushed the fix/ck-grouped-gemm-bf16-fp32-output branch from ff19241 to d416572 Compare May 7, 2026 17:45
@lizamd
Copy link
Copy Markdown
Author

lizamd commented May 7, 2026

@wangye805 @aris134 could you check the new commit?

Comment thread tests/pytorch/test_numerics.py Outdated
)
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("layout", ["TN", "NT"])
def test_grouped_gemm_fp32_output(input_dtype, layout):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be done by adding configs/parameters to test_grouped_gemm?

Comment thread tests/pytorch/test_numerics.py Outdated
single_output = False
grad = True

os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides Ilya's comment, please try to use monkeypatch for this env setting. You current setting is assuming user didn't set NVTE_USE_CUTLASS_GROUPED_GEMM when running pytests

- Drop the inline comment in cublaslt_gemm.cu (rationale moved to PR body).
- Fold test_grouped_gemm_fp32_output into test_grouped_gemm via a new
  fp32_output parametrize, removing the standalone test function.
- Use pytest's monkeypatch fixture for NVTE_USE_CUTLASS_GROUPED_GEMM
  instead of mutating os.environ directly, so the test no longer assumes
  the user had the env var unset.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants