@@ -77,6 +77,93 @@ def forward(self, x1, x2):
7777 x = torch .cat ([x2 , x1 ], dim = 1 )
7878 return self .conv (x )
7979
80+ class Encoder (nn .Module ):
81+ """
82+ Encoder of 3D UNet.
83+
84+ Parameters are given in the `params` dictionary, and should include the
85+ following fields:
86+
87+ :param in_chns: (int) Input channel number.
88+ :param feature_chns: (list) Feature channel for each resolution level.
89+ The length should be 4 or 5, such as [16, 32, 64, 128, 256].
90+ :param dropout: (list) The dropout ratio for each resolution level.
91+ The length should be the same as that of `feature_chns`.
92+ """
93+ def __init__ (self , params ):
94+ super (Encoder , self ).__init__ ()
95+ self .params = params
96+ self .in_chns = self .params ['in_chns' ]
97+ self .ft_chns = self .params ['feature_chns' ]
98+ self .dropout = self .params ['dropout' ]
99+ assert (len (self .ft_chns ) == 5 or len (self .ft_chns ) == 4 )
100+
101+ self .in_conv = ConvBlock (self .in_chns , self .ft_chns [0 ], self .dropout [0 ])
102+ self .down1 = DownBlock (self .ft_chns [0 ], self .ft_chns [1 ], self .dropout [1 ])
103+ self .down2 = DownBlock (self .ft_chns [1 ], self .ft_chns [2 ], self .dropout [2 ])
104+ self .down3 = DownBlock (self .ft_chns [2 ], self .ft_chns [3 ], self .dropout [3 ])
105+ if (len (self .ft_chns ) == 5 ):
106+ self .down4 = DownBlock (self .ft_chns [3 ], self .ft_chns [4 ], self .dropout [4 ])
107+
108+ def forward (self , x ):
109+ x0 = self .in_conv (x )
110+ x1 = self .down1 (x0 )
111+ x2 = self .down2 (x1 )
112+ x3 = self .down3 (x2 )
113+ output = [x0 , x1 , x2 , x3 ]
114+ if (len (self .ft_chns ) == 5 ):
115+ x4 = self .down4 (x3 )
116+ output .append (x4 )
117+ return output
118+
119+ class Decoder (nn .Module ):
120+ """
121+ Decoder of 3D UNet.
122+
123+ Parameters are given in the `params` dictionary, and should include the
124+ following fields:
125+
126+ :param in_chns: (int) Input channel number.
127+ :param feature_chns: (list) Feature channel for each resolution level.
128+ The length should be 4 or 5, such as [16, 32, 64, 128, 256].
129+ :param dropout: (list) The dropout ratio for each resolution level.
130+ The length should be the same as that of `feature_chns`.
131+ :param class_num: (int) The class number for segmentation task.
132+ :param trilinear: (bool) Using bilinear for up-sampling or not.
133+ If False, deconvolution will be used for up-sampling.
134+ """
135+ def __init__ (self , params ):
136+ super (Decoder , self ).__init__ ()
137+ self .params = params
138+ self .in_chns = self .params ['in_chns' ]
139+ self .ft_chns = self .params ['feature_chns' ]
140+ self .dropout = self .params ['dropout' ]
141+ self .n_class = self .params ['class_num' ]
142+ self .trilinear = self .params ['trilinear' ]
143+
144+ assert (len (self .ft_chns ) == 5 or len (self .ft_chns ) == 4 )
145+
146+ if (len (self .ft_chns ) == 5 ):
147+ self .up1 = UpBlock (self .ft_chns [4 ], self .ft_chns [3 ], self .ft_chns [3 ], self .dropout [3 ], self .bilinear )
148+ self .up2 = UpBlock (self .ft_chns [3 ], self .ft_chns [2 ], self .ft_chns [2 ], self .dropout [2 ], self .bilinear )
149+ self .up3 = UpBlock (self .ft_chns [2 ], self .ft_chns [1 ], self .ft_chns [1 ], self .dropout [1 ], self .bilinear )
150+ self .up4 = UpBlock (self .ft_chns [1 ], self .ft_chns [0 ], self .ft_chns [0 ], self .dropout [0 ], self .bilinear )
151+ self .out_conv = nn .Conv3d (self .ft_chns [0 ], self .n_class , kernel_size = 1 )
152+
153+ def forward (self , x ):
154+ if (len (self .ft_chns ) == 5 ):
155+ assert (len (x ) == 5 )
156+ x0 , x1 , x2 , x3 , x4 = x
157+ x_d3 = self .up1 (x4 , x3 )
158+ else :
159+ assert (len (x ) == 4 )
160+ x0 , x1 , x2 , x3 = x
161+ x_d3 = x3
162+ x_d2 = self .up2 (x_d3 , x2 )
163+ x_d1 = self .up3 (x_d2 , x1 )
164+ x_d0 = self .up4 (x_d1 , x0 )
165+ output = self .out_conv (x_d0 )
166+ return output
80167
81168class UNet3D (nn .Module ):
82169 """
0 commit comments