1- from typing import Dict , List , Optional , Tuple
1+ from typing import Any , Dict , List , Optional , Tuple , Union
22
33import torch
44import torch .nn as nn
@@ -13,50 +13,77 @@ def __init__(
1313 self ,
1414 enc_channels : Tuple [int , ...],
1515 out_channels : Tuple [int , ...] = (256 , 128 , 64 , 32 , 16 ),
16- style_channels : int = None ,
17- n_conv_layers : Tuple [int , ...] = (1 , 1 , 1 , 1 , 1 ),
18- n_conv_blocks : Tuple [Tuple [int , ...], ...] = ((2 ,), (2 ,), (2 ,), (2 ,), (2 ,)),
19- long_skip : str = "unet" ,
20- n_transformers : Tuple [int , ...] = None ,
21- n_transformer_blocks : Tuple [Tuple [int ], ...] = ((1 ,), (1 ,), (1 ,), (1 ,), (1 ,)),
16+ long_skip : Union [None , str , Tuple [str , ...]] = "unet" ,
17+ n_conv_layers : Union [None , int , Tuple [int , ...]] = 1 ,
18+ n_transformers : Union [None , int , Tuple [int , ...]] = None ,
19+ n_conv_blocks : Union [int , Tuple [Tuple [int , ...], ...]] = 2 ,
20+ n_transformer_blocks : Union [int , Tuple [Tuple [int ], ...]] = 1 ,
2221 stage_params : Optional [Tuple [Dict , ...]] = None ,
22+ style_channels : int = None ,
2323 ** kwargs ,
2424 ) -> None :
2525 """Build a generic U-net-like decoder.
2626
27+ I.e stack decoder stages that are composed followingly:
28+
29+ DecoderStage:
30+ - UpSample(up_method)
31+ - LongSkip(long_skip_method)
32+ - ConvLayer (optional)
33+ - ConvBlock(conv_block_method)
34+ - TransformerLayer (optional)
35+ - TransformerBlock(transformer_block_method)
36+
2737 Parameters
2838 ----------
2939 enc_channels : Tuple[int, ...]
3040 Number of channels at each encoder layer.
3141 out_channels : Tuple[int, ...], default=(256, 128, 64, 32, 16)
3242 Number of channels at each decoder layer output.
33- style_channels : int, default=None
34- Number of style vector channels. If None, style vectors are ignored.
35- n_conv_layers : Tuple[int, ...], default=(1, 1, 1, 1, 1)
36- The number of conv layers inside each of the decoder stages.
37- n_conv_blocks : Tuple[Tuple[int, ...], ...] =((2, ),(2, ),(2, ),(2, ),(2, ))
38- The number of blocks inside each conv-layer at each decoder stage.
39- long_skip : str, default="unet"
40- long skip method to be used. One of: "unet", "unetpp", "unet3p",
41- "unet3p-lite", None
42- n_transformers : Tuple[int, ...], optional, default=None
43- The number of transformer layers inside each of the decoder stages.
44- n_transformer_blocks : Tuple[Tuple[int]] = ((1, ),(1, ),(1, ),(1, ),(1, ))
43+ long_skip : Union[None, str, Tuple[str, ...]], default="unet"
44+ long skip method to be used. The argument can be given as a tuple, where
45+ each value indicates the long-skip method for each stage of the decoder,
46+ allowing the mixing of long-skip methods in the decoder.
47+ Allowed: "cross-attn", "unet", "unetpp", "unet3p", "unet3p-lite", None
48+ n_conv_layers : Union[None, int, Tuple[int, ...]], default=1
49+ The number of convolution layers inside each of the decoder stages. The
50+ argument can be given as a tuple, where each value indicates the number
51+ of conv-layers inside each stage of the decoder allowing the mixing of
52+ different sized layers inside the stages in the decoder. If set to None,
53+ no conv-layers will be included in the decoder.
54+ n_transformers : Union[None, int, Tuple[int, ...]] , optional
55+ The number of transformer layers inside each of the decoder stages. The
56+ argument can be given as a tuple, where each value indicates the number
57+ of transformer-layers inside each stage of the decoder allowing the
58+ mixing of different sized layers inside the stages in the decoder. If
59+ set to None, no transformer layers will be included in the decoder.
60+ n_conv_blocks : Union[int, Tuple[Tuple[int, ...], ...]], default=2
61+ The number of blocks inside each conv-layer at each decoder stage. The
62+ argument can be given as a nested tuple, where each value indicates the
63+ number of `ConvBlock`s inside a single `ConvLayer` allowing different
64+ sized blocks inside each conv-layer in the decoder.
65+ n_transformer_blocks : Union[int, Tuple[Tuple[int], ...]], default=1
4566 The number of transformer blocks inside each transformer-layer at each
46- decoder stage.
67+ decoder stage. The argument can be given as a nested tuple, where each
68+ value indicates the number of `SelfAttention`s inside a single
69+ `TranformerLayer` allowing different sized transformer blocks inside
70+ each transformer-layer in the decoder.
4771 stage_params : Optional[Tuple[Dict, ...]], default=None
4872 The keyword args for each of the distinct decoder stages. Incudes the
4973 parameters for the long skip connections, convolutional layers of the
5074 decoder and transformer layers itself. See the `DecoderStage`
5175 documentation for more info.
76+ style_channels : int, default=None
77+ Number of style vector channels. If None, style vectors are ignored.
78+ If `n_conv_layers` is None, this is ignored since style vectors are
79+ applied inside `ConvBlocks`.
5280
5381 Raises
5482 ------
5583 ValueError:
5684 If there is a mismatch between encoder and decoder channel lengths.
5785 """
5886 super ().__init__ ()
59- self .long_skip = long_skip
6087
6188 if not len (out_channels ) == len (enc_channels ):
6289 raise ValueError (
@@ -70,66 +97,105 @@ def __init__(
7097
7198 # scaling factor assumed to be 2 for the spatial dims and the input
7299 # has to be divisible by 32. 256 used here just for convenience.
73- depth = len (out_channels )
74- out_dims = [256 // 2 ** i for i in range (depth )][::- 1 ]
100+ self . depth = len (out_channels )
101+ out_dims = [256 // 2 ** i for i in range (self . depth )][::- 1 ]
75102
76- # Build decoder
77- for i in range (depth - 1 ):
78- # number of conv layers
79- n_clayers = None
80- if n_conv_layers is not None :
81- n_clayers = n_conv_layers [i ]
82-
83- # number of conv blocks inside each layer
84- n_cblocks = None
85- if n_conv_blocks is not None :
86- n_cblocks = n_conv_blocks [i ]
87-
88- # number of transformer layers
89- n_tr_layers = None
90- if n_transformers is not None :
91- n_tr_layers = n_transformers [i ]
92-
93- # number of transformer blocks inside transformer layers
94- n_tr_blocks = None
95- if n_transformer_blocks is not None :
96- n_tr_blocks = n_transformer_blocks [i ]
103+ # set layer-level tuple-args
104+ self .long_skips = self ._layer_tuple (long_skip )
105+ n_conv_layers = self ._layer_tuple (n_conv_layers )
106+ n_transformers = self ._layer_tuple (n_transformers )
97107
108+ # set block-level tuple-args
109+ n_conv_blocks = self ._block_tuple (n_conv_blocks , n_conv_layers )
110+ n_transformer_blocks = self ._block_tuple (n_transformer_blocks , n_transformers )
111+
112+ # Build decoder
113+ for i in range (self .depth - 1 ):
98114 decoder_block = DecoderStage (
99115 stage_ix = i ,
100116 dec_channels = tuple (out_channels ),
101117 dec_dims = tuple (out_dims ),
102118 skip_channels = skip_channels ,
119+ long_skip = self ._tup_arg (self .long_skips , i ),
120+ n_conv_layers = self ._tup_arg (n_conv_layers , i ),
121+ n_conv_blocks = self ._tup_arg (n_conv_blocks , i ),
122+ n_transformers = self ._tup_arg (n_transformers , i ),
123+ n_transformer_blocks = self ._tup_arg (n_transformer_blocks , i ),
103124 style_channels = style_channels ,
104- long_skip = long_skip ,
105- n_conv_layers = n_clayers ,
106- n_conv_blocks = n_cblocks ,
107- n_transformers = n_tr_layers ,
108- n_transformer_blocks = n_tr_blocks ,
109125 ** stage_params [i ] if stage_params is not None else {"k" : None },
110126 )
111127 self .add_module (f"decoder_stage{ i + 1 } " , decoder_block )
112128
113129 self .out_channels = decoder_block .out_channels
114130
131+ def _tup_arg (self , tup : Tuple [Any , ...], ix : int ) -> Union [None , int , str ]:
132+ """Return None if given tuple-arg is None, else, return the value at ix."""
133+ ret = None
134+ if tup is not None :
135+ ret = tup [ix ]
136+ return ret
137+
138+ def _layer_tuple (
139+ self , arg : Union [None , str , int , Tuple [Any , ...]]
140+ ) -> Union [None , Tuple [Any , ...]]:
141+ """Return a non-nested tuple or None for layer-related arguments."""
142+ ret = None
143+ if isinstance (arg , (list , tuple )):
144+ ret = tuple (arg )
145+ elif isinstance (arg , (str , int )):
146+ ret = tuple ([arg ] * self .depth )
147+ elif arg is None :
148+ ret = ret
149+ else :
150+ raise ValueError (
151+ f"Given arg: { arg } should be None, str, int or a Tuple of ints or strs."
152+ )
153+
154+ return ret
155+
156+ def _block_tuple (
157+ self ,
158+ arg : Union [int , None , Tuple [Tuple [int , ...], ...]],
159+ n_layers : Tuple [int , ...],
160+ ) -> Union [None , Tuple [Tuple [int , ...], ...]]:
161+ """Return a nested tuple or None for block-related arguments."""
162+ ret = None
163+ if isinstance (arg , (list , tuple )):
164+ if not all ([isinstance (a , (tuple , list )) for a in arg ]):
165+ raise ValueError (
166+ f"Given arg: { arg } should be a nested sequence. Got: { arg } ."
167+ )
168+ ret = tuple (arg )
169+ elif isinstance (arg , int ):
170+ if n_layers is not None :
171+ ret = tuple ([tuple ([arg ] * i ) for i in n_layers ])
172+ else :
173+ ret = None
174+ elif arg is None :
175+ ret = ret
176+ else :
177+ raise ValueError (f"Given arg: { arg } should be None, int or a nested tuple." )
178+
179+ return ret
180+
115181 def forward_features (
116182 self , features : Tuple [torch .Tensor ], style : torch .Tensor = None
117183 ) -> List [torch .Tensor ]:
118184 """Forward pass of the decoder. Returns all the decoder stage feats."""
119185 head = features [0 ]
120186 skips = features [1 :]
121- extra_skips = [head ] if self .long_skip == "unet3p" else []
187+ extra_skips = [head ] if self .long_skips [ 0 ] == "unet3p" else []
122188 ret_feats = []
123189
124190 x = head
125- for decoder_stage in self .values ():
191+ for i , decoder_stage in enumerate ( self .values () ):
126192 x , extra = decoder_stage (
127193 x , skips = skips , extra_skips = extra_skips , style = style
128194 )
129195
130- if self .long_skip == "unetpp" :
196+ if self .long_skips [ i ] == "unetpp" :
131197 extra_skips = extra
132- elif self .long_skip == "unet3p" :
198+ elif self .long_skips [ i ] == "unet3p" :
133199 extra_skips .append (x )
134200
135201 ret_feats .append (x )
0 commit comments