@@ -417,17 +417,73 @@ end
417417end
418418
419419# ####
420- # #### `foldl`
420+ # ####
421+ # #### `foldl(f, ::Tuple)`
421422# ####
422423
423424# `foldl` guarantees to execute `f` in order, left to right. So it makes sense even when
424- # this `f` is stateful, in which case the gradient must be calculated in the reverse order.
425+ # this `f` is stateful, in which case the gradient must be calculated in the reverse order.
426+
427+ # The rule is attached to `Base.mapfoldl_impl` because this gets the `init` keyword as an argument,
428+ # which is handled below. For tuples, `reduce` also comes here.
429+
430+ function rrule (
431+ config:: RuleConfig{>:HasReverseMode} ,
432+ :: typeof (Base. mapfoldl_impl),
433+ :: typeof (identity),
434+ op:: G ,
435+ init:: Base._InitialValue ,
436+ x:: Tuple ;
437+ ) where {G}
438+ hobbits = accumulate (Base. tail (x); init= (first (x), nothing )) do (a, _), b
439+ # Here `a` is what we would normally cary forward, and `_` ignores
440+ # the previous iteration's pullback function (needed later),
441+ # while `b` is the fresh input from `list` as usual.
442+ c, back = rrule_via_ad (config, op, a, b)
443+ # We don't really need to store every `c`, last one is `foldl` output.
444+ # (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.)
445+ end
446+ y = first (last (hobbits))
447+ project = ProjectTo (x)
448+ function foldl_pullback_tuple (dy)
449+ trio = accumulate (_reverse1 (hobbits); init= (0 , dy, 0 )) do (_, dc, _), (_, back)
450+ ds, da, db = back (dc)
451+ # Don't need to store every `da`, need one for the next iteration + the last.
452+ end
453+ dop = sum (first, trio)
454+ dx = (trio[end ][2 ], reverse (map (last, trio))... )
455+ return (NoTangent (), NoTangent (), ProjectTo (op)(dop), NoTangent (), project (dx))
456+ end
457+ return y, foldl_pullback_tuple
458+ end
459+
460+ function rrule (
461+ config:: RuleConfig{>:HasReverseMode} ,
462+ :: typeof (Base. mapfoldl_impl),
463+ :: typeof (identity),
464+ op:: G ,
465+ init,
466+ x:: Tuple ;
467+ ) where {G}
468+ # Treat `init` by simply appending it to the `x`:
469+ y, back = rrule (config, Base. mapfoldl_impl, identity, op, Base. _InitialValue (), (init, x... ))
470+ project_x = ProjectTo (x)
471+ project_in = ProjectTo (init)
472+ function foldl_pullback_tuple_init (dy)
473+ _, _, dop, _, dxplus = back (dy)
474+ return (NoTangent (), NoTangent (), dop, project_in (first (dxplus)), project_x (Base. tail (dxplus)))
475+ end
476+ return y, foldl_pullback_tuple_init
477+ end
425478
426- # The implementation aims to be efficient for both tuples and arrays, although using accumulate
427- # to carry intermediate results along creates arrays of tuples which could be avoided; using a
428- # loop can be a few times faster. Note also that it does not return a gradient for `init`.
479+ # ####
480+ # #### `foldl(f, ::Array)`
481+ # ####
429482
430- # Maybe that's a problem. Let's move the rule to `mapfoldr_impl(f, op, init, itr)`, where it's easier?
483+ # The implementation was originally for both tuples and arrays, although using accumulate
484+ # to carry intermediate results along creates arrays of tuples which could be avoided.
485+ # Using a loop can be a few times faster, this should be replaced.
486+ # Note also that it does not return a gradient for `init`.
431487
432488function rrule (
433489 config:: RuleConfig{>:HasReverseMode} , :: typeof (Base. mapfoldl_impl), :: typeof (identity), op:: G , init, x:: Union{AbstractArray, Tuple} ;
@@ -486,8 +542,7 @@ _reverse1(x::Tuple) = reverse(x)
486542_drop1 (x:: Tuple ) = Base. tail (x)
487543_zip2 (x:: Tuple{Vararg{Any,N}} , y:: Tuple{Vararg{Any,N}} ) where N = ntuple (i -> (x[i],y[i]), N)
488544
489- # struct _InitialValue end # Old versions don't have `Base._InitialValue`
490- const _INIT = VERSION >= v " 1.5" ? Base. _InitialValue () : NamedTuple ()
545+ const _INIT = Base. _InitialValue ()
491546
492547_vcat1 (x, ys:: AbstractVector ) = vcat (x, ys)
493548_vcat1 (x:: AbstractArray , ys:: AbstractVector ) = vcat ([x], ys)
0 commit comments