From 4260f2a6e09477879136da413d1e8f138121e816 Mon Sep 17 00:00:00 2001 From: Tarek Elgamal Date: Thu, 28 May 2026 17:16:48 -0700 Subject: [PATCH] Bypass the softmax pytorch kernel that upscales to fp32 Summary: Bypasses the PyTorch softmax kernel in `static_attention` that upscales activations to fp32, keeping the softmax computation in fp16. Also updates `norm.py` to handle the fp16 softmax output. Differential Revision: D106729898 --- examples/models/llama/norm.py | 2 +- examples/models/llama/static_attention.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 0b6ed7f5b01..1c2a13cebaa 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -87,7 +87,7 @@ def _norm(self, x): # standard RMSNorm's eps semantics (`rsqrt(mean(x²) + eps)`) and is # large enough to survive fp16 (1e-6 alone underflows in fp16). floor_val = torch.sqrt(torch.tensor(self.dim * self.eps, dtype=x.dtype)) - norm_val = torch.clamp_min( + norm_val = torch.maximum( torch.linalg.vector_norm(x, dim=-1, keepdim=True), floor_val ) rms_norm_eps0 = ( diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 72ce31438d6..358bc344b7f 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -1406,7 +1406,9 @@ def _forward_mha( # Ungroup, add mask, and regroup attn_grouped = attn_grouped.view(1, self.n_heads, Tq, Tk) attn_grouped = attn_grouped + mask - attn_grouped = F.softmax(attn_grouped, dim=-1) + attn_max = attn_grouped.amax(dim=-1, keepdim=True) + attn_grouped = (attn_grouped - attn_max).exp() + attn_grouped = attn_grouped / attn_grouped.sum(dim=-1, keepdim=True) attn_grouped = attn_grouped.view(n_kv, n_rep, Tq, Tk) # Group v