diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 7b9b711c22..97f6cb3b88 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 7b9b711c22b6823e87150213ecd8449260db8610 +Subproject commit 97f6cb3b88cacff507cca1280db5650a457d92b3 diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 34ab1df063..68e69e405e 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.15.0.dev0 +2.15.0 diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 3217d29c3b..db86498005 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -229,6 +229,16 @@ Operation fuser .. autoapiclass:: transformer_engine.pytorch.ops.SwiGLU +.. autoapifunction:: transformer_engine.pytorch.triton.mhc.mhc_fused_sinkhorn + +.. autoapifunction:: transformer_engine.pytorch.triton.mhc.mhc_fused_scale + +.. autoapifunction:: transformer_engine.pytorch.triton.mhc.mhc_fused_aggregate + +.. autoapifunction:: transformer_engine.pytorch.triton.mhc.mhc_fused_expand_combine + +.. autoapifunction:: transformer_engine.pytorch.triton.mhc.mhc_fused_projection + Deprecated functions -------------------- diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index ce65bc4305..3efa462628 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -36,7 +36,7 @@ NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" +NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then diff --git a/qa/L0_pytorch_lint/test.sh b/qa/L0_pytorch_lint/test.sh old mode 100644 new mode 100755 diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 377c9ddb00..22636828f9 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -24,7 +24,7 @@ mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" +NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" @@ -37,11 +37,11 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_PATH/tests/pytorch/test_grouped_tensor.py || test_fail "test_grouped_tensor.py" +NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_PATH/tests/pytorch/test_grouped_tensor.py || test_fail "test_grouped_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_backward_override.xml $TE_PATH/tests/pytorch/test_backward_override.py || test_fail "test_backward_override.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" @@ -58,6 +58,8 @@ fi python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" +# Disable autotuning to make unittests faster. In addition, disable TF32 path to fully align with the pytorch reference implementation's precision +NVTE_DISABLE_TRITON_AUTOTUNING=1 NVIDIA_TF32_OVERRIDE=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mhc.xml $TE_PATH/tests/pytorch/test_mhc.py || test_fail "test_mhc.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index f83c4ae066..a5ea74171d 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -32,6 +32,7 @@ add_executable(test_operator test_multi_unpadding.cu test_causal_softmax.cu test_swizzle.cu + test_multi_swizzle.cu test_swap_first_dims.cu test_grouped_gemm.cu ../test_common.cu) diff --git a/tests/cpp/operator/test_multi_swizzle.cu b/tests/cpp/operator/test_multi_swizzle.cu new file mode 100644 index 0000000000..4984b7783b --- /dev/null +++ b/tests/cpp/operator/test_multi_swizzle.cu @@ -0,0 +1,415 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; + +template +void compute_ref_swizzle(const uint8_t *h_input, uint8_t *h_output, + const size_t M, const size_t K) { + constexpr int NEW_SF_TILE_DIM_M = SF_TILE_DIM_M / 4; + constexpr int NEW_SF_TILE_DIM_K = SF_TILE_DIM_K * 4; + constexpr int SF_TILE_SIZE = SF_TILE_DIM_M * SF_TILE_DIM_K; + + for (size_t m = 0; m < M; m++) { + for (size_t k = 0; k < K; k++) { + int tile_id_m = m / SF_TILE_DIM_M; + int tile_id_k = k / SF_TILE_DIM_K; + int m_in_tile = m % SF_TILE_DIM_M; + int k_in_tile = k % SF_TILE_DIM_K; + + int row_in_new_tile = m_in_tile % NEW_SF_TILE_DIM_M; + int col_in_new_tile = m_in_tile / NEW_SF_TILE_DIM_M * SF_TILE_DIM_K + k_in_tile; + + int tile_output_ptr = tile_id_m * SF_TILE_DIM_M * K + tile_id_k * SF_TILE_SIZE; + int out_index = tile_output_ptr + row_in_new_tile * NEW_SF_TILE_DIM_K + col_in_new_tile; + if constexpr (row_scaling) + h_output[out_index] = h_input[k + m * K]; + else + h_output[out_index] = h_input[k * M + m]; + } + } +} + +static void zero_scale_inv_padding(uint8_t *buf, + size_t padded_rows, size_t padded_cols, + size_t orig_rows, size_t orig_cols) { + for (size_t r = 0; r < padded_rows; ++r) { + for (size_t c = 0; c < padded_cols; ++c) { + if (r >= orig_rows || c >= orig_cols) { + buf[r * padded_cols + c] = 0; + } + } + } +} + +// =================================================================== +// Multi-tensor swizzle test +// =================================================================== + +void performTestMultiTensorSwizzle(const int num_tensors, const size_t M, const size_t K, + bool rowwise) { + using namespace test; + constexpr size_t BLOCK_SIZE = 32; + const std::vector shape{M, K}; + + std::vector> input_tensors; + std::vector> output_tensors; + std::vector input_handles; + std::vector output_handles; + + for (int i = 0; i < num_tensors; ++i) { + auto input = std::make_unique("input_" + std::to_string(i), shape, + DType::kFloat8E4M3, rowwise, !rowwise, + NVTE_MXFP8_1D_SCALING); + auto output = std::make_unique("output_" + std::to_string(i), shape, + DType::kFloat8E4M3, rowwise, !rowwise, + NVTE_MXFP8_1D_SCALING); + fillUniform(input.get()); + output->set_with_gemm_swizzled_scales(true); + + input->to_cpu(); + if (rowwise) { + const NVTEShape rs = input->rowwise_scale_inv_shape(); + zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr(), + rs.data[0], rs.data[1], + M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE); + } else { + const NVTEShape cs = input->columnwise_scale_inv_shape(); + zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr(), + cs.data[0], cs.data[1], + (M + BLOCK_SIZE - 1) / BLOCK_SIZE, K); + } + input->from_cpu(); + + input_handles.push_back(input->data()); + output_handles.push_back(output->data()); + input_tensors.emplace_back(std::move(input)); + output_tensors.emplace_back(std::move(output)); + } + + nvte_multi_tensor_swizzle_scaling_factors(input_handles.data(), output_handles.data(), + num_tensors, 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + for (int i = 0; i < num_tensors; ++i) { + output_tensors[i]->to_cpu(); + if (rowwise) { + const NVTEShape rs = input_tensors[i]->rowwise_scale_inv_shape(); + const size_t numel = rs.data[0] * rs.data[1]; + std::unique_ptr ref = std::make_unique(numel); + compute_ref_swizzle<128, 4, true>( + input_tensors[i]->rowwise_cpu_scale_inv_ptr(), + ref.get(), rs.data[0], rs.data[1]); + compareResults("multi_tensor_swizzle_row_" + std::to_string(i), + output_tensors[i]->rowwise_cpu_scale_inv_ptr(), + ref.get(), numel); + } else { + const NVTEShape cs = input_tensors[i]->columnwise_scale_inv_shape(); + const size_t numel = cs.data[0] * cs.data[1]; + std::unique_ptr ref = std::make_unique(numel); + compute_ref_swizzle<128, 4, false>( + input_tensors[i]->columnwise_cpu_scale_inv_ptr(), + ref.get(), cs.data[1], cs.data[0]); + compareResults("multi_tensor_swizzle_col_" + std::to_string(i), + output_tensors[i]->columnwise_cpu_scale_inv_ptr(), + ref.get(), numel); + } + } +} + +// =================================================================== +// Multi-tensor unswizzle test (uses single-tensor swizzle to prepare) +// =================================================================== + +void performTestMultiTensorUnswizzle(const int num_tensors, const size_t M, const size_t K, + bool rowwise) { + using namespace test; + constexpr size_t BLOCK_SIZE = 32; + const std::vector shape{M, K}; + + std::vector> orig_tensors, swizzled_tensors, output_tensors; + std::vector swizzled_handles, output_handles; + + for (int i = 0; i < num_tensors; ++i) { + auto orig = std::make_unique("orig_" + std::to_string(i), shape, + DType::kFloat8E4M3, rowwise, !rowwise, + NVTE_MXFP8_1D_SCALING); + auto swizzled = std::make_unique("swizzled_" + std::to_string(i), shape, + DType::kFloat8E4M3, rowwise, !rowwise, + NVTE_MXFP8_1D_SCALING); + auto output = std::make_unique("output_" + std::to_string(i), shape, + DType::kFloat8E4M3, rowwise, !rowwise, + NVTE_MXFP8_1D_SCALING); + fillUniform(orig.get()); + swizzled->set_with_gemm_swizzled_scales(true); + + orig->to_cpu(); + if (rowwise) { + const NVTEShape rs = orig->rowwise_scale_inv_shape(); + zero_scale_inv_padding(orig->rowwise_cpu_scale_inv_ptr(), + rs.data[0], rs.data[1], + M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE); + } else { + const NVTEShape cs = orig->columnwise_scale_inv_shape(); + zero_scale_inv_padding(orig->columnwise_cpu_scale_inv_ptr(), + cs.data[0], cs.data[1], + (M + BLOCK_SIZE - 1) / BLOCK_SIZE, K); + } + orig->from_cpu(); + + nvte_swizzle_scaling_factors(orig->data(), swizzled->data(), 0); + + swizzled_handles.push_back(swizzled->data()); + output_handles.push_back(output->data()); + orig_tensors.emplace_back(std::move(orig)); + swizzled_tensors.emplace_back(std::move(swizzled)); + output_tensors.emplace_back(std::move(output)); + } + + nvte_multi_tensor_unswizzle_scaling_factors(swizzled_handles.data(), output_handles.data(), + num_tensors, 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + for (int i = 0; i < num_tensors; ++i) { + orig_tensors[i]->to_cpu(); + output_tensors[i]->to_cpu(); + if (rowwise) { + const NVTEShape rs = orig_tensors[i]->rowwise_scale_inv_shape(); + const size_t numel = rs.data[0] * rs.data[1]; + compareResults("multi_unswizzle_row_" + std::to_string(i), + output_tensors[i]->rowwise_cpu_scale_inv_ptr(), + orig_tensors[i]->rowwise_cpu_scale_inv_ptr(), + numel); + } else { + const NVTEShape cs = orig_tensors[i]->columnwise_scale_inv_shape(); + const size_t numel = cs.data[0] * cs.data[1]; + compareResults("multi_unswizzle_col_" + std::to_string(i), + output_tensors[i]->columnwise_cpu_scale_inv_ptr(), + orig_tensors[i]->columnwise_cpu_scale_inv_ptr(), + numel); + } + } +} + +// =================================================================== +// Multi-tensor swizzle -> unswizzle roundtrip test +// =================================================================== + +void performTestMultiTensorRoundtrip(const int num_tensors, const size_t M, const size_t K, + bool rowwise) { + using namespace test; + constexpr size_t BLOCK_SIZE = 32; + const std::vector shape{M, K}; + + std::vector> orig_tensors, mid_tensors, final_tensors; + std::vector orig_handles, mid_handles, final_handles; + + for (int i = 0; i < num_tensors; ++i) { + auto orig = std::make_unique("orig_" + std::to_string(i), shape, + DType::kFloat8E4M3, rowwise, !rowwise, + NVTE_MXFP8_1D_SCALING); + auto mid = std::make_unique("mid_" + std::to_string(i), shape, + DType::kFloat8E4M3, rowwise, !rowwise, + NVTE_MXFP8_1D_SCALING); + auto fin = std::make_unique("fin_" + std::to_string(i), shape, + DType::kFloat8E4M3, rowwise, !rowwise, + NVTE_MXFP8_1D_SCALING); + fillUniform(orig.get()); + mid->set_with_gemm_swizzled_scales(true); + + orig->to_cpu(); + if (rowwise) { + const NVTEShape rs = orig->rowwise_scale_inv_shape(); + zero_scale_inv_padding(orig->rowwise_cpu_scale_inv_ptr(), + rs.data[0], rs.data[1], + M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE); + } else { + const NVTEShape cs = orig->columnwise_scale_inv_shape(); + zero_scale_inv_padding(orig->columnwise_cpu_scale_inv_ptr(), + cs.data[0], cs.data[1], + (M + BLOCK_SIZE - 1) / BLOCK_SIZE, K); + } + orig->from_cpu(); + + orig_handles.push_back(orig->data()); + mid_handles.push_back(mid->data()); + final_handles.push_back(fin->data()); + orig_tensors.emplace_back(std::move(orig)); + mid_tensors.emplace_back(std::move(mid)); + final_tensors.emplace_back(std::move(fin)); + } + + nvte_multi_tensor_swizzle_scaling_factors(orig_handles.data(), mid_handles.data(), + num_tensors, 0); + nvte_multi_tensor_unswizzle_scaling_factors(mid_handles.data(), final_handles.data(), + num_tensors, 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + for (int i = 0; i < num_tensors; ++i) { + orig_tensors[i]->to_cpu(); + final_tensors[i]->to_cpu(); + if (rowwise) { + const NVTEShape rs = orig_tensors[i]->rowwise_scale_inv_shape(); + const size_t numel = rs.data[0] * rs.data[1]; + compareResults("multi_roundtrip_row_" + std::to_string(i), + final_tensors[i]->rowwise_cpu_scale_inv_ptr(), + orig_tensors[i]->rowwise_cpu_scale_inv_ptr(), + numel); + } else { + const NVTEShape cs = orig_tensors[i]->columnwise_scale_inv_shape(); + const size_t numel = cs.data[0] * cs.data[1]; + compareResults("multi_roundtrip_col_" + std::to_string(i), + final_tensors[i]->columnwise_cpu_scale_inv_ptr(), + orig_tensors[i]->columnwise_cpu_scale_inv_ptr(), + numel); + } + } +} + +// =================================================================== +// Test suites +// =================================================================== + +class MultiTensorSwizzleTestSuite + : public ::testing::TestWithParam> {}; + +TEST_P(MultiTensorSwizzleTestSuite, TestMultiTensorSwizzle) { + const auto num_tensors = std::get<0>(GetParam()); + const auto M = std::get<1>(GetParam()); + const auto K = std::get<2>(GetParam()); + const auto rowwise = std::get<3>(GetParam()); + performTestMultiTensorSwizzle(num_tensors, M, K, rowwise); +} + +class MultiTensorUnswizzleTestSuite + : public ::testing::TestWithParam> {}; + +TEST_P(MultiTensorUnswizzleTestSuite, TestMultiTensorUnswizzle) { + const auto num_tensors = std::get<0>(GetParam()); + const auto M = std::get<1>(GetParam()); + const auto K = std::get<2>(GetParam()); + const auto rowwise = std::get<3>(GetParam()); + performTestMultiTensorUnswizzle(num_tensors, M, K, rowwise); +} + +class MultiTensorRoundtripTestSuite + : public ::testing::TestWithParam> {}; + +TEST_P(MultiTensorRoundtripTestSuite, TestMultiTensorRoundtrip) { + const auto num_tensors = std::get<0>(GetParam()); + const auto M = std::get<1>(GetParam()); + const auto K = std::get<2>(GetParam()); + const auto rowwise = std::get<3>(GetParam()); + performTestMultiTensorRoundtrip(num_tensors, M, K, rowwise); +} + +namespace { + +// Shapes that exercise the narrow_k kernel (rowwise) / narrow_m kernel (colwise): +// Narrow-K fires when ALL tensors have scale num_tiles_k < TB_DIM (32), +// i.e. padded ceil(K/32) < 128. +// Narrow-M fires analogously for colwise when padded K < 4096 +// (since colwise m = K padded to 128, num_tiles_m = m / 128 < 32). +// +// Shapes that bypass narrow and use the regular multi_tensor kernel: +// K >= 4096 makes num_tiles_k >= 32 (rowwise) and num_tiles_m >= 32 (colwise). + +std::vector> multi_tensor_test_cases = { + // --- Narrow path cases (K small → narrow_k for row, narrow_m for col) --- + // M and K both aligned to 128 + {3, 256, 256, true}, + {3, 256, 256, false}, + {4, 128, 128, true}, + {4, 128, 128, false}, + // M not divisible by 128 (but must be divisible by 32 for colwise — + // the kernel computes original_K = M / BLOCK_SIZE using floor division) + {3, 192, 256, true}, + {3, 192, 256, false}, + {2, 64, 256, true}, + {2, 64, 256, false}, + // Larger narrow K (num_tiles_k = 8, shared mem = 128 KB) + {2, 128, 1024, true}, + {2, 128, 1024, false}, + // K not divisible by 128 + {3, 256, 160, true}, + {3, 256, 160, false}, + // Neither M nor K divisible by 128 + {3, 192, 160, true}, + {3, 192, 160, false}, + // Minimum sizes (M=32 is the MXFP8 block size minimum for colwise) + {2, 32, 32, true}, + {2, 32, 32, false}, + {4, 32, 64, true}, + {4, 32, 64, false}, + + // --- Non-narrow path cases (K >= 4096 → regular multi_tensor kernel) --- + {3, 256, 4096, true}, + {3, 256, 4096, false}, + {2, 128, 8192, true}, + {2, 128, 8192, false}, +}; + +} // namespace + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MultiTensorSwizzleTestSuite, + ::testing::ValuesIn(multi_tensor_test_cases), + [](const testing::TestParamInfo& info) { + return "n" + std::to_string(std::get<0>(info.param)) + + "_M" + std::to_string(std::get<1>(info.param)) + + "_K" + std::to_string(std::get<2>(info.param)) + + (std::get<3>(info.param) ? "_row" : "_col"); + }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MultiTensorUnswizzleTestSuite, + ::testing::ValuesIn(multi_tensor_test_cases), + [](const testing::TestParamInfo& info) { + return "n" + std::to_string(std::get<0>(info.param)) + + "_M" + std::to_string(std::get<1>(info.param)) + + "_K" + std::to_string(std::get<2>(info.param)) + + (std::get<3>(info.param) ? "_row" : "_col"); + }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MultiTensorRoundtripTestSuite, + ::testing::ValuesIn(multi_tensor_test_cases), + [](const testing::TestParamInfo& info) { + return "n" + std::to_string(std::get<0>(info.param)) + + "_M" + std::to_string(std::get<1>(info.param)) + + "_K" + std::to_string(std::get<2>(info.param)) + + (std::get<3>(info.param) ? "_row" : "_col"); + }); diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 806a2482ab..3fec5062ff 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -248,11 +248,11 @@ void performTestGroupedSwizzleMXFP8(const int num_tensors, const size_t M, const const NVTEShape rs = input->rowwise_scale_inv_shape(); zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr(), rs.data[0], rs.data[1], - M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE); + M, divide_round_up(K, BLOCK_SIZE)); const NVTEShape cs = input->columnwise_scale_inv_shape(); zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr(), cs.data[0], cs.data[1], - (M + BLOCK_SIZE - 1) / BLOCK_SIZE, K); + divide_round_up(M, BLOCK_SIZE), K); input->from_cpu(); input_ptrs.push_back(input.get()); @@ -444,11 +444,11 @@ void performTestGroupedSwizzleUnswizzleRoundtrip(const int num_tensors, const si const NVTEShape rs = orig->rowwise_scale_inv_shape(); zero_scale_inv_padding(orig->rowwise_cpu_scale_inv_ptr(), rs.data[0], rs.data[1], - M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE); + M, divide_round_up(K, BLOCK_SIZE)); const NVTEShape cs = orig->columnwise_scale_inv_shape(); zero_scale_inv_padding(orig->columnwise_cpu_scale_inv_ptr(), cs.data[0], cs.data[1], - (M + BLOCK_SIZE - 1) / BLOCK_SIZE, K); + divide_round_up(M, BLOCK_SIZE), K); orig->from_cpu(); orig_ptrs.push_back(orig.get()); @@ -541,6 +541,253 @@ INSTANTIATE_TEST_SUITE_P( } ); +// Build a "compact" grouped MXFP8 scale_inv buffer for swizzle input. This is +// the layout produced by the grouped MXFP8 quantize kernel: the per-tensor +// stride is `M_per_tensor * padded_K` (rowwise) or `DIVUP(M,32) * padded_K_for_cols` +// (columnwise) -- i.e. NO per-tensor padding rows are inserted. The total buffer +// is rounded up at its very end to a multiple of 128 (rowwise) or 4 (columnwise) +// in the grouped first dim, matching what the C++ allocator hands out. +// +// Each tensor's compact scales are gathered from the unpadded-prefix rows of +// that tensor's per-tensor padded CPU scale buffer. +namespace { + +struct CompactScaleBuffer { + test::CudaPtr<> ptr; + size_t numel{0}; +}; + +CompactScaleBuffer gather_compact_grouped_scale( + const std::vector>& tensors, + size_t M_per_tensor, size_t K_per_tensor, bool rowwise) { + using namespace test; + constexpr size_t BLOCK = 32; + const size_t num_tensors = tensors.size(); + + size_t per_tensor_first_unpadded; + size_t per_tensor_last_padded; + size_t group_first_align; + if (rowwise) { + per_tensor_first_unpadded = M_per_tensor; + per_tensor_last_padded = + round_up_to_nearest_multiple(divide_round_up(K_per_tensor, BLOCK), 4); + group_first_align = 128; + } else { + per_tensor_first_unpadded = divide_round_up(M_per_tensor, BLOCK); + per_tensor_last_padded = round_up_to_nearest_multiple(K_per_tensor, 128); + group_first_align = 4; + } + + const size_t per_tensor_compact_numel = + per_tensor_first_unpadded * per_tensor_last_padded; + const size_t total_first = round_up_to_nearest_multiple( + num_tensors * per_tensor_first_unpadded, group_first_align); + const size_t total_numel = total_first * per_tensor_last_padded; + + std::vector host_buf(total_numel, 0); + for (size_t i = 0; i < num_tensors; ++i) { + tensors[i]->to_cpu(); + const NVTEShape padded_shape = rowwise ? tensors[i]->rowwise_scale_inv_shape() + : tensors[i]->columnwise_scale_inv_shape(); + NVTE_CHECK(padded_shape.data[1] == per_tensor_last_padded, + "Unexpected per-tensor padded last dim in compact gather."); + const uint8_t* src = rowwise + ? tensors[i]->rowwise_cpu_scale_inv_ptr() + : tensors[i]->columnwise_cpu_scale_inv_ptr(); + uint8_t* dst = host_buf.data() + i * per_tensor_compact_numel; + // Per-tensor padded buffer is row-major (padded_first, padded_last); copy + // only the first `per_tensor_first_unpadded` rows. + std::memcpy(dst, src, per_tensor_compact_numel); + } + + CompactScaleBuffer out; + out.ptr = cuda_alloc(total_numel); + NVTE_CHECK_CUDA(cudaMemcpy(out.ptr.get(), host_buf.data(), + total_numel, cudaMemcpyHostToDevice)); + out.numel = total_numel; + return out; +} + +} // namespace + +// Tests that grouped_swizzle_for_gemm correctly handles a COMPACT input +// scale_inv buffer (no per-tensor padding rows), producing an output in the +// per-tensor padded layout with padded regions zeroed out. This is the layout +// produced by the grouped MXFP8 quantize kernel; previously the swizzle kernel +// asserted the input matched the per-tensor padded packed size, which broke +// grouped MLP weights with M not a multiple of 128. +void performTestGroupedSwizzleMXFP8CompactInput(const int num_tensors, const size_t M, + const size_t K) { + using namespace transformer_engine; + using namespace test; + + std::vector> input_tensors; + std::vector> output_tensors; + std::vector input_ptrs, output_ptrs; + input_tensors.reserve(num_tensors); + output_tensors.reserve(num_tensors); + input_ptrs.reserve(num_tensors); + output_ptrs.reserve(num_tensors); + + constexpr size_t BLOCK_SIZE = 32; + const std::vector shape{M, K}; + for (int i = 0; i < num_tensors; ++i) { + auto input = std::make_unique("input_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + auto output = std::make_unique("output_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + fillUniform(input.get()); + fillUniform(output.get()); + + // Zero the per-tensor padded regions so the reference (which sees the + // padded layout) and the kernel (which sees the compact layout but writes + // zeros into output padding) agree byte-for-byte. + input->to_cpu(); + const NVTEShape rs = input->rowwise_scale_inv_shape(); + zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr(), + rs.data[0], rs.data[1], + M, divide_round_up(K, BLOCK_SIZE)); + const NVTEShape cs = input->columnwise_scale_inv_shape(); + zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr(), + cs.data[0], cs.data[1], + divide_round_up(M, BLOCK_SIZE), K); + input->from_cpu(); + + input_ptrs.push_back(input.get()); + output_ptrs.push_back(output.get()); + input_tensors.emplace_back(std::move(input)); + output_tensors.emplace_back(std::move(output)); + } + + // Build a per-tensor padded grouped output via the standard helper, and a + // compact-scale grouped input by overriding the scale_inv buffers of a + // padded grouped input with newly allocated compact buffers. + GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING); + GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING); + + CompactScaleBuffer compact_row = + gather_compact_grouped_scale(input_tensors, M, K, /*rowwise=*/true); + CompactScaleBuffer compact_col = + gather_compact_grouped_scale(input_tensors, M, K, /*rowwise=*/false); + + grouped_input.scale_inv = std::move(compact_row.ptr); + grouped_input.columnwise_scale_inv = std::move(compact_col.ptr); + { + NVTEShape s = nvte_make_shape(&compact_row.numel, 1); + NVTEBasicTensor t{grouped_input.scale_inv.get(), kNVTEFloat8E8M0, s}; + nvte_set_grouped_tensor_param(grouped_input.get_handle(), + kNVTEGroupedRowwiseScaleInv, &t, sizeof(t)); + } + { + NVTEShape s = nvte_make_shape(&compact_col.numel, 1); + NVTEBasicTensor t{grouped_input.columnwise_scale_inv.get(), kNVTEFloat8E8M0, s}; + nvte_set_grouped_tensor_param(grouped_input.get_handle(), + kNVTEGroupedColumnwiseScaleInv, &t, sizeof(t)); + } + + const uint8_t input_swizzled = 0; + nvte_set_grouped_tensor_param(grouped_input.get_handle(), + kNVTEGroupedWithGEMMSwizzledScales, + &input_swizzled, sizeof(input_swizzled)); + const uint8_t output_swizzled = 1; + nvte_set_grouped_tensor_param(grouped_output.get_handle(), + kNVTEGroupedWithGEMMSwizzledScales, + &output_swizzled, sizeof(output_swizzled)); + + const NVTEShape row_shape = input_tensors[0]->rowwise_scale_inv_shape(); + const NVTEShape col_shape = input_tensors[0]->columnwise_scale_inv_shape(); + const size_t row_numel = row_shape.data[0] * row_shape.data[1]; + const size_t col_numel = col_shape.data[0] * col_shape.data[1]; + + // Memset to a non-zero sentinel so we can detect kernel failures to write + // padded regions (those must be overwritten with zero by the kernel). + NVTE_CHECK_CUDA(cudaMemset(grouped_output.scale_inv.get(), 0xCD, + num_tensors * row_numel)); + NVTE_CHECK_CUDA(cudaMemset(grouped_output.columnwise_scale_inv.get(), 0xCD, + num_tensors * col_numel)); + + nvte_swizzle_grouped_scaling_factors(grouped_input.get_handle(), + grouped_output.get_handle(), 0); + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + std::vector output_row(num_tensors * row_numel); + std::vector output_col(num_tensors * col_numel); + NVTE_CHECK_CUDA(cudaMemcpy(output_row.data(), grouped_output.scale_inv.get(), + output_row.size(), cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(output_col.data(), + grouped_output.columnwise_scale_inv.get(), + output_col.size(), cudaMemcpyDeviceToHost)); + + std::vector ref_row(num_tensors * row_numel); + std::vector ref_col(num_tensors * col_numel); + for (int i = 0; i < num_tensors; ++i) { + compute_ref_swizzle<128, 4, true>( + input_tensors[i]->rowwise_cpu_scale_inv_ptr(), + ref_row.data() + i * row_numel, + row_shape.data[0], row_shape.data[1]); + compute_ref_swizzle<128, 4, false>( + input_tensors[i]->columnwise_cpu_scale_inv_ptr(), + ref_col.data() + i * col_numel, + col_shape.data[1], col_shape.data[0]); + } + + compareResults("grouped_swizzle_compact_rowwise", output_row.data(), + ref_row.data(), num_tensors * row_numel); + compareResults("grouped_swizzle_compact_colwise", output_col.data(), + ref_col.data(), num_tensors * col_numel); +} + +class SwizzleGroupedCompactInputTestSuite + : public ::testing::TestWithParam> {}; + +TEST_P(SwizzleGroupedCompactInputTestSuite, TestGroupedSwizzleMXFP8CompactInput) { + const auto num_tensors = std::get<0>(GetParam()); + const auto M = std::get<1>(GetParam()); + const auto K = std::get<2>(GetParam()); + performTestGroupedSwizzleMXFP8CompactInput(num_tensors, M, K); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + SwizzleGroupedCompactInputTestSuite, + ::testing::Values( + // Aligned M and K. Per-tensor compact stride == per-tensor padded stride, + // so the kernel may use either layout; serves as a sanity check that the + // compact-input plumbing doesn't regress aligned shapes. + std::make_tuple(3, 256, 256), + std::make_tuple(4, 128, 128), + // M NOT divisible by 128 (the original-bug case): per-tensor compact stride + // shrinks vs padded. We pick (num_tensors, M) so that BOTH + // round_up(N * M, 128) != N * round_up(M, 128) (rowwise) + // round_up(N * DIVUP(M,32), 4) != N * round_up(DIVUP(M,32),4) (colwise) + // i.e. compact_total != padded_total on either axis, so the kernel + // unambiguously detects the compact layout. + std::make_tuple(4, 200, 256), + std::make_tuple(4, 65, 256), + std::make_tuple(2, 2880, 2880), // shape from the originally failing workload + // K not divisible by 128 (DIVUP(K,32) padded up to a multiple of 4). + std::make_tuple(3, 256, 160), + std::make_tuple(2, 256, 96), + // Neither M nor K aligned. + std::make_tuple(4, 200, 160), + std::make_tuple(4, 33, 64), + std::make_tuple(2, 1, 32), + // num_tensors * M not aligned to 128 -> exercises trailing alignment slack + // at the end of the compact rowwise buffer. + std::make_tuple(3, 64, 128), + std::make_tuple(5, 33, 96) + ), + [](const testing::TestParamInfo& info) { + return "n" + std::to_string(std::get<0>(info.param)) + + "_M" + std::to_string(std::get<1>(info.param)) + + "_K" + std::to_string(std::get<2>(info.param)); + } +); + class UnswizzleGroupedTestSuite : public ::testing::TestWithParam> {}; @@ -613,6 +860,14 @@ std::vector> num_tiles = { {65, 257}, {65, 258}, {65, 259}, + // Additional narrow-path coverage: narrow_k (row) when num_tiles_K < 32, + // narrow_m (col) when num_tiles_M < 32. + {1, 4}, // narrow_k with 4 K-tiles + {1, 8}, // narrow_k with 8 K-tiles + {4, 1}, // narrow_m with 4 M-tiles + {8, 1}, // narrow_m with 8 M-tiles + {31, 1}, // narrow_m at boundary (31 < TB_DIM=32) + {1, 31}, // narrow_k at boundary (31 < TB_DIM=32) }; // Raw {M, K} data shapes for unswizzle tests. Includes aligned cases (scale dims diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 3e5529c077..b154cd49d9 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1913,11 +1913,24 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype) def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims): - out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims) # Note: we use jnp.sum instead of jnp.mean to make the gradient larger # and prevent them from being clamp to zero in FP8. / sqrt(x.size) is used to # normalize the output and prevent the gradient from being too large for FP8. - out_sum_list = [jnp.sum(out) for out in out_list] + # + # We pass bias=None here and add bias externally in fp32 so the autodiff + # bias-grad (sum over the m axis of the cotangent) accumulates in fp32. + # If bias is added inside _ref_grouped_dense in bf16, JAX lowers the bias + # backward as a bf16 sum-over-m and loses precision on the largest group, + # producing a >bf16-rtol mismatch against the primitive's grouped_dbias + # (which casts the cotangent to fp32 before segment_sum). Bias is required + # for this helper since it is only used by the grad tests below, which all + # set with_bias=True. + assert bias is not None, "_ref_sum_grouped_dense requires a non-None bias" + out_list = self._ref_grouped_dense(x, kernel, None, group_sizes, contracting_dims) + out_sum_list = [] + for out_i, bias_i in zip(out_list, bias): + out_with_bias_fp32 = out_i.astype(jnp.float32) + bias_i.astype(jnp.float32) + out_sum_list.append(jnp.sum(out_with_bias_fp32)) return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size) def _primitive_sum_grouped_dense( @@ -1926,7 +1939,9 @@ def _primitive_sum_grouped_dense( out = grouped_dense( x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set ) - return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size) + # Match the fp32 accumulation in _ref_sum_grouped_dense so loss values are + # comparable and the cotangent dtype on `out` is unambiguous. + return jnp.sum(out.astype(jnp.float32)) / jnp.sqrt(x.size) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) def test_grouped_dense_grad_fp16(self, dtype, input_shape): diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 0f36a8816d..8dfea644a5 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -19,8 +19,14 @@ DotProductAttention, Float8Quantizer, Float8CurrentScalingQuantizer, + MXFP8Quantizer, +) +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, + MXFP8BlockScaling, + Format, ) -from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling from utils import ModelConfig, compare_and_assert dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -180,6 +186,7 @@ def run_dpa_with_cp( scaling_mode="delayed", f16_O="False", is_training="True", + deterministic="False", log_level=logging.WARNING, ): """Test DotProductAttention module with context parallelism""" @@ -188,11 +195,15 @@ def run_dpa_with_cp( is_training = is_training == "True" # set up environment variables and config + if deterministic == "True": + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + else: + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" fp8_bwd = fp8_bwd == "True" and dtype == "fp8" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0" fp8_dpa = fp8_dpa == "True" and dtype == "fp8" - fp8_mha = fp8_mha == "True" and dtype == "fp8" - f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True" + fp8_mha = fp8_mha == "True" and dtype == "fp8" and scaling_mode != "mxfp8" + f16_O = dtype == "fp8" and scaling_mode in ["current", "mxfp8"] and f16_O == "True" os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0" os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" @@ -247,6 +258,8 @@ def run_dpa_with_cp( fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) if scaling_mode == "current": fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "mxfp8": + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) # instantiate attention module core_attn = DotProductAttention( @@ -302,10 +315,25 @@ def run_dpa_with_cp( fp8_dtype=tex.DType.kFloat8E5M2, device="cuda", ) + if scaling_mode == "mxfp8": + qkv_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + ) + qkv_quantizer.optimize_for_gemm = True + qkv_quantizer.internal = False + dout_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + rowwise=True, + columnwise=True, + ) + dout_quantizer.optimize_for_gemm = True + dout_quantizer.internal = False qkv_layout = "_".join([qkv_format] * 3) q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] if fp8_mha: - q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) + q, k, v, qkv_layout, _ = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) for x in [q, k, v]: x.requires_grad = True @@ -413,7 +441,7 @@ def run_dpa_with_cp( dout_quantizer.scale.fill_(1.0) dout_quantizer.amax.fill_(0.0) if fp8_mha: - q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) + q_, k_, v_, qkv_layout, _ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) if is_training: q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: @@ -494,6 +522,7 @@ def run_dpa_with_cp( # get outputs tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] + names = ["out", "dq", "dk", "dv", "dbias", "out_cp", "dq_cp", "dk_cp", "dv_cp", "dbias_cp"] if fp8_mha: tensors_to_deq = [out, out_] if not fp8_bwd else tensors for i, tensor in enumerate(tensors_to_deq): @@ -502,11 +531,11 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[5] = tensors_to_deq - for tensor in tensors: + for i, tensor in enumerate(tensors): # dbias/dbias_ could be None, so skip check for it if tensor is not None: - assert torch.all(~torch.isnan(tensor)) - assert torch.all(~torch.isinf(tensor)) + assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN" + assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf" out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 38d8626b4b..f735793fff 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -26,6 +26,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import ( _attention_backends, ) +from transformer_engine.pytorch.attention.dot_product_attention import backends as dpa_backends +import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils, check_set_window_size, @@ -667,6 +669,75 @@ def test_dpa_mask(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") +def test_unfused_thd_padding_causal_uses_sdpa_without_full_mask(monkeypatch): + """Unfused THD padding_causal should avoid materializing a full quadratic mask.""" + reset_rng_states() + batch_size = 2 + num_heads = 2 + head_dim = 16 + seqlens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) + total_seqlen = int(cu_seqlens[-1].item()) + max_seqlen = int(seqlens.max().item()) + + query = torch.randn( + total_seqlen, num_heads, head_dim, dtype=torch.float16, device="cuda", requires_grad=True + ) + key = torch.randn_like(query, requires_grad=True) + value = torch.randn_like(query, requires_grad=True) + softmax_scale = head_dim**-0.5 + + expected = [] + with torch.no_grad(): + for batch_id in range(batch_size): + start = int(cu_seqlens[batch_id].item()) + end = int(cu_seqlens[batch_id + 1].item()) + q = query[start:end].permute(1, 0, 2).unsqueeze(0) + k = key[start:end].permute(1, 0, 2).unsqueeze(0) + v = value[start:end].permute(1, 0, 2).unsqueeze(0) + expected.append( + torch.nn.functional.scaled_dot_product_attention( + q, k, v, dropout_p=0.0, is_causal=True, scale=softmax_scale + ) + .squeeze(0) + .permute(1, 0, 2) + .reshape(end - start, -1) + ) + expected = torch.cat(expected, dim=0) + + def fail_get_full_mask(*args, **kwargs): + raise AssertionError("get_full_mask should not be called for this path") + + monkeypatch.setattr(dpa_utils, "get_full_mask", fail_get_full_mask) + + attention = dpa_backends.UnfusedDotProductAttention( + softmax_scale=softmax_scale, + attention_type="self", + attention_dropout=0.0, + ).eval() + output = attention( + {}, + query, + key, + value, + qkv_layout="thd_thd_thd", + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + attn_mask_type="padding_causal", + window_size=(-1, 0), + ) + + torch.testing.assert_close(output, expected, rtol=1e-3, atol=1e-3) + output.float().sum().backward() + assert query.grad is not None + assert key.grad is not None + assert value.grad is not None + + model_configs_bias = { # test: ModelConfig(b, sq, hq, dqk) "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"), @@ -1936,20 +2007,45 @@ def get_model(dtype, config): return outputs +attn_mask_type = "causal" model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 2048, 16, 128), - "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), - "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), - "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), - "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), - "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), - "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), - "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), - "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), - "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), - "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), - "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), + "fp8_9": ModelConfig( + 2, + 4096, + 128, + 192, + head_dim_v=128, + ), + "fp8_10": ModelConfig( + 1, + 4096, + 128, + 192, + head_dim_v=128, + attn_mask_type="causal", + ), + "fp8_11": ModelConfig( + 2, + 4096, + 128, + 192, + head_dim_v=128, + attn_mask_type="causal_bottom_right", + ), + "fp8_12": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), + "fp8_13": ModelConfig(2, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)), + "fp8_14": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), + "fp8_15": ModelConfig(2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)), + "fp8_16": ModelConfig( + 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + ), + "fp8_17": ModelConfig( + 2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" + ), + "fp8_18": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), + "fp8_19": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), + "fp8_20": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), } param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] @@ -1966,7 +2062,7 @@ def get_model(dtype, config): @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) def test_mha_fp8_vs_f16( dtype, model, @@ -1997,6 +2093,12 @@ def test_mha_fp8_vs_f16( fp8_dpa=True, fp8_mha=True, ) + elif scaling_mode == "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling( + fp8_format=recipe.Format.E4M3, + fp8_dpa=True, + fp8_mha=False, + ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe available_backends, _, _ = get_available_attention_backends( @@ -2216,7 +2318,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode): """Test DotProductAttention module in FP8""" config = model_configs_fp8_vs_f16[model] @@ -2248,6 +2350,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fp8_format=recipe.Format.HYBRID, fp8_dpa=True, ) + elif scaling_mode == "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling( + fp8_format=recipe.Format.E4M3, + fp8_dpa=True, + fp8_mha=False, + ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe available_backends, _, _ = get_available_attention_backends( @@ -2319,7 +2427,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal atol = 5e-1 rtol = 5e-2 rmse_tol = 0.11 - bwd_names = ["dq", "dk", "dv"] + bwd_names = ["dq", "dk", "dv", "d_softmax_offset"] if flash_attn_supported and fused_attn_supported_f16: logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("forward output")) @@ -2408,7 +2516,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: with quantized_model_init(enabled=fp8_dpa): dpa = DotProductAttention( config.num_heads, - config.head_dim_qk, + (config.head_dim_qk, config.head_dim_v), num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, sequence_parallel=False, @@ -2418,6 +2526,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layer_number=1, attention_type="self", qkv_format=qkv_format, + softmax_type=config.softmax_type, ).to(dtype=dtype, device="cuda") if not is_training: dpa = dpa.eval() @@ -2453,7 +2562,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "skv": config.max_seqlen_kv, "h": config.num_heads, "hg": config.num_gqa_groups, - "d": config.head_dim_qk, + "dqk": config.head_dim_qk, + "dv": config.head_dim_v, "t": cu_seqlens_q[-1], "tg": cu_seqlens_kv[-1], "3": 3, @@ -2469,6 +2579,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layout = layout.replace("s", "skv") layout = layout.replace("h", "hg") layout = layout.replace("t", "tg") + if i == 2: + layout = layout.replace("d", "dv") + else: + layout = layout.replace("d", "dqk") tensor_shape = [dim_to_num[j] for j in layout.split("_")] if config.dropout_p == 0.0: tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda") @@ -2493,6 +2607,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: qkv_format_kv = "_".join(qkv_format) qkv_format_kv = qkv_format_kv.replace("s", "sq") + qkv_format_kv = qkv_format_kv.replace("d", "dv") out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")] out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda") @@ -2503,6 +2618,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: inp[1], inp[2], qkv_format=qkv_format, + window_size=config.window_size, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=config.max_seqlen_q, @@ -2510,14 +2626,16 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: attn_mask_type=config.attn_mask_type, checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, - fp8_output=fp8_dpa, ) if is_training: out.backward(out_grad) + d_softmax_offset = None + if is_training and config.softmax_type != "vanilla": + d_softmax_offset = dpa.softmax_offset.grad if is_training: - return out, (inp[0].grad, inp[1].grad, inp[2].grad) - return out, (None, None, None) + return out, (inp[0].grad, inp[1].grad, inp[2].grad, d_softmax_offset) + return out, (None, None, None, d_softmax_offset) model_configs_fp8 = { @@ -2769,6 +2887,8 @@ def forward( quantization_params=qkv_quantizer, use_split_accumulator=_2X_ACC_FPROP, ) + qkv_layout = "bs3hd" if cudnn_frontend_version == 1 else "t3hd" + o_format = "bshd" if cudnn_frontend_version == 1 else "thd" qkv = qkv.view(-1, 3, h, d) qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous() torch.save(qkv_fp16, "qkv.pt") @@ -2797,7 +2917,8 @@ def forward( attn_scale=None, dropout=p_dropout, fast_zero_fill=fast_zero_fill, - qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", + qkv_layout=qkv_layout, + o_format=o_format, attn_bias_type="no_bias", attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding", rng_gen=None, @@ -2820,6 +2941,8 @@ def forward( ctx.num_heads = num_heads ctx.mask_type = mask_type ctx.dtype = inp.dtype + ctx.qkv_layout = qkv_layout + ctx.o_format = o_format ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer @@ -2837,7 +2960,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], (q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_func_ctx(ctx) proj_dgrad = ctx.dO_quantizer(grad_output) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_s, @@ -2850,7 +2972,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], out, proj_dgrad.view_as(out), ctx.qkv_dtype, - fp8_dtype_backward, ctx.aux_ctx_tensors, FusedAttnBackend["FP8"], None, @@ -2861,7 +2982,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], attn_scale=None, dropout=ctx.p_dropout, fast_zero_fill=ctx.fast_zero_fill, - qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", + qkv_layout=ctx.qkv_layout, + o_format=ctx.o_format, + do_format=ctx.o_format, + dqkv_layout=ctx.qkv_layout, attn_bias_type="no_bias", attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding", ) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 5aaf67061b..23d1bfdd85 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -17,6 +17,8 @@ from transformer_engine.common.recipe import ( DelayedScaling, Float8CurrentScaling, + MXFP8BlockScaling, + Format, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils @@ -26,6 +28,12 @@ pytest_logging_level = logging.getLevelName(logging.root.level) +# Get determinism +_deterministic = ( + not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + or torch.are_deterministic_algorithms_enabled() +) + # Initialize RNG state seed = 1234 torch.manual_seed(seed) @@ -39,13 +47,11 @@ "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA "cp_1_3": ModelConfig(2, 4096, 12, 128, window_size=(512, 512)), # MHA - "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA + "cp_2_0": ModelConfig(2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA - "cp_2_2": ModelConfig( - 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) - ), # GQA + "cp_2_2": ModelConfig(2, 4096, 32, 128, attn_mask_type="causal", window_size=(128, 0)), # GQA "cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA - "cp_3_0": ModelConfig(2, 4096, 12, 192, attn_mask_type="causal", head_dim_v=128), # MLA + "cp_3_0": ModelConfig(2, 4096, 128, 192, attn_mask_type="causal", head_dim_v=128), # MLA "cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA "cp_3_2": ModelConfig( 2, 4096, 12, 192, attn_mask_type="causal", window_size=(512, 0), head_dim_v=128 @@ -73,7 +79,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_1_0", "cp_1_2", "cp_2_1", "cp_3_2", "cp_3_3"] + configs = ["cp_2_0", "cp_2_2", "cp_3_0", "cp_3_3"] model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs} dtypes = ["bf16"] qkv_formats = ["sbhd", "thd"] @@ -94,25 +100,34 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): config.context_parallel = True config.cp_comm_type = cp_comm_type - if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("CP implementation with KV P2P does not support sliding window yet!") - if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if qkv_format == "thd": - if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") - if cp_comm_type == "a2a+p2p": - pytest.skip( - "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" - " yet!" - ) - if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): + if config.attn_bias_type != "no_bias" and qkv_format == "thd": + pytest.skip("No support for bias with THD format!") + if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]: + pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!") + + if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") + + if ( + config.window_size != (-1, 0) + and config.window_size != (-1, -1) + and cp_comm_type + in [ + "p2p", + "a2a+p2p", + ] + ): + pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") + + if cp_comm_type in ["a2a", "a2a+p2p"] and ( + config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0 + ): pytest.skip( - f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" - f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" + f"cp_comm_type=a2a requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) divisible by 2!" ) + + # FlashAttention / CP implementation specific: MLA only with KV P2P if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} @@ -150,8 +165,22 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss" ), # MHA "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA - "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA - "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA + "cp_2_0": ModelConfig( + 2, + 4096, + 32, + 128, + num_gqa_groups=4, + attn_mask_type="causal", + ), # GQA + "cp_2_1": ModelConfig( + 2, + 4096, + 32, + 128, + attn_mask_type="causal", + window_size=(128, 0), + ), # GQA "cp_2_2": ModelConfig( 2, 4096, @@ -189,7 +218,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512) ), # GQA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA - "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA + "cp_3_1": ModelConfig(2, 4096, 128, 192, head_dim_v=128, attn_mask_type="causal"), # MLA "cp_3_2": ModelConfig( 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA @@ -206,6 +235,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_4_2": ModelConfig( 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" ), # GQA + "cp_4_3": ModelConfig( + 2, 4096, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" + ), # GQA } @@ -215,16 +247,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): if test_essential: configs = [ "cp_1_0", - "cp_1_1", - "cp_1_4", - "cp_1_5", "cp_2_0", + "cp_2_1", "cp_2_2", - "cp_2_3", "cp_2_4", + "cp_3_1", "cp_3_2", "cp_3_4", "cp_4_2", + "cp_4_3", ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] @@ -240,96 +271,81 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("fp8_bwd", [True, False]) @pytest.mark.parametrize("fp8_mha", [True, False]) @pytest.mark.parametrize("fp8_dpa", [True, False]) -@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current", "mxfp8"]) @pytest.mark.parametrize("f16_O", [True, False]) def test_cp_with_fused_attention( dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O ): + config = model_configs_fused_attn[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type + num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): - pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()} GPUs.") + + if get_device_compute_capability() < (9, 0) and qkv_format == "thd": + pytest.skip("Only sm90+ architectures support THD format!") + if get_device_compute_capability() < (9, 0) and dtype == "fp8": + pytest.skip("Only sm90+ architectures support FP8 attention!") - if qkv_format == "thd" and get_device_compute_capability() < (9, 0): - pytest.skip("THD format is only supported on sm90+!") - if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): - pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") - if dtype == "fp8" and get_device_compute_capability() < (9, 0): - pytest.skip("FP8 attention is only supported on sm90+!") + if dtype == "fp8" and not (fp8_mha or fp8_dpa): + pytest.skip("dtype=fp8 requires fp8_dpa=True or fp8_mha=True!") if dtype == "fp8" and not fp8_dpa and fp8_mha: pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!") if dtype != "fp8" and fp8_bwd: - pytest.skip("Only fp8 works with fp8_bwd=True!") - - config = model_configs_fused_attn[model] - config.context_parallel = True - config.cp_comm_type = cp_comm_type + pytest.skip("fp8_bwd=True requires dtype=fp8!") + if dtype != "fp8" and (fp8_mha or fp8_dpa): + pytest.skip("dtype!=fp8 requires fp8_dpa=False and fp8_mha=False!") - if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": - pytest.skip("THD format does not support post_scale_bias yet!") - if qkv_format == "thd": - if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") - if cp_comm_type == "a2a+p2p": - pytest.skip( - "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" - " yet!" - ) - if dtype == "fp8" and cp_comm_type == "all_gather": - pytest.skip( - "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" - ) if dtype == "fp8" and qkv_format == "thd": - pytest.skip("FP8 attention cannot work with THD format yet!") + pytest.skip("No support for FP8 attention with THD format!") if dtype == "fp8" and config.attn_bias_type != "no_bias": - pytest.skip("FP8 attention cannot work with bias yet!") - if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("FP8 attention cannot work with sliding window yet!") - if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("CP implementation with KV P2P does not support sliding window yet!") - if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): - pytest.skip( - f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" - f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" - ) - if dtype != "fp8" and (fp8_mha or fp8_dpa): - pytest.skip("Only fp8 works with fp8_dpa=True or fp8_mha=True!") - if dtype == "fp8" and not (fp8_mha or fp8_dpa): - pytest.skip("fp8 only works with fp8_dpa=True or fp8_mha=True!") - if dtype != "fp8" and scaling_mode is not None: - pytest.skip("Only fp8 works with scaling_mode != None!") - if dtype == "fp8" and scaling_mode is None: - pytest.skip("fp8 only works with scaling_mode != None!") - if ( - dtype == "fp8" - and scaling_mode == "current" - and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] + pytest.skip("No support for FP8 attention with bias!") + + if config.attn_bias_type != "no_bias" and qkv_format == "thd": + pytest.skip("No support for bias with THD format!") + if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]: + pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!") + + if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") + + if (config.window_size[0] != -1 or config.window_size[1] not in [-1, 0]) and cp_comm_type in [ + "p2p", + "a2a+p2p", + ]: + pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") + + if cp_comm_type in ["a2a", "a2a+p2p"] and ( + config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0 ): - pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") - if f16_O and (dtype != "fp8" or scaling_mode != "current"): - pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") - if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: - pytest.skip("MLA CP currently only support KV P2P!") - if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: - pytest.skip("MLA CP currently does not support FP8 attention!") - if dtype == "fp8" and config.softmax_type != "vanilla": - pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!") - if config.softmax_type != "vanilla" and cp_comm_type != "a2a": pytest.skip( - "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" + f"cp_comm_type=a2a requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) divisible by 2!" ) + + if config.softmax_type != "vanilla" and cp_comm_type != "a2a": + pytest.skip(f"No support for non-vanilla softmax with cp_comm_type={cp_comm_type}!") if ( - get_cudnn_version() < (9, 18, 0) - and config.softmax_type != "vanilla" + config.softmax_type != "vanilla" and qkv_format == "thd" + and get_cudnn_version() < (9, 18, 0) ): - pytest.skip( - "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for" - " non-vanilla softmax types!" - ) + pytest.skip("No support for non-vanilla softmax with THD format and cuDNN < 9.18.0!") + + if dtype == "fp8" and scaling_mode is None: + pytest.skip("dtype=fp8 requires scaling_mode != None!") + if dtype != "fp8" and scaling_mode is not None: + pytest.skip("dtype!=fp8 requires scaling_mode = None!") + if dtype != "fp8" and not f16_O: + pytest.skip("dtype!=fp8 requires f16_O=True!") + if scaling_mode == "delayed" and f16_O: + pytest.skip("scaling_mode=delayed requires f16_O=False!") + if scaling_mode == "mxfp8" and not f16_O: + pytest.skip("scaling_mode=mxfp8 requires f16_O=True!") + if scaling_mode == "mxfp8" and fp8_mha: + pytest.skip("No support for scaling_mode=mxfp8 with fp8_mha=True!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -353,6 +369,12 @@ def test_cp_with_fused_attention( Float8CurrentScaling(fp8_dpa=True), DelayedScaling(fp8_dpa=True), ] + if fp8 and scaling_mode == "mxfp8": + fp8_meta["recipe"] = MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=True) + fp8_meta["local_recipes"] = [ + MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=True), + ] + # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. is_training = False if config.bias_shape == "111s" else True available_backends, _, fused_attn_backends = get_available_attention_backends( @@ -362,8 +384,23 @@ def test_cp_with_fused_attention( fp8=fp8, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) _, fused_attn_supported, _ = available_backends + if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]: + config_copy = copy.deepcopy(config) + config_copy.context_parallel = False + config_copy.attn_mask_type = config.attn_mask_type + "_bottom_right" + available_backends, _, fused_attn_backends = get_available_attention_backends( + config_copy, + qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, + qkv_layout="_".join([qkv_format] * 3), + fp8=fp8, + fp8_meta=fp8_meta, + is_training=is_training, + deterministic=_deterministic, + ) + _, fused_attn_supported, _ = available_backends if not fused_attn_supported: pytest.skip("No attention backend available.") @@ -381,6 +418,7 @@ def test_cp_with_fused_attention( scaling_mode=scaling_mode, f16_O=f16_O, is_training=is_training, + deterministic=_deterministic, log_level=pytest_logging_level, ), ) diff --git a/tests/pytorch/distributed/run_newton_schulz.py b/tests/pytorch/distributed/run_newton_schulz.py index bbd0733447..712d83bd1c 100644 --- a/tests/pytorch/distributed/run_newton_schulz.py +++ b/tests/pytorch/distributed/run_newton_schulz.py @@ -21,11 +21,12 @@ ) -def newton_schulz_reference(in_x: torch.Tensor, coefficients: list[float]) -> torch.Tensor: +def newton_schulz_reference( + in_x: torch.Tensor, coefficients: list[tuple[float, float, float]] +) -> torch.Tensor: """Local Newton-Schulz reference mirroring the provided Octave update.""" x = in_x.clone() - for i in range(len(coefficients) // 3): - a, b, c = coefficients[3 * i : 3 * (i + 1)] + for a, b, c in coefficients: xxt = x @ x.mT x = a * x + b * xxt @ x + c * xxt @ xxt @ x return x diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 0f40e92183..47507dc384 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -4,7 +4,7 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Iterable, Sequence import functools import io import math @@ -20,7 +20,7 @@ import transformer_engine.pytorch as te import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import ( - _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu, + _cudnn_frontend_version_supported, ) from transformer_engine.pytorch.ops.fused import ( @@ -42,6 +42,7 @@ ) from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.pytorch.cpp_extensions.gemm import general_grouped_gemm_for_grouped_tensor +from transformer_engine.pytorch.module.base import get_dummy_wgrad import transformer_engine_torch as tex # Import utility functions @@ -199,6 +200,88 @@ def make_reference_and_test_tensors( return ref, test +def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """Convert to an FP64 CPU tensor""" + if tensor is None: + return None + out = tensor.detach() + if isinstance(out, QuantizedTensor): + out = out.dequantize() + out = out.to(dtype=torch.float64, device="cpu") + out = out.requires_grad_(requires_grad=tensor.requires_grad) + return out + + +class MegatronTrainingHelper: + """Test-side stand-in for the Megatron-Core DDP / MegatronFSDP wrapper. + Megatron's DDP wrapper (and MegatronFSDP) owns the per-parameter + ``main_grad`` buffer and the ``overwrite_main_grad`` / + ``grad_added_to_main_grad`` attributes that coordinate + ``fuse_wgrad_accumulation`` with TE modules. These helpers reproduce the + relevant slice of that protocol so TE tests can exercise the + accumulate-into-``main_grad`` code path without pulling in the full + Megatron-Core dependency. + """ + + @staticmethod + def init_main_grad_buffers( + weight_params: Iterable[torch.nn.Parameter], + *, + fill_value: float, + overwrite_main_grad: bool, + zero_out_wgrad: bool = False, + dtype: torch.dtype = torch.float32, + ) -> None: + """Allocate ``main_grad`` and stamp the wrapper attributes on each + param, mirroring what the Megatron DDP/FSDP wrapper does before + backward.""" + for wp in weight_params: + wp.main_grad = torch.full(wp.size(), fill_value, device=wp.device, dtype=dtype) + wp.overwrite_main_grad = overwrite_main_grad + wp.zero_out_wgrad = zero_out_wgrad + wp.grad_added_to_main_grad = False + + @staticmethod + def verify_main_grad_accumulation( + weight_params: Iterable[torch.nn.Parameter], + *, + expected_main_grads: Iterable[torch.Tensor], + rtol: float = 0.0, + atol: float = 0.0, + ) -> None: + """Check that backward produced what the Megatron wrapper expects: + each ``main_grad`` matches ``expected_main_grads``, + ``grad_added_to_main_grad`` was flipped to ``True`` so the wrapper's + post-backward hooks won't double-accumulate, and ``param.grad`` was + replaced by the cached dummy tensor (so a wrapper hook that did + ``main_grad += grad`` would be a no-op rather than double-counting). + """ + for wp, expected in zip(weight_params, expected_main_grads): + torch.testing.assert_close(wp.main_grad.to(expected), expected, rtol=rtol, atol=atol) + + assert wp.grad_added_to_main_grad is True, ( + "weight.grad_added_to_main_grad was not flipped to True; " + "the Megatron DDP/FSDP wrapper hook will double-accumulate." + ) + + # ``.grad`` should be the cached dummy tensor returned by + # ``get_dummy_wgrad`` -- shared storage, not the real wgrad. + expected_dummy = get_dummy_wgrad(list(wp.size()), wp.dtype) + assert ( + wp.grad is not None + ), "weight.grad is None; the Megatron protocol expects a dummy tensor stand-in here." + assert wp.grad.data_ptr() == expected_dummy.data_ptr(), ( + "weight.grad does not share storage with the cached dummy " + "wgrad; downstream wrapper hooks risk double-accumulating." + ) + if getattr(wp, "zero_out_wgrad", False): + assert torch.all(wp.grad == 0), ( + "weight.zero_out_wgrad=True but the dummy weight.grad " + "was not zeroed; downstream hooks reading .grad would " + "see stale bytes from the previous step." + ) + + class TestSequentialContainer: """Tests for sequential container""" @@ -3297,25 +3380,17 @@ def test_layernorm_mlp( y_test = forward(x_test) y_test.backward(dy_test) - def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: - """Convert to FP64 CPU tensor""" - if tensor is None: - return None - out = tensor.detach().to(dtype=torch.float64, device="cpu") - out = out.requires_grad_(requires_grad=tensor.requires_grad) - return out - # Check values tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking - torch.testing.assert_close(to_cpu(y_test), y_ref, **tols) - torch.testing.assert_close(to_cpu(x_test.grad), x_ref.grad, **tols) - torch.testing.assert_close(to_cpu(norm.weight.grad), norm_w_ref.grad, **tols) - torch.testing.assert_close(to_cpu(norm.bias.grad), norm_b_ref.grad, **tols) - torch.testing.assert_close(to_cpu(ffn2.weight.grad), w2_ref.grad, **tols) - torch.testing.assert_close(to_cpu(ffn1.weight.grad), w1_ref.grad, **tols) + assert_close(y_test, y_ref, **tols) + assert_close(x_test.grad, x_ref.grad, **tols) + assert_close_grads(norm.weight, norm_w_ref, **tols) + assert_close_grads(norm.bias, norm_b_ref, **tols) + assert_close_grads(ffn2.weight, w2_ref, **tols) + assert_close_grads(ffn1.weight, w1_ref, **tols) if bias: - torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols) - torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols) + assert_close_grads(ffn1.bias, b1_ref, **tols) + assert_close_grads(ffn2.bias, b2_ref, **tols) @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) @@ -3537,33 +3612,20 @@ def test_grouped_mlp( getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) if accumulate_into_main_grad: + # 0.5 sentinel lets us reconstruct ``expected = ref_grad + 0.5`` + # below and detect a missed accumulation. + main_grad_sentinel = 0.5 if single_grouped_weight: - fc1.weight.main_grad = torch.full( - fc1.weight.size(), - 0.5, - device=device, - dtype=torch.float32, - ) - fc2.weight.main_grad = torch.full( - fc2.weight.size(), - 0.5, - device=device, - dtype=torch.float32, - ) + weight_params_for_main_grad = [fc1.weight, fc2.weight] else: - for group_idx in range(group_size): - getattr(fc1, f"weight{group_idx}").main_grad = torch.full( - getattr(fc1, f"weight{group_idx}").size(), - 0.5, - device=device, - dtype=torch.float32, - ) - getattr(fc2, f"weight{group_idx}").main_grad = torch.full( - getattr(fc2, f"weight{group_idx}").size(), - 0.5, - device=device, - dtype=torch.float32, - ) + weight_params_for_main_grad = [ + getattr(fc, f"weight{i}") for fc in (fc1, fc2) for i in range(group_size) + ] + MegatronTrainingHelper.init_main_grad_buffers( + weight_params_for_main_grad, + fill_value=main_grad_sentinel, + overwrite_main_grad=False, + ) del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test # Fuse ops and perform forward and backward pass @@ -3580,10 +3642,7 @@ def test_grouped_mlp( quantization == "mxfp8" and dtype in (torch.bfloat16, torch.float16) and glu_interleave_size == 32 - and ( - activation != "scaled_clamped_qgeglu" - or _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu() - ) + and _cudnn_frontend_version_supported() ): if te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): forward_ops = module._module_groups[0]._forward_ops @@ -3639,32 +3698,24 @@ def test_grouped_mlp( fc1_w_ref_grad = torch.stack([w.grad for w in fc1_ws_ref], dim=0) fc2_w_ref_grad = torch.stack([w.grad for w in fc2_ws_ref], dim=0) if accumulate_into_main_grad: - if single_grouped_weight: - fc1_w_test_grad = fc1.weight.main_grad.to(dtype=torch.float64, device="cpu") - 0.5 - fc2_w_test_grad = fc2.weight.main_grad.to(dtype=torch.float64, device="cpu") - 0.5 - else: - fc1_w_test_grad = torch.stack( - [ - getattr(fc1, f"weight{group_idx}").main_grad.to( - dtype=torch.float64, device="cpu" - ) - - 0.5 - for group_idx in range(group_size) - ], - dim=0, - ) - fc2_w_test_grad = torch.stack( - [ - getattr(fc2, f"weight{group_idx}").main_grad.to( - dtype=torch.float64, device="cpu" - ) - - 0.5 - for group_idx in range(group_size) - ], - dim=0, - ) - assert_close(fc1_w_test_grad, fc1_w_ref_grad, **tols) - assert_close(fc2_w_test_grad, fc2_w_ref_grad, **tols) + # main_grad should accumulate the ref wgrad onto the 0.5 sentinel. + # Per-param expected views must line up with + # ``weight_params_for_main_grad`` registered above. + fc1_expected = ( + [fc1_w_ref_grad + main_grad_sentinel] + if single_grouped_weight + else [g + main_grad_sentinel for g in fc1_w_ref_grad] + ) + fc2_expected = ( + [fc2_w_ref_grad + main_grad_sentinel] + if single_grouped_weight + else [g + main_grad_sentinel for g in fc2_w_ref_grad] + ) + MegatronTrainingHelper.verify_main_grad_accumulation( + weight_params_for_main_grad, + expected_main_grads=fc1_expected + fc2_expected, + **tols, + ) elif single_grouped_weight: assert_close(fc1.weight.grad, fc1_w_ref_grad, **tols) assert_close(fc2.weight.grad, fc2_w_ref_grad, **tols) @@ -3694,12 +3745,6 @@ def test_grouped_mlp_single_weight_numerics( pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(): pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") - if activation == "scaled_clamped_qgeglu" and not ( - _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu() - ): - pytest.skip( - "ScaledClampedQGeGLU fused grouped MLP requires nvidia-cudnn-frontend >= 1.23.0" - ) split_sizes = [split_alignment * (i + 1) for i in range(group_size)] random.shuffle(split_sizes) @@ -3884,6 +3929,153 @@ def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]: torch.testing.assert_close(fc1_db_false, fc1_db_true, **bias_tols) torch.testing.assert_close(fc2_db_false, fc2_db_true, **bias_tols) + @pytest.mark.parametrize("single_grouped_weight", (False, True)) + @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) + @pytest.mark.parametrize("zero_out_wgrad", (False, True)) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_grouped_mlp_overwrite_main_grad( + self, + *, + single_grouped_weight: bool, + delay_wgrad_compute: bool, + zero_out_wgrad: bool, + dtype: torch.dtype = torch.bfloat16, + device: torch.device = "cuda", + group_size: int = 4, + hidden_size: int = 256, + split_alignment: int = 256, + glu_interleave_size: int = 32, + ) -> None: + """End-to-end check that the fused grouped-MLP backward writes the + wgrad into ``weight.main_grad`` correctly under the MegatronFSDP + ``overwrite_main_grad=True`` convention. + ``test_grouped_mlp`` already covers the standard Megatron-LM + ``fuse_wgrad_accumulation`` (DDP) path where the wgrad GEMM + *accumulates* into ``main_grad``. This test focuses exclusively on + the MegatronFSDP variant where the wgrad GEMM must *overwrite* + ``main_grad`` (because FSDP has already ReduceScattered the previous + accumulation), so ``main_grad`` after backward equals ``wgrad`` + regardless of the prior contents. + + Also exercises the MegatronFSDP ``zero_out_wgrad`` flag, which is + independent of ``main_grad`` and only controls whether the dummy + ``param.grad`` returned to autograd is zeroed (so downstream hooks + that read ``.grad`` don't see stale bytes from the cached dummy). + """ + + if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") + if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(): + pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") + + recipe = make_recipe("mxfp8") + split_sizes = [split_alignment * (i + 1) for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) + in_shape = (split_sizes.sum().item(), hidden_size) + x_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) + probs_base = torch.empty((in_shape[0],), device=device, dtype=dtype).uniform_(-0.25, 0.25) + dy_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) + fc1_ws_base = [ + torch.empty((2 * hidden_size, hidden_size), device=device, dtype=dtype).uniform_( + -0.25, 0.25 + ) + for _ in range(group_size) + ] + fc2_ws_base = [ + torch.empty((hidden_size, hidden_size), device=device, dtype=dtype).uniform_( + -0.25, 0.25 + ) + for _ in range(group_size) + ] + + def _build_module(*, accumulate_into_main_grad: bool): + with te.quantized_model_init(enabled=True, recipe=recipe): + fc1 = te_ops.GroupedLinear( + group_size, + hidden_size, + 2 * hidden_size, + bias=False, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, + ) + fc2 = te_ops.GroupedLinear( + group_size, + hidden_size, + hidden_size, + bias=False, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, + ) + scaled_act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + module = te_ops.Sequential(fc1, scaled_act, fc2) + + with torch.no_grad(): + if single_grouped_weight: + fc1_weights = ( + fc1.weight.quantized_tensors or fc1.weight.split_into_quantized_tensors() + ) + fc2_weights = ( + fc2.weight.quantized_tensors or fc2.weight.split_into_quantized_tensors() + ) + for group_idx in range(group_size): + fc1_weights[group_idx].copy_(fc1_ws_base[group_idx]) + fc2_weights[group_idx].copy_(fc2_ws_base[group_idx]) + else: + for group_idx in range(group_size): + getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_base[group_idx]) + getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_base[group_idx]) + return module, fc1, fc2 + + def _weight_params(fc): + if single_grouped_weight: + return [fc.weight] + return [getattr(fc, f"weight{i}") for i in range(group_size)] + + def _run_backward(module, fc1, fc2): + x = x_base.detach().clone().requires_grad_(True) + probs = probs_base.detach().clone().requires_grad_(True) + with te.autocast(enabled=True, recipe=recipe): + y = module(x, split_sizes, probs, split_sizes) + y.backward(dy_base) + if delay_wgrad_compute: + fc1.backward_dw() + fc2.backward_dw() + + # Reference run: vanilla autograd, no Megatron protocol. + ref_module, ref_fc1, ref_fc2 = _build_module(accumulate_into_main_grad=False) + _run_backward(ref_module, ref_fc1, ref_fc2) + ref_fc1_grads = [wp.grad.detach().clone() for wp in _weight_params(ref_fc1)] + ref_fc2_grads = [wp.grad.detach().clone() for wp in _weight_params(ref_fc2)] + + # Test run: main_grad fusion with overwrite_main_grad=True (MegatronFSDP). + # NaN sentinel makes a missed write loud (would surface as NaN diff). + test_module, test_fc1, test_fc2 = _build_module(accumulate_into_main_grad=True) + for fc in (test_fc1, test_fc2): + MegatronTrainingHelper.init_main_grad_buffers( + _weight_params(fc), + fill_value=float("nan"), + overwrite_main_grad=True, + zero_out_wgrad=zero_out_wgrad, + ) + _run_backward(test_module, test_fc1, test_fc2) + + # main_grad must be overwritten to exactly the ref wgrad (bitwise: + # the wgrad GEMM is deterministic across the two runs because the + # quantized weights and inputs are identical). + MegatronTrainingHelper.verify_main_grad_accumulation( + _weight_params(test_fc1), expected_main_grads=ref_fc1_grads + ) + MegatronTrainingHelper.verify_main_grad_accumulation( + _weight_params(test_fc2), expected_main_grads=ref_fc2_grads + ) + @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("single_grouped_weight", (False, True)) @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) @@ -3909,12 +4101,6 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8( pytest.skip("MXFP8 fused grouped MLP is not supported on this system") if dtype not in (torch.bfloat16, torch.float16): pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16") - if activation == "scaled_clamped_qgeglu" and not ( - _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu() - ): - pytest.skip( - "ScaledClampedQGeGLU fused grouped MLP requires nvidia-cudnn-frontend >= 1.23.0" - ) split_sizes = [split_alignment * (i + 1) for i in range(group_size)] random.shuffle(split_sizes) @@ -4543,6 +4729,232 @@ def fuse_ops( torch.testing.assert_close(dw_test, w_ref.grad, **tols) +class TestTrainingLoops: + + def _linear_train_stage( + self, + module: te.ops.Linear, + *, + steps: int = 3, + in_shape: Sequence[int], + out_shape: Sequence[int], + dtype: torch.type, + device: torch.device, + quantization: Optional[str], + recipe: Optional[transformer_engine.common.recipe.Recipe], + ) -> None: + """Perform training steps with linear op""" + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantization is not None: + tols = quantization_tols(quantization) + + for _ in range(steps): + # Update parameters with random values to simulate + # optimizer step or FSDP param all-gather + with torch.no_grad(): + module.weight.copy_(torch.empty_like(module.weight).uniform_()) + module.bias.copy_(torch.empty_like(module.bias).uniform_()) + for param in module.parameters(): + param.grad = None + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + w_ref = to_cpu(module.weight) + b_ref = to_cpu(module.bias) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref) + y_ref.backward(dy_ref) + + # Implementation with linear op + with te.autocast(enabled=quantization is not None, recipe=recipe): + y_test = module(x_test) + y_test.backward(dy_test) + + # Check results + assert_close(y_test, y_ref, **tols) + assert_close_grads(x_test, x_ref, **tols) + assert_close_grads(module.weight, w_ref, **tols) + assert_close_grads(module.bias, b_ref, **tols) + + @torch.inference_mode + def _linear_infer_stage( + self, + module: te.ops.Linear, + *, + steps: int = 3, + in_shape: Sequence[int], + dtype: torch.type, + device: torch.device, + quantization: Optional[str], + recipe: Optional[transformer_engine.common.recipe.Recipe], + ) -> None: + """Perform inference steps with linear op""" + + # Parameter reference values + w_ref = to_cpu(module.weight) + b_ref = to_cpu(module.bias) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantization is not None: + tols = quantization_tols(quantization) + + for _ in range(steps): + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref) + + # Implementation with linear op + with te.autocast(enabled=quantization is not None, recipe=recipe): + y_test = module(x_test) + + # Check results + assert_close(y_test, y_ref, **tols) + + @pytest.mark.parametrize("stages", (["train", "infer"] * 2, ["infer", "train"] * 2)) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantized_weight", (False, True)) + def test_linear_training_loop( + self, + *, + stages: Sequence[str], + weight_shape: tuple[int, int] = (32, 32), + in_shape: Sequence[int] = (32, -1), + dtype: Optional[torch.dtype] = None, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_weight: bool, + ) -> None: + """Training loops with linear op""" + if dtype is None: + dtype = torch.bfloat16 if is_bf16_available() else torch.float32 + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and quantized_weight: + pytest.skip("Quantization scheme is not specified") + + # Construct module with random weights + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): + module = te.ops.Linear( + in_features, + out_features, + device=device, + dtype=dtype, + ) + with torch.no_grad(): + for param in module.parameters(): + param.copy_(torch.empty_like(param).uniform_()) + + # Training loop stages + for stage in stages: + if stage == "train": + self._linear_train_stage( + module, + in_shape=in_shape, + out_shape=out_shape, + dtype=dtype, + device=device, + quantization=quantization, + recipe=recipe, + ) + elif stage == "infer": + self._linear_infer_stage( + module, + in_shape=in_shape, + dtype=dtype, + device=device, + quantization=quantization, + recipe=recipe, + ) + else: + raise ValueError(f"Unrecognized stage ({stage})") + + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantized_weight", (False, True)) + def test_linear_inference_loop( + self, + *, + weight_shape: tuple[int, int] = (32, 32), + in_shape: Sequence[int] = (32, -1), + dtype: Optional[torch.dtype] = None, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_weight: bool, + ) -> None: + """Inference loop with linear op""" + if dtype is None: + dtype = torch.bfloat16 if is_bf16_available() else torch.float32 + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and quantized_weight: + pytest.skip("Quantization scheme is not specified") + + # Construct module with random weights + recipe = make_recipe(quantization) + with ( + torch.inference_mode(), + te.quantized_model_init(enabled=quantized_weight, recipe=recipe), + ): + module = te.ops.Linear( + in_features, + out_features, + device=device, + dtype=dtype, + ) + for param in module.parameters(): + param.copy_(torch.empty_like(param).uniform_()) + + # Inference loop + self._linear_infer_stage( + module, + in_shape=in_shape, + dtype=dtype, + device=device, + quantization=quantization, + recipe=recipe, + ) + + def test_grouped_gemm_quant_cute_matches_mxfp8_quantized() -> None: if not mxfp8_available: pytest.skip(reason_for_no_mxfp8) @@ -4658,7 +5070,6 @@ def test_grouped_gemm_quant_cute_matches_mxfp8_quantized() -> None: norm_const_tensor=None, prob_tensor=inputs["prob_tensor"], acc_dtype=torch.float32, - c_dtype=torch.bfloat16, d_dtype=torch.bfloat16, cd_major="n", sf_vec_size=32, diff --git a/tests/pytorch/test_mhc.py b/tests/pytorch/test_mhc.py new file mode 100644 index 0000000000..541ce9a8c2 --- /dev/null +++ b/tests/pytorch/test_mhc.py @@ -0,0 +1,497 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from dataclasses import dataclass +import pytest +import torch +import torch.nn.functional as F + +from utils import reset_rng_states +from transformer_engine.pytorch.triton.mhc import ( + mhc_fused_sinkhorn, + mhc_fused_scale, + mhc_fused_aggregate, + mhc_fused_expand_combine, + mhc_fused_projection, +) + +# Disable TF32 for matmul to ensure consistency between the fused and reference implementations +torch.backends.cuda.matmul.allow_tf32 = False + + +def mhc_projection_ref(x, phi): + """ + Reference operator for mHC's projection building operation. + + x: (M, nC) where M = s * b + phi: (2n + n^2, nC), which consists of the following matrices + - phi_pre: (n, nC) + - phi_post: (n, nC) + - phi_res: (n^2, nC) + n: number of Hyper Connection streams + C: hidden dimension per stream + """ + x_dtype = x.dtype + x = x.to(torch.float32) + phi = phi.to(torch.float32) + + Hs = x @ phi.T # (M, 2n + n^2) + + x_fp32 = x.to(torch.float32) # Use fp32 for better numerical stability in variance calculation + ms = (x_fp32 * x_fp32).mean(dim=1) + + return Hs.to(x_dtype), ms + + +def mhc_scale_ref(H, alpha, beta, ms, n): + """ + Reference operator for mHC's H matrices scaling operation + + :param: H: (M, 2n + n^2), the unprocessed H matrices where M = s * b + :param: alpha: (3,), three scalar parameters + :param: beta: (1, 2n + n^2), bias term + :param: r: (M,), the denominator for RMSNorm + :param: n: int, the width of Hyper-Connection + + :return Hs: (M, 2n + n^2), the processed H matrices + """ + + input_dtype = H.dtype + H = H.to(torch.float32) + alpha = alpha.to(torch.float32) + beta = beta.to(torch.float32) + eps = torch.finfo(torch.float32).eps + rms = torch.sqrt(ms + eps) # (M,) + rms = rms.to(torch.float32) + + H_pre = H[:, :n] # (M, n) + H_post = H[:, n : 2 * n] # (M, n) + H_res = H[:, 2 * n :] # (M, n^2) + + beta_pre = beta[0, :n] + beta_post = beta[0, n : 2 * n] + beta_res = beta[0, 2 * n : 2 * n + n * n] + + alpha_pre, alpha_post, alpha_res = alpha[0], alpha[1], alpha[2] + + H_pre = H_pre * alpha_pre + H_post = H_post * alpha_post + H_res = H_res * alpha_res + + H_pre = H_pre / rms[:, None] + H_post = H_post / rms[:, None] + H_res = H_res / rms[:, None] + + H_pre = H_pre + beta_pre + H_post = H_post + beta_post + H_res = H_res + beta_res + + H_pre = F.sigmoid(H_pre) + H_post = 2 * F.sigmoid(H_post) + + return H_pre.to(input_dtype), H_post.to(input_dtype), H_res.to(input_dtype) + + +def mhc_sinkhorn_ref(H_res, n=4, iterations=20): + """ + Reference operator for mHC's Sinkhorn-Knopp algorithm to convert a matrix into a doubly stochastic matrix. + Calculated in log space for numerical stability. + + :param H_res: a tensor of shape (s, b, n, n) + :return: a tensor of shape (s, b, n, n) + """ + s, b = H_res.shape[:2] + device = H_res.device + dtype = H_res.dtype + + H_res_f = H_res.to( + torch.float32 + ).clone() # Use float32 for better numerical stability during Sinkhorn iterations + + log_mu = torch.zeros(s, b, n, device=device, dtype=torch.float32) + log_nu = torch.zeros(s, b, n, device=device, dtype=torch.float32) + + f = torch.zeros(s, b, n, device=device, dtype=torch.float32) + g = torch.zeros(s, b, n, device=device, dtype=torch.float32) + + for _ in range(iterations): + # Update f: logsumexp over the column dimension (3) + f = log_mu - torch.logsumexp(H_res_f + g.unsqueeze(2), dim=3) + # Update g: logsumexp over the row dimension (2) + g = log_nu - torch.logsumexp(H_res_f + f.unsqueeze(3), dim=2) + + log_P = f.unsqueeze(3) + H_res_f + g.unsqueeze(2) + H_res_out = torch.exp(log_P).to(dtype) # Convert back to original dtype + + return H_res_out + + +def mhc_aggregate_ref(x, H_pre, n): + """ + Reference operator for applying mHC's aggregation transformation + + x: (s, b, C, n) + H_pre: (s, b, n) + """ + H_pre = H_pre.contiguous() + + s, b, C, n = x.shape + H_pre = H_pre.view(s, b, n, 1) + + out = (x @ H_pre).view(s, b, C) + + return out + + +def mhc_expand_combine_ref(f, bias, H_post, x, H_res, n): + """ + Reference operator for applying mHC's expansion and combination transformation + + f: (s, b, C) + bias: (C,) or None + H_post: (s, b, n) + x: (s, b, C, n) + H_res: (s, b, n, n) + """ + + s, b, C, n = x.shape + + # My triton kernels use FMA and MMA instructions with fp32 accumulator for bf16 test cases + # which has better numerical stability than this pytorch implementation + # To match the kernel's accuracy we need to cast to fp32 here to match kernels' result + input_dtype = f.dtype + f = f.to(torch.float32) + bias = bias.to(torch.float32) if bias is not None else None + H_post = H_post.to(torch.float32) + x = x.to(torch.float32) + H_res = H_res.to(torch.float32) + + if bias is not None: + f = f + bias[None, None, :] + + f = f.view(s, b, C, 1) + H_post = H_post.view(s, b, 1, n) + + out = f @ H_post + x @ H_res # (s, b, C, n) + + return out.to(input_dtype) + + +@dataclass +class MHCConfig: + s: int = 2048 # Sequence length + b: int = 32 # Batch size + C: int = 1024 # Hidden dimension + n: int = 4 # Number of Hyper Connection streams + + allow_n = [ + 4, + ] + + def __init__(self, b, s, C, n=4): + assert n in self.allow_n, f"n must be one of {self.allow_n}" + self.b = b + self.s = s + self.C = C + self.n = n + + @staticmethod + def desc(cfg): + return f"b{cfg.b}_s{cfg.s}_C{cfg.C}_n{cfg.n}" + + +mhc_configs = [ + MHCConfig(8, 32, 32), + MHCConfig(8, 128, 16 * 64), + MHCConfig( + 4, + 128, + 16 * 64, + ), + MHCConfig(2, 2048, 24 * 128), + MHCConfig( + 1, + 2048, + 24 * 128, + ), + MHCConfig( + 13, + 1, + 16 * 128, + ), + MHCConfig( + 7, + 1, + 16 * 256, + ), + MHCConfig( + 8, + 1, + 16 * 192, + ), + MHCConfig( + 8, + 128, + 5129, + ), + MHCConfig( + 8, + 512, + 8000, + ), + MHCConfig( + 4, + 1024, + 8192, + ), + MHCConfig( + 2, + 4096, + 8192, + ), + MHCConfig( + 8, + 128, + 16384, + ), +] + + +def get_tols(dtype): + if dtype == torch.bfloat16: + tols = dict(atol=2.5e-2, rtol=2.5e-2) + else: + tols = dict(atol=5e-3, rtol=5e-3) + return tols + + +@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) +def test_mhc_projection(cfg: MHCConfig, dtype): + reset_rng_states() + + s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n + nC = n * C + N = 2 * n + n * n + + tols = get_tols(dtype) + use_tf32 = False + + x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=dtype) + phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda") + + x_ref = x.detach().clone().requires_grad_(True) + phi_ref = phi.detach().clone().requires_grad_(True) + + ref_out_Hs, ref_out_ms = mhc_projection_ref(x_ref, phi_ref) + fused_out_Hs_padded, fused_out_ms = mhc_fused_projection(x, phi, use_tf32) + fused_out_Hs = fused_out_Hs_padded[:, :N] + + torch.testing.assert_close(fused_out_Hs, ref_out_Hs, **tols) + torch.testing.assert_close(fused_out_ms, ref_out_ms, **tols) + (ref_out_Hs.sum() + ref_out_ms.sum()).backward() + (fused_out_Hs.sum() + fused_out_ms.sum()).backward() + + torch.testing.assert_close(x.grad, x_ref.grad, **tols) + torch.testing.assert_close(phi.grad, phi_ref.grad, **tols) + + +@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) +@pytest.mark.parametrize("dtype", [torch.float32], ids=["fp32"]) +def test_mhc_scale(cfg: MHCConfig, dtype): + reset_rng_states() + + s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n + N = 2 * n + n * n + + tols = get_tols(dtype) + + H_padded = torch.randn(s * b, 32, device="cuda", requires_grad=True, dtype=dtype) + H = H_padded[:, :N] + alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=dtype) + beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=dtype) + ms_raw = torch.randn(s * b, device="cuda", dtype=dtype).abs() + 1.0 + ms = ms_raw.detach().clone().requires_grad_(True) + + H_ref = H.detach().clone().requires_grad_(True) + alpha_ref = alpha.detach().clone().requires_grad_(True) + beta_ref = beta.detach().clone().requires_grad_(True) + ms_ref = ms.detach().clone().requires_grad_(True) + + ref_out = mhc_scale_ref(H_ref[:, :N], alpha_ref, beta_ref, ms_ref, n) + fused_out = mhc_fused_scale(H_padded, alpha, beta, ms, n) + + for i in range(3): + torch.testing.assert_close(fused_out[i], ref_out[i], **tols) + + torch.cat([ref_out[i] for i in range(3)], dim=-1).sum().backward() + torch.cat([fused_out[i] for i in range(3)], dim=-1).sum().backward() + + torch.testing.assert_close(H_padded.grad[:, :N], H_ref.grad, **tols) + torch.testing.assert_close(alpha.grad, alpha_ref.grad, **tols) + torch.testing.assert_close(beta.grad, beta_ref.grad, **tols) + torch.testing.assert_close(ms.grad, ms_ref.grad, **tols) + + +@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) +def test_mhc_combined(cfg: MHCConfig, dtype): + reset_rng_states() + + s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n + N = 2 * n + n * n + nC = n * C + + tols = get_tols(dtype) + use_tf32 = False + + x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=dtype) + phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda") + + alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=dtype) + beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=dtype) + + x_ref = x.detach().clone().requires_grad_(True) + phi_ref = phi.detach().clone().requires_grad_(True) + + alpha_ref = alpha.detach().clone().requires_grad_(True) + beta_ref = beta.detach().clone().requires_grad_(True) + + ref_out_H, ref_out_r = mhc_projection_ref(x_ref, phi_ref) + fused_out_H_padded, fused_out_r = mhc_fused_projection(x, phi, use_tf32) + + ref_H_pre, ref_H_post, ref_H_res = mhc_scale_ref( + ref_out_H[:, :N], alpha_ref, beta_ref, ref_out_r, n + ) + fused_H_pre, fused_H_post, fused_H_res = mhc_fused_scale( + fused_out_H_padded, alpha, beta, fused_out_r, n + ) + + def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref): + dtype = x_ref.dtype + x_ref = x_ref.to(torch.float32) + phi_ref = phi_ref.to(torch.float32) + alpha_ref = alpha_ref.to(torch.float32) + beta_ref = beta_ref.to(torch.float32) + + # Check if after spliting RMSNorm to two steps in projection and scaling, + # theresult is close to applying RMSNorm in the correct order + x_rmsnorm = F.rms_norm(x_ref, normalized_shape=(nC,)) + H = x_rmsnorm @ phi_ref.T + H_pre = H[:, :n] + H_post = H[:, n : 2 * n] + H_res = H[:, 2 * n :] + + out_pre = H_pre * alpha_ref[0] + beta_ref[:, :n] + out_post = H_post * alpha_ref[1] + beta_ref[:, n : 2 * n] + out_res = H_res * alpha_ref[2] + beta_ref[:, 2 * n :] + + out_pre = out_pre.sigmoid() + out_post = 2 * out_post.sigmoid() + out_res = out_res + + return out_pre.to(dtype), out_post.to(dtype), out_res.to(dtype) + + combined_H_pre, combined_H_post, combined_H_res = mhc_combined( + x_ref, phi_ref, alpha_ref, beta_ref + ) + + torch.testing.assert_close(combined_H_pre, ref_H_pre, **tols) + torch.testing.assert_close(combined_H_post, ref_H_post, **tols) + torch.testing.assert_close(combined_H_res, ref_H_res, **tols) + + torch.testing.assert_close(combined_H_pre, fused_H_pre, **tols) + torch.testing.assert_close(combined_H_post, fused_H_post, **tols) + torch.testing.assert_close(combined_H_res, fused_H_res, **tols) + + +@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) +@pytest.mark.parametrize("recompute", [False, True], ids=["no_recompute", "recompute"]) +def test_mhc_sinkhorn(cfg: MHCConfig, dtype, recompute): + reset_rng_states() + + s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n + + tols = get_tols(dtype) + + x = torch.randn(s, b, n, n, device="cuda", requires_grad=True, dtype=dtype) + x_ref = x.detach().clone().requires_grad_(True) + + ref_out = mhc_sinkhorn_ref(x_ref, n) + fused_out = mhc_fused_sinkhorn(x, n, recompute) + + torch.testing.assert_close(fused_out, ref_out, **tols) + + ref_out.sum().backward() + fused_out.sum().backward() + + torch.testing.assert_close(x.grad, x_ref.grad, **tols) + + +@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) +def test_mhc_aggregate(cfg: MHCConfig, dtype): + reset_rng_states() + + s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n + + tols = get_tols(dtype) + + x = torch.randn(s, b, C, n, device="cuda", requires_grad=True, dtype=dtype) + H_pre = torch.randn(s, b, n, device="cuda", requires_grad=True, dtype=dtype) + + x_ref = x.detach().clone().requires_grad_(True) + H_pre_ref = H_pre.detach().clone().requires_grad_(True) + + ref_out = mhc_aggregate_ref(x_ref, H_pre_ref, n) + fused_out = mhc_fused_aggregate(x, H_pre, n, False) + + torch.testing.assert_close(fused_out, ref_out, **tols) + + ref_out.sum().backward() + fused_out.sum().backward() + + torch.testing.assert_close(x.grad, x_ref.grad, **tols) + torch.testing.assert_close(H_pre.grad, H_pre_ref.grad, **tols) + + +@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) +@pytest.mark.parametrize("with_bias", [True, False], ids=["with_bias", "no_bias"]) +def test_mhc_expand_combine(cfg: MHCConfig, dtype, with_bias): + reset_rng_states() + + s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n + + tols = get_tols(dtype) + + f = torch.randn(s, b, C, device="cuda", requires_grad=True, dtype=dtype) + bias = None + if with_bias: + bias = torch.randn(C, device="cuda", requires_grad=True, dtype=dtype) + H_post = torch.randn(s, b, n, device="cuda", requires_grad=True, dtype=dtype) + x = torch.randn(s, b, C, n, device="cuda", requires_grad=True, dtype=dtype) + H_res = torch.randn(s, b, n, n, device="cuda", requires_grad=True, dtype=dtype) + + f_ref = f.detach().clone().requires_grad_(True) + bias_ref = None if bias is None else bias.detach().clone().requires_grad_(True) + H_post_ref = H_post.detach().clone().requires_grad_(True) + x_ref = x.detach().clone().requires_grad_(True) + H_res_ref = H_res.detach().clone().requires_grad_(True) + + ref_out = mhc_expand_combine_ref(f_ref, bias_ref, H_post_ref, x_ref, H_res_ref, n) + fused_out = mhc_fused_expand_combine(f, bias, H_post, x, H_res, n, False) + + torch.testing.assert_close(fused_out, ref_out, **tols) + + ref_out.sum().backward() + fused_out.sum().backward() + + torch.testing.assert_close(f.grad, f_ref.grad, **tols) + torch.testing.assert_close(H_post.grad, H_post_ref.grad, **tols) + torch.testing.assert_close(x.grad, x_ref.grad, **tols) + torch.testing.assert_close(H_res.grad, H_res_ref.grad, **tols) + if bias is not None: + torch.testing.assert_close(bias.grad, bias_ref.grad, **tols) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index fd9a6416ec..8f8852edc2 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -198,6 +198,10 @@ def reset_rng_states() -> None: def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8): + if a is None and b is None: + logging.debug(f"{name_a} vs {name_b}: both are None") + return + if not is_fp8: torch.testing.assert_close(a, b, atol=atol, rtol=rtol) return diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a21c1ee7e6..53f9773a73 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -179,7 +179,6 @@ list(APPEND transformer_engine_cuda_sources transpose/quantize_transpose_vector_blockwise.cu transpose/swap_first_dims.cu dropout/dropout.cu - fused_attn/flash_attn.cu fused_attn/context_parallel.cu fused_attn/kv_cache.cu fused_attn/fused_attn_f16_max512_seqlen.cu @@ -210,6 +209,7 @@ list(APPEND transformer_engine_cuda_sources comm_gemm_overlap/userbuffers/userbuffers.cu) list(APPEND transformer_engine_cuda_arch_specific_sources + fused_attn/flash_attn.cu activation/gelu.cu activation/glu.cu activation/relu.cu diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index ce6917aa42..aa697d4bfe 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -109,11 +109,15 @@ __device__ __forceinline__ void process_colwise_stage( const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; const size_t global_scales_offset_X = scales_offset_X_colwise; + const bool colwise_scale_is_within_bounds = global_scales_offset_X < cols; + size_t scale_idx = 0; if constexpr (WITH_GEMM_SWIZZLED_SCALES) { const size_t tensor_base_row = tensor_base_for_scales / cols; const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; - const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; + const size_t cols_padded = DIVUP(cols, static_cast(scale_tensor_alignment_X_colwise)) * + static_cast(scale_tensor_alignment_X_colwise); + const size_t tensor_scales_offset_colwise_base = tensor_base_row * cols_padded / SCALE_DIM_Y; const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; scale_idx = tensor_scales_offset_colwise_base + transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( @@ -164,7 +168,9 @@ __device__ __forceinline__ void process_colwise_stage( const e8m0_t biased_exponent = ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - scales_colwise[scale_idx] = biased_exponent; + // OOB padded region needs to be zeroed out. + scales_colwise[scale_idx] = + colwise_scale_is_within_bounds ? biased_exponent : static_cast(0); const bf16 block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const ptx::bf16x2 block_scale_inverse_bf16_x2 = {block_scale_inverse, block_scale_inverse}; @@ -234,7 +240,9 @@ __device__ __forceinline__ void process_colwise_stage( const e8m0_t biased_exponent = ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - scales_colwise[scale_idx] = biased_exponent; + // OOB padded region needs to be zeroed out. + scales_colwise[scale_idx] = + colwise_scale_is_within_bounds ? biased_exponent : static_cast(0); const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); #pragma unroll @@ -393,9 +401,9 @@ __device__ __forceinline__ void process_rowwise_stage( } else { scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; } - if (rowwise_scale_is_within_bounds) { - scales_rowwise[scale_idx] = biased_exponent; - } + // OOB padded region needs to be zeroed out. + scales_rowwise[scale_idx] = + rowwise_scale_is_within_bounds ? biased_exponent : static_cast(0); const bf16 block_scale_inverse_bf16 = ptx::exp2f_rcp(biased_exponent); const ptx::bf16x2 block_scale_inverse_bf16_x2 = {block_scale_inverse_bf16, @@ -705,6 +713,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel if constexpr (COLWISE_SCALING) { thread_partial_dbias = partial_dbias_colwise; } else { + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); float *partial_dbias_rowwise = reinterpret_cast(dshmem); constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index f36b071081..a0ae7dde82 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -498,6 +498,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (COLWISE_SCALING) { thread_partial_dbias = partial_dbias_colwise; } else { + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] // HEIGHT = THREADS_Y // WIDTH = THREADS_X * (SCALE_DIM_X + 1) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 6e207370dd..c1b3f8f427 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -946,6 +946,26 @@ struct TypeInfo { } \ } +#define TRANSFORMER_ENGINE_VECTORIZED_LOAD_INTEGER_TYPE_SWITCH(INTEGER_ELTS_NUM, type, ...) \ + switch (INTEGER_ELTS_NUM) { \ + case 1: { \ + using type = int; \ + { __VA_ARGS__ } \ + } break; \ + case 2: { \ + using type = int2; \ + { __VA_ARGS__ } \ + } break; \ + case 4: { \ + using type = int4; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported number of integer elements ", INTEGER_ELTS_NUM, \ + ". Expected one of: 1, 2, or 4."); \ + } \ + } + //////////////////////////////////////////////////////////////////////////////////////////////////// inline int log2_ceil(int value) { @@ -1003,7 +1023,7 @@ size_t typeToSize(const DType type); size_t typeToNumBits(const DType type); void CheckNoopTensor(const Tensor &t, const std::string &name); -void CheckInputTensor(const Tensor &t, const std::string &name); +void CheckInputTensor(const Tensor &t, const std::string &name, bool check_scale_inv_shapes = true); void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false); /*! \brief Update a tensor's FP8 scale-inverse diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index 6c66746e62..38bf09f810 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -4,12 +4,30 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include + #include "../common.h" +#include "../util/cuda_driver.h" +#include "../util/cuda_runtime.h" +#include "../util/ptx.cuh" +#include "../utils.cuh" #include "transformer_engine/fused_attn.h" namespace transformer_engine { + +// ============================================================================ +// prepare_flash_attn: SBH3D <-> BSHD_BSHD_BSHD for the FlashAttention backend +// ============================================================================ + namespace flash_attention { +/// Packed vector of N elements of T; alignment matches a single wide load/store of N * sizeof(T) bytes. +template +struct alignas(sizeof(T) * N) Vec { + T data[N]; +}; + constexpr int warp_size = 32; constexpr int type_size = 2; // FP16 or BF16 constexpr int nvec = sizeof(uint64_t) / type_size; @@ -35,8 +53,8 @@ __launch_bounds__(block_size) __global__ T *my_output = qkv + offset_output; for (int i = 0; i < Z; ++i) { - uint64_t *out = reinterpret_cast(my_output + i * load_size); - *out = *reinterpret_cast(my_input + i * load_size * 3); + Vec *const out = reinterpret_cast *>(my_output + i * load_size); + *out = *reinterpret_cast *>(my_input + i * load_size * 3); } } @@ -61,8 +79,8 @@ __launch_bounds__(block_size) __global__ T *my_output = qkv + offset_output; for (int i = 0; i < Z; ++i) { - uint64_t *out = reinterpret_cast(my_output + i * load_size * 3); - *out = *reinterpret_cast(my_input + i * load_size); + Vec *const out = reinterpret_cast *>(my_output + i * load_size * 3); + *out = *reinterpret_cast *>(my_input + i * load_size); } } @@ -134,6 +152,696 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream } } // namespace flash_attention + +// ============================================================================ +// multi_tensor_transpose_to_bhsd: BSHD/SBHD -> BHSD +// ============================================================================ + +namespace multi_tensor_transpose_to_bhsd { + +using flash_attention::Vec; + +constexpr int kMaxPermuteTensors = 16; + +struct PermuteSlot { + const void *input; + void *output; + size_t S, H, D_in, D_out; +}; + +struct PermuteParams { + PermuteSlot slots[kMaxPermuteTensors]; +}; + +struct TmaMapParams { + CUtensorMap maps[kMaxPermuteTensors]; +}; + +// ---------- path 3: fallback_not_vec_aligned ---------- + +__device__ __forceinline__ void copy_row_bytes(const char *__restrict__ src, char *__restrict__ dst, + size_t D_bytes) { + size_t off = 0; + for (; off + 16 <= D_bytes; off += 16) { + uint4 tmp; + memcpy(&tmp, src + off, 16); + memcpy(dst + off, &tmp, 16); + } + for (; off + 8 <= D_bytes; off += 8) { + uint2 tmp; + memcpy(&tmp, src + off, 8); + memcpy(dst + off, &tmp, 8); + } + for (; off + 4 <= D_bytes; off += 4) { + unsigned int tmp; + memcpy(&tmp, src + off, 4); + memcpy(dst + off, &tmp, 4); + } + for (; off + 2 <= D_bytes; off += 2) { + uint16_t tmp; + memcpy(&tmp, src + off, 2); + memcpy(dst + off, &tmp, 2); + } + for (; off < D_bytes; ++off) dst[off] = src[off]; +} + +__device__ __forceinline__ void copy_and_pad_row_bytes(const char *__restrict__ src, + char *__restrict__ dst, size_t D_bytes, + size_t D_out_bytes) { + copy_row_bytes(src, dst, D_bytes); + for (size_t off = D_bytes; off < D_out_bytes; ++off) dst[off] = 0; +} + +constexpr int TRANSPOSE_TILE = 32; +constexpr int TRANSPOSE_BLOCK = 256; +constexpr int TRANSPOSE_WARPS = TRANSPOSE_BLOCK / 32; // 8 + +template +__launch_bounds__(TRANSPOSE_BLOCK) __global__ + void transpose_to_bhsd_fallback_not_vec_aligned_kernel(PermuteParams params, size_t b, + unsigned int s_tiles) { + const auto &slot = params.slots[blockIdx.z]; + const T *__restrict__ in = reinterpret_cast(slot.input); + T *__restrict__ out = reinterpret_cast(slot.output); + const size_t S = slot.S; + const size_t H = slot.H; + const size_t D = slot.D_in; + const size_t D_out = slot.D_out; + const size_t D_bytes = D * sizeof(T); + const size_t D_out_bytes = D_out * sizeof(T); + const size_t D_smem_pad = (D_bytes + 3u) & ~size_t(3); + + const size_t tile_s = static_cast(blockIdx.x) % static_cast(s_tiles); + const size_t b_i = static_cast(blockIdx.x) / static_cast(s_tiles); + if (b_i >= b) return; + const size_t tile_h = static_cast(blockIdx.y); + + const size_t s_base = tile_s * TRANSPOSE_TILE; + const size_t h_base = tile_h * TRANSPOSE_TILE; + + extern __shared__ char smem[]; + const size_t smem_row = static_cast(TRANSPOSE_TILE) * D_smem_pad + 4; + + // ---- Phase 1: global → smem (sweep consecutive H → coalesced reads) ---- + for (unsigned int warp_off = threadIdx.x >> 5; warp_off < TRANSPOSE_TILE; + warp_off += TRANSPOSE_WARPS) { + const size_t local_s = warp_off; + const size_t local_h = threadIdx.x & 31u; + const size_t s_i = s_base + local_s; + const size_t h_i = h_base + local_h; + if (s_i < S && h_i < H) { + const char *__restrict__ src; + if constexpr (kIsBshd) + src = reinterpret_cast(in + b_i * S * H * D + s_i * H * D + h_i * D); + else + src = reinterpret_cast(in + s_i * b * H * D + b_i * H * D + h_i * D); + copy_row_bytes(src, smem + local_s * smem_row + local_h * D_smem_pad, D_bytes); + } + } + + __syncthreads(); + + // ---- Phase 2: smem → global (sweep consecutive S → coalesced writes, with padding) ---- + for (unsigned int warp_off = threadIdx.x >> 5; warp_off < TRANSPOSE_TILE; + warp_off += TRANSPOSE_WARPS) { + const size_t local_h = warp_off; + const size_t local_s = threadIdx.x & 31u; + const size_t s_i = s_base + local_s; + const size_t h_i = h_base + local_h; + if (s_i < S && h_i < H) { + copy_and_pad_row_bytes( + smem + local_s * smem_row + local_h * D_smem_pad, + reinterpret_cast(out + b_i * H * S * D_out + h_i * S * D_out + s_i * D_out), + D_bytes, D_out_bytes); + } + } +} + +// ---------- path 2: fallback_vec_aligned ---------- + +constexpr int fallback_permute_threads = 1024; + +template +__device__ __forceinline__ void permute_vec_loop(const T *__restrict__ in, T *__restrict__ out, + size_t b, size_t S, size_t H, size_t D, + size_t D_out, size_t b_i, size_t h_i, + size_t s_begin, size_t S_chunk) { + const size_t out_base = b_i * H * S * D_out + h_i * S * D_out; + const size_t d_vec = D / static_cast(N); + const size_t total_work = S_chunk * d_vec; + for (size_t w = static_cast(threadIdx.x); w < total_work; + w += static_cast(blockDim.x)) { + const size_t s_local = w / d_vec; + const size_t s_i = s_begin + s_local; + const size_t d_off = (w % d_vec) * static_cast(N); + const T *__restrict__ in_ptr; + if constexpr (kIsBshd) { + in_ptr = in + b_i * (S * H * D) + s_i * (H * D) + h_i * D + d_off; + } else { + in_ptr = in + s_i * (b * H * D) + b_i * (H * D) + h_i * D + d_off; + } + T *__restrict__ out_ptr = out + out_base + s_i * D_out + d_off; + *reinterpret_cast *>(out_ptr) = *reinterpret_cast *>(in_ptr); + } + if (D_out > D) { + const size_t pad_elems = D_out - D; + const size_t total_pad = S_chunk * pad_elems; + for (size_t w = static_cast(threadIdx.x); w < total_pad; + w += static_cast(blockDim.x)) { + const size_t s_local = w / pad_elems; + const size_t s_i = s_begin + s_local; + const size_t d_off = D + (w % pad_elems); + out[out_base + s_i * D_out + d_off] = static_cast(0.f); + } + } +} + +template +__launch_bounds__(fallback_permute_threads) __global__ + void transpose_to_bhsd_fallback_vec_aligned_kernel(PermuteParams params, size_t b, + unsigned int permute_s_splits, + size_t h_grid) { + const auto &slot = params.slots[blockIdx.z]; + const T *__restrict__ in = reinterpret_cast(slot.input); + T *__restrict__ out = reinterpret_cast(slot.output); + const size_t S = slot.S; + const size_t H = slot.H; + const size_t D = slot.D_in; + const size_t D_out = slot.D_out; + + const size_t b_i = static_cast(blockIdx.x) / h_grid; + const size_t h_i = static_cast(blockIdx.x) % h_grid; + if (b_i >= b) return; + if (h_i >= H) return; + + const unsigned int s_part = blockIdx.y; + const size_t s_begin = (S * static_cast(s_part)) / static_cast(permute_s_splits); + const size_t s_end = + (S * static_cast(s_part + 1)) / static_cast(permute_s_splits); + if (s_begin >= s_end) return; + const size_t S_chunk = s_end - s_begin; + + const size_t D_bytes = D * sizeof(T); + + if (D_bytes % 16 == 0) { + constexpr size_t N = 16 / sizeof(T); + permute_vec_loop(in, out, b, S, H, D, D_out, b_i, h_i, s_begin, S_chunk); + return; + } + if (D_bytes % 8 == 0) { + constexpr size_t N = 8 / sizeof(T); + permute_vec_loop(in, out, b, S, H, D, D_out, b_i, h_i, s_begin, S_chunk); + return; + } + if constexpr (sizeof(T) <= 4) { + if (D_bytes % 4 == 0) { + constexpr size_t N = 4 / sizeof(T); + permute_vec_loop(in, out, b, S, H, D, D_out, b_i, h_i, s_begin, S_chunk); + return; + } + } +} + +// ---------- path 1: TMA ---------- + +constexpr int tma_permute_threads = 128; +constexpr int tma_permute_s_tile_default = 32; + +__device__ __forceinline__ void cp_async_bulk_tensor_4d_global_to_shared( + void *dst_shmem, const CUtensorMap *tensor_map, uint32_t c0, uint32_t c1, uint32_t c2, + uint32_t c3, uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t dst = __cvta_generic_to_shared(dst_shmem); + uint32_t bar = __cvta_generic_to_shared(mbar); + asm volatile( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile" + ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3, %4, %5}], [%6];" ::"r"(dst), + "l"(tensor_map), "r"(c0), "r"(c1), "r"(c2), "r"(c3), "r"(bar) + : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_global_to_shared requires SM 10.0+."); +#endif +} + +__device__ __forceinline__ void cp_async_bulk_tensor_4d_shared_to_global( + const CUtensorMap *tensor_map, uint32_t c0, uint32_t c1, uint32_t c2, uint32_t c3, + void *src_shmem) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t src = __cvta_generic_to_shared(src_shmem); + asm volatile( + "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group" + " [%0, {%1, %2, %3, %4}], [%5];" ::"l"(tensor_map), + "r"(c0), "r"(c1), "r"(c2), "r"(c3), "r"(src) + : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_shared_to_global requires SM 10.0+."); +#endif +} + +static void create_4D_tensor_map(CUtensorMap &tensorMap, void *dataPtr, DType dtype, uint64_t dim0, + uint64_t dim1, uint64_t dim2, uint64_t dim3, uint32_t box0, + uint32_t box1, uint32_t box2, uint32_t box3) { + cuda_driver::ensure_context_exists(); + static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { + void *ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled"); + return reinterpret_cast(ptr); + }(); + + CUtensorMapDataType tma_dtype; + size_t elem_bytes; + switch (dtype) { + case DType::kFloat16: + tma_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + elem_bytes = 2; + break; + case DType::kBFloat16: + tma_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + elem_bytes = 2; + break; + case DType::kFloat8E4M3: + case DType::kFloat8E5M2: + case DType::kFloat8E8M0: + case DType::kByte: + tma_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; + elem_bytes = 1; + break; + default: + NVTE_ERROR("create_4D_tensor_map: unsupported dtype ", to_string(static_cast(dtype))); + } + + constexpr uint32_t rank = 4; + uint64_t size[rank] = {dim0, dim1, dim2, dim3}; + uint64_t stride[rank - 1] = { + dim0 * elem_bytes, + dim0 * dim1 * elem_bytes, + dim0 * dim1 * dim2 * elem_bytes, + }; + uint32_t boxSize[rank] = {box0, box1, box2, box3}; + uint32_t elemStride[rank] = {1, 1, 1, 1}; + + const auto oob_fill = (tma_dtype == CU_TENSOR_MAP_DATA_TYPE_UINT8) + ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE + : CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA; + + NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( + &tensorMap, tma_dtype, rank, dataPtr, size, stride, boxSize, elemStride, + CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, + oob_fill)); +} + +template +__device__ __forceinline__ void issue_tma_load_strided(T *smem_buf, const CUtensorMap *tma, + size_t h_i, size_t s_tile, size_t b_i, + uint64_t *mbar, size_t tile_bytes) { + ptx::mbarrier_arrive_expect_tx(mbar, static_cast(tile_bytes)); + if constexpr (kIsBshd) { + cp_async_bulk_tensor_4d_global_to_shared(smem_buf, tma, 0, static_cast(h_i), + static_cast(s_tile), + static_cast(b_i), mbar); + } else { + cp_async_bulk_tensor_4d_global_to_shared(smem_buf, tma, 0, static_cast(h_i), + static_cast(b_i), + static_cast(s_tile), mbar); + } +} + +__device__ __forceinline__ void st_global_cs_uint4(uint4 *ptr, uint4 val) { + asm volatile("st.global.cs.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"(ptr), "r"(val.x), "r"(val.y), + "r"(val.z), "r"(val.w) + : "memory"); +} +// TMA loads from strided input to smem + non-temporal stores to contiguous output in gmem + +template +__launch_bounds__(tma_permute_threads) __global__ + void transpose_to_bhsd_kernel(const __grid_constant__ TmaMapParams tma_maps, + PermuteParams params, size_t b, size_t h_grid, + unsigned int permute_s_splits, size_t s_tile_size) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + const auto &slot = params.slots[blockIdx.z]; + const CUtensorMap *tma_in = &tma_maps.maps[blockIdx.z]; + T *__restrict__ tensor_out = reinterpret_cast(slot.output); + const size_t Sdim = slot.S; + const size_t Hdim = slot.H; + const size_t Ddim = slot.D_in; + const size_t Ddim_out = slot.D_out; + + const size_t b_i = static_cast(blockIdx.x) / h_grid; + const size_t h_i = static_cast(blockIdx.x) % h_grid; + + if (b_i >= b) return; + if (h_i >= Hdim) return; + + const unsigned int s_part = blockIdx.y; + const size_t s_begin = + (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); + const size_t s_end = + (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); + if (s_begin >= s_end) return; + + const size_t out_base = b_i * Hdim * Sdim * Ddim_out + h_i * Sdim * Ddim_out; + + extern __shared__ __align__(128) char smem_raw[]; + T *smem = reinterpret_cast(smem_raw); + + __shared__ __align__(8) uint64_t mbar; + const bool is_leader = (threadIdx.x == 0); + + if (is_leader) { + ptx::mbarrier_init(&mbar, static_cast(blockDim.x)); + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + const size_t S_TILE = s_tile_size; + const uint32_t tile_bytes = static_cast(S_TILE * Ddim * sizeof(T)); + int parity = 0; + + for (size_t s_tile = s_begin; s_tile < s_end; s_tile += S_TILE) { + const size_t tile_rows = min(S_TILE, s_end - s_tile); + + if (is_leader) { + issue_tma_load_strided(smem, tma_in, h_i, s_tile, b_i, &mbar, tile_bytes); + } else { + ptx::mbarrier_arrive(&mbar); + } + + ptx::mbarrier_wait_parity(&mbar, parity); + parity ^= 1; + + T *__restrict__ out_ptr = tensor_out + out_base + s_tile * Ddim_out; + constexpr size_t vec_elems = sizeof(uint4) / sizeof(T); + + if (Ddim_out == Ddim) { + const size_t total_elems = tile_rows * Ddim; + for (size_t i = threadIdx.x * vec_elems; i < total_elems; + i += static_cast(blockDim.x) * vec_elems) { + uint4 v = *reinterpret_cast(smem + i); + st_global_cs_uint4(reinterpret_cast(out_ptr + i), v); + } + } else { + const size_t total_out_elems = tile_rows * Ddim_out; + for (size_t i = threadIdx.x * vec_elems; i < total_out_elems; + i += static_cast(blockDim.x) * vec_elems) { + const size_t row = i / Ddim_out; + const size_t col = i % Ddim_out; + uint4 v; + if (col + vec_elems <= Ddim) { + v = *reinterpret_cast(smem + row * Ddim + col); + } else { + memset(&v, 0, sizeof(v)); + const size_t smem_off = row * Ddim + col; + size_t copy_elems = (col < Ddim) ? (Ddim - col) : 0; + if (copy_elems > 0) memcpy(&v, smem + smem_off, copy_elems * sizeof(T)); + } + st_global_cs_uint4(reinterpret_cast(out_ptr + i), v); + } + } + + __syncthreads(); + } + + if (is_leader) { + ptx::mbarrier_invalid(&mbar); + } +#endif +} + +// 4D TMA descriptor: +// [B, S, H, D]: TMA dims [D, H, S, B], box [D, 1, S_TILE, 1] +// [S, B, H, D]: TMA dims [D, H, B, S], box [D, 1, 1, S_TILE] + +static void create_strided_tensor_map(CUtensorMap &map, void *ptr, DType dtype, size_t b, size_t s, + size_t h, size_t d, size_t s_tile, bool is_bshd) { + if (is_bshd) { + create_4D_tensor_map(map, ptr, dtype, static_cast(d), static_cast(h), + static_cast(s), static_cast(b), + static_cast(d), 1, static_cast(s_tile), 1); + } else { + create_4D_tensor_map(map, ptr, dtype, static_cast(d), static_cast(h), + static_cast(b), static_cast(s), + static_cast(d), 1, 1, static_cast(s_tile)); + } +} + +void multi_tensor_transpose_to_bhsd(Tensor *inputs, Tensor *outputs, size_t num_tensors, + NVTE_QKV_Format original_format, cudaStream_t stream) { + using namespace transformer_engine; + if (num_tensors == 0) return; + NVTE_CHECK(num_tensors <= static_cast(kMaxPermuteTensors), "num_tensors must be in [1, ", + kMaxPermuteTensors, "], got ", num_tensors, "."); + + const bool is_bshd = (original_format == NVTE_QKV_Format::NVTE_BSHD); + const DType dtype = inputs[0].dtype(); + const size_t elem_size = typeToSize(dtype); + const size_t b = outputs[0].shape()[0]; + + PermuteParams params{}; + size_t s_max = 0, h_max = 0, s_min = SIZE_MAX; + size_t d_in_max = 0, d_out_max = 0; + bool any_not_vec_aligned = false; + bool all_tma_ok = true; + + for (size_t i = 0; i < num_tensors; ++i) { + const size_t H = outputs[i].shape()[1]; + const size_t S = outputs[i].shape()[2]; + const size_t D_in = inputs[i].shape()[inputs[i].shape().size() - 1]; + const size_t D_out = outputs[i].shape()[3]; + params.slots[i] = {inputs[i].data.dptr, outputs[i].data.dptr, S, H, D_in, D_out}; + s_max = std::max(s_max, S); + h_max = std::max(h_max, H); + s_min = std::min(s_min, S); + d_in_max = std::max(d_in_max, D_in); + d_out_max = std::max(d_out_max, D_out); + if ((D_in * elem_size) % 4 != 0) any_not_vec_aligned = true; + const size_t inner = D_in * elem_size; + if (inner < 32 || inner % 16 != 0) all_tma_ok = false; + } + + if (all_tma_ok) { + const int sm = cuda::sm_arch(cuda::current_device()); + if (sm < 100) { + all_tma_ok = false; + } else { + switch (dtype) { + case DType::kFloat16: + case DType::kBFloat16: + case DType::kFloat8E4M3: + case DType::kFloat8E5M2: + case DType::kFloat8E8M0: + case DType::kByte: + break; + default: + all_tma_ok = false; + } + } + } + + // Dispatch order: + // 1. TMA path: SM 10.0+, D_in*elem >= 32 && 16-aligned, supported dtype, + // and s_tile*D_in*elem is uint4-aligned. + // 2. Fallback path (vec-aligned): vectorized loads/stores when D_in*elem % 4 == 0. + // 3. Fallback path (not-vec-aligned): shared-memory transpose when D_in*elem % 4 != 0. + if (all_tma_ok) { + const size_t s_tile = std::min(static_cast(tma_permute_s_tile_default), s_min); + bool tma_aligned = true; + for (size_t i = 0; i < num_tensors && tma_aligned; ++i) { + if ((s_tile * params.slots[i].D_in * elem_size) % sizeof(uint4) != 0) tma_aligned = false; + } + + if (tma_aligned) { + TmaMapParams tma_maps{}; + for (size_t i = 0; i < num_tensors; ++i) { + const auto &slot = params.slots[i]; + create_strided_tensor_map(tma_maps.maps[i], const_cast(slot.input), dtype, b, + slot.S, slot.H, slot.D_in, s_tile, is_bshd); + } + + const unsigned int permute_s_splits = std::max(1u, static_cast(s_min / s_tile)); + dim3 grid(static_cast(b * h_max), permute_s_splits, + static_cast(num_tensors)); + const size_t smem_bytes = s_tile * d_in_max * elem_size; + + if (is_bshd) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + dtype, dtype_t, auto kernel = transpose_to_bhsd_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + kernel<<>>(tma_maps, params, b, h_max, + permute_s_splits, s_tile);); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + dtype, dtype_t, auto kernel = transpose_to_bhsd_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + kernel<<>>(tma_maps, params, b, h_max, + permute_s_splits, s_tile);); + } + NVTE_CHECK_CUDA(cudaGetLastError()); + return; + } + } + + if (!any_not_vec_aligned) { + const unsigned int permute_s_splits = std::max( + 1u, static_cast(s_min / static_cast(fallback_permute_threads))); + dim3 grid(static_cast(b * h_max), permute_s_splits, + static_cast(num_tensors)); + + if (is_bshd) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + dtype, dtype_t, + transpose_to_bhsd_fallback_vec_aligned_kernel + <<>>(params, b, permute_s_splits, h_max);); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + dtype, dtype_t, + transpose_to_bhsd_fallback_vec_aligned_kernel + <<>>(params, b, permute_s_splits, h_max);); + } + } else { + const unsigned int st = + static_cast((s_max + TRANSPOSE_TILE - 1) / TRANSPOSE_TILE); + const unsigned int ht = + static_cast((h_max + TRANSPOSE_TILE - 1) / TRANSPOSE_TILE); + dim3 grid(static_cast(b) * st, ht, static_cast(num_tensors)); + const size_t D_pad = (d_in_max * elem_size + 3u) & ~size_t(3); + const size_t smem_bytes = + static_cast(TRANSPOSE_TILE) * (static_cast(TRANSPOSE_TILE) * D_pad + 4); + + if (is_bshd) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + dtype, dtype_t, + transpose_to_bhsd_fallback_not_vec_aligned_kernel + <<>>(params, b, st);); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + dtype, dtype_t, + transpose_to_bhsd_fallback_not_vec_aligned_kernel + <<>>(params, b, st);); + } + } + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace multi_tensor_transpose_to_bhsd + +// =================================================================================== +// multi_tensor_pad_last_dim: pad the last dim of multiple tensors to certain alignment +// =================================================================================== + +namespace multi_tensor_pad_last_dim { + +constexpr int pad_threads_per_block = 256; +constexpr int kMaxPadTensors = 16; + +struct PadLastDimArgs { + const uint8_t *input; + uint32_t *output; + size_t n_uint32; + uint32_t in_row_bytes; + uint32_t out_row_uint32; +}; + +struct MultiPadParams { + PadLastDimArgs tensors[kMaxPadTensors]; +}; + +__launch_bounds__(pad_threads_per_block) __global__ + void multi_tensor_pad_last_dim_kernel(MultiPadParams params) { + const auto &a = params.tensors[blockIdx.y]; + + for (size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; idx < a.n_uint32; + idx += static_cast(gridDim.x) * blockDim.x) { + const uint32_t col_byte = (idx % a.out_row_uint32) * 4; + const size_t row = idx / a.out_row_uint32; + const uint8_t *__restrict__ src = a.input + row * static_cast(a.in_row_bytes); + + uint32_t val; + if (col_byte + 4 <= a.in_row_bytes) { + memcpy(&val, src + col_byte, 4); + } else if (col_byte >= a.in_row_bytes) { + val = 0; + } else { + val = 0; + memcpy(&val, src + col_byte, a.in_row_bytes - col_byte); + } + a.output[idx] = val; + } +} + +void launch_pad_batch(MultiPadParams ¶ms, int kernel_count, size_t max_n_uint32, + cudaStream_t stream) { + if (kernel_count == 0) return; + constexpr int threads = pad_threads_per_block; + const int blocks_x = static_cast( + std::min(DIVUP(max_n_uint32, static_cast(threads)), static_cast(65535))); + dim3 grid(blocks_x, kernel_count); + multi_tensor_pad_last_dim_kernel<<>>(params); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void multi_tensor_pad_last_dim(Tensor *inputs, Tensor *outputs, size_t num_tensors, + cudaStream_t stream) { + using namespace transformer_engine; + + if (num_tensors == 0) return; + + MultiPadParams params{}; + size_t max_n_uint32 = 0; + int kernel_count = 0; + + for (size_t i = 0; i < num_tensors; ++i) { + auto &inp = inputs[i]; + auto &out = outputs[i]; + + NVTE_CHECK(inp.data.shape.size() == 2, "Expected 2D input tensor at index ", i, "."); + NVTE_CHECK(out.data.shape.size() == 2, "Expected 2D output tensor at index ", i, "."); + NVTE_CHECK(inp.data.dtype == out.data.dtype, "Dtype mismatch at index ", i, "."); + + const size_t rows = inp.data.shape[0]; + const size_t in_cols = inp.data.shape[1]; + const size_t out_cols = out.data.shape[1]; + + NVTE_CHECK(out.data.shape[0] == rows, "Row count mismatch at index ", i, "."); + NVTE_CHECK(out_cols >= in_cols, "out_cols < in_cols at index ", i, "."); + + if (rows == 0) continue; + + if (in_cols == out_cols) { + const size_t total_bytes = rows * in_cols * typeToSize(inp.data.dtype); + NVTE_CHECK_CUDA(cudaMemcpyAsync(out.data.dptr, inp.data.dptr, total_bytes, + cudaMemcpyDeviceToDevice, stream)); + continue; + } + + if (kernel_count == kMaxPadTensors) { + launch_pad_batch(params, kernel_count, max_n_uint32, stream); + params = MultiPadParams{}; + kernel_count = 0; + max_n_uint32 = 0; + } + + const size_t elem_size = typeToSize(inp.data.dtype); + const auto in_row_bytes = static_cast(in_cols * elem_size); + const auto out_row_bytes = static_cast(out_cols * elem_size); + NVTE_CHECK(out_row_bytes % 4 == 0, "Padded row size in bytes (", out_row_bytes, + ") must be a multiple of 4."); + + const uint32_t out_row_uint32 = out_row_bytes / 4; + const size_t n_uint32 = rows * out_row_uint32; + + params.tensors[kernel_count] = {reinterpret_cast(inp.data.dptr), + reinterpret_cast(out.data.dptr), n_uint32, + in_row_bytes, out_row_uint32}; + max_n_uint32 = std::max(max_n_uint32, n_uint32); + ++kernel_count; + } + + launch_pad_batch(params, kernel_count, max_n_uint32, stream); +} + +} // namespace multi_tensor_pad_last_dim } // namespace transformer_engine void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream) { @@ -153,3 +861,40 @@ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTET *convertNVTETensorCheck(v), *convertNVTETensorCheck(qkv), stream); } + +void nvte_multi_tensor_transpose_to_bhsd(NVTETensor *inputs, NVTETensor *outputs, + size_t num_tensors, NVTE_QKV_Format original_format, + cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_transpose_to_bhsd); + NVTE_CHECK(original_format == NVTE_QKV_Format::NVTE_BSHD || + original_format == NVTE_QKV_Format::NVTE_SBHD, + "nvte_multi_tensor_transpose_to_bhsd: only BSHD/SBHD -> BHSD is currently " + "supported."); + using namespace transformer_engine; + + std::vector in_vec(num_tensors), out_vec(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + in_vec[i] = *convertNVTETensorCheck(inputs[i]); + out_vec[i] = *convertNVTETensorCheck(outputs[i]); + } + constexpr size_t kBatch = multi_tensor_transpose_to_bhsd::kMaxPermuteTensors; + for (size_t offset = 0; offset < num_tensors; offset += kBatch) { + const size_t batch = std::min(num_tensors - offset, kBatch); + multi_tensor_transpose_to_bhsd::multi_tensor_transpose_to_bhsd( + in_vec.data() + offset, out_vec.data() + offset, batch, original_format, stream); + } +} + +void nvte_multi_tensor_pad_last_dim(NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors, + cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_pad_last_dim); + using namespace transformer_engine; + + std::vector in_vec(num_tensors), out_vec(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + in_vec[i] = *convertNVTETensorCheck(inputs[i]); + out_vec[i] = *convertNVTETensorCheck(outputs[i]); + } + multi_tensor_pad_last_dim::multi_tensor_pad_last_dim(in_vec.data(), out_vec.data(), num_tensors, + stream); +} diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 3d6e3a0aac..141767b803 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -131,6 +131,8 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + return NVTE_QKV_Layout_Group::NVTE_SD_SD_SD; default: NVTE_ERROR("Unsupported qkv_layout ", transformer_engine::to_string(qkv_layout), " in nvte_get_qkv_layout_group."); @@ -172,6 +174,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_THD_2SBHD; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("Unsupported qkv_layout ", transformer_engine::to_string(qkv_layout), " in nvte_get_qkv_format."); @@ -192,6 +196,8 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD_2BSHD: case NVTE_QKV_Format::NVTE_THD_2SBHD: return NVTE_QKV_Format::NVTE_THD; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("Unsupported qkv_format ", transformer_engine::to_string(qkv_format), " in nvte_get_q_format."); @@ -212,6 +218,8 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { return NVTE_QKV_Format::NVTE_BSHD; case NVTE_QKV_Format::NVTE_THD: return NVTE_QKV_Format::NVTE_THD; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("Unsupported qkv_format ", transformer_engine::to_string(qkv_format), " in nvte_get_kv_format."); @@ -269,9 +277,22 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || + // 9.21: d_qk=192, d_v=128 + (cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && head_dim_qk <= 192 && + head_dim_v <= 128 && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) && + // pre-9.21: {bshd, sbhd}, {vanilla} + // 9.21+: {bshd, sbhd, bhsd}, {vanilla, off-by-one, learnable} + ((cudnn_runtime_version < 92100 && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && + softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) || + (cudnn_runtime_version >= 92100 && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || + qkv_format == NVTE_QKV_Format::NVTE_BHSD))) && + !requires_64bit_ragged_offset && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { if (cudnn_runtime_version >= 8900) { @@ -410,12 +431,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && // qkv format (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || + qkv_format == NVTE_QKV_Format::NVTE_BHSD || (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || cudnn_runtime_version >= 90600)) || ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || + q_format == NVTE_QKV_Format::NVTE_BHSD || (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || + kv_format == NVTE_QKV_Format::NVTE_BHSD || (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && cudnn_runtime_version >= 90700)) && // sliding window @@ -565,7 +589,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { @@ -587,23 +612,24 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); - auto ndim = input_Q->data.shape.size(); - auto ndim_kv = input_K->data.shape.size(); - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; - size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim_kv - 1]; - size_t t_q = 0; - size_t t_kv = 0; NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + auto *q_dims = input_Q->data.shape.data(); + auto *k_dims = input_K->data.shape.data(); + auto *v_dims = input_V->scaling_mode != NVTE_MXFP8_1D_SCALING + ? input_V->data.shape.data() + : input_V->columnwise_data.shape.data(); + AttentionShape q_shape(q_format, q_dims); + AttentionShape k_shape(kv_format, k_dims); + AttentionShape v_shape(kv_format, v_dims); + size_t b = q_shape.b(), h_q = q_shape.h(), d_qk = q_shape.d(), t_q = q_shape.t(); + size_t h_kv = k_shape.h(), t_kv = k_shape.t(), d_v = v_shape.d(); if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_K->data.shape[0]; + b = input_cu_seqlens_q->data.shape[0] - 1; + } else if (kv_format == NVTE_QKV_Format::NVTE_THD) { + b = input_cu_seqlens_kv->data.shape[0] - 1; } + int64_t num_pages_k = 0; int64_t num_pages_v = 0; int64_t page_size_k = 0; @@ -642,38 +668,26 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso return_max_logit, cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { -#if (CUDNN_VERSION >= 8901) fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); -#endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { -#if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, - return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, - input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + return_max_logit, attn_scale, dropout, qkv_layout, o_format, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, + input_V, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); -#else - NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " - "\n"); -#endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { -#if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, + attn_scale, dropout, qkv_layout, o_format, qkv_scale_inv_format, bias_type, + attn_mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, input_Q, input_K, input_V, input_SoftmaxOffset, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); -#endif } else { NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } @@ -687,11 +701,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -712,22 +728,20 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); - auto ndim = input_Q->data.shape.size(); - auto ndim_kv = input_K->data.shape.size(); - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; - size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim_kv - 1]; - size_t t_q = 0; - size_t t_kv = 0; NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + auto *q_dims = input_Q->data.shape.data(); + auto *k_dims = input_K->data.shape.data(); + auto *v_dims = input_V->data.shape.data(); + AttentionShape q_shape(q_format, q_dims); + AttentionShape k_shape(kv_format, k_dims); + AttentionShape v_shape(kv_format, v_dims); + size_t b = q_shape.b(), h_q = q_shape.h(), d_qk = q_shape.d(), t_q = q_shape.t(); + size_t h_kv = k_shape.h(), t_kv = k_shape.t(), d_v = v_shape.d(); if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_K->data.shape[0]; + b = input_cu_seqlens_q->data.shape[0] - 1; + } else if (kv_format == NVTE_QKV_Format::NVTE_THD) { + b = input_cu_seqlens_kv->data.shape[0] - 1; } auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); @@ -740,17 +754,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso cuda_graph, deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { -#if (CUDNN_VERSION >= 8901) Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); -#endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { -#if (CUDNN_VERSION >= 8900) size_t i = 0; Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); @@ -763,30 +772,36 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, - bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, - input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, - output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); -#else - const char *err_msg = - "cuDNN 8.9.0 is required for BF16/FP16 fused attention " - "with arbitrary sequence length. \n"; - NVTE_ERROR(err_msg); -#endif + qkv_layout, o_format, do_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, input_Q, input_K, + input_V, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, + output_dV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, + handle); } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { -#if (CUDNN_VERSION >= 8900) - const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, deterministic, input_Q, input_K, - input_V, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, - output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); -#endif + size_t i = 0; + const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + const Tensor *input_ZInv = nullptr; + if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + const Tensor *input_SoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + const Tensor *input_dO_f16 = nullptr; + if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { + input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, + qkv_layout, o_format, do_format, dqkv_layout, qkv_scale_inv_format, + do_scale_inv_format, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, + input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, + input_ZInv, input_S, input_SoftmaxOffset, input_output_dP, output_dQ, + output_dK, output_dV, output_dSoftmaxOffset, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); } else { NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index eed6740740..6df7ad35c8 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -19,7 +19,6 @@ #include "fused_attn_f16_arbitrary_seqlen.h" #include "utils.h" -#if (CUDNN_VERSION >= 8900) #define Q_ID 1 #define K_ID 2 #define V_ID 3 @@ -54,11 +53,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, - void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, + void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, + void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -80,8 +79,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (is_training && dropout_probability != 0.0f); - NVTE_QKV_Format q_format = nvte_get_q_format(layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); @@ -89,7 +88,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( const int sm_arch_ = cuda::sm_arch(device_id); bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); if (is_paged_kv) { NVTE_CHECK(is_padding, "Paged attention requires padding mask!"); @@ -135,7 +134,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl( scaling_factor, is_training, dropout_probability, - layout, + qkv_layout, + o_format, + NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Layout_NOT_SET, + NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Format_NOT_SET, bias_type, mask_type, softmax_type, @@ -202,17 +206,17 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); if (is_paged_kv) { generateMatrixStrides(num_pages_k, hg, page_size_k, page_size_v, d_qk, k_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_K_Matrix); + qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(num_pages_v, hg, page_size_k, page_size_v, d_v, v_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); } else { - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); } @@ -368,7 +372,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto [O, Stats] = mha_graph->sdpa(Q, K, V, std::move(sdpa_options)); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_O_Matrix); O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride); if (is_ragged_q) { @@ -513,7 +517,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; } - const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); cu_seqlens_padded_to_offsets<<>>( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, @@ -551,7 +555,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, @@ -578,8 +583,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (dropout_probability != 0.0f); - NVTE_QKV_Format q_format = nvte_get_q_format(layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); @@ -587,7 +592,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const int sm_arch_ = cuda::sm_arch(device_id); bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); if (is_paged_kv) { NVTE_CHECK(is_padding, "Paged attention requires padding mask!"); @@ -632,7 +637,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl( scaling_factor, true, dropout_probability, - layout, + qkv_layout, + o_format, + do_format, + dqkv_layout, + NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Format_NOT_SET, bias_type, mask_type, softmax_type, @@ -703,13 +713,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::vector k_stride(4); std::vector v_stride(4); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_O_Matrix); q = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -1024,7 +1034,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; } - const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); cu_seqlens_padded_to_offsets<<>>( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, @@ -1067,13 +1077,14 @@ void fused_attn_arbitrary_seqlen_fwd( size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1202,12 +1213,12 @@ void fused_attn_arbitrary_seqlen_fwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, - is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, - devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, o_format, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, + devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, + devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1228,6 +1239,7 @@ void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, @@ -1300,12 +1312,12 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, - devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, - devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, - devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, + devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, + devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1322,4 +1334,3 @@ void fused_attn_arbitrary_seqlen_bwd( } } } // namespace transformer_engine -#endif // CUDNN_VERSION >= 8900 diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 4dd7f3d1da..8f79b5bb4a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -17,25 +17,26 @@ #include "transformer_engine/fused_attn.h" namespace transformer_engine { -#if (CUDNN_VERSION >= 8900) void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, @@ -46,7 +47,6 @@ void fused_attn_arbitrary_seqlen_bwd( const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); -#endif // CUDNN_VERSION >= 8900 } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index 336e3d5386..d5151a51f1 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -16,7 +16,6 @@ #include "fused_attn_f16_max512_seqlen.h" #include "utils.h" -#if (CUDNN_VERSION >= 8901) #define Q_ID 1 #define K_ID 2 #define V_ID 3 @@ -1342,4 +1341,3 @@ void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, } } } // namespace transformer_engine -#endif // CUDNN_VERSION >= 8901 diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h index 3b30c6e716..1e59d4dc8f 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h @@ -17,7 +17,6 @@ #include "transformer_engine/fused_attn.h" namespace transformer_engine { -#if (CUDNN_VERSION >= 8901) void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, @@ -37,7 +36,6 @@ void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, Tensor *output_dBias, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); -#endif // CUDNN_VERSION >= 8901 } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_ diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 80e64370f9..d97f388459 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -15,7 +15,6 @@ namespace fused_attn { using namespace transformer_engine; -#if (CUDNN_VERSION >= 8900) std::unordered_map tensor_name_to_uid = {{"Q", 1}, {"K", 2}, {"V", 3}, @@ -1652,16 +1651,20 @@ void fused_attn_fp8_bwd_impl( // fused attention FWD FP8 with FE 1.0+ void fused_attn_fp8_fwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, bool is_training, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, - void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, - void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, - cudnn_frontend::DataType_t o_tensor_type, void* workspace, size_t* workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrSoftmaxOffset, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, + void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, + void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, + void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, + NVTEScalingMode scaling_mode, NVTE_QKV_Format qkv_scale_inv_format, void* workspace, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; + const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -1669,19 +1672,27 @@ void fused_attn_fp8_fwd_impl_v1( bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); auto bias_b = b; auto bias_h = h; auto bias_sq = s_q; auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF || - o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_delayed_scaling = (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - NVTE_CHECK(is_current_scaling || is_delayed_scaling, - "FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or " - "kFloat8E5M2!"); + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK( + is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + NVTE_CHECK(!is_mxfp8 || cudnn_runtime_version >= 92100, + "MXFP8 fused attention requires cuDNN 9.21.0 or later!"); try { FADescriptor_v1 descriptor{b, @@ -1689,8 +1700,8 @@ void fused_attn_fp8_fwd_impl_v1( hg, s_q, s_kv, - d, - d, + d_qk, + d_v, 0, 0, 0, @@ -1704,13 +1715,18 @@ void fused_attn_fp8_fwd_impl_v1( scaling_factor, is_training, dropout_probability, - layout, + qkv_layout, + o_format, + NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Layout_NOT_SET, + qkv_scale_inv_format, + NVTE_QKV_Format_NOT_SET, bias_type, mask_type, - NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, - 0, - 0, - true, + softmax_type, + window_size_left, + window_size_right, + bottom_right_diagonal, true, qkv_tensor_type, o_tensor_type, @@ -1736,6 +1752,7 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr, // amax_o std::shared_ptr, // Stats std::shared_ptr, // bias + std::shared_ptr, // softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // dropout_seed @@ -1762,31 +1779,28 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr Q, K, V, attn_scale; std::shared_ptr descale_q, descale_k, descale_v; std::shared_ptr descale_s, scale_s, scale_o; - std::shared_ptr bias, seq_q, seq_kv; + std::shared_ptr bias, softmax_offset, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; - std::vector q_stride(4); - std::vector k_stride(4); - std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + // Q, K, V, attn_scale + std::vector q_strides(4), k_strides(4), v_strides(4); + generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), + k_strides.data(), v_strides.data(), qkv_layout); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_strides) + .set_data_type(qkv_tensor_type)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_strides) + .set_data_type(qkv_tensor_type)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); - + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_strides) + .set_data_type(qkv_tensor_type)); attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") .set_dim({1, 1, 1, 1}) @@ -1794,21 +1808,61 @@ void fused_attn_fp8_fwd_impl_v1( .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - - if (is_delayed_scaling) { - scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); - } - if (is_current_scaling) { - scale_o = mha_graph->tensor(1.0f); + // Descale_q, Descale_k, Descale_v, Descale_s, Scale_s, Scale_o + if (is_delayed_scaling || is_current_scaling) { + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_v"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_s"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_s"); + if (is_delayed_scaling) { + scale_o = mha_graph->tensor_like(descale_q, "Scale_o"); + } + if (is_current_scaling) { + scale_o = mha_graph->tensor(1.0f); + } + } else if (is_mxfp8) { + NVTE_QKV_Format q_scale_inv_format = (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) + ? qkv_scale_inv_format + : nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_scale_inv_format = (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) + ? qkv_scale_inv_format + : nvte_get_kv_format(qkv_layout); + std::vector q_scale_strides(4); + std::vector k_scale_strides(4); + std::vector v_scale_strides(4); + auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, + q_scale_strides.data(), q_scale_inv_format); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, + k_scale_strides.data(), kv_scale_inv_format); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, + v_scale_strides.data(), kv_scale_inv_format); + descale_q = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k") + .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) + .set_stride(k_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_v = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_v") + .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_v_padded}) + .set_stride(v_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } fe::graph::SDPA_fp8_attributes sdpa_options; @@ -1818,6 +1872,20 @@ void fused_attn_fp8_fwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); + fe::DiagonalAlignment_t const& diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_options.set_diagonal_alignment(diagonal_alignment); + + if (cudnn_runtime_version >= 92100) { + if (window_size_left != -1) { + sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); + } + if (window_size_right != -1) { + sdpa_options.set_diagonal_band_right_bound(window_size_right); + } + } + // sdpa_options.set_alibi_mask(is_alibi); // if (is_bias) { // bias = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -1855,19 +1923,41 @@ void fused_attn_fp8_fwd_impl_v1( sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8( - Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options); + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_options.set_sink_token(softmax_offset); + } - std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); - O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride).set_data_type(o_tensor_type); - amax_o->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); + std::shared_ptr O, Stats, amax_s, amax_o; + if (is_delayed_scaling || is_current_scaling) { + auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, sdpa_options); + O = outputs[0]; + Stats = outputs[1]; + amax_s = outputs[2]; + amax_o = outputs[3]; + amax_s->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + } else if (is_mxfp8) { + auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, sdpa_options); + O = outputs[0]; + Stats = outputs[1]; + amax_o = outputs[2]; + } - amax_s->set_output(true) + std::vector o_strides(4); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_strides.data(), o_format); + O->set_output(true) + .set_dim({b, h, s_q, d_v}) + .set_stride(o_strides) + .set_data_type(o_tensor_type); + amax_o->set_output(!is_mxfp8) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); @@ -1890,10 +1980,15 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr, // O std::shared_ptr, // amax_s std::shared_ptr> // amax_o - key_tensors_tuple = std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, - scale_s, scale_o, attn_scale, O, amax_s, amax_o); + key_tensors_tuple = + is_mxfp8 ? std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, nullptr, nullptr, + nullptr, attn_scale, O, nullptr, amax_o) + : std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, attn_scale, O, amax_s, amax_o); auto Stats_tuple = std::make_tuple(Stats); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); + auto softmax_offset_tuple = + is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) @@ -1904,17 +1999,17 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, - bias_tuple, padding_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, + softmax_offset_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, - attn_scale, O, amax_s, amax_o, Stats, bias, seq_q, seq_kv, dropout_seed, dropout_offset] = - get_graph(sdpa_fp8_fprop_cache, descriptor); + attn_scale, O, amax_s, amax_o, Stats, bias, softmax_offset, seq_q, seq_kv, dropout_seed, + dropout_offset] = get_graph(sdpa_fp8_fprop_cache, descriptor); auto plan_workspace_size = mha_graph->get_workspace_size(); @@ -1937,17 +2032,19 @@ void fused_attn_fp8_fwd_impl_v1( {descale_q, devPtrDescaleQ}, {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, - {descale_s, devPtrDescaleS}, - {scale_s, devPtrScaleS}, {attn_scale, &scaling_factor}, {O, devPtrO}, - {amax_s, devPtrAmaxS}, - {amax_o, devPtrAmaxO}, {Stats, devPtrM}}; if (is_delayed_scaling) { variant_pack[scale_o] = devPtrScaleO; } + if (is_delayed_scaling || is_current_scaling) { + variant_pack[descale_s] = devPtrDescaleS; + variant_pack[scale_s] = devPtrScaleS; + variant_pack[amax_s] = devPtrAmaxS; + variant_pack[amax_o] = devPtrAmaxO; + } /* if (is_bias) { variant_pack[bias] = devPtrBias; @@ -1972,6 +2069,10 @@ void fused_attn_fp8_fwd_impl_v1( variant_pack[dropout_offset] = devPtrDropoutOffset; } + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); @@ -1980,20 +2081,27 @@ void fused_attn_fp8_fwd_impl_v1( // fused attention BWD FP8 with FE 1.0+ void fused_attn_fp8_bwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, float scaling_factor, - float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, - void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, - void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, - void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, - void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, - void* devPtrAmaxdV, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, - void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, + void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrSoftmaxOffset, void* devPtrdQ, + void* devPtrdK, void* devPtrdV, void* devPtrdSoftmaxOffset, void* devPtrDescaleQ, + void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, + void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, void* devPtrScaledP, + void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, + void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrQ_t, void* devPtrK_t, + void* devPtrdO_f16, void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, + void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, - cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, void* workspace, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; + const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -2001,20 +2109,28 @@ void fused_attn_fp8_bwd_impl_v1( bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (dropout_probability != 0.0f); + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); auto bias_b = b; auto bias_h = h; - const auto cudnn_runtime_version = cudnnGetVersion(); auto bias_sq = s_q; auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || - dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_delayed_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - NVTE_CHECK(is_current_scaling || is_delayed_scaling, - "FP8 fused attention only supports dQKV tensor in kFloat16, kBFloat16, kFloat8E4M3 or " - "kFloat8E5M2!"); + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK( + is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + NVTE_CHECK(!is_mxfp8 || cudnn_runtime_version >= 92100, + "MXFP8 fused attention requires cuDNN 9.21.0 or later!"); + bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF || o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); @@ -2024,8 +2140,8 @@ void fused_attn_fp8_bwd_impl_v1( hg, s_q, s_kv, - d, - d, + d_qk, + d_v, 0, 0, 0, @@ -2039,13 +2155,18 @@ void fused_attn_fp8_bwd_impl_v1( scaling_factor, true, dropout_probability, - layout, + qkv_layout, + o_format, + do_format, + dqkv_layout, + qkv_scale_inv_format, + do_scale_inv_format, bias_type, mask_type, - NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, - 0, - 0, - true, + softmax_type, + window_size_left, + window_size_right, + bottom_right_diagonal, deterministic, qkv_tensor_type, o_tensor_type, @@ -2056,18 +2177,25 @@ void fused_attn_fp8_bwd_impl_v1( namespace fe = cudnn_frontend; using graph_and_tensors = std::tuple, - std::shared_ptr, // q - std::shared_ptr, // k - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // stats + std::shared_ptr, // Q + std::shared_ptr, // Q_t + std::shared_ptr, // K + std::shared_ptr, // K_t + std::shared_ptr, // V + std::shared_ptr, // O + std::shared_ptr, // Stats std::shared_ptr, // dO + std::shared_ptr, // dO_t + std::shared_ptr, // dO_f16 std::shared_ptr, // attn_scale std::shared_ptr, // descale_q + std::shared_ptr, // descale_q_t std::shared_ptr, // descale_k + std::shared_ptr, // descale_k_t std::shared_ptr, // descale_v std::shared_ptr, // descale_o std::shared_ptr, // descale_dO + std::shared_ptr, // descale_dO_t std::shared_ptr, // descale_s std::shared_ptr, // descale_dP std::shared_ptr, // scale_dQ @@ -2084,6 +2212,8 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr, // amax_dP std::shared_ptr, // bias std::shared_ptr, // dBias + std::shared_ptr, // softmax_offset + std::shared_ptr, // d_softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // dropout_seed @@ -2108,54 +2238,54 @@ void fused_attn_fp8_bwd_impl_v1( .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - std::shared_ptr q, k, v, o, dO, stats, attn_scale; - std::shared_ptr descale_q, descale_k, descale_v; + std::shared_ptr Q, Q_t, K, K_t, V, O, dO, dO_t, dO_f16, Stats, + attn_scale; + std::shared_ptr descale_q, descale_q_t, descale_k, descale_k_t, + descale_v; std::shared_ptr descale_s, descale_o; - std::shared_ptr descale_dP, descale_dO; + std::shared_ptr descale_dP, descale_dO, descale_dO_t; std::shared_ptr scale_s, scale_dP; std::shared_ptr scale_dQ, scale_dK, scale_dV; - std::shared_ptr bias, dBias, seq_q, seq_kv; + std::shared_ptr bias, dBias, softmax_offset, d_softmax_offset; + std::shared_ptr seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; - std::vector q_stride(4); - std::vector k_stride(4); - std::vector v_stride(4); - std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); - q = mha_graph->tensor(fe::graph::Tensor_attributes() + // Q, K, V, O, dO, stats, attn_scale + std::vector q_strides(4), k_strides(4), v_strides(4), o_strides(4), dO_strides(4); + generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), + k_strides.data(), v_strides.data(), qkv_layout); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_strides.data(), o_format); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_strides.data(), do_format); + Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); - k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_strides) + .set_data_type(qkv_tensor_type)); + K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); - v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_strides) + .set_data_type(qkv_tensor_type)); + V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); - o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_strides) + .set_data_type(qkv_tensor_type)); + O = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride) + .set_dim({b, h, s_q, d_v}) + .set_stride(o_strides) .set_data_type(o_tensor_type)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); - stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") + .set_dim({b, h, s_q, d_v}) + .set_stride(dO_strides) + .set_data_type(do_tensor_type)); + Stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Stats") .set_dim({b, h, s_q, 1}) .set_stride({h * s_q, s_q, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); - attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") .set_dim({1, 1, 1, 1}) @@ -2163,33 +2293,136 @@ void fused_attn_fp8_bwd_impl_v1( .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); - if (is_O_in_F16) { - descale_o = mha_graph->tensor(1.0f); - } else { - descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); - } - descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); - - if (is_delayed_scaling) { - scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); - scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); - scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); - } - if (is_current_scaling) { - scale_dQ = mha_graph->tensor(1.0f); - scale_dK = mha_graph->tensor(1.0f); - scale_dV = mha_graph->tensor(1.0f); + // Descale_q, Descale_k, Descale_v, Descale_s, Scale_s, Descale_dP, Scale_dP, Descale_o, Descale_dO, Scale_dQ, Scale_dK, Scale_dV + if (is_delayed_scaling || is_current_scaling) { + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_v"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_s"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_s"); + descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); + scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); + if (is_current_scaling && is_O_in_F16) { + descale_o = mha_graph->tensor(1.0f); + } else { + descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); + } + descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); + if (is_delayed_scaling) { + scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); + scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); + scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + } + if (is_current_scaling) { + scale_dQ = mha_graph->tensor(1.0f); + scale_dK = mha_graph->tensor(1.0f); + scale_dV = mha_graph->tensor(1.0f); + } + } else if (is_mxfp8) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + NVTE_QKV_Format q_scale_inv_format = + (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) ? qkv_scale_inv_format : q_format; + NVTE_QKV_Format kv_scale_inv_format = + (qkv_scale_inv_format != NVTE_QKV_Format_NOT_SET) ? qkv_scale_inv_format : kv_format; + NVTE_QKV_Format do_scale_format_ = + (do_scale_inv_format != NVTE_QKV_Format_NOT_SET) ? do_scale_inv_format : do_format; + // Q_t, K_t, dO_t, dO_f16 + std::vector q_t_strides(4), k_t_strides(4), dO_t_strides(4); + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_strides.data(), q_format); + generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_strides.data(), kv_format); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_strides.data(), do_format); + Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q_t") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_t_strides) + .set_data_type(qkv_tensor_type)); + K_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K_t") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_t_strides) + .set_data_type(qkv_tensor_type)); + dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_t") + .set_dim({b, h, s_q, d_v}) + .set_stride(dO_t_strides) + .set_data_type(do_tensor_type)); + dO_f16 = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_f16") + .set_dim({b, h, s_q, d_v}) + .set_stride(dO_strides) + .set_data_type(o_tensor_type)); + // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t + auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); + std::vector q_scale_strides(4), q_t_scale_strides(4), k_scale_strides(4), + k_t_scale_strides(4), v_scale_strides(4), dO_scale_strides(4), dO_t_scale_strides(4); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, + q_scale_strides.data(), q_scale_inv_format); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, + q_t_scale_strides.data(), q_scale_inv_format); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, + k_scale_strides.data(), kv_scale_inv_format); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, + k_t_scale_strides.data(), kv_scale_inv_format); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, + v_scale_strides.data(), kv_scale_inv_format); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, + dO_scale_strides.data(), do_scale_format_); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, + dO_t_scale_strides.data(), do_scale_format_); + descale_q = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_q_t = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q_t") + .set_dim({b, h, padded.s_q_scale_padded, padded.d_qk_padded}) + .set_stride(q_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k") + .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) + .set_stride(k_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k_t = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k_t") + .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_qk_padded}) + .set_stride(k_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_v = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_v") + .set_dim({b, hg, padded.s_kv_padded, padded.d_v_scale_padded}) + .set_stride(v_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_dO = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_dO") + .set_dim({b, h, padded.s_q_padded, padded.d_v_scale_padded}) + .set_stride(dO_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_dO_t = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_dO_t") + .set_dim({b, h, padded.s_q_scale_padded, padded.d_v_padded}) + .set_stride(dO_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options; @@ -2198,6 +2431,20 @@ void fused_attn_fp8_bwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); + fe::DiagonalAlignment_t const& diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); + + if (cudnn_runtime_version >= 92100) { + if (window_size_left != -1) { + sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); + } + if (window_size_right != -1) { + sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + } + } + // sdpa_backward_options.set_alibi_mask(is_alibi); // if (is_bias) { @@ -2251,40 +2498,75 @@ void fused_attn_fp8_bwd_impl_v1( sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( - q, k, v, o, dO, stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, - descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, sdpa_backward_options); + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_backward_options.set_sink_token(softmax_offset); + d_softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("d_softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_backward_options.set_dsink_token(d_softmax_offset); + } - dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); - dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); - dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); - amax_dQ->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); - amax_dK->set_output(true) + std::shared_ptr dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP; + if (is_delayed_scaling || is_current_scaling) { + std::tie(dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP) = + std::apply([](const auto&... elems) { return std::make_tuple(elems...); }, + mha_graph->sdpa_fp8_backward(Q, K, V, O, dO, Stats, descale_q, descale_k, + descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, + scale_dV, scale_dP, sdpa_backward_options)); + } else if (is_mxfp8) { + std::tie(dQ, dK, dV, amax_dQ, amax_dK, amax_dV) = std::apply( + [](const auto&... elems) { return std::make_tuple(elems...); }, + mha_graph->sdpa_fp8_backward(Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, + descale_q_t, descale_k, descale_k_t, descale_v, descale_dO, + descale_dO_t, sdpa_backward_options)); + } + std::vector dq_strides(4), dk_strides(4), dv_strides(4); + generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, dq_strides.data(), + dk_strides.data(), dv_strides.data(), dqkv_layout); + dQ->set_output(true) + .set_dim({b, h, s_q, d_qk}) + .set_stride(dq_strides) + .set_data_type(dqkv_tensor_type); + dK->set_output(true) + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(dk_strides) + .set_data_type(dqkv_tensor_type); + dV->set_output(true) + .set_dim({b, hg, s_kv, d_v}) + .set_stride(dv_strides) + .set_data_type(dqkv_tensor_type); + amax_dQ->set_output(!is_mxfp8) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - amax_dV->set_output(true) + amax_dK->set_output(!is_mxfp8) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - amax_dP->set_output(true) + amax_dV->set_output(!is_mxfp8) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); + if (is_delayed_scaling || is_current_scaling) { + amax_dP->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + } - dO->set_data_type(do_tensor_type); - dQ->set_data_type(dqkv_tensor_type); - dK->set_data_type(dqkv_tensor_type); - dV->set_data_type(dqkv_tensor_type); - - std::tuple, // q - std::shared_ptr, // k - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // stats + std::tuple, // Q + std::shared_ptr, // K + std::shared_ptr, // V + std::shared_ptr, // O + std::shared_ptr, // Stats std::shared_ptr, // dO std::shared_ptr, // attn_scale std::shared_ptr, // descale_q @@ -2307,10 +2589,16 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr, // amax_dV std::shared_ptr> // amax_dP key_tensors_tuple = std::make_tuple( - q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP); + auto mxfp8_tensors_tuple = + is_mxfp8 ? std::make_tuple(Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); + auto softmax_offset_tuple = is_softmax_offset + ? std::make_tuple(softmax_offset, d_softmax_offset) + : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) @@ -2322,17 +2610,18 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, - padding_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, mxfp8_tensors_tuple, + bias_tuple, softmax_offset_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - - auto [mha_graph, q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + auto [mha_graph, Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, - dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, bias, dBias, seq_q, seq_kv, dropout_seed, - dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); + dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, Q_t, K_t, dO_f16, dO_t, descale_q_t, + descale_k_t, descale_dO_t, bias, dBias, softmax_offset, d_softmax_offset, seq_q, seq_kv, + dropout_seed, dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); auto plan_workspace_size = mha_graph->get_workspace_size(); @@ -2349,37 +2638,47 @@ void fused_attn_fp8_bwd_impl_v1( // build variant pack std::unordered_map, void*> variant_pack = { - {q, devPtrQ}, - {k, devPtrK}, - {v, devPtrV}, - {o, devPtrO}, - {stats, devPtrM}, + {Q, devPtrQ}, + {K, devPtrK}, + {V, devPtrV}, + {O, devPtrO}, + {Stats, devPtrM}, {dO, devPtrdO}, {attn_scale, &scaling_factor}, {descale_q, devPtrDescaleQ}, {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, {descale_dO, devPtrDescaledO}, - {descale_s, devPtrDescaleS}, - {descale_dP, devPtrDescaledP}, - {scale_s, devPtrScaleS}, - {scale_dP, devPtrScaledP}, {dQ, devPtrdQ}, {dK, devPtrdK}, {dV, devPtrdV}, - {amax_dQ, devPtrAmaxdQ}, - {amax_dK, devPtrAmaxdK}, - {amax_dV, devPtrAmaxdV}, - {amax_dP, devPtrAmaxdP}, }; - + if (is_delayed_scaling || is_current_scaling) { + variant_pack[descale_s] = devPtrDescaleS; + variant_pack[descale_dP] = devPtrDescaledP; + variant_pack[scale_s] = devPtrScaleS; + variant_pack[scale_dP] = devPtrScaledP; + variant_pack[amax_dP] = devPtrAmaxdP; + variant_pack[amax_dQ] = devPtrAmaxdQ; + variant_pack[amax_dK] = devPtrAmaxdK; + variant_pack[amax_dV] = devPtrAmaxdV; + } + if (is_delayed_scaling || (is_current_scaling && !is_O_in_F16)) { + variant_pack[descale_o] = devPtrDescaleO; + } if (is_delayed_scaling) { variant_pack[scale_dQ] = devPtrScaledQ; variant_pack[scale_dK] = devPtrScaledK; variant_pack[scale_dV] = devPtrScaledV; } - if (!is_O_in_F16) { - variant_pack[descale_o] = devPtrDescaleO; + if (is_mxfp8) { + variant_pack[Q_t] = devPtrQ_t; + variant_pack[K_t] = devPtrK_t; + variant_pack[dO_f16] = devPtrdO_f16; + variant_pack[dO_t] = devPtrdO_t; + variant_pack[descale_q_t] = devPtrDescaleQ_t; + variant_pack[descale_k_t] = devPtrDescaleK_t; + variant_pack[descale_dO_t] = devPtrDescaledO_t; } /* if (is_bias) { @@ -2410,70 +2709,100 @@ void fused_attn_fp8_bwd_impl_v1( variant_pack[dropout_offset] = devPtrDropoutOffset; } + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); } -} - -#endif +} // NOLINT(readability/fn_size) } // namespace fused_attn -#if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V -void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor* input_Q, const Tensor* input_K, - const Tensor* input_V, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, - cudaStream_t stream, cudnnHandle_t handle) { +void fused_attn_fp8_fwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, + bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, + const Tensor* input_SoftmaxOffset, Tensor* input_output_S, Tensor* output_O, + NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, + const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - void* devPtrQ = input_Q->data.dptr; - void* devPtrK = input_K->data.dptr; - void* devPtrV = input_V->data.dptr; - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_Q->scale_inv.dptr; - void* devPtrDescaleV = input_Q->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; - + void *devPtrQ = nullptr, *devPtrK = nullptr, *devPtrV = nullptr; + void *devPtrDescaleQ = nullptr, *devPtrDescaleK = nullptr, *devPtrDescaleV = nullptr; + void *devPtrO = nullptr, *devPtrAmaxO = nullptr, *devPtrScaleO = nullptr; + void *devPtrAmaxS = nullptr, *devPtrScaleS = nullptr, *devPtrDescaleS = nullptr; + devPtrQ = input_Q->data.dptr; + devPtrDescaleQ = input_Q->scale_inv.dptr; + devPtrK = input_K->data.dptr; + devPtrDescaleK = input_K->scale_inv.dptr; + devPtrO = output_O->data.dptr; + if (input_Q->scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + devPtrV = input_V->data.dptr; + devPtrDescaleV = input_V->scale_inv.dptr; + devPtrScaleO = output_O->scale.dptr; + devPtrAmaxS = input_output_S->amax.dptr; + devPtrScaleS = input_output_S->scale.dptr; + devPtrDescaleS = input_output_S->scale_inv.dptr; + devPtrAmaxO = output_O->amax.dptr; + } else if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { + devPtrV = input_V->columnwise_data.dptr; + devPtrDescaleV = input_V->columnwise_scale_inv.dptr; + } + void* devPtrSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + } void* devPtrM = nullptr; void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 3; - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + int i = 0; + Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_M->data.dptr = nullptr; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_M->data.dtype = DType::kFloat32; - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_ZInv->data.dtype = DType::kFloat32; + if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_ZInv->data.dptr = nullptr; + output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + output_ZInv->data.dtype = DType::kFloat32; + } + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor* output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = nullptr; + output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; + output_softmax_offset->data.dtype = DType::kFloat32; + } + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 2) { + int i = 0; + Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrM = output_M->data.dptr; - devPtrZInv = output_ZInv->data.dptr; + devPtrZInv = nullptr; + if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrZInv = output_ZInv->data.dptr; + } + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor* output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = devPtrSoftmaxOffset; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - void* devPtrcuSeqlensQ = reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); void* devPtrcuSeqlensKV = @@ -2488,17 +2817,20 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || + (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + is_training, attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, + devPtrV, devPtrSoftmaxOffset, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, + devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, + qkv_scale_inv_format, workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, is_training, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, @@ -2521,24 +2853,35 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } } // fused attention BWD FP8 with separate Q, K, V -void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic, - const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, - const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, - const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, - const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, - const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { +void fused_attn_fp8_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, + bool bottom_right_diagonal, bool deterministic, const Tensor* input_Q, const Tensor* input_K, + const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, + const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, + const Tensor* input_S, const Tensor* input_SoftmaxOffset, Tensor* input_output_dP, + const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, + Tensor* output_dSoftmaxOffset, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, + const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = input_Q->data.dptr; void* devPtrK = input_K->data.dptr; void* devPtrV = input_V->data.dptr; void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_Q->scale_inv.dptr; - void* devPtrDescaleV = input_Q->scale_inv.dptr; + void* devPtrDescaleK = input_K->scale_inv.dptr; + void* devPtrDescaleV = input_V->scale_inv.dptr; + void *devPtrQ_t = nullptr, *devPtrK_t = nullptr, *devPtrDescaleQ_t = nullptr, + *devPtrDescaleK_t = nullptr; + if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { + devPtrQ_t = input_Q->columnwise_data.dptr; + devPtrDescaleQ_t = input_Q->columnwise_scale_inv.dptr; + devPtrK_t = input_K->columnwise_data.dptr; + devPtrDescaleK_t = input_K->columnwise_scale_inv.dptr; + } void* devPtrO = input_O->data.dptr; const DType O_type = input_O->data.dtype; @@ -2548,25 +2891,46 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; + void *devPtrdO_t = nullptr, *devPtrdO_f16 = nullptr, *devPtrDescaledO_t = nullptr; + if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { + devPtrdO_t = input_dO->columnwise_data.dptr; + devPtrdO_f16 = input_dO_f16->data.dptr; + devPtrDescaledO_t = input_dO->columnwise_scale_inv.dptr; + } void* devPtrM = input_M->data.dptr; - void* devPtrZInv = input_ZInv->data.dptr; + void* devPtrZInv = (input_ZInv != nullptr) ? input_ZInv->data.dptr : nullptr; + + void *devPtrScaleS = nullptr, *devPtrDescaleS = nullptr, *devPtrAmaxdP = nullptr, + *devPtrScaledP = nullptr, *devPtrDescaledP = nullptr; + if (input_Q->scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + devPtrScaleS = input_S->scale.dptr; + devPtrDescaleS = input_S->scale_inv.dptr; + devPtrAmaxdP = input_output_dP->amax.dptr; + devPtrScaledP = input_output_dP->scale.dptr; + devPtrDescaledP = input_output_dP->scale_inv.dptr; + } - void* devPtrScaleS = input_S->scale.dptr; - void* devPtrDescaleS = input_S->scale_inv.dptr; - void* devPtrAmaxdP = input_output_dP->amax.dptr; - void* devPtrScaledP = input_output_dP->scale.dptr; - void* devPtrDescaledP = input_output_dP->scale_inv.dptr; + void* devPtrSoftmaxOffset = nullptr; + void* devPtrdSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; + } void* devPtrdQ = output_dQ->data.dptr; void* devPtrdK = output_dK->data.dptr; void* devPtrdV = output_dV->data.dptr; - void* devPtrAmaxdQ = output_dQ->amax.dptr; - void* devPtrAmaxdK = output_dQ->amax.dptr; - void* devPtrAmaxdV = output_dQ->amax.dptr; - void* devPtrScaledQ = output_dQ->scale.dptr; - void* devPtrScaledK = output_dQ->scale.dptr; - void* devPtrScaledV = output_dQ->scale.dptr; + void *devPtrAmaxdQ = nullptr, *devPtrAmaxdK = nullptr, *devPtrAmaxdV = nullptr, + *devPtrScaledQ = nullptr, *devPtrScaledK = nullptr, *devPtrScaledV = nullptr; + if (input_Q->scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + devPtrAmaxdQ = output_dQ->amax.dptr; + devPtrAmaxdK = output_dK->amax.dptr; + devPtrAmaxdV = output_dV->amax.dptr; + devPtrScaledQ = output_dQ->scale.dptr; + devPtrScaledK = output_dK->scale.dptr; + devPtrScaledV = output_dV->scale.dptr; + } void* devPtrcuSeqlensQ = reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); @@ -2582,21 +2946,29 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou const DType dQKV_type = output_dQ->data.dtype; size_t workspace_size = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + NVTE_QKV_Format dqkv_format = nvte_get_qkv_format(dqkv_layout); + if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD) || + (dqkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, deterministic, devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, - devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, - devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, - devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, + devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrSoftmaxOffset, + devPtrdQ, devPtrdK, devPtrdV, devPtrdSoftmaxOffset, devPtrDescaleQ, devPtrDescaleK, + devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, + devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, + devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, devPtrdO_f16, devPtrdO_t, + devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + input_dO->scaling_mode, qkv_scale_inv_format, do_scale_inv_format, workspace->data.dptr, + &workspace_size, stream, handle); + } else if (dqkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + // remove this when cuDNN FE supports FP8 + THD + NVTE_CHECK(input_ZInv != nullptr && input_ZInv->data.dptr != nullptr, + "ZInv tensor required for FP8 fused attention backward with T3HD layout."); fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, @@ -2619,5 +2991,4 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou return; } } -#endif // end of CUDNN>=8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 225e700eff..aaf5039eeb 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -12,29 +12,31 @@ #include "transformer_engine/transformer_engine.h" namespace transformer_engine { -#if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V -void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); +void fused_attn_fp8_fwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, + bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_SoftmaxOffset, Tensor *input_output_S, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); // fused attention BWD FP8 with separate Q, K, V -void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, - const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, - const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); -#endif // end of CUDNN>=8900 +void fused_attn_fp8_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, + bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, + const Tensor *input_S, const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, + const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, + Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index a897b09330..f37eeb0c68 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -293,6 +293,27 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_dim_idx] = 1; } break; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = h * s_q * d; + strideA[head_dim_idx] = s_q * d; + strideA[seqlen_dim_idx] = d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = h * s_kv * d; + strideA[head_dim_idx] = s_kv * d; + strideA[seqlen_dim_idx] = d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = h * s_kv * d; + strideA[head_dim_idx] = s_kv * d; + strideA[seqlen_transpose_dim_idx] = d; + strideA[hidden_transpose_dim_idx] = 1; + } + break; } if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 1ec1616c4a..c3736a6c65 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -14,6 +14,7 @@ #include #include +#include "../common.h" #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" @@ -27,11 +28,198 @@ enum NVTE_QKV_Matrix { NVTE_K_Matrix = 1, // keys NVTE_K_Matrix_Transpose = 2, // keys transposed NVTE_V_Matrix = 3, // values - NVTE_V_Matrix_Transpose = 4, // value matrix transposed + NVTE_V_Matrix_Transpose = 4, // values transposed NVTE_S_Matrix = 5, // output of GEMM1 NVTE_O_Matrix = 6, // final output }; +// Padded sizes for MXFP8 layout (s_q/s_kv/d_qk/d_v and their scaled dimensions) +struct MXFP8PaddedSizes { + int64_t s_q_padded; + int64_t s_kv_padded; + int64_t s_q_scale; + int64_t s_kv_scale; + int64_t s_q_scale_padded; + int64_t s_kv_scale_padded; + int64_t d_qk_padded; + int64_t d_v_padded; + int64_t d_qk_scale; + int64_t d_v_scale; + int64_t d_qk_scale_padded; + int64_t d_v_scale_padded; +}; + +// Pad s and d for MXFP8 quantization +inline MXFP8PaddedSizes pad_s_d_for_mxfp8(int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v) { + constexpr int64_t block_size = 32; + MXFP8PaddedSizes p; + p.s_q_padded = DIVUP_TO_MULTIPLE(s_q, 128); + p.s_kv_padded = DIVUP_TO_MULTIPLE(s_kv, 128); + p.s_q_scale = DIVUP(s_q, block_size); + p.s_kv_scale = DIVUP(s_kv, block_size); + p.s_q_scale_padded = DIVUP_TO_MULTIPLE(p.s_q_scale, 4); + p.s_kv_scale_padded = DIVUP_TO_MULTIPLE(p.s_kv_scale, 4); + p.d_qk_padded = DIVUP_TO_MULTIPLE(d_qk, 128); + p.d_v_padded = DIVUP_TO_MULTIPLE(d_v, 128); + p.d_qk_scale = DIVUP(d_qk, block_size); + p.d_v_scale = DIVUP(d_v, block_size); + p.d_qk_scale_padded = DIVUP_TO_MULTIPLE(p.d_qk_scale, 4); + p.d_v_scale_padded = DIVUP_TO_MULTIPLE(p.d_v_scale, 4); + return p; +} + +// Get matrix strides for a 4D tensor [batch_size, num_heads, sequence_len, head_dim] given a QKV format. +// strides must point to at least 4 int64_t elements. +inline void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, + int64_t *strides, NVTE_QKV_Format format) { + constexpr int b_dim = 0; + constexpr int h_dim = 1; + constexpr int s_dim = 2; + constexpr int d_dim = 3; + + switch (format) { + case NVTE_QKV_Format::NVTE_BSHD: + case NVTE_QKV_Format::NVTE_THD: + strides[b_dim] = s * h * d; + strides[h_dim] = d; + strides[s_dim] = h * d; + strides[d_dim] = 1; + break; + case NVTE_QKV_Format::NVTE_SBHD: + strides[b_dim] = h * d; + strides[h_dim] = d; + strides[s_dim] = b * h * d; + strides[d_dim] = 1; + break; + case NVTE_QKV_Format::NVTE_BHSD: + strides[b_dim] = h * s * d; + strides[h_dim] = s * d; + strides[s_dim] = d; + strides[d_dim] = 1; + break; + default: + NVTE_CHECK(false, "Invalid format."); + break; + } +} + +// get matrix strides based on layout and matrix type +inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, int64_t s_q, + int64_t s_kv, int64_t d_qk, int64_t d_v, + int64_t *q_strides, int64_t *k_strides, + int64_t *v_strides, NVTE_QKV_Layout layout) { + constexpr int b_dim = 0; + constexpr int h_dim = 1; + constexpr int s_dim = 2; + constexpr int d_dim = 3; + const NVTE_QKV_Format q_format = nvte_get_q_format(layout); + const NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); + + switch (layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + q_strides[b_dim] = 3 * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = b * 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + k_strides[i] = v_strides[i] = q_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_SBH3D: + q_strides[b_dim] = 3 * h * d_qk; + q_strides[h_dim] = 3 * d_qk; + q_strides[s_dim] = b * 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + k_strides[i] = v_strides[i] = q_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_strides, q_format); + k_strides[b_dim] = 2 * hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = b * 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + v_strides[i] = k_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_strides, q_format); + k_strides[b_dim] = 2 * hg * d_qk; + k_strides[h_dim] = 2 * d_qk; + k_strides[s_dim] = b * 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + v_strides[i] = k_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_T3HD: + q_strides[b_dim] = s_q * 3 * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + k_strides[i] = v_strides[i] = q_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_TH3D: + q_strides[b_dim] = s_q * 3 * h * d_qk; + q_strides[h_dim] = 3 * d_qk; + q_strides[s_dim] = 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + k_strides[i] = v_strides[i] = q_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_strides, q_format); + k_strides[b_dim] = s_kv * 2 * hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + v_strides[i] = k_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_strides, q_format); + k_strides[b_dim] = s_kv * 2 * hg * d_qk; + k_strides[h_dim] = 2 * d_qk; + k_strides[s_dim] = 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + v_strides[i] = k_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_strides, q_format); + generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_strides, kv_format); + generateMatrixStridesWithFormat(b, hg, s_kv, d_v, v_strides, kv_format); + break; + default: + NVTE_CHECK(false, "Invalid layout."); + break; + } +} + void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix); @@ -106,7 +294,12 @@ struct FADescriptor_v1 { float attnScale; bool isTraining; float dropoutProbability; - NVTE_QKV_Layout layout; + NVTE_QKV_Layout qkv_layout; + NVTE_QKV_Format o_format; + NVTE_QKV_Format do_format; + NVTE_QKV_Layout dqkv_layout; + NVTE_QKV_Format qkv_scale_inv_format; + NVTE_QKV_Format do_scale_inv_format; NVTE_Bias_Type bias_type; NVTE_Mask_Type mask_type; NVTE_Softmax_Type softmax_type; @@ -123,17 +316,19 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, - bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, + bias_skv, attnScale, isTraining, dropoutProbability, qkv_layout, o_format, + do_format, dqkv_layout, qkv_scale_inv_format, do_scale_inv_format, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, return_max_logit) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, - rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, - rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, - rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type, - rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, + rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.qkv_layout, + rhs.o_format, rhs.do_format, rhs.dqkv_layout, rhs.qkv_scale_inv_format, + rhs.do_scale_inv_format, rhs.mask_type, rhs.softmax_type, rhs.window_size_left, + rhs.window_size_right, rhs.bottom_right_diagonal, rhs.deterministic, + rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.return_max_logit); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 8d9adeb620..912dc32d35 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -52,6 +52,8 @@ enum NVTE_QKV_Layout { NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */ NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */ NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */ + NVTE_BHSD_BHSD_BHSD = 25, /*!< BHSD_BHSD_BHSD layout */ + NVTE_QKV_Layout_NOT_SET, /*!< Not set */ }; /*! \enum NVTE_QKV_Layout_Group @@ -70,6 +72,8 @@ enum NVTE_QKV_Layout_Group { NVTE_HD_HD_HD = 4, /*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */ NVTE_Paged_KV_HD_HD_HD = 5, + /*! SD_SD_SD QKV layouts, e.g. BHSD_BHSD_BHSD */ + NVTE_SD_SD_SD = 6, }; /*! \enum NVTE_QKV_Format @@ -90,6 +94,10 @@ enum NVTE_QKV_Format { NVTE_THD_2BSHD = 5, /*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */ NVTE_THD_2SBHD = 6, + /*! BHSD QKV format, e.g. BHSD_BHSD_BHSD */ + NVTE_BHSD = 7, + /*! Not set */ + NVTE_QKV_Format_NOT_SET, }; /*! \enum NVTE_Bias_Type @@ -274,6 +282,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. + * \param[in] o_format Output format. + * \param[in] qkv_scale_inv_format Format of scale-inverse tensors for QKV; + * if NVTE_QKV_Format_NOT_SET, inferred from qkv_layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. * \param[in] softmax_type Attention softmax type. @@ -292,7 +303,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); @@ -347,6 +359,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. + * \param[in] o_format Output format. + * \param[in] do_format Output gradient's format. + * \param[in] dqkv_layout QKV gradient tensors' layout. + * \param[in] qkv_scale_inv_format Format of scale-inverse tensors for QKV; + * if NVTE_QKV_Format_NOT_SET, inferred from qkv_layout. + * \param[in] do_scale_inv_format Format of scale-inverse tensors for dO; + * if NVTE_QKV_Format_NOT_SET, inferred from the output layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. * \param[in] softmax_type Attention softmax type. @@ -366,11 +385,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream); + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * @@ -584,8 +605,81 @@ void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t s void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv, cudaStream_t stream); +/*! \brief Transpose multiple tensors from BSHD/SBHD to BHSD. + * + * Each input tensor is 4D in BSHD or SBHD layout, and the corresponding output tensor + * is 4D in BHSD layout. Output tensors are pre-allocated and may have a larger last dimension. + * + * \param[in] inputs List of input tensors. + * \param[in,out] outputs List of output tensors. + * \param[in] num_tensors Number of tensors in the list. + * \param[in] original_format Original QKV format (NVTE_BSHD or NVTE_SBHD). + * \param[in] stream CUDA stream. + */ +void nvte_multi_tensor_transpose_to_bhsd(NVTETensor *inputs, NVTETensor *outputs, + size_t num_tensors, NVTE_QKV_Format original_format, + cudaStream_t stream); + +/*! \brief Pad the last dimension of multiple 2D tensors with zeros in one kernel launch. + * + * Each tensor copies a row-major (rows, in_cols) input to a (rows, out_cols) output, + * zero-filling the region [in_cols, out_cols) in every row. + * Outputs must be pre-allocated with out_cols >= in_cols and matching dtype. + * + * \param[in] inputs List of input tensors. + * \param[in,out] outputs List of output tensors. + * \param[in] num_tensors Number of tensors in the list. + * \param[in] stream CUDA stream. + */ +void nvte_multi_tensor_pad_last_dim(NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" -#endif + +#include +#include +#include + +/*! \brief Parses a QKV tensor shape into canonical (b, h, s, d, t) dimensions + * and converts between QKV formats. + */ +class AttentionShape { + public: + inline AttentionShape(NVTE_QKV_Format fmt, const size_t *shape) : canonical_{} { + auto [ndim, order] = dim_order(fmt); + for (size_t i = 0; i < ndim; ++i) canonical_[order[i]] = shape[i]; + } + + size_t b() const { return canonical_[0]; } + size_t h() const { return canonical_[1]; } + size_t s() const { return canonical_[2]; } + size_t d() const { return canonical_[3]; } + size_t t() const { return canonical_[4]; } + + inline void to_format(NVTE_QKV_Format dst_fmt, size_t *dst_shape) const { + auto [ndim, order] = dim_order(dst_fmt); + for (size_t i = 0; i < ndim; ++i) dst_shape[i] = canonical_[order[i]]; + } + + private: + static inline std::pair> dim_order(NVTE_QKV_Format fmt) { + switch (fmt) { + case NVTE_QKV_Format::NVTE_BSHD: + return {4, {0, 2, 1, 3}}; // b s h d + case NVTE_QKV_Format::NVTE_SBHD: + return {4, {2, 0, 1, 3}}; // s b h d + case NVTE_QKV_Format::NVTE_BHSD: + return {4, {0, 1, 2, 3}}; // b h s d + case NVTE_QKV_Format::NVTE_THD: + return {3, {4, 1, 3, -1}}; // t h d + default: + return {0, {}}; + } + } + size_t canonical_[5] = {}; +}; + +#endif // __cplusplus #endif diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index 4e28de3beb..396093b543 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -32,10 +32,10 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud /*! \brief Swizzling scaling factors into the required interleaved layout for GEMM * - * \param[in] inputs Input tensors with non-swizzled scale_inv. - * \param[in,out] outputs Output tensors which hosts swizzled scale_inv. - * \param[in] num_tensors Number of input and output tensors. - * \param[in] stream CUDA stream used for the operation. + * \param[in] inputs Input tensors with non-swizzled scale_inv. + * \param[in,out] outputs Output tensors which hosts swizzled scale_inv. + * \param[in] num_tensors Number of input and output tensors. + * \param[in] stream CUDA stream used for the operation. * * Requirements: * - scale_inv is stored in row-major. @@ -45,6 +45,17 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, const size_t num_tensors, cudaStream_t stream); +/*! \brief Same as nvte_multi_tensor_swizzle_scaling_factors, but skips + * scale_inv shape/padding validation. + * + * Use this variant when the data and scale_inv tensors intentionally have + * different shapes, e.g. when scale_invs have been transposed for attention. + */ +void nvte_multi_tensor_swizzle_scaling_factors_unchecked(const NVTETensor* inputs, + NVTETensor* outputs, + const size_t num_tensors, + cudaStream_t stream); + /*! \brief Unswizzling scaling factors from the interleaved layout used by GEMM back to row-major * * \param[in] input Input tensor with swizzled scale_inv. diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 6c59776245..ad4a130928 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -21,6 +21,17 @@ namespace { constexpr int MXFP8_BLOCK_SIZE = 32; constexpr int NVFP4_BLOCK_SIZE = 16; +int get_max_dynamic_smem() { + static int max_smem = -1; + if (max_smem < 0) { + int device; + NVTE_CHECK_CUDA(cudaGetDevice(&device)); + NVTE_CHECK_CUDA( + cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + } + return max_smem; +} + constexpr __device__ __host__ int TB_DIM = 32; constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16; constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4; @@ -80,7 +91,11 @@ __device__ inline void regs_unshuffle_with_bit_shifts(LType* regs_vec) { for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i]; } -template +// IS_PADDED_K / IS_PADDED_M select the boundary-block specialization at compile +// time so the inner load loop avoids the per-iteration runtime checks. The +// caller computes the runtime predicates from blockIdx/gridDim once per block +// (uniform across the block) and dispatches to the right specialization. +template __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, const int M, const int K, const int original_M, const int original_K, const int bid_x, @@ -106,9 +121,6 @@ __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1; } - bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); - bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); - const int input_offset = bid_x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + bid_y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; const int32_t* input_i32 = reinterpret_cast(input) + input_offset; @@ -121,19 +133,37 @@ __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, extern __shared__ int slm[]; // load, global -> regs + // Each register read for a given i is along the M direction at K-coord + // (bid_x * TB_DIM * SF_TILE_DIM_K + threadIdx.y * SF_TILE_DIM_K + i). When that + // K-coord is past original_K, the entire register is out of the per-tensor data + // region (which may be the unpadded compact extent), so we must NOT issue the + // __ldg there -- it could read past the per-tensor buffer (and, for the last + // tensor in a grouped allocation, past the end of the allocation entirely). LType regs_vec[N_SF_PER_TD_PER_TILE]; if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 && threadIdx.y < k_tiles_in_tb) { + const int k_base = bid_x * TB_DIM * SF_TILE_DIM_K + threadIdx.y * SF_TILE_DIM_K; #pragma unroll for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { const int thread_offset = (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD; + const int k_coord = k_base + i; + if constexpr (IS_PADDED_K) { + if (k_coord >= original_K) { + // Entire register is past original_K: zero directly without loading. + uint8_t* zero_bytes = reinterpret_cast(regs_vec + i); +#pragma unroll + for (int j = 0; j < static_cast(sizeof(LType)); j++) zero_bytes[j] = 0; + continue; + } + } regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); - // Pad zeros - if (padding_m || padding_k) { + // Per-byte M masking is still needed when only part of the register is past + // original_M (i.e. K-coord is in range but the M position spans the boundary). + if constexpr (IS_PADDED_M) { for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { const int index = (input_offset + thread_offset) * sizeof(int) + j; - if (index / M >= original_K || index % M >= original_M) { + if (index % M >= original_M) { reinterpret_cast(regs_vec + i)[j] = 0; } } @@ -172,12 +202,43 @@ __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, } } +// Dispatch helper: pick the right (IS_PADDED_K, IS_PADDED_M) col-scaling impl +// specialization at runtime based on the per-block padding predicates. The +// branching here is uniform across all threads in the block, so the indirect +// path each block takes still inlines cleanly. +template +__device__ __forceinline__ void dispatch_swizzle_col_scaling_kernel_impl( + const void* input, void* output, const int M, const int K, const int original_M, + const int original_K, const int bid_x, const int bid_y, const int grid_dim_x, + const int grid_dim_y, const bool padding_k, const bool padding_m) { + if (padding_k && padding_m) { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else if (padding_k) { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else if (padding_m) { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } +} + template __global__ void __launch_bounds__(TB_DIM* TB_DIM) swizzle_col_scaling_kernel(const void* input, void* output, const int M, const int K, const int original_M, const int original_K) { - swizzle_col_scaling_kernel_impl( - input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); + const bool padding_m = (blockIdx.y == gridDim.y - 1) && (original_M < M); + const bool padding_k = (blockIdx.x == gridDim.x - 1) && (original_K < K); + dispatch_swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y, + padding_k, padding_m); } template @@ -213,7 +274,11 @@ __device__ inline void regs_unshuffle(LType* regs_vec) { for (int i = 0; i < kVectorSize; i++) ptr[i] = tmp[i]; } -template +// IS_PADDED_K / IS_PADDED_M select the boundary-block specialization at compile +// time so the inner load loop avoids the per-iteration runtime checks. The +// caller computes the runtime predicates from blockIdx/gridDim once per block +// (uniform across the block) and dispatches to the right specialization. +template __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, const int M, const int K, const int original_M, const int original_K, const int bid_x, @@ -232,9 +297,6 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; } - bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); - bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); - const int input_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB; const int* input_i32 = reinterpret_cast(input) + input_offset; int* output_i32 = reinterpret_cast(output) + bid_y * SF_TILE_DIM_M_I32 * K_i32 + @@ -243,17 +305,35 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, extern __shared__ int4 slm_v4i[]; // load, global -> regs + // Each register read for a given i is along the K direction at row + // (bid_y * SF_TILE_DIM_M + i * TB_DIM + threadIdx.y). When that row is past + // original_M, the entire register is out of the per-tensor data region (which + // may be the unpadded compact extent), so we must NOT issue the __ldg there -- + // it could read past the per-tensor buffer (and, for the last tensor in a + // grouped allocation, past the end of the allocation entirely). LType regs_vec[N_SF_PER_TD_PER_TILE]; if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { #pragma unroll for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + const int row = bid_y * SF_TILE_DIM_M + i * TB_DIM + threadIdx.y; const int thread_offset = (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD; + if constexpr (IS_PADDED_M) { + if (row >= original_M) { + // Entire register is past original_M: zero directly without loading. + uint8_t* zero_bytes = reinterpret_cast(regs_vec + i); +#pragma unroll + for (int j = 0; j < static_cast(sizeof(LType)); j++) zero_bytes[j] = 0; + continue; + } + } regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); - if (padding_m || padding_k) { - // Pad zeros + // Per-byte K masking is still needed when only part of the register is past + // original_K (i.e. row is in range but the K position spans the boundary). + if constexpr (IS_PADDED_K) { +#pragma unroll for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { const int index = (input_offset + thread_offset) * sizeof(int) + j; - if (index / K >= original_M || index % K >= original_K) { + if (index % K >= original_K) { reinterpret_cast(regs_vec + i)[j] = 0; } } @@ -282,6 +362,202 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, } } +// Dispatch helper: pick the right (IS_PADDED_K, IS_PADDED_M) row-scaling impl +// specialization at runtime based on the per-block padding predicates. The +// branching here is uniform across all threads in the block, so the indirect +// path each block takes still inlines cleanly. +template +__device__ __forceinline__ void dispatch_swizzle_row_scaling_kernel_impl( + const void* input, void* output, const int M, const int K, const int original_M, + const int original_K, const int bid_x, const int bid_y, const int grid_dim_x, + const int grid_dim_y, const bool padding_k, const bool padding_m) { + if (padding_k && padding_m) { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else if (padding_k) { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else if (padding_m) { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } +} + +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K, + const int original_M, const int original_K) { + const bool padding_m = (blockIdx.y == gridDim.y - 1) && (original_M < M); + const bool padding_k = (blockIdx.x == gridDim.x - 1) && (original_K < K); + dispatch_swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y, + padding_k, padding_m); +} + +// Narrow-K specialization for row scaling swizzle. +// When K is small (num_tiles_k < TB_DIM), the standard kernel wastes threadIdx.x +// because there aren't enough K-tiles to distribute across threads. +// This kernel repurposes the thread dimensions: threadIdx.x iterates rows within +// an M-tile, threadIdx.y indexes M-tiles within the block, processing TB_DIM +// M-tiles per block with full thread utilization. +template +__device__ void swizzle_row_scaling_narrow_k_kernel_impl(const void* input, void* output, + const int M, const int K, + const int original_M, const int original_K, + const int bid, const int grid_dim) { + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + const int K_i32 = K / 4; + const int num_tiles_m = M / SF_TILE_DIM_M; + + const int m_tile = bid * blockDim.y + threadIdx.y; + const bool active = (m_tile < num_tiles_m); + + extern __shared__ int4 slm_v4i[]; + const int slm_tile_v4i = K_i32 * (SF_TILE_SIZE_I32 / 4); + + if (active) { + const bool padding_m = (m_tile == num_tiles_m - 1) && (original_M < M); + const bool padding_k = (original_K < K); + + int4* my_slm = slm_v4i + threadIdx.y * slm_tile_v4i; + + for (int k = 0; k < K_i32; k++) { + const int input_base = m_tile * SF_TILE_DIM_M * K_i32 + k; + const int* input_i32 = reinterpret_cast(input) + input_base; + + int regs[N_SF_PER_TD_PER_TILE]; +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + const int row = i * TB_DIM + threadIdx.x; + regs[i] = __ldg(input_i32 + row * K_i32); + if (padding_m || padding_k) { + for (int j = 0; j < 4; j++) { + const int byte_row = m_tile * SF_TILE_DIM_M + row; + const int byte_col = k * 4 + j; + if (byte_row >= original_M || byte_col >= original_K) { + reinterpret_cast(®s[i])[j] = 0; + } + } + } + } + + my_slm[k * (SF_TILE_SIZE_I32 / 4) + threadIdx.x] = *reinterpret_cast(regs); + } + } + + __syncthreads(); + + if (active) { + int4* my_slm = slm_v4i + threadIdx.y * slm_tile_v4i; + int4* out_v4i = + reinterpret_cast(reinterpret_cast(output) + m_tile * SF_TILE_DIM_M * K_i32); + + for (int i = threadIdx.x; i < slm_tile_v4i; i += blockDim.x) { + out_v4i[i] = my_slm[i]; + } + } +} + +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + swizzle_row_scaling_narrow_k_kernel(const void* input, void* output, const int M, const int K, + const int original_M, const int original_K) { + swizzle_row_scaling_narrow_k_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, gridDim.x); +} + +// Narrow-M variant of the column scaling swizzle kernel, for when num_tiles_m < TB_DIM. +// Analogous to the narrow-K row kernel: when the M dimension is small, the normal +// col kernel underutilizes threads in the load phase because threadIdx.x covers M +// positions with vectorized loads, leaving many threads idle. This kernel repurposes +// thread dimensions: threadIdx.y indexes K-tiles within the block, threadIdx.x covers +// one int32 column of an M-tile, and M-tiles are iterated serially. +template +__device__ void swizzle_col_scaling_narrow_m_kernel_impl(const void* input, void* output, + const int M, const int K, + const int original_M, const int original_K, + const int bid, const int grid_dim) { + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M / 4; + constexpr int SF_TILE_DIM_K_I32 = SF_TILE_DIM_K; + + const int M_i32 = M / 4; + const int K_i32 = K; + const int num_tiles_m = M / SF_TILE_DIM_M; + const int num_tiles_k = K / SF_TILE_DIM_K; + + const int k_tile = bid * blockDim.y + threadIdx.y; + const bool active = (k_tile < num_tiles_k); + const int remaining = num_tiles_k - bid * static_cast(blockDim.y); + const int k_tiles_in_block = remaining <= 0 ? 0 : (remaining < TB_DIM ? remaining : TB_DIM); + + extern __shared__ int slm_narrow_m[]; + + if (active) { + const bool padding_k = (k_tile == num_tiles_k - 1) && (original_K < K); + const int32_t* input_i32 = reinterpret_cast(input); + + for (int m_tile = 0; m_tile < num_tiles_m; m_tile++) { + const bool padding_m = (m_tile == num_tiles_m - 1) && (original_M < M); + + int regs[N_SF_PER_TD_PER_TILE]; +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + const int k_row = k_tile * SF_TILE_DIM_K_I32 + i; + const int m_col = m_tile * SF_TILE_DIM_M_I32 + threadIdx.x; + regs[i] = __ldg(input_i32 + k_row * M_i32 + m_col); + if (padding_m || padding_k) { + for (int j = 0; j < 4; j++) { + if (m_col * 4 + j >= original_M || k_row >= original_K) { + reinterpret_cast(®s[i])[j] = 0; + } + } + } + } + + regs_shuffle_with_bit_shifts(regs); + + int tM = threadIdx.x * N_SF_PER_TD_PER_TILE; + int* slm_tile = + slm_narrow_m + m_tile * TB_DIM * SF_TILE_SIZE_I32 + threadIdx.y * SF_TILE_SIZE_I32; +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + slm_tile[(tM % SF_TILE_DIM_M) / NEW_SF_TILE_DIM_M_I32 + + ((tM + i) % NEW_SF_TILE_DIM_M_I32) * NEW_SF_TILE_DIM_K_I32] = regs[i]; + } + } + } + + __syncthreads(); + + const int linear_id = threadIdx.y * blockDim.x + threadIdx.x; + for (int m_tile = 0; m_tile < num_tiles_m; m_tile++) { + int4* out_v4i = reinterpret_cast(reinterpret_cast(output) + + m_tile * SF_TILE_DIM_M_I32 * K_i32 + + bid * TB_DIM * SF_TILE_SIZE_I32); + int4* slm_v4i = reinterpret_cast(slm_narrow_m + m_tile * TB_DIM * SF_TILE_SIZE_I32); + const int n_v4i = k_tiles_in_block * SF_TILE_SIZE_I32 / 4; + for (int j = linear_id; j < n_v4i; j += blockDim.x * blockDim.y) { + out_v4i[j] = slm_v4i[j]; + } + } +} + +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + swizzle_col_scaling_narrow_m_kernel(const void* input, void* output, const int M, const int K, + const int original_M, const int original_K) { + swizzle_col_scaling_narrow_m_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, gridDim.x); +} + template __device__ void unswizzle_row_scaling_kernel_impl(const void* input, void* output, const int M, const int K, const int bid_x, const int bid_y, @@ -422,14 +698,6 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) } } -template -__global__ void __launch_bounds__(TB_DIM* TB_DIM) - swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K, - const int original_M, const int original_K) { - swizzle_row_scaling_kernel_impl( - input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); -} - constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB struct MultiSwizzleArgs { // (input) Data buffers for input scaling factors @@ -460,14 +728,21 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) grouped_swizzle_row_scaling_uniform_shape_kernel(const void* input, void* output, const int M, const int K, const int original_M, const int original_K, - const size_t scale_stride_bytes) { + const size_t input_stride_bytes, + const size_t output_stride_bytes) { const int tensor_id = blockIdx.z; + // Input and output strides may differ: input is in the kernel-produced "compact" + // layout (per-tensor stride = original_M * padded_k * elem_size) when callers + // pass the unswizzled grouped scale buffer as-is, while the output is always in + // the per-tensor padded ("swizzle-ready") layout (padded_m * padded_k * elem_size). const uint8_t* input_base = - reinterpret_cast(input) + tensor_id * scale_stride_bytes; - uint8_t* output_base = reinterpret_cast(output) + tensor_id * scale_stride_bytes; - swizzle_row_scaling_kernel_impl( + reinterpret_cast(input) + tensor_id * input_stride_bytes; + uint8_t* output_base = reinterpret_cast(output) + tensor_id * output_stride_bytes; + const bool padding_m = (blockIdx.y == gridDim.y - 1) && (original_M < M); + const bool padding_k = (blockIdx.x == gridDim.x - 1) && (original_K < K); + dispatch_swizzle_row_scaling_kernel_impl( input_base, output_base, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, - gridDim.y); + gridDim.y, padding_k, padding_m); } template @@ -475,14 +750,20 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) grouped_swizzle_col_scaling_uniform_shape_kernel(const void* input, void* output, const int M, const int K, const int original_M, const int original_K, - const size_t scale_stride_bytes) { + const size_t input_stride_bytes, + const size_t output_stride_bytes) { const int tensor_id = blockIdx.z; + // See the rowwise kernel for stride semantics. For columnwise the per-tensor + // compact stride is DIVUP(original_K, 1) * padded_m * elem_size (i.e. the + // unpadded scale-row count in the K direction times the padded M extent). const uint8_t* input_base = - reinterpret_cast(input) + tensor_id * scale_stride_bytes; - uint8_t* output_base = reinterpret_cast(output) + tensor_id * scale_stride_bytes; - swizzle_col_scaling_kernel_impl( + reinterpret_cast(input) + tensor_id * input_stride_bytes; + uint8_t* output_base = reinterpret_cast(output) + tensor_id * output_stride_bytes; + const bool padding_m = (blockIdx.y == gridDim.y - 1) && (original_M < M); + const bool padding_k = (blockIdx.x == gridDim.x - 1) && (original_K < K); + dispatch_swizzle_col_scaling_kernel_impl( input_base, output_base, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, - gridDim.y); + gridDim.y, padding_k, padding_m); } template @@ -583,8 +864,11 @@ __global__ void multi_tensor_swizzle_row_scaling_kernel(MultiSwizzleArgs kernel_ const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y; const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y; - swizzle_row_scaling_kernel_impl( - input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + const bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); + const bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); + dispatch_swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y, padding_k, + padding_m); } template @@ -613,8 +897,55 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y; const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y; - swizzle_col_scaling_kernel_impl( - input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + const bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); + const bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); + dispatch_swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y, padding_k, + padding_m); +} + +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + multi_tensor_swizzle_row_scaling_narrow_k_kernel(MultiSwizzleArgs kernel_args) { + const int bid = blockIdx.x; + int tensor_id = 0; + while (kernel_args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + const void* input = kernel_args.input_list[tensor_id]; + void* output = kernel_args.output_list[tensor_id]; + const int M = kernel_args.m_list[tensor_id]; + const int K = kernel_args.k_list[tensor_id]; + const int original_M = kernel_args.original_m_list[tensor_id]; + const int original_K = kernel_args.original_k_list[tensor_id]; + const int flat_bid = bid - kernel_args.block_range[tensor_id]; + const int num_tiles_m = M / SF_TILE_DIM_M; + const int grid_dim = DIVUP(num_tiles_m, TB_DIM); + + swizzle_row_scaling_narrow_k_kernel_impl( + input, output, M, K, original_M, original_K, flat_bid, grid_dim); +} + +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + multi_tensor_swizzle_col_scaling_narrow_m_kernel(MultiSwizzleArgs kernel_args) { + const int bid = blockIdx.x; + int tensor_id = 0; + while (kernel_args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + const void* input = kernel_args.input_list[tensor_id]; + void* output = kernel_args.output_list[tensor_id]; + const int M = kernel_args.m_list[tensor_id]; + const int K = kernel_args.k_list[tensor_id]; + const int original_M = kernel_args.original_m_list[tensor_id]; + const int original_K = kernel_args.original_k_list[tensor_id]; + const int flat_bid = bid - kernel_args.block_range[tensor_id]; + const int num_tiles_k = K / SF_TILE_DIM_K; + const int grid_dim = DIVUP(num_tiles_k, TB_DIM); + + swizzle_col_scaling_narrow_m_kernel_impl( + input, output, M, K, original_M, original_K, flat_bid, grid_dim); } } // namespace @@ -737,13 +1068,6 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s // Perform row-wise swizzle if (rowwise_swizzle) { - int vec_load_size = (num_tiles_k - 1) % 4 + 1; - /* there is no int3 and misaligned if using int4/int2 */ - if (vec_load_size == 3) vec_load_size = 1; - int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - int original_M{0}, original_K{0}; void *input_scale_inv_ptr{nullptr}, *output_scale_inv_ptr{nullptr}; switch (scaling_mode) { @@ -772,79 +1096,114 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s NVTE_ERROR("Invalid scaling mode"); } - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_row_scaling_kernel - <<>>( - input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); - break; - case 2: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_row_scaling_kernel - <<>>( - input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); - break; - case 1: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_row_scaling_kernel - <<>>( - input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; + const int narrow_k_slm_size = + TB_DIM * num_tiles_k * SF_TILE_DIM_M * SF_TILE_DIM_K * static_cast(sizeof(int8_t)); + if (num_tiles_k < TB_DIM && narrow_k_slm_size <= get_max_dynamic_smem()) { + // Narrow-K: batch TB_DIM M-tiles per block, fully utilizing all threads. + dim3 num_blocks_narrow(DIVUP(num_tiles_m, TB_DIM)); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_narrow_k_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, narrow_k_slm_size)); + swizzle_row_scaling_narrow_k_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + } else { + int vec_load_size = (num_tiles_k - 1) % 4 + 1; + /* there is no int3 and misaligned if using int4/int2 */ + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + case 2: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + case 1: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } } NVTE_CHECK_CUDA(cudaGetLastError()); } // Perform column-wise swizzle if (columnwise_swizzle) { - int vec_load_size = (num_tiles_m - 1) % 4 + 1; - if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ - int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); const int original_M = input->flat_last_dim(); const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, k, - original_M, original_K); - break; - case 2: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, k, - original_M, original_K); - break; - case 1: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, k, - original_M, original_K); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; + const int narrow_m_slm_size = + TB_DIM * num_tiles_m * SF_TILE_DIM_M * SF_TILE_DIM_K * static_cast(sizeof(int8_t)); + if (num_tiles_m < TB_DIM && narrow_m_slm_size <= get_max_dynamic_smem()) { + // Narrow-M: batch TB_DIM K-tiles per block, fully utilizing all threads. + dim3 num_blocks_narrow(DIVUP(num_tiles_k, TB_DIM)); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_narrow_m_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, narrow_m_slm_size)); + swizzle_col_scaling_narrow_m_kernel + <<>>( + input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k, original_M, + original_K); + } else { + int vec_load_size = (num_tiles_m - 1) % 4 + 1; + if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_col_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, + k, original_M, original_K); + break; + case 2: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_col_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, + k, original_M, original_K); + break; + case 1: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_col_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, + k, original_M, original_K); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } } NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -853,83 +1212,138 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s template void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, const int vec_load_size, const bool is_rowwise, + const bool use_narrow_k, const bool use_narrow_m, cudaStream_t stream) { - int n_tiles_in_tb = TB_DIM * vec_load_size; - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - /* Calculate number of CUDA blocks needed for each tensor. - * We have to do it here because we have to iterate over all tensors in this batch to - * get the minimum vec_load_size. - */ - for (size_t j = 0; j < kernel_args.num_tensors; j++) { - const int m = kernel_args.m_list[j]; - const int k = kernel_args.k_list[j]; - int num_tiles_m = m / SF_TILE_DIM_M; - int num_tiles_k = k / SF_TILE_DIM_K; - if (is_rowwise) { - kernel_args.block_range[j + 1] = - kernel_args.block_range[j] + DIVUP(num_tiles_k, n_tiles_in_tb) * num_tiles_m; - } else { - kernel_args.block_range[j + 1] = - kernel_args.block_range[j] + - DIVUP(num_tiles_k, TB_DIM) * DIVUP(num_tiles_m, vec_load_size); + // cudaFuncSetAttribute is a host-synchronous driver call; cache the max shared memory + // setting per kernel variant so we only pay the cost when slm_size actually increases. + auto set_smem_if_needed = [](auto kernel_fn, int slm, int& cached) { + if (cached < slm) { + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, slm)); + cached = slm; } - } - // Launch kernel - const int num_blocks = kernel_args.block_range[kernel_args.num_tensors]; + }; + dim3 block_size(TB_DIM, TB_DIM); - if (is_rowwise) { - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - multi_tensor_swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - multi_tensor_swizzle_row_scaling_kernel - <<>>(kernel_args); - break; - case 2: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - multi_tensor_swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - multi_tensor_swizzle_row_scaling_kernel - <<>>(kernel_args); - break; - case 1: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - multi_tensor_swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - multi_tensor_swizzle_row_scaling_kernel - <<>>(kernel_args); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; + + if (is_rowwise && use_narrow_k) { + // Narrow-K path: each block handles TB_DIM M-tiles with full thread utilization. + // slm_size depends on num_tiles_k, which can vary per tensor — use the max. + int max_num_tiles_k = 0; + for (size_t j = 0; j < kernel_args.num_tensors; j++) { + const int num_tiles_m = kernel_args.m_list[j] / SF_TILE_DIM_M; + const int num_tiles_k = kernel_args.k_list[j] / SF_TILE_DIM_K; + max_num_tiles_k = std::max(max_num_tiles_k, num_tiles_k); + kernel_args.block_range[j + 1] = kernel_args.block_range[j] + DIVUP(num_tiles_m, TB_DIM); } + int slm_size = TB_DIM * max_num_tiles_k * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int num_blocks = kernel_args.block_range[kernel_args.num_tensors]; + + static int cached_narrow_k = -1; + set_smem_if_needed( + multi_tensor_swizzle_row_scaling_narrow_k_kernel, slm_size, + cached_narrow_k); + multi_tensor_swizzle_row_scaling_narrow_k_kernel + <<>>(kernel_args); + } else if (!is_rowwise && use_narrow_m) { + // Narrow-M path: each block handles TB_DIM K-tiles with full thread utilization. + // slm_size depends on num_tiles_m, which can vary per tensor — use the max. + int max_num_tiles_m = 0; + for (size_t j = 0; j < kernel_args.num_tensors; j++) { + const int num_tiles_m = kernel_args.m_list[j] / SF_TILE_DIM_M; + const int num_tiles_k = kernel_args.k_list[j] / SF_TILE_DIM_K; + max_num_tiles_m = std::max(max_num_tiles_m, num_tiles_m); + kernel_args.block_range[j + 1] = kernel_args.block_range[j] + DIVUP(num_tiles_k, TB_DIM); + } + int slm_size = TB_DIM * max_num_tiles_m * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int num_blocks = kernel_args.block_range[kernel_args.num_tensors]; + + static int cached_narrow_m = -1; + set_smem_if_needed( + multi_tensor_swizzle_col_scaling_narrow_m_kernel, slm_size, + cached_narrow_m); + multi_tensor_swizzle_col_scaling_narrow_m_kernel + <<>>(kernel_args); } else { - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - multi_tensor_swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - multi_tensor_swizzle_col_scaling_kernel - <<>>(kernel_args); - break; - case 2: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - multi_tensor_swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - multi_tensor_swizzle_col_scaling_kernel - <<>>(kernel_args); - break; - case 1: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - multi_tensor_swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - multi_tensor_swizzle_col_scaling_kernel - <<>>(kernel_args); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; + int n_tiles_in_tb = TB_DIM * vec_load_size; + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + /* Calculate number of CUDA blocks needed for each tensor. + * We have to do it here because we have to iterate over all tensors in this batch to + * get the minimum vec_load_size. + */ + for (size_t j = 0; j < kernel_args.num_tensors; j++) { + const int m = kernel_args.m_list[j]; + const int k = kernel_args.k_list[j]; + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + if (is_rowwise) { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + DIVUP(num_tiles_k, n_tiles_in_tb) * num_tiles_m; + } else { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + + DIVUP(num_tiles_k, TB_DIM) * DIVUP(num_tiles_m, vec_load_size); + } + } + const int num_blocks = kernel_args.block_range[kernel_args.num_tensors]; + + static int cached_row_int4 = -1, cached_row_int2 = -1, cached_row_int1 = -1; + static int cached_col_int4 = -1, cached_col_int2 = -1, cached_col_int1 = -1; + + if (is_rowwise) { + switch (vec_load_size) { + case 4: + set_smem_if_needed( + multi_tensor_swizzle_row_scaling_kernel, slm_size, + cached_row_int4); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 2: + set_smem_if_needed( + multi_tensor_swizzle_row_scaling_kernel, slm_size, + cached_row_int2); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 1: + set_smem_if_needed( + multi_tensor_swizzle_row_scaling_kernel, slm_size, + cached_row_int1); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } else { + switch (vec_load_size) { + case 4: + set_smem_if_needed( + multi_tensor_swizzle_col_scaling_kernel, slm_size, + cached_col_int4); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 2: + set_smem_if_needed( + multi_tensor_swizzle_col_scaling_kernel, slm_size, + cached_col_int2); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 1: + set_smem_if_needed( + multi_tensor_swizzle_col_scaling_kernel, slm_size, + cached_col_int1); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } } } NVTE_CHECK_CUDA(cudaGetLastError()); @@ -1019,7 +1433,8 @@ void launch_multi_tensor_unswizzle_scaling_factors(MultiSwizzleArgs& kernel_args } void multi_tensor_swizzle_scaling_factors(const std::vector& input, - std::vector& output, cudaStream_t stream) { + std::vector& output, cudaStream_t stream, + bool check_scale_inv_shapes) { auto num_tensors = input.size(); bool all_has_data = true; bool all_has_columnwise_data = true; @@ -1038,8 +1453,10 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, // We don't allow empty tensors. They should be filtered out before calling this function. NVTE_CHECK(input[i]->numel() != 0, "Tensor input[", i, "] is empty."); - CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]"); - CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]"); + CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]", + check_scale_inv_shapes); + CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]", + check_scale_inv_shapes); all_has_data = all_has_data && input[i]->scale_inv.has_data(); all_has_columnwise_data = (all_has_columnwise_data && input[i]->columnwise_scale_inv.has_data()); @@ -1060,16 +1477,18 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, kernel_args.num_tensors = 0; kernel_args.block_range[0] = 0; int vec_load_size = 4; + bool all_narrow_k = true; for (size_t i = 0; i < num_tensors; i++) { //Launch kernel if argument struct is full if (kernel_args.num_tensors == kMaxTensorsPerKernel) { // There is no int3 and misaligned if using int4/int2. if (vec_load_size == 3) vec_load_size = 1; launch_multi_tensor_swizzle_scaling_factors( - kernel_args, vec_load_size, true, stream); + kernel_args, vec_load_size, true, all_narrow_k, false, stream); // Reset the argument struct and vec_load_size kernel_args.num_tensors = 0; vec_load_size = 4; + all_narrow_k = true; } int m, k; @@ -1103,6 +1522,10 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, } int num_tiles_k = k / SF_TILE_DIM_K; + const int narrow_k_slm = + TB_DIM * num_tiles_k * SF_TILE_DIM_M * SF_TILE_DIM_K * static_cast(sizeof(int8_t)); + all_narrow_k = + all_narrow_k && (num_tiles_k < TB_DIM) && (narrow_k_slm <= get_max_dynamic_smem()); int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; // We use the minimum vec_load_size across all tensors. // TODO(zhongbo): fix vec_load_size for NVFP4 @@ -1132,7 +1555,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, // There is no int3 and misaligned if using int4/int2. if (vec_load_size == 3) vec_load_size = 1; launch_multi_tensor_swizzle_scaling_factors( - kernel_args, vec_load_size, true, stream); + kernel_args, vec_load_size, true, all_narrow_k, false, stream); } if (columnwise_swizzle) { @@ -1143,16 +1566,18 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, kernel_args.num_tensors = 0; kernel_args.block_range[0] = 0; int vec_load_size = 4; + bool all_narrow_m = true; for (size_t i = 0; i < num_tensors; i++) { //Launch kernel if argument struct is full if (kernel_args.num_tensors == kMaxTensorsPerKernel) { // There is no int3 and misaligned if using int4/int2. if (vec_load_size == 3) vec_load_size = 1; launch_multi_tensor_swizzle_scaling_factors( - kernel_args, vec_load_size, false, stream); + kernel_args, vec_load_size, false, false, all_narrow_m, stream); // Reset the argument struct and vec_load_size kernel_args.num_tensors = 0; vec_load_size = 4; + all_narrow_m = true; } const int m = input[i]->columnwise_scale_inv.shape[1]; const int k = input[i]->columnwise_scale_inv.shape[0]; @@ -1166,7 +1591,12 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, "Input.columnwise_scale_inv size is not equal to " "Output.columnwise_scale_inv size!"); + int num_tiles_m = m / SF_TILE_DIM_M; int num_tiles_k = k / SF_TILE_DIM_K; + const int narrow_m_slm = + TB_DIM * num_tiles_m * SF_TILE_DIM_M * SF_TILE_DIM_K * static_cast(sizeof(int8_t)); + all_narrow_m = + all_narrow_m && (num_tiles_m < TB_DIM) && (narrow_m_slm <= get_max_dynamic_smem()); int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; // We use the minimum vec_load_size across all tensors. vec_load_size = std::min(vec_load_size, vec_load_size_i); @@ -1184,7 +1614,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, // There is no int3 and misaligned if using int4/int2. if (vec_load_size == 3) vec_load_size = 1; launch_multi_tensor_swizzle_scaling_factors( - kernel_args, vec_load_size, false, stream); + kernel_args, vec_load_size, false, false, all_narrow_m, stream); } } @@ -1529,7 +1959,24 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen input_list.push_back(convertNVTETensorCheck(inputs[i])); output_list.push_back(convertNVTETensorCheck(outputs[i])); } - multi_tensor_swizzle_scaling_factors(input_list, output_list, stream); + multi_tensor_swizzle_scaling_factors(input_list, output_list, stream, + /*check_scale_inv_shapes=*/true); +} + +void nvte_multi_tensor_swizzle_scaling_factors_unchecked(const NVTETensor* inputs, + NVTETensor* outputs, + const size_t num_tensors, + cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_swizzle_scaling_factors_unchecked); + using namespace transformer_engine; + NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); + std::vector input_list, output_list; + for (size_t i = 0; i < num_tensors; i++) { + input_list.push_back(convertNVTETensorCheck(inputs[i])); + output_list.push_back(convertNVTETensorCheck(outputs[i])); + } + multi_tensor_swizzle_scaling_factors(input_list, output_list, stream, + /*check_scale_inv_shapes=*/false); } void nvte_unswizzle_scaling_factors(const NVTETensor input, NVTETensor output, @@ -1596,23 +2043,56 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* const size_t padded_m = round_up_to_multiple(m, 128); const size_t padded_k = round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); - const size_t scale_elems = padded_m * padded_k; + // Per-tensor scale-element counts: + // - "padded" layout: each tensor occupies padded_m * padded_k elements + // (total buffer = num_tensors * padded_m * padded_k). + // - "compact" layout (what the grouped MXFP8 quantize kernel actually writes): + // per-tensor stride is m * padded_k (rowwise) or DIVUP(k,32) * padded_m + // (columnwise) and the total buffer the C++ allocator hands out has its + // grouped first dim padded up to a multiple of 128 (rowwise) or 4 + // (columnwise) -- so the buffer may be slightly larger than + // num_tensors * compact_scale_elems, with trailing alignment slack at + // the very end (never read because of the per-tensor row/k guard in the + // kernel impl). + // The output is always written in the padded layout. The input may be in + // either layout; the kernel handles the compact case safely by using + // different per-tensor strides for input vs output and skipping loads past + // the per-tensor extent. + const size_t padded_scale_elems = padded_m * padded_k; + const size_t compact_scale_elems = + rowwise ? m * padded_k : DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)) * padded_m; + const size_t compact_total_scale_elems = + rowwise ? round_up_to_multiple(input->num_tensors * m, 128) * padded_k + : round_up_to_multiple( + input->num_tensors * DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4) * + padded_m; const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) : typeToSize(input->columnwise_scale_inv.dtype); - const size_t scale_stride_bytes = scale_elems * scale_elem_size; - if (rowwise) { - NVTE_CHECK(input->scale_inv.numel() == input->num_tensors * scale_elems, - "Grouped input scale_inv size does not match expected packed size."); - NVTE_CHECK(output->scale_inv.numel() == output->num_tensors * scale_elems, - "Grouped output scale_inv size does not match expected packed size."); + const size_t input_scale_numel = + rowwise ? input->scale_inv.numel() : input->columnwise_scale_inv.numel(); + const size_t output_scale_numel = + rowwise ? output->scale_inv.numel() : output->columnwise_scale_inv.numel(); + + bool input_is_compact; + if (input_scale_numel == input->num_tensors * padded_scale_elems) { + input_is_compact = false; + } else if (input_scale_numel == compact_total_scale_elems) { + input_is_compact = true; } else { - NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems, - "Grouped input columnwise_scale_inv size does not match expected packed size."); - NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems, - "Grouped output columnwise_scale_inv size does not match expected packed size."); + NVTE_ERROR("Grouped input ", (rowwise ? "scale_inv" : "columnwise_scale_inv"), + " size does not match expected packed size (got ", input_scale_numel, + ", expected either ", input->num_tensors * padded_scale_elems, + " (per-tensor padded) or ", compact_total_scale_elems, " (compact))."); } + NVTE_CHECK(output_scale_numel == input->num_tensors * padded_scale_elems, "Grouped output ", + (rowwise ? "scale_inv" : "columnwise_scale_inv"), + " size does not match expected per-tensor padded size."); + + const size_t input_stride_bytes = + (input_is_compact ? compact_scale_elems : padded_scale_elems) * scale_elem_size; + const size_t output_stride_bytes = padded_scale_elems * scale_elem_size; const int num_tiles_m = padded_m / SF_TILE_DIM_M; const int num_tiles_k = padded_k / SF_TILE_DIM_K; @@ -1635,69 +2115,25 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; if (rowwise) { - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_row_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 2: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_row_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 1: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_row_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - } + TRANSFORMER_ENGINE_VECTORIZED_LOAD_INTEGER_TYPE_SWITCH(vec_load_size, LType, { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_row_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_row_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + input_stride_bytes, output_stride_bytes); + }); } else { - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_col_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 2: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_col_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 1: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_col_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - } + TRANSFORMER_ENGINE_VECTORIZED_LOAD_INTEGER_TYPE_SWITCH(vec_load_size, LType, { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_col_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_col_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + input_stride_bytes, output_stride_bytes); + }); } NVTE_CHECK_CUDA(cudaGetLastError()); }; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index eacd10eb30..1261879a8b 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -120,7 +120,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { const auto &expected = std::vector{expected_x, expected_y}; NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, - "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", + "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", t.columnwise_scale_inv.shape, ")"); } } else if (t.scaling_mode == NVTE_NVFP4_1D_SCALING) { @@ -144,7 +144,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { } } -void CheckInputTensor(const Tensor &t, const std::string &name) { +void CheckInputTensor(const Tensor &t, const std::string &name, bool check_scale_inv_shapes) { const DType type = t.dtype(); if (is_fp8_dtype(type)) { // FP8 input needs to have scale_inv @@ -195,7 +195,9 @@ void CheckInputTensor(const Tensor &t, const std::string &name) { } NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!"); - CheckScaleTensorShape(t, name); + if (check_scale_inv_shapes) { + CheckScaleTensorShape(t, name); + } } void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) { diff --git a/transformer_engine/common/triton/mhc.py b/transformer_engine/common/triton/mhc.py new file mode 100644 index 0000000000..965bb437ff --- /dev/null +++ b/transformer_engine/common/triton/mhc.py @@ -0,0 +1,1693 @@ +# pylint: disable=missing-function-docstring + +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""triton kernels for mHC (manifold Hyper-Connection) operations""" + +import itertools +import os + +import triton +import triton.language as tl + + +def projection_config_fwd(): + block_m = [64, 128] + block_k = [1024] + step_k = [32, 64] + warps = [4] + stages = [3, 4] + + configs = [] + for m, bk, sk, w, s in itertools.product(block_m, block_k, step_k, warps, stages): + configs.append( + triton.Config( + {"BLOCK_SIZE_M": m, "BLOCK_SIZE_K": bk, "STEP_SIZE_K": sk}, + num_warps=w, + num_stages=s, + ) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +def projection_config_bwd(): + block_m = [32, 128] + block_k = [128] + warps = [2] + stages = [2, 3, 4] + + configs = [] + for m, bk, w, s in itertools.product(block_m, block_k, warps, stages): + configs.append( + triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_K": bk}, num_warps=w, num_stages=s) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.autotune(configs=projection_config_fwd(), key=["M", "K"], reset_to_zero=["h_ptr", "ms_ptr"]) +@triton.jit +def _mhc_projection_fwd_fused( + x_ptr, # (M, K) + phi_ptr, # (N, K) + h_ptr, # (M, 32) + ms_ptr, # (M,) + M, + N, + K, + stride_xm, + stride_xk: tl.constexpr, + stride_phin, + stride_phik: tl.constexpr, + stride_hm: tl.constexpr, + stride_hn: tl.constexpr, + stride_ms: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + STEP_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + precision: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + tl.assume(pid_m >= 0) + tl.assume(pid_k >= 0) + tl.assume(stride_xm > 0) + tl.assume(stride_xk == 1) + tl.assume(stride_phin == K) + tl.assume(stride_phik == 1) + tl.assume(stride_hm == 32) + tl.assume(stride_hn == 1) + tl.assume(stride_ms == 1) + + tl.assume(BLOCK_SIZE_M % 32 == 0) + tl.assume(BLOCK_SIZE_K % 32 == 0) + tl.assume(BLOCK_SIZE_N == 32) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n_full = tl.arange(0, BLOCK_SIZE_N) + mask_m = offs_m < M + + h_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + ms_acc = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + + k_base = pid_k * BLOCK_SIZE_K + for k_start in range(0, tl.cdiv(BLOCK_SIZE_K, STEP_SIZE_K)): + k_offs = k_base + k_start * STEP_SIZE_K + tl.arange(0, STEP_SIZE_K) + mask_k = k_offs < K + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + phi_ptrs = phi_ptr + offs_n_full[:, None] * stride_phin + k_offs[None, :] * stride_phik + phi = tl.load( + phi_ptrs, + mask=(offs_n_full[:, None] < N) & mask_k[None, :], + other=0.0, + cache_modifier=".ca", + ) # (BLOCK_SIZE_N, BLOCK_SIZE_K) + ms_acc += tl.sum(x * x, axis=1) + h_acc = tl.dot( + x, tl.trans(phi, (1, 0)), h_acc, input_precision=precision, out_dtype=tl.float32 + ) + + h_ptrs = h_ptr + offs_m[:, None] * stride_hm + offs_n_full[None, :] * stride_hn + tl.atomic_add(h_ptrs, h_acc, mask=mask_m[:, None], sem="relaxed") + + offs_ms = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + masks_ms = offs_ms < M + offs_ms %= M + ms_ptrs = ms_ptr + offs_ms * stride_ms + ms = ms_acc / tl.cast(K, tl.float32) + tl.atomic_add(ms_ptrs, ms, mask=masks_ms, sem="relaxed") + + +@triton.autotune( + configs=projection_config_bwd(), + key=["M", "K"], +) +@triton.jit +def _mhc_projection_bwd_fused( + x_ptr, + grad_x_ptr, # (M, K) + phi_ptr, # (N, K) + grad_h_ptr, # (M, N) + grad_ms_ptr, # (M,) + M, + N, + K, + stride_xm, + stride_xk: tl.constexpr, + stride_grad_xm, + stride_grad_xk: tl.constexpr, + stride_phin, + stride_phik: tl.constexpr, + stride_grad_phin, + stride_grad_phik: tl.constexpr, + stride_grad_hm: tl.constexpr, + stride_grad_hn: tl.constexpr, + stride_grad_ms: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + precision: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + tl.assume(pid_m >= 0) + tl.assume(pid_k >= 0) + tl.assume(stride_xm > 0) + tl.assume(stride_xk == 1) + tl.assume(stride_grad_hm == 32) + tl.assume(stride_grad_hn == 1) + tl.assume(stride_phin == K) + tl.assume(stride_phik == 1) + tl.assume(stride_grad_phin == K) + tl.assume(stride_grad_phik == 1) + tl.assume(stride_grad_ms == 1) + + tl.assume(BLOCK_SIZE_M % 32 == 0) + tl.assume(BLOCK_SIZE_K % 32 == 0) + tl.assume(BLOCK_SIZE_N == 32) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_n_full = tl.arange(0, BLOCK_SIZE_N) + mask_m = offs_m < M + mask_k = offs_k < K + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + + grad_h_ptrs = ( + grad_h_ptr + offs_m[:, None] * stride_grad_hm + offs_n_full[None, :] * stride_grad_hn + ) + grad_h = tl.load( + grad_h_ptrs, mask=mask_m[:, None] & (offs_n_full[None, :] < N), other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + + phi_ptrs = phi_ptr + offs_n_full[:, None] * stride_phin + offs_k[None, :] * stride_phik + offs_ms = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + grad_ms_ptrs = grad_ms_ptr + offs_ms * stride_grad_ms + + phi = tl.load( + phi_ptrs, mask=(offs_n_full[:, None] < N) & mask_k[None, :], other=0.0 + ) # (BLOCK_SIZE_N, BLOCK_SIZE_K) + grad_ms = tl.load( + grad_ms_ptrs, mask=offs_ms < M, other=0.0, cache_modifier=".ca" + ) # (BLOCK_SIZE_M,) + + grad_x = x * (grad_ms * 2 / tl.cast(K, tl.float32))[:, None] + grad_x = tl.dot( + grad_h, phi, acc=grad_x, input_precision=precision, out_dtype=tl.float32 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_k[None, :] * stride_grad_xk + grad_x = grad_x.to(x.dtype) + tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_k[None, :]) + + +def scale_config(): + block_m = [128] + warps = [4] + stages = [1, 2, 4] + + configs = [] + for m, w, s in itertools.product(block_m, warps, stages): + configs.append(triton.Config({"BLOCK_SIZE_M": m}, num_warps=w, num_stages=s)) + + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.autotune( + configs=scale_config(), + key=["M"], +) +@triton.jit +def _mhc_scale_fwd_fused( + h_ptr, # (M, 2n + n^2), which is padded to (M, 32) in the last dimension + a_ptr, # (3,) + b_ptr, # (2n + n^2) + ms_ptr, # (M,) + out_ptr, # (M, 2n + n^2), which is padded to (M, 32) in the last dimension + M, + n, + stride_hm, + stride_hn, + stride_a, + stride_b, + stride_ms, + stride_out_m, + stride_out_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + eps: tl.constexpr, +): + pid = tl.program_id(0) + + tl.assume(M > 0) + tl.assume(n == 4) + tl.assume(stride_hm == 32) + tl.assume(stride_hn == 1) + tl.assume(stride_out_m == 32) + tl.assume(stride_out_n == 1) + tl.assume(stride_a == 1) + tl.assume(stride_b == 1) + tl.assume(stride_ms == 1) + tl.assume(BLOCK_SIZE_N == 32) + + N = 2 * n + n * n + + offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + cols = tl.arange(0, BLOCK_SIZE_N) + mask_m = offs_m < M + + # Expand a to BLOCK_SIZE_N length + offs_a = tl.zeros_like(cols) + offs_a = tl.where((cols >= n) & (cols < 2 * n), 1, offs_a) + offs_a = tl.where((cols >= 2 * n) & (cols < 2 * n + n * n), 2, offs_a) + # Pick a[0] from a for the first 4 columns, a[1] for the next 4 columns, and a[2] for the rest of the columns + a = tl.load( + a_ptr + offs_a * stride_a, mask=offs_a < 3, other=0.0 + ) # a[2*n + n*n:] is filled with garbage + a = tl.where(cols < N, a, 0.0) # Mask out the garbage values in a + + b = tl.load(b_ptr + cols * stride_b, mask=cols < N, other=0.0) # (BLOCK_SIZE_N,) + ms = tl.load(ms_ptr + offs_m * stride_ms, mask=mask_m, other=0.0) # (BLOCK_SIZE_M,) + # In projection kernel we use split-K so we only have the accumulated ms, + # and now we need to take sqrt on the accumulated ms to obtain the RMSNorm denominator. + rms = tl.sqrt(ms + eps) + + h = tl.load( + h_ptr + offs_m[:, None] * stride_hm + cols[None, :] * stride_hn, + mask=mask_m[:, None], + other=0.0, + ) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + + h = a[None, :] * h + h = tl.fma( + h, 1.0 / rms[:, None], b[None, :] + ) # (BLOCK_SIZE_M, BLOCK_SIZE_N), where the first 2n columns are H_pre and H_post, and the rest are H_res + h_sigmoid_pre = tl.sigmoid(h) + h_sigmoid_post = 2 * h_sigmoid_pre + + # Use this mask to select h[:, :2n] + h = tl.where(cols[None, :] < n, h_sigmoid_pre, h) + h = tl.where((cols[None, :] >= n) & (cols[None, :] < 2 * n), h_sigmoid_post, h) + + tl.store( + out_ptr + offs_m[:, None] * stride_out_m + cols[None, :] * stride_out_n, + h, + mask=mask_m[:, None], + ) + + +@triton.autotune( + configs=scale_config(), + key=["M"], + reset_to_zero=["grad_a_ptr", "grad_b_ptr"], +) +@triton.jit +def _mhc_scale_bwd_fused( + grad_out_ptr, + out_ptr, # (M, 2n + n^2), which is padded to (M, 32) in the last dimension + grad_h_ptr, + h_ptr, # (M, 2n + n^2), which is padded to (M, 32) in the last dimension + grad_a_ptr, + a_ptr, # (3,) + grad_b_ptr, # (2n + n^2,) + grad_ms_ptr, + ms_ptr, # (M,) + M, + n, + stride_grad_out_m, + stride_grad_out_n, + stride_out_m, + stride_out_n, + stride_grad_hm, + stride_grad_hn, + stride_hm, + stride_hn, + stride_grad_a, + stride_a, + stride_grad_b, + stride_grad_ms, + stride_ms, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + eps: tl.constexpr, +): + pid = tl.program_id(0) + + tl.assume(M > 0) + tl.assume(n == 4) + tl.assume(stride_grad_out_m == 32) + tl.assume(stride_grad_out_n == 1) + tl.assume(stride_out_m == 32) + tl.assume(stride_out_n == 1) + tl.assume(stride_grad_hm == 32) + tl.assume(stride_grad_hn == 1) + tl.assume(stride_hm == 32) + tl.assume(stride_hn == 1) + tl.assume(stride_grad_a == 1) + tl.assume(stride_a == 1) + tl.assume(stride_grad_b == 1) + tl.assume(stride_grad_ms == 1) + tl.assume(stride_ms == 1) + tl.assume(BLOCK_SIZE_N == 32) + + N = 2 * n + n * n + + offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + cols = tl.arange(0, BLOCK_SIZE_N) + mask_m = offs_m < M + mask_n = cols < N + + # Expand a to BLOCK_SIZE_N length + offs_a = tl.zeros_like(cols) + offs_a = tl.where((cols >= n) & (cols < 2 * n), 1, offs_a) + offs_a = tl.where((cols >= 2 * n) & (cols < 2 * n + n * n), 2, offs_a) + # Pick a[0] from a for the first 4 columns, a[1] for the next 4 columns, and a[2] for the rest of the columns + a = tl.load( + a_ptr + offs_a * stride_a, mask=offs_a < 3, other=0.0 + ) # a[2*n + n*n:] is filled with garbage + a = tl.where(cols < N, a, 0.0) # Mask out the garbage values in a + + ms_offsets = offs_m + ms_mask = mask_m + ms = tl.load(ms_ptr + ms_offsets * stride_ms, mask=ms_mask, other=1.0) # (BLOCK_SIZE_M,) + rms = tl.sqrt(ms + eps) + + grad_out = tl.load( + grad_out_ptr + offs_m[:, None] * stride_grad_out_m + cols[None, :] * stride_grad_out_n, + mask=mask_m[:, None] & mask_n[None, :], + other=0.0, + ) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + out = tl.load( + out_ptr + offs_m[:, None] * stride_out_m + cols[None, :] * stride_out_n, + mask=mask_m[:, None] & mask_n[None, :], + other=0.0, + ) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + h = tl.load( + h_ptr + offs_m[:, None] * stride_hm + cols[None, :] * stride_hn, + mask=mask_m[:, None] & mask_n[None, :], + other=0.0, + ) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + + # Gradiient of H before H_pre and H_post go through sigmoid + grad_out_out = grad_out * out + grad_h_pre = grad_out_out * (1 - out) + grad_h_post = grad_out_out * 0.5 * (2 - out) + grad_h = grad_out + grad_h = tl.where(cols[None, :] < n, grad_h_pre, grad_h) + grad_h = tl.where((cols[None, :] >= n) & (cols[None, :] < 2 * n), grad_h_post, grad_h) + + grad_a = tl.sum(h * grad_h / rms[:, None], axis=0).to(a.dtype) + # Write grad_a[0:4].sum to grad_a_ptr[0], grad_a[4:8].sum to grad_a_ptr[1], and grad_a[8:24].sum to grad_a_ptr[2] + tl.atomic_add(grad_a_ptr, tl.where(cols[None, :] < n, grad_a, 0.0).sum(), sem="relaxed") + tl.atomic_add( + grad_a_ptr + stride_grad_a, + tl.where((cols[None, :] >= n) & (cols[None, :] < 2 * n), grad_a, 0.0).sum(), + sem="relaxed", + ) + tl.atomic_add( + grad_a_ptr + 2 * stride_grad_a, + tl.where((cols[None, :] >= 2 * n) & (cols[None, :] < 2 * n + n * n), grad_a, 0.0).sum(), + sem="relaxed", + ) + + grad_b = tl.sum(grad_h, axis=0).to(a.dtype) + tl.atomic_add(grad_b_ptr + cols * stride_grad_b, grad_b, mask=cols < N, sem="relaxed") + + grad_rms = (tl.sum((-grad_h * h * a[None, :]), axis=1) / (rms * rms)).to(rms.dtype) + grad_ms = grad_rms / (2 * rms) + tl.store(grad_ms_ptr + ms_offsets * stride_grad_ms, grad_ms, mask=ms_mask) + + grad_h = a[None, :] * grad_h / rms[:, None] + tl.store( + grad_h_ptr + offs_m[:, None] * stride_grad_hm + cols[None, :] * stride_grad_hn, + grad_h, + mask=mask_m[:, None] & mask_n[None, :], + ) + + +def sinkhorn_config(): + block = [256, 1024] + warps = [2, 8] + stages = [2, 4] + configs = [] + for b, w, s in itertools.product(block, warps, stages): + configs.append(triton.Config({"BLOCK_SIZE": b}, num_warps=w, num_stages=s)) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.autotune( + configs=sinkhorn_config(), + key=["M"], +) +@triton.jit +def _mhc_sinkhorn_fwd_fused_recompute( + x_ptr, # (M, n*n) + output_ptr, # (M, n*n) + stride_xm, + stride_xn, + stride_out_m, + stride_out_n, + M, + n: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + iters, +): + pid = tl.program_id(0) + + tl.static_assert(BLOCK_SIZE % (n * n) == 0, "BLOCK_SIZE must be divisible by n*n") + tl.assume(M > 0 and iters > 0) + tl.assume(n == 4) + + BATCH_SIZE: tl.constexpr = BLOCK_SIZE // (n * n) + + offs_batch = pid * BATCH_SIZE + tl.arange(0, BATCH_SIZE) + offs_nn = tl.arange(0, n * n) + mask_batch = offs_batch < M + + x_ptrs = x_ptr + offs_batch[:, None] * stride_xm + offs_nn[None, :] * stride_xn + x = tl.load(x_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + x = tl.reshape(x, (BATCH_SIZE, n, n)) # (BATCH_SIZE, n, n) + + log_mu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + log_nu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + + f = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + g = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + + for _ in range(iters): + # Update f: logsumexp over the column dimension (1) + f = x + g[:, None, :] # Broadcast g to (BATCH_SIZE, n, n) + f_max = tl.max(f, axis=2) + f = tl.log(tl.sum(tl.exp(f - f_max[:, :, None]), axis=2)) # logsumexp over columns + f = log_mu - f - f_max + + # Update g: logsumexp over the row dimension (2) + g = x + f[:, :, None] # Broadcast f to (BATCH_SIZE, n, n) + g_max = tl.max(g, axis=1) + g = tl.log(tl.sum(tl.exp(g - g_max[:, None, :]), axis=1)) # logsumexp over rows + g = log_nu - g - g_max + + log_P = f[:, :, None] + x + g[:, None, :] + log_P = tl.reshape( + log_P, + ( + BATCH_SIZE, + n * n, + ), + ) + P = tl.exp(log_P) + + output_ptrs = output_ptr + offs_batch[:, None] * stride_out_m + offs_nn[None, :] * stride_out_n + tl.store(output_ptrs, P, mask=mask_batch[:, None]) + + +@triton.autotune( + configs=sinkhorn_config(), + key=["M"], +) +@triton.jit +def _mhc_sinkhorn_bwd_fused_recompute( + grad_out_ptr, + output_ptr, + grad_x_ptr, + x_ptr, + hist_f_ptr, + hist_g_ptr, + stride_grad_out_m, + stride_grad_out_n, + stride_out_m, + stride_out_n, + stride_grad_xm, + stride_grad_xn, + stride_xm, + stride_xn, + M, + n: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + iters, +): + pid = tl.program_id(0) + + tl.static_assert(BLOCK_SIZE % (n * n) == 0, "BLOCK_SIZE must be divisible by n*n") + tl.assume(M > 0 and iters > 0) + tl.assume(n == 4) + + BATCH_SIZE: tl.constexpr = BLOCK_SIZE // (n * n) # Assume there's no remainder for simplicity + + offs_batch = pid * BATCH_SIZE + tl.arange(0, BATCH_SIZE) + offs_nn = tl.arange(0, n * n) + offs_n_hist = tl.arange(0, n) + mask_batch = offs_batch < M + + x_ptrs = x_ptr + offs_batch[:, None] * stride_xm + offs_nn[None, :] * stride_xn + x = tl.load(x_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + x = tl.reshape(x, (BATCH_SIZE, n, n)) # (BATCH_SIZE, n, n) + + P_ptrs = output_ptr + offs_batch[:, None] * stride_out_m + offs_nn[None, :] * stride_out_n + P = tl.load(P_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + P = tl.reshape(P, (BATCH_SIZE, n, n)) + + grad_out_ptrs = ( + grad_out_ptr + + offs_batch[:, None] * stride_grad_out_m + + offs_nn[None, :] * stride_grad_out_n + ) + grad_out = tl.load(grad_out_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + grad_out = tl.reshape(grad_out, (BATCH_SIZE, n, n)) # (BATCH_SIZE, n, n) + + sbn = M * n + + # Recompute the full history of f and g + log_mu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + log_nu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + + f = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + g = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + + f_hist_ptrs = hist_f_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] + g_hist_ptrs = hist_g_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] + tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) + tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) + + for iter_idx in range(iters): + # Update f: logsumexp over the column dimension (1) + f = x + g[:, None, :] # Broadcast g to (BATCH_SIZE, n, n) + f_max = tl.max(f, axis=2) + f = tl.log(tl.sum(tl.exp(f - f_max[:, :, None]), axis=2)) # logsumexp over columns + f = log_mu - f - f_max + + f_hist_ptrs = ( + hist_f_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) + + # Update g: logsumexp over the row dimension (2) + g = x + f[:, :, None] # Broadcast f to (BATCH_SIZE, n, n) + g_max = tl.max(g, axis=1) + g = tl.log(tl.sum(tl.exp(g - g_max[:, None, :]), axis=1)) # logsumexp over rows + g = log_nu - g - g_max + + g_hist_ptrs = ( + hist_g_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) + + # Backward pass + grad_log_P = grad_out * P # (BATCH_SIZE, n, n) + zeros = tl.zeros_like(grad_log_P) + grad_g = tl.sum(grad_log_P, axis=1) # (BATCH_SIZE, n) + grad_x = grad_log_P + + g_hist_ptrs = hist_g_ptr + iters * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + g = tl.load(g_hist_ptrs, mask=mask_batch[:, None], other=0.0) + g = tl.reshape(g, (BATCH_SIZE, n)) + + for iter_idx in range(iters, 0, -1): + f_hist_ptrs = hist_f_ptr + iter_idx * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + f = tl.load(f_hist_ptrs, mask=mask_batch[:, None], other=0.0) + f = tl.reshape(f, (BATCH_SIZE, n)) + + g_hist_ptrs = ( + hist_g_ptr + (iter_idx - 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + g_next = tl.load(g_hist_ptrs, mask=mask_batch[:, None], other=0.0) + g_next = tl.reshape(g_next, (BATCH_SIZE, n)) + + term_g = -grad_g[:, None, :] * tl.exp(f[:, :, None] + x + g[:, None, :]) + grad_f = tl.sum(term_g + grad_log_P, axis=2) # (BATCH_SIZE, n) + # Only the last iteration's f will contribute to gradients with both grad_g1 and grad_log_P + grad_log_P = zeros # Zero out grad_log_P for next iterations + + g = g_next + + term_f = -grad_f[:, :, None] * tl.exp(f[:, :, None] + x + g[:, None, :]) + grad_g = tl.sum(term_f, axis=1) # (BATCH_SIZE, n) + + grad_x += term_f + term_g + + grad_x_ptrs = ( + grad_x_ptr + offs_batch[:, None] * stride_grad_xm + offs_nn[None, :] * stride_grad_xn + ) + tl.store( + grad_x_ptrs, + tl.reshape( + grad_x, + ( + BATCH_SIZE, + n * n, + ), + ), + mask=mask_batch[:, None], + ) + + +@triton.autotune( + configs=sinkhorn_config(), + key=["M"], +) +@triton.jit +def _mhc_sinkhorn_fwd_fused( + x_ptr, # (M, n*n) + output_ptr, # (M, n*n) + hist_f_ptr, # (iters+1, M, n) + hist_g_ptr, # (iters+1, M, n) + stride_xm, + stride_xn, + stride_out_m, + stride_out_n, + M, + n: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + iters, +): + pid = tl.program_id(0) + + tl.static_assert(BLOCK_SIZE % (n * n) == 0, "BLOCK_SIZE must be divisible by n*n") + tl.assume(M > 0 and iters > 0) + tl.assume(n == 4) + + BATCH_SIZE: tl.constexpr = BLOCK_SIZE // (n * n) # Assume there's no remainder for simplicity + + offs_batch = pid * BATCH_SIZE + tl.arange(0, BATCH_SIZE) + offs_nn = tl.arange(0, n * n) + offs_n_hist = tl.arange(0, n) + mask_batch = offs_batch < M + + x_ptrs = x_ptr + offs_batch[:, None] * stride_xm + offs_nn[None, :] * stride_xn + x = tl.load(x_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + x = tl.reshape(x, (BATCH_SIZE, n, n)) # (BATCH_SIZE, n, n) + + log_mu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + log_nu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + + f = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + g = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + + sbn = M * n + + # Store the initial f and g to history + f_hist_ptrs = hist_f_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] + g_hist_ptrs = hist_g_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] + tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) + tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) + + for iter_idx in range(iters): + # Update f: logsumexp over the column dimension (1) + f = x + g[:, None, :] # Broadcast g to (BATCH_SIZE, n, n) + f_max = tl.max(f, axis=2) + f = tl.log(tl.sum(tl.exp(f - f_max[:, :, None]), axis=2)) # logsumexp over columns + f = log_mu - f - f_max + + f_hist_ptrs = ( + hist_f_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) + + # Update g: logsumexp over the row dimension (2) + g = x + f[:, :, None] # Broadcast f to (BATCH_SIZE, n, n) + g_max = tl.max(g, axis=1) + g = tl.log(tl.sum(tl.exp(g - g_max[:, None, :]), axis=1)) # logsumexp over rows + g = log_nu - g - g_max + + g_hist_ptrs = ( + hist_g_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) + + log_P = f[:, :, None] + x + g[:, None, :] + log_P = tl.reshape( + log_P, + ( + BATCH_SIZE, + n * n, + ), + ) + P = tl.exp(log_P) + + output_ptrs = output_ptr + offs_batch[:, None] * stride_out_m + offs_nn[None, :] * stride_out_n + tl.store(output_ptrs, P, mask=mask_batch[:, None]) + + +@triton.autotune( + configs=sinkhorn_config(), + key=["M"], +) +@triton.jit +def _mhc_sinkhorn_bwd_fused( + grad_out_ptr, # (M, n*n) + output_ptr, # (M, n*n) + grad_x_ptr, # (M, n*n) + x_ptr, # (M, n*n) + hist_f_ptr, # (iters+1, M, n) + hist_g_ptr, # (iters+1, M, n) + stride_grad_out_m, + stride_grad_out_n, + stride_out_m, + stride_out_n, + stride_grad_xm, + stride_grad_xn, + stride_xm, + stride_xn, + M, + n: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + iters, +): + pid = tl.program_id(0) + + tl.static_assert(BLOCK_SIZE % (n * n) == 0, "BLOCK_SIZE must be divisible by n*n") + tl.assume(M > 0 and iters > 0) + tl.assume(n == 4) + + BATCH_SIZE: tl.constexpr = BLOCK_SIZE // (n * n) # Assume there's no remainder for simplicity + + offs_batch = pid * BATCH_SIZE + tl.arange(0, BATCH_SIZE) + offs_nn = tl.arange(0, n * n) + offs_n_hist = tl.arange(0, n) + mask_batch = offs_batch < M + + x_ptrs = x_ptr + offs_batch[:, None] * stride_xm + offs_nn[None, :] * stride_xn + x = tl.load(x_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + x = tl.reshape(x, (BATCH_SIZE, n, n)) # (BATCH_SIZE, n, n) + + P_ptrs = output_ptr + offs_batch[:, None] * stride_out_m + offs_nn[None, :] * stride_out_n + P = tl.load(P_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + P = tl.reshape(P, (BATCH_SIZE, n, n)) + + grad_out_ptrs = ( + grad_out_ptr + + offs_batch[:, None] * stride_grad_out_m + + offs_nn[None, :] * stride_grad_out_n + ) + grad_out = tl.load(grad_out_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) + grad_out = tl.reshape(grad_out, (BATCH_SIZE, n, n)) # (BATCH_SIZE, n, n) + + sbn = M * n + + # Backward pass + grad_log_P = grad_out * P # (BATCH_SIZE, n, n) + zeros = tl.zeros_like(grad_log_P) + grad_g = tl.sum(grad_log_P, axis=1) # (BATCH_SIZE, n) + grad_x = grad_log_P + + g_hist_ptrs = hist_g_ptr + iters * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + g = tl.load(g_hist_ptrs, mask=mask_batch[:, None], other=0.0) + g = tl.reshape(g, (BATCH_SIZE, n)) + + for iter_idx in range(iters, 0, -1): + f_hist_ptrs = hist_f_ptr + iter_idx * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + f = tl.load(f_hist_ptrs, mask=mask_batch[:, None], other=0.0) + f = tl.reshape(f, (BATCH_SIZE, n)) + + g_hist_ptrs = ( + hist_g_ptr + (iter_idx - 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + g_next = tl.load(g_hist_ptrs, mask=mask_batch[:, None], other=0.0) + g_next = tl.reshape(g_next, (BATCH_SIZE, n)) + + term_g = -grad_g[:, None, :] * tl.exp(f[:, :, None] + x + g[:, None, :]) + grad_f = tl.sum(term_g + grad_log_P, axis=2) # (BATCH_SIZE, n) + # Only the last iteration's f will contribute to gradients with both grad_g1 and grad_log_P + grad_log_P = zeros # Zero out grad_log_P for next iterations + + g = g_next + + term_f = -grad_f[:, :, None] * tl.exp(f[:, :, None] + x + g[:, None, :]) + grad_g = tl.sum(term_f, axis=1) # (BATCH_SIZE, n) + + grad_x += term_f + term_g + + grad_x_ptrs = ( + grad_x_ptr + offs_batch[:, None] * stride_grad_xm + offs_nn[None, :] * stride_grad_xn + ) + tl.store( + grad_x_ptrs, + tl.reshape( + grad_x, + ( + BATCH_SIZE, + n * n, + ), + ), + mask=mask_batch[:, None], + ) + + +def aggregate_config(): + block_m = [1, 2, 4] + block_c = [64, 128, 256] + warps = [1, 2, 4] + stages = [1, 2, 3, 4] + + configs = [] + for m, c, w, s in itertools.product(block_m, block_c, warps, stages): + configs.append( + triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c}, num_warps=w, num_stages=s) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.autotune( + configs=aggregate_config(), + key=["M", "C"], +) +@triton.jit +def _mhc_aggregate_fwd( + x_ptr, # # (M, C, n) + H_pre_ptr, # (M, n) + output_ptr, # (M, C) + M, + C, + n: tl.constexpr, + stride_xm, + stride_xCn, + stride_output_m, + stride_output_c, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_C: tl.constexpr, +): + """ + output = x @ H_pre: (M, C, n) @ (M, n, 1) = (M, C, 1) + """ + pid_m = tl.program_id(1) + pid_c = tl.program_id(0) + + tl.static_assert(n == 4) + tl.assume(M > 0) + tl.assume(C > 0) + tl.assume(n == 4) + tl.assume(stride_xm > 0 and stride_xCn == 1) + tl.assume(stride_output_m > 0 and stride_output_c == 1) + + tl.assume(BLOCK_SIZE_C % 32 == 0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) + offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) + mask_m = offs_m < M + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + offs_H_pre = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + H_pre = tl.load( + H_pre_ptr + offs_H_pre, mask=offs_H_pre < M * n, other=0.0, cache_modifier=".ca" + ) # (BLOCK_SIZE_M * n) + H_pre = H_pre.reshape(BLOCK_SIZE_M, 2, 2) + H_pre01, H_pre23 = tl.split(H_pre) + H_pre0, H_pre1 = tl.split(H_pre01) + H_pre2, H_pre3 = tl.split(H_pre23) # (BLOCK_SIZE_M, 1) + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) + + x = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2)) + x01, x23 = tl.split(x) + x0, x1 = tl.split(x01) + x2, x3 = tl.split(x23) # (BLOCK_SIZE_M, BLOCK_SIZE_C) + + # x @ H_pre: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, 1) + # triton doesn't support dot prod with inner dimension < 16, so we need to manually unroll the computation for n=4: + # x @ H_pre = x[:, :, 0] * H_pre[:, 0] + # + x[:, :, 1] * H_pre[:, 1] + # + x[:, :, 2] * H_pre[:, 2] + # + x[:, :, 3] * H_pre[:, 3] + out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C), dtype=tl.float32) + out_acc = tl.fma(x0, H_pre0[:, None], out_acc) + out_acc = tl.fma(x1, H_pre1[:, None], out_acc) + out_acc = tl.fma(x2, H_pre2[:, None], out_acc) + out_acc = tl.fma(x3, H_pre3[:, None], out_acc) + + out = out_acc.to(x.dtype) + + output_ptrs = output_ptr + offs_m[:, None] * stride_output_m + offs_c[None, :] * stride_output_c + tl.store(output_ptrs, out, mask=mask_m[:, None] & mask_c[None, :]) + + +@triton.autotune(configs=aggregate_config(), key=["M", "C"], reset_to_zero=["grad_H_pre_ptr"]) +@triton.jit +def _mhc_aggregate_bwd( + grad_output_ptr, # (M, C) + H_pre_ptr, # (M, n) + grad_H_pre_ptr, # (M, n) + x_ptr, # (M, C, n) + grad_x_ptr, # # (M, C, n) + M, + C, + n: tl.constexpr, + stride_grad_output_m, + stride_grad_output_c, + stride_xm, + stride_xCn, + stride_grad_xm, + stride_grad_xCn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_C: tl.constexpr, + precision: tl.constexpr, +): + """ + Forward: + out = x @ H_pre: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, 1) = (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) + Backward: + grad_H_pre = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) = (BLOCK_SIZE_M, n, 1) + grad_H_pre.T = grad_output.T @ x: (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) + which is easier to compute since transposing grad_H_pre and grad_output is just view change + grad_x = grad_output @ H_pre.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + """ + pid_m = tl.program_id(1) + pid_c = tl.program_id(0) + + tl.static_assert(n == 4) + tl.assume(M > 0) + tl.assume(C > 0) + tl.assume(n == 4) + tl.assume(stride_xm > 0 and stride_xCn == 1) + tl.assume(stride_grad_xm > 0 and stride_grad_xCn == 1) + tl.assume(stride_grad_output_m > 0 and stride_grad_output_c == 1) + + tl.assume(BLOCK_SIZE_C % 32 == 0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) + offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) + mask_m = offs_m < M + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + grad_output_ptrs = ( + grad_output_ptr + + offs_m[:, None] * stride_grad_output_m + + offs_c[None, :] * stride_grad_output_c + ) + grad_output = tl.load( + grad_output_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C) + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) + + grad_H_pre = tl.dot( + tl.reshape(grad_output, (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), + tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), + input_precision=precision, + out_dtype=tl.float32, + ) + grad_H_pre = tl.reshape(grad_H_pre, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) + offs_grad_H_pre = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + grad_H_pre_ptrs = grad_H_pre_ptr + offs_grad_H_pre + tl.atomic_add(grad_H_pre_ptrs, grad_H_pre, mask=offs_grad_H_pre < M * n, sem="relaxed") + + H_pre_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + H_pre = tl.load( + H_pre_ptr + H_pre_offs, mask=H_pre_offs < M * n, other=0.0, cache_modifier=".ca" + ) # (BLOCK_SIZE_M * n) + H_pre = tl.reshape(H_pre, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) + + # grad_x = grad_output @ H_pre.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + grad_x = grad_output[:, :, None] * H_pre[:, None, :] # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) + + grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn + tl.store( + grad_x_ptrs, + grad_x, + mask=mask_m[:, None] & mask_cn[None, :], + ) + + +def expand_combine_config(): + block_m = [1, 2, 4] + block_c = [128, 256] + warps = [1, 2] + stages = [1, 2, 3, 4] + + configs = [] + for m, c, w, s in itertools.product(block_m, block_c, warps, stages): + configs.append( + triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c}, num_warps=w, num_stages=s) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +@triton.autotune( + configs=expand_combine_config(), + key=["M", "C"], +) +@triton.jit +def _mhc_expand_combine_fwd( + f_ptr, # (M, C) + H_post_ptr, # (M, n) + x_ptr, # (M, C, n) + H_res_ptr, # (M, n, n) + output_ptr, # # (M, C, n) + M, + C, + n: tl.constexpr, + stride_fm, + stride_fc, + stride_xm, + stride_xCn, + stride_output_m, + stride_output_Cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_C: tl.constexpr, +): + """ + output = f @ H_post: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + x @ H_res: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + """ + pid_m = tl.program_id(1) + pid_c = tl.program_id(0) + + tl.static_assert(n == 4) + tl.assume(M > 0) + tl.assume(C > 0) + tl.assume(n == 4) + tl.assume(stride_fm > 0 and stride_fc == 1) + tl.assume(stride_xm > 0 and stride_xCn == 1) + tl.assume(stride_output_m > 0 and stride_output_Cn == 1) + + tl.assume(BLOCK_SIZE_C % 32 == 0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) + offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) + mask_m = offs_m < M + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc + f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + + offs_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + H_post = tl.load( + H_post_ptr + offs_H_post, mask=offs_H_post < M * n, other=0.0, cache_modifier=".ca" + ) + H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) + + # Residual connection path: res_out = f @ H_post: + # (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) + # Due to broadcasting, it's equivalent to a multiplicaiton + out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) + out_acc = tl.fma(f[:, :, None], H_post[:, None, :], out_acc) + + H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) + H_res = tl.load( + H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0, cache_modifier=".ca" + ) + H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + # Manifold connection path: manifold_out = H_res @ x: + # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + # triton doesn't support dot prod with inner dimension < 16, so we need to manually unroll the computation for n=4: + # x @ H_res = x[:, :, 0] @ H_res[:, 0, :] + # + x[:, :, 1] @ H_res[:, 1, :] + # + x[:, :, 2] @ H_res[:, 2, :] + # + x[:, :, 3] @ H_res[:, 3, :] + + x_reshape = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2)) + x01, x23 = tl.split( + x_reshape + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) + x0, x1 = tl.split(x01) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + x2, x3 = tl.split(x23) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + + H_resT = tl.reshape(tl.trans(H_res, (0, 2, 1)), (BLOCK_SIZE_M, n, 2, 2)) + H_res01, H_res23 = tl.split(H_resT) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) + H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + + out_acc = tl.fma(x0[:, :, None], H_res0[:, None, :], out_acc) + out_acc = tl.fma(x1[:, :, None], H_res1[:, None, :], out_acc) + out_acc = tl.fma(x2[:, :, None], H_res2[:, None, :], out_acc) + out_acc = tl.fma(x3[:, :, None], H_res3[:, None, :], out_acc) + + out = out_acc.to(x.dtype) + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) + + output_ptrs = ( + output_ptr + offs_m[:, None] * stride_output_m + offs_cn[None, :] * stride_output_Cn + ) + tl.store(output_ptrs, out, mask=mask_m[:, None] & mask_cn[None, :]) + + +@triton.autotune( + configs=expand_combine_config(), + key=["M", "C"], + reset_to_zero=["grad_H_post_ptr", "grad_H_res_ptr"], +) +@triton.jit +def _mhc_expand_combine_bwd( + grad_output_ptr, # (M, C, n) + f_ptr, # (M, C) + H_post_ptr, # (M, n) + x_ptr, # (M, C, n) + H_res_ptr, # (M, n, n) + grad_H_post_ptr, # (M, n) + grad_f_ptr, # (M, C) + grad_H_res_ptr, # (M, n, n) + grad_x_ptr, # (M, C, n) + M, + C, + n: tl.constexpr, + stride_grad_output_m, + stride_grad_output_Cn, + stride_fm, + stride_fc, + stride_xm, + stride_xCn, + stride_grad_fm, + stride_grad_fc, + stride_grad_xm, + stride_grad_xCn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_C: tl.constexpr, + precision: tl.constexpr, +): + """ + Each block + It reads + - (BLOCK_SIZE_M, BLOCK_SIZE_C) of f, which is the output of the attention / FFN module + - (BLOCK_SIZE_M, n) of H_post, which is applied for the transformation of the attention / FFN output + - (BLOCK_SIZE_M, BLOCK_SIZE_C, n) of x, which is the skip connection's input + - (BLOCK_SIZE_M, n*n) of H_res, which is applied for the transformation of the skip connection + and writes + - (BLOCK_SIZE_M, n) of grad_H_post + - (BLOCK_SIZE_M, BLOCK_SIZE_C) of grad_f + - (BLOCK_SIZE_M, n, n) of grad_H_res + - (BLOCK_SIZE_M, BLOCK_SIZE_C, n) of grad_x + + Forward: + out = f @ H_post + x @ H_res + Backward: + GEMM: + grad_H_post = f.T @ grad_output: (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) + grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) + Not GEMM: + grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, 1) = (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) + grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + """ + + pid_m = tl.program_id(1) + pid_c = tl.program_id(0) + + tl.static_assert(n == 4) + tl.assume(M > 0) + tl.assume(C > 0) + tl.assume(n == 4) + tl.assume(stride_fm > 0 and stride_fc == 1) + tl.assume(stride_xm > 0 and stride_xCn == 1) + tl.assume(stride_grad_output_m > 0 and stride_grad_output_Cn == 1) + tl.assume(stride_grad_fm > 0 and stride_grad_fc == 1) + tl.assume(stride_grad_xm > 0 and stride_grad_xCn == 1) + + tl.assume(BLOCK_SIZE_C % 32 == 0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) + offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) + mask_m = offs_m < M + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc + f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + + H_post_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + H_post = tl.load(H_post_ptr + H_post_offs, mask=H_post_offs < M * n, other=0.0) + H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) + + H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) + H_res = tl.load( + H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0 + ) # (BLOCK_SIZE_M, n, n) + H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) + + grad_out_ptrs = ( + grad_output_ptr + + offs_m[:, None] * stride_grad_output_m + + offs_cn[None, :] * stride_grad_output_Cn + ) + grad_out = tl.load( + grad_out_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) + grad_out = tl.reshape( + grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + # grad_H_post = f.T @ grad_output # (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) + grad_H_post = tl.dot( + tl.reshape(f, (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), + tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), + input_precision=precision, + out_dtype=tl.float32, + ) # (BLOCK_SIZE_M, 1, n) + grad_H_post = tl.reshape(grad_H_post, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) + offs_grad_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + grad_H_post_ptrs = grad_H_post_ptr + offs_grad_H_post + tl.atomic_add(grad_H_post_ptrs, grad_H_post, mask=offs_grad_H_post < M * n, sem="relaxed") + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) + x = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + # grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) + grad_H_res = tl.dot( + tl.trans(x, (0, 2, 1)), grad_out, input_precision=precision, out_dtype=tl.float32 + ) # (BLOCK_SIZE_M, n, n) + grad_H_res = tl.reshape(grad_H_res, (BLOCK_SIZE_M * n * n,)) # (BLOCK_SIZE_M * n * n) + offs_grad_H_res = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) + grad_H_res_ptrs = grad_H_res_ptr + offs_grad_H_res + tl.atomic_add( + grad_H_res_ptrs, grad_H_res.to(tl.float32), mask=offs_grad_H_res < M * n * n, sem="relaxed" + ) + + grad_out_reshape = tl.reshape( + grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) + grad_out01, grad_out23 = tl.split( + grad_out_reshape + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) + grad_out0, grad_out1 = tl.split( + grad_out01 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + grad_out2, grad_out3 = tl.split( + grad_out23 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + + # grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, 1, n) @ (BLOCK_SIZE_M, n, BLOCK_SIZE_C) = (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) + # Triton doesn't support dot prod with inner dimension < 16, so we need to hack this: + # grad_f = grad_out[:, :, 0] @ H_post.T[:, 0, :] (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, 1) + # + grad_out[:, :, 1] @ H_post.T[:, 1, :] + # + grad_out[:, :, 2] @ H_post.T[:, 2, :] + # + grad_out[:, :, 3] @ H_post.T[:, 3, :] + # where H_post.T[:, i, :] = H_post[:, :, i] + H_post = tl.reshape(H_post, (BLOCK_SIZE_M, 2, 2)) + H_post01, H_post23 = tl.split(H_post) # (BLOCK_SIZE_M, 2), (BLOCK_SIZE_M, 2) + H_post0, H_post1 = tl.split(H_post01) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) + H_post2, H_post3 = tl.split(H_post23) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) + + grad_f_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C), dtype=tl.float32) + # (BLOCK_SIZE_M, BLOCK_SIZE_C) * (BLOCK_SIZE_M, 1) -> (BLOCK_SIZE_M, BLOCK_SIZE_C) + grad_f_acc = tl.fma(grad_out0, H_post0[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out1, H_post1[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out2, H_post2[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out3, H_post3[:, None], grad_f_acc) + grad_f = grad_f_acc.to(f.dtype) + + grad_f_ptrs = grad_f_ptr + offs_m[:, None] * stride_grad_fm + offs_c[None, :] * stride_grad_fc + tl.store(grad_f_ptrs, grad_f, mask=mask_m[:, None] & mask_c[None, :]) + + # grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) + # The inner dim is n=4 which is too small for triton, so we will manually unroll the matmul + # grad_x = grad_out[:, :, 0] @ H_res.T[:, 0, :] + # + grad_out[:, :, 1] @ H_res.T[:, 1, :] + # + grad_out[:, :, 2] @ H_res.T[:, 2, :] + # + grad_out[:, :, 3] @ H_res.T[:, 3, :] + # where H_res.T[:, i, :] = H_res[:, :, i] + # Due to broadcasting, it's equivalent to multiplying each H_res[:, i, :].T with grad_out[:, i, :] + + H_res_reshape = tl.reshape(H_res, (BLOCK_SIZE_M, n, 2, 2)) # (BLOCK_SIZE_M, n, 2, 2) + H_res01, H_res23 = tl.split(H_res_reshape) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) + H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + + grad_x_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) + grad_x_acc = tl.fma(grad_out0[:, :, None], H_res0[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out1[:, :, None], H_res1[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out2[:, :, None], H_res2[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out3[:, :, None], H_res3[:, None, :], grad_x_acc) + + grad_x = grad_x_acc.to(x.dtype) + grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) + + grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn + tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_cn[None, :]) + + +@triton.autotune( + configs=expand_combine_config(), + key=["M", "C"], +) +@triton.jit +def _mhc_expand_combine_with_bias_fwd( + f_ptr, # (M, C) + bias_ptr, # (C,) + H_post_ptr, # (M, n) + x_ptr, # (M, C, n) + H_res_ptr, # (M, n, n) + output_ptr, # # (M, C, n) + M, + C, + n: tl.constexpr, + stride_fm, + stride_fc, + stride_bias, + stride_xm, + stride_xCn, + stride_output_m, + stride_output_Cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_C: tl.constexpr, +): + """ + output = (f + bias[None, :, None]) @ H_post: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + x @ H_res: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + """ + pid_m = tl.program_id(1) + pid_c = tl.program_id(0) + + tl.static_assert(n == 4) + tl.assume(M > 0) + tl.assume(C > 0) + tl.assume(n == 4) + tl.assume(stride_fm > 0 and stride_fc == 1) + tl.assume(stride_bias == 1) + tl.assume(stride_xm > 0 and stride_xCn == 1) + tl.assume(stride_output_m > 0 and stride_output_Cn == 1) + + tl.assume(BLOCK_SIZE_C % 32 == 0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) + offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) + mask_m = offs_m < M + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc + f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + bias = tl.load(bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0) # (BLOCK_SIZE_C,) + + offs_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + H_post = tl.load( + H_post_ptr + offs_H_post, mask=offs_H_post < M * n, other=0.0, cache_modifier=".ca" + ) + H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) + + # Residual connection path: res_out = f @ H_post + bias @ H_post: + # (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) + # Due to broadcasting, it's equivalent to a multiplicaiton + out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) + out_acc = tl.fma(bias[None, :, None], H_post[:, None, :], out_acc) + out_acc = tl.fma(f[:, :, None], H_post[:, None, :], out_acc) + + H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) + H_res = tl.load( + H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0, cache_modifier=".ca" + ) + H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + # Manifold connection path: manifold_out = H_res @ x: + # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + # triton doesn't support dot prod with inner dimension < 16, so we need to manually unroll the computation for n=4: + # x @ H_res = x[:, :, 0] @ H_res[:, 0, :] + # + x[:, :, 1] @ H_res[:, 1, :] + # + x[:, :, 2] @ H_res[:, 2, :] + # + x[:, :, 3] @ H_res[:, 3, :] + + x_reshape = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2)) + x01, x23 = tl.split( + x_reshape + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) + x0, x1 = tl.split(x01) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + x2, x3 = tl.split(x23) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + + H_resT = tl.reshape(tl.trans(H_res, (0, 2, 1)), (BLOCK_SIZE_M, n, 2, 2)) + H_res01, H_res23 = tl.split(H_resT) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) + H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + + out_acc = tl.fma(x0[:, :, None], H_res0[:, None, :], out_acc) + out_acc = tl.fma(x1[:, :, None], H_res1[:, None, :], out_acc) + out_acc = tl.fma(x2[:, :, None], H_res2[:, None, :], out_acc) + out_acc = tl.fma(x3[:, :, None], H_res3[:, None, :], out_acc) + + out = out_acc.to(x.dtype) + out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) + + output_ptrs = ( + output_ptr + offs_m[:, None] * stride_output_m + offs_cn[None, :] * stride_output_Cn + ) + tl.store(output_ptrs, out, mask=mask_m[:, None] & mask_cn[None, :]) + + +@triton.autotune( + configs=expand_combine_config(), + key=["M", "C"], + reset_to_zero=["grad_H_post_ptr", "grad_H_res_ptr", "grad_bias_ptr"], +) +@triton.jit +def _mhc_expand_combine_with_bias_bwd( + grad_output_ptr, # (M, C, n) + f_ptr, # (M, C) + bias_ptr, # (C,) + H_post_ptr, # (M, n) + x_ptr, # (M, C, n) + H_res_ptr, # (M, n, n) + grad_H_post_ptr, # (M, n) + grad_f_ptr, # (M, C) + grad_bias_ptr, # (C,) + grad_H_res_ptr, # (M, n, n) + grad_x_ptr, # (M, C, n) + M, + C, + n: tl.constexpr, + stride_grad_output_m, + stride_grad_output_Cn, + stride_fm, + stride_fc, + stride_bias, + stride_xm, + stride_xCn, + stride_grad_fm, + stride_grad_fc, + stride_grad_bias, + stride_grad_xm, + stride_grad_xCn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_C: tl.constexpr, + precision: tl.constexpr, +): + """ + Each block + It reads + - (BLOCK_SIZE_M, BLOCK_SIZE_C) of f, which is the output of the attention / FFN module + - (BLOCK_SIZE_M, n) of H_post, which is applied for the transformation of the attention / FFN output + - (BLOCK_SIZE_M, BLOCK_SIZE_C, n) of x, which is the skip connection's input + - (BLOCK_SIZE_M, n*n) of H_res, which is applied for the transformation of the skip connection + and writes + - (BLOCK_SIZE_M, n) of grad_H_post + - (BLOCK_SIZE_M, BLOCK_SIZE_C) of grad_f + - (BLOCK_SIZE_M, n, n) of grad_H_res + - (BLOCK_SIZE_M, BLOCK_SIZE_C, n) of grad_x + + Forward: + out = f @ H_post + x @ H_res + Backward: + GEMM: + grad_H_post = f.T @ grad_output: (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) + grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) + Not GEMM: + grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, 1) = (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) + grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + """ + + pid_m = tl.program_id(1) + pid_c = tl.program_id(0) + + tl.static_assert(n == 4) + tl.assume(M > 0) + tl.assume(C > 0) + tl.assume(n == 4) + tl.assume(stride_fm > 0 and stride_fc == 1) + tl.assume(stride_bias == 1) + tl.assume(stride_xm > 0 and stride_xCn == 1) + tl.assume(stride_grad_output_m > 0 and stride_grad_output_Cn == 1) + tl.assume(stride_grad_fm > 0 and stride_grad_fc == 1) + tl.assume(stride_grad_bias == 1) + tl.assume(stride_grad_xm > 0 and stride_grad_xCn == 1) + + tl.assume(BLOCK_SIZE_C % 32 == 0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) + offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) + mask_m = offs_m < M + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc + f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + + bias = tl.load(bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0) # (BLOCK_SIZE_C,) + + H_post_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + H_post = tl.load(H_post_ptr + H_post_offs, mask=H_post_offs < M * n, other=0.0) + H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) + + H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) + H_res = tl.load( + H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0 + ) # (BLOCK_SIZE_M, n, n) + H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) + + grad_out_ptrs = ( + grad_output_ptr + + offs_m[:, None] * stride_grad_output_m + + offs_cn[None, :] * stride_grad_output_Cn + ) + grad_out = tl.load( + grad_out_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) + grad_out = tl.reshape( + grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + # grad_H_post = f.T @ grad_output # (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) + grad_H_post = tl.dot( + tl.reshape(f, (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), + tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), + input_precision=precision, + out_dtype=tl.float32, + ) # (BLOCK_SIZE_M, 1, n) + grad_H_post = tl.dot( + tl.broadcast_to(bias[None, None, :], (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), + tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), + acc=grad_H_post, + input_precision=precision, + out_dtype=tl.float32, + ) # (BLOCK_SIZE_M, 1, n) + grad_H_post = tl.reshape(grad_H_post, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) + offs_grad_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + grad_H_post_ptrs = grad_H_post_ptr + offs_grad_H_post + tl.atomic_add(grad_H_post_ptrs, grad_H_post, mask=offs_grad_H_post < M * n, sem="relaxed") + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) + x = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + + # grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) + grad_H_res = tl.dot( + tl.trans(x, (0, 2, 1)), grad_out, input_precision=precision, out_dtype=tl.float32 + ) # (BLOCK_SIZE_M, n, n) + grad_H_res = tl.reshape(grad_H_res, (BLOCK_SIZE_M * n * n,)) # (BLOCK_SIZE_M * n * n) + offs_grad_H_res = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) + grad_H_res_ptrs = grad_H_res_ptr + offs_grad_H_res + tl.atomic_add( + grad_H_res_ptrs, grad_H_res.to(tl.float32), mask=offs_grad_H_res < M * n * n, sem="relaxed" + ) + + grad_out_reshape = tl.reshape( + grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) + grad_out01, grad_out23 = tl.split( + grad_out_reshape + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) + grad_out0, grad_out1 = tl.split( + grad_out01 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + grad_out2, grad_out3 = tl.split( + grad_out23 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) + + # grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, 1, n) @ (BLOCK_SIZE_M, n, BLOCK_SIZE_C) = (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) + # Triton doesn't support dot prod with inner dimension < 16, so we need to hack this: + # = grad_out[:, :, 0] @ H_post.T[:, 0, :] (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, 1) + # + grad_out[:, :, 1] @ H_post.T[:, 1, :] + # + grad_out[:, :, 2] @ H_post.T[:, 2, :] + # + grad_out[:, :, 3] @ H_post.T[:, 3, :] + # where H_post.T[:, i, :] = H_post[:, :, i] + H_post = tl.reshape(H_post, (BLOCK_SIZE_M, 2, 2)) + H_post01, H_post23 = tl.split(H_post) # (BLOCK_SIZE_M, 2), (BLOCK_SIZE_M, 2) + H_post0, H_post1 = tl.split(H_post01) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) + H_post2, H_post3 = tl.split(H_post23) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) + + grad_f_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C), dtype=tl.float32) + # (BLOCK_SIZE_M, BLOCK_SIZE_C) * (BLOCK_SIZE_M, 1) -> (BLOCK_SIZE_M, BLOCK_SIZE_C) + grad_f_acc = tl.fma(grad_out0, H_post0[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out1, H_post1[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out2, H_post2[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out3, H_post3[:, None], grad_f_acc) + grad_f = grad_f_acc.to(f.dtype) + + grad_f_ptrs = grad_f_ptr + offs_m[:, None] * stride_grad_fm + offs_c[None, :] * stride_grad_fc + tl.store(grad_f_ptrs, grad_f, mask=mask_m[:, None] & mask_c[None, :]) + + grad_bias = tl.sum(grad_f_acc, axis=0) # (BLOCK_SIZE_C,) + grad_bias_ptrs = grad_bias_ptr + offs_c * stride_grad_bias + tl.atomic_add(grad_bias_ptrs, grad_bias, mask=mask_c, sem="relaxed") + + # grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) + # The inner dim is n=4 which is too small for triton, so we will manually unroll the matmul + # grad_x = grad_out[:, :, 0] @ H_res.T[:, 0, :] + # + grad_out[:, :, 1] @ H_res.T[:, 1, :] + # + grad_out[:, :, 2] @ H_res.T[:, 2, :] + # + grad_out[:, :, 3] @ H_res.T[:, 3, :] + # where H_res.T[:, i, :] = H_res[:, :, i] + # Due to broadcasting, it's equivalent to multiplying each H_res[:, i, :].T with grad_out[:, i, :] + + H_res_reshape = tl.reshape(H_res, (BLOCK_SIZE_M, n, 2, 2)) # (BLOCK_SIZE_M, n, 2, 2) + H_res01, H_res23 = tl.split(H_res_reshape) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) + H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + + grad_x_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) + grad_x_acc = tl.fma(grad_out0[:, :, None], H_res0[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out1[:, :, None], H_res1[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out2[:, :, None], H_res2[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out3[:, :, None], H_res3[:, None, :], grad_x_acc) + + grad_x = grad_x_acc.to(x.dtype) + grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) + + grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn + tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_cn[None, :]) diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 6adba23a8f..fdfa47da8f 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -48,7 +48,9 @@ .value("NVTE_SBHD_2BSHD", NVTE_QKV_Format::NVTE_SBHD_2BSHD) \ .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ - .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD); \ + .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD) \ + .value("NVTE_BHSD", NVTE_QKV_Format::NVTE_BHSD) \ + .value("NVTE_QKV_Format_NOT_SET", NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET); \ pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ @@ -74,7 +76,8 @@ .value("NVTE_Paged_KV_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD) \ .value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \ .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ - .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \ + .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD) \ + .value("NVTE_BHSD_BHSD_BHSD", NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); \ pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 29d0848381..f54a043fd2 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -11,7 +11,6 @@ from jax.ad_checkpoint import checkpoint_name import jax import jax.numpy as jnp -from flax.linen import make_attention_mask from transformer_engine_jax import NVTE_Bias_Type from transformer_engine_jax import NVTE_Mask_Type @@ -541,6 +540,149 @@ def run_length_fill(segment_ids) -> jnp.ndarray: return run_length_segment_id_shape.reshape(orig_shape) +def _get_seqlens_offsets_thd( + segment_ids_q, + segment_ids_kv, + segment_pos_q, + segment_pos_kv, + attn_mask_type, + max_segments_per_seq, +): + """O(T * max_segments_per_seq) replacement for the older O(T^2) mask-based slow path. + Returns (q_seqlen, kv_seqlen, q_offset, kv_offset) values to match the reference older mask-based path: + segment_mask = make_attention_mask(q_ids, kv_ids, equal) + segment_mask_with_id = make_attention_mask(q_ids, kv_ids, equal * q_id) + attn_mask = segment_mask AND (causal_or_brcm_or_none) + attn_mask_with_id = where(attn_mask, segment_mask_with_id, 0) + row_ids = reduce_max(attn_mask_with_id, axis=kv) # [B, T_q] + col_ids = reduce_max(attn_mask_with_id, axis=q) # [B, T_kv] + seqlens/offsets = bincount(...) / find_offsets(...) + The two reductions are expressed equivalently as per-segment aggregates: + - causal: row_ids[q] = q_seg_id iff seg_pos_q[q] >= min(seg_pos_kv over same-seg KV) + - brcm: row_ids[q] = q_seg_id iff (run_len_q - seg_pos_q) >= + min(run_len_kv - seg_pos_kv over same-seg KV) + - padding: row_ids[q] = q_seg_id iff q_seg_id appears in KV + (and symmetrically for col_ids with max/<=). + """ + + # Example: For striping P2P causal attention (but this logic also applies for non-CP fused attn) + # pre-striping and sharding: segment_ids = [[1 1 1 1 2 2 2 2]], segment_pos = [[0 1 2 3 0 1 2 3]] + # post-striping and sharding (striped CP=2, Q from rank 0 × KV from rank 1, max_segments_per_seq=2): + # segment_ids_q = [1 1 2 2] segment_pos_q = [0 2 0 2] → q_key = [0 2 0 2] + # segment_ids_kv = [1 1 2 2] segment_pos_kv = [1 3 1 3] → kv_key = [1 3 1 3] + # Q-side — kv_agg[s] = min(kv_key over same-seg KV), fill = max_fill_val = 5 (assumed to be large enough): + # scatter (rows = kv tokens, cols = segs): + # [5 1 5 / 5 3 5 / 5 5 1 / 5 5 3] → reduce min → kv_agg = [5 1 1] + # q_ok = q_key >= kv_agg[seg_ids_q] = [0 2 0 2] >= [1 1 1 1] = [F T F T] + # KV-side — q_agg[s] = max(q_key over same-seg Q), fill = neg_fill_val = -1 (assumed to be small enough): + # scatter: [-1 0 -1 / -1 2 -1 / -1 -1 0 / -1 -1 2] → reduce max → q_agg = [-1 2 2] + # kv_ok = kv_key <= q_agg[seg_ids_kv] = [1 3 1 3] <= [2 2 2 2] = [T F T F] + # Outer combiner: + # row_ids = [0 1 0 2] col_ids = [1 0 2 0] + # q_seqlen = [1 1] kv_seqlen = [1 1] + # q_offset = [1 3 -1] kv_offset = [0 2 -1] + def _row_and_col_ids(): + if attn_mask_type.is_bottom_right(): + # BRCM: mask[q][kv] = (same seg) AND (q_key <= kv_key). + rl_q = run_length_fill(segment_ids_q) + rl_kv = run_length_fill(segment_ids_kv) + q_key = (rl_q - segment_pos_q).astype(jnp.int32) + kv_key = (rl_kv - segment_pos_kv).astype(jnp.int32) + + # Use large positive and negative values as fill values for the KV keys and Q keys respectively + max_fill_val = jnp.asarray(jnp.iinfo(jnp.int32).max, dtype=jnp.int32) + neg_fill_val = jnp.asarray(-1, dtype=jnp.int32) + # Creates a one-hot encoding mask of the KV segment ids (size [B, T_kv, max_segments_per_seq+1]) + # i.e. each row has only one True value, which is the segment id of the row. + kv_oh = jax.nn.one_hot(segment_ids_kv, max_segments_per_seq + 1, dtype=jnp.bool_) + # Mask the KV keys with the valid segment ids (size [B, T_kv, 1]) + kv_key_masked = jnp.where(segment_ids_kv != 0, kv_key, neg_fill_val)[..., None] + # Scatter each KV key (i.e. seg pos) into it's own segment column + kv_agg = jnp.where(kv_oh, kv_key_masked, neg_fill_val) + kv_agg = jnp.max(kv_agg, axis=-2) + # Define causal relationship: Q is attended iff q_key <= max(kv_key over same-seg KV) + q_has_match = q_key <= jnp.take_along_axis( + kv_agg, segment_ids_q.astype(jnp.int32), axis=-1 + ) + + # Symmetric to the Q case, but with KV and Q swapped + q_oh = jax.nn.one_hot(segment_ids_q, max_segments_per_seq + 1, dtype=jnp.bool_) + q_key_masked = jnp.where(segment_ids_q != 0, q_key, max_fill_val)[..., None] + q_agg = jnp.where(q_oh, q_key_masked, max_fill_val) + q_agg = jnp.min(q_agg, axis=-2) + # Define causal relationship: KV is attended iff kv_key >= min(q_key over same-seg Q) + kv_has_match = kv_key >= jnp.take_along_axis( + q_agg, segment_ids_kv.astype(jnp.int32), axis=-1 + ) + elif attn_mask_type.is_causal(): + # CM: mask[q][kv] = (same_seg) AND (q_pos >= kv_pos). + q_key = segment_pos_q.astype(jnp.int32) + kv_key = segment_pos_kv.astype(jnp.int32) + + # Use large positive and negative values as a fill value for the KV keys and Q keys respectively + max_fill_val = jnp.asarray(jnp.iinfo(jnp.int32).max, dtype=jnp.int32) + neg_fill_val = jnp.asarray(-1, dtype=jnp.int32) + + # Creates a one-hot encoding mask of the KV segment ids (size [B, T_kv, max_segments_per_seq+1]) + # i.e. each row has only one True value, which is the segment id of the row. + kv_oh = jax.nn.one_hot(segment_ids_kv, max_segments_per_seq + 1, dtype=jnp.bool_) + # Mask the KV keys with the valid segment ids (size [B, T_kv, 1]) + kv_key_masked = jnp.where(segment_ids_kv != 0, kv_key, max_fill_val)[..., None] + # Scatter each KV key (i.e. seg pos) into it's own segment column + kv_agg = jnp.where(kv_oh, kv_key_masked, max_fill_val) + kv_agg = jnp.min(kv_agg, axis=-2) + # Define causal relationship: Q is attended iff q_key >= min(kv_key over same-seg KV) + q_has_match = q_key >= jnp.take_along_axis( + kv_agg, segment_ids_q.astype(jnp.int32), axis=-1 + ) + + # Symmetric to the Q case, but with KV and Q swapped + q_oh = jax.nn.one_hot(segment_ids_q, max_segments_per_seq + 1, dtype=jnp.bool_) + q_key_masked = jnp.where(segment_ids_q != 0, q_key, neg_fill_val)[..., None] + q_agg = jnp.where(q_oh, q_key_masked, neg_fill_val) + q_agg = jnp.max(q_agg, axis=-2) + # Define causal relationship: KV is attended iff kv_key <= max(q_key over same-seg Q) + kv_has_match = kv_key <= jnp.take_along_axis( + q_agg, segment_ids_kv.astype(jnp.int32), axis=-1 + ) + else: + # Padding-only: row_ids[q] = q_seg_id iff q_seg_id is present in KV (and q not pad). + kv_seg_ids_present = jax.nn.one_hot( + segment_ids_kv, max_segments_per_seq + 1, dtype=jnp.bool_ + ).any(axis=-2) + q_seg_ids_present = jax.nn.one_hot( + segment_ids_q, max_segments_per_seq + 1, dtype=jnp.bool_ + ).any(axis=-2) + q_has_match = jnp.take_along_axis( + kv_seg_ids_present, segment_ids_q.astype(jnp.int32), axis=-1 + ) & (segment_ids_q != 0) + kv_has_match = jnp.take_along_axis( + q_seg_ids_present, segment_ids_kv.astype(jnp.int32), axis=-1 + ) & (segment_ids_kv != 0) + + row_ids = jnp.where(q_has_match, segment_ids_q, 0).astype(jnp.int32) + col_ids = jnp.where(kv_has_match, segment_ids_kv, 0).astype(jnp.int32) + return row_ids, col_ids + + row_ids, col_ids = _row_and_col_ids() + + bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_segments_per_seq + 1)) + q_seqlen = bincount_vmap(row_ids)[..., 1:] + kv_seqlen = bincount_vmap(col_ids)[..., 1:] + + def _find_offsets(x): + same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0) + first_column = x[..., :1] != 0 + boundaries = jnp.concatenate([first_column, same_as_previous], axis=-1) + return jax.vmap(partial(jnp.argwhere, size=(max_segments_per_seq + 1), fill_value=-1))( + boundaries + ).squeeze(-1) + + q_offset = _find_offsets(row_ids) + kv_offset = _find_offsets(col_ids) + return q_seqlen, kv_seqlen, q_offset, kv_offset + + def _segment_ids_pos_to_seqlens_offsets( segment_ids_q, segment_ids_kv, @@ -550,9 +692,52 @@ def _segment_ids_pos_to_seqlens_offsets( window_size, max_segments_per_seq, ): + """Compute per-segment seqlens and start offsets(currently only used for THD) + Given segment-id and segment-position tensors for Q and KV, + returns the four metadata tensors cuDNN needed for variable-length attention: + q_seqlen : [..., max_segments_per_seq] # valid Q tokens per segment + kv_seqlen : [..., max_segments_per_seq] # valid KV tokens per segment + q_offset : [..., max_segments_per_seq + 1] # start index of each Q segment + kv_offset : [..., max_segments_per_seq + 1] # start index of each KV segment + + Args: + segment_ids_q: int32 [..., T_q] per-token segment id; 0 == padding + segment_ids_kv: int32 [..., T_kv] same convention as segment_ids_q + segment_pos_q: int32 [..., T_q] per-token position inside its segment + segment_pos_kv: int32 [..., T_kv] same convention as segment_pos_q + attn_mask_type: AttnMaskType. Selects the mask predicate used to decide + which positions are valid (top-left causal vs + bottom-right causal vs. padding-only) + window_size: Optional sliding-window tuple ``(left, right)`` or None + Used here only as a fast-path eligibility hint + max_segments_per_seq: maximum number of segments expected per row + Used to size the bincount / argwhere outputs + + Routing (only invoked for THD qkv_layout): + 1. Fast path -- ``_segment_ids_pos_to_seqlens_offsets_fast_causal_path``. + O(T) per row. Counts all segment tokens via bincount on + segment_ids and trims at most one token per segment at the + boundary. Used for: + - top-left CAUSAL / PADDING_CAUSAL with ``window_size is None`` + - SWA with ``window_size == (-1, -1)`` and not bottom-right + Bottom-right causal cross-attention is excluded: the boundary + trim leaves kv_seqlen short by one per active segment, which + shifts the BRCM bottom-right alignment by one KV per Q row. + + 2. Slow path -- ``_get_seqlens_offsets_thd``. + O(T * max_segments_per_seq) per row. Per-segment min/max + aggregation that is equivalent to the older O(T^2) + mask-based reference for top-left causal, bottom-right causal, + and padding-only masks. Required under ring attention where + ``segment_ids_q != segment_ids_kv`` in rotated steps. + + Returns: + Tuple ``(q_seqlen, kv_seqlen, q_offset, kv_offset)`` with shapes as + above. Inactive segment slots are filled with 0 in seqlens and -1 + in offsets. + """ # TODO(mgoldfarb-nvidia): Consider an opt-in for arbitrary masking if needed here. # Computing the full mask is expensive due to quadratic expansion of Q * KV masking. - # Assumptions for cudnn causal mask correctness. # 1. Segments are monotonic [4 4 4 0 0 5 5 5 6 6 0 0] # 2. No intra-segment padding, only inter-segment paddding allowed @@ -561,82 +746,30 @@ def _segment_ids_pos_to_seqlens_offsets( # 0 x x # 4 x x x x x # 8 x x x x x x x x - # # This fast path avoids expanding the mask to Q * KV matrix and instead allows us to # examine only O(Q+KV) elements. - - # For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation - # using the segment ids and pos along with mask type (causal or brcm) is sufficient. - # It does not need to involve SW for this mask's creation - - # Currently, this function is only exercised for THD qkv_layout. - - # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well - if (attn_mask_type.is_causal() and window_size is None) or ( - window_size == (-1, -1) and not attn_mask_type.is_bottom_right() - ): + # The fast causal path encodes TOP-LEFT causal semantics via + # valid[q][kv] = (segment_pos_q >= segment_pos_kv) + # which is only equivalent to BRCM when s_q == s_kv (self-attention). For + # cross-attention (s_q != s_kv), BRCM diverges from top-left causal, so we + # must route bottom-right masks to the slow path. + + # Fast path: O(T) per row. + if ( + attn_mask_type.is_causal() and not attn_mask_type.is_bottom_right() and window_size is None + ) or (window_size == (-1, -1) and not attn_mask_type.is_bottom_right()): return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq ) - - # (1 = attend, 0 = masked) - segment_mask = make_attention_mask( - segment_ids_q, - segment_ids_kv, - jnp.equal, - ) - segment_mask_with_id = make_attention_mask( + # Slow path: O(T * max_segments_per_seq) per row. + return _get_seqlens_offsets_thd( segment_ids_q, segment_ids_kv, - lambda x, y: jnp.equal(x, y) * x, - ) - # TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied - attn_mask = segment_mask - if attn_mask_type.is_bottom_right(): - run_length_out_q = run_length_fill(segment_ids_q) - run_length_out_kv = run_length_fill(segment_ids_kv) - # Example for brcm: - # run_length_out_q: [3 3 3 0 4 4 4 4] - # segment_pos_q: [0 1 2 3 0 1 2 3] - # segment_ids_q: [1 1 1 0 2 2 2 2] - # run_length_out_kv: [4 4 4 4 0 0 10 10 10 10 10 10 10 10 10 10] - # segment_pos_kv: [0 1 2 3 4 5 0 1 2 3 4 5 6 7 8 9] - # segment_ids_kv: [1 1 1 1 0 0 2 2 2 2 2 2 2 2 2 2] - # brcm: [[[1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0] - # [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0] - # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1] - # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1] - # [1 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0] - # [1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0] - # [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0] - # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]]] - # attn_mask(noswa):[[[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0] - # [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0] - # [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] - # [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]]] - bottom_right_causal_mask = make_attention_mask( - run_length_out_q - segment_pos_q, - run_length_out_kv - segment_pos_kv, - jnp.less_equal, - ) - attn_mask = jnp.logical_and(segment_mask, bottom_right_causal_mask) - elif attn_mask_type.is_causal(): - causal_mask = make_attention_mask( - segment_pos_q, - segment_pos_kv, - jnp.greater_equal, - ) - attn_mask = jnp.logical_and(segment_mask, causal_mask) - - attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0) - q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset( - attn_mask_with_id, max_segments_per_seq + segment_pos_q, + segment_pos_kv, + attn_mask_type, + max_segments_per_seq, ) - return q_seqlen, kv_seqlen, q_offset, kv_offset def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type): diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 92e67ac191..76f2d92891 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -145,19 +145,28 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; + auto q_shape = is_ragged + ? std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim} + : std::vector{input_batch, q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto k_shape = is_ragged + ? std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim} + : std::vector{input_batch, kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); - auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; + auto v_shape = is_ragged + ? std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim} + : std::vector{input_batch, kv_max_seqlen, num_gqa_groups, v_head_dim}; auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); + auto o_shape = is_ragged ? std::vector{input_batch * q_max_seqlen, attn_heads, v_head_dim} + : std::vector{input_batch, q_max_seqlen, attn_heads, v_head_dim}; auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); // F16 doesn't use this tensor auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); - auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto o_tensor = TensorWrapper(nullptr, o_shape, dtype); auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); @@ -168,7 +177,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( nvte_tensor_pack_create(&aux_output_tensors); TensorWrapper query_workspace_tensor; - auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; size_t min_num_segments = input_batch; @@ -191,9 +199,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, query_workspace_tensor.data(), - nullptr); + scaling_factor, dropout_probability, qkv_layout, nvte_get_q_format(qkv_layout), + NVTE_QKV_Format_NOT_SET, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_output_tensors); @@ -257,7 +265,8 @@ static void FusedAttnForwardImpl( /* Output tensors */ auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 - auto o_shape = std::vector{input_batch * q_max_seqlen, attn_heads, v_head_dim}; + auto o_shape = is_ragged ? std::vector{input_batch * q_max_seqlen, attn_heads, v_head_dim} + : std::vector{input_batch, q_max_seqlen, attn_heads, v_head_dim}; auto o_tensor = TensorWrapper(output, o_shape, dtype); /* Prepare RNG state */ @@ -285,9 +294,15 @@ static void FusedAttnForwardImpl( void *q_ptr = q; void *k_ptr = k; void *v_ptr = v; - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; - auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; + auto q_shape = is_ragged + ? std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim} + : std::vector{input_batch, q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = is_ragged + ? std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim} + : std::vector{input_batch, kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = is_ragged + ? std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim} + : std::vector{input_batch, kv_max_seqlen, num_gqa_groups, v_head_dim}; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { // QKV packed in q: [batch*seqlen, 3, heads, dim] @@ -328,8 +343,9 @@ static void FusedAttnForwardImpl( q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, workspace_tensor.data(), stream); + scaling_factor, dropout_probability, qkv_layout, nvte_get_q_format(qkv_layout), + NVTE_QKV_Format_NOT_SET, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_output_tensors); } @@ -418,17 +434,26 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; + auto q_shape = is_ragged + ? std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim} + : std::vector{input_batch, q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto k_shape = is_ragged + ? std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim} + : std::vector{input_batch, kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); - auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; + auto v_shape = is_ragged + ? std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim} + : std::vector{input_batch, kv_max_seqlen, num_gqa_groups, v_head_dim}; auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype); - auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, v_head_dim}; + auto output_shape = is_ragged + ? std::vector{input_batch * q_max_seqlen, attn_heads, v_head_dim} + : std::vector{input_batch, q_max_seqlen, attn_heads, v_head_dim}; auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); @@ -443,7 +468,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( TensorWrapper query_workspace_tensor; - auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; size_t min_num_segments = input_batch; @@ -469,18 +493,19 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( auto dummy_ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, bottom_right_diagonal, deterministic, false, - query_workspace_tensor.data(), nullptr); + nvte_fused_attn_bwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, nvte_get_q_format(qkv_layout), + nvte_get_q_format(qkv_layout), qkv_layout, NVTE_QKV_Format_NOT_SET, NVTE_QKV_Format_NOT_SET, + bias_type, mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, false, query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_input_tensors); @@ -503,7 +528,9 @@ static void FusedAttnBackwardImpl( FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ - auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, v_head_dim}; + auto output_shape = is_ragged + ? std::vector{input_batch * q_max_seqlen, attn_heads, v_head_dim} + : std::vector{input_batch, q_max_seqlen, attn_heads, v_head_dim}; auto output_tensor = TensorWrapper(output, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); @@ -530,7 +557,7 @@ static void FusedAttnBackwardImpl( bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); - /* Call the underly NVTE API */ + /* Call the underlying NVTE API */ // Prepare Q, K, V pointers and shapes based on layout void *q_ptr = q; void *k_ptr = k; @@ -538,9 +565,15 @@ static void FusedAttnBackwardImpl( void *dq_ptr = dq; void *dk_ptr = dk; void *dv_ptr = dv; - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; - auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; + auto q_shape = is_ragged + ? std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim} + : std::vector{input_batch, q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = is_ragged + ? std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim} + : std::vector{input_batch, kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = is_ragged + ? std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim} + : std::vector{input_batch, kv_max_seqlen, num_gqa_groups, v_head_dim}; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { // QKV packed in q: [batch*seqlen, 3, heads, dim] @@ -596,17 +629,18 @@ static void FusedAttnBackwardImpl( } } - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), dsoftmax_offset_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, - kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, - bottom_right_diagonal, deterministic, false, workspace_tensor.data(), stream); + nvte_fused_attn_bwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), + dsoftmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, nvte_get_q_format(qkv_layout), + nvte_get_q_format(qkv_layout), qkv_layout, NVTE_QKV_Format_NOT_SET, NVTE_QKV_Format_NOT_SET, + bias_type, mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, false, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_input_tensors); } diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index ecf3af2bf0..d1c77b2277 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -29,6 +29,7 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.quantized_tensor import ( QuantizedTensorStorage, prepare_for_saving, @@ -36,7 +37,6 @@ ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.constants import ( - TE_DType, QKVLayouts, dist_group_type, ) @@ -72,6 +72,7 @@ print_quantizers, ConvertTHDtoBSHD, ConvertBSHDtoTHD, + mxfp8_quantize_fast_path, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import ( AttentionLogging as attn_log, @@ -193,15 +194,27 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] ] - q_fp8, k_fp8, v_fp8 = combine_and_quantize( - qkv_layout, query_layer, key_layer, value_layer, quantizer + # always in sbhd_sbhd_sbhd shape at this point + q_fp8, k_fp8, v_fp8, qkv_layout, _ = combine_and_quantize( + qkv_layout, + query_layer, + key_layer, + value_layer, + quantizer, + keep_same_data_and_scale_inv_format=True, ) tensors = combine_and_dequantize( qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype ) + if isinstance(quantizer, MXFP8Quantizer): + # bhsd_bhsd_bhsd after combine_and_quantize; permute back to sbhd_sbhd_sbhd + tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] elif quantizer_name in ["S_quantizer", "O_quantizer"]: - t_fp8 = quantizer(tensor1) - tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + if quantizer is not None: + t_fp8 = quantizer(tensor1) + tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + else: + tensors = (tensor1, tensor2, tensor3) else: tensors = (tensor1, tensor2, tensor3) ctx.quantizer = quantizer @@ -213,16 +226,28 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou def backward(ctx, grad1, grad2, grad3): # pylint: disable=missing-function-docstring if ctx.quantizer_name in ["dO_quantizer", "dP_quantizer"]: - dt_fp8 = ctx.quantizer(grad1) - tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + if ctx.quantizer is not None: + dt_fp8 = ctx.quantizer(grad1) + tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + else: + tensors = grad1, grad2, grad3 elif ctx.quantizer_name == "dQKV_quantizer": query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]] - dq_fp8, dk_fp8, dv_fp8 = combine_and_quantize( - ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer + # always in sbhd_sbhd_sbhd shape at this point + dq_fp8, dk_fp8, dv_fp8, new_qkv_layout, _ = combine_and_quantize( + ctx.qkv_layout, + query_grad, + key_grad, + value_grad, + ctx.quantizer, + keep_same_data_and_scale_inv_format=True, ) tensors = combine_and_dequantize( - ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype + new_qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype ) + if isinstance(ctx.quantizer, MXFP8Quantizer): + # bhsd_bhsd_bhsd after combine_and_quantize; permute back to sbhd_sbhd_sbhd + tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] else: tensors = grad1, grad2, grad3 return tensors[0], tensors[1], tensors[2], None, None, None @@ -317,6 +342,111 @@ def fast_setattr(self, name: str, value: Any) -> None: """Fast attribute set for non-parameter fields.""" self.__dict__[name] = value + def _use_varlen_sdpa( + self, + attn_mask_type: str, + attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], + window_size: Optional[Tuple[int, int]], + core_attention_bias_type: str, + alibi_slopes: Optional[torch.Tensor], + fp8: bool, + ) -> bool: + """Whether PyTorch SDPA can replace unfused attention without materializing masks.""" + if self.attention_type != "self": + return False + if attn_mask_type != "padding_causal": + return False + if window_size not in [None, (-1, 0), (-1, -1)]: + return False + if attn_mask_type == "padding_causal" and attention_mask is None: + return False + if isinstance(attention_mask, tuple): + return False + return ( + core_attention_bias_type == "no_bias" + and self.attention_dropout.p == 0.0 + and alibi_slopes is None + and self.softmax_type == "vanilla" + and not self.return_max_logit + and not fp8 + ) + + def _format_context( + self, + context_layer: torch.Tensor, + q_format: str, + max_seqlen_q: int, + batch_size: int, + cu_seqlens_q: Optional[torch.Tensor], + ) -> torch.Tensor: + """Convert context from [b, h, sq, d] to the requested output layout.""" + if q_format == "sbhd": + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + return context_layer.view(max_seqlen_q, batch_size, -1) + if q_format == "bshd": + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + return context_layer.view(batch_size, max_seqlen_q, -1) + if q_format == "thd": + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + context_layer = ConvertBSHDtoTHD.apply(context_layer, cu_seqlens_q) + return context_layer.view(context_layer.shape[0], -1) + raise ValueError(f"Unsupported q_format = {q_format}!") + + def _forward_varlen_sdpa( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + q_format: str, + batch_size: int, + max_seqlen_q: int, + cu_seqlens_q: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor], + scale: float, + ) -> torch.Tensor: + """Run causal self-attention without expanding padding masks to [b, 1, sq, sk].""" + context_layer = torch.zeros( + batch_size, + query_layer.size(2), + max_seqlen_q, + value_layer.size(3), + dtype=query_layer.dtype, + device=query_layer.device, + ) + + if attention_mask is not None: + seqlens_q = attention_mask.logical_not()[:, 0, 0, :].sum(dim=1) + else: + seqlens_q = torch.full( + (batch_size,), max_seqlen_q, dtype=torch.int64, device=query_layer.device + ) + + dropout_p = self.attention_dropout.p if self.training else 0.0 + with self.attention_dropout_ctx(): + for batch_id in range(batch_size): + seqlen_q = int(seqlens_q[batch_id].item()) + if seqlen_q == 0: + continue + query = query_layer[:seqlen_q, batch_id].permute(1, 0, 2).unsqueeze(0) + key = key_layer[:seqlen_q, batch_id].permute(1, 0, 2).unsqueeze(0) + value = value_layer[:seqlen_q, batch_id].permute(1, 0, 2).unsqueeze(0) + context_layer[batch_id, :, :seqlen_q, :] = F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=dropout_p, + is_causal=True, + scale=scale, + ).squeeze(0) + + return self._format_context( + context_layer, + q_format, + max_seqlen_q, + batch_size, + cu_seqlens_q, + ) + def forward( self, _alibi_cache: Dict[str, Any], @@ -409,26 +539,9 @@ def forward( max_seqlen_kv, self.attention_type, ) - attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = ( - dpa_utils.get_full_mask( - max_seqlen_q, - max_seqlen_kv, - attn_mask_type=attn_mask_type, - attention_mask=attention_mask, - window_size=window_size, - attention_type=self.attention_type, - bottom_right_alignment=( - attn_mask_type not in ["causal", "padding_causal"] - if bottom_right_diagonal is None - else bottom_right_diagonal - ), - ) - ) - - batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 - # [b, np, sq, sk] + # [b, h, sq, sk] output_size = ( query_layer.size(1), query_layer.size(2), @@ -447,12 +560,47 @@ def forward( int(query_layer.shape[2] / value_layer.shape[2]), dim=2 ) - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) + scale = self.softmax_scale + if apply_qk_layer_scaling: + scale /= self.layer_number + + if self._use_varlen_sdpa( + attn_mask_type, + attention_mask, + window_size, + core_attention_bias_type, + alibi_slopes, + fp8, + ): + return self._forward_varlen_sdpa( + query_layer, + key_layer, + value_layer, + q_format, + batch_size, + max_seqlen_q, + cu_seqlens_q, + attention_mask, + self.softmax_scale, + ) - # preallocting result tensor: [b * np, sq, sk] + attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = ( + dpa_utils.get_full_mask( + max_seqlen_q, + max_seqlen_kv, + attn_mask_type=attn_mask_type, + attention_mask=attention_mask, + window_size=window_size, + attention_type=self.attention_type, + bottom_right_alignment=( + attn_mask_type not in ["causal", "padding_causal"] + if bottom_right_diagonal is None + else bottom_right_diagonal + ), + ) + ) + + # preallocting result tensor: [b * h, sq, sk] matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], @@ -461,19 +609,16 @@ def forward( device=torch.cuda.current_device(), ) - scale = self.softmax_scale - if apply_qk_layer_scaling: - scale /= self.layer_number - if fp8: + # get fp8 recipe for DPA + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) # S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation - fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: - fp8_recipe = fp8_meta["local_recipes"][0] if fp8_recipe.float8_current_scaling(): S_quantizer = Float8CurrentScalingQuantizer( fp8_dtype=S_quantizer.dtype, device="cuda" @@ -481,25 +626,50 @@ def forward( dP_quantizer = Float8CurrentScalingQuantizer( fp8_dtype=dP_quantizer.dtype, device="cuda" ) + # disable swizzle for MXFP8Quantizer + for quantizer in [ + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ]: + if isinstance(quantizer, MXFP8Quantizer): + quantizer.optimize_for_gemm = False + quantizer.internal = False - if "2" in qkv_layout or "3" in qkv_layout: - qkv_format, *_ = dpa_utils.get_qkv_format(qkv_layout) - qkv_layout = "_".join([qkv_format] * 3) + # q, k, v are in sbhd after previous reshaping # quantize and dequantize QKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", qkv_layout + query_layer, + key_layer, + value_layer, + QKV_quantizer, + "QKV_quantizer", + "sbhd_sbhd_sbhd", ) # quantize and dequantize dQKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", qkv_layout + query_layer, + key_layer, + value_layer, + dQKV_quantizer, + "dQKV_quantizer", + "sbhd_sbhd_sbhd", ) - # Raw attention scores. [b * np, sq, sk] + # [sq, b, h, d] -> [sq, b * h, d] + query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, h, d] -> [sk, b * h, d] + key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) + + # Raw attention scores. [b * h, sq, sk] if core_attention_bias_type == "no_bias": matmul_result = torch.baddbmm( matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * h, sq, d] + key_layer.transpose(0, 1).transpose(1, 2), # [b * h, d, sk] beta=0.0, alpha=scale, ).view(*output_size) @@ -507,8 +677,8 @@ def forward( elif core_attention_bias_type == "pre_scale_bias": assert core_attention_bias is not None, "core_attention_bias should not be None!" matmul_result = torch.bmm( - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * h, sq, d] + key_layer.transpose(0, 1).transpose(1, 2), # [b * h, d, sk] ) matmul_result = matmul_result.view(*output_size) + core_attention_bias matmul_result *= scale @@ -533,8 +703,8 @@ def forward( ) matmul_result = torch.baddbmm( matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * h, sq, d] + key_layer.transpose(0, 1).transpose(1, 2), # [b * h, d, sk] beta=0.0, alpha=scale, ) @@ -551,13 +721,13 @@ def forward( # max attention score max_logit = None if self.return_max_logit: - # matmul_result [b, np, sq, dk], max_logit [np] + # matmul_result [b, h, sq, dk], max_logit [h] max_logit = matmul_result if attn_mask_type != "no_mask": max_logit = self.mask_func(matmul_result, attention_mask) max_logit = torch.amax(max_logit, dim=(0, 2, 3)) - # add attention sink to the last column: [b, np, sq, sk+1] + # add attention sink to the last column: [b, h, sq, sk+1] if self.softmax_type != "vanilla": matmul_result = torch.cat( [ @@ -582,7 +752,7 @@ def forward( if "padding" in attn_mask_type: attention_probs = attention_probs.masked_fill(attention_mask, 0) - # remove attention sink: [b, np, sq, sk] + # remove attention sink: [b, h, sq, sk] if self.softmax_type != "vanilla": attention_probs = attention_probs[..., :-1] @@ -592,7 +762,7 @@ def forward( attention_probs = self.attention_dropout(attention_probs) # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] + # [sk, b, h, d] --> [b, h, sq, d] output_size = ( value_layer.size(1), value_layer.size(2), @@ -600,10 +770,10 @@ def forward( value_layer.size(3), ) - # change view [sk, b * np, hn] + # change view [sk, b * h, d] value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1) - # change view [b * np, sq, sk] + # change view [b * h, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) if fp8: @@ -612,37 +782,37 @@ def forward( attention_probs, None, None, S_quantizer, "S_quantizer", None ) - # matmul: [b * np, sq, hn] + # matmul: [b * h, sq, d] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - # change view [b, np, sq, hn] + # change view [b, h, sq, d] context_layer = context_layer.view(*output_size) if q_format == "sbhd": - # [b, np, sq, hn] --> [sq, b, np, hn] + # [b, h, sq, d] --> [sq, b, h, d] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - # [sq, b, np, hn] --> [sq, b, hp] - context_layer = context_layer.view(seqlen, batch_size, -1) + # [sq, b, h, d] --> [sq, b, hd] + context_layer = context_layer.view(max_seqlen_q, batch_size, -1) if q_format == "bshd": - # [b, np, sq, hn] --> [b, sq, np, hn] + # [b, h, sq, d] --> [b, sq, h, d] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - # [b, sq, np, hn] --> [b, sq, hp] - context_layer = context_layer.view(batch_size, seqlen, -1) + # [b, sq, h, d] --> [b, sq, hd] + context_layer = context_layer.view(batch_size, max_seqlen_q, -1) if q_format == "thd": - # [b, np, sq, hn] --> [b, sq, np, hn] + # [b, h, sq, d] --> [b, sq, h, d] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - # [b, sq, np, hn] --> [tq, np, hn] + # [b, sq, h, d] --> [tq, h, d] context_layer = ConvertBSHDtoTHD.apply( context_layer, cu_seqlens_q, ) - # [tq, np, hn] --> [tq, hp] + # [tq, h, d] --> [tq, hd] context_layer = context_layer.view(context_layer.shape[0], -1) if fp8: @@ -936,14 +1106,14 @@ def forward( batch_size * context_len, ) - use_flash_attn_4 = False - if flash_attention_backend is not None and flash_attention_backend > PkgVersion("4.0.0b"): - use_flash_attn_4 = True - use_flash_attn_3 = False - if flash_attention_backend is not None and PkgVersion( - "3.0.0b" - ) < flash_attention_backend < PkgVersion("4.0.0"): - use_flash_attn_3 = True + # FA4 prereleases such as 4.0.0b8 sort below 4.0.0, so key off the major + # version instead of a stable-version range check when selecting the API. + use_flash_attn_4 = ( + flash_attention_backend is not None and flash_attention_backend.major == 4 + ) + use_flash_attn_3 = ( + flash_attention_backend is not None and flash_attention_backend.major == 3 + ) if context_parallel and all( not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] ): @@ -1254,21 +1424,26 @@ def forward( if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] - # input types are inferred from the real data while output types are controlled by fp8_output - # fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha) + # qkv_layout may change due to MXFP8 quantization + # o_format should stay the same as original q_format + original_qkv_layout = qkv_layout + _, o_format, _ = dpa_utils.get_qkv_format(qkv_layout) + + # input types are inferred from real data while output types are controlled by fp8_output + # fp8_output should be set upstream assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." - is_input_fp8 = isinstance(q, Float8Tensor) + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output - # whether fwd kernel in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) - # whether bwd kernel in FP8: + # whether fwd kernel will be run in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) + # whether bwd kernel will be run in FP8: is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) # get nominal data type for out @@ -1277,16 +1452,20 @@ def forward( out_nominal_dtype = q.dtype max_logit = None + qkv_scale_inv_format = None if fp8: fused_attention_backend = FusedAttnBackend["FP8"] # q, k, v: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # q_fp8, k_fp8, v_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E4M3 + # q_fp8, k_fp8, v_fp8: Float8Tensor/MXFP8Tensor; + # dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout, qkv_scale_inv_format = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer, used_in_backward=is_training + ) # print quantizers print_quantizers( @@ -1304,6 +1483,7 @@ def forward( # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # MXFP8BlockScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_, aux_ctx_tensors, *_ = fused_attn_fwd( is_training, max_seqlen_q, @@ -1326,6 +1506,8 @@ def forward( dropout_p, fast_zero_fill, qkv_layout, + o_format, + qkv_scale_inv_format, attn_bias_type, attn_mask_type, softmax_type, @@ -1336,20 +1518,34 @@ def forward( cuda_graph=is_graph_capturing(), ) - # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ - out = out_ - - if isinstance(out_, Float8Tensor): - if not is_output_fp8 or not is_bwd_fp8: - out = out_.dequantize().view(out_.shape) - else: - if is_output_fp8 or ( + out_f16 = out_ + bwd_requires_o_f16 = is_training and ( + not is_bwd_fp8 + or ( is_bwd_fp8 - and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) - ): + and ( + (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or fp8_recipe.mxfp8() + ) + ) + ) + bwd_requires_o_fp8 = ( + is_training + and is_bwd_fp8 + and ( + fp8_recipe.delayed() + or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ) + ) + if isinstance(out_, QuantizedTensorStorage): + if not is_output_fp8 or bwd_requires_o_f16: + out_f16 = out_.dequantize().view(out_.shape) + else: + if is_output_fp8 or bwd_requires_o_fp8: out_fp8 = O_quantizer(out_) # print quantizers @@ -1365,21 +1561,25 @@ def forward( ) # return appropriate tensors - out_ret = out_fp8 if is_output_fp8 else out + out_ret = out_fp8 if is_output_fp8 else out_f16 - # save appropriate tensors + # save q, k, v, o tensors fp8_tensors = (None, None, None, None) - qkvo_tensors = (None, None, None, None) + f16_tensors = (None, None, None, None) if is_bwd_fp8: - if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if ( + fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or fp8_recipe.mxfp8(): fp8_tensors = (q_fp8, k_fp8, v_fp8, None) - qkvo_tensors = (None, None, None, out) - else: + f16_tensors = (None, None, None, out_f16) + elif fp8_recipe.delayed() or ( + fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 + ): fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) else: if is_input_fp8: q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) - qkvo_tensors = (q, k, v, out) + f16_tensors = (q, k, v, out_f16) else: # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( @@ -1404,6 +1604,8 @@ def forward( dropout_p, fast_zero_fill, qkv_layout, + o_format, + None, attn_bias_type, attn_mask_type, softmax_type, @@ -1414,10 +1616,10 @@ def forward( return_max_logit, is_graph_capturing(), ) - out = out_ + out_f16 = out_ out_ret = out_ fp8_tensors = (None, None, None, None) - qkvo_tensors = (q, k, v, out) + f16_tensors = (q, k, v, out_f16) nvtx_range_pop(f"{nvtx_label}") @@ -1431,7 +1633,7 @@ def forward( if ctx.fp8: tensor_list = fp8_tensors else: - tensor_list = [q, k, v, out] + tensor_list = [q, k, v, out_f16] mark_activation_offload(*tensor_list) mark_activation_offload(*aux_ctx_tensors) @@ -1441,7 +1643,7 @@ def forward( tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, - *qkvo_tensors, + *f16_tensors, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, @@ -1489,9 +1691,17 @@ def forward( ctx.qkv_layout = reload_layout[:-1] else: ctx.qkv_layout = qkv_layout + if fp8 and not ctx.fp8: + ctx.qkv_layout = original_qkv_layout else: ctx.qkv_layout = qkv_layout + if fp8 and not ctx.fp8: + ctx.qkv_layout = original_qkv_layout + ctx.o_format = o_format + ctx.qkv_scale_inv_format = qkv_scale_inv_format + # dqkv should have the same layout as the original qkv + ctx.dqkv_layout = original_qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type ctx.softmax_type = softmax_type @@ -1511,14 +1721,24 @@ def forward( def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring - # d_out is expected to be in FP8 if is_output_fp8=True, - # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorStorage): - d_out = ctx.dO_quantizer(d_out) - if not ctx.use_FAv2_bwd: - d_out._data = d_out._data.contiguous() - elif not ctx.use_FAv2_bwd: + # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 + if not isinstance(d_out, QuantizedTensorStorage) and not ctx.use_FAv2_bwd: d_out = d_out.contiguous() + d_out_fp8 = None + do_format = ctx.o_format + do_scale_inv_format = None + if ctx.fp8: + if isinstance(d_out, QuantizedTensorStorage): + d_out_fp8 = d_out + elif isinstance(ctx.dO_quantizer, MXFP8Quantizer): + (d_out_fp8,), do_scale_inv_format = mxfp8_quantize_fast_path( + [(d_out, ctx.dO_quantizer)], + do_format, + ) + else: + d_out_fp8 = ctx.dO_quantizer(d_out) ( q_fp8, k_fp8, @@ -1579,14 +1799,6 @@ def backward(ctx, d_out, *_args): dqkv_nominal_dtype = ctx.nominal_dtype if ctx.fp8: - # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E5M2 - if ctx.is_output_fp8: - d_out_fp8 = d_out - else: - d_out_fp8 = ctx.dO_quantizer(d_out) - # print quantizers print_quantizers( "FusedAttnFunc.backward >> before: ", @@ -1599,27 +1811,31 @@ def backward(ctx, d_out, *_args): ctx.dP_quantizer, ) - # get tex.DType for dq, dk, dv data - dqkv_te_dtype = d_out_fp8._fp8_dtype - - # q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16, + # DelayedScaling/Float8CurrentScaling/MXFP8BlockScaling: + # q_fp8, k_fp8, v_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16, # fp8_dtype = tex.DType.kFloat8E4M3 - # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # d_out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E5M2 - # out_: - # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # DelayedScaling: + # out_: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 - # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # - # dq_, dk_, dv_: - # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # dq_, dk_, dv_: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E5M2 - # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - out_ = ( - out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - else out_fp8 - ) + # Float8CurrentScaling: + # out_: NVTE_DPA_FP8CS_O_in_F16=1: + # torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # NVTE_DPA_FP8CS_O_in_F16=0: + # Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 + # dq_, dk_, dv_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # MXFP8BlockScaling: + # out_, dq_, dk_, dv_, d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_ = out_fp8 + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + out_ = out + if ctx.fp8_recipe.mxfp8(): + out_ = out + aux_ctx_tensors.append(d_out) dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1631,7 +1847,6 @@ def backward(ctx, d_out, *_args): out_, d_out_fp8, dqkv_nominal_dtype, - dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1643,6 +1858,11 @@ def backward(ctx, d_out, *_args): ctx.dropout_p, ctx.fast_zero_fill, ctx.qkv_layout, + ctx.o_format, + do_format, + ctx.dqkv_layout, + ctx.qkv_scale_inv_format, + do_scale_inv_format, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, @@ -1651,23 +1871,22 @@ def backward(ctx, d_out, *_args): ctx.deterministic, is_graph_capturing(), ) - # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 dq, dk, dv = dq_, dk_, dv_ - is_float8tensor = isinstance(dq_, Float8Tensor) - if is_float8tensor and not ctx.is_input_fp8: + is_quantized_tensor = isinstance(dq_, QuantizedTensorStorage) + if is_quantized_tensor and not ctx.is_input_fp8: # return in F16 dq, dk, dv = combine_and_dequantize( - ctx.qkv_layout, + ctx.dqkv_layout, dq_, dk_, dv_, src_nominal_dtype=dq_.dtype, ) - if not is_float8tensor and ctx.is_input_fp8: + if not is_quantized_tensor and ctx.is_input_fp8: # return in FP8 - dq, dk, dv = combine_and_quantize( - ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer + dq, dk, dv, _, _ = combine_and_quantize( + ctx.dqkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer ) # print quantizers @@ -1684,7 +1903,6 @@ def backward(ctx, d_out, *_args): else: if isinstance(d_out, QuantizedTensorStorage): d_out = d_out.dequantize(dtype=ctx.nominal_dtype) - dqkv_te_dtype = TE_DType[d_out.dtype] # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16 dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, @@ -1697,7 +1915,6 @@ def backward(ctx, d_out, *_args): out, d_out, dqkv_nominal_dtype, - dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1709,6 +1926,11 @@ def backward(ctx, d_out, *_args): ctx.dropout_p, ctx.fast_zero_fill, ctx.qkv_layout, + ctx.o_format, + do_format, + ctx.dqkv_layout, + None, + None, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, @@ -1873,9 +2095,9 @@ def forward( fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend ), "No fused attention backend supports this input combination!" assert all( - x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor) + x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, QuantizedTensorStorage) for x in [query_layer, key_layer, value_layer] - ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors." + ), "FusedAttention only supports FP16 and BF16 data types, or QuantizedTensors." assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda ), "FusedAttention only supports CUDA tensors." @@ -1981,7 +2203,7 @@ def forward( " with FP8!" ) if fp8_recipe.float8_current_scaling() and context_parallel: - all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers) + all_quantizers = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) for q in all_quantizers: if isinstance(q, Float8CurrentScalingQuantizer): q.with_amax_reduction = True diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64cccaac6e..1313119817 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -22,13 +22,11 @@ ) from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import Float8TensorStorage from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.graph import is_graph_capturing -from transformer_engine.pytorch.constants import ( - dist_group_type, - TE_DType, -) +from transformer_engine.pytorch.constants import dist_group_type from transformer_engine.pytorch.distributed import ( get_distributed_world_size, get_distributed_rank, @@ -48,6 +46,7 @@ combine_and_quantize, combine_and_dequantize, print_quantizers, + mxfp8_quantize_fast_path, ) _cu_seqlens_info_with_cp_cache = {} @@ -59,6 +58,18 @@ _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" +def get_bsh_dims(tensor_format): + """Get batch dimension and sequence dimension from tensor format""" + if tensor_format in ["bshd", "sbhd", "bhsd"]: + batch_dim = tensor_format.index("b") + seq_dim = tensor_format.index("s") + head_dim = tensor_format.index("h") + else: # tensor_format == "thd" + batch_dim = seq_dim = tensor_format.index("t") + head_dim = tensor_format.index("h") + return batch_dim, seq_dim, head_dim + + def flash_attn_p2p_communicate( rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm ): @@ -237,10 +248,10 @@ def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device): def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication before attention compute.""" # [cp, b, s, h//cp, d] -> [b, cp, s, h//cp, d] - # or [cp, s, b, h//cp, d] -> [cp, s, b, h//cp, d] + # [cp, s, b, h//cp, d] -> [cp, s, b, h//cp, d] x = x.movedim(0, seq_dim).contiguous() # [b, cp, s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] - # or [cp, s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] + # [cp, s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) # reorder the sequence chunks x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) @@ -251,12 +262,12 @@ def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_siz def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication after attention compute.""" # [b, cp*2, s//2, h//cp, d] -> [cp*2, b, s//2, h//cp, d] - # or [cp*2, s//2, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] + # [cp*2, s//2, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.movedim(seq_dim, 0).contiguous() # reorder the sequence chunks x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) # [cp*2, b, s//2, h//cp, d] -> [cp, 2, b, s//2, h//cp, d] - # or [cp*2, s//2, b, h//cp, d] -> [cp, 2, s//2, b, h//cp, d] + # [cp*2, s//2, b, h//cp, d] -> [cp, 2, s//2, b, h//cp, d] x = x.view(cp_size, 2, *x.shape[1:]) return x @@ -410,15 +421,32 @@ def flash_attn_a2a_communicate( cp_stream: torch.cuda.Stream, before_attn: bool, qkv_format: str = "bshd", - cu_seqlens_padded: torch.Tensor = None, + cu_seqlens_q_padded: torch.Tensor = None, + cu_seqlens_kv_padded: torch.Tensor = None, + a2a_input_names: List[str] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """A2A communication for context parallelism.""" - - assert ( - qkv_format != "thd" or cu_seqlens_padded is not None - ), "cu_seqlens_padded is required for THD format!" + assert a2a_input_names in [ + ["q", "k", "v"], + ["out"], + ["dout"], + ["dq", "dk", "dv"], + ], "a2a_input_names must be one of ['q', 'k', 'v'], ['out'], ['dout'], ['dq', 'dk', 'dv']!" + if a2a_input_names in [["out"], ["dout"]]: + assert qkv_format != "thd" or cu_seqlens_q_padded is not None, ( + f"flash_attn_a2a_communicate requires cu_seqlens_q_padded for {a2a_input_names} with" + " THD format!" + ) + if a2a_input_names in [["q", "k", "v"], ["dq", "dk", "dv"]]: + assert qkv_format != "thd" or ( + cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None + ), ( + "flash_attn_a2a_communicate requires cu_seqlens_q_padded and cu_seqlens_kv_padded for" + f" {a2a_input_names} with THD format!" + ) a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) + _, _, head_dim = get_bsh_dims(qkv_format) if before_attn: for i in range(len(a2a_inputs) + 2): if 0 < i < len(a2a_inputs) + 1: @@ -430,18 +458,24 @@ def flash_attn_a2a_communicate( with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - if qkv_format in ["bshd", "sbhd"]: + if qkv_format in ["bshd", "sbhd", "bhsd"]: # reorder the sequence chunks x = reorder_seq_chunks_for_a2a_before_attn( x, chunk_ids_for_a2a, seq_dim, cp_size ) - # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + # [b, cp*2, s//2, h//cp, d] -> [b, cp*s, h//cp, d] + # [cp*2, s//2, b, h//cp, d] -> [cp*s, b, h//cp, d] + # [b, h//cp, cp*2, s//2, d] -> [b, h//cp, cp*s, d] a2a_outputs[i - 2] = x.view( *x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :] ) else: # qkv_format == "thd" - # [cp, t, np//cp, hn] -> [cp*t, np//cp, hn] + cu_seqlens_padded = ( + cu_seqlens_q_padded + if a2a_input_names[i - 2] in ["q", "out", "dout", "dq"] + else cu_seqlens_kv_padded + ) + # [cp, t, h//cp, d] -> [cp*t, h//cp, d] x = x.view(-1, *x.shape[2:]) # reorder the sequence chunks a2a_outputs[i - 2] = reorder_seq_chunks_after_a2a_before_attn_thd( @@ -450,14 +484,21 @@ def flash_attn_a2a_communicate( if i < len(a2a_inputs): x = a2a_inputs[i] - # [b, s, np, hn] -> [b, s, cp, np//cp, hn] - # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] - # or [t, np, hn] -> [t, cp, np//cp, hn] - x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) - # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] - # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] - # or [t, cp, np//cp, hn] -> [cp, t, np//cp, hn] - a2a_inputs[i] = x.movedim(-3, 0).contiguous() + # [b, s, h, d] -> [b, s, cp, h//cp, d] + # [s, b, h, d] -> [s, b, cp, h//cp, d] + # [b, h, s, d] -> [b, cp, h//cp, s, d] + # [t, h, d] -> [t, cp, h//cp, d] + x = x.view( + *x.shape[:head_dim], + cp_size, + x.shape[head_dim] // cp_size, + *x.shape[head_dim + 1 :], + ) + # [b, s, cp, h//cp, d] -> [cp, b, s, h//cp, d] + # [s, b, cp, h//cp, d] -> [cp, s, b, h//cp, d] + # [b, cp, h//cp, s, d] -> [cp, b, h//cp, s, d] + # [t, cp, h//cp, d] -> [cp, t, h//cp, d] + a2a_inputs[i] = x.movedim(head_dim, 0).contiguous() else: for i in range(len(a2a_inputs) + 2): if 0 < i < len(a2a_inputs) + 1: @@ -467,30 +508,57 @@ def flash_attn_a2a_communicate( ) if i < len(a2a_inputs): x = a2a_inputs[i] - if qkv_format in ["bshd", "sbhd"]: - # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] - # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + if qkv_format in ["bshd", "sbhd", "bhsd"]: + # [b, cp*s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] + # [cp*s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] + # [b, h//cp, cp*s, d] -> [b, h//cp, cp*2, s//2, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) # reorder the sequence chunks a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( x, chunk_ids_for_a2a, seq_dim, cp_size ) else: # qkv_format == "thd" + cu_seqlens_padded = ( + cu_seqlens_q_padded + if a2a_input_names[i] in ["q", "out", "dout", "dq"] + else cu_seqlens_kv_padded + ) # reorder the sequence chunks x = reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens_padded, cp_size) - # [cp*t, np//cp, hn] -> [cp, t, np//cp, hn] + # [cp*t, h//cp, d] -> [cp, t, h//cp, d] a2a_inputs[i] = x.view(cp_size, -1, *x.shape[-2:]) if i > 1: with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] - # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] - # or [cp, t, np//cp, hn] -> [t, cp, np//cp, hn] - x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() - # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] - # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] - # or [t, cp, np//cp, hn] -> [t, np, hn] + # [cp, 2, b, s//2, h//cp, d] -> [2, b, s//2, cp, h//cp, d] + # [cp, 2, s//2, b, h//cp, d] -> [2, s//2, b, cp, h//cp, d] + # [cp, 2, b, h//cp, s//2, d] -> [2, b, cp, h//cp, s//2, d] + # [cp, t, h//cp, d] -> [t, cp, h//cp, d] + tmp_list = list(qkv_format) + if "t" not in qkv_format: + tmp_list.insert(0, "2") + tmp_list.insert(0, "c") + tmp_format = "".join(tmp_list) + head_dim_ = tmp_format.index("h") - 1 + tmp_list.insert(head_dim_, tmp_list.pop(0)) + x = x.movedim(0, head_dim_) + # [2, b, s//2, cp, h//cp, d] -> [b, 2, s//2, cp, h//cp, d] + # [2, s//2, b, cp, h//cp, d] -> [2, s//2, b, cp, h//cp, d] + # [2, b, cp, h//cp, s//2, d] -> [b, cp, h//cp, 2, s//2, d] + # [t, cp, h//cp, d] -> [t, cp, h//cp, d] + if "t" not in qkv_format: + tmp_format = "".join(tmp_list) + seq_dim_ = tmp_format.index("s") - 1 + tmp_list.insert(seq_dim_, tmp_list.pop(0)) + x = x.movedim(0, seq_dim_) + else: + seq_dim_ = 0 + x = x.contiguous() + # [b, 2, s//2, cp, h//cp, d] -> [b*s, h, d] + # [2, s//2, b, cp, h//cp, d] -> [s*b, h, d] + # [b, cp, h//cp, 2, s//2, d] -> [b*h, s, d] + # [t, cp, h//cp, d] -> [t, h, d] a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) torch.cuda.current_stream().wait_stream(cp_stream) return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs @@ -775,13 +843,16 @@ def cp_p2p_fwd_fused_attn( softmax_scale, dropout_p, qkv_layout, + o_format, attn_mask_type, attn_bias_type, fp8, + fp8_recipe, q_fp8, k_fp8, v_fp8, fwd_nominal_dtype, + QKV_quantizer, S_quantizer_per_step, O_quantizer_per_step, rank, @@ -867,11 +938,18 @@ def cp_p2p_fwd_fused_attn( cu_seqlens_kv_padded_ = cu_seqlens_kv_padded fp8_meta_kwargs = {} + new_qkv_layout = qkv_layout + qkv_scale_inv_format = None if fp8: - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) - ] + if not fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + else: + q_part, k_part, v_part, new_qkv_layout, qkv_scale_inv_format = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step @@ -888,7 +966,8 @@ def cp_p2p_fwd_fused_attn( fused_attention_backend=fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, + o_format=o_format, attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs, @@ -897,10 +976,14 @@ def cp_p2p_fwd_fused_attn( **fp8_meta_kwargs, return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), + qkv_scale_inv_format=qkv_scale_inv_format, ) if fp8: - softmax_lse_per_step, _, rng_states = aux_ctx_tensors + if qkv_layout != "t3hd": + softmax_lse_per_step, rng_states = aux_ctx_tensors + else: + softmax_lse_per_step, _, rng_states = aux_ctx_tensors else: softmax_lse_per_step, rng_states, *rest = aux_ctx_tensors attn_bias = rest[0] if len(rest) > 0 else None @@ -1065,15 +1148,19 @@ def cp_p2p_bwd_fused_attn( softmax_scale, dropout_p, qkv_layout, + o_format, + do_format, + dqkv_layout, attn_mask_type, attn_bias_type, deterministic, fwd_nominal_dtype, bwd_nominal_dtype, - bwd_output_te_dtype, S_quantizer, dP_quantizer_per_step, dQKV_quantizer_per_step, + QKV_quantizer_per_step, + dO_quantizer_per_step, q_part, k_part, v_part, @@ -1083,11 +1170,14 @@ def cp_p2p_bwd_fused_attn( ): """Per-tile backward call of CP P2P with FusedAttention backend""" if fp8: - aux_tensors = [ - softmax_lse, - softmax_lse, - rng_states[cp_size - step - 1], - ] + if qkv_layout == "t3hd": + aux_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - step - 1], + ] + else: + aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] else: aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] @@ -1106,11 +1196,14 @@ def cp_p2p_bwd_fused_attn( elif section == "upper-triangle": q_part, out_part, dout_part = [x.contiguous() for x in [q_part, out_part, dout_part]] if fp8: - aux_tensors = [ - softmax_lse_, - softmax_lse_, - rng_states[cp_size - step - 1], - ] + if qkv_layout == "t3hd": + aux_tensors = [ + softmax_lse_, + softmax_lse_, + rng_states[cp_size - step - 1], + ] + else: + aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] else: aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] @@ -1122,17 +1215,37 @@ def cp_p2p_bwd_fused_attn( aux_tensors += [attn_biases[cp_size - step - 1]] fp8_meta_kwargs = {} + qkv_scale_inv_format = None + do_scale_inv_format = None if fp8: - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip( - [q_fp8, kv_fp8, kv_fp8], - [q_part, k_part, v_part], + if not fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip( + [q_fp8, kv_fp8, kv_fp8], + [q_part, k_part, v_part], + ) + ] + else: + q_part, k_part, v_part, qkv_layout, qkv_scale_inv_format = combine_and_quantize( + qkv_layout, + q_part, + k_part, + v_part, + QKV_quantizer_per_step, + used_in_forward=False, + used_in_backward=True, + ) + if not fp8_recipe.mxfp8(): + if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): + out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) + dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) + else: + aux_tensors.append(dout_part) + (dout_part,), do_scale_inv_format = mxfp8_quantize_fast_path( + [(dout_part, dO_quantizer_per_step)], + do_format, ) - ] - if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): - out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) - dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) fp8_meta_kwargs["s_quantizer"] = S_quantizer fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step fp8_meta_kwargs["dqkv_quantizer"] = dQKV_quantizer_per_step @@ -1148,7 +1261,6 @@ def cp_p2p_bwd_fused_attn( out_part, dout_part, bwd_nominal_dtype, - bwd_output_te_dtype, aux_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded_, @@ -1156,10 +1268,15 @@ def cp_p2p_bwd_fused_attn( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, + o_format=o_format, + do_format=do_format, + dqkv_layout=dqkv_layout, attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, deterministic=deterministic, cuda_graph=is_graph_capturing(), + qkv_scale_inv_format=qkv_scale_inv_format, + do_scale_inv_format=do_scale_inv_format, **fp8_meta_kwargs, ) @@ -1313,16 +1430,15 @@ def forward( ) # set up attention args - enable_mla = k.shape[-1] != v.shape[-1] - causal = "causal" in attn_mask_type - if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - + causal = "causal" in attn_mask_type + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + orig_q_shape, orig_k_shape, orig_v_shape = q.shape, k.shape, v.shape + orig_o_shape = q.shape[:-1] + v.shape[-1:] batch_dim = None seq_dim = None cu_seqlens_q_half, cu_seqlens_kv_half = None, None - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format if qkv_format in ["bshd", "sbhd"]: seq_dim = qkv_format.index("s") cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None @@ -1337,13 +1453,10 @@ def forward( else: cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size - max_seqlen_q = max_seqlen_q // cp_size max_seqlen_kv = max_seqlen_kv // cp_size cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)] - - fused_attn_backend = None amax_per_step = None S_quantizer_per_step = [None for _ in range(cp_size)] O_quantizer_per_step = [None for _ in range(cp_size)] @@ -1352,17 +1465,17 @@ def forward( assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." fwd_nominal_dtype = q.dtype - is_input_fp8 = isinstance(q, Float8Tensor) + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output - is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + _use_fp8_dpa_bwd = bool(int(os.getenv("NVTE_FP8_DPA_BWD", "1"))) + is_bwd_fp8 = fp8 and _use_fp8_dpa_bwd # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; # may be different from fp8_meta["recipe"] fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] - ( QKV_quantizer, O_quantizer, @@ -1370,43 +1483,58 @@ def forward( dQKV_quantizer, dO_quantizer, dP_quantizer, - ) = dpa_utils.get_attention_quantizers(fp8, quantizers) + ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) - q_f16 = None + # q, k, v a2a: gather s and split h + # FP8DS/CS: Float8Tensor -> torch.uint8 -> Float8Tensor + # MXFP8/F16: fwd_nominal_dtype q_fp8, k_fp8, v_fp8 = (None, None, None) - # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: if fp8 and is_input_fp8: - QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = (q._data, k._data, v._data) + if not fp8_recipe.mxfp8(): + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) q, k, v = flash_attn_a2a_communicate( - [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True + [q, k, v], + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + cp_group_a2a, + cp_stream, + True, + qkv_format=qkv_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["q", "k", "v"], ) - if fp8 and is_input_fp8: + if fp8 and is_input_fp8 and not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8, v_fp8], [q, k, v]) ] q, k, v = q_fp8, k_fp8, v_fp8 + post_a2a_o_shape = q.shape[:-1] + v.shape[-1:] # convert qkv to the right type + q_f16 = None + fused_attn_backend = None if fp8: assert use_fused_attention, "FP8 is only supported with Fused Attention!" fused_attn_backend = FusedAttnBackend["FP8"] - if is_input_fp8: # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] - else: + elif not fp8_recipe.mxfp8(): # q_f16: torch.Tensor, dtype=fwd_nominal_dtype # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_f16 = q - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout, _ = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) + if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] # print quantizers @@ -1427,10 +1555,11 @@ def forward( # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; # only used to hold temporary scale/amax values (output only, no quantization op) for i in range(cp_size): - S_quantizer_per_step[i] = S_quantizer.copy() - S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + S_quantizer_per_step[i] = S_quantizer.copy() if S_quantizer is not None else None O_quantizer_per_step[i] = O_quantizer.copy() - O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if not fp8_recipe.mxfp8(): + S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: # q_f16: torch.Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=fwd_nominal_dtype @@ -1482,7 +1611,6 @@ def forward( attn_bias_ = attn_bias.view( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) ) - # [b, h, sq, sk] -> [b, h, sq, 2*cp, sk//(2*cp)] attn_bias = attn_bias.view( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) @@ -1557,17 +1685,22 @@ def forward( # synchronize fwd results correction across steps fwd_results_correction_done = torch.cuda.Event() + # q, k, v, o: + # causal: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # non-causal: [b, s, h, d] or [s, b, h, d] p2p_comm_buffers = [None for _ in range(cp_size)] k_shape = k.shape k_numel = k.numel() v_shape = v.shape + o_shape = q.shape[:-1] + v.shape[-1:] p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) send_recv_reqs = [[], []] # P2P communication and compute: each rank has cp_size steps - # f16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype - # fp8 attention: q, k, v: torch.Tensor, dtype=torch.uint8 + # MXFP8/F16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype + # FP8DS/CS attention: q, k, v: torch.Tensor, dtype=torch.uint8 out = None + o_format = qkv_format for i in range(cp_size + 1): if i < cp_size: with torch.cuda.stream(flash_attn_streams[i % 2]): @@ -1621,13 +1754,16 @@ def forward( softmax_scale, dropout_p, qkv_layout, + o_format, attn_mask_type, attn_bias_type, fp8, + fp8_recipe, q_fp8, k_fp8, v_fp8, fwd_nominal_dtype, + QKV_quantizer, S_quantizer_per_step[i], O_quantizer_per_step[i], rank, @@ -1775,8 +1911,8 @@ def forward( with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if use_fused_attention: - # [b, h, sq, 1] -> [b, h, sq] or - # [t, h, 1] -> [t, np] + # [b, h, sq, 1] -> [b, h, sq] + # [t, h, 1] -> [t, h] softmax_lse_per_step[i - 1].squeeze_(-1) if softmax_lse_in_packed_format: softmax_lse_per_step[i - 1] = ( @@ -1788,21 +1924,16 @@ def forward( out_per_step[i - 1] = out_per_step[i - 1].dequantize( dtype=torch.float32 ) - if fp8_recipe.float8_current_scaling(): + if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): out_per_step[i - 1] = out_per_step[i - 1].to(dtype=torch.float32) if i == 1: softmax_lse = torch.clone(softmax_lse_per_step[0]) if qkv_format == "thd": - if enable_mla: - out = torch.zeros_like(v if not fp8 else out_per_step[0]).view( - v_shape - ) + if fp8: + out = torch.zeros_like(out_per_step[0]).view(o_shape) else: - # MHA or GQA - out = torch.zeros_like(q if not fp8 else out_per_step[0]).view( - q.shape - ) + out = torch.zeros(o_shape, dtype=q.dtype, device=q.device) elif (i - 1) <= rank or not causal: flash_attn_fwd_softmax_lse_correction( softmax_lse, softmax_lse_per_step[i - 1] @@ -1842,7 +1973,7 @@ def forward( # fwd output correction: out in torch.float32 for i in range(cp_size): if i <= rank or not causal: - if qkv_format in ["bshd", "sbhd"]: + if o_format in ["bshd", "sbhd"]: if i == 0: out = flash_attn_fwd_out_correction_init( out_per_step[0], @@ -1850,10 +1981,7 @@ def forward( softmax_lse_per_step[0], seq_dim, ) - if enable_mla: - out = out.view(v_shape) - else: - out = out.view(q.shape) + out = out.view(o_shape) else: flash_attn_fwd_out_correction( out.view(*out_per_step[i].shape), @@ -1862,7 +1990,7 @@ def forward( softmax_lse_per_step[i], seq_dim, ) - elif qkv_format == "thd": + elif o_format == "thd": tex.thd_out_correction( out, out_per_step[i], @@ -1873,7 +2001,7 @@ def forward( softmax_lse_in_packed_format, ) else: - if qkv_format in ["bshd", "sbhd"]: + if o_format in ["bshd", "sbhd"]: flash_attn_fwd_second_half_out_correction( out, out_per_step[i], @@ -1881,7 +2009,7 @@ def forward( softmax_lse_per_step[i], seq_dim, ) - elif qkv_format == "thd": + elif o_format == "thd": tex.thd_out_correction( out, out_per_step[i], @@ -1891,35 +2019,31 @@ def forward( True, softmax_lse_in_packed_format, ) - - if qkv_format == "bshd": - out = out.view(out.shape[0], -1, *out.shape[-2:]) - ctx.batch_size = out.shape[0] - elif qkv_format == "sbhd": - out = out.view(-1, *out.shape[-3:]) - ctx.batch_size = out.shape[1] + out = out.view(post_a2a_o_shape) + out_part = out.to(fwd_nominal_dtype) if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) out = flash_attn_a2a_communicate( - out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False + out, + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + cp_group_a2a, + cp_stream, + False, + qkv_format=o_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["out"], ) - if use_fused_attention: - if qkv_format == "bshd": - # [b*s, h, d] -> [b, s, h, d] - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - elif qkv_format == "sbhd": - # [s*b, h, d] -> [s, b, h, d] - out = out.view(-1, ctx.batch_size, *out.shape[-2:]) + out = out.view(orig_o_shape) if return_max_logit: max_logit = flash_attn_a2a_communicate_softmax_offset( max_logit, 0, cp_size_a2a, cp_group_a2a, cp_stream, False ) - elif not use_fused_attention: - out = out.view(-1, *out.shape[-2:]) # update FP8 quantizers: amax across cp_size steps - if fp8 and use_fused_attention: + if fp8 and use_fused_attention and not fp8_recipe.mxfp8(): amax_cp_fwd = amax_per_step.amax(dim=1) S_quantizer.amax.copy_(amax_cp_fwd[0]) O_quantizer.amax.copy_(amax_cp_fwd[1]) @@ -1940,20 +2064,21 @@ def forward( # prepare for return and ctx saves out_fp8 = None out_f16 = out.to(fwd_nominal_dtype) - if fp8 and ( - is_output_fp8 - or (is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)) + if (fp8 and is_output_fp8) or ( + is_bwd_fp8 + and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + and not fp8_recipe.mxfp8() ): out_fp8 = O_quantizer(out_f16) out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 ctx.layer_number = layer_number ctx.fp8_recipe = fp8_recipe - ctx.fp8 = fp8 and is_bwd_fp8 + ctx.fp8 = is_bwd_fp8 kv_fp8 = None kv = p2p_comm_buffers[-1] - if fp8: + if fp8 and not fp8_recipe.mxfp8(): q_fp8, kv_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8], [q, kv]) @@ -1961,17 +2086,28 @@ def forward( # q, kv, out fp8_tensors = (None, None, None) f16_tensors = (None, None, None) + out_f16 = out_part if ctx.fp8: # fwd: fp8, bwd: fp8, save all fp8 fp8_tensors = (q_fp8, kv_fp8, out_fp8) if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: f16_tensors = (None, None, out_f16) - elif fp8 and is_input_fp8: + elif fp8_recipe.mxfp8(): + f16_tensors = (q, kv, out_f16) + elif fp8 and is_input_fp8 and not fp8_recipe.mxfp8(): # fwd: fp8, bwd: f16, save all f16 # dequantize fp8 inputs q_f16 = q_fp8.dequantize() kv_f16 = kv_fp8.dequantize() f16_tensors = (q_f16, kv_f16, out_f16) + elif fp8 and is_input_fp8 and fp8_recipe.mxfp8(): + # fwd: fp8, bwd: f16, save all f16 + # there is already an F16 version of the inputs + q_f16, k_f16, v_f16 = combine_and_dequantize(qkv_layout, q, k, v) + kv_f16 = torch.cat((k_f16.view(-1), v_f16.view(-1)), dim=-1) + f16_tensors = (q_f16, kv_f16, out_f16) + elif fp8 and not is_input_fp8 and fp8_recipe.mxfp8(): + f16_tensors = (q, kv, out_f16) elif fp8: # fwd: fp8, bwd: f16, save all f16 # inputs are already in f16 @@ -2009,7 +2145,6 @@ def forward( ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale - ctx.qkv_format = qkv_format ctx.attn_mask_type = attn_mask_type ctx.attn_bias_type = attn_bias_type ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape @@ -2022,12 +2157,19 @@ def forward( ctx.is_output_fp8 = is_output_fp8 ctx.use_flash_attn_3 = use_flash_attn_3 - ctx.enable_mla = enable_mla + ctx.orig_q_shape = orig_q_shape + ctx.orig_k_shape = orig_k_shape + ctx.orig_v_shape = orig_v_shape + ctx.orig_o_shape = orig_o_shape + ctx.post_a2a_o_shape = post_a2a_o_shape ctx.k_numel = k_numel ctx.k_shape = k_shape ctx.v_shape = v_shape - + ctx.o_shape = o_shape + ctx.qkv_format = qkv_format + ctx.qkv_layout = qkv_layout ctx.fwd_nominal_dtype = fwd_nominal_dtype + ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer @@ -2036,14 +2178,14 @@ def forward( ctx.S_quantizer = S_quantizer if ctx.fp8: ctx.QKV_quantizer = QKV_quantizer.copy() - ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer = O_quantizer.copy() - ctx.O_quantizer.scale = O_quantizer.scale.clone() - ctx.S_quantizer = S_quantizer.copy() - ctx.S_quantizer.scale = S_quantizer.scale.clone() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop(f"{nvtx_label}") - if return_max_logit: return out_ret, max_logit return out_ret @@ -2057,8 +2199,13 @@ def backward(ctx, dout, *_args): nvtx_range_push(f"{nvtx_label}") # dout is expected to be in FP8 if is_output_fp8=True, - # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorStorage): + # but in the case it's not, convert it to FP8 (except for MXFP8) before any operation + if ( + ctx.fp8 + and ctx.is_output_fp8 + and not isinstance(dout, QuantizedTensorStorage) + and not ctx.fp8_recipe.mxfp8() + ): dout = ctx.dO_quantizer(dout) if ctx.use_fused_attention: dout._data = dout._data.contiguous() @@ -2098,7 +2245,6 @@ def backward(ctx, dout, *_args): # set up attention args causal = "causal" in ctx.attn_mask_type seq_dim = None - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format if ctx.qkv_format in ["bshd", "sbhd"]: seq_dim = ctx.qkv_format.index("s") @@ -2137,13 +2283,13 @@ def backward(ctx, dout, *_args): if ctx.softmax_lse_in_packed_format: softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous() # [b, h, sq//2] -> [b, h, sq//2, 1] or - # [t//2, np] -> [t//2, h, 1] + # [t//2, h] -> [t//2, h, 1] softmax_lse_.unsqueeze_(-1) if ctx.use_fused_attention: if ctx.softmax_lse_in_packed_format: softmax_lse = softmax_lse.transpose(0, 1).contiguous() # [b, h, sq] -> [b, h, sq, 1] or - # [t, np] -> [t, h, 1] + # [t, h] -> [t, h, 1] softmax_lse.unsqueeze_(-1) # assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16 @@ -2158,28 +2304,29 @@ def backward(ctx, dout, *_args): buffer_dtype = torch.uint8 dq_buffer = None dout_fp8 = None - bwd_output_te_dtype = None dkv_buffer = None if ctx.fp8: - assert ctx.use_fused_attention, "FP8 is only supported with Fused Attention!" + assert ctx.use_fused_attention, "FP8 is only supported with FusedAttention backend!" fused_attn_backend = FusedAttnBackend["FP8"] - q, kv, out = ( - q_fp8._data, - kv_fp8._data, - ( - out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - else out_fp8._data - ), - ) + if not ctx.fp8_recipe.mxfp8(): + q, kv, out = ( + q_fp8._data, + kv_fp8._data, + ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8._data + ), + ) # dout_fp8: Float8Tensor, dtype=bwd_nominal_dtype # dout: torch.Tensor, dtype=torch.uint8 - if ctx.is_output_fp8: + if isinstance(dout, QuantizedTensorStorage): dout_fp8 = dout - else: + elif not ctx.fp8_recipe.mxfp8(): dout_fp8 = ctx.dO_quantizer(dout) - dout = dout_fp8._data + if not ctx.fp8_recipe.mxfp8(): + dout = dout_fp8._data # print quantizers print_quantizers( @@ -2193,9 +2340,6 @@ def backward(ctx, dout, *_args): ctx.dP_quantizer, ) - # dout_fp8._fp8_dtype - bwd_output_te_dtype = ctx.dO_quantizer.dtype - # create buffers for reduction in float32 if ctx.fp8_recipe.delayed(): dq_buffer = torch.empty( @@ -2203,7 +2347,7 @@ def backward(ctx, dout, *_args): dtype=buffer_dtype, device=q.device, ) - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dq_buffer = torch.empty( q.shape, dtype=torch.float32, @@ -2217,7 +2361,7 @@ def backward(ctx, dout, *_args): ) dkv_recv_buffer = torch.empty_like(dkv_send_buffer) p2p_comm_buffers = [[kv, dkv_send_buffer], [kv_recv_buffer, dkv_recv_buffer]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dkv_buffer = torch.zeros( kv.shape, dtype=torch.float32, @@ -2230,10 +2374,13 @@ def backward(ctx, dout, *_args): # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; # only used to hold temporary scale/amax values (output only, no quantization op) for i in range(cp_size): - dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() - dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dP_quantizer_per_step[i] = ( + ctx.dP_quantizer.copy() if ctx.dP_quantizer is not None else None + ) dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy() - dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if not ctx.fp8_recipe.mxfp8(): + dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: if isinstance(dout, QuantizedTensorStorage): dout = dout.dequantize(dtype=bwd_nominal_dtype) @@ -2244,34 +2391,28 @@ def backward(ctx, dout, *_args): ] p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: - bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] # communicate for the 'a2a' part of 'a2a+p2p' + dout = dout.view(*ctx.orig_o_shape) if cp_size_a2a > 1: - if not ctx.use_fused_attention: - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - dout = dout.view(*out.shape) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn( cp_size_a2a, out.device ) - out, dout = flash_attn_a2a_communicate( - [out, dout], + dout = flash_attn_a2a_communicate( + dout, chunk_ids_for_a2a, seq_dim, cp_size_a2a, ctx.cp_group_a2a, ctx.cp_stream, True, + qkv_format=ctx.qkv_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["dout"], ) - - if ctx.enable_mla: - out = out.view(*ctx.v_shape) - dout = dout.view(*ctx.v_shape) - else: - # MHA or GQA - out = out.view(*q.shape) - dout = dout.view(*q.shape) + out = out.view(*ctx.o_shape) + dout = dout.view(*ctx.o_shape) flash_attn_bwd = None if not ctx.use_fused_attention: @@ -2368,10 +2509,11 @@ def backward(ctx, dout, *_args): kv_fp8, ( out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + if (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or ctx.fp8_recipe.mxfp8() else out_fp8 ), - dout_fp8, + dout_fp8 if not ctx.fp8_recipe.mxfp8() else dout, softmax_lse, softmax_lse_, rng_states, @@ -2388,16 +2530,20 @@ def backward(ctx, dout, *_args): fused_attn_backend, ctx.softmax_scale, ctx.dropout_p, - qkv_layout, + ctx.qkv_layout, + ctx.qkv_format, + ctx.qkv_format, + ctx.qkv_layout, ctx.attn_mask_type, ctx.attn_bias_type, ctx.deterministic, ctx.fwd_nominal_dtype, bwd_nominal_dtype, - bwd_output_te_dtype, ctx.S_quantizer, dP_quantizer_per_step[i], dQKV_quantizer_per_step[i], + ctx.QKV_quantizer, + ctx.dO_quantizer, ] else: flash_attn_inputs = [ @@ -2471,7 +2617,7 @@ def backward(ctx, dout, *_args): if ctx.fp8 and ctx.use_fused_attention: if ctx.fp8_recipe.delayed(): dq_, dk_, dv_ = [x._data for x in [dq_, dk_, dv_]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dq_, dk_, dv_ = [x.to(torch.float32) for x in [dq_, dk_, dv_]] # copy dq_ into the right buffer position @@ -2555,7 +2701,7 @@ def backward(ctx, dout, *_args): # dkv correction if ctx.fp8 and ctx.fp8_recipe.delayed(): dkv = dkv_recv_buffer[(rank + i + 1) % cp_size] - elif ctx.fp8 and ctx.fp8_recipe.float8_current_scaling(): + elif ctx.fp8 and (ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8()): dkv = dkv_buffer else: dkv = p2p_comm_buffers[(i + 1) % 2][1] @@ -2645,9 +2791,10 @@ def backward(ctx, dout, *_args): # sum up all cp_size for dq, dk, dv if ctx.fp8 and ctx.use_fused_attention: - amax_cp_bwd = amax_per_step.amax(dim=1) - ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) - ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) + if not ctx.fp8_recipe.mxfp8(): + amax_cp_bwd = amax_per_step.amax(dim=1) + ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) + ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) dq = dq_buffer if ctx.fp8_recipe.delayed(): @@ -2661,7 +2808,7 @@ def backward(ctx, dout, *_args): for x in [dq, dk, dv] ] dq, dk, dv = combine_and_dequantize( - qkv_layout, + ctx.qkv_layout, dq, dk, dv, @@ -2670,7 +2817,7 @@ def backward(ctx, dout, *_args): ) dq, dk, dv = [x.sum(dim=0).to(bwd_nominal_dtype) for x in [dq, dk, dv]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dk = dkv[: ctx.k_numel].view(ctx.k_shape) dv = dkv[ctx.k_numel :].view(ctx.v_shape) @@ -2686,7 +2833,7 @@ def backward(ctx, dout, *_args): dv[cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _, _ = combine_and_quantize(ctx.qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) if ctx.fp8: # print quantizers @@ -2704,7 +2851,8 @@ def backward(ctx, dout, *_args): if cp_size_a2a > 1: if ctx.fp8 and ctx.is_input_fp8: dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv - dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) + if not ctx.fp8_recipe.mxfp8(): + dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], @@ -2714,16 +2862,22 @@ def backward(ctx, dout, *_args): ctx.cp_group_a2a, ctx.cp_stream, False, + qkv_format=ctx.qkv_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["dq", "dk", "dv"], ) - if ctx.fp8 and ctx.is_input_fp8: + if ctx.fp8 and ctx.is_input_fp8 and not ctx.fp8_recipe.mxfp8(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) ] - if ctx.qkv_format == "bshd": - dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] - elif ctx.qkv_format == "sbhd": - dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + dq, dk, dv = [ + x.view(y) + for x, y in zip( + [dq, dk, dv], [ctx.orig_q_shape, ctx.orig_k_shape, ctx.orig_v_shape] + ) + ] if attn_dbias is not None: # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, sq, sk] @@ -2821,27 +2975,42 @@ def forward( cp_group, cp_stream, use_flash_attn_3, + fp8, + fp8_meta, + quantizers, + fp8_output, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + o_format = qkv_format + _, seq_dim_qkv, _ = get_bsh_dims(qkv_format) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) - qkv_dtype = q.dtype - - causal = "causal" in attn_mask_type - padding = "padding" in attn_mask_type - assert not padding, f"{attn_mask_type} mask type is not supported!" - if use_fused_attention and causal and "bottom_right" not in attn_mask_type: - attn_mask_type = attn_mask_type + "_bottom_right" - assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" - assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" + assert qkv_format != "thd", f"No support for cp_comm_type='all_gather' and {qkv_format=}." + assert ( + "padding" not in attn_mask_type + ), f"No support for cp_comm_type='all_gather' and {attn_mask_type=}." + assert ( + attn_bias_type == "no_bias" + ), f"No support for cp_comm_type='all_gather' and {attn_bias_type=}." assert ( - use_fused_attention or fa_utils.v2_3_plus - ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" + window_size == (-1, 0) + or window_size == (-1, -1) + or use_fused_attention + or fa_utils.v2_3_plus + ), ( + "cp_comm_type='all_gather' only supports SWA through FusedAttention or FlashAttention" + f" >= 2.3. Found {use_fused_attention=} and {fa_utils.v2_3_plus=}." + ) + assert q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0, ( + "cp_comm_type='all_gather' requires seq_len % 2 == 0 for Q, K, V. Found seq_len_q =" + f" {q.shape[seq_dim_qkv]}, seq_len_kv = {k.shape[seq_dim_qkv]}." + ) flash_attn_fwd = None if not use_fused_attention: @@ -2874,14 +3043,6 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - assert qkv_format != "thd", f"{qkv_format} format is not supported!" - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - - seq_dim = qkv_format.index("s") - assert ( - q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 - ), "Sequence length per GPU needs to be divisible by 2!" - max_seqlen_q = max_seqlen_q // (2 * cp_size) max_seqlen_kv = max_seqlen_kv // (2 * cp_size) if use_fused_attention or qkv_format == "thd": @@ -2890,30 +3051,91 @@ def forward( cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) else: cu_seqlens_q_padded = None + if use_fused_attention and attn_mask_type == "causal": + attn_mask_type = attn_mask_type + "_bottom_right" + causal = "causal" in attn_mask_type - # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] - q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) - # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] - k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] + # FP8 setup + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) + is_output_fp8 = fp8_output + _use_fp8_dpa_bwd = bool(int(os.getenv("NVTE_FP8_DPA_BWD", "1"))) + is_bwd_fp8 = fp8 and _use_fp8_dpa_bwd + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + ( + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) + fwd_nominal_dtype = q.dtype + q_fp8, k_fp8, v_fp8 = (q, k, v) if is_input_fp8 else (None, None, None) + q_f16, k_f16, v_f16 = (None, None, None) if is_input_fp8 else (q, k, v) + fused_attn_backend = None + fp8_meta_kwargs = {} + if fp8: + assert use_fused_attention, "FP8 is only supported with FusedAttention backend!" + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 + if not is_input_fp8 and not fp8_recipe.mxfp8(): + q_fp8, k_fp8, v_fp8, qkv_layout, _ = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) + if not fp8_recipe.mxfp8(): + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer + elif use_fused_attention: + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + orig_q_shape, _, orig_v_shape = q.shape, k.shape, v.shape + orig_o_shape = orig_q_shape[:-1] + orig_v_shape[-1:] + + # q, k, v: + # FP8DS/CS: torch.uint8 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # reshape: split s + # [b, s, h, d] -> [b, 2, s//2, h, d] + # [s, b, h, d] -> [2, s//2, b, h, d] + q = q.view( + *q.shape[:seq_dim_qkv], 2, q.shape[seq_dim_qkv] // 2, *q.shape[(seq_dim_qkv + 1) :] + ) + # s dim first for all-gather + # [b, s, h, d]/[s, b, h, d] -> [s, b, h, d] + k, v = [x.movedim(seq_dim_qkv, 0).contiguous() for x in [k, v]] - # [s, b, h, d] -> [cp, s, b, h, d] + # gather along s: [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, cp_group) v_ag, _ = gather_along_first_dim(v, cp_group) - - # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + # split s:[cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + # pick out specific chunks for each rank chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + # reshape/flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) cp_stream.wait_stream(torch.cuda.current_stream()) + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] + # k_ag: [cp*s, b, h, d] + # v_ag: [cp*s, b, h, d] + # out_f16: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + q_shape, k_shape, v_shape = q.shape, k.shape, v.shape + o_shape = q.shape[:-1] + v.shape[-1:] + out_f16 = torch.empty(o_shape, dtype=fwd_nominal_dtype, device=q.device) + # create two streams to resolve wave quantization issue of Flash Attn in each step flash_attn_streams = [torch.cuda.current_stream(), cp_stream] - + # prepare per-step tensors local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] kv_seq_range_per_step = [None, None] window_size_per_step = [None, None] @@ -2921,16 +3143,15 @@ def forward( out_per_step = [None, None] softmax_lse_per_step = [None, None] rng_states = [None, None] - out = torch.empty_like(q) max_logit_per_step = [None, None] max_logit = None for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] - # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] - q_ = q.select(seq_dim, i).contiguous() + # [b, 2, s//2, h, d] -> [b, s//2, h, d] + # [2, s//2, b, h, d] -> [s//2, b, h, d] + q_part = q.select(seq_dim_qkv, i).contiguous() kv_seq_range_per_step[i], window_size_per_step[i] = ( get_kv_seq_info_after_all_gather( local_seq_chunk_ids[i], @@ -2950,13 +3171,30 @@ def forward( cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens( k.shape[1], max_seqlen_kv_, k.device ) - k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] - k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] + # select range: [s_range, b, h, d] + k_part, v_part = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # reshape to original format: [b, s_range, h, d] or [s_range, b, h, d] + k_part, v_part = [ + x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part] + ] if use_fused_attention: + new_qkv_layout = qkv_layout + qkv_scale_inv_format = None + if fp8: + if not fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + else: + q_part, k_part, v_part, new_qkv_layout, qkv_scale_inv_format = ( + combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) + ) ( out_per_step[i], - [softmax_lse_per_step[i], rng_states[i]], + aux_ctx_tensors, *max_logit_, ) = fused_attn_fwd( is_training, @@ -2964,14 +3202,15 @@ def forward( max_seqlen_kv_, cu_seqlens_q, cu_seqlens_kv_per_step[i], - q_, - k_, - v_, - qkv_dtype, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + q_part, + k_part, + v_part, + fwd_nominal_dtype, + fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, + o_format=o_format, attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, @@ -2980,9 +3219,20 @@ def forward( window_size=window_size_per_step[i], return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), + qkv_scale_inv_format=qkv_scale_inv_format, + **fp8_meta_kwargs, ) + if fp8: + if qkv_layout != "t3hd": + softmax_lse_per_step[i], rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *_ = aux_ctx_tensors if return_max_logit: max_logit_per_step[i] = max_logit_[0] + if fp8 and isinstance(out_per_step[i], QuantizedTensorStorage): + out_per_step[i] = out_per_step[i].dequantize(dtype=fwd_nominal_dtype) else: fa_forward_args_thd = get_fa_args( True, @@ -2999,9 +3249,9 @@ def forward( fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1] fa_outputs = flash_attn_fwd( - q_, - k_, - v_, + q_part, + k_part, + v_part, *fa_forward_args_thd, causal=causal, **fa_forward_kwargs, @@ -3017,61 +3267,152 @@ def forward( if not use_flash_attn_3: rng_states[i] = fa_outputs[3] + # out_per_step[i]: fwd_nominal_dtype, [b, s//2, h, d] or [s//2, b, h, d] + # out_f16: fwd_nominal_dtype, [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # max_logit_per_step[i]: torch.float32, [h] + # max_logit: torch.float32, [h] if return_max_logit and i == 0: max_logit = torch.clone(max_logit_per_step[0]) if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): - if qkv_format == "bshd": - out[:, i - 1].copy_(out_per_step[i - 1]) - elif qkv_format == "sbhd": - out[i - 1].copy_(out_per_step[i - 1]) + if o_format == "bshd": + out_f16[:, i - 1].copy_(out_per_step[i - 1]) + elif o_format == "sbhd": + out_f16[i - 1].copy_(out_per_step[i - 1]) if return_max_logit: max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) torch.cuda.current_stream().wait_stream(cp_stream) + + # all reduce max_logit across ranks if return_max_logit: torch.distributed.all_reduce( max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group ) - if use_fused_attention: - if qkv_format == "bshd": - out = out.view(out.shape[0], -1, *out.shape[-2:]) - elif qkv_format == "sbhd": - out = out.view(-1, *out.shape[-3:]) - else: - out = out.view(-1, *out.shape[-2:]) + # out_f16: fwd_nominal_dtype + # [b, 2, s//2, h, d] -> [b, s, h, d] + # [2, s//2, b, h, d] -> [s, b, h, d] + out_f16 = out_f16.view(orig_o_shape) - ctx.save_for_backward( - q, - k, - v, + # prepare for forward output and backward saves of out + out_fp8 = None + bwd_requires_o_fp8 = ( + is_training + and is_bwd_fp8 + and ( + fp8_recipe.delayed() + or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ) + ) + if (fp8 and is_output_fp8) or bwd_requires_o_fp8: + out_fp8 = O_quantizer(out_f16) + out_ret = out_fp8 if is_output_fp8 else out_f16 + + # save tensors for backward + ctx.fp8 = is_bwd_fp8 + ctx.fp8_recipe = fp8_recipe + fp8_tensors = (None, None, None, None) + f16_tensors = (None, None, None, None) + # True: q split along s; k/v with s first, i.e. [s, b, h, d] + # False: original [b, s, h, d] or [s, b, h, d] + ctx.qkv_reshaped = True + # no load-balance related token shuffling; original token order in q/k/v/out_f16 + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] + # out_f16/out_fp8: [b, s, h, d] or [s, b, h, d] + if ctx.fp8: + # q_fp8_save: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k_fp8_save: [s, b, h, d] + # v_fp8_save: [s, b, h, d] + q_fp8_save, k_fp8_save, v_fp8_save = None, None, None + if fp8_recipe.delayed() or fp8_recipe.float8_current_scaling(): + q_fp8_save = Float8Tensor.make_like(q_fp8, data=q, dtype=fwd_nominal_dtype) + k_fp8_save = Float8Tensor.make_like(k_fp8, data=k, dtype=fwd_nominal_dtype) + v_fp8_save = Float8Tensor.make_like(v_fp8, data=v, dtype=fwd_nominal_dtype) + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): q/k/v/o all in FP8 + # FP8CS+_dpa_fp8_cs_o_in_f16: q/k/v in FP8, o in f16 + # MXFP8: q/k/v/o all in f16 + if fp8_recipe.delayed() or ( + fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 + ): + fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, out_fp8) + elif fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, None) + f16_tensors = (None, None, None, out_f16) + elif fp8_recipe.mxfp8(): + f16_tensors = (q, k, v, out_f16) + elif fp8: + # convert q/k/v to F16 if necessary, and save q/k/v/o all in F16 and original format + if is_input_fp8: + q_f16, k_f16, v_f16 = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) + f16_tensors = (q_f16, k_f16, v_f16, out_f16) + ctx.qkv_reshaped = False + else: + # save all in F16 + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] + # out_f16: [b, s, h, d] or [s, b, h, d] + f16_tensors = (q, k, v, out_f16) + tensors_to_save, tensor_objects = prepare_for_saving( + *fp8_tensors, + *f16_tensors, cu_seqlens_q, cu_seqlens_q_padded, *cu_seqlens_kv_per_step, - *out_per_step, *softmax_lse_per_step, *rng_states, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects - ctx.qkv_dtype = qkv_dtype + ctx.qkv_format = qkv_format + ctx.qkv_layout = qkv_layout + ctx.o_format = o_format + ctx.dqkv_format = qkv_format + ctx.dqkv_layout = qkv_layout + ctx.fwd_nominal_dtype = fwd_nominal_dtype + ctx.q_shape = q_shape + ctx.k_shape = k_shape + ctx.v_shape = v_shape + ctx.o_shape = o_shape ctx.kv_seq_range_per_step = kv_seq_range_per_step ctx.window_size_per_step = window_size_per_step + ctx.cp_group = cp_group ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.softmax_scale = softmax_scale - ctx.qkv_format = qkv_format ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.fp8_meta = fp8_meta + ctx.is_input_fp8 = is_input_fp8 + + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer + ctx.S_quantizer = S_quantizer + if ctx.fp8: + ctx.QKV_quantizer = QKV_quantizer.copy() + ctx.O_quantizer = O_quantizer.copy() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") if return_max_logit: - return out, max_logit - return out + return out_ret, max_logit + return out_ret @staticmethod def backward(ctx, dout, *_args): @@ -3080,22 +3421,94 @@ def backward(ctx, dout, *_args): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - (*saved_tensors,) = ctx.saved_tensors - (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5] - cu_seqlens_kv_per_step = saved_tensors[5:7] - out_per_step = saved_tensors[7:9] - softmax_lse_per_step = saved_tensors[9:11] - rng_states = saved_tensors[11:13] + cu_seqlens_kv_per_step = [None, None] + softmax_lse_per_step = [None, None] + rng_states = [None, None] + ( + q_fp8, + k_fp8, + v_fp8, + out_fp8, + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_q_padded, + cu_seqlens_kv_per_step[0], + cu_seqlens_kv_per_step[1], + softmax_lse_per_step[0], + softmax_lse_per_step[1], + rng_states[0], + rng_states[1], + ) = restore_from_func_ctx(ctx) kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step - seq_dim = ctx.qkv_format.index("s") - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + _, seq_dim_qkv, _ = get_bsh_dims(ctx.qkv_format) + _, seq_dim_dqkv, _ = get_bsh_dims(ctx.dqkv_format) + _, seq_dim_o, _ = get_bsh_dims(ctx.o_format) + causal = "causal" in ctx.attn_mask_type - dout = dout.view(q.shape) - dq = torch.empty_like(q) - dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device) - dv = torch.zeros_like(dk) + # set up dout: + # FP8DS/CS: torch.uint8, [b, s, h, d] or [s, b, h, d] + # MXFP8/F16: torch.float16 or torch.bfloat16, [b, s, h, d] or [s, b, h, d] + dout_fp8 = None + if ctx.fp8: + assert ctx.use_fused_attention, "FP8 is only supported with FusedAttention backend!" + if isinstance(dout, QuantizedTensorStorage): + dout_fp8 = dout + elif not ctx.fp8_recipe.mxfp8(): + dout = ctx.dO_quantizer(dout) + dout_fp8 = dout + if not ctx.fp8_recipe.mxfp8(): + dout = dout_fp8._data + # [b, s, h, d] -> [b, 2, s//2, h, d] + # [s, b, h, d] -> [2, s//2, b, h, d] + dout = dout.view(ctx.o_shape) + + # set up q, k, v: + # FP8DS/CS: torch.uint8 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] + if ctx.fp8 and not ctx.fp8_recipe.mxfp8(): + q, k, v = [x._data for x in [q_fp8, k_fp8, v_fp8]] + if not ctx.qkv_reshaped: + q = q.view( + *q.shape[:seq_dim_qkv], 2, q.shape[seq_dim_qkv] // 2, *q.shape[(seq_dim_qkv + 1) :] + ) + k, v = [x.movedim(seq_dim_qkv, 0).contiguous() for x in [k, v]] + + # set up out: + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): torch.uint8 + # FP8CS+_dpa_fp8_cs_o_in_f16: torch.float16 or torch.bfloat16 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # [b, s, h, d] -> [b, 2, s//2, h, d] + # [s, b, h, d] -> [2, s//2, b, h, d] + if ctx.fp8 and ( + ctx.fp8_recipe.delayed() + or (ctx.fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ): + out = out_fp8._data + out = out.view(ctx.o_shape) + + # set up dq, dk, dv: + # dq: fwd_nominal_dtype, [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # dk: fwd_nominal_dtype, [cp*s, b, h, d] + # dv: fwd_nominal_dtype, [cp*s, b, h, d] + dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) + dk = torch.zeros( + (ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), + dtype=ctx.fwd_nominal_dtype, + device=k.device, + ) + dv = torch.zeros( + (ctx.v_shape[0] * cp_size, *ctx.v_shape[1:]), + dtype=ctx.fwd_nominal_dtype, + device=v.device, + ) dq_per_step = [None, None] dk_per_step = [None, None] dv_per_step = [None, None] @@ -3105,23 +3518,22 @@ def backward(ctx, dout, *_args): # synchronize dkv update across steps dkv_update_done = torch.cuda.Event() - # [s, b, h, d] -> [cp, s, b, h, d] + # gather k and v along s: [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, ctx.cp_group) v_ag, _ = gather_along_first_dim(v, ctx.cp_group) - - # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + # split s: [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + # select appropriate chunks for each rank chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + # flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) ctx.cp_stream.wait_stream(torch.cuda.current_stream()) - local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] - + # set up flash_attn_bwd flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} @@ -3153,57 +3565,132 @@ def backward(ctx, dout, *_args): if fa_utils.v2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 + local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] - # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] - q_ = q.select(seq_dim, i).contiguous() + # [b, 2, s//2, h, d] -> [b, s//2, h, d] + # [2, s//2, b, h, d] -> [s//2, b, h, d] + q_part = q.select(seq_dim_qkv, i).contiguous() seq_start_idx, seq_end_idx = ( kv_seq_range_per_step[i][0], kv_seq_range_per_step[i][1], ) max_seqlen_kv = seq_end_idx - seq_start_idx - k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [cp*s, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] - k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] - out_ = out_per_step[i] - dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) + # select range: [s_range, b, h, d] + k_part, v_part = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # reshape to original format: [b, s_range, h, d] or [s_range, b, h, d] + k_part, v_part = [ + x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part] + ] + # [b, 2, s//2, h, d] -> [b, s//2, h, d] + # [2, s//2, b, h, d] -> [s//2, b, h, d] + out_part = out.select(seq_dim_o, i).contiguous() + dout_part = dout.select(seq_dim_o, i).contiguous() if ctx.use_fused_attention: - aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] + if ctx.fp8 and ctx.qkv_layout == "t3hd": + aux_ctx_tensors = [ + softmax_lse_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + ] + else: + aux_ctx_tensors = [ + softmax_lse_per_step[i], + rng_states[i], + ] + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + fp8_meta_kwargs = {} + new_qkv_layout = ctx.qkv_layout + do_format = ctx.o_format + qkv_scale_inv_format = None + do_scale_inv_format = None + if ctx.fp8: + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): q/k/v/o/do all in FP8 + # FP8CS+_dpa_fp8_cs_o_in_f16: q/k/v/do in FP8, o in f16 + # MXFP8: q/k/v/do all in MXFP8, o/do_f16 in F16 + if not ctx.fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=ctx.fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + if ctx.fp8_recipe.delayed() or ( + ctx.fp8_recipe.float8_current_scaling() + and not _dpa_fp8_cs_o_in_f16 + ): + out_part = Float8Tensor.make_like( + out_fp8, data=out_part, dtype=ctx.fwd_nominal_dtype + ) + dout_part = Float8Tensor.make_like( + dout_fp8, data=dout_part, dtype=ctx.fwd_nominal_dtype + ) + else: + q_part, k_part, v_part, new_qkv_layout, qkv_scale_inv_format = ( + combine_and_quantize( + ctx.qkv_layout, + q_part, + k_part, + v_part, + ctx.QKV_quantizer, + used_in_forward=False, + used_in_backward=True, + ) + ) + aux_ctx_tensors.append(dout_part) + (dout_part,), do_scale_inv_format = mxfp8_quantize_fast_path( + [(dout_part, ctx.dO_quantizer)], + do_format, + ) dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv_per_step[i], - q_, - k_, - v_, - out_, - dout_, - ctx.qkv_dtype, - TE_DType[dout.dtype], + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.fwd_nominal_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, + o_format=ctx.o_format, + do_format=do_format, + dqkv_layout=ctx.dqkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, window_size=window_size_per_step[i], deterministic=ctx.deterministic, cuda_graph=is_graph_capturing(), + qkv_scale_inv_format=qkv_scale_inv_format, + do_scale_inv_format=do_scale_inv_format, + **fp8_meta_kwargs, ) + if ctx.fp8 and all( + isinstance(x, QuantizedTensorStorage) + for x in [dq_per_step[i], dk_per_step[i], dv_per_step[i]] + ): + dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ + x.dequantize(dtype=ctx.fwd_nominal_dtype) + for x in [dq_per_step[i], dk_per_step[i], dv_per_step[i]] + ] else: dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ - torch.empty_like(x) for x in [q_, k_, v_] + torch.empty_like(x) for x in [q_part, k_part, v_part] ] fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, - ctx.qkv_format, + ctx.dqkv_format, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv_per_step[i], max_seqlen_q=ctx.max_seqlen_q, @@ -3220,29 +3707,34 @@ def backward(ctx, dout, *_args): fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] if ctx.use_flash_attn_3: - fa_backward_kwargs["is_causal"] = "causal" in ctx.attn_mask_type + fa_backward_kwargs["is_causal"] = causal else: - fa_backward_kwargs["causal"] = "causal" in ctx.attn_mask_type + fa_backward_kwargs["causal"] = causal flash_attn_bwd( - dout_, - q_, - k_, - v_, - out_, + dout_part, + q_part, + k_part, + v_part, + out_part, softmax_lse_per_step[i], *fa_backward_args_thd, **fa_backward_kwargs, ) if i > 0: + # dq/dk/dv, dq_per_step/dk_per_step/dv_per_step: ctx.fwd_nominal_dtype with torch.cuda.stream(flash_attn_streams[i - 1]): - if ctx.qkv_format == "bshd": + # dq: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # dq_per_step[i]: [b, s//2, h, d] or [s//2, b, h, d] + if ctx.dqkv_format == "bshd": dq[:, i - 1].copy_(dq_per_step[i - 1]) - elif ctx.qkv_format == "sbhd": + elif ctx.dqkv_format == "sbhd": dq[i - 1].copy_(dq_per_step[i - 1]) - # [b, s_range, h, d] or [s_range, b, h, d] -> [s_range, b, h, d] + # dk/dv: [cp*s, b, h, d] + # dk_per_step[i - 1]/dv_per_step[i - 1]: [s_range, b, h, d] or [b, s_range, h, d] + # move s to first dim: [s_range, b, h, d] dk_per_step[i - 1], dv_per_step[i - 1] = [ - x.movedim(seq_dim, 0).contiguous() + x.movedim(seq_dim_dqkv, 0).contiguous() for x in [dk_per_step[i - 1], dv_per_step[i - 1]] ] # wait until dkv update of last step is done @@ -3252,6 +3744,7 @@ def backward(ctx, dout, *_args): kv_seq_range_per_step[i - 1][0], kv_seq_range_per_step[i - 1][1], ) + # add to dk/dv: [cp*s, b, h, d] dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) if i < len(local_seq_chunk_ids): @@ -3259,23 +3752,33 @@ def backward(ctx, dout, *_args): torch.cuda.current_stream().wait_stream(ctx.cp_stream) - # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] + # split s:[cp*s, b, h, d] -> [cp*2, s//2, b, h, d] dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) + # put back together the right chunks for each rank chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + # flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] dk = dk.view(-1, *dk.shape[-3:]) dv = dv.view(-1, *dv.shape[-3:]) + # reduce scatter: [cp*s, b, h, d] -> [s, b, h, d] dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) - dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) - dk = dk.movedim(0, seq_dim).contiguous() - dv = dv.movedim(0, seq_dim).contiguous() - nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") + # reshape to original format: + # dq: [b, 2, s//2, h, d] or [2, s//2, b, h, d] -> [b, s, h, d] or [s, b, h, d] + # dk: [s, b, h, d] -> [b, s, h, d] or [s, b, h, d] + # dv: [s, b, h, d] -> [b, s, h, d] or [s, b, h, d] + dq = dq.view(*dq.shape[:seq_dim_dqkv], -1, *dq.shape[(seq_dim_dqkv + 2) :]) + dk = dk.movedim(0, seq_dim_dqkv).contiguous() + dv = dv.movedim(0, seq_dim_dqkv).contiguous() + + # quantize if necessary + if ctx.fp8 and ctx.is_input_fp8: + dq, dk, dv, _, _ = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") return ( None, dq, @@ -3298,6 +3801,10 @@ def backward(ctx, dout, *_args): None, None, None, + None, + None, + None, + None, ) @@ -3342,24 +3849,43 @@ def forward( ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) - + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + original_qkv_layout = qkv_layout + orig_q_shape, orig_k_shape, orig_v_shape = q.shape, k.shape, v.shape + orig_o_shape = orig_q_shape[:-1] + orig_v_shape[-1:] + o_format = qkv_format + _, seq_dim_qkv, _ = get_bsh_dims(qkv_format) + _, seq_dim_o, _ = get_bsh_dims(o_format) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) causal = "causal" in attn_mask_type - padding = "padding" in attn_mask_type + + if qkv_format in ["bshd", "sbhd"]: + assert ( + "padding" not in attn_mask_type + ), f"No support for cp_comm_type='a2a', {attn_mask_type=} and {qkv_format=}." assert ( - not padding or qkv_format == "thd" - ), f"{attn_mask_type} mask type is not supported for BSHD and SBHD!" - assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" - assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" + attn_bias_type == "no_bias" + ), f"No support for cp_comm_type='a2a' and {attn_bias_type=}." assert ( window_size == (-1, 0) or window_size == (-1, -1) or use_fused_attention or fa_utils.v2_3_plus - ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" + ), ( + "cp_comm_type='a2a' only supports SWA through FusedAttention or FlashAttention >= 2.3." + f" Found {use_fused_attention=} and {fa_utils.v2_3_plus=}." + ) + assert q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0, ( + "cp_comm_type='a2a' requires seq_len % 2 == 0 for Q, K, V. Found seq_len_q =" + f" {q.shape[seq_dim_qkv]}, seq_len_kv = {k.shape[seq_dim_qkv]}, cp_size = {cp_size}." + ) + assert q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0, ( + "cp_comm_type='a2a' requires num_heads % cp_size == 0 for Q, K, V. Found num_heads_q =" + f" {q.shape[-2]}, num_heads_kv = {k.shape[-2]}, cp_size = {cp_size}." + ) flash_attn_fwd = None if not use_fused_attention: @@ -3399,89 +3925,116 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - assert ( - q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 - ), "The number of attention heads needs to be divisible by CP size!" - - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - - if qkv_format in ["bshd", "sbhd"]: - batch_dim = qkv_format.index("b") - seq_dim = qkv_format.index("s") - else: # qkv_format == "thd" - batch_dim = seq_dim = qkv_format.index("t") - - assert ( - q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 - ), "Sequence length per GPU needs to be divisible by 2!" - assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." - is_input_fp8 = isinstance(q, Float8Tensor) + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output - is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + _use_fp8_dpa_bwd = bool(int(os.getenv("NVTE_FP8_DPA_BWD", "1"))) + is_bwd_fp8 = fp8 and _use_fp8_dpa_bwd # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; # may be different from fp8_meta["recipe"] fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] + fwd_nominal_dtype = q.dtype fused_attn_backend = None max_logit = None QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) q_fp8, k_fp8, v_fp8 = (None, None, None) + fp8_meta_kwargs = {} if fp8: - if use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] - if is_input_fp8: - q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = q_fp8._data, k_fp8._data, v_fp8._data - else: - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) - q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] - fp8_meta_kwargs = {} - fp8_meta_kwargs["s_quantizer"] = S_quantizer - fp8_meta_kwargs["o_quantizer"] = O_quantizer - else: - assert False, "FP8 is only supported with Fused Attention!" + assert use_fused_attention, "FP8 is only supported with FusedAttention backend!" + fused_attn_backend = FusedAttnBackend["FP8"] + if is_input_fp8: + q_fp8, k_fp8, v_fp8 = q, k, v + elif not fp8_recipe.mxfp8(): + q_fp8, k_fp8, v_fp8, qkv_layout, _ = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) + if not fp8_recipe.mxfp8(): + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer else: if use_fused_attention: - fp8_meta_kwargs = {} fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + # q, k, v: + # FP8DS/FP8CS: torch.uint8 + # MXFP8: torch.float16 or torch.bfloat16 + # F16: torch.float16 or torch.bfloat16 + # a2a: gather s and split h + # [b, s//cp, h, d] -> [b, s, h//cp, d] + # [s//cp, b, h, d] -> [s, b, h//cp, d] + # [t//cp, h, d] -> [t, h//cp, d] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device) q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, - seq_dim, + seq_dim_qkv, cp_size, cp_group, cp_stream, before_attn=True, qkv_format=qkv_format, - cu_seqlens_padded=cu_seqlens_q_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["q", "k", "v"], ) + + # softmax_offset: split h + # [1, h, 1, 1] -> [1, h//cp, 1, 1] if softmax_type != "vanilla": softmax_offset = flash_attn_a2a_communicate_softmax_offset( softmax_offset, 1, cp_size, cp_group, cp_stream, True ) - out_fp8 = None - out_f16 = None - batch_size = q.shape[batch_dim] + # _part: inputs to attention kernel and saved for backward + # note: they have post a2a shapes q_part, k_part, v_part = q, k, v - out_part = None + out_part, out_fp8, out_f16 = None, None, None + bwd_requires_o_f16 = is_training and ( + not is_bwd_fp8 + or ( + is_bwd_fp8 + and ( + (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or fp8_recipe.mxfp8() + ) + ) + ) + bwd_requires_o_fp8 = ( + is_training + and is_bwd_fp8 + and ( + fp8_recipe.delayed() + or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ) + ) + qkv_scale_inv_format = None if use_fused_attention: if fp8: - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) - ] + if fp8_recipe.mxfp8(): + q_fp8, k_fp8, v_fp8, qkv_layout, qkv_scale_inv_format = combine_and_quantize( + qkv_layout, + q_part, + k_part, + v_part, + QKV_quantizer, + used_in_backward=is_training, + ) + q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] + else: + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, max_seqlen_q, @@ -3496,6 +4049,7 @@ def forward( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, + o_format=o_format, attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, @@ -3507,25 +4061,20 @@ def forward( softmax_offset=softmax_offset, return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), + qkv_scale_inv_format=qkv_scale_inv_format, ) - if isinstance(out_, Float8Tensor): - out_fp8 = out_ - out_ = out_._data - if is_bwd_fp8 and not ( - fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - ): - out_part = out_fp8 - else: - out_part = out_fp8.dequantize(dtype=fwd_nominal_dtype) - else: - out_f16 = out_ - out_part = out_ - if ( - fp8 - and is_bwd_fp8 - and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) - ): - out_part = O_quantizer(out_) + # construct out_part for backward + # out_fp8 and out_f16 store the FP8 or F16 tensor for backward saves + out_fp8 = out_ + out_f16 = out_ + if bwd_requires_o_fp8: + if not isinstance(out_, QuantizedTensorStorage): + out_fp8 = O_quantizer(out_) + out_part = out_fp8 + if bwd_requires_o_f16: + if isinstance(out_, QuantizedTensorStorage): + out_f16 = out_.dequantize(dtype=fwd_nominal_dtype) + out_part = out_f16 else: fa_forward_args_thd = get_fa_args( True, @@ -3553,60 +4102,95 @@ def forward( aux_ctx_tensors = [softmax_lse, rng_state] out_part = out_ + # a2a: split s and gather h + # [b, s, h//cp, d] -> [b*s//cp, h, d] + # [s, b, h//cp, d] -> [s//cp*b, h, d] + # [t, h//cp, d] -> [t//cp, h, d] + if isinstance(out_, Float8TensorStorage): + out_ = out_._data chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out_.device) out_ = flash_attn_a2a_communicate( out_, chunk_ids_for_a2a, - seq_dim, + seq_dim_o, cp_size, cp_group, cp_stream, before_attn=False, - qkv_format=qkv_format, - cu_seqlens_padded=cu_seqlens_q_padded, + qkv_format=o_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["out"], ) - if return_max_logit: - max_logit = flash_attn_a2a_communicate_softmax_offset( - *max_logit, 0, cp_size, cp_group, cp_stream, False - ) - - if use_fused_attention: - if qkv_format == "bshd": - # [b*s, h, d] -> [b, s, h, d] - out_ = out_.view(batch_size, -1, *out_.shape[-2:]) - elif qkv_format == "sbhd": - # [s*b, h, d] -> [s, b, h, d] - out_ = out_.view(-1, batch_size, *out_.shape[-2:]) + # [b*s//cp, h, d] -> [b, s//cp, h, d] + # [s//cp*b, h, d] -> [s//cp, b, h, d] + # [t//cp, h, d] -> [t//cp, h, d] + out_ = out_.view(orig_o_shape) - if fp8 and use_fused_attention: - if fp8_recipe.float8_current_scaling(): - out_f16 = out_ - if is_output_fp8: - out_fp8 = O_quantizer(out_) + # out_ret: output tensor for forward pass + # out_fp8 and out_f16 are reused here to store the FP8 or F16 tensor for forward returns + if fp8: if fp8_recipe.delayed(): out_fp8 = Float8Tensor.make_like(out_fp8, data=out_, dtype=fwd_nominal_dtype) - if not is_output_fp8: + if is_output_fp8: + if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): + out_fp8 = O_quantizer(out_) + out_f16 = out_ + else: + if fp8_recipe.delayed(): out_f16 = out_fp8.dequantize(dtype=fwd_nominal_dtype) + else: + out_f16 = out_ else: out_f16 = out_ - out_ret = out_fp8 if is_output_fp8 else out_f16 - ctx.fp8 = fp8 and is_bwd_fp8 + # all gather max logit + if return_max_logit: + max_logit = flash_attn_a2a_communicate_softmax_offset( + *max_logit, 0, cp_size, cp_group, cp_stream, False + ) + + ctx.qkv_layout = qkv_layout + ctx.o_format = o_format + ctx.qkv_scale_inv_format = qkv_scale_inv_format + ctx.dqkv_layout = original_qkv_layout + ctx.dqkv_format = qkv_format + ctx.orig_q_shape = orig_q_shape + ctx.orig_k_shape = orig_k_shape + ctx.orig_v_shape = orig_v_shape + ctx.orig_o_shape = orig_o_shape + + # save tensors for backward + ctx.fp8 = is_bwd_fp8 fp8_tensors = (None, None, None, None) f16_tensors = (None, None, None, None) - if ctx.fp8: - if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: - fp8_tensors = (q_part, k_part, v_part, None) - f16_tensors = (None, None, None, out_part) + if is_training: + if ctx.fp8: + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): q/k/v/o all in FP8 + # (FP8CS+_dpa_fp8_cs_o_in_f16) or MXFP8: q/k/v in FP8, o in F16 + if ( + fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or fp8_recipe.mxfp8(): + fp8_tensors = (q_part, k_part, v_part, None) + f16_tensors = (None, None, None, out_part) + elif fp8_recipe.delayed() or ( + fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 + ): + fp8_tensors = (q_part, k_part, v_part, out_part) + elif fp8: + # FP8DS/CS: convert post-a2a FP8 q/k/v to F16; out_part already in F16 + # MXFP8: save post-a2a pre-quantization F16 q/k/v; out_part already in F16 + if fp8_recipe.mxfp8(): + f16_tensors = (q, k, v, out_part) + ctx.qkv_layout = original_qkv_layout + else: + q_part, k_part, v_part = combine_and_dequantize( + qkv_layout, q_part, k_part, v_part + ) + f16_tensors = (q_part, k_part, v_part, out_part) else: - fp8_tensors = (q_part, k_part, v_part, out_part) - elif fp8: - q_part, k_part, v_part = combine_and_dequantize(qkv_layout, q_part, k_part, v_part) - f16_tensors = (q_part, k_part, v_part, out_part) - else: - f16_tensors = (q_part, k_part, v_part, out_part) - + # all tensors are in F16 + f16_tensors = (q_part, k_part, v_part, out_part) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, *f16_tensors, @@ -3618,16 +4202,13 @@ def forward( ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects - ctx.out_shape = out_ret.shape - ctx.batch_size = batch_size ctx.cp_group = cp_group ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale - ctx.qkv_format = qkv_format ctx.attn_mask_type = attn_mask_type ctx.attn_bias_type = attn_bias_type ctx.deterministic = deterministic @@ -3649,11 +4230,13 @@ def forward( ctx.S_quantizer = S_quantizer if ctx.fp8: ctx.QKV_quantizer = QKV_quantizer.copy() - ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer = O_quantizer.copy() - ctx.O_quantizer.scale = O_quantizer.scale.clone() - ctx.S_quantizer = S_quantizer.copy() - ctx.S_quantizer.scale = S_quantizer.scale.clone() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") if return_max_logit: return out_ret, max_logit @@ -3681,60 +4264,53 @@ def backward(ctx, dout, *_args): *aux_ctx_tensors, ) = restore_from_func_ctx(ctx) - qkv_format = ctx.qkv_format - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - causal = "causal" in ctx.attn_mask_type - - if qkv_format in ["bshd", "sbhd"]: - seq_dim = qkv_format.index("s") - else: # qkv_format == "thd" - seq_dim = qkv_format.index("t") - + _, seq_dim_dqkv, _ = get_bsh_dims(ctx.dqkv_format) + _, seq_dim_do, _ = get_bsh_dims(ctx.o_format) bwd_nominal_dtype = ctx.fwd_nominal_dtype - dqkv_te_dtype = None fused_attn_backend = None - dout_fp8 = dout + causal = "causal" in ctx.attn_mask_type + + dout_fp8 = None + fp8_meta_kwargs = {} if ctx.fp8: - if ctx.use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] - if not isinstance(dout, QuantizedTensorStorage): - dout = ctx.dO_quantizer(dout) - dout_fp8 = dout - dqkv_te_dtype = dout._fp8_dtype + assert ctx.use_fused_attention, "FP8 is only supported with FusedAttention backend!" + fused_attn_backend = FusedAttnBackend["FP8"] + if isinstance(dout, QuantizedTensorStorage): + dout_fp8 = dout + elif not ctx.fp8_recipe.mxfp8(): + dout = ctx.dO_quantizer(dout) + dout_fp8 = dout + if not ctx.fp8_recipe.mxfp8(): dout = dout._data - fp8_meta_kwargs = {} - fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer - fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer - fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer - - else: - assert False, "FP8 is only supported with Fused Attention!" + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer else: if isinstance(dout, QuantizedTensorStorage): dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: - fp8_meta_kwargs = {} - dqkv_te_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - - if not ctx.use_fused_attention: - if qkv_format in ["bshd", "sbhd"]: - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - dout = dout.view(ctx.batch_size, -1, *dout.shape[-2:]) - else: - dout = dout.view(*ctx.out_shape) - + dout = dout.view(*ctx.orig_o_shape) + + # dout: + # FP8DS/CS: torch.uint8 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # a2a: gather s and split h + # [b, s//cp, h, d] -> [b, s, h//cp, d] + # [s//cp, b, h, d] -> [s, b, h//cp, d] + # [t//cp, h, d] -> [t, h//cp, d] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, dout.device) dout = flash_attn_a2a_communicate( dout, chunk_ids_for_a2a, - seq_dim, + seq_dim_do, cp_size, ctx.cp_group, ctx.cp_stream, before_attn=True, - qkv_format=qkv_format, - cu_seqlens_padded=cu_seqlens_q_padded, + qkv_format=ctx.o_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["dout"], ) flash_attn_bwd = None @@ -3752,7 +4328,7 @@ def backward(ctx, dout, *_args): fa_backward_kwargs["window_size_right"] = ctx.window_size[1] fa_backward_kwargs["deterministic"] = ctx.deterministic else: - if qkv_format == "thd": + if ctx.o_format == "thd": from transformer_engine.pytorch.attention.dot_product_attention.backends import ( _flash_attn_varlen_bwd, ) @@ -3779,12 +4355,23 @@ def backward(ctx, dout, *_args): dq_fp8, dk_fp8, dv_fp8 = None, None, None if ctx.use_fused_attention: + do_format = ctx.o_format + do_scale_inv_format = None q_part, k_part, v_part, out_part, dout_part = q, k, v, out, dout if ctx.fp8: q_part, k_part, v_part, out_part = q_fp8, k_fp8, v_fp8, out_fp8 - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if ( + ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or ctx.fp8_recipe.mxfp8(): out_part = out - dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) + if not ctx.fp8_recipe.mxfp8(): + dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) + else: + aux_ctx_tensors.append(dout) + (dout_part,), do_scale_inv_format = mxfp8_quantize_fast_path( + [(dout, ctx.dO_quantizer)], + do_format, + ) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -3796,23 +4383,27 @@ def backward(ctx, dout, *_args): out_part, dout_part, bwd_nominal_dtype, - dqkv_te_dtype, aux_ctx_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout=qkv_layout, + qkv_layout=ctx.qkv_layout, + o_format=ctx.o_format, + do_format=do_format, + dqkv_layout=ctx.dqkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, window_size=ctx.window_size, deterministic=ctx.deterministic, cuda_graph=is_graph_capturing(), + qkv_scale_inv_format=ctx.qkv_scale_inv_format, + do_scale_inv_format=do_scale_inv_format, **fp8_meta_kwargs, softmax_type=ctx.softmax_type, ) - if isinstance(dq, Float8Tensor): + if all(isinstance(x, Float8TensorStorage) for x in [dq, dk, dv]): dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv dq, dk, dv = [x._data for x in [dq, dk, dv]] else: @@ -3821,7 +4412,7 @@ def backward(ctx, dout, *_args): fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, - qkv_format, + ctx.dqkv_format, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=ctx.max_seqlen_q, @@ -3847,24 +4438,33 @@ def backward(ctx, dout, *_args): **fa_backward_kwargs, ) + # dq, dk, dv: + # FP8DS: torch.uint8 + # FP8CS/MXFP8/F16: torch.float16 or torch.bfloat16 + # a2a: gather s and split h + # [b, s//cp, h, d] -> [b, s, h//cp, d] + # [s//cp, b, h, d] -> [s, b, h//cp, d] + # [t//cp, h, d] -> [t, h//cp, d] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dq.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], chunk_ids_for_a2a, - seq_dim, + seq_dim_dqkv, cp_size, ctx.cp_group, ctx.cp_stream, before_attn=False, - qkv_format=qkv_format, - cu_seqlens_padded=cu_seqlens_q_padded, + qkv_format=ctx.dqkv_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["dq", "dk", "dv"], ) + dq, dk, dv = [ + x.view(y) + for x, y in zip([dq, dk, dv], [ctx.orig_q_shape, ctx.orig_k_shape, ctx.orig_v_shape]) + ] - if qkv_format == "bshd": - dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] - elif qkv_format == "sbhd": - dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] - + # d_bias, d_softmax_offset d_bias = None d_softmax_offset = None if ctx.use_fused_attention: @@ -3876,9 +4476,14 @@ def backward(ctx, dout, *_args): d_softmax_offset, 1, cp_size, ctx.cp_group, ctx.cp_stream, False ) + # convert dq, dk, dv to appropriate types if ctx.fp8: - if ctx.fp8_recipe.float8_current_scaling() and ctx.is_input_fp8: - dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + if ( + ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8() + ) and ctx.is_input_fp8: + dq, dk, dv, _, _ = combine_and_quantize( + ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer + ) if ctx.fp8_recipe.delayed(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) @@ -3886,7 +4491,7 @@ def backward(ctx, dout, *_args): ] if not ctx.is_input_fp8: dq, dk, dv = combine_and_dequantize( - qkv_layout, + ctx.dqkv_layout, dq, dk, dv, @@ -3894,7 +4499,6 @@ def backward(ctx, dout, *_args): ) nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") - return ( None, dq, @@ -4069,17 +4673,6 @@ def attn_forward_func_with_cp( "all_gather", ], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!" - enable_mla = k.shape[-1] != v.shape[-1] - assert not enable_mla or cp_comm_type in [ - "p2p", - "a2a+p2p", - ], f"Context parallelism does not support MLA with {cp_comm_type=}!" - - if fp8 and fp8_meta is not None: - if fp8_meta["recipe"].fp8_dpa: - assert ( - softmax_type == "vanilla" - ), f"Context parallelism does not support {softmax_type=} with FP8 attention!" assert ( softmax_type == "vanilla" or use_fused_attention ), f"Context parallelism only supports {softmax_type=} with FusedAttention backend!" @@ -4131,7 +4724,16 @@ def attn_forward_func_with_cp( elif cp_comm_type == "all_gather": args.pop(5) args.pop(8) - args += [window_size, cp_group, cp_stream, use_flash_attn_3] + args += [ + window_size, + cp_group, + cp_stream, + use_flash_attn_3, + fp8, + fp8_meta, + quantizers, + fp8_output, + ] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": args += [ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 588c708e10..17e9a337a4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -19,6 +19,7 @@ Recipe, DelayedScaling, Float8CurrentScaling, + MXFP8BlockScaling, ) from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.quantization import ( @@ -30,7 +31,7 @@ Float8CurrentScalingRecipeState, Float8BlockScalingRecipeState, ) -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import Float8TensorStorage from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.constants import ( @@ -98,19 +99,26 @@ +-------------------+-----------+-----------------------------------------------------------------------------------+ | Linear | Attention | Configuration | +===================+===========+===================================================================================+ -| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to autocast(); | -| | | export NVTE_DPA_FP8_RECIPE="F16" | +| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS, NVFP4 or MXFP8 to autocast(); | +| /MXFP8 | | export NVTE_DPA_FP8_RECIPE="F16" | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8DS | FP8DS | Pass FP8DS to autocast(); | +| FP8DS | FP8DS | Pass FP8DS to autocast(); | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8CS | FP8DS | Pass FP8CS to autocast(); | +| FP8CS | FP8DS | Pass FP8CS to autocast(); | | | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| NVFP4 | FP8DS | Pass NVFP4 to autocast(); | +| MXFP8 | FP8DS | Pass MXFP8 to autocast(); | +| | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear MXFP8; | +| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | FP8DS | Pass NVFP4 to autocast(); | | | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | @@ -118,19 +126,27 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8DS | FP8CS | Pass FP8DS to autocast(); | +| FP8DS | FP8CS | Pass FP8DS to autocast(); | | | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,| | | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8CS | FP8CS | Pass FP8CS to autocast(); | +| FP8CS | FP8CS | Pass FP8CS to autocast(); | | | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe | | | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| NVFP4 | FP8CS | Pass NVFP4 to autocast(); | +| MXFP8 | FP8CS | Pass MXFP8 to autocast(); | +| | | Attention creates a new FP8CS recipe based on fp8_format, fp8_dpa, fp8_mha from | +| | | linear MXFP8, and: | +| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | FP8CS | Pass NVFP4 to autocast(); | | | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe | | | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | @@ -139,6 +155,18 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8DS/FP8CS | MXFP8 | Pass FP8DS/FP8CS to autocast(); | +| | | Attention creates a new MXFP8 recipe based on fp8_format, fp8_dpa, fp8_mha from | +| | | linear FP8DS/FP8CS | +| | | export NVTE_DPA_FP8_RECIPE="MXFP8BlockScaling" # switch to MXFP8BS | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| MXFP8 | MXFP8 | Pass MXFP8 to autocast(); | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | MXFP8 | Pass NVFP4 to autocast(); | +| | | Attention MXFP8 reuses the fp8_dpa, fp8_mha values from linear NVFP4; | +| | | export NVTE_DPA_FP8_RECIPE="MXFP8BlockScaling" # switch to MXFP8BS | +| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | ++-------------------+-----------+-----------------------------------------------------------------------------------+ """ _dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") formats = {"HYBRID": Format.HYBRID, "E4M3": Format.E4M3, "E5M2": Format.E5M2} @@ -600,7 +628,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # ignore the recipe from autocast, set fp8_dpa = False, fp8_mha = False fp8_recipe.fp8_dpa = False fp8_recipe.fp8_mha = False - elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe == "DelayedScaling": + elif ( + fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8() + ) and _dpa_fp8_recipe == "DelayedScaling": # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a DS recipe fake_recipe = DelayedScaling( fp8_format=fp8_recipe.fp8_format, @@ -653,6 +683,25 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ) fp8_recipe_dpa = fake_recipe fp8_recipes = [fp8_recipe, fp8_recipe_dpa] + elif fp8_recipe.mxfp8() and _dpa_fp8_recipe == "Float8CurrentScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a CS+DS recipe + fake_recipes = [ + Float8CurrentScaling( + fp8_format=fp8_recipe.fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ), + DelayedScaling( + fp8_format=fp8_recipe.fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ), + ] + fp8_recipe_dpa = fake_recipes[1] + fp8_recipes = fake_recipes elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "Float8CurrentScaling": # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format # construct a CS recipe for QKV, O, dO, dQKV and a DS recipe for S, dP @@ -673,11 +722,26 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ] fp8_recipe_dpa = fake_recipes[1] fp8_recipes = fake_recipes - # DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False - if not fp8_recipe_dpa.float8_per_tensor_scaling(): - assert not ( - fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha - ), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe" + elif ( + fp8_recipe.delayed() or fp8_recipe.float8_current_scaling() + ) and _dpa_fp8_recipe == "MXFP8BlockScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a MXFP8 recipe + fake_recipe = MXFP8BlockScaling( + fp8_format=fp8_recipe.fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa + elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "MXFP8BlockScaling": + # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a MXFP8 recipe + fake_recipe = MXFP8BlockScaling( + fp8_format=_dpa_fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa # reduce over TP+CP groups; expect fp8_group to be set up so # assume attention uses the same fp8_group as GEMMs @@ -1203,7 +1267,9 @@ def forward( cu_seqlens_kv_padded = None # get qkv's memory layout - if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): + if all( + isinstance(x, Float8TensorStorage) for x in [query_layer, key_layer, value_layer] + ): ( qkv_layout, query_layer._data, @@ -1365,6 +1431,7 @@ def forward( attention_dropout=self.attention_dropout, context_parallel=context_parallel, cp_comm_type=self.cp_comm_type, + cp_size=cp_size, deterministic=self.deterministic, is_training=self.training, fp8=self.fp8, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 20228ddb80..16817b0402 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -35,13 +35,18 @@ META_DP, ) from transformer_engine.pytorch.attention.inference import InferenceParams +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.tensor.float8_tensor import Float8TensorStorage +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor +from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage + from transformer_engine.pytorch.quantization import get_fp8_te_dtype -from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.constants import TE_DType, MXFP8_BLOCK_SCALING_SIZE from transformer_engine.pytorch.utils import ( @@ -231,6 +236,8 @@ class AttentionParams: Whether context parallelism is used or not. cp_comm_type : str, default = "p2p" The communication type of context parallelism. + cp_size : int, default = 1 + The group size of context parallelism. deterministic : bool, default = False Whether to run `DotProductAttention` with determinism or not. is_training : bool, default = True @@ -272,6 +279,7 @@ class AttentionParams: attention_dropout: float = 0.0 context_parallel: bool = False cp_comm_type: str = "p2p" + cp_size: int = 1 deterministic: bool = False is_training: bool = True fp8: bool = False @@ -349,6 +357,7 @@ def get_attention_backend( attention_dropout = attention_params.attention_dropout context_parallel = attention_params.context_parallel cp_comm_type = attention_params.cp_comm_type + cp_size = attention_params.cp_size # pylint: disable=unused-variable deterministic = attention_params.deterministic is_training = attention_params.is_training fp8 = attention_params.fp8 @@ -368,9 +377,9 @@ def get_attention_backend( cudnn_version = get_cudnn_version() run_config = { "transformer_engine_version": te.__version__, - "compute_capability": ( - "sm" + str(10 * device_compute_capability[0] + device_compute_capability[1]) - ), + "compute_capability": "sm" + + str(10 * device_compute_capability[0] + device_compute_capability[1]), + "cuda_version": torch.version.cuda, "flash_attn_version": ( str(FlashAttentionUtils.version) if FlashAttentionUtils.is_installed @@ -464,9 +473,14 @@ def get_attention_backend( # On SM90, prefer FA3 over FA4 when FA3 is available. # FA3 is more mature on Hopper; FA4's SM90 backward has limitations # (MLA, non-standard head dims, SplitKV). - if use_flash_attention_4 and use_flash_attention_3 and device_compute_capability == (9, 0): - if FlashAttentionUtils.v4_is_installed: - logger.debug("Disabling FlashAttention 4 to prefer FlashAttention 3 on SM90") + if ( + device_compute_capability == (9, 0) + and use_flash_attention_3 + and FlashAttentionUtils.v3_is_installed + and use_flash_attention_4 + and FlashAttentionUtils.v4_is_installed + ): + logger.debug("Disabling FlashAttention 4 to prefer FlashAttention 3 on SM90") use_flash_attention_4 = False # Filter: Data type @@ -488,21 +502,30 @@ def get_attention_backend( if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or qkv_type not in [ torch.Tensor, Float8Tensor, + Float8TensorStorage, ]: if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: logger.debug( - "Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " - "qkv_type = {torch.Tensor, Float8Tensor}. ", + "Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s." + " Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}," + " qkv_type = {torch.Tensor, Float8Tensor, Float8TensorStorage}. ", qkv_dtype, qkv_type, ) use_flash_attention_3 = False + if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or qkv_type not in ( + torch.Tensor, + Float8Tensor, + Float8TensorStorage, + MXFP8Tensor, + MXFP8TensorStorage, + ): if use_fused_attention: logger.debug( - "Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " - "qkv_type = {torch.Tensor, Float8Tensor}. ", + "Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. Supported:" + " qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, qkv_type =" + " {torch.Tensor, Float8Tensor, Float8TensorStorage, MXFP8Tensor," + " MXFP8TensorStorage}. ", qkv_dtype, qkv_type, ) @@ -510,6 +533,9 @@ def get_attention_backend( # Filter: Execution type if fp8 and fp8_meta["recipe"].fp8_dpa: + fp8_recipe = fp8_meta["recipe"] + if fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] if use_flash_attention_2 and FlashAttentionUtils.is_installed: logger.debug("Disabling FlashAttention 2 for FP8 attention") use_flash_attention_2 = False @@ -520,6 +546,12 @@ def get_attention_backend( if FlashAttentionUtils.v3_is_installed: logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False + if use_flash_attention_3 and not ( + fp8_recipe.delayed() or fp8_recipe.float8_current_scaling() + ): + if FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 for %s", fp8_recipe.__class__.__name__) + use_flash_attention_3 = False if use_unfused_attention: allow_emulation = ( os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() @@ -527,15 +559,21 @@ def get_attention_backend( if not allow_emulation: logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False - fp8_recipe = fp8_meta["recipe"] - if fp8_meta.get("local_recipes", None) is not None: - fp8_recipe = fp8_meta["local_recipes"][0] + if use_fused_attention and fp8_recipe.delayed(): + if ( + device_compute_capability >= (10, 0) + and deterministic + and cudnn_version < (9, 18, 0) + ): + logger.debug( + "Disabling FusedAttention for FP8 delayed scaling on arch >= sm100 with" + " determinism for cuDNN < 9.18.0" + ) + use_fused_attention = False if use_fused_attention and fp8_recipe.float8_current_scaling(): if device_compute_capability < (10, 0): logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") use_fused_attention = False - # TODO(cyanguwa): Modify the min cuDNN version supporting FP8 current scaling - # determinism for Blackwell else: if cudnn_version < (9, 14, 0): logger.debug( @@ -545,10 +583,27 @@ def get_attention_backend( else: if deterministic and cudnn_version < (9, 18, 0): logger.debug( - "Disabling FusedAttention for FP8 current scaling requiring determinism" - " with cuDNN < 9.18.0" + "Disabling FusedAttention for FP8 current scaling with determinism" + " for cuDNN < 9.18.0" ) use_fused_attention = False + if use_fused_attention and fp8_recipe.mxfp8(): + if device_compute_capability < (10, 0): + logger.debug("Disabling FusedAttention for MXFP8 on arch < sm100") + use_fused_attention = False + elif fp8_recipe.fp8_mha: + logger.debug("Disabling FusedAttention for MXFP8 with fp8_mha=True") + use_fused_attention = False + else: + if cudnn_version < (9, 21, 0): + logger.debug("Disabling FusedAttention for MXFP8 with cuDNN < 9.21.0") + use_fused_attention = False + elif qkv_format == "thd": + logger.debug("Disabling FusedAttention for MXFP8 with qkv_format = thd") + use_fused_attention = False + if use_fused_attention and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()): + logger.debug("Disabling FusedAttention for %s", fp8_recipe.__class__.__name__) + use_fused_attention = False if device_compute_capability == (12, 0): if use_flash_attention: @@ -837,29 +892,36 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type) use_flash_attention = False if fp8 and fp8_meta["recipe"].fp8_dpa: - logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type) - use_fused_attention = False - logger.debug( - "Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type - ) - use_unfused_attention = False - if qkv_format == "thd": - if cudnn_version < (9, 18, 0): + if use_fused_attention and ( + device_compute_capability < (10, 0) or cudnn_version < (9, 21, 0) + ): logger.debug( - "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN" - " version < 9.18", + "Disabling FusedAttention for softmax_type = %s in FP8 on sm < 100 with cuDNN" + " version < 9.21", softmax_type, ) use_fused_attention = False - if context_parallel: - if cp_comm_type != "a2a": + if use_unfused_attention: logger.debug( - "Disabling FusedAttention for context parallelism with softmax_type = %s and" - " cp_comm_type = %s", + "Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type, - cp_comm_type, ) - use_fused_attention = False + use_unfused_attention = False + if qkv_format == "thd" and cudnn_version < (9, 18, 0): + logger.debug( + "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN" + " version < 9.18", + softmax_type, + ) + use_fused_attention = False + if context_parallel and cp_comm_type != "a2a": + logger.debug( + "Disabling FusedAttention for context parallelism with softmax_type = %s and" + " cp_comm_type = %s", + softmax_type, + cp_comm_type, + ) + use_fused_attention = False # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends @@ -946,10 +1008,50 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt " bias for THD format" ) use_fused_attention = False - elif fp8 and fp8_meta["recipe"].fp8_dpa and head_dim_qk != head_dim_v: + elif fp8 and fp8_meta["recipe"].fp8_dpa and qkv_format == "thd": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with FP8" + " attention and THD format" + ) + use_fused_attention = False + elif fp8 and fp8_meta["recipe"].fp8_dpa and core_attention_bias_type != "no_bias": logger.debug( "Disabling FusedAttention as it does not support context parallelism with FP8" - " MLA attention" + " attention and bias" + ) + use_fused_attention = False + elif core_attention_bias_type != "no_bias" and cp_comm_type != "p2p": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with bias" + " and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with THD" + " format and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif ( + window_size is not None + and (window_size[0] != -1 or window_size[1] not in [-1, 0]) + and cp_comm_type in ["p2p", "a2a+p2p"] + ): + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with sliding" + " window attention and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif cp_comm_type in ["a2a", "a2a+p2p"] and (num_heads % 2 != 0 or num_gqa_groups % 2 != 0): + logger.debug( + "Disabling FusedAttention as cp_comm_type = %s requires num_heads and" + " num_gqa_groups divisible by 2 (got num_heads = %s, num_gqa_groups = %s)", + cp_comm_type, + num_heads, + num_gqa_groups, ) use_fused_attention = False @@ -1004,9 +1106,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): + if ( + fp8 + and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha) + and (device_compute_capability < (10, 0) or cudnn_version < (9, 21, 0)) + ): logger.debug( "Disabling FusedAttention as it does not support sliding window attention for FP8" + " on sm < 100 with cuDNN version < 9.21" ) use_fused_attention = False elif attention_dropout != 0.0: @@ -1150,8 +1257,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if ( use_fused_attention and window_size is not None - and window_size[0] != -1 - and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] + and (window_size[0] != -1 or window_size[1] not in [-1, 0]) + and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] ): logger.debug( "Disabling FusedAttention as only sub-backend %s does not support " @@ -2256,28 +2363,45 @@ def check_set_window_size( return window_size -def get_attention_quantizers(fp8, quantizers): +def get_attention_quantizers(fp8, fp8_recipe, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: return [None] * 6 + QKV_quantizer = quantizers["scaling_fwd"][META_QKV] - QKV_quantizer.internal = True + QKV_quantizer.internal = False QKV_quantizer.set_usage(rowwise=True, columnwise=False) - O_quantizer = quantizers["scaling_fwd"][META_O] - O_quantizer.set_usage(rowwise=True, columnwise=False) + S_quantizer = quantizers["scaling_fwd"][META_S] S_quantizer.internal = True S_quantizer.set_usage(rowwise=True, columnwise=False) - dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] - dQKV_quantizer.interal = True - dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + O_quantizer = quantizers["scaling_fwd"][META_O] + O_quantizer.internal = False + O_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer = quantizers["scaling_bwd"][META_DO] + dO_quantizer.internal = False dO_quantizer.set_usage(rowwise=True, columnwise=False) - dO_quantizer.internal = True + dP_quantizer = quantizers["scaling_bwd"][META_DP] + dP_quantizer.internal = True dP_quantizer.set_usage(rowwise=True, columnwise=False) - dP_quantizer.interal = True + + dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] + dQKV_quantizer.internal = False + dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + + if fp8_recipe.mxfp8(): + QKV_quantizer.columnwise_usage = True + QKV_quantizer.optimize_for_gemm = True + S_quantizer = None + O_quantizer.columnwise_usage = True + + dO_quantizer.columnwise_usage = True + dO_quantizer.optimize_for_gemm = True + dP_quantizer = None + dQKV_quantizer.columnwise_usage = True return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer @@ -2331,18 +2455,289 @@ def print_quantizers( type_str = "DS" elif isinstance(q, Float8CurrentScalingQuantizer): type_str = "CS" - print( - f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" - f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" + elif isinstance(q, MXFP8Quantizer): + type_str = "MXFP8" + if type_str in ["DS", "CS"]: + print( + f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" + f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" + ) + else: + print(f"{label} >> {names[i]:14s}: {type_str}") + + +def transpose_to_bhsd_htd_pytorch(tensor, src_format): + """Permute to BHSD or HTD format using native PyTorch operations.""" + if src_format in ("bhsd", "htd"): + return tensor + dim_s = src_format.find("s") if "s" in src_format else src_format.find("t") + dim_others = [i for i in range(tensor.ndim) if i != dim_s] + new_dims = [*dim_others[:-1], dim_s, dim_others[-1]] + return tensor.permute(*new_dims).contiguous() + + +def mxfp8_quantize_fast_path(tensor_quantizer_pairs, src_format): + """MXFP8 attention requires quantization along S and D dimensions. This fast path + quantizes tensors without swizzle, and pads, permutes and swizzles the scale_invs + to achieve faster speed due to the smaller sizes of scale_invs compare to the data. + The output tensors have _rowwise_data and _columnwise_data in src_format, and + _rowwise_scale_inv and _columnwise_scale_inv in BHSD format. + + Parameters + ---------- + tensor_quantizer_pairs : list of (torch.Tensor, MXFP8Quantizer) + Each pair is a tensor and its quantizer (with the desired + rowwise_usage / columnwise_usage already set). + src_format : str + Layout of input tensors: ``"bshd"`` or ``"sbhd"``. + All tensors in the list must have the same src_format. + Returns + ------- + fp8_tensors : list of MXFP8Tensors + Data in ``src_format``, scale_inv in BHSD format. + scale_inv_format : str + Always ``"bhsd"``. + """ + if not tensor_quantizer_pairs: + return [], src_format + assert src_format in ( + "bshd", + "sbhd", + ), f"mxfp8_quantize_fast_path only supports bshd/sbhd, got {src_format!r}." + _s_dim = {"bshd": 1, "sbhd": 0} + _d_dim = {"bshd": 3, "sbhd": 3} + + fp8_tensors = [] + for tensor, quantizer in tensor_quantizer_pairs: + original_shape = tensor.shape + rs_shape = list(original_shape) + rs_shape[_d_dim[src_format]] //= MXFP8_BLOCK_SCALING_SIZE + cs_shape = list(original_shape) + cs_shape[_s_dim[src_format]] //= MXFP8_BLOCK_SCALING_SIZE + + # view tensor as 2D for quantization + # BSHD -> (B*S, H*D) + # SBHD -> (S, B*H*D) + if src_format == "bshd": + tensor = tensor.view(*tensor.shape[:2], -1) + else: + tensor = tensor.view(tensor.shape[0], -1) + + # quantize + orig_optimize = quantizer.optimize_for_gemm + quantizer.optimize_for_gemm = False + fp8_tensor = quantizer(tensor) + quantizer.optimize_for_gemm = orig_optimize + + # reshape rowwise/columnwise data to original shape + fp8_tensor._rowwise_data = ( + fp8_tensor._rowwise_data.view(original_shape) + if fp8_tensor._rowwise_data is not None + else None + ) + fp8_tensor._columnwise_data = ( + fp8_tensor._columnwise_data.view(original_shape) + if fp8_tensor._columnwise_data is not None + else None + ) + fp8_tensor._rowwise_scale_inv = ( + fp8_tensor._rowwise_scale_inv.view(rs_shape) + if fp8_tensor._rowwise_scale_inv is not None + else None + ) + fp8_tensor._columnwise_scale_inv = ( + fp8_tensor._columnwise_scale_inv.view(cs_shape) + if fp8_tensor._columnwise_scale_inv is not None + else None + ) + fp8_tensors.append(fp8_tensor) + + # ---- Pad + permute + swizzle scale_inv to BHSD ---- + rs_list = [t._rowwise_scale_inv for t in fp8_tensors] + cs_list = [t._columnwise_scale_inv for t in fp8_tensors] + + def _align_up(x, a): + return ((x + a - 1) // a) * a + + def _bhsd_shape(src_4d, d_pad): + if src_format == "sbhd": + S, B, H, _ = src_4d.shape + else: + B, S, H, _ = src_4d.shape + return (B, H, S, d_pad) + + def _build_outputs(scale_list, alignment): + entries = [] + total = 0 + for s in scale_list: + if s is None: + entries.append(None) + continue + d_pad = _align_up(s.shape[-1], alignment) + shape = _bhsd_shape(s, d_pad) + numel = 1 + for dim in shape: + numel *= dim + entries.append((total, numel, shape)) + total += numel + if total == 0: + return [None] * len(scale_list) + device = next(s for s in scale_list if s is not None).device + buf = torch.empty(total, dtype=torch.uint8, device=device) + return [buf[e[0] : e[0] + e[1]].view(e[2]) if e is not None else None for e in entries] + + # allocate buffers with padding in mind + rs_outs = _build_outputs(rs_list, 4) + cs_outs = _build_outputs(cs_list, 128) + + # permute scale_invs to BHSD; batched + rs_permuted = tex.multi_tensor_transpose_to_bhsd( + rs_list, + original_format=src_format, + outputs=rs_outs, + ) + cs_permuted = tex.multi_tensor_transpose_to_bhsd( + cs_list, + original_format=src_format, + outputs=cs_outs, + ) + + # build output tensors + result = [] + for t, rp, cp in zip(fp8_tensors, rs_permuted, cs_permuted): + rp = rp.view(-1, rp.shape[-1]) if rp is not None else None + cp = cp.view(-1, cp.shape[-1]) if cp is not None else None + result.append( + MXFP8Tensor( + shape=t.shape, + dtype=t.dtype, + rowwise_data=t._rowwise_data, + rowwise_scale_inv=rp, + columnwise_data=t._columnwise_data, + columnwise_scale_inv=cp, + quantizer=t._quantizer, + requires_grad=False, + fp8_dtype=t._fp8_dtype, + with_gemm_swizzled_scales=t._with_gemm_swizzled_scales, ) + ) + # swizzle in place; batched + tex.multi_tensor_swizzle_scales_for_gemm_unchecked_(result, True, False) + tex.multi_tensor_swizzle_scales_for_gemm_unchecked_(result, False, True) + for t in result: + t._with_gemm_swizzled_scales = True + + return result, "bhsd" + + +def combine_and_quantize( + qkv_layout, + q, + k, + v, + qkv_quantizer, + used_in_forward=True, + used_in_backward=False, + keep_same_data_and_scale_inv_format=False, +): + """Combine Q, K, V tensors based on qkv_layout and quantize them together.""" + if isinstance(qkv_quantizer, MXFP8Quantizer): + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) + assert qkv_format in ("bshd", "sbhd"), ( + "combine_and_quantize only supports bshd/sbhd for MXFP8 quantization, got" + f" {qkv_format!r}." + ) + + _s_dim = {"sbhd": 0, "bshd": 1} + _d_dim = {"sbhd": 3, "bshd": 3} + d_qk = q.shape[_d_dim[qkv_format]] + d_v = v.shape[_d_dim[qkv_format]] + s_q = q.shape[_s_dim[q_format]] + s_kv = v.shape[_s_dim[kv_format]] + assert s_q % 128 == 0 and s_kv % 128 == 0 and d_qk % 32 == 0 and d_v % 32 == 0, ( + "MXFP8 quantization requires s_q % 128 == 0, s_kv % 128 == 0, d_qk % 32 == 0, d_v % 32" + f" == 0. Found {s_q=}, {s_kv=}, {d_qk=}, {d_v=}." + ) + + if qkv_layout not in ("bshd_bshd_bshd", "sbhd_sbhd_sbhd"): + keep_same_data_and_scale_inv_format = True + + # ---- Fast path: quantize in original layout, permute scale_inv to BHSD, then swizzle ---- + if not keep_same_data_and_scale_inv_format: + q_quantizer, k_quantizer, v_quantizer = [qkv_quantizer.copy() for _ in range(3)] + if used_in_forward and not used_in_backward: + q_quantizer.rowwise_usage = True + q_quantizer.columnwise_usage = False + k_quantizer.rowwise_usage = True + k_quantizer.columnwise_usage = False + v_quantizer.rowwise_usage = False + v_quantizer.columnwise_usage = True + elif (not used_in_forward) and used_in_backward: + q_quantizer.rowwise_usage = True + q_quantizer.columnwise_usage = True + k_quantizer.rowwise_usage = True + k_quantizer.columnwise_usage = True + v_quantizer.rowwise_usage = True + v_quantizer.columnwise_usage = False + (q_fp8, k_fp8, v_fp8), qkv_scale_inv_format = mxfp8_quantize_fast_path( + [(q, q_quantizer), (k, k_quantizer), (v, v_quantizer)], qkv_format + ) + return q_fp8, k_fp8, v_fp8, qkv_layout, qkv_scale_inv_format + + # ---- Slow path: permute data to BHSD, then quantize with swizzle ---- + if qkv_layout in ("bshd_bshd_bshd", "sbhd_sbhd_sbhd"): + q, k, v = tex.multi_tensor_transpose_to_bhsd( + [q, k, v], + original_format=qkv_format, + ) + else: + q = transpose_to_bhsd_htd_pytorch(q, q_format) + k = transpose_to_bhsd_htd_pytorch(k, kv_format) + v = transpose_to_bhsd_htd_pytorch(v, kv_format) + qkv_layout = "bhsd_bhsd_bhsd" + qkv_scale_inv_format = "bhsd" + + original_shapes = [x.shape for x in [q, k, v]] + q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] + + q_quantizer, k_quantizer, v_quantizer = [qkv_quantizer.copy() for _ in range(3)] + if used_in_forward and not used_in_backward: + q_quantizer.rowwise_usage = True + q_quantizer.columnwise_usage = False + k_quantizer.rowwise_usage = True + k_quantizer.columnwise_usage = False + v_quantizer.rowwise_usage = False + v_quantizer.columnwise_usage = True + elif (not used_in_forward) and used_in_backward: + q_quantizer.rowwise_usage = True + q_quantizer.columnwise_usage = True + k_quantizer.rowwise_usage = True + k_quantizer.columnwise_usage = True + v_quantizer.rowwise_usage = True + v_quantizer.columnwise_usage = False + q_fp8, k_fp8, v_fp8 = [ + quant(x) for quant, x in zip([q_quantizer, k_quantizer, v_quantizer], [q, k, v]) + ] + + for fp8_tensor, shape in zip([q_fp8, k_fp8, v_fp8], original_shapes): + fp8_tensor._rowwise_data = ( + fp8_tensor._rowwise_data.view(shape) + if fp8_tensor._rowwise_data is not None + else None + ) + fp8_tensor._columnwise_data = ( + fp8_tensor._columnwise_data.view(shape) + if fp8_tensor._columnwise_data is not None + else None + ) + + return q_fp8, k_fp8, v_fp8, qkv_layout, qkv_scale_inv_format -def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): - """Combine q,k,v based on qkv_layout and quantize them together""" - # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_layout = qkv_layout.replace("paged_kv_", "") qkv_group = len(qkv_layout.split("_")) src_nominal_dtype = q.dtype + # 1: qkv packed, 2: kv packed, 3: qkv separate match qkv_group: case 1: dim = qkv_layout.find("3") @@ -2382,24 +2777,28 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): for x in [q_data, k_data, v_data] ] - return q_fp8, k_fp8, v_fp8 + return q_fp8, k_fp8, v_fp8, qkv_layout, None def combine_and_dequantize( qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=None, des_nominal_dtype=None ): """Combine q,k,v based on qkv_layout and dequantize them together""" - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_layout = qkv_layout.replace("paged_kv_", "") - qkv_group = len(qkv_layout.split("_")) - if all(isinstance(x, Float8Tensor) for x in [q_fp8, k_fp8, v_fp8]): + if all(isinstance(x, QuantizedTensorStorage) for x in [q_fp8, k_fp8, v_fp8]): src_nominal_dtype = q_fp8.dtype else: assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!" if des_nominal_dtype is None: des_nominal_dtype = src_nominal_dtype + if all(isinstance(x, (MXFP8Tensor, MXFP8TensorStorage)) for x in [q_fp8, k_fp8, v_fp8]): + q, k, v = [x.dequantize(dtype=des_nominal_dtype) for x in [q_fp8, k_fp8, v_fp8]] + return q, k, v + + qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_group = len(qkv_layout.split("_")) q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]] + # 1: qkv packed, 2: kv packed, 3: qkv separate match qkv_group: case 1: dim = qkv_layout.find("3") diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index d95d327c78..afc4622b22 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -795,15 +795,31 @@ def forward( fp8_dpa = fp8_recipe.fp8_dpa fp8_mha = fp8_recipe.fp8_mha float8_current_scaling = fp8_recipe.float8_current_scaling() + mxfp8_scaling = fp8_recipe.mxfp8() else: fp8_dpa = _dpa_fp8_recipe_dpa fp8_mha = _dpa_fp8_recipe_mha float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling" - # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe - qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling - # DPA: always produce FP8 output when fp8=True to take advantage of the O amax - dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) - # Proj Gemm: match DPA output except for Float8CurrentScaling + mxfp8_scaling = _dpa_fp8_recipe == "MXFP8BlockScaling" + + # QKV Gemm: do not produce FP8 output when fp8_mha = True if + # 1. RoPE is on: RoPE is only implemented in F16 currently + # 2. FP8CS recipe: due to cuBLAS limitation, FP8CS Gemms can not produce FP8 output + # 3. MXFP8 recipe: QKV Gemm produces QKV in bs(hd), sb(hd), t(hd) shapes, quantization of which would be along + # s/b/t and (hd) dimensions, whereas MXFP8 attention requires quantization along s and d, e.g. bhsd, sbhd, thd + qkv_fp8_output = ( + fp8 + and fp8_mha + and rotary_pos_emb is None + and not float8_current_scaling + and not mxfp8_scaling + ) + # DPA: produce FP8 output to take advantage of O amax from DPA; Projection Gemm can take FP8 or F16 inputs + # 1. FP8DS/FP8CS recipe: produce FP8 output + # 2. MXFP8 recipe: produce F16 output; again, due to quantization dimensions mismatch + dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) and not mxfp8_scaling + # Projection Gemm: match DPA output except + # 1. FP8CS recipe: produce F16 grads; again, due to cuBLAS limitation proj_fp8_grad = dpa_fp8_output and not float8_current_scaling layernorm_output = None diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 06bfb6ef3c..01e139da46 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -35,6 +35,7 @@ } QKVFormat = { + None: NVTE_QKV_Format.NVTE_QKV_Format_NOT_SET, "bshd": NVTE_QKV_Format.NVTE_BSHD, "sbhd": NVTE_QKV_Format.NVTE_SBHD, "thd": NVTE_QKV_Format.NVTE_THD, @@ -42,6 +43,7 @@ "bshd_2sbhd": NVTE_QKV_Format.NVTE_BSHD_2SBHD, "thd_2bshd": NVTE_QKV_Format.NVTE_THD_2BSHD, "thd_2sbhd": NVTE_QKV_Format.NVTE_THD_2SBHD, + "bhsd": NVTE_QKV_Format.NVTE_BHSD, } QKVLayout = { @@ -70,6 +72,7 @@ "paged_kv_sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_SBHD_SBHD, "paged_kv_thd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_BSHD_BSHD, "paged_kv_thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_SBHD_SBHD, + "bhsd_bhsd_bhsd": NVTE_QKV_Layout.NVTE_BHSD_BHSD_BHSD, } AttnBiasType = { @@ -134,6 +137,8 @@ def fused_attn_fwd( dropout: float = 0.0, fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", + o_format: str = "sbhd", + qkv_scale_inv_format: str = None, attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", softmax_type: str = "vanilla", @@ -203,6 +208,11 @@ def fused_attn_fwd( {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} + o_format : str, default = "sbhd" + format of O; {"sbhd", "bshd", "thd"} + qkv_scale_inv_format : str, default = None + format of the scale-inverse tensors for QKV; {"sbhd", "bshd", "thd", "bhsd"}; + if None, defaults to the format inferred from qkv_layout. attn_bias_type : str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type : str, default = "padding" @@ -251,7 +261,7 @@ def fused_attn_fwd( M: torch.Tensor max(Q*K.T) shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - ZInv: torch.Tensor + ZInv: torch.Tensor, only allocated for T3HD path 1/sum(e^(x - max(x))), where x=Q*K.T shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen @@ -302,17 +312,6 @@ def fused_attn_fwd( rng_elts_per_thread = ( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - - if s_quantizer is None: - raise ValueError( - "s_quantizer is required for FP8 fused attention forward" - f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." - ) - if o_quantizer is None: - raise ValueError( - "o_quantizer is required for FP8 fused attention forward" - f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." - ) else: raise ValueError(f"Unsupported backend {fused_attention_backend}") @@ -326,6 +325,8 @@ def fused_attn_fwd( dropout, fast_zero_fill, QKVLayout[qkv_layout], + QKVFormat[o_format], + QKVFormat[qkv_scale_inv_format], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], @@ -415,7 +416,6 @@ def fused_attn_bwd( o: torch.Tensor, d_o: torch.Tensor, fake_dtype: torch.dtype, - dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, cu_seqlens_q_padded: torch.Tensor = None, @@ -427,6 +427,11 @@ def fused_attn_bwd( dropout: float = 0.0, fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", + o_format: str = "sbhd", + do_format: str = "sbhd", + dqkv_layout: str = "sbh3d", + qkv_scale_inv_format: str = None, + do_scale_inv_format: str = None, attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", softmax_type: str = "vanilla", @@ -465,8 +470,6 @@ def fused_attn_bwd( fake_dtype : tex.DType data type of Q, K and V - in case of high precision, fake dtype in case of FP8; in torch.dtype - dqkv_dtype : tex.DType - data type of dQ, dK and dV; in tex.DType, not torch.dtype aux_ctx_tensors : List[torch.Tensor] auxiliary output tensors of the forward pass when its is_training is True, e.g. aux_ctx_tensors = [M, ZInv, rng_state] @@ -482,6 +485,9 @@ def fused_attn_bwd( Quantizer object for the intermediate value dP. dqkv_quantizer : Quantizer, default = None Quantizer object for the output values of the fused_attn_bwd. + attn_scale : float, default = None + if not None, use attn_scale as the attention scale for Q*K.T BMM; + if None, use 1.0/sqrt(head_dim_qk) as the default dropout : float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -493,6 +499,21 @@ def fused_attn_bwd( {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} + o_format : str, default = "sbhd" + format of O; {"sbhd", "bshd", "thd"} + do_format : str, default = "sbhd" + format of dO; {"sbhd", "bshd", "thd"} + dqkv_layout : str, default = "sbh3d" + layout of dQ, dK and dV; + {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", + "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", + "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} + qkv_scale_inv_format : str, default = None + format of the scale-inverse tensors for QKV; {"sbhd", "bshd", "thd", "bhsd"}; + if None, defaults to the format inferred from qkv_layout. + do_scale_inv_format : str, default = None + format of the scale-inverse tensors for dO; {"sbhd", "bshd", "thd", "bhsd"}; + if None, defaults to the format inferred from the output layout. attn_bias_type : str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type : str, default = "padding" @@ -553,29 +574,6 @@ def fused_attn_bwd( f" for backend={fused_attention_backend}." ) - if fused_attention_backend == FusedAttnBackend["FP8"]: - if s_quantizer is None: - raise ValueError( - "s_quantizer is required for FP8 fused attention backward" - f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." - ) - if dp_quantizer is None: - raise ValueError( - "dp_quantizer is required for FP8 fused attention backward" - f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." - ) - if dqkv_dtype is None: - raise ValueError( - "dqkv_dtype is required for FP8 fused attention backward" - f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." - ) - if len(aux_ctx_tensors) != 3: - raise ValueError( - "aux_ctx_tensors must be [M, ZInv, rng_state] for FP8 fused attention," - f" but got len(aux_ctx_tensors)={len(aux_ctx_tensors)}" - f" (backend={fused_attention_backend})." - ) - output_tensors = tex.fused_attn_bwd( max_seqlen_q, max_seqlen_kv, @@ -583,6 +581,11 @@ def fused_attn_bwd( dropout, fast_zero_fill, QKVLayout[qkv_layout], + QKVFormat[o_format], + QKVFormat[do_format], + QKVLayout[dqkv_layout], + QKVFormat[qkv_scale_inv_format], + QKVFormat[do_scale_inv_format], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], @@ -597,7 +600,6 @@ def fused_attn_bwd( o, d_o, fake_dtype, - dqkv_dtype, aux_ctx_tensors, cu_seqlens_q_padded, cu_seqlens_kv_padded, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index fb5783dfcb..929be8906f 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -84,11 +84,11 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, bool bottom_right_diagonal, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const at::ScalarType fake_dtype, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, + bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, + const py::handle Q, const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, @@ -98,11 +98,13 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -111,6 +113,13 @@ std::vector fused_attn_bwd( at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); +std::vector> multi_tensor_transpose_to_bhsd( + std::vector> inputs, const std::string &original_format, + std::vector> outputs = {}); + +std::vector multi_tensor_pad_last_dim(std::vector inputs, + int64_t alignment); + at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len); at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t); void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache, @@ -572,6 +581,13 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, void inplace_swizzle_scale_for_gemm(py::handle &tensor); +void inplace_multi_tensor_swizzle_scales_for_gemm(std::vector &tensors, + bool rowwise_usage, bool columnwise_usage); + +void inplace_multi_tensor_swizzle_scales_for_gemm_unchecked(std::vector &tensors, + bool rowwise_usage, + bool columnwise_usage); + void grouped_swizzle_for_gemm(py::handle &tensor, bool rowwise, bool columnwise); /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index ff60bb87bb..8a2e54a733 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -57,7 +57,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( // helper function for S and dP quantizers std::pair quantizer_helper(py::handle quantizer, const std::vector &shape, DType dtype, - bool create_hp_tensor_for_cs, + bool create_hp_tensor, std::optional data) { std::unique_ptr T_quantizer = convert_quantizer(quantizer); TensorWrapper te_T; @@ -78,7 +78,7 @@ std::pair quantizer_helper(py::handle quantizer, } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // current scaling auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); - if (create_hp_tensor_for_cs) { + if (create_hp_tensor) { if (data.has_value()) { std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); @@ -91,6 +91,20 @@ std::pair quantizer_helper(py::handle quantizer, !data.has_value(), "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); } + } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + // MXFP8 + if (create_hp_tensor) { + if (data.has_value()) { + std::tie(te_T, py_T) = NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()); + } else { + std::tie(te_T, py_T) = NoneQuantizer(py::none()).create_tensor(shape, dtype); + } + } else { + auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); + std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); + NVTE_CHECK(!data.has_value(), + "MXFP8Quantizer::create_tensor() does not take data tensor as input!"); + } } return {std::move(te_T), std::move(py_T)}; } @@ -98,11 +112,11 @@ std::pair quantizer_helper(py::handle quantizer, // fused attention FWD with separate Q, K and V tensors std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, bool bottom_right_diagonal, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const at::ScalarType fake_dtype, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, + bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, + const py::handle Q, const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, @@ -134,8 +148,13 @@ std::vector fused_attn_fwd( std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); std::vector q_shape = convertShape(te_Q.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; - o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; + auto o_shape_tmp = std::vector{q_shape.begin(), q_shape.end()}; + o_shape_tmp[o_shape_tmp.size() - 1] = v_shape[v_shape.size() - 1]; + auto o_shape = std::vector{o_shape_tmp.begin(), o_shape_tmp.end()}; + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + AttentionShape o_parsed(q_format, o_shape_tmp.data()); + size_t h = o_parsed.h(), d = o_parsed.d(); + o_parsed.to_format(o_format, o_shape.data()); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); @@ -146,9 +165,7 @@ std::vector fused_attn_fwd( TensorWrapper te_page_table_k, te_page_table_v; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 - auto h = q_shape[q_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; - if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && (o_format == NVTE_QKV_Format::NVTE_THD)) { if ((h * d) % block_size == 0) { mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { @@ -156,7 +173,7 @@ std::vector fused_attn_fwd( } } } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + if (o_format == NVTE_QKV_Format::NVTE_THD) { te_O.zero_(at::cuda::getCurrentCUDAStream()); } } else { @@ -235,9 +252,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), - at::cuda::getCurrentCUDAStream()); + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, + qkv_scale_inv_format, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace and auxiliary output tensors @@ -260,7 +277,7 @@ std::vector fused_attn_fwd( // f16_arbitrary: // return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] // return_max_logit=true: S [b, h, sq, 1], Max [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] - // fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2] + // fp8 : M [b, h, sq, 1], optional ZInv [b, h, sq, 1] (T3HD path), rng_state [2] size_t i = 0; at::Tensor output_tensor; // intermediate softmax tensor, S or M (for fp8) @@ -268,8 +285,10 @@ std::vector fused_attn_fwd( allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); set_tensor_param(i++, output_tensor); - // fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Max tensor - if (return_max_logit || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // fp8 T3HD has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Max tensor + if (((qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) && + qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) || + return_max_logit) { output_tensor = allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); @@ -295,9 +314,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), - at::cuda::getCurrentCUDAStream()); + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, + qkv_scale_inv_format, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers, but not allocated memory @@ -310,11 +329,13 @@ std::vector fused_attn_fwd( // fused attention BWD with separate Q, K and V std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -343,25 +364,37 @@ std::vector fused_attn_bwd( std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto h_q = q_shape[q_shape.size() - 2]; - auto h_kv = k_shape[k_shape.size() - 2]; - auto d_qk = q_shape[q_shape.size() - 1]; - const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); - + const DType dqkv_fake_dtype = GetTransformerEngineDType(fake_dtype); + size_t ndim_q = q_shape.size(); + size_t ndim_kv = k_shape.size(); + std::vector dQ_shape(ndim_q), dK_shape(ndim_kv), dV_shape(ndim_kv); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + NVTE_QKV_Format dq_format = nvte_get_q_format(dqkv_layout); + NVTE_QKV_Format dkv_format = nvte_get_kv_format(dqkv_layout); + AttentionShape q_parsed(q_format, q_shape.data()); + size_t h_q = q_parsed.h(), d_qk = q_parsed.d(); + q_parsed.to_format(dq_format, dQ_shape.data()); + AttentionShape k_parsed(kv_format, k_shape.data()); + size_t h_kv = k_parsed.h(); + k_parsed.to_format(dkv_format, dK_shape.data()); + AttentionShape v_parsed(kv_format, v_shape.data()); + size_t d_v = v_parsed.d(); + v_parsed.to_format(dkv_format, dV_shape.data()); at::Tensor dQ, dK, dV, dQKV, dKV; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - std::vector tmp_shape; - auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { + // FP16/BF16: dqkv_fake_dtype = kFloat16/kBFloat16, dQ/dK/dV.dtype = torch.float16/torch.bfloat16 + // FP8DS: dqkv_fake_dtype = kFloat16/kBFloat16, dQ/dK/dV.dtype = torch.uint8 + // FP8CS/MXFP8: dqkv_fake_dtype = kFloat16/kBFloat16, dQ/dK/dV.dtype = torch.float16/torch.bfloat16 + auto options = torch::TensorOptions().dtype(fake_dtype).device(torch::kCUDA); + if (detail::IsFloat8Quantizers(dqkv_quantizer.ptr())) { options = options.dtype(torch::kUInt8); } - if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr())) { - options = options.dtype(fake_dtype); - } + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(dqkv_layout); + std::vector tmp_shape; switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_3HD: - tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; + tmp_shape = std::vector{dQ_shape.begin(), dQ_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -378,7 +411,7 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_H3D: - tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; + tmp_shape = std::vector{dQ_shape.begin(), dQ_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -392,9 +425,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; + tmp_shape = std::vector{dK_shape.begin(), dK_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -407,9 +440,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; + tmp_shape = std::vector{dK_shape.begin(), dK_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -420,39 +453,51 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + case NVTE_QKV_Layout_Group::NVTE_SD_SD_SD: + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector(k_shape.begin(), k_shape.end()); + tmp_shape = std::vector(dK_shape.begin(), dK_shape.end()); dK = torch::empty(tmp_shape, options); - tmp_shape = std::vector(v_shape.begin(), v_shape.end()); + tmp_shape = std::vector(dV_shape.begin(), dV_shape.end()); dV = torch::empty(tmp_shape, options); break; default: NVTE_ERROR("QKV layout not supported!"); } - std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, q_shape, fake_dtype_te, true, dQ); - std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, k_shape, fake_dtype_te, true, dK); - std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, v_shape, fake_dtype_te, true, dV); + std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, dQ_shape, dqkv_fake_dtype, true, dQ); + std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, dK_shape, dqkv_fake_dtype, true, dK); + std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, dV_shape, dqkv_fake_dtype, true, dV); // construct NVTE tensors - if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { + if (detail::IsFloat8Quantizers(dqkv_quantizer.ptr())) { // FP8 - if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && - dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) { - mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - dQ.fill_(0); - dK.fill_(0); - dV.fill_(0); + if (set_zero) { + if (dq_format == NVTE_QKV_Format::NVTE_THD) { + if (((h_q * d_qk) % block_size == 0) && dQ.is_contiguous()) { + mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + dQ.fill_(0); + } + } + if (dkv_format == NVTE_QKV_Format::NVTE_THD) { + if (((h_kv * d_qk) % block_size == 0) && ((h_kv * d_v) % block_size == 0) && + dK.is_contiguous() && dV.is_contiguous()) { + mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + dK.fill_(0); + dV.fill_(0); + } } } - } else if (dqkv_type == DType::kBFloat16 || dqkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + } else if (dqkv_quantizer.is_none() || + detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || + detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { + if (dq_format == NVTE_QKV_Format::NVTE_THD) { dQ.fill_(0); + } + if (dkv_format == NVTE_QKV_Format::NVTE_THD) { dK.fill_(0); dV.fill_(0); } @@ -538,7 +583,8 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, qkv_scale_inv_format, + do_scale_inv_format, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -555,7 +601,8 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, qkv_scale_inv_format, + do_scale_inv_format, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -614,6 +661,135 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { return qkv; } +std::vector> multi_tensor_transpose_to_bhsd( + std::vector> inputs, const std::string &original_format, + std::vector> outputs) { + NVTE_CHECK(original_format == "sbhd" || original_format == "bshd", + "multi_tensor_transpose_to_bhsd: only BSHD/SBHD -> BHSD is currently supported. " + "Got original_format=\"", + original_format, "\"."); + const auto original_format_enum = (original_format == "sbhd") ? NVTE_SBHD : NVTE_BSHD; + + if (inputs.empty()) return {}; + + const bool has_outputs = !outputs.empty(); + if (has_outputs) { + NVTE_CHECK(outputs.size() == inputs.size(), "multi_tensor_transpose_to_bhsd: outputs.size() (", + outputs.size(), ") != inputs.size() (", inputs.size(), ")."); + } + + std::vector te_ins, te_outs; + std::vector> result(inputs.size(), std::nullopt); + + for (size_t i = 0; i < inputs.size(); ++i) { + if (!inputs[i].has_value()) continue; + + auto &input = inputs[i].value(); + NVTE_CHECK(input.is_cuda() && input.dim() == 4, "multi_tensor_transpose_to_bhsd: input ", i, + " must be a 4D CUDA tensor."); + input = input.contiguous(); + NVTE_CHECK(input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16 || + input.scalar_type() == at::ScalarType::Byte, + "multi_tensor_transpose_to_bhsd: unsupported dtype at index ", i, "."); + + at::Tensor output; + if (has_outputs && outputs[i].has_value()) { + output = outputs[i].value(); + } else { + int64_t B, S, H, D; + if (original_format_enum == NVTE_SBHD) { + S = input.size(0); + B = input.size(1); + H = input.size(2); + D = input.size(3); + } else { + B = input.size(0); + S = input.size(1); + H = input.size(2); + D = input.size(3); + } + output = at::empty({B, H, S, D}, input.options()); + } + + te_ins.push_back(makeTransformerEngineTensor(input)); + te_outs.push_back(makeTransformerEngineTensor(output)); + result[i] = output; + } + + if (!te_ins.empty()) { + std::vector nvte_ins(te_ins.size()), nvte_outs(te_outs.size()); + for (size_t j = 0; j < te_ins.size(); ++j) { + nvte_ins[j] = te_ins[j].data(); + nvte_outs[j] = te_outs[j].data(); + } + nvte_multi_tensor_transpose_to_bhsd(nvte_ins.data(), nvte_outs.data(), te_ins.size(), + original_format_enum, at::cuda::getCurrentCUDAStream()); + } + + return result; +} + +std::vector multi_tensor_pad_last_dim(std::vector inputs, + int64_t alignment) { + const auto align = static_cast(alignment); + NVTE_CHECK(align > 0, "multi_tensor_pad_last_dim: alignment must be > 0."); + NVTE_CHECK(!inputs.empty(), "multi_tensor_pad_last_dim: inputs must not be empty."); + + auto stream = at::cuda::getCurrentCUDAStream(); + std::vector outputs; + outputs.reserve(inputs.size()); + + std::vector kernel_indices; + + for (size_t i = 0; i < inputs.size(); ++i) { + auto &input = inputs[i]; + + NVTE_CHECK(input.dim() == 2, "multi_tensor_pad_last_dim: expected 2D input at index ", i, + ", got ", input.dim(), "D."); + NVTE_CHECK(input.is_cuda(), "multi_tensor_pad_last_dim: input must be a CUDA tensor at index ", + i, "."); + input = input.contiguous(); + + const int64_t rows = input.size(0); + const int64_t in_cols = input.size(1); + const int64_t padded_cols = + static_cast(DIVUP_TO_MULTIPLE(static_cast(in_cols), align)); + + if (in_cols == padded_cols) { + outputs.push_back(input); + continue; + } + + at::Tensor output = at::empty({rows, padded_cols}, input.options()); + outputs.push_back(output); + kernel_indices.push_back(outputs.size() - 1); + } + + if (kernel_indices.empty()) return outputs; + + std::vector te_in_wrappers, te_out_wrappers; + te_in_wrappers.reserve(kernel_indices.size()); + te_out_wrappers.reserve(kernel_indices.size()); + + for (size_t idx : kernel_indices) { + te_in_wrappers.push_back(makeTransformerEngineTensor(inputs[idx])); + te_out_wrappers.push_back(makeTransformerEngineTensor(outputs[idx])); + } + + std::vector nvte_inputs(te_in_wrappers.size()); + std::vector nvte_outputs(te_out_wrappers.size()); + for (size_t i = 0; i < te_in_wrappers.size(); ++i) { + nvte_inputs[i] = te_in_wrappers[i].data(); + nvte_outputs[i] = te_out_wrappers[i].data(); + } + + nvte_multi_tensor_pad_last_dim(nvte_inputs.data(), nvte_outputs.data(), te_in_wrappers.size(), + stream); + + return outputs; +} + /*************************************************************************************************** * Support THD format for Context Parallel: Read the half of a THD tensor **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 27d26d3dab..eb7576d905 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -391,6 +391,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Fused Multi-tensor unpadding", py::call_guard()); m.def("swizzle_scales_for_gemm_", &transformer_engine::pytorch::inplace_swizzle_scale_for_gemm, "Convert tensor block scales into GEMM swizzled format"); + m.def("multi_tensor_swizzle_scales_for_gemm_", + &transformer_engine::pytorch::inplace_multi_tensor_swizzle_scales_for_gemm, + "Convert multiple tensors' block scales into GEMM swizzled format", py::arg("tensors"), + py::arg("rowwise_usage"), py::arg("columnwise_usage")); + m.def( + "multi_tensor_swizzle_scales_for_gemm_unchecked_", + &transformer_engine::pytorch::inplace_multi_tensor_swizzle_scales_for_gemm_unchecked, + "Convert multiple tensors' block scales into GEMM swizzled format (skip scale shape checks)", + py::arg("tensors"), py::arg("rowwise_usage"), py::arg("columnwise_usage")); m.def("grouped_swizzle_for_gemm", &transformer_engine::pytorch::grouped_swizzle_for_gemm, "In-place swizzle of grouped tensor scales for GEMM", py::arg("tensor"), py::arg("rowwise"), py::arg("columnwise")); @@ -401,6 +410,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fa_prepare_bwd", &transformer_engine::pytorch::fa_prepare_bwd, "Backward of QKV preparation for Flash Attention", py::call_guard()); + m.def("multi_tensor_transpose_to_bhsd", + &transformer_engine::pytorch::multi_tensor_transpose_to_bhsd, + "Permute multiple tensors from BSHD/SBHD to BHSD.", py::arg("inputs"), + py::arg("original_format"), py::arg("outputs") = std::vector>{}, + py::call_guard()); + m.def("multi_tensor_pad_last_dim", &transformer_engine::pytorch::multi_tensor_pad_last_dim, + "Pad multiple tensors' last dimension to a common alignment.", py::arg("inputs"), + py::arg("alignment"), py::call_guard()); m.def("fused_attn_fwd", &transformer_engine::pytorch::fused_attn_fwd, "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); m.def("fused_attn_bwd", &transformer_engine::pytorch::fused_attn_bwd, diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index a6b4e7569d..d8ab830c48 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -141,9 +141,11 @@ std::tuple, std::optional> swizzle_scales_ return {std::move(rowwise_scales_pyt), std::move(columnwise_scales_pyt)}; } -std::optional multi_tensor_swizzle_scales_for_gemm( +namespace { + +std::optional multi_tensor_swizzle_scales_for_gemm_impl( std::vector &tensors, bool rowwise_usage, - bool columnwise_usage) { + bool columnwise_usage, bool check_scale_inv_shapes) { // Checks and trivial cases NVTE_CHECK(rowwise_usage != columnwise_usage, "Expect exactly one of rowwise_usage=", rowwise_usage, @@ -243,9 +245,15 @@ std::optional multi_tensor_swizzle_scales_for_gemm( // Launch kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte_raw.data(), outputs_nvte_raw.data(), - inputs_nvte_raw.size(), - at::cuda::getCurrentCUDAStream()); + if (check_scale_inv_shapes) { + nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte_raw.data(), outputs_nvte_raw.data(), + inputs_nvte_raw.size(), + at::cuda::getCurrentCUDAStream()); + } else { + nvte_multi_tensor_swizzle_scaling_factors_unchecked( + inputs_nvte_raw.data(), outputs_nvte_raw.data(), inputs_nvte_raw.size(), + at::cuda::getCurrentCUDAStream()); + } }); // Update tensors with swizzled scales @@ -269,6 +277,22 @@ std::optional multi_tensor_swizzle_scales_for_gemm( return std::move(output_scales_pyt); } +} // anonymous namespace + +std::optional multi_tensor_swizzle_scales_for_gemm( + std::vector &tensors, bool rowwise_usage, + bool columnwise_usage) { + return multi_tensor_swizzle_scales_for_gemm_impl(tensors, rowwise_usage, columnwise_usage, + /*check_scale_inv_shapes=*/true); +} + +std::optional multi_tensor_swizzle_scales_for_gemm_unchecked( + std::vector &tensors, bool rowwise_usage, + bool columnwise_usage) { + return multi_tensor_swizzle_scales_for_gemm_impl(tensors, rowwise_usage, columnwise_usage, + /*check_scale_inv_shapes=*/false); +} + at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input, bool rowwise) { // Check input tensor @@ -379,16 +403,39 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW tensor_offsets.data_ptr, static_cast(tensor_offsets.dtype), tensor_offsets.shape); } + // Per-tensor logical dimensions (uniform-shape grouped tensor). + const size_t num_tensors = input.num_tensors(); + const auto logical_shape_nvte = input.logical_shape(); + NVTE_CHECK(logical_shape_nvte.ndim >= 2, + "Grouped GEMM swizzle expects logical_shape with ndim >= 2."); + const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors; + const size_t per_tensor_last_dim = logical_shape_nvte.data[logical_shape_nvte.ndim - 1]; + constexpr size_t kMxfp8BlockSize = 32; + + // Output is always allocated in the per-tensor padded ("swizzle-ready") layout + // so the cuDNN grouped GEMM consumer sees the correct stride between experts. + // The swizzle kernel itself handles converting from the kernel-emitted compact + // layout (per-tensor first dim is the unpadded value) to this padded layout. + auto compute_padded_grouped_scale_shape = [&](bool rowwise) { + const size_t m = rowwise ? per_tensor_first_dim : per_tensor_last_dim; + const size_t k = rowwise ? per_tensor_last_dim : per_tensor_first_dim; + const size_t padded_m = ceildiv(m, size_t{128}) * 128; + const size_t padded_k = ceildiv(ceildiv(k, kMxfp8BlockSize), size_t{4}) * 4; + return std::vector{num_tensors * padded_m, padded_k}; + }; + if (swizzle_rowwise) { const auto data = input.get_rowwise_data(); const auto data_dtype = static_cast(data.dtype); const auto scales_dtype = static_cast(row_scales.dtype); swizzle_input.set_rowwise_data(nullptr, data_dtype, data.shape); swizzle_input.set_rowwise_scale_inv(row_scales.data_ptr, scales_dtype, row_scales.shape); - rowwise_scales_pyt = allocateSpace(row_scales.shape, scales_dtype, false); + const auto padded_shape = compute_padded_grouped_scale_shape(/*rowwise=*/true); + rowwise_scales_pyt = allocateSpace(padded_shape, scales_dtype, false); + NVTEShape padded_shape_nvte = nvte_make_shape(padded_shape.data(), padded_shape.size()); swizzle_output.set_rowwise_data(nullptr, data_dtype, data.shape); swizzle_output.set_rowwise_scale_inv(getDataPtr(*rowwise_scales_pyt), scales_dtype, - row_scales.shape); + padded_shape_nvte); } if (swizzle_columnwise) { const auto data = input.get_columnwise_data(); @@ -396,10 +443,12 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW const auto scales_dtype = static_cast(col_scales.dtype); swizzle_input.set_columnwise_data(nullptr, data_dtype, data.shape); swizzle_input.set_columnwise_scale_inv(col_scales.data_ptr, scales_dtype, col_scales.shape); - columnwise_scales_pyt = allocateSpace(col_scales.shape, scales_dtype, false); + const auto padded_shape = compute_padded_grouped_scale_shape(/*rowwise=*/false); + columnwise_scales_pyt = allocateSpace(padded_shape, scales_dtype, false); + NVTEShape padded_shape_nvte = nvte_make_shape(padded_shape.data(), padded_shape.size()); swizzle_output.set_columnwise_data(nullptr, data_dtype, data.shape); swizzle_output.set_columnwise_scale_inv(getDataPtr(*columnwise_scales_pyt), scales_dtype, - col_scales.shape); + padded_shape_nvte); } swizzle_output.set_with_gemm_swizzled_scales(true); @@ -410,12 +459,13 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW if (swizzle_rowwise) { const auto scales_dtype = static_cast(row_scales.dtype); - input.set_rowwise_scale_inv(getDataPtr(*rowwise_scales_pyt), scales_dtype, row_scales.shape); + input.set_rowwise_scale_inv(getDataPtr(*rowwise_scales_pyt), scales_dtype, + getTensorShape(*rowwise_scales_pyt)); } if (swizzle_columnwise) { const auto scales_dtype = static_cast(col_scales.dtype); input.set_columnwise_scale_inv(getDataPtr(*columnwise_scales_pyt), scales_dtype, - col_scales.shape); + getTensorShape(*columnwise_scales_pyt)); } input.set_with_gemm_swizzled_scales(true); return SwizzledGroupedScales{std::move(rowwise_scales_pyt), std::move(columnwise_scales_pyt)}; @@ -443,6 +493,105 @@ void grouped_swizzle_for_gemm(py::handle &tensor, bool rowwise, bool columnwise) } } +namespace { + +void inplace_multi_tensor_swizzle_scales_for_gemm_impl(std::vector &tensors, + bool rowwise_usage, bool columnwise_usage, + bool check_scale_inv_shapes) { + NVTE_CHECK(rowwise_usage != columnwise_usage, + "Expect exactly one of rowwise_usage and columnwise_usage."); + if (tensors.empty()) { + return; + } + + // Convert Python tensors to TensorWrappers, filtering those that need swizzling + std::vector swizzle_indices; + std::vector wrappers_to_swizzle; + + for (size_t i = 0; i < tensors.size(); ++i) { + auto tw = makeTransformerEngineTensor(tensors[i], py::none()); + + if (i == 0) { + switch (tw.scaling_mode()) { + case NVTE_MXFP8_1D_SCALING: + case NVTE_NVFP4_1D_SCALING: + break; + case NVTE_INVALID_SCALING: + NVTE_ERROR("Invalid scaling mode for swizzling scaling factors."); + default: + return; + } + } + + if (tw.get_with_gemm_swizzled_scales()) { + continue; + } + const auto scales_nvte = + rowwise_usage ? tw.get_rowwise_scale_inv() : tw.get_columnwise_scale_inv(); + if (scales_nvte.data_ptr == nullptr || + (scales_nvte.shape.ndim == 1 && scales_nvte.shape.data[0] == 0)) { + continue; + } + + swizzle_indices.push_back(i); + wrappers_to_swizzle.push_back(std::move(tw)); + } + + if (wrappers_to_swizzle.empty()) { + return; + } + + // Delegate to core C++ function + auto swizzle_fn = check_scale_inv_shapes ? multi_tensor_swizzle_scales_for_gemm + : multi_tensor_swizzle_scales_for_gemm_unchecked; + auto output_buffer = swizzle_fn(wrappers_to_swizzle, rowwise_usage, columnwise_usage); + if (!output_buffer.has_value()) { + return; + } + + // Update Python objects with properly-shaped views into the contiguous output buffer + const uint8_t *base = reinterpret_cast(output_buffer->data_ptr()); + for (size_t j = 0; j < wrappers_to_swizzle.size(); ++j) { + const auto scales_nvte = rowwise_usage ? wrappers_to_swizzle[j].get_rowwise_scale_inv() + : wrappers_to_swizzle[j].get_columnwise_scale_inv(); + + const size_t offset = reinterpret_cast(scales_nvte.data_ptr) - base; + const auto dtype = static_cast(scales_nvte.dtype); + const size_t num_elements = product(scales_nvte.shape, 0, scales_nvte.shape.ndim); + const size_t num_bytes = + ceildiv(num_elements * transformer_engine::pytorch::typeToNumBits(dtype), size_t(8)); + + std::vector torch_shape; + for (size_t d = 0; d < scales_nvte.shape.ndim; ++d) { + torch_shape.push_back(static_cast(scales_nvte.shape.data[d])); + } + auto scale_view = + output_buffer->narrow(0, static_cast(offset), static_cast(num_bytes)) + .view(torch_shape); + + if (rowwise_usage) { + tensors[swizzle_indices[j]].attr("_rowwise_scale_inv") = py::cast(scale_view); + } else { + tensors[swizzle_indices[j]].attr("_columnwise_scale_inv") = py::cast(scale_view); + } + } +} + +} // anonymous namespace + +void inplace_multi_tensor_swizzle_scales_for_gemm(std::vector &tensors, + bool rowwise_usage, bool columnwise_usage) { + inplace_multi_tensor_swizzle_scales_for_gemm_impl(tensors, rowwise_usage, columnwise_usage, + /*check_scale_inv_shapes=*/true); +} + +void inplace_multi_tensor_swizzle_scales_for_gemm_unchecked(std::vector &tensors, + bool rowwise_usage, + bool columnwise_usage) { + inplace_multi_tensor_swizzle_scales_for_gemm_impl(tensors, rowwise_usage, columnwise_usage, + /*check_scale_inv_shapes=*/false); +} + void inplace_swizzle_scale_for_gemm(py::handle &tensor) { // Convert Python tensor to C++ tensor auto tensor_nvte = makeTransformerEngineTensor(tensor, py::none()); diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 88f76a7cb1..132db4075f 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -33,6 +33,9 @@ std::optional multi_tensor_swizzle_scales_for_gemm(std::vector multi_tensor_swizzle_scales_for_gemm_unchecked( + std::vector& tensors, bool rowwise_usage, bool columnwise_usage); + using SwizzledGroupedScales = std::pair, std::optional>; /*! \brief Swizzle grouped tensor scales for GEMM if needed. diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 720a274119..18b7049233 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -31,6 +31,7 @@ clear_tensor_data, init_method_constant, requires_grad, + resolve_grouped_linear_single_param_flags, get_nvtx_range_context, ) from ..distributed import ( @@ -673,11 +674,15 @@ class GroupedLinear(TransformerEngineBaseModule): single_grouped_weight : bool, default = False If set to ``True``, grouped weights are stored as a single grouped parameter instead of one parameter per GEMM. - EXPERIMENTAL and subject to change. + EXPERIMENTAL and subject to change. Gated by the + ``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var + is not set this argument is forced to ``False`` with a warning. single_grouped_bias : bool, default = False If set to ``True``, grouped biases are stored as a single grouped bias instead of one bias per GEMM. - EXPERIMENTAL and subject to change. + EXPERIMENTAL and subject to change. Gated by the + ``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var + is not set this argument is forced to ``False`` with a warning. Notes ----- @@ -726,6 +731,9 @@ def __init__( self.ub_overlap_ag = ub_overlap_ag self.ub_name = ub_name self.save_original_input = save_original_input + single_grouped_weight, single_grouped_bias = resolve_grouped_linear_single_param_flags( + single_grouped_weight, single_grouped_bias + ) self.single_grouped_weight = single_grouped_weight self.single_grouped_bias = single_grouped_bias if ub_overlap_rs or ub_overlap_ag: diff --git a/transformer_engine/pytorch/newton_schulz.py b/transformer_engine/pytorch/newton_schulz.py index 2367897565..1cbe6ebfbf 100644 --- a/transformer_engine/pytorch/newton_schulz.py +++ b/transformer_engine/pytorch/newton_schulz.py @@ -5,7 +5,7 @@ """Distributed Newton-Schulz matrix orthogonalization via cuSolverMp.""" from itertools import chain, cycle, islice, repeat -from typing import Iterator, List, Literal, Optional, Sequence +from typing import Iterator, Literal, Optional, Sequence import torch import torch.distributed as dist @@ -63,13 +63,14 @@ NSCoeffT = Literal[_COEFFICIENT_SETS.keys()] CoeffIterMode = Literal["cycle", "repeat_last"] +CoeffT = tuple[float, float, float] def get_coefficient_iterator( steps: int, - coefficient_sets: Sequence[tuple[float, float, float]], + coefficient_sets: Sequence[CoeffT], mode: CoeffIterMode = "cycle", -) -> Iterator[tuple[float, float, float]]: +) -> Iterator[CoeffT]: """Iterate through coefficient sets with configurable end behavior using itertools. Args: @@ -89,7 +90,7 @@ def get_coefficient_iterator( if not coefficient_sets: raise ValueError("coefficient_sets must be non-empty.") - base: Iterator[tuple[float, float, float]] + base: Iterator[CoeffT] if mode == "cycle": base = cycle(coefficient_sets) elif mode == "repeat_last": @@ -101,7 +102,7 @@ def get_coefficient_iterator( return islice(base, steps) -def get_coefficients(steps: int, coefficient_type: NSCoeffT = "quintic") -> List[float]: +def get_coefficients(steps: int, coefficient_type: NSCoeffT = "quintic") -> list[CoeffT]: """Return the coefficient schedule for Newton-Schulz. Parameter ``coefficient_type`` can be one of the following @@ -119,7 +120,7 @@ def get_coefficients(steps: int, coefficient_type: NSCoeffT = "quintic") -> List coeff_iter = get_coefficient_iterator( steps, _COEFFICIENT_SETS[coefficient_type], mode=iter_mode ) - return list(chain.from_iterable(coeff_iter)) + return list(coeff_iter) class CusolverMpCtx: @@ -159,7 +160,7 @@ def newton_schulz( x: torch.Tensor, ctx: CusolverMpCtx, num_iterations: int = 5, - coefficients: Optional[List[float]] = None, + coefficients: Optional[Sequence[CoeffT]] = None, ) -> None: """Compute Newton-Schulz matrix orthogonalization in-place on a distributed matrix. @@ -173,16 +174,23 @@ def newton_schulz( cuSolverMp context created by :func:`cusolvermp_ctx_create`. num_iterations : int, optional Number of Newton-Schulz iterations. Default: 5. - coefficients : list of float, optional + coefficients : sequence of tuple[float, float, float], optional Polynomial coefficients for the Newton-Schulz iteration. """ if coefficients is None: coefficients = get_coefficients(num_iterations) - if len(coefficients) != num_iterations * 3: + if len(coefficients) != num_iterations: raise ValueError( f"Unexpected number of coefficients: {len(coefficients)} for" f" {num_iterations} iterations" ) + flat_coefficients: list[float] = [] + for i, coeff in enumerate(coefficients): + if len(coeff) != 3: + raise ValueError( + f"Expected coefficient tuple of length 3 at iteration {i}, got {len(coeff)}" + ) + flat_coefficients.extend(coeff) if x.dim() != 2: raise ValueError(f"Expected 2D tensor, got {x.dim()}D") @@ -197,4 +205,4 @@ def newton_schulz( m = x.size(0) n = x.size(1) * ctx.nranks - tex.newton_schulz(ctx._ptr, m, n, x, num_iterations, coefficients) + tex.newton_schulz(ctx._ptr, m, n, x, num_iterations, flat_coefficients) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index e21915a5a6..beef6fe52f 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -21,17 +21,11 @@ @functools.lru_cache(maxsize=1) -def _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu() -> bool: - """Check cuDNN FE min version with fixed numerics for qgeglu.""" - try: - return PkgVersion(get_pkg_version("nvidia-cudnn-frontend")) >= PkgVersion("1.23.0") - except PackageNotFoundError: - return False - +def _cudnn_frontend_version_supported() -> bool: + """Check cuDNN frontend is at least 1.23.0. -@functools.lru_cache(maxsize=1) -def _nvidia_cudnn_frontend_supports_wgrad() -> bool: - """Check cuDNN FE min version for grouped GEMM wgrad kernel.""" + All grouped MLP fused-kernel features require cuDNN frontend 1.23.0. + """ try: return PkgVersion(get_pkg_version("nvidia-cudnn-frontend")) >= PkgVersion("1.23.0") except PackageNotFoundError: @@ -140,8 +134,6 @@ def fuse_grouped_mlp_ops( constructor accepting ``fc1``, ``glu_op``, ``fc2`` keyword args. The ``glu_op`` must be :class:`~transformer_engine.pytorch.ops.basic.swiglu.ScaledSwiGLU` or :class:`~transformer_engine.pytorch.ops.basic.swiglu.ScaledClampedQGeGLU`. - May also expose ``is_fc1_bias_supported()`` and/or - ``is_fc2_bias_supported()`` classmethods for bias eligibility. Returns ------- @@ -159,13 +151,6 @@ def fuse_grouped_mlp_ops( if recipe is None or not recipe.mxfp8(): return ops - fc1_bias_ok = ( - not hasattr(fused_op_cls, "is_fc1_bias_supported") or fused_op_cls.is_fc1_bias_supported() - ) - fc2_bias_ok = ( - not hasattr(fused_op_cls, "is_fc2_bias_supported") or fused_op_cls.is_fc2_bias_supported() - ) - out = [] window, ops = ops[:3], ops[3:] while len(window) == 3: @@ -179,7 +164,6 @@ def fuse_grouped_mlp_ops( matches_pattern = False elif isinstance(window[1], ScaledClampedQGeGLU) and ( abs(window[1]._clamped.alpha - 1.702) > 0.001 - or not _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu() ): matches_pattern = False elif window[0].num_groups != window[2].num_groups: @@ -193,10 +177,6 @@ def fuse_grouped_mlp_ops( matches_pattern = False elif window[1].glu_interleave_size != 32: matches_pattern = False - elif window[0].has_bias and not fc1_bias_ok: - matches_pattern = False - elif window[2].has_bias and not fc2_bias_ok: - matches_pattern = False if matches_pattern: op = fused_op_cls( diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 17594726cc..19fcf62ced 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -329,8 +329,6 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: super().pre_fuser_forward(requires_grad=requires_grad) if FP8GlobalStateManager.is_fp8_enabled(): # Configure quantizer usages - # Note: We cache the quantized input for backward pass, - # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad columnwise_usage = weight_requires_grad if FP8GlobalStateManager.get_fp8_recipe().backward_override is not None: @@ -339,13 +337,13 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: weight_quantizer = self.get_quantizer("forward", 1) grad_output_quantizer = self.get_quantizer("backward", 0) input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - weight_quantizer.set_usage(rowwise=True, columnwise=False) + weight_quantizer.set_usage(rowwise=True, columnwise=requires_grad) grad_output_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) - # Configure input/grad output tensor + # Configure input/grad output quantizers # Note: These tensors are only used internally. If there is no # tensor-parallel communication, they are only used for GEMM. input_quantizer = self.get_quantizer("forward", 0) @@ -370,21 +368,15 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: # Configure weight quantizer # Note: This function may be called in base class constructor, - # before any basic linear attrs have been set. + # before basic linear attrs have been set. weight_quantizer = self.get_quantizer("forward", 1) - if weight_quantizer is None: - pass - elif is_quantized_tensor(getattr(self, "weight", None)): - # Make sure weight param has correct quantizer - weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) - weight_quantizer.internal = False - self.weight.update_quantizer(weight_quantizer.copy()) - else: - # Use internal tensors if quantized weights will not be - # exposed externally - weight_quantizer.internal = ( - not FP8GlobalStateManager.with_fp8_parameters() - and not getattr(self, "_with_quantized_weight", False) + weight = getattr(self, "weight", None) + if weight_quantizer is not None: + # Determine if quantized weight is exposed as parameter + weight_quantizer.internal = not ( + FP8GlobalStateManager.with_fp8_parameters() + or getattr(self, "_with_quantized_weight", False) + or is_quantized_tensor(weight) ) # Recipe-specific configuration @@ -416,6 +408,18 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: grad_output_quantizer.with_amax_reduction = True grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + # Update quantizer in quantized weight tensor + if weight_quantizer is not None and is_quantized_tensor(weight): + if weight._quantizer is not None: + # Preserve existing usages in weight tensor. Even if a + # usage is currently unnecessary, the weight tensor + # may be used elsewhere. + weight_quantizer.set_usage( + rowwise=weight._quantizer.rowwise_usage, + columnwise=weight._quantizer.columnwise_usage, + ) + weight.update_quantizer(weight_quantizer.copy()) + @staticmethod def _functional_forward( input: torch.Tensor, # pylint: disable=redefined-builtin diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index a1d40a30ec..302fa3384b 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -30,6 +30,7 @@ canonicalize_dtype, clear_tensor_data, devices_match, + resolve_grouped_linear_single_param_flags, round_up_to_nearest_multiple, ) from .._common import is_quantized_tensor, maybe_dequantize @@ -75,11 +76,17 @@ class GroupedLinear(BasicOperation): ``main_grad`` instead of accumulating. single_grouped_weight : bool, default = ``False`` Store all expert weights as one ``GroupedTensor`` parameter ``weight``. + EXPERIMENTAL and subject to change. Gated by the + ``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var + is not set this argument is forced to ``False`` with a warning. delay_wgrad_compute : bool, default = ``False`` Whether to delay weight gradient computation single_grouped_bias : bool, default = ``False`` If ``True`` (and ``bias=True``), store all expert biases as one ``GroupedTensor`` parameter named ``bias`` instead of ``bias0``..``bias{N-1}``. + EXPERIMENTAL and subject to change. Gated by the + ``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var + is not set this argument is forced to ``False`` with a warning. scale_bias : bool, default = ``False`` If ``True`` (and ``bias=True``), expects a probability tensor as an additional extra input and adds ``bias * scales`` instead of ``bias`` @@ -120,6 +127,9 @@ def __init__( self.num_groups: int = num_groups self.in_features: int = in_features self.out_features: int = out_features + single_grouped_weight, single_grouped_bias = resolve_grouped_linear_single_param_flags( + single_grouped_weight, single_grouped_bias + ) self.single_grouped_weight: bool = single_grouped_weight self.single_grouped_bias: bool = single_grouped_bias self.use_bias: bool = bias @@ -616,14 +626,12 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: weight_requires_grad = requires_grad and weight_requires_grad # Configure quantizer usages - # Note: We cache the quantized input for backward pass, - # but discard the quantized weights. for group_idx in range(self.num_groups): input_quantizer = self.get_quantizer("forward", 2 * group_idx) weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) grad_output_quantizer = self.get_quantizer("backward", group_idx) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - weight_quantizer.set_usage(rowwise=True, columnwise=False) + weight_quantizer.set_usage(rowwise=True, columnwise=requires_grad) grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: @@ -638,32 +646,29 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: if grad_output_quantizer is not None: grad_output_quantizer.internal = True - # Handle weight quantizer + # Get weight tensor # Note: This function may be called in base class constructor, - # before any basic linear attrs have been set. - weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) - if weight_quantizer is None: - pass - elif is_quantized_tensor(getattr(self, f"weight{group_idx}", None)): - # Make sure weight param has correct quantizer - weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) - weight_quantizer.internal = False - if self.single_grouped_weight: - self.weight.quantizer = weight_quantizer.copy() - else: - getattr(self, f"weight{group_idx}").update_quantizer(weight_quantizer.copy()) + # before any grouped linear attrs have been set. + weight = None + weight_is_quantized = False + if getattr(self, "single_grouped_weight", False): + weight = getattr(self, "weight", None) + weight_is_quantized = weight is not None and weight.quantizer is not None else: - # Use internal tensors if quantized weights will not be - # exposed externally - weight_quantizer.internal = ( - not FP8GlobalStateManager.with_fp8_parameters() - and not getattr(self, "_with_quantized_weight", False) - and not self.single_grouped_weight + weight = getattr(self, f"weight{group_idx}", None) + weight_is_quantized = is_quantized_tensor(weight) + + # Configure weight quantizer + weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + if weight_quantizer is not None: + # Determine if quantized weight is exposed as parameter + weight_quantizer.internal = not ( + FP8GlobalStateManager.with_fp8_parameters() + or getattr(self, "_with_quantized_weight", False) + or weight_is_quantized ) # Recipe-specific configuration - # Note: This function may be called in base class constructor, - # before any basic linear attrs have been set. if recipe is not None: if recipe.float8_current_scaling(): input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale @@ -677,6 +682,29 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: recipe.fp8_quant_bwd_grad.amax_epsilon ) + # Update quantizer in quantized weight tensor + if weight_quantizer is not None and weight_is_quantized: + # Get quantizer from weight tensor + weight_tensor_quantizer = ( + weight.quantizer if self.single_grouped_weight else weight._quantizer + ) + + # Preserve existing usages in weight tensor. Even if a + # usage is currently unnecessary, the weight tensor + # may be used elsewhere. + if weight_tensor_quantizer is not None: + weight_quantizer.set_usage( + rowwise=weight_tensor_quantizer.rowwise_usage, + columnwise=weight_tensor_quantizer.columnwise_usage, + ) + + # Update weight tensor + if self.single_grouped_weight: + if group_idx == 0: + weight.quantizer = weight_quantizer.copy() + else: + weight.update_quantizer(weight_quantizer.copy()) + def op_forward(self, *args, **kwargs): raise RuntimeError( f"{self.__class__.__name__} operation has " diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 3eb57c3563..8bea63c82f 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -7,7 +7,6 @@ from __future__ import annotations from collections.abc import Callable import functools -import inspect import math import os from typing import Optional @@ -15,7 +14,6 @@ import torch import transformer_engine_torch as tex -from ...module.base import get_dummy_wgrad from ...quantization import Recipe from ...tensor.grouped_tensor import GroupedTensor from ...tensor.mxfp8_tensor import MXFP8Quantizer @@ -25,13 +23,13 @@ from ..fuser import register_backward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext from .._common import ( - _nvidia_cudnn_frontend_supports_wgrad, + _cudnn_frontend_version_supported, fuse_grouped_mlp_ops, maybe_dequantize, validate_grouped_mlp_dims, ) from ...cpp_extensions import general_grouped_gemm_for_grouped_tensor -from ...module.base import _2X_ACC_WGRAD +from ...module.base import _2X_ACC_WGRAD, get_dummy_wgrad from ...triton.grouped_dbias_dscales import _compute_grouped_dbias_dscales @@ -58,17 +56,41 @@ def _cudnn_compute_wgrad( fp8_dtype = torch.float8_e4m3fn - # a_tensor = DY^T = (out_features, total_tokens) row-major - a_tensor = grouped_dy.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, out_features).T - # b_tensor = X = (total_tokens, in_features) column-major - b_tensor = grouped_x.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, in_features) - - sfa_tensor = grouped_dy.columnwise_scale_inv.view(out_features, -1).view( - dtype=torch.float8_e8m0fnu - ) - sfb_tensor = grouped_x.columnwise_scale_inv.view(in_features, -1).view( - dtype=torch.float8_e8m0fnu - ) + sfa_leading_dim = ((out_features + 127) // 128) * 128 + sfb_leading_dim = ((in_features + 127) // 128) * 128 + + if total_tokens == 0: + # A workaround for the case with zero-token experts. + # Even for this case, cuteDSL still requires the same + # stride requirements for the input and scale tensors. + device = grouped_dy.columnwise_data.device + a_tensor = torch.empty_strided((out_features, 0), (16, 1), dtype=fp8_dtype, device=device) + b_tensor = torch.empty_strided( + (0, in_features), (in_features, 1), dtype=fp8_dtype, device=device + ) + sfa_tensor = torch.empty_strided( + (sfa_leading_dim, 0), + (16, 1), + dtype=torch.float8_e8m0fnu, + device=device, + ) + sfb_tensor = torch.empty_strided( + (sfb_leading_dim, 0), + (16, 1), + dtype=torch.float8_e8m0fnu, + device=device, + ) + else: + a_tensor = ( + grouped_dy.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, out_features).T + ) + b_tensor = grouped_x.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, in_features) + sfa_tensor = grouped_dy.columnwise_scale_inv.view(sfa_leading_dim, -1).view( + dtype=torch.float8_e8m0fnu + ) + sfb_tensor = grouped_x.columnwise_scale_inv.view(sfb_leading_dim, -1).view( + dtype=torch.float8_e8m0fnu + ) # Prepare wgrad output if single_grouped_weight: @@ -107,20 +129,6 @@ def _cudnn_compute_wgrad( ) -@functools.lru_cache(maxsize=1) -def _dglu_wrapper_has_generate_dbias_arg() -> bool: - """True if cudnn-frontend SM100 dGLU wrapper accepts ``generate_dbias``.""" - try: - from cudnn import grouped_gemm_dglu_wrapper_sm100 # pylint: disable=import-outside-toplevel - except ImportError: - return False - try: - params = inspect.signature(grouped_gemm_dglu_wrapper_sm100).parameters - except (TypeError, ValueError): - return False - return "generate_dbias" in params - - def _compute_grad_params( fc_op, ctx, @@ -171,13 +179,12 @@ def _compute_grad_params( f" {tuple(main_grad.stride())}" ) from e accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) - if accumulate_into_main_grad: - grouped_wgrad = GroupedTensor.make_grouped_tensor_from_rowwise_data( - num_tensors=num_groups, - tensor_shape=weight_shape, - rowwise_data=main_grad, - dtype=main_grad.dtype, - ) + grouped_wgrad = GroupedTensor.make_grouped_tensor_from_rowwise_data( + num_tensors=num_groups, + tensor_shape=weight_shape, + rowwise_data=main_grad, + dtype=main_grad.dtype, + ) if grouped_wgrad is None: grouped_wgrad = GroupedTensor.make_grouped_tensor_with_shapes( @@ -235,7 +242,9 @@ def _compute_grad_params( packed_wgrad = None if not delay_wgrad: packed_wgrad = grouped_wgrad.rowwise_data.view(num_groups, *weight_shape) - if accumulate_into_main_grad and hasattr(weight_param, "grad_added_to_main_grad"): + if fc_op._accumulate_into_main_grad and hasattr( + weight_param, "grad_added_to_main_grad" + ): weight_param.grad_added_to_main_grad = True packed_wgrad = get_dummy_wgrad( list(weight_param.size()), @@ -244,9 +253,9 @@ def _compute_grad_params( ) w_list = [packed_wgrad] else: - if delay_wgrad or accumulate_into_main_grad: + if delay_wgrad or fc_op._accumulate_into_main_grad: w_list = [None] * num_groups - if accumulate_into_main_grad: + if fc_op._accumulate_into_main_grad: for idx in range(num_groups): wp = getattr(fc_op, f"weight{idx}") if hasattr(wp, "grad_added_to_main_grad"): @@ -297,10 +306,11 @@ def grouped_gemm_quant_kernel(cls) -> Callable: @functools.lru_cache(maxsize=None) def grouped_gemm_wgrad_kernel(cls) -> Optional[Callable]: """CuTe DSL kernel for grouped GEMM wgrad on SM100+. - Returns ``None`` when the cuDNN front-end package is older than - 1.23.0. + + Returns ``None`` when the environment variable + ``NVTE_DISABLE_CUTEDSL_WGRAD_FUSED_GROUPED_MLP`` is set to ``1``. """ - if not _nvidia_cudnn_frontend_supports_wgrad(): + if int(os.environ.get("NVTE_DISABLE_CUTEDSL_WGRAD_FUSED_GROUPED_MLP", "0")) >= 1: return None from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module @@ -314,6 +324,8 @@ def is_supported(cls) -> bool: return False if get_device_compute_capability()[0] != 10: return False + if not _cudnn_frontend_version_supported(): + return False try: cls.grouped_gemm_dglu_kernel() cls.grouped_gemm_quant_kernel() @@ -321,13 +333,6 @@ def is_supported(cls) -> bool: return False return True - @classmethod - def is_fc1_bias_supported(cls) -> bool: - """Whether cudnn-frontend exposes ``generate_dbias`` on the dGLU SM100 wrapper (FC1 bias grad only).""" - if not cls.is_supported(): - return False - return _dglu_wrapper_has_generate_dbias_arg() - def __init__( self, *, @@ -685,7 +690,6 @@ def fuser_backward( "norm_const_tensor": None, "prob_tensor": torch.ones((out_shape[0], 1, 1), dtype=torch.float32, device=device), "acc_dtype": torch.float32, - "c_dtype": dtype, "d_dtype": dtype, "cd_major": "n", "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 90c4204f06..599e5f96ae 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -7,7 +7,6 @@ from __future__ import annotations from collections.abc import Callable, Iterable import functools -import inspect import os from typing import Any, Optional @@ -24,6 +23,7 @@ from ..fuser import register_forward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext from .._common import ( + _cudnn_frontend_version_supported, fuse_grouped_mlp_ops, is_quantized_tensor, maybe_dequantize, @@ -76,6 +76,8 @@ def is_supported(cls) -> bool: return False if get_device_compute_capability()[0] != 10: return False + if not _cudnn_frontend_version_supported(): + return False try: cls.grouped_gemm_glu_kernel() cls.grouped_gemm_quant_kernel() @@ -83,42 +85,6 @@ def is_supported(cls) -> bool: return False return True - @classmethod - @functools.lru_cache(maxsize=1) - def is_fc1_bias_supported(cls) -> bool: - """Whether cudnn-frontend exposes ``bias_tensor`` on the grouped GEMM GLU SM100 wrapper (FC1).""" - if not cls.is_supported(): - return False - try: - from cudnn import ( - grouped_gemm_glu_wrapper_sm100, - ) # pylint: disable=import-outside-toplevel - except ImportError: - return False - try: - params = inspect.signature(grouped_gemm_glu_wrapper_sm100).parameters - except (TypeError, ValueError): - return False - return "bias_tensor" in params - - @classmethod - @functools.lru_cache(maxsize=1) - def is_fc2_bias_supported(cls) -> bool: - """Whether cudnn-frontend exposes ``bias_tensor`` on the grouped GEMM Quant SM100 wrapper (FC2).""" - if not cls.is_supported(): - return False - try: - from cudnn import ( - grouped_gemm_quant_wrapper_sm100, - ) # pylint: disable=import-outside-toplevel - except ImportError: - return False - try: - params = inspect.signature(grouped_gemm_quant_wrapper_sm100).parameters - except (TypeError, ValueError): - return False - return "bias_tensor" in params - def __init__( self, *, @@ -433,18 +399,16 @@ def fuser_forward( "sfa_tensor": fc1_kernel_out["sfd_row_tensor"], "padded_offsets": split_points, "alpha_tensor": alpha_tensor.float(), + "bias_tensor": fc2_bias_packed, "norm_const_tensor": None, "prob_tensor": fc2_scales_tensor, "acc_dtype": torch.float32, - "c_dtype": dtype, "d_dtype": dtype, "cd_major": "n", "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, "current_stream": current_stream, "use_dynamic_sched": True, } - if self.is_fc2_bias_supported(): - fc2_quant_kwargs["bias_tensor"] = fc2_bias_packed if fc2_op.single_grouped_weight: # Clone and swizzle scales for GEMM (original stays unmodified for save_for_backward) diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 7e2fea45f3..5f12c3ed8c 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -635,7 +635,7 @@ def make_grouped_tensor( total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] for i, s in enumerate(shape): - scale_inv_shape = quantizer.get_scale_shape(s, False) + scale_inv_shape = quantizer.get_scale_shape(s, True) columnwise_scale_elements = math.prod(scale_inv_shape) total_columnwise_scale_elements += columnwise_scale_elements columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) @@ -872,15 +872,25 @@ def split_into_quantized_tensors( # populate scale_inv_offsets from the tensor offsets if self.scale_inv is not None and self.scale_inv_offsets is None: - if recipe.nvfp4(): - self.scale_inv_offsets = self.tensor_offsets // 16 - if recipe.mxfp8(): - self.scale_inv_offsets = self.tensor_offsets // 32 + if recipe.nvfp4() or recipe.mxfp8() or recipe.float8_block_scaling(): + cum = 0 + scale_inv_offsets = [0] + for i in range(self.num_tensors): + tensor_shape = self.tensor_shapes[i] + scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) + cum += math.prod(scale_shape) + scale_inv_offsets.append(cum) + self.scale_inv_offsets = scale_inv_offsets if self.columnwise_scale_inv is not None and self.columnwise_scale_inv_offsets is None: - if recipe.nvfp4(): - self.columnwise_scale_inv_offsets = self.tensor_offsets // 16 - if recipe.mxfp8(): - self.columnwise_scale_inv_offsets = self.tensor_offsets // 32 + if recipe.nvfp4() or recipe.mxfp8() or recipe.float8_block_scaling(): + cum = 0 + columnwise_scale_inv_offsets = [0] + for i in range(self.num_tensors): + tensor_shape = self.tensor_shapes[i] + scale_shape = self.quantizer.get_scale_shape(tensor_shape, True) + cum += math.prod(scale_shape) + columnwise_scale_inv_offsets.append(cum) + self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets for i in range(self.num_tensors): quantizer = self.quantizer diff --git a/transformer_engine/pytorch/triton/__init__.py b/transformer_engine/pytorch/triton/__init__.py index d86cededd7..6d3141253d 100644 --- a/transformer_engine/pytorch/triton/__init__.py +++ b/transformer_engine/pytorch/triton/__init__.py @@ -3,3 +3,4 @@ # See LICENSE for license information. """PyTorch wrappers for Triton kernels.""" +from transformer_engine.pytorch.triton import mhc diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py new file mode 100644 index 0000000000..987216e327 --- /dev/null +++ b/transformer_engine/pytorch/triton/mhc.py @@ -0,0 +1,999 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""PyTorch wrapper functions for mHC (manifold Hyper-Connection) Triton kernels.""" + +import os +import torch +import triton + +from transformer_engine.common.triton.mhc import ( + _mhc_scale_fwd_fused, + _mhc_scale_bwd_fused, + _mhc_expand_combine_with_bias_fwd, + _mhc_expand_combine_with_bias_bwd, + _mhc_expand_combine_fwd, + _mhc_expand_combine_bwd, + _mhc_aggregate_fwd, + _mhc_aggregate_bwd, + _mhc_projection_fwd_fused, + _mhc_projection_bwd_fused, + _mhc_sinkhorn_fwd_fused, + _mhc_sinkhorn_fwd_fused_recompute, + _mhc_sinkhorn_bwd_fused, + _mhc_sinkhorn_bwd_fused_recompute, +) +from transformer_engine.pytorch.cpp_extensions.gemm import general_gemm + + +def check_deterministic(operator: str): + """ + Checks if the non-deterministic algorithm is allowed for the given operator. If not, raises an assertion error with instructions on how to allow it. + Since atomic add is used in this mHC implementation, it breaks the determinism guarantee due to non-associativity of floating point addition. + """ + allow_nondeterministic = os.environ.get("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1") == "1" + assert allow_nondeterministic, ( + f"[{operator}]: This operation uses atomic add which violates determinism. Set" + " NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 to allow this non-deterministic behavior." + ) + + +def mhc_fused_sinkhorn( + H_res: torch.Tensor, n: int = 4, recompute_hist: bool = True, iters: int = 20 +): + """ + Sinkhorn operation to compute the final H_res matrix (see eq. 19, section 4.3.1 of the DeepSeek mHC paper): + + The Sinkhorn operation conducts an iterative normalization process that alternately rescales rows and columns to sum to 1. + This kernel performs this operation in the log space for numerical stability. + + Parameters + ---------- + H_res : torch.Tensor + input H_res matrix of shape (s, b, n, n) that needs to be normalized into a doubly stochastic matrix. + n : int + number of hyper connections, where only n=4 is supported in the current implementation + recompute_hist : bool + whether to recompute the intermediate history in the backward pass to save memory + iters : int + number of Sinkhorn iterations, according to the DeepSeek paper 20 is enough for convergence + + Returns + ------- + out : torch.Tensor + out of shape (s, b, n, n), which is the final H_res after Sinkhorn normalization + """ + assert n == 4, "Only n=4 is supported in this implementation" + out = mHCSinkhornOp.apply(H_res, n, recompute_hist, iters) + return out + + +def mhc_fused_scale( + H: torch.Tensor, alpha: torch.Tensor, beta: torch.Tensor, ms: torch.Tensor, n: int +): + """ + Fused scale operation to compute the scaled H matrices (see eq. 16-18, section 4.3.1 of the DeepSeek mHC paper): + + H_pre = H[:, 0:n] * alpha[0] / sqrt(ms) + beta[0:n] + H_post = H[:, n:2n] * alpha[1] / sqrt(ms) + beta[n:2n] + H_res = H[:, 2n:2n+n*n] * alpha[2] / sqrt(ms) + beta[2n:2n+n*n] + + H_pre = sigmoid(H_pre) + H_post = 2*sigmoid(H_post) + + Parameters + ---------- + H : torch.Tensor + input H matrix of shape (M, 32), where M=s*b, and only the first N elements in the last dimension are valid + alpha : torch.Tensor + scaling factor for H, of shape (3,), where + alpha[0] is applied to H[:, 0:n] for H_pre + alpha[1] is applied to H[:, n:2n] for H_post + alpha[2] is applied to H[:, 2n:2n+n*n] for H_res + beta : torch.Tensor + bias term for H, of shape (1, 2*n+n*n), where + beta[0, 0:n] is applied to H[:, 0:n] for H_pre + beta[0, n:2n] is applied to H[:, n:2n] for H_post + beta[0, 2n:2n+n*n] is applied to H[:, 2n:2n+n*n] for H_res + ms : torch.Tensor + mean square for each row of H from the projection kernel, of shape (M,), used for RMSNorm scaling + n : int + number of hyper connections, where only n=4 is supported in the current implementation + + Returns + ------- + h_pre : torch.Tensor + Scaled H_pre of shape (M, n), which aggregates (s, b, C, n) input of a Hyper Connection block into (s, b, n) as the input of attention / MLP + h_post : torch.Tensor + Scaled H_post of shape (M, n), which expands the output of attention / MLP of shape (s, b, n) back to (s, b, C, n) for the residual connection + h_res : torch.Tensor + Scaled H_res of shape (M, n*n), which mixes the n streams of the (s, b, C, n) input of a Hyper Connection block + + """ + assert n == 4, "Only n=4 is supported in this implementation" + check_deterministic("mhc_fused_scale") + out = mHCScaleFusedOp.apply(H, alpha, beta, ms, n) + h_pre = out[..., :n] + h_post = out[..., n : 2 * n] + h_res = out[..., 2 * n : n * n + 2 * n] + return h_pre, h_post, h_res + + +def mhc_fused_aggregate(x: torch.Tensor, H_pre: torch.Tensor, n: int, use_tf32: bool = True): + """ + Aggregate operation to merge n activation streams into one (see section 4.3.1 of the DeepSeek mHC paper): + out = x @ H_pre: (s, b, C, n) @ (s, b, n, 1) -> (s, b, C, 1) -> (s, b, C) after squeezing the last dimension + + Parameters + ---------- + x : torch.Tensor + input activation tensor of shape (s, b, C, n), + where s is the sequence length, b is the batch size, C is the hidden dimension per hyper connection, and n is the number of hyper connections. Note that C is equal to the original hidden dimension divided by n. + H_pre: torch.Tensor + input H_pre matrix of shape (s, b, n) + n: int + number of hyper connections, where only n=4 is supported in the current implementation + use_tf32: bool + whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. + This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail + + Returns + ------- + out: torch.Tensor + output activation tensor of shape (s, b, C), which is the aggregated output after merging n hyper connections + """ + assert n == 4, "Only n=4 is supported in this implementation" + check_deterministic("mhc_fused_aggregate") + out = mHCAggregateOp.apply(x, H_pre, n, use_tf32) + return out + + +def mhc_fused_expand_combine( + f: torch.Tensor, + bias: torch.Tensor, + H_post: torch.Tensor, + x: torch.Tensor, + H_res: torch.Tensor, + n: int, + use_tf32: bool = True, +): + """ + Expand and combine operation for merging n hyper connections (see section 4.3.1 of the DeepSeek mHC paper): + + out = (f [+ bias]) @ H_post + x @ H_res: (s, b, C, 1) @ (s, b, 1, n) + (s, b, C, n) @ (s, b, n, n) -> (s, b, C, n) + + Parameters + ---------- + f : torch.Tensor + input activation tensor of shape (s, b, C), which is the output from the attention / FFN sub-layer in a transformer block + bias : torch.Tensor or None + optional bias tensor of shape (C,) from the last linear layer, where f + bias is fused in this kernel for better performance + H_post : torch.Tensor + input H_post matrix of shape (s, b, n) + x : torch.Tensor + input activation tensor of shape (s, b, C, n), which is the hyper connection input before the aggregation operation + H_res : torch.Tensor + input H_res matrix of shape (s, b, n, n) + n : int + number of hyper connections + use_tf32 : bool + whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. + This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail + + Returns + ------- + out : torch.Tensor + out of shape (s, b, C, n), which is the expanded and combined output after merging n hyper connections + """ + assert n == 4, "Only n=4 is supported in this implementation" + check_deterministic("mhc_fused_expand_combine") + out = mHCExpandCombineOp.apply( + f, + bias, + H_post, + x, + H_res, + n, + use_tf32, + ) + return out + + +def mhc_fused_projection(x: torch.Tensor, phi: torch.Tensor, use_tf32: bool = True): + """ + Fused projection operation to compute H matrices and mean square for RMSNorm (see eq. 14-15, section 4.3.1 of the DeepSeek mHC paper): + + H = x @ phi^T: (M, K) @ (K, N) -> (M, N), which is padded to (M, 32) for better memory access pattern in the next kernels. + ms = mean(x^2, dim=-1): (M,) + + Note: the current implementation only supports n=4 + + Parameters + ---------- + x : torch.Tensor + input tensor of shape (M, K), where M=s*b is the batch size and K=nC is the hidden dimension after expansion. + phi : torch.Tensor + projection matrix of shape (N, K), where N=2n+n*n (=24 for n=4) + use_tf32 : bool + whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. + This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail. + + Returns + ------- + H : torch.Tensor + Projected matrix of shape (M, 32), where only the first N elements in the last dimension are valid. + ms : torch.Tensor + Mean square of shape (M,), which is used for RMSNorm in the next kernel. + """ + assert ( + phi.shape[0] == 24 + ), "Currently only n=4 is supported, which means phi should have 24 in its first dimension" + check_deterministic("mhc_fused_projection") + H, ms = mHCProjectionOp.apply(x, phi, use_tf32) + return H, ms + + +class mHCProjectionOp(torch.autograd.Function): + """ + PyTorch operator for the fused projection operation in mHC, whose wrapper API is mhc_fused_projection. + """ + + @staticmethod + def forward(ctx, x, phi, use_tf32=True): + """ + The forward pass of the fused projection operation. Computes H = x @ phi^T and the mean + square ms = mean(x^2, dim=-1) for RMSNorm in a single fused kernel. + + Parameters: + ctx : The context object. + x (tensor): The input tensor of shape (M, K), where M=s*b is the flattened batch dimension and K=nC is the hidden dimension after expansion. + phi (tensor): The projection matrix of shape (N, K), where N=2n+n*n (=24 for n=4). + use_tf32 (bool): Whether to use TF32 precision for matmul operations. If False, uses IEEE for better precision. + + Returns: + tuple: A tuple of (H, ms) where H is the projected matrix of shape (M, 32) padded for memory alignment (only the first N elements are valid), and ms is the mean square of shape (M,) in FP32. + """ + x = x.contiguous() + phi = phi.contiguous() + + ctx.use_tf32 = use_tf32 + ctx.dtype = x.dtype + + M, K = x.shape + device = x.device + + N = phi.shape[0] + + # Pad H to (s, b, 32) for better memory access pattern in the kernel, but only the first N elements in the last dimension are valid + H = torch.zeros((M, 32), device=device, dtype=torch.float32) + ms = torch.zeros( + (M,), device=device, dtype=torch.float32 + ) # Mean square for x, used to compute RMSNorm in the next kernel + + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(K, META["BLOCK_SIZE_K"]), + ) + + _mhc_projection_fwd_fused[grid]( + x_ptr=x, # (M, K) + phi_ptr=phi, # (N, K) + h_ptr=H, # (M, 32) + ms_ptr=ms, # (M,) + M=M, + N=N, + K=K, + stride_xm=K, + stride_xk=1, + stride_phin=K, + stride_phik=1, + stride_hm=32, + stride_hn=1, + stride_ms=1, + BLOCK_SIZE_N=32, + precision="tf32" if use_tf32 else "ieee", + ) + + ctx.save_for_backward(x, phi, ms) + ctx.phi_dtype = phi.dtype + + return H.to(ctx.dtype), ms # Keep ms in fp32 + + @staticmethod + def backward(ctx, grad_H, grad_ms): + """ + The backward pass of the fused projection operation. Computes gradients for x and phi. + + grad_phi = grad_H^T @ x, truncated to the first N rows. + grad_x = grad_H @ phi + 2 * x * grad_ms / K, where the second term is the gradient contribution from + the mean square computation fused in the forward pass. + + Parameters: + ctx : The context object with saved tensors. + grad_H (tensor): The gradient of the loss with respect to H, of shape (M, 32). + grad_ms (tensor): The gradient of the loss with respect to the mean square, of shape (M,). + + Returns: + tuple: A tuple with the gradients (grad_x, grad_phi, None). + """ + x, phi, ms = ctx.saved_tensors + M, K = x.shape + device = x.device + + N = phi.shape[0] + + grad_H = grad_H.contiguous().view(M, -1) + grad_ms = grad_ms.contiguous().view( + M, + ) + ms = ms.contiguous().view( + M, + ) + + grad_x = torch.empty((M, K), device=device, dtype=x.dtype) + + grad_x = torch.empty((M, K), device=device, dtype=x.dtype) + grad_phi = general_gemm(x, grad_H, out_dtype=torch.float32, layout="NT")[0][:N, :].to( + phi.dtype + ) # (2n + n^2, M) @ (M, nC) = (2n + n^2, nC); grad_H's last dim is padded to 32 + + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(K, META["BLOCK_SIZE_K"]), + ) + + _mhc_projection_bwd_fused[grid]( + x_ptr=x, + grad_x_ptr=grad_x, # (M, K) + phi_ptr=phi, # (N, K) + grad_h_ptr=grad_H, # (M, 32) + grad_ms_ptr=grad_ms, # (M,) + M=M, + N=N, + K=K, + stride_xm=K, + stride_xk=1, + stride_grad_xm=K, + stride_grad_xk=1, + stride_phin=K, + stride_phik=1, + stride_grad_phin=K, + stride_grad_phik=1, + stride_grad_hm=32, + stride_grad_hn=1, + stride_grad_ms=1, + BLOCK_SIZE_N=32, + precision="tf32" if ctx.use_tf32 else "ieee", + ) + + return grad_x.to(ctx.dtype), grad_phi.to(ctx.dtype), None + + +class mHCScaleFusedOp(torch.autograd.Function): + """ + PyTorch operator for the fused scale operation in mHC, whose wrapper API is mhc_fused_scale. + """ + + @staticmethod + def forward(ctx, H, alpha, beta, ms, n): + """ + The forward pass of the fused scale operation. Applies RMSNorm scaling, bias, and activation + functions to produce H_pre, H_post, and H_res: + + H_pre = sigmoid(H[:, 0:n] * alpha[0] / sqrt(ms) + beta[0:n]) + H_post = 2 * sigmoid(H[:, n:2n] * alpha[1] / sqrt(ms) + beta[n:2n]) + H_res = H[:, 2n:2n+n*n] * alpha[2] / sqrt(ms) + beta[2n:2n+n*n] + + Parameters: + ctx : The context object. + H (tensor): The input H matrix of shape (M, 32), where only the first N=2n+n*n elements are valid. + alpha (tensor): The scaling factors of shape (3,), one for each of H_pre, H_post, H_res. + beta (tensor): The bias terms of shape (1, 2n+n*n). + ms (tensor): The mean square from the projection kernel, of shape (M,), used for RMSNorm scaling. + n (int): The number of hyper connections (only n=4 is supported). + + Returns: + tensor: The scaled output of shape (M, 32), where only the first N elements are valid. + """ + + ctx.dtype = H.dtype + H = H.to(torch.float32) + alpha = alpha.to(torch.float32) + beta = beta.to(torch.float32) + ms = ms.to(torch.float32) + + M, _ = H.shape + + H = H.contiguous() + beta = beta.contiguous() + ms = ms.contiguous() + + out = torch.empty( + (M, 32), device=H.device, dtype=H.dtype + ) # Pad the output to 32 in the last dimension + + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]),) + + _mhc_scale_fwd_fused[grid]( + h_ptr=H, # (M, N), which is padded to (M, 32) + b_ptr=beta, # (N,) + a_ptr=alpha, # (N,) + ms_ptr=ms, # (M,) + out_ptr=out, # (M, N), which is padded to (M, 32) + M=M, + n=n, + stride_hm=32, + stride_hn=1, + stride_a=1, + stride_b=1, + stride_ms=1, + stride_out_m=32, + stride_out_n=1, # strides for out, which is padded to 32 in the last dimension + BLOCK_SIZE_N=32, + eps=torch.finfo(ms.dtype).eps, + ) + + ctx.save_for_backward(H, alpha, ms, out) + ctx.n = n + + return out.to(ctx.dtype) # Cast back to the original dtype of H + + @staticmethod + def backward(ctx, grad_out): + """ + The backward pass of the fused scale operation. Computes gradients for H, alpha, beta, and ms + by backpropagating through the sigmoid activations, RMSNorm scaling, and bias additions. + + Parameters: + ctx : The context object with saved tensors. + grad_out (tensor): The gradient of the loss with respect to the output, of shape (M, 32). + + Returns: + tuple: A tuple with the gradients (grad_H, grad_alpha, grad_beta, grad_ms, None). + """ + H, alpha, ms, out = ctx.saved_tensors + n = ctx.n + + grad_out = grad_out.contiguous() + grad_out = grad_out.to(torch.float32) + + M, _ = grad_out.shape + N = 2 * n + n * n + + grad_h = torch.zeros( + (M, 32), device=grad_out.device, dtype=grad_out.dtype + ) # Pad the grad_h to 32 in the last dimension + grad_alpha = torch.zeros((3,), device=grad_out.device, dtype=grad_out.dtype) + grad_beta_padded = torch.zeros((1, 32), device=grad_out.device, dtype=grad_out.dtype) + grad_beta = grad_beta_padded[ + :, :N + ] # Use only the first N elements for grad_beta, the rest are just padding + grad_ms = torch.zeros((M,), device=grad_out.device, dtype=grad_out.dtype) + + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]),) + + _mhc_scale_bwd_fused[grid]( + grad_out_ptr=grad_out, + out_ptr=out, + grad_h_ptr=grad_h, + h_ptr=H, + grad_a_ptr=grad_alpha, + a_ptr=alpha, + grad_b_ptr=grad_beta, + grad_ms_ptr=grad_ms, + ms_ptr=ms, + M=M, + n=n, + stride_grad_out_m=32, + stride_grad_out_n=1, + stride_out_m=32, + stride_out_n=1, + stride_grad_hm=32, + stride_grad_hn=1, + stride_hm=32, + stride_hn=1, + stride_grad_a=1, + stride_a=1, + stride_grad_b=1, + stride_grad_ms=1, + stride_ms=1, + BLOCK_SIZE_N=32, + eps=torch.finfo(ms.dtype).eps, + ) + + return ( + grad_h.to(ctx.dtype), + grad_alpha.to(ctx.dtype), + grad_beta.to(ctx.dtype), + grad_ms.to(ctx.dtype), + None, + ) + + +class mHCSinkhornOp(torch.autograd.Function): + """ + PyTorch operator for the Sinkhorn operation in mHC, whose wrapper API is mhc_fused_sinkhorn. + """ + + @staticmethod + def forward(ctx, H_res, n=4, recompute_hist=True, iters=20): + """ + The forward pass of the Sinkhorn operation. Performs iterative row-column normalization + in log space to convert H_res into a doubly stochastic matrix. Each iteration alternately + rescales rows and columns to sum to 1: + + f = log_mu - logsumexp(H_res + g, dim=cols) + g = log_nu - logsumexp(H_res + f, dim=rows) + output = exp(f + H_res + g) + + Parameters: + ctx : The context object. + H_res (tensor): The input H_res matrix of shape (s, b, n, n). + n (int): The number of hyper connections (only n=4 is supported). + recompute_hist (bool): Whether to recompute the intermediate f/g history in the backward pass to save memory. If False, stores history buffers of shape (iters+1, s, b, n). + iters (int): The number of Sinkhorn iterations (20 is enough for convergence per the DeepSeek paper). + + Returns: + tensor: The doubly stochastic matrix of shape (s, b, n, n). + """ + + s, b, _, _ = H_res.shape + + ctx.dtype = H_res.dtype + H_res = H_res.to(torch.float32) + + H_res = H_res.contiguous().view(s * b, n * n) + + hist_f, hist_g = None, None + if not recompute_hist: + # History buffers: (iters+1, s, b, n) + hist_f = torch.empty((iters + 1, s, b, n), device=H_res.device, dtype=H_res.dtype) + hist_g = torch.empty((iters + 1, s, b, n), device=H_res.device, dtype=H_res.dtype) + H_res_out = torch.empty_like(H_res) # (s*b, n*n) + + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: (triton.cdiv(s * b * n * n, META["BLOCK_SIZE"]),) + + if recompute_hist: + _mhc_sinkhorn_fwd_fused_recompute[grid]( + x_ptr=H_res, + output_ptr=H_res_out, + stride_xm=n * n, + stride_xn=1, + stride_out_m=n * n, + stride_out_n=1, + M=s * b, + n=n, + iters=iters, + ) + else: + _mhc_sinkhorn_fwd_fused[grid]( + x_ptr=H_res, + output_ptr=H_res_out, + hist_f_ptr=hist_f, + hist_g_ptr=hist_g, + stride_xm=n * n, + stride_xn=1, + stride_out_m=n * n, + stride_out_n=1, + M=s * b, + n=n, + iters=iters, + ) + + if recompute_hist: + ctx.save_for_backward(H_res, H_res_out) + else: + ctx.save_for_backward(H_res, H_res_out, hist_f, hist_g) + ctx.recompute_hist = recompute_hist + ctx.iters = iters + ctx.n = n + + H_res_out = H_res_out.view(s, b, n, n) + return H_res_out.to(ctx.dtype) # Cast back to the original dtype of H + + @staticmethod + def backward(ctx, grad_out): + """ + The backward pass of the Sinkhorn operation. Backpropagates through the iterative + normalization by reversing through the f/g update steps. If recompute_hist is True, + the forward pass history is recomputed to save memory. + + Parameters: + ctx : The context object with saved tensors. + grad_out (tensor): The gradient of the loss with respect to the output, of shape (s, b, n, n). + + Returns: + tuple: A tuple with the gradients (grad_H_res, None, None, None). + """ + + s, b, n, _ = grad_out.shape + M = s * b + + hist_f, hist_g = None, None + recompute_hist = ctx.recompute_hist + iters = ctx.iters + if recompute_hist: + H_res, H_res_out = ctx.saved_tensors + hist_f = torch.empty((iters + 1, s, b, n), device=H_res.device, dtype=H_res.dtype) + hist_g = torch.empty((iters + 1, s, b, n), device=H_res.device, dtype=H_res.dtype) + else: + H_res, H_res_out, hist_f, hist_g = ctx.saved_tensors + + n = ctx.n + + grad_res_out = grad_out.clone().contiguous().view(M, n * n) + + grad_res = torch.empty_like(H_res) + + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: (triton.cdiv(M * n * n, META["BLOCK_SIZE"]),) + + if recompute_hist: + _mhc_sinkhorn_bwd_fused_recompute[grid]( + grad_out_ptr=grad_res_out, + output_ptr=H_res_out, + grad_x_ptr=grad_res, + x_ptr=H_res, + hist_f_ptr=hist_f, + hist_g_ptr=hist_g, + stride_grad_out_m=n * n, + stride_grad_out_n=1, + stride_out_m=n * n, + stride_out_n=1, + stride_grad_xm=n * n, + stride_grad_xn=1, + stride_xm=n * n, + stride_xn=1, + M=M, + n=n, + iters=iters, + ) + else: + _mhc_sinkhorn_bwd_fused[grid]( + grad_out_ptr=grad_res_out, + output_ptr=H_res_out, + grad_x_ptr=grad_res, + x_ptr=H_res, + hist_f_ptr=hist_f, + hist_g_ptr=hist_g, + stride_grad_out_m=n * n, + stride_grad_out_n=1, + stride_out_m=n * n, + stride_out_n=1, + stride_grad_xm=n * n, + stride_grad_xn=1, + stride_xm=n * n, + stride_xn=1, + M=M, + n=n, + iters=iters, + ) + + grad_res = grad_res.view(s, b, n, n) + + return grad_res.to(ctx.dtype), None, None, None + + +class mHCAggregateOp(torch.autograd.Function): + """ + PyTorch operator for the aggregate operation in mHC, whose wrapper API is mhc_fused_aggregate. + """ + + @staticmethod + def forward(ctx, x, H_pre, n, use_tf32=True): + """ + The forward pass of the aggregate operation. Merges n activation streams into one by + computing a weighted sum using H_pre: + + out = x @ H_pre: (s, b, C, n) @ (s, b, n, 1) -> (s, b, C) + + Parameters: + ctx : The context object. + x (tensor): The input activation tensor of shape (s, b, C, n). + H_pre (tensor): The pre-connection matrix of shape (s, b, n), used as weights for aggregation. + n (int): The number of hyper connections (only n=4 is supported). + use_tf32 (bool): Whether to use TF32 precision for matmul operations. + + Returns: + tensor: The aggregated output of shape (s, b, C). + """ + + x = x.contiguous() + H_pre = H_pre.contiguous() + + s, b, C, n = x.shape + nC = n * C + M = s * b + + out = torch.empty((s, b, C), device=x.device, dtype=x.dtype) + + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: ( + triton.cdiv(C, META["BLOCK_SIZE_C"]), + triton.cdiv(M, META["BLOCK_SIZE_M"]), + ) + + _mhc_aggregate_fwd[grid]( + x_ptr=x, + H_pre_ptr=H_pre, + output_ptr=out, + M=M, + C=C, + n=n, + stride_xm=nC, + stride_xCn=1, + stride_output_m=C, + stride_output_c=1, + ) + + ctx.save_for_backward(x, H_pre) + ctx.n = n + ctx.use_tf32 = use_tf32 + + return out + + @staticmethod + def backward(ctx, grad_output): + """ + The backward pass of the aggregate operation. Computes gradients for x and H_pre: + + grad_x[:, :, :, i] = grad_output * H_pre[:, :, i] for each stream i + grad_H_pre[:, :, i] = sum_C(grad_output * x[:, :, :, i]) for each stream i + + Parameters: + ctx : The context object with saved tensors. + grad_output (tensor): The gradient of the loss with respect to the output, of shape (s, b, C). + + Returns: + tuple: A tuple with the gradients (grad_x, grad_H_pre, None, None). + """ + grad_output = grad_output.contiguous() + + x, H_pre = ctx.saved_tensors + n = ctx.n + + s, b, C, n = x.shape + nC = n * C + assert n == 4, "Only n=4 is supported in this implementation" + M = s * b + + grad_x = torch.empty_like(x) + grad_H_pre = torch.zeros( + (s, b, n), dtype=torch.float32, device=H_pre.device + ) # We need to use atomic_add for this so we need higher precision + + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: ( + triton.cdiv(C, META["BLOCK_SIZE_C"]), + triton.cdiv(M, META["BLOCK_SIZE_M"]), + ) + + _mhc_aggregate_bwd[grid]( + grad_output_ptr=grad_output, + H_pre_ptr=H_pre, + grad_H_pre_ptr=grad_H_pre, + x_ptr=x, + grad_x_ptr=grad_x, + M=M, + C=C, + n=n, + stride_grad_output_m=C, + stride_grad_output_c=1, + stride_xm=nC, + stride_xCn=1, + stride_grad_xm=nC, + stride_grad_xCn=1, + precision="tf32" if ctx.use_tf32 else "ieee", + ) + + grad_H_pre = grad_H_pre.to(H_pre.dtype) # Cast back to the original dtype of H_pre + + return grad_x, grad_H_pre, None, None + + +class mHCExpandCombineOp(torch.autograd.Function): + """ + PyTorch operator for the expand and combine operation in mHC, whose wrapper API is mhc_fused_expand_combine. + """ + + @staticmethod + def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True): + """ + The forward pass of the expand and combine operation. Expands the sub-layer output f back + to n streams using H_post, and combines with the residual connections using H_res: + + out = (f [+ bias]) @ H_post + x @ H_res: (s, b, C, 1) @ (s, b, 1, n) + (s, b, C, n) @ (s, b, n, n) -> (s, b, C, n) + + Parameters: + ctx : The context object. + f (tensor): The sub-layer output tensor of shape (s, b, C). + bias (tensor or None): Optional bias tensor of shape (C,) from the last linear layer, fused in this kernel. + H_post (tensor): The post-connection matrix of shape (s, b, n). + x (tensor): The hyper connection input tensor of shape (s, b, C, n) before aggregation. + H_res (tensor): The residual connection matrix of shape (s, b, n, n). + n (int): The number of hyper connections (only n=4 is supported). + use_tf32 (bool): Whether to use TF32 precision for matmul operations. + + Returns: + tensor: The expanded and combined output of shape (s, b, C, n). + """ + + x = x.contiguous() + f = f.contiguous() + if bias is not None: + bias = bias.contiguous() + H_post = H_post.contiguous() + H_res = H_res.contiguous() + + s, b, C, n = x.shape + Cn = C * n + M = s * b + + out = torch.empty((s, b, C, n), device=x.device, dtype=x.dtype) + + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: ( + triton.cdiv(C, META["BLOCK_SIZE_C"]), + triton.cdiv(M, META["BLOCK_SIZE_M"]), + ) + + if bias is None: + _mhc_expand_combine_fwd[grid]( + f_ptr=f, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + output_ptr=out, + M=M, + C=C, + n=n, + stride_fm=C, + stride_fc=1, + stride_xm=Cn, + stride_xCn=1, + stride_output_m=Cn, + stride_output_Cn=1, + ) + else: + _mhc_expand_combine_with_bias_fwd[grid]( + f_ptr=f, + bias_ptr=bias, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + output_ptr=out, + M=M, + C=C, + n=n, + stride_fm=C, + stride_fc=1, + stride_bias=1, + stride_xm=Cn, + stride_xCn=1, + stride_output_m=Cn, + stride_output_Cn=1, + ) + + ctx.n = n + ctx.have_bias = bias is not None + if bias is not None: + ctx.save_for_backward(f, bias, H_post, x, H_res) + else: + ctx.save_for_backward(f, H_post, x, H_res) + ctx.use_tf32 = use_tf32 + + return out + + @staticmethod + def backward(ctx, grad_output): + """ + The backward pass of the expand and combine operation. Computes gradients for f, bias, + H_post, x, and H_res by backpropagating through the outer product and matrix multiply: + + grad_f = sum_n(grad_output * H_post) [+ reduce grad_bias over (s, b)] + grad_H_post[:, :, i] = sum_C(grad_output[:, :, :, i] * (f [+ bias])) + grad_x = grad_output @ H_res^T + grad_H_res[:, :, i, j] = sum_C(grad_output[:, :, :, j] * x[:, :, :, i]) + + Parameters: + ctx : The context object with saved tensors. + grad_output (tensor): The gradient of the loss with respect to the output, of shape (s, b, C, n). + + Returns: + tuple: A tuple with the gradients (grad_f, grad_bias, grad_H_post, grad_x, grad_H_res, None, None). + """ + grad_output = grad_output.contiguous() + s, b, C, n = grad_output.shape + + if ctx.have_bias: + f, bias, H_post, x, H_res = ctx.saved_tensors + else: + bias = None + f, H_post, x, H_res = ctx.saved_tensors + M = s * b + + grad_f = torch.empty_like(f) + grad_bias = torch.zeros_like(bias, dtype=torch.float32) if bias is not None else None + grad_H_post = torch.zeros_like( + H_post, dtype=torch.float32 + ) # We need to use atomic_add for this so we need higher precision + grad_x = torch.empty_like(x) + grad_H_res = torch.zeros_like( + H_res, dtype=torch.float32 + ) # We need to use atomic_add for this so we need higher precision + + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: ( + triton.cdiv(C, META["BLOCK_SIZE_C"]), + triton.cdiv(M, META["BLOCK_SIZE_M"]), + ) + + if bias is None: + _mhc_expand_combine_bwd[grid]( + grad_output_ptr=grad_output, + f_ptr=f, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + grad_H_post_ptr=grad_H_post, + grad_f_ptr=grad_f, + grad_H_res_ptr=grad_H_res, + grad_x_ptr=grad_x, + M=M, + C=C, + n=n, + stride_grad_output_m=n * C, + stride_grad_output_Cn=1, + stride_fm=C, + stride_fc=1, + stride_xm=n * C, + stride_xCn=1, + stride_grad_fm=C, + stride_grad_fc=1, + stride_grad_xm=n * C, + stride_grad_xCn=1, + precision="tf32" if ctx.use_tf32 else "ieee", + ) + else: + _mhc_expand_combine_with_bias_bwd[grid]( + grad_output_ptr=grad_output, + f_ptr=f, + bias_ptr=bias, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + grad_H_post_ptr=grad_H_post, + grad_f_ptr=grad_f, + grad_bias_ptr=grad_bias, + grad_H_res_ptr=grad_H_res, + grad_x_ptr=grad_x, + M=M, + C=C, + n=n, + stride_grad_output_m=n * C, + stride_grad_output_Cn=1, + stride_fm=C, + stride_fc=1, + stride_bias=1, + stride_xm=n * C, + stride_xCn=1, + stride_grad_fm=C, + stride_grad_fc=1, + stride_grad_bias=1, + stride_grad_xm=n * C, + stride_grad_xCn=1, + precision="tf32" if ctx.use_tf32 else "ieee", + ) + + grad_H_post = grad_H_post.to(H_post.dtype) # Cast back to the original dtype of H_post + grad_H_res = grad_H_res.to(H_res.dtype) # Cast back to the original dtype of H_res + if bias is not None: + grad_bias = grad_bias.to(bias.dtype) + + return grad_f, grad_bias, grad_H_post, grad_x, grad_H_res, None, None diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index a76f205acc..250daec67f 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -7,6 +7,7 @@ import functools import math import os +import warnings from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from contextlib import nullcontext import numpy as np @@ -81,6 +82,36 @@ def get_device_compute_capability() -> Tuple[int, int]: return _get_device_compute_capability(torch.cuda.current_device()) +def resolve_grouped_linear_single_param_flags( + single_grouped_weight: bool, + single_grouped_bias: bool, +) -> Tuple[bool, bool]: + """Gate ``single_grouped_weight`` / ``single_grouped_bias`` on ``NVTE_GROUPED_LINEAR_SINGLE_PARAM``.""" + if not (single_grouped_weight or single_grouped_bias): + return single_grouped_weight, single_grouped_bias + + env_enabled = int(os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0")) > 0 + if not env_enabled: + warnings.warn( + f"GroupedLinear was constructed with single_grouped_weight={single_grouped_weight} " + f"and single_grouped_bias={single_grouped_bias}, but the " + "NVTE_GROUPED_LINEAR_SINGLE_PARAM environment variable is not set. " + "Disabling single grouped weight/bias and falling back to per-expert parameters.", + UserWarning, + stacklevel=3, + ) + return False, False + + warnings.warn( + "GroupedLinear is using single_grouped_weight/single_grouped_bias. " + "This feature is experimental, may change in future " + "releases, and is known to be non-deterministic in certain cases.", + UserWarning, + stacklevel=3, + ) + return single_grouped_weight, single_grouped_bias + + def attention_mask_func( attention_scores: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: