Commit d83fa56
Add dimensionality of heads argument to SABlock (#7664)
Fixes #7661.
### Description
The changes made add a parameter (_dim_head_) to set the output
paramters of all the heads in the Self-attention Block (SABlock).
Currently the output dimension is set to be _hidden_size_ and when
increasing the number of heads this is equally distributed among all
heads.
### Example
The original implementation automatically determines
**_equally_distributed_head_dim_**:
(qkv * num_heds * equally_distributed_head_dim = 3*hidden_size
in this example -> 3 * 8 * 16 = 384)
```
block = SABlock(hidden_size=128, num_heads=8)
x = torch.zeros(1, 256, 128)
x = block.qkv(x)
print(x.shape)
x = block.input_rearrange(x)
print(x.shape)
> torch.Size([1, 256, 384])
> torch.Size([3, 1, 8, 256, 16]) # <- This corresponds to (qkv batch num_heads sequence_length equally_distributed_head_dim)
```
The propesed implementation fixes this by setting the new argument
**_dim_head_:**
```
block_new = SABlock(hidden_size=128, num_heads=8, dim_head=32)
x = torch.zeros(1, 256, 128)
x = block_new.qkv(x)
print(x.shape)
x = block_new.input_rearrange(x)
print(x.shape)
> torch.Size([1, 256, 384])
> torch.Size([3, 1, 8, 256, 32]) # <- This corresponds to (qkv batch num_heads sequence_length dim_head)
```
### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: NabJa <nabil.jabareen@gmail.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>1 parent f278e51 commit d83fa56
File tree
2 files changed
+42
-4
lines changed- monai/networks/blocks
- tests
2 files changed
+42
-4
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
32 | 32 | | |
33 | 33 | | |
34 | 34 | | |
| 35 | + | |
35 | 36 | | |
36 | 37 | | |
37 | 38 | | |
| |||
40 | 41 | | |
41 | 42 | | |
42 | 43 | | |
| 44 | + | |
43 | 45 | | |
44 | 46 | | |
45 | 47 | | |
| |||
52 | 54 | | |
53 | 55 | | |
54 | 56 | | |
55 | | - | |
56 | | - | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
57 | 62 | | |
58 | 63 | | |
59 | 64 | | |
60 | 65 | | |
61 | | - | |
62 | | - | |
| 66 | + | |
63 | 67 | | |
64 | 68 | | |
65 | 69 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
74 | 74 | | |
75 | 75 | | |
76 | 76 | | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
77 | 111 | | |
78 | 112 | | |
79 | 113 | | |
0 commit comments