@@ -186,25 +186,28 @@ end
186186# ####
187187
188188function frule ((_, xdot), :: typeof (cumsum), x:: AbstractArray ; dims:: Integer )
189- return cumsum (x; dims= dims ), cumsum (xdot; dims = dims)
189+ return cumsum (x; dims), cumsum (xdot; dims)
190190end
191191frule (tang, :: typeof (cumsum), x:: AbstractVector ) = frule (tang, cumsum, x; dims= 1 )
192192
193193function frule ((_, ydot, xdot), :: typeof (cumsum!), y:: AbstractArray , x:: AbstractArray ; dims:: Integer )
194- return cumsum! (y, x; dims= dims ), cumsum! (ydot, xdot; dims = dims)
194+ return cumsum! (y, x; dims), cumsum! (ydot, xdot; dims)
195195end
196196frule (t, :: typeof (cumsum!), y:: AbstractVector , x:: AbstractVector ) = frule (t, cumsum!, y, x; dims= 1 )
197197
198- function rrule (:: typeof (cumsum), x:: AbstractArray ; dims:: Integer )
198+ function rrule (:: typeof (cumsum), x:: AbstractArray{T,N} ; dims:: Integer ) where {T,N}
199199 project = ProjectTo (x)
200200 function cumsum_pullback (dy)
201+ if dims > N # trivial case, for which reverse fails
202+ return (NoTangent (), project (unthunk (dy)))
203+ end
201204 step1 = reverse (unthunk (dy); dims= dims)
202- if ChainRulesCore. is_inplaceable_destination (step1) && VERSION >= v " 1.6 "
203- step2 = cumsum! (step1, step1; dims= dims )
204- step3 = reverse! (step2; dims= dims )
205+ if ChainRulesCore. is_inplaceable_destination (step1)
206+ step2 = cumsum! (step1, step1; dims)
207+ step3 = reverse! (step2; dims)
205208 else
206- step2 = cumsum (step1; dims= dims )
207- step3 = reverse (step2; dims= dims )
209+ step2 = cumsum (step1; dims)
210+ step3 = reverse (step2; dims)
208211 end
209212 return (NoTangent (), project (step3))
210213 end
0 commit comments