44
55import Base. Broadcast:
66BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, materialize!
7+ import Base. Broadcast: combine_axes, instantiate, _broadcast_getindex, broadcast_shape, Style
78import Base. Broadcast: _bcs1 # for SOneTo axis information
89using Base. Broadcast: _bcsm
910# Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle
@@ -19,6 +20,42 @@ BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
1920 DefaultArrayStyle (Val (max (M, N)))
2021BroadcastStyle (:: StaticArrayStyle{M} , :: DefaultArrayStyle{0} ) where {M} =
2122 StaticArrayStyle {M} ()
23+
24+ # combine_axes overload (for Tuple)
25+ @inline static_combine_axes (A, B... ) = broadcast_shape (static_axes (A), static_combine_axes (B... ))
26+ static_combine_axes (A) = static_axes (A)
27+ static_axes (A) = axes (A)
28+ static_axes (x:: Tuple ) = (SOneTo {length(x)} (),)
29+ static_axes (bc:: Broadcasted{Style{Tuple}} ) = static_combine_axes (bc. args... )
30+ Broadcast. _axes (bc:: Broadcasted{<:StaticArrayStyle} , :: Nothing ) = static_combine_axes (bc. args... )
31+
32+ # instantiate overload
33+ @inline function instantiate (B:: Broadcasted{StaticArrayStyle{M}} ) where M
34+ if B. axes isa Tuple{Vararg{SOneTo}} || B. axes isa Tuple && length (B. axes) > M
35+ return invoke (instantiate, Tuple{Broadcasted}, B)
36+ elseif B. axes isa Nothing
37+ ax = static_combine_axes (B. args... )
38+ return Broadcasted {StaticArrayStyle{M}} (B. f, B. args, ax)
39+ else
40+ # We need to update B.axes for `broadcast!` if it's not static and `ndims(dest) < M`.
41+ ax = static_check_broadcast_shape (B. axes, static_combine_axes (B. args... ))
42+ return Broadcasted {StaticArrayStyle{M}} (B. f, B. args, ax)
43+ end
44+ end
45+ @inline function static_check_broadcast_shape (shp:: Tuple , Ashp:: Tuple{Vararg{SOneTo}} )
46+ ax1 = if length (Ashp[1 ]) == 1
47+ shp[1 ]
48+ elseif Ashp[1 ] == shp[1 ]
49+ Ashp[1 ]
50+ else
51+ throw (DimensionMismatch (" array could not be broadcast to match destination" ))
52+ end
53+ (ax1, static_check_broadcast_shape (Base. tail (shp), Base. tail (Ashp))... )
54+ end
55+ static_check_broadcast_shape (:: Tuple{} , :: Tuple{SOneTo,Vararg{SOneTo}} ) =
56+ throw (DimensionMismatch (" cannot broadcast array to have fewer non-singleton dimensions" ))
57+ static_check_broadcast_shape (:: Tuple{} , :: Tuple{SOneTo{1},Vararg{SOneTo{1}}} ) = ()
58+ static_check_broadcast_shape (:: Tuple{} , :: Tuple{} ) = ()
2259# copy overload
2360@inline function Base. copy (B:: Broadcasted{StaticArrayStyle{M}} ) where M
2461 flat = Broadcast. flatten (B); as = flat. args; f = flat. f
4279
4380# Resolving priority between dynamic and static axes
4481_bcs1 (a:: SOneTo , b:: SOneTo ) = _bcsm (b, a) ? b : (_bcsm (a, b) ? a : throw (DimensionMismatch (" arrays could not be broadcast to a common size" )))
45- _bcs1 (a:: SOneTo , b:: Base.OneTo ) = _bcs1 (Base. OneTo (a), b)
46- _bcs1 (a:: Base.OneTo , b:: SOneTo ) = _bcs1 (a, Base. OneTo (b))
82+ function _bcs1 (a:: SOneTo , b:: Base.OneTo )
83+ length (a) == 1 && return b
84+ if length (b) != length (a) && length (b) != 1
85+ throw (DimensionMismatch (" arrays could not be broadcast to a common size" ))
86+ end
87+ return a
88+ end
89+ _bcs1 (a:: Base.OneTo , b:: SOneTo ) = _bcs1 (b, a)
4790
4891# ##################################################
4992# # Internal broadcast machinery for StaticArrays ##
0 commit comments