@@ -166,31 +166,41 @@ end
166166
167167# First, we will define mappings from the generic API names to our accelerated backend
168168# implementations. For homogeneous-datatype 1, 2 and 3d convolutions, we default to using
169- # im2col + GEMM. Do so in a loop, here:
169+ # im2col + GEMM.
170+ # But we always support a fallback, non-accelerated path, where we use the direct, but
171+ # slow, implementations. These should not typically be used, hence the `@warn`,
170172
171173# These are the GEMM types we will accelerate with `im2col`
172174const G = Union{[x[2 ] for x in gemm_datatype_mappings]. .. }
173175
174- for (front_name, backend) in (
175- # This maps from public, front-facing name, to internal backend name
176- :conv => :im2col ,
177- )
178-
176+ for (front_name, backend, signature) in (
177+ # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
178+ # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
179+ (:conv , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
180+ (:conv , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
181+ )
179182 # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
180183 @eval begin
181- # im2col-accelerated function forwarding definition
184+
182185 function $ (Symbol (" $(front_name) !" ))(
183- out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
184- in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: $G , C <: ConvDims }
186+ out:: AbstractArray{$(signature[1][1]), $(signature[1][2])} ,
187+ in1:: AbstractArray{$(signature[2][1]), $(signature[1][2])} ,
188+ in2:: AbstractArray{$(signature[3][1]), $(signature[1][2])} ,
189+ cdims:: $ (signature[4 ]);
190+ kwargs... ) where {$ (signature[5 ]. .. )}
191+ if $ (string (backend)) == " direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
192+ @warn string (" Slow fallback implementation invoked for " , $ (string (front_name)), " ! " ,
193+ " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
194+ end
185195
186196 x_cs = Iterators. partition (1 : size (in1, 4 ),
187- channels_in (cdims) ÷ groupcount (cdims))
197+ channels_in (cdims) ÷ groupcount (cdims))
188198 w_cs = Iterators. partition (1 : size (in2, 5 ),
189- channels_out (cdims) ÷ groupcount (cdims))
199+ channels_out (cdims) ÷ groupcount (cdims))
190200 cdims2 = basetype (C)(cdims,
191- G = 1 ,
192- C_in = channels_in (cdims) ÷ groupcount (cdims),
193- C_out = channels_out (cdims) ÷ groupcount (cdims))
201+ G = 1 ,
202+ C_in = channels_in (cdims) ÷ groupcount (cdims),
203+ C_out = channels_out (cdims) ÷ groupcount (cdims))
194204
195205 Threads. @sync for (xc, wc) in zip (x_cs, w_cs)
196206 x = @view in1[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
@@ -205,87 +215,119 @@ for (front_name, backend) in (
205215end
206216
207217# im2col-accelerated function forwarding definition
208- function ∇conv_data! (out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
209- in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: G , C <: ConvDims }
210-
211- dx_cs = Iterators. partition (1 : size (out, 4 ),
212- channels_in (cdims) ÷ groupcount (cdims))
213- w_cs = Iterators. partition (1 : size (in2, 5 ),
214- channels_out (cdims) ÷ groupcount (cdims))
215- dy_cs = Iterators. partition (1 : size (in1, 4 ),
216- channels_out (cdims) ÷ groupcount (cdims))
217- cdims2 = basetype (C)(cdims,
218- G = 1 ,
219- C_in = channels_in (cdims) ÷ groupcount (cdims),
220- C_out = channels_out (cdims) ÷ groupcount (cdims))
221-
222- Threads. @sync for (xc, yc, wc) in zip (dx_cs, dy_cs, w_cs)
223- dxv = @view out[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
224- dyv = @view in1[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
225- wv = @view in2[ntuple (i -> i == 5 ? wc : Colon (), 5 )... ]
226- Threads. @spawn ∇conv_data_im2col! (dxv, dyv, wv, cdims2; kwargs... )
227- end
218+ for (front_name, backend, signature) in (
219+ # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
220+ # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
221+ (:∇conv_data , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
222+ (:∇conv_data , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
223+ )
224+ # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
225+ @eval begin
226+ function $ (Symbol (" $(front_name) !" ))(
227+ out:: AbstractArray{$(signature[1][1]), $(signature[1][2])} ,
228+ in1:: AbstractArray{$(signature[2][1]), $(signature[1][2])} ,
229+ in2:: AbstractArray{$(signature[3][1]), $(signature[1][2])} ,
230+ cdims:: $ (signature[4 ]);
231+ kwargs... ) where {$ (signature[5 ]. .. )}
232+ if $ (string (backend)) == " direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
233+ @warn string (" Slow fallback implementation invoked for " , $ (string (front_name)), " ! " ,
234+ " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
235+ end
228236
229- return out
230- end
231237
232- function ∇conv_filter! (out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
233- in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: G , C <: ConvDims }
234- dw_cs = Iterators. partition (1 : size (out, 5 ),
235- channels_out (cdims) ÷ groupcount (cdims))
236- dy_cs = Iterators. partition (1 : size (in2, 4 ),
237- channels_out (cdims) ÷ groupcount (cdims))
238- x_cs = Iterators. partition (1 : size (in1, 4 ),
239- channels_in (cdims) ÷ groupcount (cdims))
240- cdims2 = basetype (C)(cdims,
241- G = 1 ,
242- C_in = channels_in (cdims) ÷ groupcount (cdims),
243- C_out = channels_out (cdims) ÷ groupcount (cdims))
244-
245- Threads. @sync for (wc, xc, yc) in zip (dw_cs, x_cs, dy_cs)
246- x = @view in1[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
247- dy = @view in2[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
248- dw = @view out[ntuple (i -> i == 5 ? yc : Colon (), 5 )... ]
249- Threads. @spawn ∇conv_filter_im2col! (dw, x, dy, cdims2; kwargs... )
250- end
238+ dx_cs = Iterators. partition (1 : size (out, 4 ),
239+ channels_in (cdims) ÷ groupcount (cdims))
240+ w_cs = Iterators. partition (1 : size (in2, 5 ),
241+ channels_out (cdims) ÷ groupcount (cdims))
242+ dy_cs = Iterators. partition (1 : size (in1, 4 ),
243+ channels_out (cdims) ÷ groupcount (cdims))
244+ cdims2 = basetype (C)(cdims,
245+ G = 1 ,
246+ C_in = channels_in (cdims) ÷ groupcount (cdims),
247+ C_out = channels_out (cdims) ÷ groupcount (cdims))
248+
249+ Threads. @sync for (xc, yc, wc) in zip (dx_cs, dy_cs, w_cs)
250+ dxv = @view out[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
251+ dyv = @view in1[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
252+ wv = @view in2[ntuple (i -> i == 5 ? wc : Colon (), 5 )... ]
253+ Threads. @spawn $ (Symbol (" $(front_name) _$(backend) !" ))(dxv, dyv, wv, cdims2; kwargs... )
254+ end
251255
252- return out
256+ return out
257+ end
258+ end
253259end
254260
255-
256- for (front_name, backend) in (
257- # This maps from public, front-facing name, to internal backend name
258- :depthwiseconv => :im2col ,
259- :∇depthwiseconv_data => :im2col ,
260- :∇depthwiseconv_filter => :im2col ,
261- )
262-
261+ for (front_name, backend, signature) in (
262+ # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
263+ # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
264+ (:∇conv_filter , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
265+ (:∇conv_filter , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
266+ )
263267 # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
264268 @eval begin
265- # im2col-accelerated function forwarding definition
266269 function $ (Symbol (" $(front_name) !" ))(
267- out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
268- in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: $G , C <: ConvDims }
269- $ (Symbol (" $(front_name) _$(backend) !" ))(out, in1, in2, cdims; kwargs... )
270+ out:: AbstractArray{$(signature[1][1]), $(signature[1][2])} ,
271+ in1:: AbstractArray{$(signature[2][1]), $(signature[1][2])} ,
272+ in2:: AbstractArray{$(signature[3][1]), $(signature[1][2])} ,
273+ cdims:: $ (signature[4 ]);
274+ kwargs... ) where {$ (signature[5 ]. .. )}
275+ if $ (string (backend)) == " direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
276+ @warn string (" Slow fallback implementation invoked for " , $ (string (front_name)), " ! " ,
277+ " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
278+ end
279+
280+ dw_cs = Iterators. partition (1 : size (out, 5 ),
281+ channels_out (cdims) ÷ groupcount (cdims))
282+ dy_cs = Iterators. partition (1 : size (in2, 4 ),
283+ channels_out (cdims) ÷ groupcount (cdims))
284+ x_cs = Iterators. partition (1 : size (in1, 4 ),
285+ channels_in (cdims) ÷ groupcount (cdims))
286+ cdims2 = basetype (C)(cdims,
287+ G = 1 ,
288+ C_in = channels_in (cdims) ÷ groupcount (cdims),
289+ C_out = channels_out (cdims) ÷ groupcount (cdims))
290+
291+ Threads. @sync for (wc, xc, yc) in zip (dw_cs, x_cs, dy_cs)
292+ x = @view in1[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
293+ dy = @view in2[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
294+ dw = @view out[ntuple (i -> i == 5 ? yc : Colon (), 5 )... ]
295+ Threads. @spawn $ (Symbol (" $(front_name) _$(backend) !" ))(dw, x, dy, cdims2; kwargs... )
296+ end
297+
298+ return out
270299 end
271300 end
272301end
273302
274- # We always support a fallback, non-accelerated path, where we use the direct, but
275- # slow, implementations. These should not typically be used, hence the `@warn`,
276- # but let's go ahead and define them first:
277- for front_name in (:conv , :∇conv_data , :∇conv_filter ,
278- :depthwiseconv , :∇depthwiseconv_data , :∇depthwiseconv_filter )
303+
304+ for (front_name, backend, signature) in (
305+ # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
306+ # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
307+ (:depthwiseconv , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
308+ (:depthwiseconv , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
309+
310+ (:∇depthwiseconv_data , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
311+ (:∇depthwiseconv_data , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
312+
313+ (:∇depthwiseconv_filter , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
314+ (:∇depthwiseconv_filter , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
315+ )
316+
317+ # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
279318 @eval begin
319+ # im2col-accelerated function forwarding definition
280320 function $ (Symbol (" $(front_name) !" ))(
281- y:: AbstractArray{yT,N} , in1:: AbstractArray{T1,N} ,
282- in2:: AbstractArray{T2,N} , cdims:: ConvDims ;
283- kwargs... ) where {yT, T1, T2, N}
284- if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
321+ out:: AbstractArray{$(signature[1][1]), $(signature[1][2])} ,
322+ in1:: AbstractArray{$(signature[2][1]), $(signature[1][2])} ,
323+ in2:: AbstractArray{$(signature[3][1]), $(signature[1][2])} ,
324+ cdims:: $ (signature[4 ]);
325+ kwargs... ) where {$ (signature[5 ]. .. )}
326+ if $ (string (backend)) == " direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
285327 @warn string (" Slow fallback implementation invoked for " , $ (string (front_name)), " ! " ,
286- " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
328+ " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
287329 end
288- $ (Symbol (" $(front_name) _direct !" ))(y , in1, in2, cdims; kwargs... )
330+ $ (Symbol (" $(front_name) _ $(backend) !" ))(out , in1, in2, cdims; kwargs... )
289331 end
290332 end
291333end
0 commit comments