@@ -102,13 +102,14 @@ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
102102
103103class MixedConv2d (nn .Module ):
104104 """ Mixed Grouped Convolution
105-
106105 Based on MDConv and GroupedConv in MixNet impl:
107106 https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
107+
108+ NOTE: This does not currently work with torch.jit.script
108109 """
109110
110111 def __init__ (self , in_channels , out_channels , kernel_size = 3 ,
111- stride = 1 , padding = '' , dilation = 1 , mixed_dilated = False , depthwise = False , ** kwargs ):
112+ stride = 1 , padding = '' , dilation = 1 , depthwise = False , ** kwargs ):
112113 super (MixedConv2d , self ).__init__ ()
113114
114115 kernel_size = kernel_size if isinstance (kernel_size , list ) else [kernel_size ]
@@ -118,17 +119,13 @@ def __init__(self, in_channels, out_channels, kernel_size=3,
118119 self .in_channels = sum (in_splits )
119120 self .out_channels = sum (out_splits )
120121 for idx , (k , in_ch , out_ch ) in enumerate (zip (kernel_size , in_splits , out_splits )):
121- d = dilation
122- # FIXME make compat with non-square kernel/dilations/strides
123- if stride == 1 and mixed_dilated :
124- d , k = (k - 1 ) // 2 , 3
125122 conv_groups = out_ch if depthwise else 1
126123 # use add_module to keep key space clean
127124 self .add_module (
128125 str (idx ),
129126 create_conv2d_pad (
130127 in_ch , out_ch , k , stride = stride ,
131- padding = padding , dilation = d , groups = conv_groups , ** kwargs )
128+ padding = padding , dilation = dilation , groups = conv_groups , ** kwargs )
132129 )
133130 self .splits = in_splits
134131
@@ -154,12 +151,12 @@ def condconv_initializer(weight):
154151
155152class CondConv2d (nn .Module ):
156153 """ Conditional Convolution
157-
158154 Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
159155
160156 Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
161157 https://github.com/pytorch/pytorch/issues/17983
162158 """
159+ __constants__ = ['bias' , 'in_channels' , 'out_channels' , 'dynamic_padding' ]
163160
164161 def __init__ (self , in_channels , out_channels , kernel_size = 3 ,
165162 stride = 1 , padding = '' , dilation = 1 , groups = 1 , bias = False , num_experts = 4 ):
@@ -171,13 +168,10 @@ def __init__(self, in_channels, out_channels, kernel_size=3,
171168 self .stride = _pair (stride )
172169 padding_val , is_padding_dynamic = get_padding_value (
173170 padding , kernel_size , stride = stride , dilation = dilation )
174- self .conv_fn = conv2d_same if is_padding_dynamic else F . conv2d
171+ self .dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
175172 self .padding = _pair (padding_val )
176173 self .dilation = _pair (dilation )
177- self .transposed = False
178- self .output_padding = _pair (0 )
179174 self .groups = groups
180- self .padding_mode = 'zero'
181175 self .num_experts = num_experts
182176
183177 self .weight_shape = (self .out_channels , self .in_channels // self .groups ) + self .kernel_size
@@ -186,60 +180,63 @@ def __init__(self, in_channels, out_channels, kernel_size=3,
186180 weight_num_param *= wd
187181 self .weight = torch .nn .Parameter (torch .Tensor (self .num_experts , weight_num_param ))
188182
189- # FIXME I haven't tested bias yet
190183 if bias :
191184 self .bias_shape = (self .out_channels ,)
192- condconv_bias_shape = (self .num_experts , self .out_channels )
193- self .bias = torch .nn .Parameter (torch .Tensor (condconv_bias_shape ))
185+ self .bias = torch .nn .Parameter (torch .Tensor (self .num_experts , self .out_channels ))
194186 else :
195187 self .register_parameter ('bias' , None )
196188
197189 self .reset_parameters ()
198- # FIXME once I'm satisfied this works, remove the looping path?
199- self ._use_groups = True # use groups for parallel per-batch-element kernel convolution
200190
201191 def reset_parameters (self ):
202192 init_weight = get_condconv_initializer (
203193 partial (nn .init .kaiming_uniform_ , a = math .sqrt (5 )), self .num_experts , self .weight_shape )
204194 init_weight (self .weight )
205195 if self .bias is not None :
206- # FIXME bias not tested
207196 fan_in = np .prod (self .weight_shape [1 :])
208197 bound = 1 / math .sqrt (fan_in )
209198 init_bias = get_condconv_initializer (
210199 partial (nn .init .uniform_ , a = - bound , b = bound ), self .num_experts , self .bias_shape )
211200 init_bias (self .bias )
212201
213202 def forward (self , x , routing_weights ):
214- weight = torch .matmul (routing_weights , self .weight )
215- bias = torch .matmul (routing_weights , self .bias ) if self .bias is not None else None
216203 B , C , H , W = x .shape
217- if self ._use_groups :
218- new_weight_shape = (B * self .out_channels , self .in_channels // self .groups ) + self .kernel_size
219- weight = weight .view (new_weight_shape )
220- # move batch elements with channels so each batch element can be efficiently convolved with separate kernel
221- x = x .view (1 , B * C , H , W )
222- out = self .conv_fn (
204+ weight = torch .matmul (routing_weights , self .weight )
205+ new_weight_shape = (B * self .out_channels , self .in_channels // self .groups ) + self .kernel_size
206+ weight = weight .view (new_weight_shape )
207+ bias = None
208+ if self .bias is not None :
209+ bias = torch .matmul (routing_weights , self .bias )
210+ bias = bias .view (B * self .out_channels )
211+ # move batch elements with channels so each batch element can be efficiently convolved with separate kernel
212+ x = x .view (1 , B * C , H , W )
213+ if self .dynamic_padding :
214+ out = conv2d_same (
223215 x , weight , bias , stride = self .stride , padding = self .padding ,
224216 dilation = self .dilation , groups = self .groups * B )
225- out = out .permute ([1 , 0 , 2 , 3 ]).view (B , self .out_channels , out .shape [- 2 ], out .shape [- 1 ])
226217 else :
227- x = torch .split (x , 1 , 0 )
228- weight = torch .split (weight , 1 , 0 )
229- if self .bias is not None :
230- bias = torch .matmul (routing_weights , self .bias )
231- bias = torch .split (bias , 1 , 0 )
232- else :
233- bias = [None ] * B
234- out = []
235- for xi , wi , bi in zip (x , weight , bias ):
236- wi = wi .view (* self .weight_shape )
237- if bi is not None :
238- bi = bi .view (* self .bias_shape )
239- out .append (self .conv_fn (
240- xi , wi , bi , stride = self .stride , padding = self .padding ,
241- dilation = self .dilation , groups = self .groups ))
242- out = torch .cat (out , 0 )
218+ out = F .conv2d (
219+ x , weight , bias , stride = self .stride , padding = self .padding ,
220+ dilation = self .dilation , groups = self .groups * B )
221+ out = out .permute ([1 , 0 , 2 , 3 ]).view (B , self .out_channels , out .shape [- 2 ], out .shape [- 1 ])
222+
223+ # Literal port (from TF definition)
224+ # x = torch.split(x, 1, 0)
225+ # weight = torch.split(weight, 1, 0)
226+ # if self.bias is not None:
227+ # bias = torch.matmul(routing_weights, self.bias)
228+ # bias = torch.split(bias, 1, 0)
229+ # else:
230+ # bias = [None] * B
231+ # out = []
232+ # for xi, wi, bi in zip(x, weight, bias):
233+ # wi = wi.view(*self.weight_shape)
234+ # if bi is not None:
235+ # bi = bi.view(*self.bias_shape)
236+ # out.append(self.conv_fn(
237+ # xi, wi, bi, stride=self.stride, padding=self.padding,
238+ # dilation=self.dilation, groups=self.groups))
239+ # out = torch.cat(out, 0)
243240 return out
244241
245242
@@ -250,13 +247,14 @@ def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
250247 assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
251248 # We're going to use only lists for defining the MixedConv2d kernel groups,
252249 # ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
253- return MixedConv2d (in_chs , out_chs , kernel_size , ** kwargs )
250+ m = MixedConv2d (in_chs , out_chs , kernel_size , ** kwargs )
254251 else :
255252 depthwise = kwargs .pop ('depthwise' , False )
256253 groups = out_chs if depthwise else 1
257254 if 'num_experts' in kwargs and kwargs ['num_experts' ] > 0 :
258- create_fn = CondConv2d
255+ m = CondConv2d ( in_chs , out_chs , kernel_size , groups = groups , ** kwargs )
259256 else :
260- create_fn = create_conv2d_pad
261- return create_fn (in_chs , out_chs , kernel_size , groups = groups , ** kwargs )
257+ m = create_conv2d_pad (in_chs , out_chs , kernel_size , groups = groups , ** kwargs )
258+ return m
259+
262260
0 commit comments