Skip to content

fix unfused padding causal sdpa#3063

Open
hungryGeek16 wants to merge 472 commits into
NVIDIA:mainfrom
hungryGeek16:fix-unfused-padding-causal-sdpa
Open

fix unfused padding causal sdpa#3063
hungryGeek16 wants to merge 472 commits into
NVIDIA:mainfrom
hungryGeek16:fix-unfused-padding-causal-sdpa

Conversation

@hungryGeek16
Copy link
Copy Markdown

@hungryGeek16 hungryGeek16 commented May 31, 2026

Adds a targeted PyTorch SDPA fallback for unfused THD padding_causal self-attention so TransformerEngine does not materialize the full quadratic padding/causal mask. Includes a regression test that fails if get_full_mask is called on this path.

ptrendx and others added 30 commits February 14, 2025 17:10
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
…x FP8 related codes (NVIDIA#1468)

* add prob permute; fix fp8tensor

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert unnecessary changes in UT

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* remove unnecessary probs dtype convert

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* keep the output nums if probs is not provided

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refine the doc string

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* fix lint

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* use fp32 compute type

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* style fix

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* fix empty input return

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* separate prob related functions out

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Xin Yao <xiny@nvidia.com>
Co-authored-by: Phuong Nguyen <phuonguyen@nvidia.com>
flax module with compute dtype inferred from the inputs

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
* Fix issues for MCore DDP.

Signed-off-by: Dennis Liu <denliu@nvidia.com>

* Remove force data release for CPU offloading.

Signed-off-by: Dennis Liu <denliu@nvidia.com>

* Add preserved attributeds.

Signed-off-by: Dennis Liu <denliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add main_grad to prevserved attributes.

Signed-off-by: Dennis Liu <denliu@nvidia.com>

* Change prepare_for_saving to original tensor and add .data to CPU hook.

Signed-off-by: Dennis Liu <denliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update.

Signed-off-by: Dennis Liu <denliu@nvidia.com>

* Fix for LayernormLinear in FP8.

Signed-off-by: Dennis Liu <denliu@nvidia.com>

---------

Signed-off-by: Dennis Liu <denliu@nvidia.com>
Co-authored-by: Xin Yao <xiny@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Fix typo

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* fix fuse_wgrad_accumulation for GroupedLinear

Signed-off-by: Xin Yao <xiny@nvidia.com>

* fix fuse_wgrad_accumulation for GroupedLinear

Signed-off-by: Xin Yao <xiny@nvidia.com>

* update tests

Signed-off-by: Xin Yao <xiny@nvidia.com>

---------

Signed-off-by: Xin Yao <xiny@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
* Fix te sequential for older pytorch versions

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* FIxes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* commit some debug code

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* add more debug info

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* debug code commit and typo fix

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* a typo fix

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* remove debug info

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* do not return lse

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* add amax_per_step for quantizers of CP

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* fix FP8 + CP

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bug fix

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* bug fix

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* dtype fix

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* bug fix

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

---------

Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Xiaowei Ren <xren@login-preos01.a51.clusters.nvidia.com>
…NVIDIA#1466)

Use same API in optimizer zero_grad as PyT optimizers

Signed-off-by: Tim Moon <tmoon@nvidia.com>
* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* reshape inp

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

---------

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
* minor fixes for attention

Signed-off-by: Charlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <charleney@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…IA#1502)

* Fix a crash with module._apply(lambda t: t.cpu())

Signed-off-by: Guyue Huang <guyueh@nvidia.com>

* Add comments

Signed-off-by: Guyue Huang <guyueh@nvidia.com>

* Make sure tensor is moved to dst device before quantizer quantizes

Signed-off-by: Guyue Huang <guyueh@nvidia.com>

---------

Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
* add remove_caches api

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* Update transformer_engine/pytorch/tensor/float8_tensor.py

Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* explicit delete

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Added TMA alignment check to cast_fp8_1D

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use tensor const-ref instead of tensor const-ptr

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
)

* Skip context parallelism tests if not enough GPUs

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Apply suggestions from code review

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
* delete extra tensor objects after restoring float8 tensors

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* nit fix

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fix the leak in float8tensor and mxfloat8tensor classes

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* uncomment the fix

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

---------

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fix quantized tensor shape

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* add shape to make_like; add test for chunk

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix typo from suggestion

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
…rt (NVIDIA#1528)

Set flag in norm modules for Mcore sequence-parallel support

Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Fix wheel install after src install

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix JAX imports

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* switch order of dirs for finding so

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Use existing dir src build

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix lint

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
…IA#1548)

Don't set data to null

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
update cudnn-frontend to its new 1.11.0-rc

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
* Enable fp8_primary_weights for current scaling

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use different cast_master_weights_to_fp8 functions depending on the type of quantizer

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* All amaxes of model_weights should participate in reduce-max

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Clear _high_precision_init_val automatically in cast_master_weights_to_fp8 function

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Merge all all-reduce on amaxes into one NCCL kernel

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add unit tests for multi_tensor_compute_scale_and_scale_inv and preserve_high_precision_init_val

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Fix conflicts

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add unit test for cast_master_weights_to_fp8

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use mock group to initialize fp8_autocast to avoid reduction of amax_history by fp8_autocast_exit

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Remove with_computing_amax and with_computing_scale

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Move replace_raw_data from QuantizedTensor to utils.py

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Remove allow_empty_output argument from nvte_compute_amax and set it always be true

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Rename import guard of recipe_common.cuh to be align with other import guards

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Add unit test for replace_raw_data

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add test_replace_raw_data into qa/L0_pytorch_unittest/test.sh

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Minor changes in comments

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Add randomness to the unit test of replace_raw_data

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (Maybe need revert) Add tex.quantize_to_fragment

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (Maybe needsto rrevert) Use nvte_quantize_noop in quantize_to_fragment

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix lint error

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Move high_precision_init_val test and replace_raw_data test to test_sanity.py

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove test_fp8_model_init.py and test_replace_raw_data.py

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Remove cast_master_weights_to_fp8 and replace_raw_data from __all__ of tensor.__init__.py

Signed-off-by: kunlunl <kunlunl@nvidia.com>

* Move FP8 casting logic back from C++ tex funcs to Python

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove unimplemented function from header

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: kunlunl <kunlunl@nvidia.com>
Signed-off-by: Kunlun Li <94586211+kunlunl@users.noreply.github.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: Tim Moon <tmoon@nvidia.com>
* fix dtypes of fused_attn_bwd in CP+A2A

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* fix dtypes of fused_attn_bwd in CP+P2P

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix amax_per_step

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* clone scaling factors of fwd quantizers

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* fix fwd quantizers of CP+P2P

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* minor change

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* dequantize fp8 out in CP unit test

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

* delete redundant None in FusedAttnFunc bwd

Signed-off-by: Xiaowei Ren <xren@nvidia.com>

---------

Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Update usage of weightmat before saving for backward

Signed-off-by: Guyue Huang <guyueh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix for layernorm mlp

Signed-off-by: Guyue Huang <guyueh@nvidia.com>

---------

Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
import te before te_jax

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Do not suppress MXFP8 norm in Python wrapper func

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Support FP8 current scaling in tex norm functions

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Use single envvar to enable cuDNN MXFP8 norm kernels

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Debug compilation error

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix compilation error

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix full-tile requirement for MXFP8 norm kernels

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Remove unused imports

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add missing imports

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fix mxfp8 columnwise data missing when switching from validation to training

Signed-off-by: Guyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>

* Fix when you interleave training and inference

Signed-off-by: Guyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>

* refact

Signed-off-by: Guyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rm useless code

Signed-off-by: Guyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>

* Update transformer_engine/pytorch/module/base.py

Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: guyueh1 <140554423+guyueh1@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warnings

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>

---------

Signed-off-by: Guyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>
Signed-off-by: guyueh1 <140554423+guyueh1@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: Guyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
bbuschkaemper and others added 21 commits April 22, 2026 16:35
Signed-off-by: Björn Buschkämper <bjoern.buschkaemper@gmail.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
…ed quantization kernel (NVIDIA#2921)

Fix the race in the dbias computation

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
fix fp8 and is_bwd_fp8 relationship

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Fix FA4 selection when FA3 is unavailable.

Signed-off-by: Björn Buschkämper <bjoern.buschkaemper@gmail.com>
…ed quantization kernel (NVIDIA#2921)

Fix the race in the dbias computation

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
…A#2922)

* remove ctype to eliminate memory usage from the cudnn kernel

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* Remove c_dtype from fusible ops test

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* [Common, PyTorch] Add triton mHC kernels & pytorch operators

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* fix

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* nit

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* make linter happy

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* nit

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* ah OK

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* new configs to improve perf

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* add APIs to docs

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* fix typos, check deterministic, refactor

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* fix

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reset rng for all tests

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* add docstring

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* fix api doc

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* whoops

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* grad_x doesn't have to zero

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* nit

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* nit

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* force pytorch to not use bf16 for reduction

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* use TE's general_gemm instead

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Looks like this is how to make TE use fp32 acc

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…True (NVIDIA#2936)

* fix

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add test

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* zero_out should also be tested

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

---------

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: root <root@gb-nvl-059-compute03.nvidia.com>
…NVIDIA#2924)

* Fix contiguous path for k=2880

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* format

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review suggestion from @Oleg-Goncharov

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add test for swizzle + padding fusion

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
…IDIA#2929)

* Avoid removing usages from quantized weight in linear op

Quantized weight tensor may be used across steps, so removing a usage is not safe.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Tweak test to catch bug when alternating train and infer steps

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Avoid removing usages from quantized weights in grouped linear op

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Restore pre-forward quantizer config in ops

Turns out we still need this in case the quantizer is used before the forward, e.g. in previous ops or CPU offloading.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Blindly preserve quantizer usages in quantized weight params.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…n expert (NVIDIA#2947)

Add workaround for cuteDSL stride requirement for zero token expert

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
…pace for THD sequences (NVIDIA#2522)

* Get seqlens and offsets in O(N) space instead of O(N*N) space

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Re enable fast causal path

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Fix: seqoffsets calculation for THD

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>

* Clean up code. Add new comments. Fix unecessary pasing of seg pos to the seqoffsets calculation API

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Optimize and fix the slow O(T*T) path for seqlens and seqoffsets calculation for THD non-cp and Cp p2p ring
    - Newer path is O(T*max_segments) per seq
    - Newer path works well with CP p2p ring

    Fix BRCM cross attn by routing to new slow path rather than fast causal path

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix lint failure

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>

---------

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kshitij  Janardan Lakhani <klakhani@login-ptyche02.ptyche.clusters.nvidia.com>
Co-authored-by: JAX Toolbox <jax@nvidia.com>
…NVIDIA#2948)

* Switch to cuDNN-FE min version 1.23.0 to enable fused grouped MLP

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
…IDIA#2942)

* accumulate bias in fp32 instead of bf16 in ref impl dbias to avoid accumulated numerical error

Signed-off-by: tdophung <tdophung@nvidia.com>
…VIDIA#2955)

* Better documentation for single param and envvar guard

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix doc

Signed-off-by: ksivamani <ksivamani@nvidia.com>

* Fix test envvar

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: ksivamani <ksivamani@nvidia.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 31, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 31, 2026

Greptile Summary

This PR avoids materializing a full quadratic [b, 1, sq, sk] attention mask for padding_causal unfused attention by routing through a new per-sample F.scaled_dot_product_attention(is_causal=True) loop (_forward_varlen_sdpa) when the conditions are met. It also carries several unrelated improvements: MXFP8 attention backend filtering in get_attention_backend, FP8EmulationFunc null-quantizer guards, MXFP8 documentation updates, a new cp_size field in AttentionParams, and cuda_version added to the attention run-config log.

  • Unfused padding_causal fast path: _use_varlen_sdpa gates a new _forward_varlen_sdpa method that iterates over batch items and calls PyTorch SDPA with is_causal=True, entirely bypassing get_full_mask. A new test (test_unfused_thd_padding_causal_uses_sdpa_without_full_mask) patches get_full_mask to assert it is never called and checks numerical correctness.
  • MXFP8 FP8-emulation fixes: FP8EmulationFunc now handles None quantizers in the S_quantizer/O_quantizer/dO_quantizer/dP_quantizer branches and applies a BSHD→SBHD permute when the quantizer is an MXFP8Quantizer.

Confidence Score: 3/5

The new fast path produces correct results under normal configuration, but silently computes wrong attention scores when NVTE_APPLY_QK_LAYER_SCALING=1 is in use.

The new varlen-SDPA branch passes self.softmax_scale to _forward_varlen_sdpa instead of the locally-modified scale variable. Any deployment with NVTE_APPLY_QK_LAYER_SCALING=1 and padding_causal self-attention will silently receive incorrect attention output from the new fast path, with no error or warning.

transformer_engine/pytorch/attention/dot_product_attention/backends.py — specifically the scale argument at the _forward_varlen_sdpa call site.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds _use_varlen_sdpa / _forward_varlen_sdpa fast path for padding_causal; passes self.softmax_scale instead of the locally-modified scale variable, silently breaking apply_qk_layer_scaling.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Adds MXFP8 backend filtering, cp_size field to AttentionParams, cuda_version to run_config, and refactors FA3/FA4 SM90 preference logic; changes appear correct.
tests/pytorch/attention/test_attention.py Adds test_unfused_thd_padding_causal_uses_sdpa_without_full_mask which verifies both correctness and that get_full_mask is not called on the new fast path.
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Documentation updates for MXFP8 recipe combinations; no functional logic changes.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[UnfusedDotProductAttention.forward] --> B{Convert input to sbhd}
    B --> C{padding in mask type and mask is None?}
    C -->|Yes| D[get_padding_mask: build mask from cu_seqlens]
    C -->|No| E[use existing mask]
    D --> F[_use_varlen_sdpa?]
    E --> F
    F -->|YES: padding_causal, self-attn, no bias, no dropout, no fp8, no alibi| G[_forward_varlen_sdpa: batch loop with SDPA is_causal=True, no full mask allocation]
    F -->|NO| H[get_full_mask: materialize full b x sq x sk mask]
    G --> I[_format_context: convert b,h,sq,d to output layout]
    H --> J[FP8 emulation if needed]
    J --> K[baddbmm QK matmul]
    K --> L[FusedScaleMaskSoftmax]
    L --> M[dropout + AV bmm]
    M --> N[reshape to output format]
    I --> Z[output tensor]
    N --> Z
Loading

Reviews (1): Last reviewed commit: "Avoid full mask allocation in unfused pa..." | Re-trigger Greptile

Comment on lines +575 to +585
return self._forward_varlen_sdpa(
query_layer,
key_layer,
value_layer,
q_format,
batch_size,
max_seqlen_q,
cu_seqlens_q,
attention_mask,
self.softmax_scale,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Wrong scale passed to fast-path SDPA

self.softmax_scale is passed here, but the locally-modified scale variable (which already incorporates the apply_qk_layer_scaling division by layer_number) was computed just a few lines above and is the value that should be forwarded. When NVTE_APPLY_QK_LAYER_SCALING=1 is set, scale = self.softmax_scale / self.layer_number, but the new varlen-SDPA path silently uses the unscaled value, producing wrong attention scores for every call that hits this branch with that env var enabled.

Suggested change
return self._forward_varlen_sdpa(
query_layer,
key_layer,
value_layer,
q_format,
batch_size,
max_seqlen_q,
cu_seqlens_q,
attention_mask,
self.softmax_scale,
)
if self._use_varlen_sdpa(
attn_mask_type,
attention_mask,
window_size,
core_attention_bias_type,
alibi_slopes,
fp8,
):
return self._forward_varlen_sdpa(
query_layer,
key_layer,
value_layer,
q_format,
batch_size,
max_seqlen_q,
cu_seqlens_q,
attention_mask,
scale,
)

Comment on lines +357 to +362
if attn_mask_type != "padding_causal":
return False
if window_size not in [None, (-1, 0), (-1, -1)]:
return False
if attn_mask_type == "padding_causal" and attention_mask is None:
return False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The first attn_mask_type != "padding_causal" guard already returns False for any value other than "padding_causal", so the later if attn_mask_type == "padding_causal" and … condition is always true at that point — the attn_mask_type == part is dead code and creates a misleading impression that the branch could be False.

Suggested change
if attn_mask_type != "padding_causal":
return False
if window_size not in [None, (-1, 0), (-1, -1)]:
return False
if attn_mask_type == "padding_causal" and attention_mask is None:
return False
if attn_mask_type != "padding_causal":
return False
if window_size not in [None, (-1, 0), (-1, -1)]:
return False
if attention_mask is None:
return False

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.