Describe the bug
Hi ,
I noticed that TransformerBlock always instantiates cross-attention layers even when with_cross_attention=False.
Current implementation:
self.with_cross_attention = with_cross_attention
self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)
but in forward:
if self.with_cross_attention:
x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
So when with_cross_attention=False the cross-attention modules are never used but their parameters still appear ( registered) in model.parameters()
To Reproduce
With self.with_cross_attention =False ,
for name, p in model.named_parameters():
if p.grad is None:
print("NO GRAD:", name)
[NO GRAD] encoder.vit.vit.blocks.1.norm_cross_attn.weight
[NO GRAD] encoder.vit.vit.blocks.1.norm_cross_attn.bias
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.out_proj.weight
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.out_proj.bias
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.to_q.weight
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.to_q.bias
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.to_k.weight
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.to_k.bias
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.to_v.weight
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.to_v.bias
Expected behavior
There should be no params registered for the cross attention module , if the transformer block doesnt use cross attnetion.
Describe the bug
Hi ,
I noticed that TransformerBlock always instantiates cross-attention layers even when with_cross_attention=False.
Current implementation:
but in forward:
So when with_cross_attention=False the cross-attention modules are never used but their parameters still appear ( registered) in model.parameters()
To Reproduce
[NO GRAD] encoder.vit.vit.blocks.1.norm_cross_attn.weight
[NO GRAD] encoder.vit.vit.blocks.1.norm_cross_attn.bias
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.out_proj.weight
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.out_proj.bias
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.to_q.weight
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.to_q.bias
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.to_k.weight
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.to_k.bias
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.to_v.weight
[NO GRAD] encoder.vit.vit.blocks.1.cross_attn.to_v.bias
Expected behavior
There should be no params registered for the cross attention module , if the transformer block doesnt use cross attnetion.