@@ -8,10 +8,10 @@ using Statistics: mean, std
88using Random
99# using StatsBase
1010
11- gradtest (f, xs:: AbstractArray... ) = gradcheck ((xs... ) -> sum (sin .(f (xs... ))), xs... )
12- gradtest (f, dims... ) = gradtest (f, rand .(Float64, dims)... )
11+ gradtest (f, xs:: AbstractArray... ; kw ... ) = gradcheck ((xs... ) -> sum (sin .(f (xs... ))), xs... ; kw ... )
12+ gradtest (f, dims... ; kw ... ) = gradtest (f, rand .(Float64, dims)... ; kw ... )
1313
14- @testset " Tracker " begin # overall testset, rest of the file
14+ @testset " gradtests 1 " begin
1515
1616@test gradtest ((x, W, b) -> σ .(W* x .+ b), 5 , (2 ,5 ), 2 )
1717@test gradtest ((x, W) -> σ .(W* x), 5 , (2 ,5 ))
4545@test gradtest (logdet, map ((x) -> x* x' , (rand (4 , 4 ),))[1 ])
4646@test gradtest ((x) -> logabsdet (x)[1 ], (4 , 4 ))
4747
48+ end # @testset gradtests
49+
4850@testset " indexing & slicing" begin
49- gradtest (x-> view (x, 1 : 2 , 1 : 2 ), rand (4 , 4 ))
51+ @test gradtest (x-> view (x, 1 : 2 , 1 : 2 ), rand (4 , 4 ))
5052end
5153
5254function promotiontest (f, A, B, C)
5355 r0 = f (A, B, C)
5456 r1 = f (param (A), B, C)
5557 r2 = f (A, param (B), C)
56- r3 = f (A, B, param (C))
58+ # r3 = f(A, B, param(C)) # no longer cater to tracked array in 3rd position
5759 r4 = f (param (A), param (B), param (C))
5860
5961 @test ! isa (r0, TrackedArray)
60- @test all (isa .([r1,r2,r3,r4], TrackedArray))
61- @test r1 == r2 == r3 == r4
62+ # @test all(isa.([r1,r2,r3,r4], TrackedArray))
63+ # @test r1 == r2 == r3 == r4
64+ @test all (isa .([r1,r2,r4], TrackedArray))
65+ @test r1 == r2 == r4
6266 @test r0 == Tracker. data (r4)
6367end
6468
6872 rvcat (x... ) = reduce (vcat, x)
6973 rhcat (x... ) = reduce (hcat, x)
7074
71- @testset for vcatf in [vcat, cat1, rvcat]
75+ @testset " 2-arg $vcatf " for vcatf in [vcat, cat1, rvcat]
7276 @test gradtest (vcatf, rand (5 ), rand (3 ))
7377 @test gradtest (vcatf, rand (5 ), rand (3 ), rand (8 ))
7478 @test gradtest (vcatf, rand (5 )' , rand (5 )' )
7983 end
8084
8185
82- @testset for hcatf in [hcat, cat2, rhcat]
86+ @testset " 2-arg $hcatf " for hcatf in [hcat, cat2, rhcat]
8387 @test gradtest (hcatf, rand (5 ), rand (5 ))
8488 @test gradtest (hcatf, rand (5 )' , rand (5 )' )
8589 @test gradtest (hcatf, rand (2 ,5 ), rand (2 ,3 ), rand (2 ,8 ))
8993 @test gradtest (hcatf, rand (5 ), rand (5 ,2 ))
9094end
9195
92- @testset for catf in [vcat, cat1, rvcat, hcat, cat2, rhcat, (x... ) -> cat (x... , dims = 3 ), (x... ) -> cat (x... , dims = (1 ,2 ))]
96+ @testset " 1-arg $catf " for catf in [vcat, cat1, rvcat, hcat, cat2, rhcat, (x... ) -> cat (x... , dims = 3 ), (x... ) -> cat (x... , dims = (1 ,2 ))]
9397 @test gradtest (catf, rand (5 ))
9498 @test gradtest (catf, rand (5 )' )
9599 @test gradtest (catf, rand (2 ,5 ))
133137 @test hcat (1 , param ([1 2 3 ;])) isa TrackedArray
134138 @test vcat (param (1 ), 2 ) isa TrackedArray
135139 end
140+
141+ @testset " ambiguities" begin
142+ @test vcat (param ([1 , 2 , 3 ]), [2 ,3 ]) isa TrackedArray
143+ @test vcat (param ([1 , 2 , 3 ]), [2.0 , 3.0 ]) isa TrackedArray
144+ @test hcat (param ([1 2 3 ]), [2 , 3 ]' ) isa TrackedArray
145+ @test hcat (param ([1 2 3 ]), [2.0 , 3.0 ]' ) isa TrackedArray
146+ end
136147
137148end
138149
141152 @test gradtest (x-> x[z], randn (MersenneTwister (123456 ), 3 ))
142153end
143154
155+ @testset " gradtests 2" begin
156+
144157@test gradtest (x -> permutedims (x, [3 ,1 ,2 ]), rand (4 ,5 ,6 ))
145158@test gradtest (x -> PermutedDimsArray (x, [3 ,1 ,2 ]), rand (4 ,5 ,6 ))
146159
159172@test gradtest (kron, rand (5 ,2 ), rand (3 ,2 ), rand (8 ,2 ))
160173
161174@test gradtest (x -> diagm (0 => x), rand (3 ))
175+ @test gradtest (x -> Matrix (Diagonal (x)), rand (3 ))
162176
163177@test gradtest (W -> inv (log .(W * W)), (5 ,5 ))
164178@test gradtest ((A, B) -> A / B , (1 ,5 ), (5 ,5 ))
178192 gradtest (A -> log .(A * A) \ exp .(B * B), (5 , 5 ))
179193end
180194
195+ end # @testset "gradtests 2"
196+
181197@testset " mean" begin
182198 @test gradtest (mean, rand (2 , 3 ))
183199
208224 @test gradtest (x -> minimum (x, dims= [1 , 2 ]), rand (2 , 3 , 4 ))
209225end
210226
227+ @testset " gradtests 3" begin
228+
211229@test gradtest (x -> std (x), rand (5 ,5 ))
212230@test gradtest (x -> std (x, dims = 1 ), rand (5 ,5 ))
213231@test gradtest (x -> std (x, dims = 1 , corrected = false ), rand (5 ,5 ))
224242 2 y + x
225243end
226244
245+ end # @testset "gradtests 3"
246+
227247@testset " transpose" begin
228248 w = Tracker. TrackedArray (rand (5 ,5 ))
229249 x = Tracker. TrackedArray (rand (5 ,5 ))
@@ -299,17 +319,15 @@ end
299319 @test transpose (w)* transpose (x) isa TrackedArray
300320end
301321
302- @testset " conv" begin
303- for spatial_rank in (1 , 2 , 3 )
322+ @testset " conv, $(spatial_rank) d" for spatial_rank in (1 , 2 , 3 )
304323 x = rand (repeat ([10 ], spatial_rank)... , 3 , 2 )
305324 w = rand (repeat ([3 ], spatial_rank)... , 3 , 3 )
306325 cdims = DenseConvDims (x, w)
307326 @test gradtest ((x, w) -> conv (x, w, cdims), x, w)
308327 y = conv (x, w, cdims)
309328 @test gradtest ((y, w) -> ∇conv_data (y, w, cdims), y, w)
310329 dcdims = DepthwiseConvDims (x, w)
311- @test gradtest ((x, w) -> depthwiseconv (x, w, dcdims), x, w)
312- end
330+ @test_skip gradtest ((x, w) -> depthwiseconv (x, w, dcdims), x, w)
313331end
314332
315333@testset " pooling" begin
321339 end
322340end
323341
324-
325342@test gradtest (x -> Float64 .(x), 5 )
326343
327344@testset " equality & order" begin
480497 @test size (y) == (5 , 3 )
481498end
482499
483- end # overall testset
0 commit comments