1+ # ####
2+ # #### broadcast
3+ # ####
14
25"""
36 unzip_broadcast(f, args...)
@@ -26,17 +29,18 @@ function unzip_broadcast(f::F, args...) where {F}
2629 T <: Tuple || throw (ArgumentError (""" unzip_broadcast(f, args) only works on functions returning a tuple,
2730 but f = $(sprint (show, f)) returns type T = $T """ ))
2831 end
29- # TODO allow GPU arrays, possibly just as a fallback unzip, but see also:
30- # https://github.com/JuliaArrays/StructArrays.jl/issues/150
31- # if any(a -> a isa CuArray, args)
32- # return unzip(broadcast(f, args...))
33- # end
3432 bc = Broadcast. instantiate (Broadcast. broadcasted (f, args... ))
35- if Broadcast. BroadcastStyle (typeof (bc)) isa Broadcast. AbstractArrayStyle
33+ bcs = Broadcast. BroadcastStyle (typeof (bc))
34+ if bcs isa AbstractGPUArrayStyle
35+ # This is a crude way to allow GPU arrays, not currently tested, TODO .
36+ # See also https://github.com/JuliaArrays/StructArrays.jl/issues/150
37+ return unzip (broadcast (f, args... ))
38+ elseif bcs isa Broadcast. AbstractArrayStyle
3639 return StructArrays. components (StructArray (bc))
3740 else
3841 return unzip (broadcast (f, args... )) # e.g. tuples
3942 end
43+ # TODO maybe this if-else can be replaced by methods of `unzip(:::Broadcast.Broadcasted)`?
4044end
4145
4246function ChainRulesCore. rrule (cfg:: RuleConfig{>:HasReverseMode} , :: typeof (unzip_broadcast), f:: F , args... ) where {F}
@@ -58,40 +62,17 @@ function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect∘unzip_broad
5862 return collect (y), back
5963end
6064
61- #=
65+ # ####
66+ # #### map
67+ # ####
6268
63- """
64- unzip_map(f, args...)
65-
66- For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`,
67- but performed using `StructArrays` for efficiency.
68-
69- Not in use at present, but see `unzip_broadcast`.
70- """
71- function unzip_map(f::F, args...) where {F}
72- T = Broadcast.combine_eltypes(f, args)
73- if isconcretetype(T)
74- T <: Tuple || throw(ArgumentError("""unzip_map(f, args) only works on functions returning a tuple,
75- but f = $(sprint(show, f)) returns type T = $T"""))
76- end
77- # if any(a -> a isa CuArray, args)
78- # return unzip(map(f, args...))
79- # end
80- return StructArrays.components(StructArray(Iterators.map(f, args...)))
81- end
69+ # `unzip_map` can use `StructArrays.components(StructArray(Iterators.map(f, args...)))`,
70+ # will be useful for the gradient of `map` etc.
8271
83- function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_map), f::F, xs...) where {F}
84- y, back = rrule_via_ad(cfg, map, f, xs...)
85- z = unzip(y)
86- function ununzip_map(dz)
87- # dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent()))
88- dy = broadcast(tuple, map(unthunk, dz)...)
89- return back(dy)
90- end
91- return z, ununzip_map
92- end
9372
94- =#
73+ # ####
74+ # #### unzip
75+ # ####
9576
9677"""
9778 unzip(A)
0 commit comments