Skip to content

NVFP4 primary weight support#2691

Open
WanZzzzzz wants to merge 9 commits intoNVIDIA:mainfrom
WanZzzzzz:fp4_native_weights
Open

NVFP4 primary weight support#2691
WanZzzzzz wants to merge 9 commits intoNVIDIA:mainfrom
WanZzzzzz:fp4_native_weights

Conversation

@WanZzzzzz
Copy link

Description

This PR adds NVFP4 partial cast support for distributed training with ZeRO/FSDP optimizers. It enables efficient casting of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks, while minimizing CPU overhead in large-scale training.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

This PR introduces NVFP4 partial cast infrastructure and optimizations for distributed training:

NVFP4 Partial Cast Kernel (nvfp4_2d_partial_cast)

  • Implements nibble-accurate partial updates for NVFP4 tensors in distributed settings
  • Supports two-level NVFP4 scaling: global FP32 scale + per-block FP8 E4M3 scale

NVFP4 Transpose Kernel (nvfp4_transpose)

  • Custom transpose kernel for nibble-packed NVFP4 data with shared memory optimization
  • Uses vectorized uint2 loads/stores with 64×64 tiles for efficient memory access
  • Handles nibble repacking during transpose (unlike FP8 byte transpose)
  • Enables columnwise data generation for GEMM operations after rowwise AllGather

Fused Scale Kernel (nvfp4_fused_scale)

  • Fuses per-block scale computation, global amax copy, and FP8 scale expansion into a single kernel
  • Eliminates multiple kernel launches and avoids D2H transfers by accepting tensor pointers
  • Reduces kernel launch overhead in the critical path

Multi-Tensor Dispatch Pattern

  • C++-side loop dispatch for NVFP4 multi-tensor operations
  • Reduces Python–C++ transition overhead compared to per-tensor Python loops
  • Collects metadata in Python and executes batched operations in C++ wrappers

CPU Overhead Optimizations

  • Batched dtype conversion via torch.cat / torch.split
  • Replaced torch.zeros() with torch.empty() for immediately written buffers
  • Consolidated metadata collection and allocation phases
  • Optimized bucket partitioning for expert parallel buffers

Scale Computation Improvements

  • Fixed floating-point precision mismatch between Python and CUDA
  • Uses FP32 constants consistent with CUDA arithmetic
  • Ensures bitwise-identical results between partial and full quantization paths

New Public API

cast_master_weights_to_nvfp4()

  • Casts FP32 master weights to NVFP4 model weights
  • Handles global and per-block amax reduction across data parallel groups
  • Designed for low CPU overhead in distributed training loops

Testing

Test Description
test_nvfp4_transpose_kernel Verifies correctness for nibble-packed transpose
test_nvfp4_partial_cast_matches_full Multi-GPU: partial cast + all-gather equals full cast
test_single_gpu_partial_cast_vs_full Single-GPU: offset=0 partial cast matches reference quantizer
_test_cast_master_weights_to_nvfp4 500-iteration training loop with bitwise-identical loss

This feature also passed numeric validation in GPT-3 training on the corresponding Megatron-Core branch:

https://gitlab-master.nvidia.com/qiyuw/megatron-lm-all/-/tree/fp4_primary_opt?ref_type=heads

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: qiyuw <qiyuw@nvidia.com>
@WanZzzzzz WanZzzzzz mentioned this pull request Feb 19, 2026
13 tasks
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 19, 2026

Greptile Summary

This PR adds NVFP4 partial cast support for distributed training with ZeRO/FSDP optimizers, enabling efficient 4-bit weight quantization with coordinated scaling across data parallel ranks.

Key Changes:

  • Implements nibble-accurate partial cast kernels for NVFP4 with two-level scaling (global FP32 + per-block FP8 E4M3)
  • Adds custom transpose kernel handling nibble repacking (unlike FP8 byte transpose)
  • Introduces multi-tensor batching pattern to reduce Python-C++ overhead
  • Provides new public API quantize_master_weights() with backward-compatible cast_master_weights_to_fp8()
  • Optimizes CPU overhead with batched dtype conversion and fused scale kernels

Critical Issue:

  • test_single_gpu_partial_cast_vs_full() computes match variables but never asserts on them (lines 1223-1229), causing the test to always pass regardless of correctness

Confidence Score: 3/5

  • This PR requires fixes before merging due to a non-functional test
  • Score reflects strong implementation quality in the CUDA kernels and Python infrastructure, but the critical test bug prevents validation of correctness. The NVFP4 kernels appear well-designed with proper nibble handling and the multi-tensor batching is a solid optimization. However, without working tests, we cannot verify the partial cast logic produces correct results.
  • Pay close attention to tests/pytorch/distributed/test_cast_master_weights_to_fp8.py - the missing assertions must be added before merge

Important Files Changed

Filename Overview
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py Missing assertions in test_single_gpu_partial_cast_vs_full (lines 1223-1229), causing test to always pass
transformer_engine/pytorch/tensor/utils.py Adds quantize_master_weights and batched NVFP4 casting logic with multi-tensor optimizations
transformer_engine/common/recipe/nvfp4.cu Implements NVFP4 partial cast and transpose CUDA kernels with nibble-accurate updates
transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp C++ wrappers for NVFP4 partial cast operations with input validation
transformer_engine/pytorch/csrc/extensions/transpose.cpp Adds NVFP4 transpose, scale operations, and multi-tensor batching infrastructure

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Master Weights FP32] --> B[quantize_master_weights]
    B --> C{Quantizer Type?}
    C -->|NVFP4| D[_cast_master_weights_to_nvfp4_2d]
    C -->|FP8 Delayed| E[_cast_master_weights_to_fp8_delayed_scaling]
    C -->|FP8 Current| F[_cast_master_weights_to_fp8_current_scaling]
    C -->|FP8 Block| G[_cast_master_weights_to_fp8_blockwise_scaling]
    C -->|MXFP8| H[_cast_master_weights_to_fp8_mxfp8_scaling]
    
    D --> D1[Batched dtype conversion<br/>torch.cat/split]
    D1 --> D2[nvfp4_multi_tensor_compute_partial_amax<br/>Compute per-block + global amax]
    D2 --> D3[AllReduce amax across DP ranks]
    D3 --> D4[nvfp4_compute_global_scale<br/>GPU kernel for scale computation]
    D4 --> D5[nvfp4_multi_tensor_fused_scale<br/>Fuse scale ops + expand to FP8]
    D5 --> D6[nvfp4_multi_tensor_2d_partial_cast<br/>Nibble-accurate partial updates]
    D6 --> I[NVFP4 Model Weights]
    
    E --> I
    F --> I
    G --> I
    H --> I
    
    I --> J[AllGather across DP ranks]
    J --> K[post_all_gather_processing]
    K --> K1{NVFP4?}
    K1 -->|Yes| L[_nvfp4_2d_multi_tensor_transpose<br/>Create columnwise data]
    K1 -->|No| M[Other transpose logic]
    L --> N[Ready for GEMM]
    M --> N
Loading

Last reviewed commit: 6ccb301

greptile-apps[bot]

This comment was marked as resolved.

@timmoon10

This comment was marked as outdated.

start_offsets,
group,
fsdp_shard_model_weights=None,
manual_post_all_gather_processing=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added this kwarg to the FP8 functions for backward compatibility, but there's no point keeping them for these brand-new NVFP4 APIs:

Suggested change
manual_post_all_gather_processing=False,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fsdp_shard_model_weights=None is for future FSDP support. It's in the plan.
manual_post_all_gather_processing is also needed for the same reason as FP8 blockwise scaling:
https://github.com/WanZzzzzz/TransformerEngine/blob/38b92b1a168dcfaa6242fea50f03e5a1b873e3a0/transformer_engine/pytorch/tensor/utils.py#L535

Copy link
Collaborator

@timmoon10 timmoon10 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that makes sense for now then. Let's change the default to True though since that's preferred.

I want to flag a potential future problem with manual_post_all_gather_processing=False: it assumes that the quantized tensor has some way to handle the post-processing automatically. For FP8 on Hopper:

cast_master_weights_to_fp8(..., manual_post_all_gather_processing=False)
torch.all_gather(...)

y = model(x)  # Float8Tensor internally performs FP8 transpose

This is not something TE will guarantee for future data formats. Maybe the next recipe has some interleaved format:

cast_master_weights_to_futureformat(...)
torch.all_gather(...)
fix_futureformat_interleaving(...)

y = model(x)  # FutureFormatTensor assumes data is interleaved

In this case, we should throw an error with the user passes manual_post_all_gather_processing=False and it should be Mcore's responsibility to perform the post-processing in a way that's friendly to overlapping.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, note it down.

Comment on lines 245 to 259
if isinstance(self.weights[0], QuantizedTensor):
weight_buffer_dtype = torch.uint8
if self.weights_are_nvfp4:
weight_buffer_length = self.storage_total
buffer_rank_start = storage_rank_start
buffer_rank_end = storage_rank_end
else:
weight_buffer_length = self.offsets[-1]
buffer_rank_start = rank_start
buffer_rank_end = rank_end
else:
weight_buffer_dtype = weights[0].dtype
weight_buffer_length = self.offsets[-1]
buffer_rank_start = rank_start
buffer_rank_end = rank_end
Copy link
Collaborator

@timmoon10 timmoon10 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: It's a bit convoluted, isn't it? It would be much nicer to disentangle the quantization logic from the buffer allocation by computing storage offsets in all cases (even if it's trivial for non-NVFP4 cases) and then using that blindly here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

qiyuw and others added 2 commits February 20, 2026 05:52
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: qiyuw <qiyuw@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

@greptile-apps

This comment was marked as resolved.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@timmoon10
Copy link
Collaborator

/te-ci L1

timmoon10
timmoon10 previously approved these changes Feb 21, 2026
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, although there are some test failures related to missing licenses and linter warnings. I also still have some nits, although they are not blocking.

@timmoon10 timmoon10 self-requested a review February 21, 2026 00:09
@timmoon10 timmoon10 dismissed their stale review February 21, 2026 00:09

Test failures

Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Comment on lines +1223 to +1229
amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax)

# Compare scale
scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale)

# Compare data
data_match = torch.equal(test_tensor._rowwise_data, ref_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test never validates results - will always pass

The match variables are computed but never asserted, making this test ineffective:

Suggested change
amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax)
# Compare scale
scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale)
# Compare data
data_match = torch.equal(test_tensor._rowwise_data, ref_data)
# Compare amax
amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax)
assert amax_match, f"Amax mismatch: {test_tensor._amax_rowwise} vs {ref_amax}"
# Compare scale
scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale)
assert scale_match, f"Scale mismatch: {test_tensor._rowwise_scale_inv} vs {ref_scale}"
# Compare data
data_match = torch.equal(test_tensor._rowwise_data, ref_data)
assert data_match, f"Data mismatch"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants