@@ -180,60 +180,71 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
180180 end # prod
181181
182182 @testset " foldl(f, ::Array)" begin
183+ # `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is
184+ # now attached there, as this is the simplest way to handle `init` keyword.
185+ @eval using Base: mapfoldl_impl
186+ @eval _INIT = VERSION >= v " 1.5" ? Base. _InitialValue () : NamedTuple ()
187+
183188 # Simple
184- y1, b1 = rrule (CFG, foldl, * , [1 , 2 , 3 ]; init = 1 )
189+ y1, b1 = rrule (CFG, mapfoldl_impl, identity, * , 1 , [1 , 2 , 3 ])
185190 @test y1 == 6
186- b1 (7 ) == (NoTangent (), NoTangent (), [42 , 21 , 14 ])
191+ @test b1 (7 )[1 : 3 ] == (NoTangent (), NoTangent (), NoTangent ())
192+ @test b1 (7 )[4 ] isa ChainRulesCore. NotImplemented
193+ @test b1 (7 )[5 ] == [42 , 21 , 14 ]
187194
188- y2, b2 = rrule (CFG, foldl, * , [1 2 ; 0 4 ]) # without init, needs vcat
195+ y2, b2 = rrule (CFG, mapfoldl_impl, identity, * , _INIT , [1 2 ; 0 4 ]) # without init, needs vcat
189196 @test y2 == 0
190- b2 (8 ) == ( NoTangent (), NoTangent (), [0 0 ; 64 0 ]) # matrix, needs reshape
197+ @test b2 (8 )[ 5 ] == [0 0 ; 64 0 ] # matrix, needs reshape
191198
192199 # Test execution order
193200 c5 = Counter ()
194- y5, b5 = rrule (CFG, foldl, c5 , [5 , 7 , 11 ])
201+ y5, b5 = rrule (CFG, mapfoldl_impl, identity, c5, _INIT , [5 , 7 , 11 ])
195202 @test c5 == Counter (2 )
196203 @test y5 == ((5 + 7 )* 1 + 11 )* 2 == foldl (Counter (), [5 , 7 , 11 ])
197- @test b5 (1 ) == ( NoTangent (), NoTangent (), [12 * 32 , 12 * 42 , 22 ])
204+ @test b5 (1 )[ 5 ] == [12 * 32 , 12 * 42 , 22 ]
198205 @test c5 == Counter (42 )
199206
200207 c6 = Counter ()
201- y6, b6 = rrule (CFG, foldl, c6, [5 , 7 , 11 ], init = 3 )
208+ y6, b6 = rrule (CFG, mapfoldl_impl, identity, c6, 3 , [5 , 7 , 11 ])
202209 @test c6 == Counter (3 )
203210 @test y6 == (((3 + 5 )* 1 + 7 )* 2 + 11 )* 3 == foldl (Counter (), [5 , 7 , 11 ], init= 3 )
204- @test b6 (1 ) == ( NoTangent (), NoTangent (), [63 * 33 * 13 , 43 * 13 , 23 ])
211+ @test b6 (1 )[ 5 ] == [63 * 33 * 13 , 43 * 13 , 23 ]
205212 @test c6 == Counter (63 )
206213
207214 # Test gradient of function
208- y7, b7 = rrule (CFG, foldl, Multiplier (3 ), [5 , 7 , 11 ])
215+ y7, b7 = rrule (CFG, mapfoldl_impl, identity, Multiplier (3 ), _INIT , [5 , 7 , 11 ])
209216 @test y7 == foldl ((x,y)-> x* y* 3 , [5 , 7 , 11 ])
210- @test b7 (1 ) == (NoTangent (), Tangent {Multiplier{Int}} (x = 2310 ,), [693 , 495 , 315 ])
217+ b7_1 = b7 (1 )
218+ @test b7_1[3 ] == Tangent {Multiplier{Int}} (x = 2310 ,)
219+ @test b7_1[5 ] == [693 , 495 , 315 ]
211220
212- y8, b8 = rrule (CFG, foldl, Multiplier (13 ), [5 , 7 , 11 ], init = 3 )
221+ y8, b8 = rrule (CFG, mapfoldl_impl, identity, Multiplier (13 ), 3 , [5 , 7 , 11 ])
213222 @test y8 == 2_537_535 == foldl ((x,y)-> x* y* 13 , [5 , 7 , 11 ], init= 3 )
214- @test b8 (1 ) == (NoTangent (), Tangent {Multiplier{Int}} (x = 585585 ,), [507507 , 362505 , 230685 ])
223+ b8_1 = b8 (1 )
224+ @test b8_1[3 ] == Tangent {Multiplier{Int}} (x = 585585 ,)
225+ @test b8_1[5 ] == [507507 , 362505 , 230685 ]
215226 # To find these numbers:
216227 # ForwardDiff.derivative(z -> foldl((x,y)->x*y*z, [5,7,11], init=3), 13)
217228 # ForwardDiff.gradient(z -> foldl((x,y)->x*y*13, z, init=3), [5,7,11]) |> string
218229
219230 # Finite differencing
220- test_rrule (foldl, / , 1 .+ rand (3 ,4 ))
221- test_rrule (foldl, * , rand (ComplexF64, 3 , 4 ); fkwargs = (; init = rand (ComplexF64) ))
222- test_rrule (foldl, + , rand (ComplexF64, 7 ); fkwargs = (; init = rand (ComplexF64) ))
223- test_rrule (foldl, max, rand (3 ); fkwargs = (; init = 999 ))
231+ test_rrule (mapfoldl_impl, identity, / , _INIT , 1 .+ rand (3 ,4 ))
232+ test_rrule (mapfoldl_impl, identity, * , rand (ComplexF64), rand (ComplexF64, 3 , 4 ))
233+ test_rrule (mapfoldl_impl, identity, + , rand (ComplexF64), rand (ComplexF64, 7 ))
234+ test_rrule (mapfoldl_impl, identity, max, 999 , rand (3 ))
224235 end
225236 VERSION >= v " 1.5" && @testset " foldl(f, ::Tuple)" begin
226- y1, b1 = rrule (CFG, foldl, * , (1 ,2 ,3 ); init = 1 )
237+ y1, b1 = rrule (CFG, mapfoldl_impl, identity, * , 1 , (1 ,2 ,3 ))
227238 @test y1 == 6
228- b1 (7 ) == ( NoTangent (), NoTangent (), Tangent {NTuple{3,Int}} (42 , 21 , 14 ) )
239+ @test b1 (7 )[ 5 ] == Tangent {NTuple{3,Int}} (42 , 21 , 14 )
229240
230- y2, b2 = rrule (CFG, foldl, * , (1 , 2 , 0 , 4 ))
241+ y2, b2 = rrule (CFG, mapfoldl_impl, identity, * , _INIT , (1 , 2 , 0 , 4 ))
231242 @test y2 == 0
232- b2 (8 ) == ( NoTangent (), NoTangent (), Tangent {NTuple{4,Int}} (0 , 0 , 64 , 0 ) )
243+ @test b2 (8 )[ 5 ] == Tangent {NTuple{4,Int}} (0 , 0 , 64 , 0 )
233244
234245 # Finite differencing
235- test_rrule (foldl, / , Tuple (1 .+ rand (5 )))
236- test_rrule (foldl, * , Tuple (rand (ComplexF64, 5 )))
246+ test_rrule (mapfoldl_impl, identity, / , _INIT , Tuple (1 .+ rand (5 )))
247+ test_rrule (mapfoldl_impl, identity, * , _INIT , Tuple (rand (ComplexF64, 5 )))
237248 end
238249end
239250
0 commit comments