Skip to content

pass params_dtype to qk_norm creation#2718

Open
pstjohn wants to merge 1 commit intoNVIDIA:mainfrom
pstjohn:pstjohn/qk-norm-dtype
Open

pass params_dtype to qk_norm creation#2718
pstjohn wants to merge 1 commit intoNVIDIA:mainfrom
pstjohn:pstjohn/qk-norm-dtype

Conversation

@pstjohn
Copy link
Contributor

@pstjohn pstjohn commented Feb 28, 2026

Previously layers would fail with

            assert (
>               query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
            ), "Queries, keys and values must have the same data type!"
E           AssertionError: Queries, keys and values must have the same data type!

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py:1063: AssertionError

if you created a layer with dtype != float32. This ensures the dtype of the layernorm layers match those of the base attention layer.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 28, 2026

Greptile Summary

This PR fixes a dtype mismatch bug that occurred when creating MultiheadAttention layers with params_dtype != torch.float32. The issue caused an assertion error in dot product attention because query/key normalization layers were not being initialized with the correct dtype.

Key changes:

  • Threads params_dtype parameter through to _create_qk_norm_modules method
  • Passes params_dtype to RMSNorm and LayerNorm constructors (L2Normalization correctly skipped as it's parameter-free)
  • Updates test to parametrize over params_dtype with torch.float32 and torch.bfloat16
  • Ensures all test tensors (hidden_states, encoder_output, rotary_pos_emb) use the correct dtype

The fix is straightforward and directly addresses the root cause. Tests now verify the fix works for both float32 and bfloat16 dtypes.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The change is minimal, well-targeted, and directly fixes the reported bug. The implementation correctly threads the dtype parameter to normalization layers, and the updated tests verify the fix works for both float32 and bfloat16.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/multi_head_attention.py Adds params_dtype parameter threading to _create_qk_norm_modules method and passes it to RMSNorm/LayerNorm constructors, fixing dtype mismatch issues
tests/pytorch/test_qk_norm.py Adds params_dtype parametrization to test both float32 and bfloat16, ensuring all test tensors use the correct dtype

Last reviewed commit: 8061e42

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants