Skip to content

Commit 57d4609

Browse files
fix(UNETR): access Bias term in SAB block (#5149)
Fixes #5148 . ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: a-parida12 <abhijeet.parida@tum.de> Signed-off-by: monai-bot <monai.miccai2019@gmail.com> Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
1 parent a1e1af5 commit 57d4609

File tree

4 files changed

+14
-5
lines changed

4 files changed

+14
-5
lines changed

monai/networks/blocks/selfattention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@ class SABlock(nn.Module):
2323
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
2424
"""
2525

26-
def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0) -> None:
26+
def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None:
2727
"""
2828
Args:
2929
hidden_size: dimension of hidden layer.
3030
num_heads: number of attention heads.
3131
dropout_rate: faction of the input units to drop.
32+
qkv_bias: bias term for the qkv linear layer.
3233
3334
"""
3435

@@ -42,7 +43,7 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0)
4243

4344
self.num_heads = num_heads
4445
self.out_proj = nn.Linear(hidden_size, hidden_size)
45-
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False)
46+
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
4647
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
4748
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
4849
self.drop_output = nn.Dropout(dropout_rate)

monai/networks/blocks/transformerblock.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,16 @@ class TransformerBlock(nn.Module):
2121
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
2222
"""
2323

24-
def __init__(self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0) -> None:
24+
def __init__(
25+
self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False
26+
) -> None:
2527
"""
2628
Args:
2729
hidden_size: dimension of hidden layer.
2830
mlp_dim: dimension of feedforward layer.
2931
num_heads: number of attention heads.
3032
dropout_rate: faction of the input units to drop.
33+
qkv_bias: apply bias term for the qkv linear layer
3134
3235
"""
3336

@@ -41,7 +44,7 @@ def __init__(self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate:
4144

4245
self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
4346
self.norm1 = nn.LayerNorm(hidden_size)
44-
self.attn = SABlock(hidden_size, num_heads, dropout_rate)
47+
self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias)
4548
self.norm2 = nn.LayerNorm(hidden_size)
4649

4750
def forward(self, x):

monai/networks/nets/unetr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
res_block: bool = True,
4141
dropout_rate: float = 0.0,
4242
spatial_dims: int = 3,
43+
qkv_bias: bool = False,
4344
) -> None:
4445
"""
4546
Args:
@@ -56,6 +57,7 @@ def __init__(
5657
res_block: bool argument to determine if residual block is used.
5758
dropout_rate: faction of the input units to drop.
5859
spatial_dims: number of spatial dims.
60+
qkv_bias: apply the bias term for the qkv linear layer in self attention block
5961
6062
Examples::
6163
@@ -96,6 +98,7 @@ def __init__(
9698
classification=self.classification,
9799
dropout_rate=dropout_rate,
98100
spatial_dims=spatial_dims,
101+
qkv_bias=qkv_bias,
99102
)
100103
self.encoder1 = UnetrBasicBlock(
101104
spatial_dims=spatial_dims,

monai/networks/nets/vit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
dropout_rate: float = 0.0,
4545
spatial_dims: int = 3,
4646
post_activation="Tanh",
47+
qkv_bias: bool = False,
4748
) -> None:
4849
"""
4950
Args:
@@ -61,6 +62,7 @@ def __init__(
6162
spatial_dims: number of spatial dimensions.
6263
post_activation: add a final acivation function to the classification head when `classification` is True.
6364
Default to "Tanh" for `nn.Tanh()`. Set to other values to remove this function.
65+
qkv_bias: apply bias to the qkv linear layer in self attention block
6466
6567
Examples::
6668
@@ -95,7 +97,7 @@ def __init__(
9597
spatial_dims=spatial_dims,
9698
)
9799
self.blocks = nn.ModuleList(
98-
[TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)]
100+
[TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias) for i in range(num_layers)]
99101
)
100102
self.norm = nn.LayerNorm(hidden_size)
101103
if self.classification:

0 commit comments

Comments
 (0)