Skip to content

Commit d83fa56

Browse files
NabJaKumoLiu
andauthored
Add dimensionality of heads argument to SABlock (#7664)
Fixes #7661. ### Description The changes made add a parameter (_dim_head_) to set the output paramters of all the heads in the Self-attention Block (SABlock). Currently the output dimension is set to be _hidden_size_ and when increasing the number of heads this is equally distributed among all heads. ### Example The original implementation automatically determines **_equally_distributed_head_dim_**: (qkv * num_heds * equally_distributed_head_dim = 3*hidden_size in this example -> 3 * 8 * 16 = 384) ``` block = SABlock(hidden_size=128, num_heads=8) x = torch.zeros(1, 256, 128) x = block.qkv(x) print(x.shape) x = block.input_rearrange(x) print(x.shape) > torch.Size([1, 256, 384]) > torch.Size([3, 1, 8, 256, 16]) # <- This corresponds to (qkv batch num_heads sequence_length equally_distributed_head_dim) ``` The propesed implementation fixes this by setting the new argument **_dim_head_:** ``` block_new = SABlock(hidden_size=128, num_heads=8, dim_head=32) x = torch.zeros(1, 256, 128) x = block_new.qkv(x) print(x.shape) x = block_new.input_rearrange(x) print(x.shape) > torch.Size([1, 256, 384]) > torch.Size([3, 1, 8, 256, 32]) # <- This corresponds to (qkv batch num_heads sequence_length dim_head) ``` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: NabJa <nabil.jabareen@gmail.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent f278e51 commit d83fa56

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

monai/networks/blocks/selfattention.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
dropout_rate: float = 0.0,
3333
qkv_bias: bool = False,
3434
save_attn: bool = False,
35+
dim_head: int | None = None,
3536
) -> None:
3637
"""
3738
Args:
@@ -40,6 +41,7 @@ def __init__(
4041
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
4142
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
4243
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
44+
dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
4345
4446
"""
4547

@@ -52,14 +54,16 @@ def __init__(
5254
raise ValueError("hidden size should be divisible by num_heads.")
5355

5456
self.num_heads = num_heads
55-
self.out_proj = nn.Linear(hidden_size, hidden_size)
56-
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
57+
self.dim_head = hidden_size // num_heads if dim_head is None else dim_head
58+
self.inner_dim = self.dim_head * num_heads
59+
60+
self.out_proj = nn.Linear(self.inner_dim, hidden_size)
61+
self.qkv = nn.Linear(hidden_size, self.inner_dim * 3, bias=qkv_bias)
5762
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
5863
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
5964
self.drop_output = nn.Dropout(dropout_rate)
6065
self.drop_weights = nn.Dropout(dropout_rate)
61-
self.head_dim = hidden_size // num_heads
62-
self.scale = self.head_dim**-0.5
66+
self.scale = self.dim_head**-0.5
6367
self.save_attn = save_attn
6468
self.att_mat = torch.Tensor()
6569

tests/test_selfattention.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,40 @@ def test_access_attn_matrix(self):
7474
matrix_acess_blk(torch.randn(input_shape))
7575
assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1])
7676

77+
def test_number_of_parameters(self):
78+
79+
def count_sablock_params(*args, **kwargs):
80+
"""Count the number of parameters in a SABlock."""
81+
sablock = SABlock(*args, **kwargs)
82+
return sum([x.numel() for x in sablock.parameters() if x.requires_grad])
83+
84+
hidden_size = 128
85+
num_heads = 8
86+
default_dim_head = hidden_size // num_heads
87+
88+
# Default dim_head is hidden_size // num_heads
89+
nparams_default = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads)
90+
nparams_like_default = count_sablock_params(
91+
hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head
92+
)
93+
self.assertEqual(nparams_default, nparams_like_default)
94+
95+
# Increasing dim_head should increase the number of parameters
96+
nparams_custom_large = count_sablock_params(
97+
hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head * 2
98+
)
99+
self.assertGreater(nparams_custom_large, nparams_default)
100+
101+
# Decreasing dim_head should decrease the number of parameters
102+
nparams_custom_small = count_sablock_params(
103+
hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head // 2
104+
)
105+
self.assertGreater(nparams_default, nparams_custom_small)
106+
107+
# Increasing the number of heads with the default behaviour should not change the number of params.
108+
nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2)
109+
self.assertEqual(nparams_default, nparams_default_more_heads)
110+
77111

78112
if __name__ == "__main__":
79113
unittest.main()

0 commit comments

Comments
 (0)