1- from typing import Tuple
1+ from typing import Any , Dict , Optional , Tuple
22
33import torch
44import torch .nn as nn
5+ import torch .nn .functional as F
56
6- from cellseg_models_pytorch . modules import SelfAttentionBlock
7-
7+ from . base_modules import Identity
8+ from . misc_modules import LayerScale
89from .mlp import MlpBlock
910from .patch_embeddings import ContiguousEmbed
11+ from .self_attention_modules import SelfAttentionBlock
1012
1113__all__ = ["Transformer2D" , "TransformerLayer" ]
1214
@@ -23,10 +25,12 @@ def __init__(
2325 computation_types : Tuple [str , ...] = ("basic" , "basic" ),
2426 dropouts : Tuple [float , ...] = (0.0 , 0.0 ),
2527 biases : Tuple [bool , ...] = (False , False ),
28+ layer_scales : Tuple [bool , ...] = (False , False ),
2629 activation : str = "star_relu" ,
2730 num_groups : int = 32 ,
28- slice_size : int = 4 ,
29- mlp_ratio : int = 4 ,
31+ mlp_ratio : int = 2 ,
32+ slice_size : Optional [int ] = 4 ,
33+ patch_embed_kwargs : Optional [Dict [str , Any ]] = None ,
3034 ** kwargs ,
3135 ) -> None :
3236 """Create a transformer for 2D-image-like (B, C, H, W) inputs.
@@ -49,7 +53,7 @@ def __init__(
4953 n_blocks : int, default=2
5054 Number of Multihead attention blocks in the transformer.
5155 block_types : Tuple[str, ...], default=("exact", "exact")
52- The name of the SelfAttentionBlocks in the TransformerLayer.
56+ The names/types of the SelfAttentionBlocks in the TransformerLayer.
5357 Length of the tuple has to equal `n_blocks`.
5458 Allowed names: ("exact", "linformer").
5559 computation_types : Tuple[str, ...], default=("basic", "basic")
@@ -60,18 +64,23 @@ def __init__(
6064 Dropout probabilities for the SelfAttention blocks.
6165 biases : bool, default=(True, True)
6266 Include bias terms in the SelfAttention blocks.
67+ layer_scales : bool, default=(False, False)
68+ Learnable layer weights for the self-attention matrix.
6369 activation : str, default="star_relu"
6470 The activation function applied at the end of the transformer layer fc.
6571 One of ("geglu", "approximate_gelu", "star_relu").
6672 num_groups : int, default=32
6773 Number of groups in the first group-norm op before the input is
6874 projected to be suitable for self-attention.
69- slice_size : int, default=4
75+ mlp_ratio : int, default=2
76+ Scaling factor for the number of input features to get the number of
77+ hidden features in the final `Mlp` layer of the transformer.
78+ slice_size : int, optional, default=4
7079 Slice size for sliced self-attention. This is used only if
7180 `name = "slice"` for a SelfAttentionBlock.
72- mlp_ratio : int, default=4
73- Multiplier that defines the out dimension of the final fc projection
74- layer .
81+ patch_embed_kwargs: Dict[str, Any], optional
82+ Extra key-word arguments for the patch embedding module. See the
83+ `ContiguousEmbed` module for more info .
7584 """
7685 super ().__init__ ()
7786 patch_norm = "gn" if in_channels >= 32 else None
@@ -82,6 +91,7 @@ def __init__(
8291 num_heads = num_heads ,
8392 normalization = patch_norm ,
8493 norm_kwargs = {"num_features" : in_channels , "num_groups" : num_groups },
94+ ** patch_embed_kwargs if patch_embed_kwargs is not None else {},
8595 )
8696 self .proj_dim = self .patch_embed .proj_dim
8797
@@ -95,6 +105,7 @@ def __init__(
95105 computation_types = computation_types ,
96106 dropouts = dropouts ,
97107 biases = biases ,
108+ layer_scales = layer_scales ,
98109 activation = activation ,
99110 slice_size = slice_size ,
100111 mlp_ratio = mlp_ratio ,
@@ -130,11 +141,22 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor
130141 # 2. transformer
131142 x = self .transformer (x , context )
132143
133- # 3. Reshape back to image-like shape and project to original input channels.
134- x = x .reshape (B , H , W , self .proj_dim ).permute (0 , 3 , 1 , 2 )
144+ # 3. Reshape back to image-like shape.
145+ p_H = self .patch_embed .get_patch_size (H )
146+ p_W = self .patch_embed .get_patch_size (W )
147+ x = x .reshape (B , p_H , p_W , self .proj_dim ).permute (0 , 3 , 1 , 2 )
148+
149+ # Upsample to input dims if patch size less than orig inp size
150+ # assumes that the input is square mat.
151+ # NOTE: the kernel_size, pad, & stride has to be set correctly for this to work
152+ if p_H < H :
153+ scale_factor = H // p_H
154+ x = F .interpolate (x , scale_factor = scale_factor , mode = "bilinear" )
155+
156+ # 4. project to original input channels
135157 x = self .proj_out (x )
136158
137- # 4 . residual
159+ # 5 . residual
138160 return x + residual
139161
140162
@@ -151,8 +173,9 @@ def __init__(
151173 computation_types : Tuple [str , ...] = ("basic" , "basic" ),
152174 dropouts : Tuple [float , ...] = (0.0 , 0.0 ),
153175 biases : Tuple [bool , ...] = (False , False ),
154- slice_size : int = 4 ,
155- mlp_ratio : int = 4 ,
176+ layer_scales : Tuple [bool , ...] = (False , False ),
177+ mlp_ratio : int = 2 ,
178+ slice_size : Optional [int ] = 4 ,
156179 ** kwargs ,
157180 ) -> None :
158181 """Chain transformer blocks to compose a full generic transformer.
@@ -191,12 +214,14 @@ def __init__(
191214 Dropout probabilities for the SelfAttention blocks.
192215 biases : bool, default=(True, True)
193216 Include bias terms in the SelfAttention blocks.
194- slice_size : int, default=4
217+ layer_scales : bool, default=(False, False)
218+ Learnable layer weights for the self-attention matrix.
219+ mlp_ratio : int, default=2
220+ Scaling factor for the number of input features to get the number of
221+ hidden features in the final `Mlp` layer of the transformer.
222+ slice_size : int, optional, default=4
195223 Slice size for sliced self-attention. This is used only if
196224 `name = "slice"` for a SelfAttentionBlock.
197- mlp_proj : int, default=4
198- Multiplier that defines the out dimension of the final fc projection
199- layer.
200225 **kwargs:
201226 Arbitrary key-word arguments.
202227
@@ -218,7 +243,9 @@ def __init__(
218243 f"Illegal args: { illegal_args } "
219244 )
220245
221- self .tr_blocks = nn .ModuleDict ()
246+ # self.tr_blocks = nn.ModuleDict()
247+ self .tr_blocks = nn .ModuleList ()
248+ self .layer_scales = nn .ModuleList ()
222249 blocks = list (range (n_blocks ))
223250 for i in blocks :
224251 cross_dim = cross_attention_dim if i == blocks [- 1 ] else None
@@ -235,7 +262,13 @@ def __init__(
235262 slice_size = slice_size ,
236263 ** kwargs ,
237264 )
238- self .tr_blocks [f"transformer_{ block_types [i ]} _{ i + 1 } " ] = att_block
265+ self .tr_blocks .append (att_block )
266+
267+ # add layer scale. (Optional)
268+ ls = LayerScale (query_dim ) if layer_scales [i ] else Identity ()
269+ self .layer_scales .append (ls )
270+
271+ # self.tr_blocks[f"transformer_{block_types[i]}_{i + 1}"] = tr_block
239272
240273 self .mlp = MlpBlock (
241274 in_channels = query_dim ,
@@ -263,12 +296,14 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor
263296 Self-attended input tensor. Shape (B, H*W, query_dim).
264297 """
265298 n_blocks = len (self .tr_blocks )
266- for i , tr_block in enumerate (self .tr_blocks .values (), 1 ):
299+
300+ for i , (tr_block , ls ) in enumerate (zip (self .tr_blocks , self .layer_scales ), 1 ):
267301 # apply context only at the last transformer block
268302 con = None
269303 if i == n_blocks :
270304 con = context
271305
272306 x = tr_block (x , con )
307+ x = ls (x )
273308
274309 return self .mlp (x ) + x
0 commit comments