Add fused_adam, quantized_model_init, and fsdp2 example#2698
Add fused_adam, quantized_model_init, and fsdp2 example#2698pstjohn wants to merge 7 commits intoNVIDIA:mainfrom
Conversation
22604c4 to
4d89e04
Compare
Greptile SummaryThis PR enables Key Changes
Issues
The implementation correctly handles the complexity of nested tensor wrappers (DTensor wrapping QuantizedTensor) and properly extracts local tensors at each step. The pickle fixes follow established patterns and are necessary for distributed checkpointing. Confidence Score: 5/5
Important Files Changed
Last reviewed commit: fa32ac6 |
| # to get a plain float32 copy for the master weight. | ||
| local_param = param._local_tensor if isinstance(param, DTensor) else param | ||
| if isinstance(local_param, QuantizedTensor): | ||
| master = local_param.dequantize().clone().detach().float() |
There was a problem hiding this comment.
Should we use dequantize(dtype=torch.float32), to fuse the cast into the de-quantization's output buffer? (Likely not a big deal since I don't think this will change anything numerically, and you only call this function during init and whenever you save and load DCP checkpoints.)
0103b53 to
3c3dbd2
Compare
|
@pstjohn Hi, thanks for the great work! Does this PR plan to also handle the BF16 path? I noticed the BF16 branch still operates on the original p/p_grad without unwrapping when they're DTensors. In my experiments with FSDP2 + BF16, I'm seeing non-trivial overhead during the optimizer step from repeated DTensor dispatch. Curious if that's intentional or a planned follow-up. |
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
a4d691f to
872caef
Compare
| ("Float8CurrentScaling", fp8.check_fp8_support), | ||
| ("Float8BlockScaling", fp8.check_fp8_block_scaling_support), | ||
| ("MXFP8BlockScaling", fp8.check_mxfp8_support), | ||
| ["NVFP4BlockScaling", fp8.check_nvfp4_support], |
There was a problem hiding this comment.
square brackets create a list instead of a tuple, inconsistent with other entries
| ["NVFP4BlockScaling", fp8.check_nvfp4_support], | |
| ("NVFP4BlockScaling", fp8.check_nvfp4_support), |
| def test_fsdp2_fused_adam_fp8_master_weights(fp_recipe): | ||
| """FusedAdam(master_weights=True) + FSDP2 + quantized_model_init.""" | ||
| _run_fused_adam_test("fused_adam_fp8_master_weights", fp_recipe) |
There was a problem hiding this comment.
the fp_recipe fixture returns PascalCase class names (e.g. "DelayedScaling"), but run_fsdp2_fused_adam.py expects snake_case in its argparse choices (["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"]). argparse will reject the command with an error.
| def test_distributed(fp8_init, sharding_dims, fp_recipe, layer_type): | ||
| _run_test(fp8_init, sharding_dims, fp_recipe, layer_type) |
There was a problem hiding this comment.
same issue: fp_recipe returns PascalCase class names, but run_fsdp2_model.py argparse expects snake_case (["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"])
| def get_recipe_from_string(recipe, fp8_format=Format.HYBRID): | ||
| if recipe == "delayed_scaling": | ||
| return DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") | ||
| elif recipe == "current_scaling": | ||
| return Float8CurrentScaling(fp8_format=fp8_format) | ||
| elif recipe == "mx_fp8_block_scaling": | ||
| return MXFP8BlockScaling(fp8_format=fp8_format) | ||
| else: | ||
| raise ValueError(f"Unknown quantizer type: {recipe}") | ||
| return getattr(transformer_engine.common.recipe, recipe)(fp8_format=fp8_format) |
There was a problem hiding this comment.
this refactoring requires the recipe parameter to be a PascalCase class name (e.g. "DelayedScaling"), but line 52 argparse choices still use snake_case (["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"]). when the test passes PascalCase names, argparse will reject them
| # Save distributed checkpoint. | ||
| checkpoint_dir = "/tmp/te_test_fsdp2_dcp_checkpoint" | ||
|
|
||
| if isinstance(recipe, DelayedScaling): |
There was a problem hiding this comment.
DelayedScaling is not imported. Add from transformer_engine.common.recipe import DelayedScaling to the imports at the top of the file.
| if isinstance(recipe, DelayedScaling): | |
| if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling): |
| @pytest.mark.parametrize("async_save", (False, True)) | ||
| def test_fsdp2_dcp_output_parity(fp_recipe, async_save): | ||
| """DCP save/load round-trip into a fresh model produces identical outputs.""" | ||
| _run_fused_adam_test("dcp_output_parity", fp_recipe, async_save) |
There was a problem hiding this comment.
_run_fused_adam_test only accepts 2 parameters (test_name, recipe), but 3 are passed here. The async_save parameter cannot be passed through this helper function.
Either remove async_save from the test, or modify _run_fused_adam_test to accept and forward it as a command-line argument.
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| def test_dcp_output_parity(recipe=None, async_save=False): |
There was a problem hiding this comment.
async_save parameter is defined but never used in the function body.
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
| ("Float8CurrentScaling", fp8.check_fp8_support), | ||
| ("Float8BlockScaling", fp8.check_fp8_block_scaling_support), | ||
| ("MXFP8BlockScaling", fp8.check_mxfp8_support), | ||
| ["NVFP4BlockScaling", fp8.check_nvfp4_support], |
There was a problem hiding this comment.
inconsistent with other entries: should use tuple parentheses (...) instead of list brackets [...]
| ["NVFP4BlockScaling", fp8.check_nvfp4_support], | |
| ("NVFP4BlockScaling", fp8.check_nvfp4_support), |
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
| ("Float8CurrentScaling", fp8.check_fp8_support), | ||
| ("Float8BlockScaling", fp8.check_fp8_block_scaling_support), | ||
| ("MXFP8BlockScaling", fp8.check_mxfp8_support), | ||
| ["NVFP4BlockScaling", fp8.check_nvfp4_support], |
There was a problem hiding this comment.
use tuple parentheses instead of list brackets for consistency with other entries
| ["NVFP4BlockScaling", fp8.check_nvfp4_support], | |
| ("NVFP4BlockScaling", fp8.check_nvfp4_support), |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Summary
FusedAdamto work with PyTorch-native FSDP2 (fully_shard) when parameters areDTensor-wrappedFloat8Tensor/QuantizedTensorfuse_wgrad_accumulationguard to avoid crashing with vanilla FSDP2 (previously assumed Megatron-Core FSDP exclusively)quantized_model_initon single-GPU (main.py) and multi-GPU FSDP2 (fully_shard.py)Note:
fuse_wgrad_accumulationremains incompatible with vanilla FSDP2fuse_wgrad_accumulationstill cannot be used with vanilla FSDP2. The feature writes weight gradients directly intomain_gradand returnsNoneto autograd, bypassing FSDP2's reduce-scatter. Each rank ends up with an unreduced gradient. Megatron-Core FSDP solves this by wiringget_main_grad()into its own reduce-scatter infrastructure. Vanilla FSDP2 does not yet expose an equivalent hook.Fixes #2682