Skip to content

Commit 828a491

Browse files
authored
fix(unter,vitautoenc): acess the attn mat (#6493)
Fixes #6492 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: a-parida12 <abhijeet.parida@tum.de>
1 parent d421efc commit 828a491

File tree

2 files changed

+33
-23
lines changed

2 files changed

+33
-23
lines changed

monai/networks/nets/unetr.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,25 @@ def __init__(
4343
dropout_rate: float = 0.0,
4444
spatial_dims: int = 3,
4545
qkv_bias: bool = False,
46+
save_attn: bool = False,
4647
) -> None:
4748
"""
4849
Args:
4950
in_channels: dimension of input channels.
5051
out_channels: dimension of output channels.
5152
img_size: dimension of input image.
52-
feature_size: dimension of network feature size.
53-
hidden_size: dimension of hidden layer.
54-
mlp_dim: dimension of feedforward layer.
55-
num_heads: number of attention heads.
56-
pos_embed: position embedding layer type.
57-
norm_name: feature normalization type and arguments.
58-
conv_block: bool argument to determine if convolutional block is used.
59-
res_block: bool argument to determine if residual block is used.
60-
dropout_rate: faction of the input units to drop.
61-
spatial_dims: number of spatial dims.
62-
qkv_bias: apply the bias term for the qkv linear layer in self attention block
53+
feature_size: dimension of network feature size. Defaults to 16.
54+
hidden_size: dimension of hidden layer. Defaults to 768.
55+
mlp_dim: dimension of feedforward layer. Defaults to 3072.
56+
num_heads: number of attention heads. Defaults to 12.
57+
pos_embed: position embedding layer type. Defaults to "conv".
58+
norm_name: feature normalization type and arguments. Defaults to "instance".
59+
conv_block: if convolutional block is used. Defaults to True.
60+
res_block: if residual block is used. Defaults to True.
61+
dropout_rate: fraction of the input units to drop. Defaults to 0.0.
62+
spatial_dims: number of spatial dims. Defaults to 3.
63+
qkv_bias: apply the bias term for the qkv linear layer in self attention block. Defaults to False.
64+
save_attn: to make accessible the attention in self attention block. Defaults to False.
6365
6466
Examples::
6567
@@ -101,6 +103,7 @@ def __init__(
101103
dropout_rate=dropout_rate,
102104
spatial_dims=spatial_dims,
103105
qkv_bias=qkv_bias,
106+
save_attn=save_attn,
104107
)
105108
self.encoder1 = UnetrBasicBlock(
106109
spatial_dims=spatial_dims,

monai/networks/nets/vitautoenc.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,25 @@ def __init__(
4646
pos_embed: str = "conv",
4747
dropout_rate: float = 0.0,
4848
spatial_dims: int = 3,
49+
qkv_bias: bool = False,
50+
save_attn: bool = False,
4951
) -> None:
5052
"""
5153
Args:
52-
in_channels: dimension of input channels or the number of channels for input
54+
in_channels: dimension of input channels or the number of channels for input.
5355
img_size: dimension of input image.
54-
patch_size: dimension of patch size.
55-
hidden_size: dimension of hidden layer.
56-
out_channels: number of output channels.
57-
deconv_chns: number of channels for the deconvolution layers.
58-
mlp_dim: dimension of feedforward layer.
59-
num_layers: number of transformer blocks.
60-
num_heads: number of attention heads.
61-
pos_embed: position embedding layer type.
62-
dropout_rate: faction of the input units to drop.
63-
spatial_dims: number of spatial dimensions.
56+
patch_size: dimension of patch size
57+
out_channels: number of output channels. Defaults to 1.
58+
deconv_chns: number of channels for the deconvolution layers. Defaults to 16.
59+
hidden_size: dimension of hidden layer. Defaults to 768.
60+
mlp_dim: dimension of feedforward layer. Defaults to 3072.
61+
num_layers: number of transformer blocks. Defaults to 12.
62+
num_heads: number of attention heads. Defaults to 12.
63+
pos_embed: position embedding layer type. Defaults to "conv".
64+
dropout_rate: faction of the input units to drop. Defaults to 0.0.
65+
spatial_dims: number of spatial dimensions. Defaults to 3.
66+
qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False.
67+
save_attn: to make accessible the attention in self attention block. Defaults to False. Defaults to False.
6468
6569
Examples::
6670
@@ -89,7 +93,10 @@ def __init__(
8993
spatial_dims=self.spatial_dims,
9094
)
9195
self.blocks = nn.ModuleList(
92-
[TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)]
96+
[
97+
TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
98+
for i in range(num_layers)
99+
]
93100
)
94101
self.norm = nn.LayerNorm(hidden_size)
95102

0 commit comments

Comments
 (0)