@@ -29,16 +29,30 @@ StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{
2929StructArrays. createinstance (T:: Type{<:FieldArray} , args... ) = invoke (createinstance, Tuple{Type{<: Any }, Vararg}, T, args... )
3030
3131# Broadcast overload
32- using StaticArraysCore: StaticArrayStyle
33- import StaticArraysCore: Size, is_staticarray_like, similar_type
32+ using StaticArraysCore: StaticArrayStyle, similar_type
3433StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N}
3534function Broadcast. instantiate (bc:: Broadcasted{StructStaticArrayStyle{M}} ) where {M}
36- bc′ = Broadcast. instantiate (convert (Broadcasted{StaticArrayStyle{M}}, bc))
35+ bc′ = Broadcast. instantiate (replace_structarray ( bc))
3736 return convert (Broadcasted{StructStaticArrayStyle{M}}, bc′)
3837end
39- function Broadcast. _axes (bc:: Broadcasted{StructStaticArrayStyle{M}} , :: Nothing ) where {M}
40- return Broadcast. _axes (convert (Broadcasted{StaticArrayStyle{M}}, bc), nothing )
38+ # This looks costy, but compiler should be able to optimize them away
39+ Broadcast. _axes (bc:: Broadcasted{<:StructStaticArrayStyle} , :: Nothing ) = axes (replace_structarray (bc))
40+
41+ to_staticstyle (@nospecialize (x:: Type )) = x
42+ to_staticstyle (:: Type{StructStaticArrayStyle{N}} ) where {N} = StaticArrayStyle{N}
43+ function replace_structarray (bc:: Broadcasted{Style} ) where {Style}
44+ args = replace_structarray_args (bc. args)
45+ return Broadcasted {to_staticstyle(Style)} (bc. f, args, nothing )
46+ end
47+ function replace_structarray (A:: StructArray )
48+ f = createinstance (eltype (A))
49+ args = Tuple (components (A))
50+ return Broadcasted {StaticArrayStyle{ndims(A)}} (f, args, nothing )
4151end
52+ replace_structarray (@nospecialize (A)) = A
53+
54+ replace_structarray_args (args:: Tuple ) = (replace_structarray (args[1 ]), replace_structarray_args (Base. tail (args))... )
55+ replace_structarray_args (:: Tuple{} ) = ()
4256
4357# StaticArrayStyle has no similar defined.
4458# Overload `Base.copy` instead.
4862 isnonemptystructtype (ET) || return sa
4963 elements = Tuple (sa)
5064 arrs = ntuple (Val (fieldcount (ET))) do i
51- similar_type (sa, fieldtype (ET, i), Size (sa) )(_getfields (elements, i))
65+ similar_type (sa, fieldtype (ET, i))(_getfields (elements, i))
5266 end
5367 return StructArray {ET} (arrs)
5468end
6074 return map (Base. Fix2 (getfield, i), x)
6175 end
6276end
63-
64- Size (:: Type{SA} ) where {SA<: StructArray } = Size (fieldtype (array_types (SA), 1 ))
65- is_staticarray_like (x:: StructArray ) = any (is_staticarray_like, components (x))
66- function similar_type (:: Type{SA} , :: Type{T} , s:: Size{S} ) where {SA<: StructArray , T, S}
67- return similar_type (fieldtype (array_types (SA), 1 ), T, s)
68- end
0 commit comments