Skip to content

[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677

Open
sudhakarsingh27 wants to merge 17 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_return_stats_max_cudnn
Open

[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677
sudhakarsingh27 wants to merge 17 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_return_stats_max_cudnn

Conversation

@sudhakarsingh27
Copy link
Collaborator

Description

cuDNN recently made returning any subset of {Stats, SumExp, Max} possible. This PR adapts TE to always get Stats from cuDNN and Max tensor if return_max_logit=True. (Note that Stats = log(SumExp)+Max)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • fused_attn_f16_arbitrary_seqlen.cu
    • Removed references to SumExp tensor as it's not needed since cuDNN returns Stats by default.
    • set generate_stats=True which forces cuDNN to always return Stats tensor (needed in the backward pass)
  • transformer_engine/pytorch/cpp_extensions/fused_attn.py
    • Remove code that manually did Stats = log(SumExp) + Max since cuDNN returns Stats directly and TE doesn't need SumExp from cuDNN
  • Corresponding documentation

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

sudhakarsingh27 and others added 5 commits February 12, 2026 13:12
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 12, 2026

Greptile Summary

This PR successfully adapts TransformerEngine to leverage cuDNN's new capability to return Stats directly (where Stats = log(SumExp) + Max), eliminating the need for manual calculation. The changes are well-coordinated across C++/CUDA and Python layers.

Key Changes:

  • Stats tensor is now always returned from cuDNN (via generate_stats=true)
  • When return_max_logit=True, Max tensor is additionally returned alongside Stats
  • Removed SumExp tensor handling and manual Stats calculation in Python
  • Renamed generate_max_sum_exp to return_max_logit in descriptor struct for clarity
  • Updated tensor ordering from (Max, SumExp) to (Stats, Max)
  • All documentation and comments updated to reflect new behavior

Verification:

  • Forward pass correctly sets up Stats and optional Max tensors
  • Backward pass receives Stats as expected
  • Tensor indexing in Python wrapper correctly updated
  • All previous review comments have been addressed

Confidence Score: 5/5

  • This PR is safe to merge with no blocking issues found
  • All changes are well-coordinated across C++/CUDA and Python layers, previous review feedback has been addressed, tensor ordering is consistent throughout the codebase, and the logic correctly implements the new cuDNN behavior
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Removed SumExp tensor handling, now always generates Stats from cuDNN, optionally returns Max when return_max_logit=True. Tensor ordering changed to Stats first, then Max.
transformer_engine/common/fused_attn/utils.h Renamed descriptor field from generate_max_sum_exp to return_max_logit for clearer semantics matching actual behavior.
transformer_engine/pytorch/cpp_extensions/fused_attn.py Updated tensor indexing to reflect new order (Stats, Max), removed manual Stats calculation, now uses Stats from cuDNN directly.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    Start[Forward Pass Starts] --> SetStats[Set generate_stats=true]
    SetStats --> CheckMax{return_max_logit?}
    CheckMax -->|Yes| CreateMax[Create Max tensor<br/>Set sdpa_options.set_logit_max]
    CheckMax -->|No| SkipMax[Skip Max creation]
    CreateMax --> RunSDPA[Run cuDNN SDPA]
    SkipMax --> RunSDPA
    RunSDPA --> ReturnStats[cuDNN returns Stats + O]
    ReturnStats --> CheckMaxReturn{return_max_logit?}
    CheckMaxReturn -->|Yes| ReturnStatsMax[Return: O, Stats, Max, rng_state, ...]
    CheckMaxReturn -->|No| ReturnStatsOnly[Return: O, Stats, rng_state, ...]
    ReturnStatsMax --> PyWrapper[Python Wrapper]
    ReturnStatsOnly --> PyWrapper
    PyWrapper --> CheckMaxPy{return_max_logit?}
    CheckMaxPy -->|Yes| ExtractMax[Extract Stats + Max<br/>Compute max_logit from Max]
    CheckMaxPy -->|No| UseStats[Use Stats directly]
    ExtractMax --> BackwardPass[Stats passed to backward pass]
    UseStats --> BackwardPass
Loading

Last reviewed commit: 7363541

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 17, 2026

Additional Comments (1)

transformer_engine/pytorch/cpp_extensions/fused_attn.py
Stale docstring: wrong formula for softmaxStats

The public docstring still describes softmaxStats as log(sum(e^(x - max(x)))), which is log(SumExp). However, with this PR, the returned tensor is cuDNN's Stats = log(SumExp) + Max, not just log(SumExp). This formula was already incorrect before this PR (the old code computed Max + log(SumExp) and stored it as stats), but the PR is an opportunity to correct it.

                       softmaxStats: torch.Tensor
                           log(sum(e^(x - max(x)))) + max(x), where x=Q*K.T (i.e. Stats = log(SumExp) + Max)
                           shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

stats = output_tensors[1] + torch.log(output_tensors[2])
# thd: output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1]
# bshd: output_tensors: out [b, sq, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] (there's no typo here)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the "there's no typo here" :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deliberately added it because I didn't believe it and checked the shapes myself :P

size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();

Copy link
Collaborator

@cyanguwa cyanguwa Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need to make these changes in the "Aux_CTX_Tensors->size == 0" sections in _fwd/bwd_qkvpacked/kvpacked APIs as well. Please check. Thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like I don't need to because nvte_fused...qvpacked are in fused_attn.cpp which calls fused_attn_f16_arbitrary... just like regular nvte_fused_fwd/bwd

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

# Max -> max_logit [h]
max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype)
aux_ctx_tensors = [stats]
max_logit = torch.amax(output_tensors[2], dim=amax_dims).to(dtype=output_tensors[0].dtype)
Copy link
Collaborator

@KshitijLakhani KshitijLakhani Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I understood this incorrectly, but isn't TE now also suppose to receive max from cuDNN directly (like stats, but with stats it is always true and with max it cn be toggled) rather than calling amax() in TE ?

(Sudhakar: Why am I able to update your comment? )

Copy link
Collaborator Author

@sudhakarsingh27 sudhakarsingh27 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuDNN returns Max ([b, h, sq, 1]) so it's an additional softmax statistic (apparently, the subset (Stats, Max) is enough for cuDNN bwd rather than the full set (Stats, SumExp, Max)).

Further, for muon, we need do amax on it to get a dimension [h] tensor. return_max_logit in TE controls whether to fetch Max from cuDNN.

Perf wise, it'd be nice for cuDNN to do additional reduction to return the [h] shaped tensor for muon as well but that's not the scope of this PR.

(Kshitij: looks like I can as well)

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27 sudhakarsingh27 force-pushed the fix_return_stats_max_cudnn branch from 21ca43a to becc3ad Compare February 20, 2026 19:41
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 20, 2026

Additional Comments (1)

transformer_engine/common/include/transformer_engine/fused_attn.h
Entire file has been reformatted with unintentional 3-space indentation changes. This creates a large diff unrelated to the actual feature changes. Revert the formatting to match the original file structure.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the fix_return_stats_max_cudnn branch from d4568db to 8f40cab Compare February 20, 2026 20:00
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants