Skip to content

Feat/selective offload on srelu fuser#3047

Open
lhb8125 wants to merge 9 commits into
NVIDIA:mainfrom
lhb8125:feat/selective-offload-on-srelu-fuser
Open

Feat/selective offload on srelu fuser#3047
lhb8125 wants to merge 9 commits into
NVIDIA:mainfrom
lhb8125:feat/selective-offload-on-srelu-fuser

Conversation

@lhb8125
Copy link
Copy Markdown
Contributor

@lhb8125 lhb8125 commented May 27, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 27, 2026
@lhb8125 lhb8125 force-pushed the feat/selective-offload-on-srelu-fuser branch from dba3531 to 53e6511 Compare May 27, 2026 07:26
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 27, 2026

Greptile Summary

This PR adds selective CPU activation offloading to the SReLU fused op (_ForwardGroupedMLP_CuTeGEMMBase_MXFP8). In the pre-quantized MXFP8 input path the FC1 input is now wrapped in a GroupedTensorStorage (previously used directly as GroupedTensor), and the GroupedTensorStorage type is also used for the FC2 GEMM input. A new get_data_tensors() method on GroupedTensorStorage enables the V1 offload path to enumerate component tensors without destructive side-effects.

  • FC1 input and activation-function input are selectively offloaded based on fine_grained_activation_offloading on the respective ops; weights are always pinned to GPU.
  • FC2 GEMM input (saved_grouped_fc2_x) is always pinned to GPU with mark_not_offload, without a corresponding fine_grained_activation_offloading check, making selective offloading of that tensor impossible through the new API.
  • In the non-V1 offload path, mark_not_offload internally calls prepare_for_saving/restore_from_saved on GroupedTensorStorage objects, which clears quantized_tensors as a side effect and does not restore it; this works safely today because those objects are freshly constructed, but the pattern is fragile for future callers.

Confidence Score: 4/5

The selective offloading logic is functional for FC1 input and activation tensors; the main gap is that the FC2 GEMM input has no corresponding selective-offload control and is always pinned to GPU.

The core mechanism works correctly end-to-end. The FC2 input is unconditionally prevented from offloading with no API to override this, creating an asymmetry that may be an oversight. The non-V1 path uses prepare_for_saving destructively on GroupedTensorStorage objects but is safe here because quantized_tensors is always None at the point of the call.

forward_grouped_mlp.py around the FC2 mark_not_offload block at lines 596-599 where the missing fine_grained_activation_offloading gate lives

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Adds selective CPU offloading control: converts the pre-quantized MXFP8 input path from GroupedTensor to GroupedTensorStorage, marks FC1/activation tensors with fine_grained_activation_offloading, but unconditionally prevents FC2 input offloading with no symmetrical selective control
transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Adds get_data_tensors() to expose the same 10 tensor fields as prepare_for_saving, enabling the V1 offload path to non-destructively mark GroupedTensorStorage component tensors

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Forward fused op] --> B[Construct GroupedTensorStorage for fc1_x and fc2_x]
    B --> C{cpu_offloading enabled?}
    C -- No --> G[save_for_backward via OperationContext]
    C -- Yes --> D{offload_fc1_input?}
    D -- False --> E[mark_not_offload grouped_fc1_x]
    D -- True --> F[fc1_x eligible for offload]
    E --> H[mark_not_offload fc1 weights always]
    F --> H
    H --> I{offload_activation_input?}
    I -- False --> J[mark_not_offload activation_in and scales]
    I -- True --> K[activation tensors eligible for offload]
    J --> L[mark_not_offload saved_grouped_fc2_x always]
    K --> L
    L --> M[mark_not_offload fc2 weights always]
    M --> G
    G --> N[fuser.py prepare_for_saving decomposes objects]
    N --> O[PyTorch save_for_backward push_tensor hook]
    O --> P{_TE_do_not_offload set?}
    P -- Yes --> Q[Stay on GPU]
    P -- No --> R[Offload to CPU]
Loading

Reviews (4): Last reviewed commit: "Simplify fused grouped MLP offload check..." | Re-trigger Greptile

Comment on lines +542 to +546
selective_offload = hasattr(fc1_op, "activation_offloading") or hasattr(
activation_op, "activation_offloading"
)
offload_fc1_input = bool(getattr(fc1_op, "activation_offloading", False))
offload_activation_input = bool(getattr(activation_op, "activation_offloading", False))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Selective-offload gate never activates unless callers set activation_offloading on op objects

hasattr(fc1_op, "activation_offloading") checks for a dynamic attribute on the op module. mark_activation_offload (both V1 and non-V1) sets activation_offloading on tensors, not on op instances, and neither GroupedLinear nor ScaledSReLU declare this attribute. As written, selective_offload will always be False and none of the new marking logic will execute unless callers set fc1_op.activation_offloading = True externally. If this is intentional, the attribute name, type, and expected caller pattern should be documented; if not, the gate condition needs to match how the attribute is actually assigned.

@lhb8125 lhb8125 force-pushed the feat/selective-offload-on-srelu-fuser branch from 2c59510 to 6e01d0a Compare May 27, 2026 07:49
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, but with one design suggestion.

Followup tasks after merging this PR:

  • Enable activation checkpointing in the unfused grouped linear op.
  • Update activation checkpointing to support v2 infrastructure from #1762, which is opt-out rather than opt-in.

Comment on lines +528 to +531
offload_fc1_input = bool(getattr(fc1_op, "fine_grained_activation_offloading", False))
offload_activation_input = bool(
getattr(activation_op, "fine_grained_activation_offloading", False)
)
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 May 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Do these options give us value? The dense linear op and activation ops don't expose this fine-grained control:

if is_cpu_offload_enabled():
mark_activation_offload(saved_input)

if is_cpu_offload_enabled():
mark_activation_offload(x)

  • For consistency with the rest of the CPU offloading behavior, shouldn't the default be to enable offloading? Disabling offloading should be the explicit path.
  • These secret undocumented attrs are delicate and unmaintainable. Better to make them arguments in the unfused ops. However, this also means we should update the unfused impls so that they disable activation checkpointing if the option is set.

Easiest just to not make this configurable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants