From 6e01d0a4c11abbf596994f18104804ace8415bef Mon Sep 17 00:00:00 2001 From: hongbinl Date: Wed, 27 May 2026 00:01:25 -0700 Subject: [PATCH 1/9] Support selective offload for fused grouped MLP --- transformer_engine/pytorch/cpu_offload_v1.py | 2 + .../pytorch/ops/fused/forward_grouped_mlp.py | 61 ++++++++++++++++--- .../tensor/storage/grouped_tensor_storage.py | 15 +++++ 3 files changed, 70 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload_v1.py b/transformer_engine/pytorch/cpu_offload_v1.py index fb62546cc0..097719b3be 100644 --- a/transformer_engine/pytorch/cpu_offload_v1.py +++ b/transformer_engine/pytorch/cpu_offload_v1.py @@ -36,6 +36,8 @@ def mark_activation_offload(*tensors, offload: bool = True): tensor._TE_do_not_offload = True else: data_tensors = tensor.get_data_tensors() + if not offload: + tensor._TE_do_not_offload = True for tensor in data_tensors: if tensor is not None: if offload: diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 034d404439..fa99193002 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -13,10 +13,12 @@ import torch import transformer_engine_torch as tex +from ...cpu_offload import mark_not_offload from ...quantization import Recipe from ...tensor import Quantizer from ...utils import get_cached_ones_tensor, get_device_compute_capability, mark_grouped_tensor from ...tensor.grouped_tensor import GroupedTensor +from ...tensor.storage.grouped_tensor_storage import GroupedTensorStorage from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...constants import MXFP8_BLOCK_SCALING_SIZE from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU @@ -291,10 +293,38 @@ def fuser_forward( if isinstance(input_, GroupedTensor) and isinstance( getattr(input_, "quantizer", None), MXFP8Quantizer ): - grouped_fc1_x = input_ + grouped_fc1_x = GroupedTensorStorage( + shape=input_.logical_shape, + dtype=input_.fake_dtype, + num_tensors=input_.num_tensors, + shapes=input_.tensor_shapes, + quantizer=input_.quantizer, + data=input_.rowwise_data, + columnwise_data=input_.columnwise_data, + scale_inv=input_.scale_inv, + columnwise_scale_inv=input_.columnwise_scale_inv, + amax=input_.amax, + columnwise_amax=input_.columnwise_amax, + scale=input_.scale, + first_dims=input_.first_dims, + last_dims=input_.last_dims, + tensor_offsets=input_.tensor_offsets, + offsets=input_.offsets, + scale_inv_offsets=input_.scale_inv_offsets, + columnwise_scale_inv_offsets=input_.columnwise_scale_inv_offsets, + with_gemm_swizzled_scales=input_._with_gemm_swizzled_scales, + row_scaled_nvfp4=input_.row_scaled_nvfp4, + ) else: fc1_x = maybe_dequantize(input_, dtype) - grouped_fc1_x = tex.group_quantize(fc1_x, fc1_input_quantizer, num_groups, split_sizes) + quantizer_internal = fc1_input_quantizer.internal + fc1_input_quantizer.internal = True + try: + grouped_fc1_x = tex.group_quantize( + fc1_x, fc1_input_quantizer, num_groups, split_sizes + ) + finally: + fc1_input_quantizer.internal = quantizer_internal # Pack data tensors # Note: Fused kernel expects tensor with non-contiguous @@ -419,7 +449,7 @@ def fuser_forward( # Repack columnwise scales on GPU to preserve group ordering. # FC2 inputs scales are already swizzled/optimized for GEMM - grouped_fc2_x = GroupedTensor( + grouped_fc2_x = GroupedTensorStorage( shape=(in_shape[0], fc2_weight_shape[1]), dtype=dtype, num_tensors=num_groups, @@ -500,6 +530,13 @@ def fuser_forward( if requires_grad: mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x) activation_op = self.basic_ops[1] + selective_offload = hasattr(fc1_op, "activation_offloading") or hasattr( + activation_op, "activation_offloading" + ) + offload_fc1_input = bool(getattr(fc1_op, "activation_offloading", False)) + offload_activation_input = bool( + getattr(activation_op, "activation_offloading", False) + ) activation_is_srelu = isinstance(activation_op, ScaledSReLU) activation_recompute_in_mlp = bool( getattr(activation_op, "activation_recompute_in_mlp", False) @@ -513,7 +550,7 @@ def fuser_forward( ) saved_grouped_fc2_x = None if recompute_srelu_fc2_x else grouped_fc2_x - # Save the input ``GroupedTensor``s themselves for the activations. + # Save the input grouped tensor storages themselves for the activations. for grouped_fc_x in (grouped_fc1_x, saved_grouped_fc2_x): if grouped_fc_x is not None: grouped_fc_x.rowwise_data = None @@ -525,6 +562,17 @@ def fuser_forward( fc1_weight_tensors = ( [grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight ) + fc2_weight_tensors = ( + [grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight + ) + if selective_offload: + if not offload_fc1_input: + mark_not_offload(grouped_fc1_x) + if not offload_activation_input: + mark_not_offload(activation_in, scales) + if saved_grouped_fc2_x is not None: + mark_not_offload(saved_grouped_fc2_x) + mark_not_offload(*fc1_weight_tensors, *fc2_weight_tensors) fc1_ctx.save_for_backward( split_sizes, base_split_offsets, @@ -555,10 +603,7 @@ def fuser_forward( # [split_sizes, base_split_offsets, split_points, # (fc2_scales if _scale_bias), # grouped_fc2_x, *fc2_weight_tensors] - fc2_weight_tensors = ( - [grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight - ) - fc2_saved: list[Optional[torch.Tensor]] = [ + fc2_saved: list[Optional[torch.Tensor | GroupedTensorStorage]] = [ split_sizes, base_split_offsets, split_points, diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 438e124021..c112634024 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -387,6 +387,21 @@ def restore_from_saved( self.tensor_offsets = tensors[9] return tensors[10:] + def get_data_tensors(self): + """Get tensor fields that may be saved or offloaded.""" + return ( + self.rowwise_data, + self.columnwise_data, + self.scale_inv, + self.columnwise_scale_inv, + self.amax, + self.columnwise_amax, + self.scale, + self.first_dims, + self.last_dims, + self.tensor_offsets, + ) + def clear(self) -> None: """ Reset tensor data and clear all buffers. From 7854c210c2ef56ad4f86c9a3b9577eb8f01ffc06 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 May 2026 07:50:54 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index fa99193002..66342a4706 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -534,9 +534,7 @@ def fuser_forward( activation_op, "activation_offloading" ) offload_fc1_input = bool(getattr(fc1_op, "activation_offloading", False)) - offload_activation_input = bool( - getattr(activation_op, "activation_offloading", False) - ) + offload_activation_input = bool(getattr(activation_op, "activation_offloading", False)) activation_is_srelu = isinstance(activation_op, ScaledSReLU) activation_recompute_in_mlp = bool( getattr(activation_op, "activation_recompute_in_mlp", False) From e7f0e78f2c4b91262247d02ea61705f484eea345 Mon Sep 17 00:00:00 2001 From: hongbinl Date: Wed, 27 May 2026 01:47:59 -0700 Subject: [PATCH 3/9] Keep fused grouped MLP FC1 input internal --- .../pytorch/ops/fused/forward_grouped_mlp.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 66342a4706..5a1452a1ce 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -290,6 +290,7 @@ def fuser_forward( # Group-quantize input tensor and convert dtypes if needed fc1_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) fc1_input_quantizer.optimize_for_gemm = True + fc1_input_quantizer.internal = True if isinstance(input_, GroupedTensor) and isinstance( getattr(input_, "quantizer", None), MXFP8Quantizer ): @@ -317,14 +318,9 @@ def fuser_forward( ) else: fc1_x = maybe_dequantize(input_, dtype) - quantizer_internal = fc1_input_quantizer.internal - fc1_input_quantizer.internal = True - try: - grouped_fc1_x = tex.group_quantize( - fc1_x, fc1_input_quantizer, num_groups, split_sizes - ) - finally: - fc1_input_quantizer.internal = quantizer_internal + grouped_fc1_x = tex.group_quantize( + fc1_x, fc1_input_quantizer, num_groups, split_sizes + ) # Pack data tensors # Note: Fused kernel expects tensor with non-contiguous From 84e6424167aea479d961cedb2865c68d1f99cd2d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 May 2026 08:49:38 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 5a1452a1ce..895ef1cda5 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -318,9 +318,7 @@ def fuser_forward( ) else: fc1_x = maybe_dequantize(input_, dtype) - grouped_fc1_x = tex.group_quantize( - fc1_x, fc1_input_quantizer, num_groups, split_sizes - ) + grouped_fc1_x = tex.group_quantize(fc1_x, fc1_input_quantizer, num_groups, split_sizes) # Pack data tensors # Note: Fused kernel expects tensor with non-contiguous From eeda00581b67cb1e6e267f71ceb4ddbc2bbd5e10 Mon Sep 17 00:00:00 2001 From: hongbinl Date: Wed, 27 May 2026 01:54:49 -0700 Subject: [PATCH 5/9] Rename fused grouped MLP offload marker --- .../pytorch/ops/fused/forward_grouped_mlp.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 895ef1cda5..3fee6b70d5 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -524,11 +524,15 @@ def fuser_forward( if requires_grad: mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x) activation_op = self.basic_ops[1] - selective_offload = hasattr(fc1_op, "activation_offloading") or hasattr( - activation_op, "activation_offloading" + fine_grained_activation_offloading = hasattr( + fc1_op, "fine_grained_activation_offloading" + ) or hasattr(activation_op, "fine_grained_activation_offloading") + offload_fc1_input = bool( + getattr(fc1_op, "fine_grained_activation_offloading", False) + ) + offload_activation_input = bool( + getattr(activation_op, "fine_grained_activation_offloading", False) ) - offload_fc1_input = bool(getattr(fc1_op, "activation_offloading", False)) - offload_activation_input = bool(getattr(activation_op, "activation_offloading", False)) activation_is_srelu = isinstance(activation_op, ScaledSReLU) activation_recompute_in_mlp = bool( getattr(activation_op, "activation_recompute_in_mlp", False) @@ -557,7 +561,7 @@ def fuser_forward( fc2_weight_tensors = ( [grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight ) - if selective_offload: + if fine_grained_activation_offloading: if not offload_fc1_input: mark_not_offload(grouped_fc1_x) if not offload_activation_input: From 726c84c9aa510041bf330f2926e91e89e1019af6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 May 2026 08:56:01 +0000 Subject: [PATCH 6/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 3fee6b70d5..47204823bd 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -527,9 +527,7 @@ def fuser_forward( fine_grained_activation_offloading = hasattr( fc1_op, "fine_grained_activation_offloading" ) or hasattr(activation_op, "fine_grained_activation_offloading") - offload_fc1_input = bool( - getattr(fc1_op, "fine_grained_activation_offloading", False) - ) + offload_fc1_input = bool(getattr(fc1_op, "fine_grained_activation_offloading", False)) offload_activation_input = bool( getattr(activation_op, "fine_grained_activation_offloading", False) ) From 573a455ed6215b37251b692a1843ccfe8654a34d Mon Sep 17 00:00:00 2001 From: hongbinl Date: Wed, 27 May 2026 02:07:58 -0700 Subject: [PATCH 7/9] Avoid tagging grouped storage wrapper for offload --- transformer_engine/pytorch/cpu_offload_v1.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload_v1.py b/transformer_engine/pytorch/cpu_offload_v1.py index 097719b3be..fb62546cc0 100644 --- a/transformer_engine/pytorch/cpu_offload_v1.py +++ b/transformer_engine/pytorch/cpu_offload_v1.py @@ -36,8 +36,6 @@ def mark_activation_offload(*tensors, offload: bool = True): tensor._TE_do_not_offload = True else: data_tensors = tensor.get_data_tensors() - if not offload: - tensor._TE_do_not_offload = True for tensor in data_tensors: if tensor is not None: if offload: From 9854db1d6aad2e2fa7f14e94ff3be10384a770b6 Mon Sep 17 00:00:00 2001 From: hongbinl Date: Wed, 27 May 2026 02:13:50 -0700 Subject: [PATCH 8/9] Move FC2 offload markers next to FC2 save --- .../pytorch/ops/fused/forward_grouped_mlp.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 47204823bd..44b537df6a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -556,17 +556,10 @@ def fuser_forward( fc1_weight_tensors = ( [grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight ) - fc2_weight_tensors = ( - [grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight - ) if fine_grained_activation_offloading: if not offload_fc1_input: mark_not_offload(grouped_fc1_x) - if not offload_activation_input: - mark_not_offload(activation_in, scales) - if saved_grouped_fc2_x is not None: - mark_not_offload(saved_grouped_fc2_x) - mark_not_offload(*fc1_weight_tensors, *fc2_weight_tensors) + mark_not_offload(*fc1_weight_tensors) fc1_ctx.save_for_backward( split_sizes, base_split_offsets, @@ -585,6 +578,8 @@ def fuser_forward( fc1_ctx.weight_requires_grad = weight_requires_grad # Activation + if fine_grained_activation_offloading and not offload_activation_input: + mark_not_offload(activation_in, scales) activation_ctx.save_for_backward(activation_in, scales) activation_ctx.extra_input_requires_grad = True activation_ctx.input_requires_grad = True @@ -597,6 +592,13 @@ def fuser_forward( # [split_sizes, base_split_offsets, split_points, # (fc2_scales if _scale_bias), # grouped_fc2_x, *fc2_weight_tensors] + fc2_weight_tensors = ( + [grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight + ) + if fine_grained_activation_offloading: + if saved_grouped_fc2_x is not None: + mark_not_offload(saved_grouped_fc2_x) + mark_not_offload(*fc2_weight_tensors) fc2_saved: list[Optional[torch.Tensor | GroupedTensorStorage]] = [ split_sizes, base_split_offsets, From 7b091a79dba869489b99edd6575598732757c731 Mon Sep 17 00:00:00 2001 From: hongbinl Date: Wed, 27 May 2026 07:10:52 -0700 Subject: [PATCH 9/9] Simplify fused grouped MLP offload checks --- .../pytorch/ops/fused/forward_grouped_mlp.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 44b537df6a..b5b237a398 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -13,7 +13,7 @@ import torch import transformer_engine_torch as tex -from ...cpu_offload import mark_not_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_not_offload from ...quantization import Recipe from ...tensor import Quantizer from ...utils import get_cached_ones_tensor, get_device_compute_capability, mark_grouped_tensor @@ -524,9 +524,7 @@ def fuser_forward( if requires_grad: mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x) activation_op = self.basic_ops[1] - fine_grained_activation_offloading = hasattr( - fc1_op, "fine_grained_activation_offloading" - ) or hasattr(activation_op, "fine_grained_activation_offloading") + cpu_offloading = is_cpu_offload_enabled() offload_fc1_input = bool(getattr(fc1_op, "fine_grained_activation_offloading", False)) offload_activation_input = bool( getattr(activation_op, "fine_grained_activation_offloading", False) @@ -556,7 +554,7 @@ def fuser_forward( fc1_weight_tensors = ( [grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight ) - if fine_grained_activation_offloading: + if cpu_offloading: if not offload_fc1_input: mark_not_offload(grouped_fc1_x) mark_not_offload(*fc1_weight_tensors) @@ -578,7 +576,7 @@ def fuser_forward( fc1_ctx.weight_requires_grad = weight_requires_grad # Activation - if fine_grained_activation_offloading and not offload_activation_input: + if cpu_offloading and not offload_activation_input: mark_not_offload(activation_in, scales) activation_ctx.save_for_backward(activation_in, scales) activation_ctx.extra_input_requires_grad = True @@ -595,7 +593,7 @@ def fuser_forward( fc2_weight_tensors = ( [grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight ) - if fine_grained_activation_offloading: + if cpu_offloading: if saved_grouped_fc2_x is not None: mark_not_offload(saved_grouped_fc2_x) mark_not_offload(*fc2_weight_tensors)