Skip to content
47 changes: 43 additions & 4 deletions transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
import torch

import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_not_offload
from ...quantization import Recipe
from ...tensor import Quantizer
from ...utils import get_cached_ones_tensor, get_device_compute_capability, mark_grouped_tensor
from ...tensor.grouped_tensor import GroupedTensor
from ...tensor.storage.grouped_tensor_storage import GroupedTensorStorage
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...constants import MXFP8_BLOCK_SCALING_SIZE
from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU
Expand Down Expand Up @@ -288,10 +290,32 @@ def fuser_forward(
# Group-quantize input tensor and convert dtypes if needed
fc1_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
fc1_input_quantizer.optimize_for_gemm = True
fc1_input_quantizer.internal = True
if isinstance(input_, GroupedTensor) and isinstance(
getattr(input_, "quantizer", None), MXFP8Quantizer
):
grouped_fc1_x = input_
grouped_fc1_x = GroupedTensorStorage(
shape=input_.logical_shape,
dtype=input_.fake_dtype,
num_tensors=input_.num_tensors,
shapes=input_.tensor_shapes,
quantizer=input_.quantizer,
data=input_.rowwise_data,
columnwise_data=input_.columnwise_data,
scale_inv=input_.scale_inv,
columnwise_scale_inv=input_.columnwise_scale_inv,
amax=input_.amax,
columnwise_amax=input_.columnwise_amax,
scale=input_.scale,
first_dims=input_.first_dims,
last_dims=input_.last_dims,
tensor_offsets=input_.tensor_offsets,
offsets=input_.offsets,
scale_inv_offsets=input_.scale_inv_offsets,
columnwise_scale_inv_offsets=input_.columnwise_scale_inv_offsets,
with_gemm_swizzled_scales=input_._with_gemm_swizzled_scales,
row_scaled_nvfp4=input_.row_scaled_nvfp4,
)
else:
fc1_x = maybe_dequantize(input_, dtype)
grouped_fc1_x = tex.group_quantize(fc1_x, fc1_input_quantizer, num_groups, split_sizes)
Expand Down Expand Up @@ -419,7 +443,7 @@ def fuser_forward(
# Repack columnwise scales on GPU to preserve group ordering.

# FC2 inputs scales are already swizzled/optimized for GEMM
grouped_fc2_x = GroupedTensor(
grouped_fc2_x = GroupedTensorStorage(
shape=(in_shape[0], fc2_weight_shape[1]),
dtype=dtype,
num_tensors=num_groups,
Expand Down Expand Up @@ -500,6 +524,11 @@ def fuser_forward(
if requires_grad:
mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x)
activation_op = self.basic_ops[1]
cpu_offloading = is_cpu_offload_enabled()
offload_fc1_input = bool(getattr(fc1_op, "fine_grained_activation_offloading", False))
offload_activation_input = bool(
getattr(activation_op, "fine_grained_activation_offloading", False)
)
Comment on lines +528 to +531
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 May 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Do these options give us value? The dense linear op and activation ops don't expose this fine-grained control:

if is_cpu_offload_enabled():
mark_activation_offload(saved_input)

if is_cpu_offload_enabled():
mark_activation_offload(x)

  • For consistency with the rest of the CPU offloading behavior, shouldn't the default be to enable offloading? Disabling offloading should be the explicit path.
  • These secret undocumented attrs are delicate and unmaintainable. Better to make them arguments in the unfused ops. However, this also means we should update the unfused impls so that they disable activation checkpointing if the option is set.

Easiest just to not make this configurable.

activation_is_srelu = isinstance(activation_op, ScaledSReLU)
activation_recompute_in_mlp = bool(
getattr(activation_op, "activation_recompute_in_mlp", False)
Expand All @@ -513,7 +542,7 @@ def fuser_forward(
)
saved_grouped_fc2_x = None if recompute_srelu_fc2_x else grouped_fc2_x

# Save the input ``GroupedTensor``s themselves for the activations.
# Save the input grouped tensor storages themselves for the activations.
for grouped_fc_x in (grouped_fc1_x, saved_grouped_fc2_x):
if grouped_fc_x is not None:
grouped_fc_x.rowwise_data = None
Expand All @@ -525,6 +554,10 @@ def fuser_forward(
fc1_weight_tensors = (
[grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight
)
if cpu_offloading:
if not offload_fc1_input:
mark_not_offload(grouped_fc1_x)
mark_not_offload(*fc1_weight_tensors)
fc1_ctx.save_for_backward(
split_sizes,
base_split_offsets,
Expand All @@ -543,6 +576,8 @@ def fuser_forward(
fc1_ctx.weight_requires_grad = weight_requires_grad

# Activation
if cpu_offloading and not offload_activation_input:
mark_not_offload(activation_in, scales)
activation_ctx.save_for_backward(activation_in, scales)
activation_ctx.extra_input_requires_grad = True
activation_ctx.input_requires_grad = True
Expand All @@ -558,7 +593,11 @@ def fuser_forward(
fc2_weight_tensors = (
[grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight
)
fc2_saved: list[Optional[torch.Tensor]] = [
if cpu_offloading:
if saved_grouped_fc2_x is not None:
mark_not_offload(saved_grouped_fc2_x)
mark_not_offload(*fc2_weight_tensors)
fc2_saved: list[Optional[torch.Tensor | GroupedTensorStorage]] = [
split_sizes,
base_split_offsets,
split_points,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,21 @@ def restore_from_saved(
self.tensor_offsets = tensors[9]
return tensors[10:]

def get_data_tensors(self):
"""Get tensor fields that may be saved or offloaded."""
return (
self.rowwise_data,
self.columnwise_data,
self.scale_inv,
self.columnwise_scale_inv,
self.amax,
self.columnwise_amax,
self.scale,
self.first_dims,
self.last_dims,
self.tensor_offsets,
)

def clear(self) -> None:
"""
Reset tensor data and clear all buffers.
Expand Down
Loading