@@ -136,18 +136,17 @@ end
136136(a:: ConvTranspose{<:Any,<:Any,W} )(x:: AbstractArray{<:Real} ) where {T <: Union{Float32,Float64} , W <: AbstractArray{T} } =
137137 a (T .(x))
138138"""
139- DepthwiseConv(size, in)
140- DepthwiseConv(size, in=>mul)
141- DepthwiseConv(size, in=>mul, relu)
139+ DepthwiseConv(size, in=>out)
140+ DepthwiseConv(size, in=>out, relu)
142141
143142Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
144- `in` and `mul ` specify the number of input channels and channel multiplier respectively.
145- In case the `mul` is not specified it is taken as 1 .
143+ `in` and `out ` specify the number of input and output channels respectively.
144+ Note that `out` must be an integer multiple of `in` .
146145
147146Data should be stored in WHCN order. In other words, a 100×100 RGB image would
148147be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
149148
150- Takes the keyword arguments `pad` and `stride `.
149+ Takes the keyword arguments `pad`, `stride` and `dilation `.
151150"""
152151struct DepthwiseConv{N,M,F,A,V}
153152 σ:: F
@@ -166,17 +165,18 @@ function DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identit
166165 return DepthwiseConv (σ, w, b, stride, pad, dilation)
167166end
168167
169- DepthwiseConv (k:: NTuple{N,Integer} , ch:: Integer , σ = identity; init = glorot_uniform,
170- stride = 1 , pad = 0 , dilation = 1 ) where N =
171- DepthwiseConv (param (init (k... , 1 , ch)), param (zeros (ch)), σ,
172- stride = stride, pad = pad, dilation= dilation)
173-
174- DepthwiseConv (k:: NTuple{N,Integer} , ch:: Pair{<:Integer,<:Integer} , σ = identity; init = glorot_uniform,
175- stride:: NTuple{N,Integer} = map (_-> 1 ,k),
176- pad:: NTuple{N,Integer} = map (_-> 0 ,2 .* k),
177- dilation:: NTuple{N,Integer} = map (_-> 1 ,k)) where N =
178- DepthwiseConv (param (init (k... , ch[2 ], ch[1 ])), param (zeros (ch[2 ]* ch[1 ])), σ,
179- stride = stride, pad = pad)
168+ function DepthwiseConv (k:: NTuple{N,Integer} , ch:: Pair{<:Integer,<:Integer} , σ = identity;
169+ init = glorot_uniform, stride = 1 , pad = 0 , dilation = 1 ) where N
170+ @assert ch[2 ] % ch[1 ] == 0 " Output channels must be integer multiple of input channels"
171+ return DepthwiseConv (
172+ param (init (k... , div (ch[2 ], ch[1 ]), ch[1 ])),
173+ param (zeros (ch[2 ])),
174+ σ;
175+ stride = stride,
176+ pad = pad,
177+ dilation = dilation
178+ )
179+ end
180180
181181@treelike DepthwiseConv
182182
@@ -187,8 +187,8 @@ function (c::DepthwiseConv)(x)
187187end
188188
189189function Base. show (io:: IO , l:: DepthwiseConv )
190- print (io, " DepthwiseConv(" , size (l. weight)[1 : ndims (l . weight) - 2 ])
191- print (io, " , " , size (l. weight, ndims (l . weight)) , " =>" , size (l. weight, ndims (l . weight) - 1 ))
190+ print (io, " DepthwiseConv(" , size (l. weight)[1 : end - 2 ])
191+ print (io, " , " , size (l. weight)[ end ] , " =>" , prod ( size (l. weight)[ end - 1 : end ] ))
192192 l. σ == identity || print (io, " , " , l. σ)
193193 print (io, " )" )
194194end
0 commit comments