Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
472 commits
Select commit Hold shift + click to select a range
a234867
Changed VERSION to 2.1.0
ptrendx Feb 15, 2025
c6a8df8
[MoE][PyTorch] Add prob permutation to mask-based MoE permutation; Fi…
hxbai Feb 18, 2025
dca7dbd
[JAX] Flax with compute dtype inferred from input dtype. (#1485)
phu0ngng Feb 18, 2025
4a4a6fa
Fix issues for MCore DDP. (#1474)
Victarry Feb 19, 2025
9c5436f
[PyTorch] Fix typo (#1495)
timmoon10 Feb 19, 2025
524668b
[PyTorch] Fix fuse_wgrad_accumulation for GroupedLinear (#1488)
yaox12 Feb 19, 2025
b1e948e
Fix TE ops API compatibility with PyTorch versions < 2.4.3 (#1494)
ksivaman Feb 20, 2025
3b64927
[PyTorch] Fix CP implementation with FP8 (#1483)
xrennvidia Feb 20, 2025
1b384b9
[PyTorch] Use same API in optimizer `zero_grad` as PyTorch optimizers…
timmoon10 Feb 22, 2025
7d07a1a
[Pytorch] Added missing assert_dim_for_fp8_exec for Linear
pggPL Feb 24, 2025
6266011
Minor fixes for attention (#1504)
cyanguwa Feb 25, 2025
7b10a04
Fix a crash in NeMo 2.0 during module._apply(lambda t: t.cpu()) (#1502)
guyueh1 Feb 25, 2025
435823b
Adding remove_caches API to Float8Tensor class (#1425)
youngeunkwon0405 Feb 25, 2025
30cea25
Added memory alignment check to cast_fp8_1D (#1507)
Oleg-Goncharov Feb 26, 2025
867ab06
[PyTorch] Skip context parallelism tests if not enough GPUs (#1508)
timmoon10 Feb 26, 2025
4f9cd42
Delete extra tensor objects after restoring float8 tensors (#1500)
sudhakarsingh27 Feb 28, 2025
8efb39d
Fix shape of new quantized tensor in `make_like` (#1515)
ksivaman Feb 28, 2025
f2b09d2
[PyTorch] Set flags in norm modules for Mcore sequence-parallel suppo…
timmoon10 Mar 1, 2025
ad0ee94
Fix installation from PyPI wheels after a source install (#1526)
ksivaman Mar 5, 2025
450146a
[PyTorch] Don't set FP8 data to `None` when saving base tensors (#1548)
ksivaman Mar 7, 2025
8eb1712
Release v2.1
ptrendx Mar 17, 2025
b6a2a48
Changed VERSION to 2.2.0
ptrendx Mar 18, 2025
eeadd43
Update cudnn-frontend to new 1.11.0-rc commit (#1590)
cyanguwa Mar 20, 2025
e9e0cd7
[PyTorch] Enable fp8_primary_weights for current scaling (#1544)
kunlunl Mar 22, 2025
28095af
Fix issues in fused_attn_bwd (#1574)
xrennvidia Mar 24, 2025
06bede8
Ensure weight transpose is valid for Hopper FP8 training (#1596)
guyueh1 Mar 24, 2025
a5eb420
[JAX] Fixing importing in the encoder examples (#1600)
phu0ngng Mar 25, 2025
0ddf331
Remove deprecated interval arg to delayed scaling recipe (#1607)
ksivaman Mar 25, 2025
457cd69
[PyTorch] Use consistent API for fused norm kernels (#1560)
timmoon10 Mar 22, 2025
7694039
Fix mxfp8 columnwise data missing (#1593)
guyueh1 Mar 25, 2025
c2d9275
[PyTorch] Minor fixes for TE 2.2 (#1589)
cyanguwa Mar 25, 2025
80d2177
[PyTorch] Optimize MXFP8 all-gathers (#1581)
timmoon10 Mar 25, 2025
c45f5fd
[PyTorch] Add tests for current scaling; misc related fixes (#1606)
ksivaman Mar 27, 2025
b4706a6
fix a sync race error of softmax_lse in CP+THD+P2P (#1624)
xrennvidia Mar 31, 2025
9577cf5
[JAX] Add fast path for causal masking with segment IDs. (#1601)
mgoldfarb-nvidia Mar 31, 2025
6756466
[PyTorch] Support default process group with FP8 current scaling (#1621)
timmoon10 Mar 31, 2025
b27283a
[JAX] Refactor + MXFP8 + GroupedGEMM (#1627)
phu0ngng Apr 1, 2025
8f11acc
[PyTorch] fix fuse_wgrad_accumulation in LayerNormMLP backward (#1618)
Marks101 Apr 1, 2025
4924444
Revert "[JAX] Refactor + MXFP8 + GroupedGEMM (#1627)"
KshitijLakhani Apr 1, 2025
32062e0
Bugfixes for LayerNormMLP (#1625)
guyueh1 Apr 1, 2025
f546444
[PyTorch] Make breaking change in `InferenceParams.init` more explici…
cyanguwa Apr 1, 2025
069dd8b
[PyTorch] Debug NCCL communication overlapping in linear backward wit…
timmoon10 Apr 1, 2025
02180cd
Fix fp8_buf for Linear and LayerNormLinear (#1633)
ksivaman Apr 3, 2025
8e0853a
Introduce NVSHMEM based communication API for pytorch (#1430)
gdengk Apr 4, 2025
c55e425
[PyTorch] Debug weight matrix usages for dgrad GEMM (#1637)
timmoon10 Apr 4, 2025
3b87080
Release v2.2
ptrendx Apr 18, 2025
fc424e5
Changed VERSION to 2.3.0
ptrendx Apr 18, 2025
234fec7
Revert "Allow NVTEShape to own data." (#1703)
timmoon10 Apr 19, 2025
5f3a162
rtx5090 arch fix support (#1659)
sudhakarsingh27 Apr 21, 2025
4730925
[JAX] WAR for CuDNN MXFP8 norm incorrect result (#1700)
jberchtold-nvidia Apr 21, 2025
a3d464c
RoPE enhancements (#1478)
sudhakarsingh27 Apr 22, 2025
5de3e14
Refactor attention.py part 2 (#1704)
KshitijLakhani Apr 28, 2025
9c8ba5c
Release v2.3
ptrendx May 14, 2025
977b4bc
Changed VERSION to 2.4.0
ptrendx May 17, 2025
c034796
[Pytorch] NVIDIA-DL-Framework-Inspect support – part 3 – tests (#1612)
pggPL May 19, 2025
f52a2ea
Fix README render for uploading package to PyPI (#1798)
ksivaman May 19, 2025
6f5af6a
Enhance recipe compatibility (#1724)
negvet May 19, 2025
8c813f2
Use an empty torch tensor to indicate no fp8 information in extra_sta…
pstjohn May 20, 2025
7fe5d68
[Pytorch] NVIDIA-DL-Framework-Inspect support – part 4 – documentatio…
pggPL May 20, 2025
ea63d61
[PyTorch] Add docstring for CP load balancing (#1802)
cyanguwa May 20, 2025
2a235b1
Add missing docs for C API (#1803)
ksivaman May 21, 2025
fc74c4e
Remove `comm_gemm_overlap` doc (#1815)
ksivaman May 22, 2025
864406c
Add docs for missing FP8 recipes. (#1816)
ksivaman May 22, 2025
ac40601
Fix the failing test cases in the CI (#1806)
ptrendx May 23, 2025
4a3bf4f
Fix multi-framework runtime lib loading (#1825)
ksivaman May 28, 2025
3cd6870
Bump cuDNN FE (#1842)
ksivaman Jun 3, 2025
b43596b
Release v2.4
ptrendx Jun 5, 2025
980c434
Changed VERSION to 2.5.0
ptrendx Jun 13, 2025
efe19c3
[JAX] Grouped GEMM & Dense support MXFP8 and handle empty matrices (#…
huanghua1994 Jun 16, 2025
4a16c2d
[Pytorch] Bugfix in te fusion ce implementation (#1879)
BestJuly Jun 16, 2025
b894f69
[JAX] Fixes for L0_jax_distributed_unittest (#1884)
phu0ngng Jun 17, 2025
82bff47
[JAX] TensorUsage + FP8 GEMM with all layouts handling on BW (#1844)
phu0ngng Jun 18, 2025
9192fb6
[PyTorch] Use FP16 tols for distributed tests with TF32 compute (#1831)
timmoon10 Jun 19, 2025
1e03882
Fix cppunittest test.sh for editable installs (#1869)
jberchtold-nvidia Jun 25, 2025
6f6951e
[PyTorch][MoE] Reduce CPU Overhead By Fuse Torch Empty Calls (#1793)
zhongbozhu Jun 26, 2025
7b9d9a5
[PyTorch|common] Optimize unpadding kernel for FP8 (#1866)
xiaoxi-wangfj Jun 26, 2025
c42614d
[PyTorch Debug] Fix the issue with PP (#1894)
pggPL Jun 26, 2025
968eb0d
[PyTorch Debug] Fixed the empty tensor bug in statistics computation …
pggPL Jun 26, 2025
866953e
[JAX] Use keyword args for jit in_shardings and out_shardings (#1898)
jberchtold-nvidia Jun 26, 2025
8382eed
[PyTorch] Skip KV cache for sm89 and cuDNN < 9.12 (#1895)
cyanguwa Jun 26, 2025
f05f12c
Fix MLA CP Bugs (#1896)
yuzhongw-nvidia Jun 28, 2025
bf5b217
Changed VERSION to 2.6.0
KshitijLakhani Jul 20, 2025
c7d0271
[PyTorch] Remove GH pinned deps (#1961)
ksivaman Jul 21, 2025
787acff
[PyTorch] Reset FP8 weight workspace if usages are invalid (#1972)
timmoon10 Jul 21, 2025
9926245
Fix the condition error when checking fp8 attn in `get_attention_back…
yuzhongw-nvidia Jul 21, 2025
4b537aa
[Common] Skip cuDNN 9.10.0/9.10.1 due to bugs (#1937)
cyanguwa Jul 21, 2025
7ba6cd5
[PyTorch] Debug linear layer when saving original input and using deb…
timmoon10 Jul 22, 2025
b97c2bf
[Common] Improved performance of mxfp8 cast kernels (#1628)
Oleg-Goncharov Jul 22, 2025
a593092
Fix the device for cuDNN/cuBLAS handles (#1974)
cyanguwa Jul 23, 2025
928dfa8
[JAX] Fix current scaling test_helper.py and enable test_helper.py in…
jberchtold-nvidia Jul 23, 2025
13f5796
[JAX] Helper to disable TE custom calls + disable GemmPrimitive for n…
phu0ngng Jul 24, 2025
e02e289
Fix runtime lib loading for cuDNN (#1989)
ksivaman Jul 24, 2025
21d7410
Fix cudnn versioning support in PyTorch DPA and Fused attn (#1991)
KshitijLakhani Jul 24, 2025
0f585e8
[JAX] Fixing GemmPrimitive partitioning rules to handle tensor-parall…
denera Jul 24, 2025
5f1142e
[PyTorch] Optimize cudagraph static_grad_outputs reuse (#1992)
buptzyb Jul 25, 2025
af9467c
Release v2.5
ptrendx Jul 28, 2025
c90a720
Fix the use-after-free bug in unfused normalization (#2002)
ptrendx Jul 29, 2025
0289e76
Changed VERSION to 2.7.0
ptrendx Aug 18, 2025
34150d1
[JAX] Error checking for mesh resource and update GemmPrimitive to us…
jberchtold-nvidia Aug 20, 2025
9f065fa
[PyTorch] Avoid garbage collection when capturing a CUDA Graph (#2092)
timmoon10 Aug 20, 2025
3a4136b
Fix incorrect version checks for atomic GEMM (#2095)
timmoon10 Aug 20, 2025
0168c26
[ TE-JAX ] Expose cp_strategy argument to DPA api (#2090)
kocchop Aug 21, 2025
e94041a
[PyTorch] Debug Mcore wgrad fusion with te.ops (#2097)
timmoon10 Aug 23, 2025
c638ac7
[JAX] Add Shardy warning in GEMM custom call (#2101)
phu0ngng Aug 25, 2025
4572dbe
Revert "[Common] PDL for Quantization Kernels" (#2114)
jberchtold-nvidia Aug 26, 2025
d2615d1
Bump cuDNN FE to 1.14.0 (#2072)
vcherepanov-nv Aug 26, 2025
58c3ac8
Revert "[Common] PDL for Blockwise Quantization" (#2115)
jberchtold-nvidia Aug 26, 2025
f8d2c50
[PyTorch] Add test for TRT integration + fix for mxfp8 export (#2083)
pggPL Aug 20, 2025
d7874aa
Add cuBLASMp-backed GEMM-like API to TE common (#1824)
mk-61 Aug 26, 2025
1d1e8ef
Further relax constraints to cuDNN 9.13 for disabling fused attn for …
KshitijLakhani Aug 27, 2025
9cd6d16
Temporarily remove comm_gemm tests (#2133)
vcherepanov-nv Aug 28, 2025
fedd9dd
[PyTorch] Disable determinism for sm100 (#2130)
cyanguwa Aug 28, 2025
bb7bb2d
Release v2.6
ptrendx Sep 9, 2025
a9f2655
Changed VERSION to 2.8.0
ptrendx Sep 19, 2025
dd707eb
[JAX] Remove import jax.extend.ffi (#2193)
phu0ngng Sep 22, 2025
33b4fa7
[PyTorch] Add sink attention support from cuDNN (#2148)
cyanguwa Sep 22, 2025
6b7f51b
[QA] Add pytest xml report for all tests in qa folder that use pytest…
shengfangd Sep 23, 2025
8edb4e5
[JAX] Local-Amax for Current-Scaling (#2183)
mingxu1067 Sep 23, 2025
408f0de
[JAX] Restore Shardy Rule with CompoundFactor (#2167)
phu0ngng Sep 23, 2025
c5c09c6
[JAX] Update JAX version requirement in pyproject.toml (#2197)
phu0ngng Sep 24, 2025
9195516
[PyTorch] Unpin version of onnxscript and onnxruntime (#2202)
pggPL Sep 26, 2025
276f53e
[JAX] Fix XML filename in the L0_jax_uniitest (#2205)
phu0ngng Sep 27, 2025
e5b715e
[JAX] CollectiveGemm (#2166)
phu0ngng Sep 27, 2025
4c82348
[JAX] Add xml export for `test_multiprocessing_encoder` and `test_cge…
phu0ngng Sep 29, 2025
1b2d089
[JAX] Address tolerance check for current scaling dact dbias (#2211)
jberchtold-nvidia Sep 29, 2025
e2f14e4
[Core][PyTorch] NVFP4 recipe (#2177)
ksivaman Sep 29, 2025
9c9b39b
Release v2.7
ptrendx Sep 30, 2025
4afd391
Fix the segfault in the nvfp4 quantization (#2214)
ptrendx Sep 30, 2025
789c6ca
[PyTorch] Add FP8 attention with current scaling (#2012)
cyanguwa Sep 30, 2025
2db51ab
[JAX] Load modules during initialize for Norm and Act primitives (#2219)
jberchtold-nvidia Sep 30, 2025
264ab86
Fix the cuBLAS workspace alignment (#2223)
ptrendx Oct 1, 2025
40c69e7
[PyTorch] Set usages for linear op quantizers before forward (#2222)
timmoon10 Oct 2, 2025
966a5b9
Changed VERSION to 2.9.0
ptrendx Oct 16, 2025
739c656
[JAX] Fix imports in test for deprecated jax.experimental.pjit (#2274)
KshitijLakhani Oct 17, 2025
c2a643d
Wheels for cuda 13 (#2278)
ksivaman Oct 18, 2025
7e72d41
[JAX] NVFP4 recipe with option to enable/disable SR, RHT, and 2D quan…
jberchtold-nvidia Oct 22, 2025
9b75db3
Include TE core headers in final build (#2291)
ksivaman Oct 23, 2025
8b9849a
Overhaul the compilation for the arch-specific features (#2279)
ptrendx Oct 23, 2025
c4c185d
[PyTorch] Add max_logit support for MuonClip (#2195)
cyanguwa Oct 25, 2025
fa71964
[PyTorch] Fix CI failures due to deterministic attention backend (#2288)
ksivaman Oct 20, 2025
fe9b150
[JAX] Fix: Skip determinism tests for bprop for all sm >=100 (#2315)
KshitijLakhani Oct 30, 2025
0acd0e7
[PyTorch] Fix attention backend and tests for `sm120` (#2320)
ksivaman Oct 30, 2025
9cc089a
[PyT] Bump the min version expected to supported FP8 current scaling …
KshitijLakhani Oct 30, 2025
70f5366
[JAX] Ensure JAX reference impl uses an accurate backend in our tests…
jberchtold-nvidia Oct 30, 2025
f7cae5e
Release v2.8
ptrendx Nov 11, 2025
f7df8d8
Release v2.9
ptrendx Nov 11, 2025
0870ff0
Updated VERSION to 2.10.0
ptrendx Nov 15, 2025
bb399cf
[JAX] Quickstart documentation (#2310)
tdophung Nov 15, 2025
cde9328
Add num_splits support for FA3 backend (#2380)
cyanguwa Nov 17, 2025
b770878
[JAX] Add support for sink attention in JAX (#2225)
pggPL Nov 18, 2025
e39db2a
Show quickstart_jax.ipynb along with quickstart.ipynb on html documen…
tdophung Nov 18, 2025
82a1c17
[PyTorch] Fix small errors (#2396)
pggPL Nov 18, 2025
016f2f2
[PyTorch] fix `test_current_device` test (#2398)
cyanguwa Nov 19, 2025
d551ee7
[PyTorch] Disable Flash Attention backend in Userbuffers tests (#2399)
timmoon10 Nov 19, 2025
645716c
[PyTorch] Reduce CPU overheads (#2377)
ksivaman Nov 17, 2025
4027154
Minor improvements to CPU overhead (#2400)
ksivaman Nov 19, 2025
b932e53
[PyTorch] Fix ONNX export errors (#2406)
pggPL Nov 21, 2025
40e9246
[PyTorch] Fix for CPU offloading (#2403)
pggPL Nov 21, 2025
fb4ad6e
[JAX] Set BSHD as default in Unfused DPA, DPA and MHA API calls (#2392)
KshitijLakhani Nov 21, 2025
353a8ee
[JAX] Remove unnecessary SWA calculation in _segment_ids_pos_to_seqle…
KshitijLakhani Nov 21, 2025
e52bdb4
Enable SWA with CP for THD input format (#2220)
sudhakarsingh27 Nov 21, 2025
7ab2c9c
[PyTorch] Only disable Flash Attention in Userbuffers test on SM 8.0 …
timmoon10 Nov 21, 2025
e589e28
[JAX] Allow DP + FSDP and fixed sr_rng_state partitioning (#2418)
phu0ngng Nov 25, 2025
24fd351
[PyTorch] Avoid initializing recipe state in fusible op base class co…
timmoon10 Nov 26, 2025
981e65e
Extend docs with quantizers/quantized_tensors/custom_recipe (#2428)
negvet Nov 26, 2025
6b815f8
Docs fix (#2301)
pggPL Nov 26, 2025
769ed77
ci: Build and attach bdist wheels to release page (#2138)
ko3n1g Nov 21, 2025
686d502
Changed VERSION to 2.11.0
ptrendx Dec 8, 2025
066f199
Fix runtime lib loading logic (#2297)
ksivaman Dec 9, 2025
e2eca8b
Jax primitives for permutation on single GPU (#2473)
tdophung Dec 9, 2025
22d304c
[PyTorch] Add THD support for max_logit/MuonClip (#2480)
cyanguwa Dec 10, 2025
cda10c4
[PyTorch] Change order of args in another permutation triton kernel …
tdophung Dec 9, 2025
741720c
Check calling convention for amax switch. (#2506)
kwyss-nvidia Dec 15, 2025
c188b53
[PyTorch debug] Fix test for debug tools (#2507)
pggPL Dec 15, 2025
883b75e
Release v2.10
ptrendx Jan 15, 2026
9f3f4ab
Release v2.11
ptrendx Jan 15, 2026
d2fd002
Changed VERSION to 2.12.0
ptrendx Jan 20, 2026
6add8c9
[Common] Enable determinism for cuDNN >= 9.18.1 on Blackwell (#2584)
cyanguwa Jan 20, 2026
cfabd83
[Common] Tuned NVFP4 cast kernel (#2412)
Oleg-Goncharov Jan 21, 2026
42e803d
Fixed the year to 2026 (#2611)
Oleg-Goncharov Jan 21, 2026
d759aa6
[pyTorch] CPU performance optimizations (#2439)
ptrendx Jan 21, 2026
bf4af7e
[JAX] Fix cb.CUDAOptions usage for Triton 3.6.0 (#2610)
jberchtold-nvidia Jan 22, 2026
f49f515
Fix bugs in permutation custom partitioning (#2617)
tdophung Jan 23, 2026
d9b7fc5
[Common] Disabled the tuned NVFP4 kernels (#2615)
Oleg-Goncharov Jan 23, 2026
07f7750
[PyT] Update THD sink attention logic for cudnn >=9.18.0 (#2568)
cuichenx Jan 22, 2026
fdc0168
Add support for SWA (left, right) with FusedAttention (#2477)
sudhakarsingh27 Jan 22, 2026
3da26cd
[JAX] Use "nyu-mll/glue" instead of "glue" for encoder datasets to fi…
jberchtold-nvidia Jan 27, 2026
cad802f
[PyTorch] ONNX test fix + export for FP8 attention (#2598)
pggPL Jan 28, 2026
9bb9d22
[common] Add support for cuBLASLt GEMM for GroupedTensor (#2502)
pggPL Jan 28, 2026
5671fd3
Revert "[common] Add support for cuBLASLt GEMM for GroupedTensor (#25…
KshitijLakhani Jan 28, 2026
8dd2698
Changed VERSION to 2.13.0
ptrendx Feb 18, 2026
1afc001
[PyT] Plumbing correct bias dims from TE to cudnn, while adding suppo…
KshitijLakhani Feb 18, 2026
5c1b415
Update cudnn-frontend to v1.18 (#2689)
cyanguwa Feb 20, 2026
c068e80
Fix race condition in RHT amax kernels (#2695)
ksivaman Feb 21, 2026
bd89e94
Release v2.12
ptrendx Feb 24, 2026
4dea802
Add and verify support for `deterministic` fp8 dpa/mha on SM100 (#2621)
sudhakarsingh27 Feb 24, 2026
7deecab
[Common][PyTorch] Fuse scaling and unscaling of bf16 momentums into k…
yaox12 Feb 24, 2026
5bc39b0
remove deprecated qkv/kv_packed apis (#2696)
sudhakarsingh27 Feb 25, 2026
2e4c522
[Common] Remove volatile keyword in fused router kernel utils (#2683)
denera Feb 26, 2026
20c3855
[Common][PyTorch] Enhance the fused router and unify the precision (#…
yaox12 Feb 27, 2026
2877704
[PyTorch] Fix L3 FA tests (#2709)
cyanguwa Feb 28, 2026
d1e20ee
Changed VERSION to 2.14.0
ptrendx Mar 16, 2026
ed424d3
[NVFP4][Dense/MoE] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose…
zhongbozhu Mar 16, 2026
ed1f662
[PyTorch] Backwards compatible single param checkpointing in `Grouped…
ksivaman Mar 16, 2026
89120f7
[JAX][Core] Fix Grouped GEMM cuBLAS version and SM arch checks (#2765)
jberchtold-nvidia Mar 17, 2026
da3fe6b
Update cudnnFE to v1.20.0 (#2774)
ksivaman Mar 18, 2026
3b18ad8
[PyT] Install pytest in onnx L1 test as Pyt container no longer packa…
KshitijLakhani Mar 19, 2026
2fc98ff
[Core] Fix MXFP8 grouped quantize for zero-sized groups in update_tma…
jberchtold-nvidia Mar 19, 2026
86ca26f
[Common] Fix linker error for to_string(DType) in distributed tests (…
vcherepanov-nv Mar 16, 2026
a4f90a2
[PyT] [Common] Enable sm120 support for fused attn if cuDNN is 9.18.1…
KshitijLakhani Mar 22, 2026
2edaf84
Enable fused RMSNorm dLN + add through CUDNN (#2778)
CarlosGomes98 Mar 24, 2026
108ecc8
add blackwell support filter for 9.7<=cudnn<9.18.1 (#2775)
sudhakarsingh27 Mar 24, 2026
788a13b
[PyT][Commong] Disable fused attention for sm120 if determinism is re…
KshitijLakhani Mar 25, 2026
71bbefb
[PyTorch][Fused Attn] Add support for cuDNN to return Softmax `Stats`…
sudhakarsingh27 Mar 25, 2026
da8f7d6
Upgrade cuDNN FE to v1.21.0 (#2799)
ksivaman Mar 25, 2026
5aa4823
Release v2.13
ptrendx Mar 31, 2026
e3e33ac
[PyTorch] Fix bug with PR 2677 (#2819)
sudhakarsingh27 Apr 2, 2026
bc62582
[Common] Persistent Grouped MXFP8 quantization kernel (#2738)
Oleg-Goncharov Apr 2, 2026
b8e17cb
[JAX] Fix: Use jitted kernels for generating THD (and BSHD) segment p…
KshitijLakhani Apr 3, 2026
36e0631
GEMM + Swiglu fused Grouped MLP for MXFP8 (#2769)
ksivaman Apr 3, 2026
5018edf
Optimize FSDP2 Pytest Timings (12 -> 2 mins) (#2787)
vthumbe1503 Mar 24, 2026
849e4aa
[PyTorch] [CI] Capture subprocess stderr in distributed tests for bet…
sudhakarsingh27 Apr 3, 2026
62a72d0
[PyT][Test] Add xfailing FSDP2 memory leak detection tests (#2803)
pstjohn Apr 3, 2026
a4a073b
[FSDP2/Megatron-FSDP/DCP] If model parameters are DTensors, optimizer…
cspades Apr 4, 2026
f031cf8
CPU offloading fix: If Data and Transpose is None depend on super Tor…
vthumbe1503 Apr 7, 2026
0ebe377
Changed version to 2.15.0
ptrendx Apr 20, 2026
41fda45
Release v2.14
ptrendx Apr 21, 2026
f9de736
[PyTorch] Fix cuteDSL kernel incorrect numerics when K is 64 aligned …
ksivaman Apr 21, 2026
8e73f54
Add MXFP8 attention (#2719)
cyanguwa Apr 21, 2026
95edbd4
Fix flash attention version check. (#2910)
bbuschkaemper Apr 22, 2026
a506ec5
Make NS coefficients parameter 2D in Python API (#2904)
vcherepanov-nv Apr 22, 2026
96b26c2
Fix the race in the dbias computation in MXFP8 quantization and group…
ptrendx Apr 24, 2026
366798e
Changed version to 2.14.1
ptrendx Apr 24, 2026
27de67e
Release v2.14.1
ptrendx Apr 24, 2026
45fb909
[PyTorch] Fix CP A2A F16 when NVTE_FP8_DPA_BWD=1 (#2917)
cyanguwa Apr 23, 2026
4b74684
[PyTorch] Fix FA4 selection when FA3 is unavailable. (#2909)
bbuschkaemper Apr 23, 2026
150525e
Fix the race in the dbias computation in MXFP8 quantization and group…
ptrendx Apr 24, 2026
c9ab18a
Remove uncessary ctype being passed to GroupedGEMMQuant kernel (#2922)
vthumbe1503 Apr 24, 2026
9250b77
[Common] Fix "0" literal for compilation (#2934)
cyanguwa Apr 28, 2026
94958be
[Common, PyTorch] Add triton mHC kernels & pytorch APIs (#2790)
kainzhong Apr 28, 2026
6075536
[PyTorch] Main_Grad buffer isnt overwritten when overwrite_main_grad=…
vthumbe1503 Apr 29, 2026
1460477
Correctly pad scaling factor inverses to satisfy cuteDSL requirements…
ksivaman Apr 29, 2026
a429716
[PyTorch] Fusible ops preserve usages in quantized weight tensors (#2…
timmoon10 May 1, 2026
df68421
[PyTorch] Add workaround for cuteDSL stride requirement for zero-toke…
ksivaman May 1, 2026
e688ae4
[JAX] Calculate seqlens and offsets in O(T) space instead of O(T*T) s…
KshitijLakhani May 1, 2026
10bdccc
[PyTorch] Cleanup `cudnn-frontend` requirements for fused grouped MLP…
ksivaman May 1, 2026
3378ef1
[JAX] Fix bf16 precision loss in TestGroupedDense reference dbias (#2…
tdophung Apr 30, 2026
42b8400
[PyTorch] Guard/document single parameter feature for grouped linear …
ksivaman May 4, 2026
cabc6b6
Release v2.15
ptrendx May 13, 2026
51d298b
Avoid full mask allocation in unfused padding causal attention
hungryGeek16 May 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 89 files
+1 −1 CMakeLists.txt
+2 −0 README.md
+359 −199 include/cudnn_frontend/graph_interface.h
+14 −0 include/cudnn_frontend/graph_properties.h
+7 −7 include/cudnn_frontend/node/diagonal_band_mask.h
+23 −2 include/cudnn_frontend/node/scaled_dot_product_flash_attention.h
+38 −5 include/cudnn_frontend/node/sdpa_fp8_bwd.h
+7 −7 include/cudnn_frontend/node/softmax.h
+202 −192 include/cudnn_frontend/plans.h
+1 −1 include/cudnn_frontend_version.h
+1 −0 python/cudnn/README.md
+25 −1 python/cudnn/__init__.py
+137 −61 python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu/api.py
+207 −173 ...cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu/discrete_B_blockscaled_grouped_gemm_dglu_dbias.py
+146 −61 python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu/api.py
+241 −128 ...on/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu/discrete_B_blockscaled_grouped_gemm_glu_bias.py
+37 −8 python/cudnn/discrete_grouped_gemm/discrete_kernel_utils.py
+3 −0 python/cudnn/experimental/__init__.py
+3 −0 python/cudnn/experimental/ops/__init__.py
+1,079 −0 python/cudnn/experimental/ops/sdpa.py
+189 −412 python/cudnn/grouped_gemm/grouped_gemm_dglu/api.py
+0 −4,427 python/cudnn/grouped_gemm/grouped_gemm_dglu/continugous_blockscaled_grouped_gemm_dglu_quant_dbias_fusion.py
+159 −97 python/cudnn/grouped_gemm/grouped_gemm_dglu/moe_blockscaled_grouped_gemm_dglu_dbias.py
+4 −2 python/cudnn/grouped_gemm/grouped_gemm_dswiglu/grouped_gemm_dswiglu_quant.py
+202 −403 python/cudnn/grouped_gemm/grouped_gemm_glu/api.py
+0 −3,713 python/cudnn/grouped_gemm/grouped_gemm_glu/continugous_blockscaled_grouped_gemm_glu_quant_bias_fusion.py
+218 −90 python/cudnn/grouped_gemm/grouped_gemm_glu/moe_blockscaled_grouped_gemm_glu_bias.py
+349 −60 python/cudnn/grouped_gemm/grouped_gemm_quant/api.py
+10 −5 python/cudnn/grouped_gemm/grouped_gemm_quant/grouped_gemm_quant.py
+6 −4 python/cudnn/grouped_gemm/grouped_gemm_swiglu/grouped_gemm_swiglu_quant.py
+36 −7 python/cudnn/grouped_gemm/moe_kernel_helpers.py
+12 −0 python/cudnn/sdpa/__init__.py
+581 −0 python/cudnn/sdpa/api.py
+438 −0 python/cudnn/sdpa/fmha_backward_sm100_2kernel.py
+3,016 −0 python/cudnn/sdpa/fmha_dkdv_d256_sm100.py
+1,968 −0 python/cudnn/sdpa/fmha_dq_d256_sm100.py
+1,143 −0 python/cudnn/sdpa/fmha_utils.py
+784 −0 python/cudnn/sdpa/utils.py
+24 −0 python/cudnn/wrapper.py
+47 −0 python/pygraph/pygraph.cpp
+23 −2 python/pygraph/pygraph.h
+10 −4 python/pygraph/sdpa.cpp
+2 −4 samples/cpp/misc/serialization.cpp
+2 −2 samples/cpp/sdpa/fp16_fwd_with_max_and_sum_exp.cpp
+2 −1 samples/legacy_samples/fp8_flash_mha_sample.cpp
+2 −2 samples/legacy_samples/fp8_flash_mha_sample.h
+1 −1 samples/legacy_samples/test_list.cpp
+4 −4 test/cpp/tensor.cpp
+9 −1 test/python/conftest.py
+152 −0 test/python/fe_api/test_discrete_grouped_gemm_dswiglu.py
+201 −7 test/python/fe_api/test_discrete_grouped_gemm_dswiglu_utils.py
+148 −0 test/python/fe_api/test_discrete_grouped_gemm_swiglu.py
+15 −1 test/python/fe_api/test_discrete_grouped_gemm_swiglu_utils.py
+3 −0 test/python/fe_api/test_fe_api_utils.py
+384 −0 test/python/fe_api/test_grouped_gemm_dglu.py
+19 −8 test/python/fe_api/test_grouped_gemm_dswiglu_utils.py
+389 −0 test/python/fe_api/test_grouped_gemm_glu.py
+391 −0 test/python/fe_api/test_grouped_gemm_quant.py
+45 −22 test/python/fe_api/test_grouped_gemm_quant_utils.py
+28 −12 test/python/fe_api/test_grouped_gemm_swiglu_utils.py
+157 −0 test/python/fe_api/test_sdpa_bwd.py
+352 −0 test/python/fe_api/test_sdpa_bwd_utils.py
+1 −0 test/python/sdpa/fp16.py
+6 −2 test/python/sdpa/fp8.py
+11 −9 test/python/sdpa/mxfp8.py
+4 −1 test/python/sdpa/mxfp8_ref.py
+1 −0 test/python/sdpa/random_config.py
+579 −0 test/python/test_cudnn_sdpa_op.py
+32 −6 test/python/test_mhas_v2.py
+107 −0 test/python/test_sdpa_fp8_serialization.py
+7 −1 tools/cudnn_repro/README.md
+13 −34 tools/cudnn_repro/cudnn_repro/__main__.py
+44 −0 tools/cudnn_repro/cudnn_repro/repro_command.py
+55 −0 tools/cudnn_repro/cudnn_repro/routing.py
+2 −7 tools/cudnn_repro/cudnn_repro/stage1_annotate.py
+67 −15 tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_bwd.py
+168 −0 tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_fp8_bwd.py
+168 −0 tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_fp8_fwd.py
+2 −7 tools/cudnn_repro/cudnn_repro/stage2_build_repro.py
+4 −32 tools/cudnn_repro/cudnn_repro/stage2_build_repro_sdpa_bwd.py
+26 −0 tools/cudnn_repro/cudnn_repro/stage2_build_repro_sdpa_fp8_bwd.py
+26 −0 tools/cudnn_repro/cudnn_repro/stage2_build_repro_sdpa_fp8_fwd.py
+4 −31 tools/cudnn_repro/cudnn_repro/stage2_build_repro_sdpa_fwd.py
+61 −0 tools/cudnn_repro/cudnn_repro/utils.py
+172 −0 tools/cudnn_repro/tests/test_cudnn_repro_bwd.py
+90 −0 tools/cudnn_repro/tests/test_cudnn_repro_closed_loop.py
+229 −0 tools/cudnn_repro/tests/test_cudnn_repro_fp8.py
+25 −0 tools/cudnn_repro/tests/test_cudnn_repro_fp8_closed_loop.py
+94 −0 tools/cudnn_repro/tests/test_cudnn_repro_schema.py
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.15.0.dev0
2.15.0
10 changes: 10 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,16 @@ Operation fuser

.. autoapiclass:: transformer_engine.pytorch.ops.SwiGLU

.. autoapifunction:: transformer_engine.pytorch.triton.mhc.mhc_fused_sinkhorn

.. autoapifunction:: transformer_engine.pytorch.triton.mhc.mhc_fused_scale

.. autoapifunction:: transformer_engine.pytorch.triton.mhc.mhc_fused_aggregate

.. autoapifunction:: transformer_engine.pytorch.triton.mhc.mhc_fused_expand_combine

.. autoapifunction:: transformer_engine.pytorch.triton.mhc.mhc_fused_projection

Deprecated functions
--------------------

Expand Down
2 changes: 1 addition & 1 deletion qa/L0_pytorch_debug_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml
pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py"

# standard sanity and numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py"
NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py"
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py"

if [ "$RET" -ne 0 ]; then
Expand Down
Empty file modified qa/L0_pytorch_lint/test.sh
100644 → 100755
Empty file.
8 changes: 5 additions & 3 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ mkdir -p "$XML_LOG_DIR"

pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"

python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
Expand All @@ -37,11 +37,11 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_PATH/tests/pytorch/test_grouped_tensor.py || test_fail "test_grouped_tensor.py"
NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_PATH/tests/pytorch/test_grouped_tensor.py || test_fail "test_grouped_tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_backward_override.xml $TE_PATH/tests/pytorch/test_backward_override.py || test_fail "test_backward_override.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
Expand All @@ -58,6 +58,8 @@ fi
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py"
# Disable autotuning to make unittests faster. In addition, disable TF32 path to fully align with the pytorch reference implementation's precision
NVTE_DISABLE_TRITON_AUTOTUNING=1 NVIDIA_TF32_OVERRIDE=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mhc.xml $TE_PATH/tests/pytorch/test_mhc.py || test_fail "test_mhc.py"

if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ add_executable(test_operator
test_multi_unpadding.cu
test_causal_softmax.cu
test_swizzle.cu
test_multi_swizzle.cu
test_swap_first_dims.cu
test_grouped_gemm.cu
../test_common.cu)
Expand Down
Loading