1- import StaticArrays: StaticArray, FieldArray, tuple_prod
1+ using StaticArrays: StaticArrays, StaticArray, FieldArray, tuple_prod, StaticArrayStyle
22
33"""
44 StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
@@ -26,4 +26,62 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i)
2626 invoke (StructArrays. staticschema, Tuple{Type{<: Any }}, T)
2727end
2828StructArrays. component (s:: FieldArray , i) = invoke (StructArrays. component, Tuple{Any, Any}, s, i)
29- StructArrays. createinstance (T:: Type{<:FieldArray} , args... ) = invoke (createinstance, Tuple{Type{<: Any }, Vararg}, T, args... )
29+ StructArrays. createinstance (T:: Type{<:FieldArray} , args... ) = invoke (createinstance, Tuple{Type{<: Any }, Vararg}, T, args... )
30+
31+ # Broadcast overload
32+ import StaticArrays: Size, isstatic, similar_type
33+ using StaticArrays: first_statictype, broadcast_sizes, SOneTo
34+ import Base. Broadcast: instantiate
35+ StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N}
36+ function instantiate (bc:: Broadcasted{StructStaticArrayStyle{M}} ) where {M}
37+ bc′ = instantiate (convert (Broadcasted{StaticArrayStyle{M}}, bc))
38+ return convert (Broadcasted{StructStaticArrayStyle{M}}, bc′)
39+ end
40+ function Broadcast. _axes (bc:: Broadcasted{<:StructStaticArrayStyle} , :: Nothing )
41+ return StaticArrays. static_combine_axes (bc. args... )
42+ end
43+
44+ # StaticArrayStyle has no similar defined.
45+ # Overload `Base.copy` instead.
46+ @inline function Base. copy (B:: Broadcasted{<:StructStaticArrayStyle} )
47+ flat = Broadcast. flatten (B); as = flat. args; f = flat. f
48+ argsizes = broadcast_sizes (as... )
49+ ax = axes (B)
50+ ax isa Tuple{Vararg{SOneTo}} || error (" Dimension is not static. Please file a bug." )
51+ return _broadcast (f, Size (map (length, ax)), argsizes, as... )
52+ end
53+ @inline function _broadcast (f, sz:: Size{newsize} , s:: Tuple{Vararg{Size}} , a... ) where newsize
54+ AT = first_statictype (a... )
55+ if prod (newsize) == 0
56+ # Use inference to get eltype in empty case (see also comments in _map)
57+ eltys = Tuple{map (eltype, a)... }
58+ T = Core. Compiler. return_type (f, eltys)
59+ return _struct_static_similar (T, AT, sz, ())
60+ end
61+ elements = StaticArrays. __broadcast (f, sz, s, a... )
62+ return _struct_static_similar (eltype (elements), AT, sz, elements)
63+ end
64+ function _struct_static_similar (:: Type{ET} , :: Type{AT} , sz, elements:: Tuple ) where {ET, AT}
65+ if isnonemptystructtype (ET)
66+ arrs = ntuple (Val (fieldcount (ET))) do i
67+ similar_type (AT, fieldtype (ET, i), sz)(_getfields (elements, i))
68+ end
69+ return StructArray {ET} (arrs)
70+ else
71+ return similar_type (AT, ET, sz)(elements)
72+ end
73+ end
74+
75+ @inline function _getfields (x:: Tuple , i:: Int )
76+ if @generated
77+ return Expr (:tuple , (:(getfield (x[$ j], i)) for j in 1 : fieldcount (x)). .. )
78+ else
79+ return map (Base. Fix2 (getfield, i), x)
80+ end
81+ end
82+
83+ Size (:: Type{SA} ) where {SA<: StructArray } = Size (fieldtype (array_types (SA), 1 ))
84+ isstatic (x:: StructArray ) = isstatic (component (x, 1 ))
85+ function similar_type (:: Type{SA} , :: Type{T} , s:: Size{S} ) where {SA<: StructArray , T, S}
86+ return similar_type (fieldtype (array_types (SA), 1 ), T, s)
87+ end
0 commit comments