-
Notifications
You must be signed in to change notification settings - Fork 735
[Common] Enable NVFP4 2D block scaling in columnwise only #3027
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
56780d1
61a2387
f7953dd
c069cea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -779,7 +779,7 @@ __global__ void __launch_bounds__(THREADS_NUM) | |
| } | ||
|
|
||
| template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &), | ||
| typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE> | ||
| typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_ROWWISE, bool RETURN_TRANSPOSE> | ||
| __global__ void __launch_bounds__(THREADS_NUM) | ||
| quantize_transpose_nvfp4_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, | ||
| const __grid_constant__ CUtensorMap tensor_map_output, | ||
|
|
@@ -1128,7 +1128,7 @@ __global__ void __launch_bounds__(THREADS_NUM) | |
| } | ||
|
|
||
| // ROWWISE scaling | ||
| { | ||
| if constexpr (RETURN_ROWWISE) { | ||
| const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; | ||
| #pragma unroll | ||
| for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { | ||
|
|
@@ -1271,9 +1271,11 @@ __global__ void __launch_bounds__(THREADS_NUM) | |
| const size_t global_offset_Y_t = block_offset_Y_t; | ||
| const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; | ||
|
|
||
| ptx::cp_async_bulk_tensor_2d_shared_to_global( | ||
| reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X, global_offset_Y, | ||
| reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out])); | ||
| if constexpr (RETURN_ROWWISE) { | ||
| ptx::cp_async_bulk_tensor_2d_shared_to_global( | ||
| reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X, | ||
| global_offset_Y, reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out])); | ||
| } | ||
|
|
||
| if constexpr (RETURN_TRANSPOSE) { | ||
| ptx::cp_async_bulk_tensor_2d_shared_to_global( | ||
|
|
@@ -1327,6 +1329,9 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, | |
| // return the transposed data. | ||
| // TODO(Frank): Is there a better way to do this? | ||
| bool return_transpose = output->has_columnwise_data(); | ||
| // Columnwise-only (no rowwise output) is supported on the optimized 2D path; the rowwise pass | ||
| // and its store are gated out via the RETURN_ROWWISE template bool. | ||
| const bool return_rowwise = output->has_data(); | ||
|
|
||
| if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { | ||
| quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); | ||
|
|
@@ -1343,9 +1348,14 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, | |
| CheckOutputTensor(*output, "output", false); | ||
|
|
||
| NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); | ||
| NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); | ||
| NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); | ||
| NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); | ||
| NVTE_CHECK(return_rowwise || return_transpose, | ||
| "At least one of rowwise/columnwise NVFP4 output must be allocated."); | ||
| NVTE_CHECK(return_rowwise || use_2d_quantization, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit confusing to read, especially if the kernel is extended in the future to support additional quantization schemes. It would be better to restrict the supported combinations explicitly, e.g. |
||
| "Columnwise-only NVFP4 requires 2D quantization on the optimized path."); | ||
| if (return_rowwise) { | ||
| NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); | ||
| NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); | ||
| } | ||
| NVTE_CHECK(!row_scaled_nvfp4 || output->amax.dptr != nullptr, | ||
| "Row-scaled NVFP4 quantization requires rowwise amax."); | ||
| NVTE_CHECK(!row_scaled_nvfp4 || !output->has_columnwise_data(), | ||
|
|
@@ -1372,7 +1382,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, | |
| const dim3 grid(blocks_X, blocks_Y); | ||
| const size_t block_size = THREADS_NUM; | ||
|
|
||
| const size_t scale_stride = output->scale_inv.shape[1]; | ||
| const size_t scale_stride = return_rowwise ? output->scale_inv.shape[1] : 0; | ||
| const size_t scale_stride_transpose = | ||
| return_transpose ? output->columnwise_scale_inv.shape[1] : 0; | ||
|
|
||
|
|
@@ -1405,8 +1415,10 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, | |
| create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, | ||
| sizeof(IType) * 8); | ||
|
|
||
| create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, | ||
| 4); | ||
| if (return_rowwise) { | ||
| create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, | ||
| 0, 4); | ||
| } | ||
| if (return_transpose) { | ||
| create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, | ||
| BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); | ||
|
|
@@ -1433,21 +1445,26 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, | |
| use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, | ||
|
|
||
| TRANSFORMER_ENGINE_SWITCH_CONDITION(row_scaled_nvfp4, ROW_SCALED_NVFP4, { | ||
| TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { | ||
| auto kernel = quantize_transpose_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, | ||
| USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE, | ||
| ROW_SCALED_NVFP4>; | ||
|
|
||
| if constexpr (use_2d_quantization) { | ||
| kernel = quantize_transpose_nvfp4_2D_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, | ||
| USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE>; | ||
| } | ||
| TRANSFORMER_ENGINE_SWITCH_CONDITION(return_rowwise, RETURN_ROWWISE, { | ||
| TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { | ||
| // The 1D kernel always produces rowwise output (no RETURN_ROWWISE); the dispatch only | ||
| // routes columnwise-only requests here when use_2d_quantization is true. | ||
| auto kernel = quantize_transpose_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, | ||
| USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE, | ||
| ROW_SCALED_NVFP4>; | ||
|
|
||
| if constexpr (use_2d_quantization) { | ||
| kernel = quantize_transpose_nvfp4_2D_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, | ||
| USE_STOCHASTIC_ROUNDING, RETURN_ROWWISE, | ||
| RETURN_TRANSPOSE>; | ||
| } | ||
|
|
||
| cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); | ||
| kernel<<<grid, block_size, dshmem_size, stream>>>( | ||
| tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, | ||
| scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, | ||
| scale_stride, scale_stride_transpose, rng_state); | ||
| cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); | ||
| kernel<<<grid, block_size, dshmem_size, stream>>>( | ||
| tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, | ||
| scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, | ||
| scale_stride, scale_stride_transpose, rng_state); | ||
| }); | ||
| }); | ||
| });); | ||
| #else | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -353,14 +353,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo | |
| extern __shared__ char smem_base[]; | ||
| SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]); | ||
|
|
||
| // 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode. | ||
| // Instead of static_assert, return early if these invalid modes are detected. | ||
| // 2D block scaling is not supported for E8 scaling MXFP4. | ||
| // Instead of static_assert, return early if this invalid mode is detected. | ||
| if constexpr (kIs2DBlockScaling && kIsE8Scaling) { | ||
| return; | ||
| } | ||
| if constexpr (kIs2DBlockScaling && !kReturnIdentity) { | ||
| return; | ||
| } | ||
| // for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4 | ||
| // use constexpr to define the size, when not using 2D, use minimal size 1x1 | ||
| constexpr int kFP4BlockScalingSize = 16; | ||
|
|
@@ -576,6 +573,67 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo | |
| } | ||
| } | ||
|
|
||
| // Step 2.5: 2D-amax-only pass for columnwise-only mode. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The new outer-level block is named "Step 2.5" at line 576, but that same label is already used at line 522 for the "Write scale_inv" substep inside Step 2's loop ( |
||
| // When only the transposed output is requested but 2D block scaling is enabled, the columnwise | ||
| // reads in Step 3 (line ~660 below) still need amax_smem populated. Re-run the load + local-amax | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment refers to line |
||
| // + 2D warp/smem reduction from Step 2 (steps 2.1-2.3), skipping the rowwise scale/quantize/store | ||
| // writes that Step 2 normally does. Same amax_smem values as the rowwise-enabled path, so the | ||
| // dgrad/wgrad columnwise output of (rowwise=False, columnwise=True, 2D) is bitwise identical to | ||
| // the columnwise half of (rowwise=True, columnwise=True, 2D). | ||
| if constexpr (!kReturnIdentity && kIs2DBlockScaling) { | ||
| constexpr int r_stride = | ||
| kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory | ||
| constexpr int num_iterations = kTileDim / r_stride; // 4 iterations for kTileDim=128 | ||
| const int c_s = | ||
| (threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory | ||
| int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory | ||
| #pragma unroll | ||
| for (int iter = 0; iter < num_iterations; ++iter) { | ||
| SMemVec smem_vec[kNVecOut / kNVecSMem]; | ||
| // Step 2.1 (amax-only): Load from shared memory to registers | ||
| #pragma unroll | ||
| for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { | ||
| int c = c_s + i; | ||
| int r = r_s; | ||
| smem_vec[i] = smem[r * kSMemCol + c]; | ||
| } | ||
| // Step 2.2 (amax-only): Compute local amax | ||
| CType amax = 0; | ||
| #pragma unroll | ||
| for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { | ||
| #pragma unroll | ||
| for (int j = 0; j < kNVecSMem; ++j) { | ||
| __builtin_assume(amax >= 0); | ||
| amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j])); | ||
| } | ||
| } | ||
| // Step 2.3 (amax-only): 2D warp + smem amax reduction (mirrors Step 2's 2D path) | ||
| constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; // 32 | ||
| int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7 | ||
| int tid_in_warp_x = threadIdx.x % kNumThreadsStore; | ||
| int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp; | ||
| CType amax_warp_reduced = groupMax<kNumRowsPerWarp, kNumThreadsStore>( | ||
| amax, WARP_REDUCE_AMAX_GROUP_MASKS[tid_in_warp_x]); | ||
| int data_row_idx = iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y; | ||
| if (tid_in_warp_y == 0) { | ||
| amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] | ||
| [warp_idx % k2DBlockAmaxReduceDim] = amax_warp_reduced; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| if (data_row_idx % kFP4BlockScalingSize == 0) { | ||
| CType amax_2d = 0.0; | ||
| for (int i = 0; i < k2DBlockAmaxReduceDim; i++) { | ||
| amax_2d = | ||
| fmaxf(amax_2d, amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]); | ||
| } | ||
| amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] = amax_2d; | ||
| } | ||
| __syncthreads(); | ||
| r_s += r_stride; | ||
| } | ||
| } | ||
|
|
||
| // Step 3: Transpose, cast and store to output_t | ||
| if constexpr (kReturnTranspose) { | ||
| constexpr int c_stride = | ||
|
|
@@ -731,8 +789,6 @@ void quantize_transpose_vector_blockwise_fp4( | |
| NVTE_CHECK(return_identity || return_transpose, | ||
| "At least one of return_identity or return_transpose must be true."); | ||
|
|
||
| NVTE_CHECK(return_identity || !use_2d_quantization, | ||
| "2D block quantization is only supported when return_identity is true."); | ||
| NVTE_CHECK(!row_scaled_nvfp4 || (return_identity && !return_transpose), | ||
| "Row-scaled NVFP4 quantization only supports rowwise quantization."); | ||
| NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is already inside the
if constexpr (RETURN_ROWWISE)scope (starting at line 1131), so it can be removed safely.