diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 53569d90d9..f4b7f30240 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -534,3 +534,103 @@ def test_nvfp4_quantization_noncontiguous_inputs( torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # Aligned tiles + (128, 128), + (256, 256), + (512, 512), + (2048, 2048), + # Padded tiles (non-multiple of kTileDim=128) + (256, 272), + (304, 304), + (320, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +def test_nvfp4_2d_columnwise_only_matches_both_directions( + x_dtype: torch.dtype, + M: int, + N: int, +): + """Bitwise check: 2D NVFP4 with columnwise-only must produce the same + columnwise data/scales as the columnwise half of (rowwise + columnwise) 2D. + + Covers both kernels depending on the (dtype, shape) routing: + - bf16 with rows % 32 == 0 and cols % 32 == 0 routes to the optimized + ``quantize_transpose_nvfp4_2D_kernel`` (instantiated with RETURN_ROWWISE=false), + validating that gating the rowwise pass/store leaves the shared + ``block_amax_matrix`` and columnwise output bitwise-identical to both-directions. + - non-bf16, or cols % 32 != 0, falls back to the columnwise-only 2D-amax-only + pass in ``quantize_transpose_vector_blockwise_fp4.cu``. + """ + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((M, N), dtype=x_dtype, device=device) + + def _make_quantizer(*, rowwise: bool, columnwise: bool) -> NVFP4Quantizer: + return NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=rowwise, + columnwise=columnwise, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=True, + row_scaled_nvfp4=False, + ) + + # Reference: produce both directions in a single kernel call. + q_both = _make_quantizer(rowwise=True, columnwise=True) + out_both = q_both(x) + + # SUT: produce columnwise only (the path that hits the new amax-only pass). + q_col_only = _make_quantizer(rowwise=False, columnwise=True) + out_col_only = q_col_only(x) + + # Columnwise data/scales/amax must be bitwise identical between the two paths. + # If amax_smem is populated differently in the column-only path, scales diverge, + # and the FP4 cast (which divides by encode_scale) produces different bytes. + assert out_both._columnwise_data is not None + assert out_col_only._columnwise_data is not None + torch.testing.assert_close( + out_col_only._columnwise_data.view(dtype=torch.uint8), + out_both._columnwise_data.view(dtype=torch.uint8), + atol=0, + rtol=0, + ) + + # Compare only the valid (in-bounds) region of the columnwise scale tensor. + # The padded tail (rows K..round_up(K, 128), cols ceil(M/16)..round_up(.., 4)) + # exists for cuBLAS alignment and is NEVER written by the kernel — its bytes + # are whatever ``at::empty`` returned, which differs between two allocations. + NVFP4_BLOCK = 16 + valid_outer = N # cols of input == rows of columnwise scale tensor + valid_inner = (M + NVFP4_BLOCK - 1) // NVFP4_BLOCK + assert out_both._columnwise_scale_inv is not None + assert out_col_only._columnwise_scale_inv is not None + col_sx_both = out_both._columnwise_scale_inv.view(dtype=torch.uint8) + col_sx_col_only = out_col_only._columnwise_scale_inv.view(dtype=torch.uint8) + torch.testing.assert_close( + col_sx_col_only[:valid_outer, :valid_inner], + col_sx_both[:valid_outer, :valid_inner], + atol=0, + rtol=0, + ) + + assert out_both._amax_columnwise is not None + assert out_col_only._amax_columnwise is not None + torch.testing.assert_close( + out_col_only._amax_columnwise, out_both._amax_columnwise, atol=0, rtol=0 + ) + + # Sanity: column-only path must not allocate a rowwise output. + assert out_col_only._rowwise_data is None diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 123362ce10..94e831afee 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -108,8 +108,13 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, "Row-scaled NVFP4 quantization does not produce columnwise output."); nvfp4::compute_rowwise_amax(*input_tensor, noop_tensor, output_tensor, stream); } - bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data(); + // Columnwise-only is supported on the optimized path only for 2D scaling; rowwise-only and + // both-directions keep their existing routing. Columnwise-only 1D and non-bf16 fall back to + // quantize_transpose_vector_blockwise_fp4. + bool use_optimized_kernel = + (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && + (output_tensor->has_data() || + (output_tensor->has_columnwise_data() && quant_config_cpp.nvfp4_2d_quantization)); // Launch NVFP4 quantize kernel if (use_optimized_kernel) { @@ -251,8 +256,13 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens auto dtype = grad_tensor->dtype(); NVTE_CHECK(!output_tensor->row_scaled_nvfp4, "Backward NVFP4 quantization does not support row-scaled outputs."); - bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data(); + // Columnwise-only is supported on the optimized path only for 2D scaling; rowwise-only and + // both-directions keep their existing routing. Columnwise-only 1D and non-bf16 fall back to + // quantize_transpose_vector_blockwise_fp4. + bool use_optimized_kernel = + (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && + (output_tensor->has_data() || + (output_tensor->has_columnwise_data() && quant_config_cpp.nvfp4_2d_quantization)); // Launch NVFP4 quantize kernel if (use_optimized_kernel) { diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 9e4aef5a1c..feb9c6d287 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -779,7 +779,7 @@ __global__ void __launch_bounds__(THREADS_NUM) } template + typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_ROWWISE, bool RETURN_TRANSPOSE> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -1128,7 +1128,7 @@ __global__ void __launch_bounds__(THREADS_NUM) } // ROWWISE scaling - { + if constexpr (RETURN_ROWWISE) { const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; #pragma unroll for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { @@ -1271,9 +1271,11 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t global_offset_Y_t = block_offset_Y_t; const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, - reinterpret_cast(&out_data_sh[buff_offset_out])); + if constexpr (RETURN_ROWWISE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, + global_offset_Y, reinterpret_cast(&out_data_sh[buff_offset_out])); + } if constexpr (RETURN_TRANSPOSE) { ptx::cp_async_bulk_tensor_2d_shared_to_global( @@ -1327,6 +1329,9 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, // return the transposed data. // TODO(Frank): Is there a better way to do this? bool return_transpose = output->has_columnwise_data(); + // Columnwise-only (no rowwise output) is supported on the optimized 2D path; the rowwise pass + // and its store are gated out via the RETURN_ROWWISE template bool. + const bool return_rowwise = output->has_data(); if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); @@ -1343,9 +1348,14 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, CheckOutputTensor(*output, "output", false); NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); - NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + NVTE_CHECK(return_rowwise || return_transpose, + "At least one of rowwise/columnwise NVFP4 output must be allocated."); + NVTE_CHECK(return_rowwise || use_2d_quantization, + "Columnwise-only NVFP4 requires 2D quantization on the optimized path."); + if (return_rowwise) { + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + } NVTE_CHECK(!row_scaled_nvfp4 || output->amax.dptr != nullptr, "Row-scaled NVFP4 quantization requires rowwise amax."); NVTE_CHECK(!row_scaled_nvfp4 || !output->has_columnwise_data(), @@ -1372,7 +1382,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, const dim3 grid(blocks_X, blocks_Y); const size_t block_size = THREADS_NUM; - const size_t scale_stride = output->scale_inv.shape[1]; + const size_t scale_stride = return_rowwise ? output->scale_inv.shape[1] : 0; const size_t scale_stride_transpose = return_transpose ? output->columnwise_scale_inv.shape[1] : 0; @@ -1405,8 +1415,10 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(IType) * 8); - create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, - 4); + if (return_rowwise) { + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, + 0, 4); + } if (return_transpose) { create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); @@ -1433,21 +1445,26 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, TRANSFORMER_ENGINE_SWITCH_CONDITION(row_scaled_nvfp4, ROW_SCALED_NVFP4, { - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = quantize_transpose_nvfp4_kernel; - - if constexpr (use_2d_quantization) { - kernel = quantize_transpose_nvfp4_2D_kernel; - } + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_rowwise, RETURN_ROWWISE, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + // The 1D kernel always produces rowwise output (no RETURN_ROWWISE); the dispatch only + // routes columnwise-only requests here when use_2d_quantization is true. + auto kernel = quantize_transpose_nvfp4_kernel; + + if constexpr (use_2d_quantization) { + kernel = quantize_transpose_nvfp4_2D_kernel; + } - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, - scale_stride, scale_stride_transpose, rng_state); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + }); }); });); #else diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index cf9821f1a9..0efef2c7af 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -353,14 +353,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo extern __shared__ char smem_base[]; SMemVec* smem = reinterpret_cast(&smem_base[0]); - // 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode. - // Instead of static_assert, return early if these invalid modes are detected. + // 2D block scaling is not supported for E8 scaling MXFP4. + // Instead of static_assert, return early if this invalid mode is detected. if constexpr (kIs2DBlockScaling && kIsE8Scaling) { return; } - if constexpr (kIs2DBlockScaling && !kReturnIdentity) { - return; - } // for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4 // use constexpr to define the size, when not using 2D, use minimal size 1x1 constexpr int kFP4BlockScalingSize = 16; @@ -576,6 +573,67 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } } + // Step 2.5: 2D-amax-only pass for columnwise-only mode. + // When only the transposed output is requested but 2D block scaling is enabled, the columnwise + // reads in Step 3 (line ~660 below) still need amax_smem populated. Re-run the load + local-amax + // + 2D warp/smem reduction from Step 2 (steps 2.1-2.3), skipping the rowwise scale/quantize/store + // writes that Step 2 normally does. Same amax_smem values as the rowwise-enabled path, so the + // dgrad/wgrad columnwise output of (rowwise=False, columnwise=True, 2D) is bitwise identical to + // the columnwise half of (rowwise=True, columnwise=True, 2D). + if constexpr (!kReturnIdentity && kIs2DBlockScaling) { + constexpr int r_stride = + kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; // 4 iterations for kTileDim=128 + const int c_s = + (threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut / kNVecSMem]; + // Step 2.1 (amax-only): Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } + // Step 2.2 (amax-only): Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { +#pragma unroll + for (int j = 0; j < kNVecSMem; ++j) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j])); + } + } + // Step 2.3 (amax-only): 2D warp + smem amax reduction (mirrors Step 2's 2D path) + constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; // 32 + int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7 + int tid_in_warp_x = threadIdx.x % kNumThreadsStore; + int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp; + CType amax_warp_reduced = groupMax( + amax, WARP_REDUCE_AMAX_GROUP_MASKS[tid_in_warp_x]); + int data_row_idx = iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y; + if (tid_in_warp_y == 0) { + amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] + [warp_idx % k2DBlockAmaxReduceDim] = amax_warp_reduced; + } + __syncthreads(); + + if (data_row_idx % kFP4BlockScalingSize == 0) { + CType amax_2d = 0.0; + for (int i = 0; i < k2DBlockAmaxReduceDim; i++) { + amax_2d = + fmaxf(amax_2d, amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]); + } + amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] = amax_2d; + } + __syncthreads(); + r_s += r_stride; + } + } + // Step 3: Transpose, cast and store to output_t if constexpr (kReturnTranspose) { constexpr int c_stride = @@ -731,8 +789,6 @@ void quantize_transpose_vector_blockwise_fp4( NVTE_CHECK(return_identity || return_transpose, "At least one of return_identity or return_transpose must be true."); - NVTE_CHECK(return_identity || !use_2d_quantization, - "2D block quantization is only supported when return_identity is true."); NVTE_CHECK(!row_scaled_nvfp4 || (return_identity && !return_transpose), "Row-scaled NVFP4 quantization only supports rowwise quantization."); NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization,