Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,103 @@ def test_nvfp4_quantization_noncontiguous_inputs(
torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0)

torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)


@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# Aligned tiles
(128, 128),
(256, 256),
(512, 512),
(2048, 2048),
# Padded tiles (non-multiple of kTileDim=128)
(256, 272),
(304, 304),
(320, 256),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
def test_nvfp4_2d_columnwise_only_matches_both_directions(
x_dtype: torch.dtype,
M: int,
N: int,
):
"""Bitwise check: 2D NVFP4 with columnwise-only must produce the same
columnwise data/scales as the columnwise half of (rowwise + columnwise) 2D.

Covers both kernels depending on the (dtype, shape) routing:
- bf16 with rows % 32 == 0 and cols % 32 == 0 routes to the optimized
``quantize_transpose_nvfp4_2D_kernel`` (instantiated with RETURN_ROWWISE=false),
validating that gating the rowwise pass/store leaves the shared
``block_amax_matrix`` and columnwise output bitwise-identical to both-directions.
- non-bf16, or cols % 32 != 0, falls back to the columnwise-only 2D-amax-only
pass in ``quantize_transpose_vector_blockwise_fp4.cu``.
"""
te_dtype = tex.DType.kFloat4E2M1
device = "cuda"

torch.manual_seed(0)
torch.cuda.manual_seed(0)
x = torch.randn((M, N), dtype=x_dtype, device=device)

def _make_quantizer(*, rowwise: bool, columnwise: bool) -> NVFP4Quantizer:
return NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=rowwise,
columnwise=columnwise,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=True,
row_scaled_nvfp4=False,
)

# Reference: produce both directions in a single kernel call.
q_both = _make_quantizer(rowwise=True, columnwise=True)
out_both = q_both(x)

# SUT: produce columnwise only (the path that hits the new amax-only pass).
q_col_only = _make_quantizer(rowwise=False, columnwise=True)
out_col_only = q_col_only(x)

# Columnwise data/scales/amax must be bitwise identical between the two paths.
# If amax_smem is populated differently in the column-only path, scales diverge,
# and the FP4 cast (which divides by encode_scale) produces different bytes.
assert out_both._columnwise_data is not None
assert out_col_only._columnwise_data is not None
torch.testing.assert_close(
out_col_only._columnwise_data.view(dtype=torch.uint8),
out_both._columnwise_data.view(dtype=torch.uint8),
atol=0,
rtol=0,
)

# Compare only the valid (in-bounds) region of the columnwise scale tensor.
# The padded tail (rows K..round_up(K, 128), cols ceil(M/16)..round_up(.., 4))
# exists for cuBLAS alignment and is NEVER written by the kernel — its bytes
# are whatever ``at::empty`` returned, which differs between two allocations.
NVFP4_BLOCK = 16
valid_outer = N # cols of input == rows of columnwise scale tensor
valid_inner = (M + NVFP4_BLOCK - 1) // NVFP4_BLOCK
assert out_both._columnwise_scale_inv is not None
assert out_col_only._columnwise_scale_inv is not None
col_sx_both = out_both._columnwise_scale_inv.view(dtype=torch.uint8)
col_sx_col_only = out_col_only._columnwise_scale_inv.view(dtype=torch.uint8)
torch.testing.assert_close(
col_sx_col_only[:valid_outer, :valid_inner],
col_sx_both[:valid_outer, :valid_inner],
atol=0,
rtol=0,
)

assert out_both._amax_columnwise is not None
assert out_col_only._amax_columnwise is not None
torch.testing.assert_close(
out_col_only._amax_columnwise, out_both._amax_columnwise, atol=0, rtol=0
)

# Sanity: column-only path must not allocate a rowwise output.
assert out_col_only._rowwise_data is None
18 changes: 14 additions & 4 deletions transformer_engine/common/cast/dispatch/quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,13 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
"Row-scaled NVFP4 quantization does not produce columnwise output.");
nvfp4::compute_rowwise_amax(*input_tensor, noop_tensor, output_tensor, stream);
}
bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) &&
(cols % 32 == 0) && output_tensor->has_data();
// Columnwise-only is supported on the optimized path only for 2D scaling; rowwise-only and
// both-directions keep their existing routing. Columnwise-only 1D and non-bf16 fall back to
// quantize_transpose_vector_blockwise_fp4.
bool use_optimized_kernel =
(dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) &&
(output_tensor->has_data() ||
(output_tensor->has_columnwise_data() && quant_config_cpp.nvfp4_2d_quantization));

// Launch NVFP4 quantize kernel
if (use_optimized_kernel) {
Expand Down Expand Up @@ -251,8 +256,13 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
auto dtype = grad_tensor->dtype();
NVTE_CHECK(!output_tensor->row_scaled_nvfp4,
"Backward NVFP4 quantization does not support row-scaled outputs.");
bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) &&
(cols % 32 == 0) && output_tensor->has_data();
// Columnwise-only is supported on the optimized path only for 2D scaling; rowwise-only and
// both-directions keep their existing routing. Columnwise-only 1D and non-bf16 fall back to
// quantize_transpose_vector_blockwise_fp4.
bool use_optimized_kernel =
(dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) &&
(output_tensor->has_data() ||
(output_tensor->has_columnwise_data() && quant_config_cpp.nvfp4_2d_quantization));

// Launch NVFP4 quantize kernel
if (use_optimized_kernel) {
Expand Down
67 changes: 42 additions & 25 deletions transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ __global__ void __launch_bounds__(THREADS_NUM)
}

template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_ROWWISE, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM)
quantize_transpose_nvfp4_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output,
Expand Down Expand Up @@ -1128,7 +1128,7 @@ __global__ void __launch_bounds__(THREADS_NUM)
}

// ROWWISE scaling
{
if constexpr (RETURN_ROWWISE) {
const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) {
Expand Down Expand Up @@ -1271,9 +1271,11 @@ __global__ void __launch_bounds__(THREADS_NUM)
const size_t global_offset_Y_t = block_offset_Y_t;
const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y;

ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X, global_offset_Y,
reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out]));
if constexpr (RETURN_ROWWISE) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already inside the if constexpr (RETURN_ROWWISE) scope (starting at line 1131), so it can be removed safely.

ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out]));
}

if constexpr (RETURN_TRANSPOSE) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
Expand Down Expand Up @@ -1327,6 +1329,9 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
// return the transposed data.
// TODO(Frank): Is there a better way to do this?
bool return_transpose = output->has_columnwise_data();
// Columnwise-only (no rowwise output) is supported on the optimized 2D path; the rowwise pass
// and its store are gated out via the RETURN_ROWWISE template bool.
const bool return_rowwise = output->has_data();

if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) {
quantize_transpose_tuned_1D(input, noop, output, quant_config, stream);
Expand All @@ -1343,9 +1348,14 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
CheckOutputTensor(*output, "output", false);

NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated.");
NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
NVTE_CHECK(return_rowwise || return_transpose,
"At least one of rowwise/columnwise NVFP4 output must be allocated.");
NVTE_CHECK(return_rowwise || use_2d_quantization,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit confusing to read, especially if the kernel is extended in the future to support additional quantization schemes. It would be better to restrict the supported combinations explicitly, e.g.
NVTE_CHECK((return_transpose && use_2d_quantization) || (return_rowwise && !use_2d_quantization),

"Columnwise-only NVFP4 requires 2D quantization on the optimized path.");
if (return_rowwise) {
NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
}
NVTE_CHECK(!row_scaled_nvfp4 || output->amax.dptr != nullptr,
"Row-scaled NVFP4 quantization requires rowwise amax.");
NVTE_CHECK(!row_scaled_nvfp4 || !output->has_columnwise_data(),
Expand All @@ -1372,7 +1382,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_NUM;

const size_t scale_stride = output->scale_inv.shape[1];
const size_t scale_stride = return_rowwise ? output->scale_inv.shape[1] : 0;
const size_t scale_stride_transpose =
return_transpose ? output->columnwise_scale_inv.shape[1] : 0;

Expand Down Expand Up @@ -1405,8 +1415,10 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0,
sizeof(IType) * 8);

create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0,
4);
if (return_rowwise) {
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols,
0, 4);
}
if (return_transpose) {
create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows,
BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4);
Expand All @@ -1433,21 +1445,26 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
use_stochastic_rounding, USE_STOCHASTIC_ROUNDING,

TRANSFORMER_ENGINE_SWITCH_CONDITION(row_scaled_nvfp4, ROW_SCALED_NVFP4, {
TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, {
auto kernel = quantize_transpose_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType,
USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE,
ROW_SCALED_NVFP4>;

if constexpr (use_2d_quantization) {
kernel = quantize_transpose_nvfp4_2D_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType,
USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE>;
}
TRANSFORMER_ENGINE_SWITCH_CONDITION(return_rowwise, RETURN_ROWWISE, {
TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, {
// The 1D kernel always produces rowwise output (no RETURN_ROWWISE); the dispatch only
// routes columnwise-only requests here when use_2d_quantization is true.
auto kernel = quantize_transpose_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType,
USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE,
ROW_SCALED_NVFP4>;

if constexpr (use_2d_quantization) {
kernel = quantize_transpose_nvfp4_2D_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType,
USE_STOCHASTIC_ROUNDING, RETURN_ROWWISE,
RETURN_TRANSPOSE>;
}

cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr,
scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols,
scale_stride, scale_stride_transpose, rng_state);
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr,
scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols,
scale_stride, scale_stride_transpose, rng_state);
});
});
}););
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,14 +353,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
extern __shared__ char smem_base[];
SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]);

// 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode.
// Instead of static_assert, return early if these invalid modes are detected.
// 2D block scaling is not supported for E8 scaling MXFP4.
// Instead of static_assert, return early if this invalid mode is detected.
if constexpr (kIs2DBlockScaling && kIsE8Scaling) {
return;
}
if constexpr (kIs2DBlockScaling && !kReturnIdentity) {
return;
}
// for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4
// use constexpr to define the size, when not using 2D, use minimal size 1x1
constexpr int kFP4BlockScalingSize = 16;
Expand Down Expand Up @@ -576,6 +573,67 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}

// Step 2.5: 2D-amax-only pass for columnwise-only mode.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Step label collision with existing substep

The new outer-level block is named "Step 2.5" at line 576, but that same label is already used at line 522 for the "Write scale_inv" substep inside Step 2's loop (if constexpr (kReturnIdentity)). A future reader scanning the file will find two distinct "Step 2.5" sections with different semantics. Consider renaming the new block to something like "Step 2b" or "Step 2.5 (outer)" to distinguish it from the // Step 2.5: Write scale_inv substep inside the inner loop.

// When only the transposed output is requested but 2D block scaling is enabled, the columnwise
// reads in Step 3 (line ~660 below) still need amax_smem populated. Re-run the load + local-amax
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment refers to line ~660, which is now line 637. Let’s maybe remove the line reference entirely to avoid confusion.

// + 2D warp/smem reduction from Step 2 (steps 2.1-2.3), skipping the rowwise scale/quantize/store
// writes that Step 2 normally does. Same amax_smem values as the rowwise-enabled path, so the
// dgrad/wgrad columnwise output of (rowwise=False, columnwise=True, 2D) is bitwise identical to
// the columnwise half of (rowwise=True, columnwise=True, 2D).
if constexpr (!kReturnIdentity && kIs2DBlockScaling) {
constexpr int r_stride =
kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride; // 4 iterations for kTileDim=128
const int c_s =
(threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut / kNVecSMem];
// Step 2.1 (amax-only): Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem_vec[i] = smem[r * kSMemCol + c];
}
// Step 2.2 (amax-only): Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j]));
}
}
// Step 2.3 (amax-only): 2D warp + smem amax reduction (mirrors Step 2's 2D path)
constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; // 32
int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7
int tid_in_warp_x = threadIdx.x % kNumThreadsStore;
int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp;
CType amax_warp_reduced = groupMax<kNumRowsPerWarp, kNumThreadsStore>(
amax, WARP_REDUCE_AMAX_GROUP_MASKS[tid_in_warp_x]);
int data_row_idx = iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y;
if (tid_in_warp_y == 0) {
amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x]
[warp_idx % k2DBlockAmaxReduceDim] = amax_warp_reduced;
}
__syncthreads();

if (data_row_idx % kFP4BlockScalingSize == 0) {
CType amax_2d = 0.0;
for (int i = 0; i < k2DBlockAmaxReduceDim; i++) {
amax_2d =
fmaxf(amax_2d, amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]);
}
amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] = amax_2d;
}
__syncthreads();
r_s += r_stride;
}
}

// Step 3: Transpose, cast and store to output_t
if constexpr (kReturnTranspose) {
constexpr int c_stride =
Expand Down Expand Up @@ -731,8 +789,6 @@ void quantize_transpose_vector_blockwise_fp4(
NVTE_CHECK(return_identity || return_transpose,
"At least one of return_identity or return_transpose must be true.");

NVTE_CHECK(return_identity || !use_2d_quantization,
"2D block quantization is only supported when return_identity is true.");
NVTE_CHECK(!row_scaled_nvfp4 || (return_identity && !return_transpose),
"Row-scaled NVFP4 quantization only supports rowwise quantization.");
NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization,
Expand Down
Loading