From ead680c4d6336e559b14f5913e2794271bdca181 Mon Sep 17 00:00:00 2001 From: Vedaanta Agarwalla Date: Wed, 20 May 2026 22:51:29 -0700 Subject: [PATCH 1/3] tests/attention: shrink fp8_vs_f16 configs from B=2 to B=1 The 9 fp8_9..fp8_17 configs in `model_configs_fp8_vs_f16` use shapes (B=2, S=4096-8192, H=32-128, D=64-192) for the bf16-vs-fp8 reference comparison. The reference path in `test_dpa_fp8_vs_f16` materializes the full (B, H, S, S) attention matrix in bf16, and keeps a handful of them live (S, P, dP, dS, dropout-mask) simultaneously. At B=2, S=8192, H=64 the per-test peak is ~70 GiB, which exceeds the memory of common 80 GB cards (H100) and pushes the suite into OOM territory on Blackwell (~91 GB measured with the cuDNN caching allocator residue). Halving B to 1 halves the bytes of every (B, H, S, S) tensor. Measured on B200 (SM_100, cuDNN 9.23, TE main): per-test peak `torch.cuda.max_memory_allocated`: before: 70.0 GiB (fp8_14) after : 36.1 GiB (fp8_14) -48% per-test peak `nvidia-smi memory.used`: before: 96.8 GiB after : 51.3 GiB -47% test outcome (B200, develop FE, this TE): identical 618F / 2196P / 891S, wall time within ~3% The shrunk configs still exercise every distinct shape/mask/SWA/GQA combination that the originals did -- only B is smaller. The suite now fits comfortably on 80 GB cards. fp8_19/20 (B=2, S=2048) are left at B=2 because their peak is small (~few GiB) and the larger batch is useful coverage for padding_causal. Signed-off-by: Vedaanta Agarwalla --- tests/pytorch/attention/test_attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 5c46949f67..4be267d257 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1905,7 +1905,7 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig( - 2, + 1, 4096, 128, 192, @@ -1920,22 +1920,22 @@ def get_model(dtype, config): attn_mask_type="causal", ), "fp8_11": ModelConfig( - 2, + 1, 4096, 128, 192, head_dim_v=128, attn_mask_type="causal_bottom_right", ), - "fp8_12": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), - "fp8_13": ModelConfig(2, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)), - "fp8_14": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), - "fp8_15": ModelConfig(2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)), + "fp8_12": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), + "fp8_13": ModelConfig(1, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)), + "fp8_14": ModelConfig(1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), + "fp8_15": ModelConfig(1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)), "fp8_16": ModelConfig( - 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + 1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" ), "fp8_17": ModelConfig( - 2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" + 1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" ), "fp8_18": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), "fp8_19": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), From 6b4720a4f2e9ce06d964efa03aa1cae7f0a03a6c Mon Sep 17 00:00:00 2001 From: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com> Date: Thu, 21 May 2026 15:27:12 -0700 Subject: [PATCH 2/3] address changes recommended by Kshitij Signed-off-by: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 4be267d257..ecc24b01e4 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1905,37 +1905,37 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig( - 1, - 4096, + 2, + 2048, 128, 192, head_dim_v=128, ), "fp8_10": ModelConfig( - 1, - 4096, + 2, + 2048, 128, 192, head_dim_v=128, attn_mask_type="causal", ), "fp8_11": ModelConfig( - 1, - 4096, + 2, + 2048, 128, 192, head_dim_v=128, attn_mask_type="causal_bottom_right", ), "fp8_12": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), - "fp8_13": ModelConfig(1, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)), - "fp8_14": ModelConfig(1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), + "fp8_13": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)), + "fp8_14": ModelConfig(2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), "fp8_15": ModelConfig(1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)), "fp8_16": ModelConfig( 1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" ), "fp8_17": ModelConfig( - 1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" + 2, 4096, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" ), "fp8_18": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), "fp8_19": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), From 56b183702bf322ba94c1f6b893d5e503d73d5383 Mon Sep 17 00:00:00 2001 From: Vedaanta Agarwalla Date: Thu, 21 May 2026 15:44:52 -0700 Subject: [PATCH 3/3] tests/attention: black format fp8_13 ModelConfig Line was 105 chars; black requires <=100 with the project's preview+ string_processing settings. Signed-off-by: Vedaanta Agarwalla --- tests/pytorch/attention/test_attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index ecc24b01e4..f980421798 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1928,7 +1928,9 @@ def get_model(dtype, config): attn_mask_type="causal_bottom_right", ), "fp8_12": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), - "fp8_13": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)), + "fp8_13": ModelConfig( + 2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0) + ), "fp8_14": ModelConfig(2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), "fp8_15": ModelConfig(1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)), "fp8_16": ModelConfig(