Skip to content

Commit 1645458

Browse files
committed
fix: only instantiate CrossAttentionBlock when with_cross_attention=True
[200~Fixes #8845 TransformerBlock previously instantiated norm_cross_attn and cross_attn unconditionally, even when with_cross_attention=False. These unused modules registered dead parameters in model.parameters(), wasting memory. Wrapped both instantiations in `if with_cross_attention:` to match the existing guard in forward(). Added tests to verify the modules and their parameters are absent when disabled, present when enabled, and that the forward pass with a context tensor works correctly.~ Signed-off-by: chhayankjain <chhayank44@gmail.com>
1 parent 586dea1 commit 1645458

2 files changed

Lines changed: 47 additions & 9 deletions

File tree

monai/networks/blocks/transformerblock.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,20 @@ def __init__(
4646
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
4747
qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
4848
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
49+
causal (bool, optional): whether to apply causal masking in self-attention. Defaults to False.
50+
sequence_length (int | None, optional): sequence length required for causal masking. Defaults to None.
51+
with_cross_attention (bool, optional): whether to include cross-attention layers that attend to an
52+
external context tensor. When False, norm_cross_attn and cross_attn are not instantiated.
53+
Defaults to False.
4954
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
5055
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
5156
include_fc: whether to include the final linear layer. Default to True.
5257
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
5358
59+
Raises:
60+
ValueError: if dropout_rate is not in [0, 1].
61+
ValueError: if hidden_size is not divisible by num_heads.
62+
5463
"""
5564

5665
super().__init__()
@@ -78,15 +87,16 @@ def __init__(
7887
self.norm2 = nn.LayerNorm(hidden_size)
7988
self.with_cross_attention = with_cross_attention
8089

81-
self.norm_cross_attn = nn.LayerNorm(hidden_size)
82-
self.cross_attn = CrossAttentionBlock(
83-
hidden_size=hidden_size,
84-
num_heads=num_heads,
85-
dropout_rate=dropout_rate,
86-
qkv_bias=qkv_bias,
87-
causal=False,
88-
use_flash_attention=use_flash_attention,
89-
)
90+
if with_cross_attention:
91+
self.norm_cross_attn = nn.LayerNorm(hidden_size)
92+
self.cross_attn = CrossAttentionBlock(
93+
hidden_size=hidden_size,
94+
num_heads=num_heads,
95+
dropout_rate=dropout_rate,
96+
qkv_bias=qkv_bias,
97+
causal=False,
98+
use_flash_attention=use_flash_attention,
99+
)
90100

91101
def forward(
92102
self, x: torch.Tensor, context: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None

tests/networks/blocks/test_transformerblock.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,34 @@ def test_ill_arg(self):
5353
with self.assertRaises(ValueError):
5454
TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4)
5555

56+
@skipUnless(has_einops, "Requires einops")
57+
def test_cross_attention_params_not_registered_when_disabled(self):
58+
block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=False)
59+
param_names = [name for name, _ in block.named_parameters()]
60+
self.assertFalse(any("cross_attn" in n for n in param_names))
61+
self.assertFalse(any("norm_cross_attn" in n for n in param_names))
62+
self.assertFalse(hasattr(block, "cross_attn"))
63+
self.assertFalse(hasattr(block, "norm_cross_attn"))
64+
65+
@skipUnless(has_einops, "Requires einops")
66+
def test_cross_attention_params_registered_when_enabled(self):
67+
block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=True)
68+
self.assertTrue(hasattr(block, "cross_attn"))
69+
self.assertTrue(hasattr(block, "norm_cross_attn"))
70+
param_names = [name for name, _ in block.named_parameters()]
71+
self.assertTrue(any("cross_attn" in n for n in param_names))
72+
self.assertTrue(any("norm_cross_attn" in n for n in param_names))
73+
74+
@skipUnless(has_einops, "Requires einops")
75+
def test_cross_attention_forward_with_context(self):
76+
hidden_size = 128
77+
block = TransformerBlock(hidden_size=hidden_size, mlp_dim=256, num_heads=4, with_cross_attention=True)
78+
x = torch.randn(2, 16, hidden_size)
79+
context = torch.randn(2, 8, hidden_size)
80+
with eval_mode(block):
81+
out = block(x, context=context)
82+
self.assertEqual(out.shape, x.shape)
83+
5684
@skipUnless(has_einops, "Requires einops")
5785
def test_access_attn_matrix(self):
5886
# input format

0 commit comments

Comments
 (0)