Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
LayerNormLinear,
LayerNormMLP,
Linear,
GroupedLinear,
MultiheadAttention,
TransformerLayer,
autocast,
Expand All @@ -23,6 +24,7 @@
is_nvfp4_available,
is_bf16_available,
)
from transformer_engine.pytorch.module.grouped_linear import _GroupedLinear
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common import recipe
Expand Down Expand Up @@ -157,6 +159,52 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None)
torch.testing.assert_close(t1, t2, rtol=0, atol=0)


def test_grouped_linear_forwards_fp8_graph_skip_tensor(monkeypatch) -> None:
"""GroupedLinear should propagate the dynamic FP8 weight-update skip flag."""
mod = GroupedLinear(1, 4, 4, bias=False, device="cpu")
skip_fp8_weight_update = torch.tensor([1.0])
captured_non_tensor_args = {}

monkeypatch.setattr(FP8GlobalStateManager, "fp8_graph_capturing", lambda: True)
monkeypatch.setattr(
FP8GlobalStateManager.quantization_state,
"skip_fp8_weight_update_tensor",
skip_fp8_weight_update,
)
monkeypatch.setattr(mod, "is_debug_iter", lambda: False)
monkeypatch.setattr(mod, "prepare_forward", lambda inp, num_gemms=1: inp)
monkeypatch.setattr(mod, "end_forward", lambda: None)
monkeypatch.setattr(mod, "_get_weight_tensors", lambda: [torch.empty(4, 4)])
monkeypatch.setattr(mod, "_get_bias_tensors", lambda: [])
monkeypatch.setattr(
mod,
"_get_quantizers",
lambda: ([None], [None], [None], [None], [None], [None]),
)

def _capture_forward(ctx, inp, non_tensor_args, *weights_and_biases):
del ctx, weights_and_biases
captured_non_tensor_args["value"] = non_tensor_args
return inp, [None]

monkeypatch.setattr(_GroupedLinear, "forward", staticmethod(_capture_forward))

with torch.no_grad():
mod(torch.empty(2, 4), [2], is_first_microbatch=True)

non_tensor_args = captured_non_tensor_args["value"]
# non_tensor_args layout (see GroupedLinear.forward):
# 0: m_splits, 1: apply_bias, 2: is_first_microbatch, 3: fp8, 4: fp8_calibration,
# 5: wgrad_store, 6-11: quantizers (x6), 12: fuse_wgrad_accumulation,
# 13: is_cpu_offload, 14: sequence_parallel, 15: activation_dtype,
# 16: is_grad_enabled, 17: weight_workspaces, 18: cache_weight,
# 19: skip_fp8_weight_update, 20: save_original_input, 21: debug
_IS_FIRST_MICROBATCH_IDX = 2
_SKIP_FP8_WEIGHT_UPDATE_IDX = 19
assert non_tensor_args[_IS_FIRST_MICROBATCH_IDX] is False
assert non_tensor_args[_SKIP_FP8_WEIGHT_UPDATE_IDX] is skip_fp8_weight_update


def generate_data(
model_config: ModelConfig,
dtype: torch.dtype,
Expand Down
11 changes: 10 additions & 1 deletion transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,15 @@ def forward(

is_grad_enabled = torch.is_grad_enabled()

if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = (
FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor
)
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False

inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
try:
weight_tensors = self._get_weight_tensors()
Expand Down Expand Up @@ -1199,7 +1208,7 @@ def forward(
is_grad_enabled,
weight_workspaces,
cache_weight,
None, # skip_fp8_weight_update
skip_fp8_weight_update,
self.save_original_input,
debug,
)
Expand Down