diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 0b6ed7f5b01..1c2a13cebaa 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -87,7 +87,7 @@ def _norm(self, x): # standard RMSNorm's eps semantics (`rsqrt(mean(x²) + eps)`) and is # large enough to survive fp16 (1e-6 alone underflows in fp16). floor_val = torch.sqrt(torch.tensor(self.dim * self.eps, dtype=x.dtype)) - norm_val = torch.clamp_min( + norm_val = torch.maximum( torch.linalg.vector_norm(x, dim=-1, keepdim=True), floor_val ) rms_norm_eps0 = ( diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 72ce31438d6..358bc344b7f 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -1406,7 +1406,9 @@ def _forward_mha( # Ungroup, add mask, and regroup attn_grouped = attn_grouped.view(1, self.n_heads, Tq, Tk) attn_grouped = attn_grouped + mask - attn_grouped = F.softmax(attn_grouped, dim=-1) + attn_max = attn_grouped.amax(dim=-1, keepdim=True) + attn_grouped = (attn_grouped - attn_max).exp() + attn_grouped = attn_grouped / attn_grouped.sum(dim=-1, keepdim=True) attn_grouped = attn_grouped.view(n_kv, n_rep, Tq, Tk) # Group v