55
66Return `x` reshaped into an array one dimensionality higher than `x`,
77where `dims` indicates in which dimension `x` is extended.
8+ `dims` can be an integer between 1 and `ndims(x)+1`.
89
910See also [`flatten`](@ref), [`stack`](@ref).
1011
@@ -33,8 +34,9 @@ julia> unsqueeze(xs, dims=1)
3334 [1, 2] [3, 4] [5, 6]
3435```
3536"""
36- function unsqueeze (x:: AbstractArray ; dims:: Int )
37- sz = ntuple (i -> i < dims ? size (x, i) : i == dims ? 1 : size (x, i - 1 ), ndims (x) + 1 )
37+ function unsqueeze (x:: AbstractArray{T,N} ; dims:: Int ) where {T, N}
38+ # @assert 1 <= dims <= N + 1
39+ sz = ntuple (i -> i < dims ? size (x, i) : i == dims ? 1 : size (x, i - 1 ), N + 1 )
3840 return reshape (x, sz)
3941end
4042
@@ -59,9 +61,11 @@ Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io,
5961 stack(xs; dims)
6062
6163Concatenate the given array of arrays `xs` into a single array along the
62- given dimension `dims`.
64+ given dimension `dims`. All arrays need to be of the same size.
65+ The number of dimension in the final arrays is one more than the number
66+ of dimensions in the input arrays.
6367
64- See also [`stack `](@ref) and [`batch`](@ref).
68+ See also [`unsqueeze`](@ref), [`unstack `](@ref) and [`batch`](@ref).
6569
6670# Examples
6771
@@ -98,7 +102,28 @@ julia> stack(xs, dims=3)
98102 6
99103```
100104"""
101- stack (xs; dims:: Int ) = cat (unsqueeze .(xs; dims)... ; dims)
105+ function stack (xs; dims:: Int )
106+ N = ndims (xs[1 ])
107+ if dims <= N
108+ vs = unsqueeze .(xs; dims)
109+ else
110+ vs = xs
111+ end
112+ if dims == 1
113+ return reduce (vcat, vs)
114+ elseif dims === 2
115+ return reduce (hcat, vs)
116+ else
117+ return reduce ((x, y) -> cat (x, y; dims= dims), vs)
118+ end
119+ end
120+
121+ function rrule (:: typeof (stack), xs; dims:: Int )
122+ function stack_pullback (Δ)
123+ return (NoTangent (), unstack (unthunk (Δ); dims= dims))
124+ end
125+ return stack (xs; dims= dims), stack_pullback
126+ end
102127
103128"""
104129 unstack(xs; dims)
0 commit comments