Conversation
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Greptile SummaryThis PR adds NVFP4 partial cast support for distributed training with ZeRO/FSDP optimizers, enabling efficient 4-bit weight quantization with coordinated scaling across data parallel ranks. Key Changes:
Critical Issue:
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Master Weights FP32] --> B[quantize_master_weights]
B --> C{Quantizer Type?}
C -->|NVFP4| D[_cast_master_weights_to_nvfp4_2d]
C -->|FP8 Delayed| E[_cast_master_weights_to_fp8_delayed_scaling]
C -->|FP8 Current| F[_cast_master_weights_to_fp8_current_scaling]
C -->|FP8 Block| G[_cast_master_weights_to_fp8_blockwise_scaling]
C -->|MXFP8| H[_cast_master_weights_to_fp8_mxfp8_scaling]
D --> D1[Batched dtype conversion<br/>torch.cat/split]
D1 --> D2[nvfp4_multi_tensor_compute_partial_amax<br/>Compute per-block + global amax]
D2 --> D3[AllReduce amax across DP ranks]
D3 --> D4[nvfp4_compute_global_scale<br/>GPU kernel for scale computation]
D4 --> D5[nvfp4_multi_tensor_fused_scale<br/>Fuse scale ops + expand to FP8]
D5 --> D6[nvfp4_multi_tensor_2d_partial_cast<br/>Nibble-accurate partial updates]
D6 --> I[NVFP4 Model Weights]
E --> I
F --> I
G --> I
H --> I
I --> J[AllGather across DP ranks]
J --> K[post_all_gather_processing]
K --> K1{NVFP4?}
K1 -->|Yes| L[_nvfp4_2d_multi_tensor_transpose<br/>Create columnwise data]
K1 -->|No| M[Other transpose logic]
L --> N[Ready for GEMM]
M --> N
Last reviewed commit: 6ccb301 |
This comment was marked as outdated.
This comment was marked as outdated.
| start_offsets, | ||
| group, | ||
| fsdp_shard_model_weights=None, | ||
| manual_post_all_gather_processing=False, |
There was a problem hiding this comment.
We added this kwarg to the FP8 functions for backward compatibility, but there's no point keeping them for these brand-new NVFP4 APIs:
| manual_post_all_gather_processing=False, |
There was a problem hiding this comment.
fsdp_shard_model_weights=None is for future FSDP support. It's in the plan.
manual_post_all_gather_processing is also needed for the same reason as FP8 blockwise scaling:
https://github.com/WanZzzzzz/TransformerEngine/blob/38b92b1a168dcfaa6242fea50f03e5a1b873e3a0/transformer_engine/pytorch/tensor/utils.py#L535
There was a problem hiding this comment.
I see, that makes sense for now then. Let's change the default to True though since that's preferred.
I want to flag a potential future problem with manual_post_all_gather_processing=False: it assumes that the quantized tensor has some way to handle the post-processing automatically. For FP8 on Hopper:
cast_master_weights_to_fp8(..., manual_post_all_gather_processing=False)
torch.all_gather(...)
y = model(x) # Float8Tensor internally performs FP8 transposeThis is not something TE will guarantee for future data formats. Maybe the next recipe has some interleaved format:
cast_master_weights_to_futureformat(...)
torch.all_gather(...)
fix_futureformat_interleaving(...)
y = model(x) # FutureFormatTensor assumes data is interleavedIn this case, we should throw an error with the user passes manual_post_all_gather_processing=False and it should be Mcore's responsibility to perform the post-processing in a way that's friendly to overlapping.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
Outdated
Show resolved
Hide resolved
| if isinstance(self.weights[0], QuantizedTensor): | ||
| weight_buffer_dtype = torch.uint8 | ||
| if self.weights_are_nvfp4: | ||
| weight_buffer_length = self.storage_total | ||
| buffer_rank_start = storage_rank_start | ||
| buffer_rank_end = storage_rank_end | ||
| else: | ||
| weight_buffer_length = self.offsets[-1] | ||
| buffer_rank_start = rank_start | ||
| buffer_rank_end = rank_end | ||
| else: | ||
| weight_buffer_dtype = weights[0].dtype | ||
| weight_buffer_length = self.offsets[-1] | ||
| buffer_rank_start = rank_start | ||
| buffer_rank_end = rank_end |
There was a problem hiding this comment.
Nit: It's a bit convoluted, isn't it? It would be much nicer to disentangle the quantization logic from the buffer allocation by computing storage offsets in all cases (even if it's trivial for non-NVFP4 cases) and then using that blindly here.
Signed-off-by: qiyuw <qiyuw@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
This comment was marked as resolved.
This comment was marked as resolved.
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
|
/te-ci L1 |
timmoon10
left a comment
There was a problem hiding this comment.
Overall LGTM, although there are some test failures related to missing licenses and linter warnings. I also still have some nits, although they are not blocking.
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
| amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax) | ||
|
|
||
| # Compare scale | ||
| scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale) | ||
|
|
||
| # Compare data | ||
| data_match = torch.equal(test_tensor._rowwise_data, ref_data) |
There was a problem hiding this comment.
Test never validates results - will always pass
The match variables are computed but never asserted, making this test ineffective:
| amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax) | |
| # Compare scale | |
| scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale) | |
| # Compare data | |
| data_match = torch.equal(test_tensor._rowwise_data, ref_data) | |
| # Compare amax | |
| amax_match = torch.equal(test_tensor._amax_rowwise, ref_amax) | |
| assert amax_match, f"Amax mismatch: {test_tensor._amax_rowwise} vs {ref_amax}" | |
| # Compare scale | |
| scale_match = torch.equal(test_tensor._rowwise_scale_inv, ref_scale) | |
| assert scale_match, f"Scale mismatch: {test_tensor._rowwise_scale_inv} vs {ref_scale}" | |
| # Compare data | |
| data_match = torch.equal(test_tensor._rowwise_data, ref_data) | |
| assert data_match, f"Data mismatch" |
Description
This PR adds NVFP4 partial cast support for distributed training with ZeRO/FSDP optimizers. It enables efficient casting of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks, while minimizing CPU overhead in large-scale training.
Type of change
Changes
This PR introduces NVFP4 partial cast infrastructure and optimizations for distributed training:
NVFP4 Partial Cast Kernel (
nvfp4_2d_partial_cast)NVFP4 Transpose Kernel (
nvfp4_transpose)uint2loads/stores with 64×64 tiles for efficient memory accessFused Scale Kernel (
nvfp4_fused_scale)Multi-Tensor Dispatch Pattern
CPU Overhead Optimizations
torch.cat/torch.splittorch.zeros()withtorch.empty()for immediately written buffersScale Computation Improvements
New Public API
cast_master_weights_to_nvfp4()Testing
test_nvfp4_transpose_kerneltest_nvfp4_partial_cast_matches_fulltest_single_gpu_partial_cast_vs_full_test_cast_master_weights_to_nvfp4This feature also passed numeric validation in GPT-3 training on the corresponding Megatron-Core branch:
https://gitlab-master.nvidia.com/qiyuw/megatron-lm-all/-/tree/fp4_primary_opt?ref_type=heads
Checklist: