|
359 | 359 | end |
360 | 360 | end # cumprod |
361 | 361 |
|
362 | | - @testset "accumulate(f, ::Array)" begin |
| 362 | + @testset "accumulate(f, ::Vector)" begin |
363 | 363 | # `accumulate(f, A; init)` goes to `_accumulate!(op, B, A, dims::Nothing, init::Nothing)`. |
364 | 364 | # The rule is now attached there, as this is the simplest way to handle `init` keyword. |
365 | | - @eval using Base: _accumulate! |
366 | 365 |
|
367 | 366 | # Simple |
368 | 367 | y1, b1 = rrule(CFG, _accumulate!, *, [0, 0, 0, 0], [1, 2, 3, 4], nothing, Some(1)) |
|
372 | 371 | @test b1([1, 1, 1, 1])[6] isa Tangent{Some{Int64}} |
373 | 372 | @test b1([1, 1, 1, 1])[6].value isa ChainRulesCore.NotImplemented |
374 | 373 |
|
375 | | - y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4]) |
376 | | - @test y2 ≈ accumulate(/, [1 2; 3 4]) |
377 | | - @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6 |
| 374 | + # y2, b2 = rrule(CFG, _accumulate!, /, [0 0; 0 0], [1 2; 3 4], :, nothing) |
| 375 | + # @test y2 ≈ accumulate(/, [1 2; 3 4.0]) |
| 376 | + # @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6 |
378 | 377 |
|
379 | 378 | # Test execution order |
380 | 379 | c3 = Counter() |
@@ -404,35 +403,11 @@ end |
404 | 403 | # ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string |
405 | 404 |
|
406 | 405 | # Finite differencing |
407 | | - test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand())) |
408 | | - test_rrule(accumulate, /, 1 .+ rand(3, 4)) |
409 | | - test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) |
| 406 | + # test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand())) |
| 407 | + test_rrule(_accumulate!, *, randn(5) ⊢ NoTangent(), randn(5), nothing, Some(rand())) |
| 408 | + # test_rrule(accumulate, /, 1 .+ rand(3, 4)) |
| 409 | + test_rrule(_accumulate!, /, randn(4) ⊢ NoTangent(), 1 .+ rand(4), nothing, nothing) |
| 410 | + # test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) |
| 411 | + test_rrule(_accumulate!, ^, randn(6) ⊢ NoTangent(), 1 .+ rand(6), nothing, Some(rand())) |
410 | 412 | end |
411 | | - @testset "accumulate(f, ::Tuple)" begin |
412 | | - # Simple |
413 | | - y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1) |
414 | | - @test y1 == (1, 2, 6, 24) |
415 | | - @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6)) |
416 | | - |
417 | | - # Finite differencing |
418 | | - test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand())) |
419 | | - test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false) |
420 | | - |
421 | | - test_rrule(_accumulate!, *, randn(5) ⊢ NoTangent(), randn(5), nothing, nothing) |
422 | | - test_rrule(_accumulate!, /, randn(5) ⊢ NoTangent(), randn(5), nothing, Some(1 + rand())) |
423 | | - # if VERSION >= v"1.5" |
424 | | - # test_rrule(accumulate, /, 1 .+ rand(3, 4)) |
425 | | - # test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) |
426 | | - # end |
427 | | - end |
428 | | - # VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin |
429 | | - # # Simple |
430 | | - # y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1) |
431 | | - # @test y1 == (1, 2, 6, 24) |
432 | | - # @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6)) |
433 | | - |
434 | | - # # Finite differencing |
435 | | - # test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand())) |
436 | | - # test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false) |
437 | | - # end |
438 | 413 | end |
0 commit comments