@@ -60,21 +60,25 @@ static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = ()
6060@inline function Base. copy (B:: Broadcasted{StaticArrayStyle{M}} ) where M
6161 flat = Broadcast. flatten (B); as = flat. args; f = flat. f
6262 argsizes = broadcast_sizes (as... )
63- destsize = combine_sizes (argsizes)
64- _broadcast (f, destsize, argsizes, as... )
63+ ax = axes (B)
64+ if ax isa Tuple{Vararg{SOneTo}}
65+ return _broadcast (f, Size (map (length, ax)), argsizes, as... )
66+ end
67+ return copy (convert (Broadcasted{DefaultArrayStyle{M}}, B))
6568end
6669# copyto! overloads
6770@inline Base. copyto! (dest, B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
6871@inline Base. copyto! (dest:: AbstractArray , B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
6972@inline function _copyto! (dest, B:: Broadcasted{StaticArrayStyle{M}} ) where M
7073 flat = Broadcast. flatten (B); as = flat. args; f = flat. f
7174 argsizes = broadcast_sizes (as... )
72- destsize = combine_sizes (( Size (dest), argsizes ... ) )
73- if Length (destsize) === Length {Dynamic()} ()
74- # destination dimension cannot be determined statically; fall back to generic broadcast!
75- return copyto! (dest, convert (Broadcasted{DefaultArrayStyle{M}}, B) )
75+ ax = axes (B )
76+ if ax isa Tuple{Vararg{SOneTo}}
77+ @boundscheck axes (dest) == ax || Broadcast . throwdm ( axes (dest), ax)
78+ return _broadcast! (f, Size ( map (length, ax)), dest, argsizes, as ... )
7679 end
77- _broadcast! (f, destsize, dest, argsizes, as... )
80+ # destination dimension cannot be determined statically; fall back to generic broadcast!
81+ return copyto! (dest, convert (Broadcasted{DefaultArrayStyle{M}}, B))
7882end
7983
8084# Resolving priority between dynamic and static axes
@@ -101,45 +105,13 @@ broadcast_indices(A::StaticArray) = indices(A)
101105@inline broadcast_size (a:: AbstractArray ) = Size (a)
102106@inline broadcast_size (a:: Tuple ) = Size (length (a))
103107
104- function broadcasted_index (oldsize, newindex)
105- index = ones (Int, length (oldsize))
106- for i = 1 : length (oldsize)
107- if oldsize[i] != 1
108- index[i] = newindex[i]
109- end
110- end
111- return LinearIndices (oldsize)[index... ]
112- end
113-
114- # similar to Base.Broadcast.combine_indices:
115- @generated function combine_sizes (s:: Tuple{Vararg{Size}} )
116- sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
117- ndims = 0
118- for i = 1 : length (sizes)
119- ndims = max (ndims, length (sizes[i]))
120- end
121- newsize = StaticDimension[Dynamic () for _ = 1 : ndims]
122- for i = 1 : length (sizes)
123- s = sizes[i]
124- for j = 1 : length (s)
125- if s[j] isa Dynamic
126- continue
127- elseif newsize[j] isa Dynamic || newsize[j] == 1
128- newsize[j] = s[j]
129- elseif newsize[j] ≠ s[j] && s[j] ≠ 1
130- throw (DimensionMismatch (" Tried to broadcast on inputs sized $sizes " ))
131- end
132- end
133- end
134- quote
135- @_inline_meta
136- Size ($ (tuple (newsize... )))
137- end
108+ broadcast_getindex (:: Tuple{} , i:: Int , I:: CartesianIndex ) = return :(_broadcast_getindex (a[$ i], $ I))
109+ function broadcast_getindex (oldsize:: Tuple , i:: Int , newindex:: CartesianIndex )
110+ li = LinearIndices (oldsize)
111+ ind = _broadcast_getindex (li, newindex)
112+ return :(a[$ i][$ ind])
138113end
139114
140- scalar_getindex (x) = x
141- scalar_getindex (x:: Ref ) = x[]
142-
143115isstatic (:: StaticArray ) = true
144116isstatic (:: Transpose{<:Any, <:StaticArray} ) = true
145117isstatic (:: Adjoint{<:Any, <:StaticArray} ) = true
@@ -163,13 +135,11 @@ end
163135
164136@generated function __broadcast (f, :: Size{newsize} , s:: Tuple{Vararg{Size}} , a... ) where newsize
165137 sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
138+
166139 indices = CartesianIndices (newsize)
167140 exprs = similar (indices, Expr)
168141 for (j, current_ind) ∈ enumerate (indices)
169- exprs_vals = [
170- (! (a[i] <: AbstractArray || a[i] <: Tuple ) ? :(scalar_getindex (a[$ i])) : :(a[$ i][$ (broadcasted_index (sizes[i], current_ind))]))
171- for i = 1 : length (sizes)
172- ]
142+ exprs_vals = (broadcast_getindex (sz, i, current_ind) for (i, sz) in enumerate (sizes))
173143 exprs[j] = :(f ($ (exprs_vals... )))
174144 end
175145
@@ -183,27 +153,18 @@ end
183153# # Internal broadcast! machinery for StaticArrays ##
184154# ###################################################
185155
186- @generated function _broadcast! (f, :: Size{newsize} , dest:: AbstractArray , s:: Tuple{Vararg{Size}} , as... ) where {newsize}
187- sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
188- sizes = tuple (sizes... )
189-
190- # TODO : this could also be done outside the generated function:
191- sizematch (Size {newsize} (), Size (dest)) ||
192- throw (DimensionMismatch (" Tried to broadcast to destination sized $newsize from inputs sized $sizes " ))
156+ @generated function _broadcast! (f, :: Size{newsize} , dest:: AbstractArray , s:: Tuple{Vararg{Size}} , a... ) where {newsize}
157+ sizes = [sz. parameters[1 ] for sz in s. parameters]
193158
194159 indices = CartesianIndices (newsize)
195160 exprs = similar (indices, Expr)
196161 for (j, current_ind) ∈ enumerate (indices)
197- exprs_vals = [
198- (! (as[i] <: AbstractArray || as[i] <: Tuple ) ? :(as[$ i][]) : :(as[$ i][$ (broadcasted_index (sizes[i], current_ind))]))
199- for i = 1 : length (sizes)
200- ]
162+ exprs_vals = (broadcast_getindex (sz, i, current_ind) for (i, sz) in enumerate (sizes))
201163 exprs[j] = :(dest[$ j] = f ($ (exprs_vals... )))
202164 end
203165
204166 return quote
205- @_propagate_inbounds_meta
206- @boundscheck sizematch ($ (Size {newsize} ()), dest) || throw (DimensionMismatch (" array could not be broadcast to match destination" ))
167+ @_inline_meta
207168 @inbounds $ (Expr (:block , exprs... ))
208169 return dest
209170 end
0 commit comments