@@ -200,3 +200,44 @@ function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{3})
200200 cube_pullback (dy) = (NoTangent (), NoTangent (), ProjectTo (x)(3 * x2 * dy), NoTangent ())
201201 return x2 * x, cube_pullback
202202end
203+
204+ # ####
205+ # #### `map`
206+ # ####
207+
208+ # Ideally reverse mode should always iterate in reverse order. For `map` and broadcasting
209+ # this may matter with a stateful `f`, but in general their order isn't guaranteed anyway,
210+ # so it's unclear how much effort should be spent on that. But `map` on Tuples normally
211+ # gets unrolled, so perhaps it does guarantee order, and reversing it should be cheap too.
212+
213+ function rrule (config:: RuleConfig{>:HasReverseMode} , :: typeof (map), f:: F , xs:: Tuple... ) where {F}
214+ length_y = minimum (length, xs)
215+ hobbits = ntuple (length_y) do i
216+ args = getindex .(xs, i)
217+ rrule_via_ad (config, f, args... )
218+ end
219+ y = map (first, hobbits)
220+ num_xs = Val (length (xs))
221+ paddings = map (x -> ntuple (Returns (NoTangent ()), (length (x) - length_y)), xs)
222+ all (isempty, paddings) || @error """ map(f, xs::Tuple...) does not allow mistmatched lengths!
223+ But its `rrule` does; when JuliaLang/julia #42216 is fixed this warning should be removed."""
224+ function map_pullback (dy_raw)
225+ dy = unthunk (dy_raw)
226+ # We want to call the pullbacks in `rrule_via_ad` in reverse sequence to the forward pass:
227+ backevals = ntuple (length_y) do i
228+ rev_i = length_y - i + 1
229+ last (hobbits[rev_i])(dy[rev_i])
230+ end |> reverse
231+ # This df doesn't infer, could test Base.issingletontype(F), but it's not the only inference problem.
232+ df = ProjectTo (f)(sum (first, backevals))
233+ # Now unzip that. Because `map` like `zip` should when any `x` stops, some `dx`s may need padding.
234+ # Although in fact, `map(+, (1,2), (3,4,5))` is an error... https://github.com/JuliaLang/julia/issues/42216
235+ dxs = ntuple (num_xs) do k
236+ dx_short = map (bv -> bv[k+ 1 ], backevals)
237+ ProjectTo (xs[k])((dx_short... , paddings[k]. .. )) # ProjectTo makes the Tangent for us
238+ end
239+ return (NoTangent (), df, dxs... )
240+ end
241+ map_back (dy:: AbstractZero ) = (NoTangent (), NoTangent (), ntuple (Returns (NoTangent ()), num_xs)... )
242+ return y, map_pullback
243+ end
0 commit comments