From 80304fa3a6d9dfbc825bb0170e30d927808e9eae Mon Sep 17 00:00:00 2001 From: allenphilipj Date: Thu, 28 May 2026 13:33:12 +0100 Subject: [PATCH 1/2] Fix GroupedLinear FP8 graph weight update flag Signed-off-by: allenphilipj --- tests/pytorch/test_cuda_graphs.py | 40 +++++++++++++++++++ .../pytorch/module/grouped_linear.py | 11 ++++- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index bb4a4e3857..9cbe5e39a7 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -12,6 +12,7 @@ LayerNormLinear, LayerNormMLP, Linear, + GroupedLinear, MultiheadAttention, TransformerLayer, autocast, @@ -23,6 +24,7 @@ is_nvfp4_available, is_bf16_available, ) +from transformer_engine.pytorch.module.grouped_linear import _GroupedLinear from transformer_engine.pytorch.quantization import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe @@ -157,6 +159,44 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) torch.testing.assert_close(t1, t2, rtol=0, atol=0) +def test_grouped_linear_forwards_fp8_graph_skip_tensor(monkeypatch) -> None: + """GroupedLinear should propagate the dynamic FP8 weight-update skip flag.""" + mod = GroupedLinear(1, 4, 4, bias=False, device="cpu") + skip_fp8_weight_update = torch.tensor([1.0]) + captured_non_tensor_args = {} + + monkeypatch.setattr(FP8GlobalStateManager, "fp8_graph_capturing", lambda: True) + monkeypatch.setattr( + FP8GlobalStateManager.quantization_state, + "skip_fp8_weight_update_tensor", + skip_fp8_weight_update, + ) + monkeypatch.setattr(mod, "is_debug_iter", lambda: False) + monkeypatch.setattr(mod, "prepare_forward", lambda inp, num_gemms=1: inp) + monkeypatch.setattr(mod, "end_forward", lambda: None) + monkeypatch.setattr(mod, "_get_weight_tensors", lambda: [torch.empty(4, 4)]) + monkeypatch.setattr(mod, "_get_bias_tensors", lambda: []) + monkeypatch.setattr( + mod, + "_get_quantizers", + lambda: ([None], [None], [None], [None], [None], [None]), + ) + + def _capture_forward(ctx, inp, non_tensor_args, *weights_and_biases): + del ctx, weights_and_biases + captured_non_tensor_args["value"] = non_tensor_args + return inp, [None] + + monkeypatch.setattr(_GroupedLinear, "forward", staticmethod(_capture_forward)) + + with torch.no_grad(): + mod(torch.empty(2, 4), [2], is_first_microbatch=True) + + non_tensor_args = captured_non_tensor_args["value"] + assert non_tensor_args[2] is False + assert non_tensor_args[19] is skip_fp8_weight_update + + def generate_data( model_config: ModelConfig, dtype: torch.dtype, diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 627144345c..92b2d1fa2f 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -1143,6 +1143,15 @@ def forward( is_grad_enabled = torch.is_grad_enabled() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = ( + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor + ) + else: + skip_fp8_weight_update = None + if skip_fp8_weight_update is not None: + is_first_microbatch = False + inp = self.prepare_forward(inp, num_gemms=self.num_gemms) try: weight_tensors = self._get_weight_tensors() @@ -1199,7 +1208,7 @@ def forward( is_grad_enabled, weight_workspaces, cache_weight, - None, # skip_fp8_weight_update + skip_fp8_weight_update, self.save_original_input, debug, ) From a1dc637c273c6ee73fd1ab1617f5b6603c068a33 Mon Sep 17 00:00:00 2001 From: allenphilipj Date: Thu, 28 May 2026 13:54:18 +0100 Subject: [PATCH 2/2] Update tests/pytorch/test_cuda_graphs.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: allenphilipj --- tests/pytorch/test_cuda_graphs.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 9cbe5e39a7..775ba443f1 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -193,8 +193,16 @@ def _capture_forward(ctx, inp, non_tensor_args, *weights_and_biases): mod(torch.empty(2, 4), [2], is_first_microbatch=True) non_tensor_args = captured_non_tensor_args["value"] - assert non_tensor_args[2] is False - assert non_tensor_args[19] is skip_fp8_weight_update + # non_tensor_args layout (see GroupedLinear.forward): + # 0: m_splits, 1: apply_bias, 2: is_first_microbatch, 3: fp8, 4: fp8_calibration, + # 5: wgrad_store, 6-11: quantizers (x6), 12: fuse_wgrad_accumulation, + # 13: is_cpu_offload, 14: sequence_parallel, 15: activation_dtype, + # 16: is_grad_enabled, 17: weight_workspaces, 18: cache_weight, + # 19: skip_fp8_weight_update, 20: save_original_input, 21: debug + _IS_FIRST_MICROBATCH_IDX = 2 + _SKIP_FP8_WEIGHT_UPDATE_IDX = 19 + assert non_tensor_args[_IS_FIRST_MICROBATCH_IDX] is False + assert non_tensor_args[_SKIP_FP8_WEIGHT_UPDATE_IDX] is skip_fp8_weight_update def generate_data(