@@ -80,21 +80,21 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
8080
8181 @testset " fused rules" begin
8282 @testset " arithmetic" begin
83- test_rrule (copy∘ broadcasted, + , rand (3 ), rand (3 ))
84- test_rrule (copy∘ broadcasted, + , rand (3 ), rand (4 )' )
85- test_rrule (copy∘ broadcasted, + , rand (3 ), rand (1 ), rand ())
86- test_rrule (copy∘ broadcasted, + , rand (3 ), 1.0 * im)
87- test_rrule (copy∘ broadcasted, + , rand (3 ), true )
88- test_rrule (copy∘ broadcasted, + , rand (3 ), Tuple (rand (3 )))
83+ @gpu test_rrule (copy∘ broadcasted, + , rand (3 ), rand (3 ))
84+ @gpu test_rrule (copy∘ broadcasted, + , rand (3 ), rand (4 )' )
85+ @gpu test_rrule (copy∘ broadcasted, + , rand (3 ), rand (1 ), rand ())
86+ @gpu test_rrule (copy∘ broadcasted, + , rand (3 ), 1.0 * im)
87+ @gpu test_rrule (copy∘ broadcasted, + , rand (3 ), true )
88+ @gpu_broken test_rrule (copy∘ broadcasted, + , rand (3 ), Tuple (rand (3 )))
8989
90- test_rrule (copy∘ broadcasted, - , rand (3 ), rand (3 ))
91- test_rrule (copy∘ broadcasted, - , rand (3 ), rand (4 )' )
92- test_rrule (copy∘ broadcasted, - , rand (3 ))
90+ @gpu test_rrule (copy∘ broadcasted, - , rand (3 ), rand (3 ))
91+ @gpu test_rrule (copy∘ broadcasted, - , rand (3 ), rand (4 )' )
92+ @gpu test_rrule (copy∘ broadcasted, - , rand (3 ))
9393 test_rrule (copy∘ broadcasted, - , Tuple (rand (3 )))
9494
95- test_rrule (copy∘ broadcasted, * , rand (3 ), rand (3 ))
96- test_rrule (copy∘ broadcasted, * , rand (3 ), rand ())
97- test_rrule (copy∘ broadcasted, * , rand (), rand (3 ))
95+ @gpu test_rrule (copy∘ broadcasted, * , rand (3 ), rand (3 ))
96+ @gpu test_rrule (copy∘ broadcasted, * , rand (3 ), rand ())
97+ @gpu test_rrule (copy∘ broadcasted, * , rand (), rand (3 ))
9898
9999 test_rrule (copy∘ broadcasted, * , rand (3 ) .+ im, rand (3 ) .+ 2im )
100100 test_rrule (copy∘ broadcasted, * , rand (3 ) .+ im, rand () + 3im )
@@ -107,14 +107,15 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
107107 @test unthunk (bk4 ([4 , 5im , 6 + 7im ])[4 ]) == [0 ,5 ,7 ]
108108
109109 # These two test vararg rrule * rule:
110- test_rrule (copy∘ broadcasted, * , rand (3 ), rand (3 ), rand (3 ), rand (3 ), rand (3 ))
111- test_rrule (copy∘ broadcasted, * , rand (), rand (), rand (3 ), rand (3 ) .+ im, rand (4 )' )
110+ @gpu test_rrule (copy∘ broadcasted, * , rand (3 ), rand (3 ), rand (3 ), rand (3 ), rand (3 ))
111+ @gpu_broken test_rrule (copy∘ broadcasted, * , rand (), rand (), rand (3 ), rand (3 ) .+ im, rand (4 )' )
112+ # GPU error from dot(x::JLArray{Float32, 1}, y::JLArray{ComplexF32, 2})
112113
113- test_rrule (copy∘ broadcasted, Base. literal_pow, ^ , rand (3 ), Val (2 ))
114- test_rrule (copy∘ broadcasted, Base. literal_pow, ^ , rand (3 ) .+ im, Val (2 ))
114+ @gpu test_rrule (copy∘ broadcasted, Base. literal_pow, ^ , rand (3 ), Val (2 ))
115+ @gpu test_rrule (copy∘ broadcasted, Base. literal_pow, ^ , rand (3 ) .+ im, Val (2 ))
115116
116- test_rrule (copy∘ broadcasted, / , rand (3 ), rand ())
117- test_rrule (copy∘ broadcasted, / , rand (3 ) .+ im, rand () + 3im )
117+ @gpu test_rrule (copy∘ broadcasted, / , rand (3 ), rand ())
118+ @gpu test_rrule (copy∘ broadcasted, / , rand (3 ) .+ im, rand () + 3im )
118119 end
119120 @testset " identity etc" begin
120121 test_rrule (copy∘ broadcasted, identity, rand (3 ))
0 commit comments