Skip to content
138 changes: 138 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3075,6 +3075,144 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)


if IS_HIP_EXTENSION:
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT", "TT"])
@pytest.mark.parametrize("accumulate", [False, True])
@pytest.mark.parametrize(
"pad_dim",
["K", "M", "N", "MK", "MKN"],
ids=lambda d: f"pad{d}",
)
def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim, capfd):
"""Test CK grouped GEMM with M, N, or K not aligned to CK tile size.

CK constraints for bf16/fp16:
- Contiguous dim of A/B must be dword-aligned (even for 2-byte types).
RowMajor: contiguous dim is cols (K for A, N for B).
ColMajor: contiguous dim is rows (M for A, K for B).
- K tile: 64, M tile: 256, N tile: 128/256
"""
torch.manual_seed(0)
z = 8

# Unaligned values per dimension (all satisfy CK vector-load constraints).
# K: even but not multiple of tile (64). Same for all groups.
# M: not multiples of tile (256), varies per group.
# N: multiple of 16 but not multiple of tile (128).
unaligned_k = 2016
unaligned_m = [100, 300, 150, 200, 50, 350, 250, 180]
unaligned_n = 2032

# Aligned defaults.
k_aligned = 2048
m_aligned = 256
n_aligned = 2048

# Select (un)aligned values based on pad_dim.
k_val = unaligned_k if "K" in pad_dim else k_aligned
m_vals = unaligned_m if "M" in pad_dim else [m_aligned] * z
n_val = unaligned_n if "N" in pad_dim else n_aligned

total_m = sum(m_vals)
os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
os.environ["NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK"] = "1"

if layout == "TN":
A = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(m, k_val, dtype=dtype, device="cuda") for m in m_vals]
out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = False
single_output = True
elif layout == "NN":
A = [torch.randn(k_val, n_val, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(m, k_val, dtype=dtype, device="cuda") for m in m_vals]
out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = True
single_output = True
elif layout == "NT":
A = list(torch.split(
torch.randn(total_m, k_val, dtype=dtype, device="cuda"), m_vals
))
B = list(torch.split(
torch.randn(total_m, n_val, dtype=dtype, device="cuda"), m_vals
))
out = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)]
out_ref = [o.clone() for o in out]
m_splits = m_vals
grad = True
single_output = False
else: # TT
A = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(k_val, m, dtype=dtype, device="cuda") for m in m_vals]
out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = False
single_output = True

# Reference: individual GEMMs
for i in range(z):
if layout == "TT":
# general_gemm doesn't support TT; compute reference manually.
ref = B[i].T.to(torch.float32) @ A[i].T.to(torch.float32)
if accumulate:
out_ref[i] = (out_ref[i].to(torch.float32) + ref).to(dtype)
else:
out_ref[i] = ref.to(dtype)
else:
general_gemm(
A[i],
B[i],
dtype,
grad=grad,
accumulate=accumulate,
layout=layout,
out=out_ref[i],
)

if single_output:
out_ref = [torch.cat(out_ref)]

general_grouped_gemm(
A,
B,
out,
[None] * z,
dtype,
m_splits=m_splits,
grad=grad,
accumulate=accumulate,
layout=layout,
single_output=single_output,
)

for o, o_ref in zip(out, out_ref):
if IS_HIP_EXTENSION and accumulate and dtype == torch.bfloat16 and get_device_compute_capability() == (9, 4):
torch.testing.assert_close(o, o_ref, rtol=4e-2, atol=4e-2)
else:
torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)

# Check for CK fallback warnings from C++ (NVTE_WARN writes to std::cerr).
# capfd captures file-descriptor-level output, including C/C++ stderr.
captured = capfd.readouterr()
if "Falling back" in captured.err or "Fallback" in captured.err:
if "K" in pad_dim and layout != "NN":
pytest.xfail(
"Known CK_Tile limitation: K-padding with non-NN layouts may fall back to cuBLAS "
"(kPadK + ColMajor B bug, or CK_Tile stride alignment requirements)"
)
else:
pytest.fail(f"CK_Tile grouped GEMM fell back to cuBLAS:\n{captured.err}")

os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)
os.environ.pop("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", None)


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ else()
gemm/ck_grouped_gemm/ck_grouped_gemm.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp
amd_detail/system.cpp)
list(APPEND transformer_engine_cuda_sources
fused_attn_rocm/fused_attn_aotriton.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,16 @@ static inline bool launch_grouped_gemm_kernel(const DescContainer& descs,

if (!Kernel::IsSupportedArgument(kargs)) {
NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. "
"Falling back.");
"transA=", ctx.transA, " transB=", ctx.transB,
" accumulate=", ctx.accumulate, " groups=", ctx.group_num,
". Falling back. "
"CK_Tile constraints for bf16/fp16: "
"contiguous dim of A and B must be dword-aligned (even).");
for (size_t i = 0; i < descs.size(); ++i) {
NVTE_WARN(" group ", i, ": M=", descs[i].M, " N=", descs[i].N, " K=", descs[i].K,
" stride_A=", descs[i].stride_A, " stride_B=", descs[i].stride_B,
" stride_E=", descs[i].stride_E);
}
return false;
}

Expand Down
Loading
Loading