Skip to content

Add fused_adam, quantized_model_init, and fsdp2 example#2698

Open
pstjohn wants to merge 7 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fsdp2-fused-adam
Open

Add fused_adam, quantized_model_init, and fsdp2 example#2698
pstjohn wants to merge 7 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fsdp2-fused-adam

Conversation

@pstjohn
Copy link
Contributor

@pstjohn pstjohn commented Feb 22, 2026

Summary

  • Fix FusedAdam to work with PyTorch-native FSDP2 (fully_shard) when parameters are DTensor-wrapped Float8Tensor/QuantizedTensor
  • Fix fuse_wgrad_accumulation guard to avoid crashing with vanilla FSDP2 (previously assumed Megatron-Core FSDP exclusively)
  • Add examples for quantized_model_init on single-GPU (main.py) and multi-GPU FSDP2 (fully_shard.py)

Note: fuse_wgrad_accumulation remains incompatible with vanilla FSDP2

fuse_wgrad_accumulation still cannot be used with vanilla FSDP2. The feature writes weight gradients directly into main_grad and returns None to autograd, bypassing FSDP2's reduce-scatter. Each rank ends up with an unreduced gradient. Megatron-Core FSDP solves this by wiring get_main_grad() into its own reduce-scatter infrastructure. Vanilla FSDP2 does not yet expose an equivalent hook.

Fixes #2682

@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch 2 times, most recently from 22604c4 to 4d89e04 Compare February 23, 2026 15:28
@pstjohn pstjohn marked this pull request as ready for review February 23, 2026 17:27
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 23, 2026

Greptile Summary

This PR enables FusedAdam optimizer and quantized_model_init to work with PyTorch-native FSDP2 (fully_shard) by adding DTensor and QuantizedTensor handling throughout the stack. The changes are well-structured and address a real compatibility gap.

Key Changes

  • FusedAdam DTensor Support: Extracts local tensors from FSDP2's DTensor wrappers before passing to multi_tensor CUDA kernels, and properly handles QuantizedTensor dequantization when initializing optimizer states
  • Pickle Compatibility: Adds __getstate__ methods to quantizers to exclude unpicklable process groups, enabling DCP checkpointing
  • Comprehensive Testing: New test suite covering FP8 master weights, BF16 params, store_param_remainders, DCP save/load, and expected failure cases
  • Documentation & Examples: Clear examples demonstrating single-GPU and multi-GPU FSDP2 usage with quantized model initialization

Issues

  • Minor style inconsistency: line 23 of test_torch_fsdp2.py uses list brackets instead of tuple parentheses

The implementation correctly handles the complexity of nested tensor wrappers (DTensor wrapping QuantizedTensor) and properly extracts local tensors at each step. The pickle fixes follow established patterns and are necessary for distributed checkpointing.

Confidence Score: 5/5

  • This PR is safe to merge with only a minor style inconsistency
  • The changes are well-tested with comprehensive test coverage for multiple scenarios (FP8 master weights, BF16, DCP checkpointing), follow existing code patterns, and address a clear compatibility issue. The DTensor/QuantizedTensor handling is correct and necessary for FSDP2 support. Only one minor style issue was found.
  • No files require special attention - all changes follow best practices

Important Files Changed

Filename Overview
transformer_engine/pytorch/optimizers/fused_adam.py Adds DTensor and QuantizedTensor support for FSDP2 compatibility. Extracts local tensors before passing to multi_tensor kernels and properly handles FP8 master weight initialization.
transformer_engine/pytorch/tensor/float8_tensor.py Adds __getstate__ to Float8CurrentScalingQuantizer to exclude unpicklable process group from serialized state, enabling DCP checkpointing.
tests/pytorch/distributed/test_torch_fsdp2.py Adds comprehensive test suite for FSDP2 + FusedAdam with various configurations. Line 23 uses list brackets instead of tuple (style inconsistency).
tests/pytorch/distributed/run_fsdp2_fused_adam.py New test runner for FSDP2 + FusedAdam scenarios. Covers FP8 master weights, BF16, store_param_remainders, DCP checkpointing, and expected failures.
examples/pytorch/quantized_model_init/main.py Clear single-GPU example demonstrating quantized_model_init + FusedAdam + gradient accumulation fusion with good inline documentation.

Last reviewed commit: fa32ac6

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 9 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Member

@cspades cspades left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, clean edits.

# to get a plain float32 copy for the master weight.
local_param = param._local_tensor if isinstance(param, DTensor) else param
if isinstance(local_param, QuantizedTensor):
master = local_param.dequantize().clone().detach().float()
Copy link
Member

@cspades cspades Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use dequantize(dtype=torch.float32), to fuse the cast into the de-quantization's output buffer? (Likely not a big deal since I don't think this will change anything numerically, and you only call this function during init and whenever you save and load DCP checkpoints.)

@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch from 0103b53 to 3c3dbd2 Compare February 24, 2026 20:06
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@XueSongTap
Copy link

@pstjohn Hi, thanks for the great work! Does this PR plan to also handle the BF16 path? I noticed the BF16 branch still operates on the original p/p_grad without unwrapping when they're DTensors. In my experiments with FSDP2 + BF16, I'm seeing non-trivial overhead during the optimizer step from repeated DTensor dispatch. Curious if that's intentional or a planned follow-up.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch from a4d691f to 872caef Compare February 26, 2026 15:11
("Float8CurrentScaling", fp8.check_fp8_support),
("Float8BlockScaling", fp8.check_fp8_block_scaling_support),
("MXFP8BlockScaling", fp8.check_mxfp8_support),
["NVFP4BlockScaling", fp8.check_nvfp4_support],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

square brackets create a list instead of a tuple, inconsistent with other entries

Suggested change
["NVFP4BlockScaling", fp8.check_nvfp4_support],
("NVFP4BlockScaling", fp8.check_nvfp4_support),

Comment on lines +98 to +100
def test_fsdp2_fused_adam_fp8_master_weights(fp_recipe):
"""FusedAdam(master_weights=True) + FSDP2 + quantized_model_init."""
_run_fused_adam_test("fused_adam_fp8_master_weights", fp_recipe)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fp_recipe fixture returns PascalCase class names (e.g. "DelayedScaling"), but run_fsdp2_fused_adam.py expects snake_case in its argparse choices (["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"]). argparse will reject the command with an error.

Comment on lines +73 to +74
def test_distributed(fp8_init, sharding_dims, fp_recipe, layer_type):
_run_test(fp8_init, sharding_dims, fp_recipe, layer_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same issue: fp_recipe returns PascalCase class names, but run_fsdp2_model.py argparse expects snake_case (["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"])

Comment on lines 112 to +113
def get_recipe_from_string(recipe, fp8_format=Format.HYBRID):
if recipe == "delayed_scaling":
return DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
elif recipe == "current_scaling":
return Float8CurrentScaling(fp8_format=fp8_format)
elif recipe == "mx_fp8_block_scaling":
return MXFP8BlockScaling(fp8_format=fp8_format)
else:
raise ValueError(f"Unknown quantizer type: {recipe}")
return getattr(transformer_engine.common.recipe, recipe)(fp8_format=fp8_format)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this refactoring requires the recipe parameter to be a PascalCase class name (e.g. "DelayedScaling"), but line 52 argparse choices still use snake_case (["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"]). when the test passes PascalCase names, argparse will reject them

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
# Save distributed checkpoint.
checkpoint_dir = "/tmp/te_test_fsdp2_dcp_checkpoint"

if isinstance(recipe, DelayedScaling):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DelayedScaling is not imported. Add from transformer_engine.common.recipe import DelayedScaling to the imports at the top of the file.

Suggested change
if isinstance(recipe, DelayedScaling):
if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling):

@pytest.mark.parametrize("async_save", (False, True))
def test_fsdp2_dcp_output_parity(fp_recipe, async_save):
"""DCP save/load round-trip into a fresh model produces identical outputs."""
_run_fused_adam_test("dcp_output_parity", fp_recipe, async_save)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_run_fused_adam_test only accepts 2 parameters (test_name, recipe), but 3 are passed here. The async_save parameter cannot be passed through this helper function.

Either remove async_save from the test, or modify _run_fused_adam_test to accept and forward it as a command-line argument.

dist.destroy_process_group()


def test_dcp_output_parity(recipe=None, async_save=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async_save parameter is defined but never used in the function body.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
("Float8CurrentScaling", fp8.check_fp8_support),
("Float8BlockScaling", fp8.check_fp8_block_scaling_support),
("MXFP8BlockScaling", fp8.check_mxfp8_support),
["NVFP4BlockScaling", fp8.check_nvfp4_support],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inconsistent with other entries: should use tuple parentheses (...) instead of list brackets [...]

Suggested change
["NVFP4BlockScaling", fp8.check_nvfp4_support],
("NVFP4BlockScaling", fp8.check_nvfp4_support),

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
("Float8CurrentScaling", fp8.check_fp8_support),
("Float8BlockScaling", fp8.check_fp8_block_scaling_support),
("MXFP8BlockScaling", fp8.check_mxfp8_support),
["NVFP4BlockScaling", fp8.check_nvfp4_support],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use tuple parentheses instead of list brackets for consistency with other entries

Suggested change
["NVFP4BlockScaling", fp8.check_nvfp4_support],
("NVFP4BlockScaling", fp8.check_nvfp4_support),

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Example of quantized_model_init for low-precision compute weights and fp32 main weights with fsdp2

3 participants