@@ -39,9 +39,12 @@ bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool) =
3939
4040
4141function ChainRulesCore. rrule (cfg:: RCR , :: typeof (bias_act!), σ:: F , x:: AbstractArray{T,N} , b:: B ) where {F,T,N,B}
42- if eltype (B) != = Bool
43- b_dims = ntuple (d -> size (b,d)== 1 ? d : N+ 1 , N)
44- size_b = size (b)
42+ biasgrad = if eltype (B) != = Bool
43+ # Summing over ndims(x)+1 is a trick to make b_dims type-stable
44+ dims = ntuple (d -> size (b,d)== 1 ? d : N+ 1 , N)
45+ _biasgrad (dx) = reshape (sum (dx; dims), size (b))
46+ else
47+ Returns (NoTangent ())
4548 end
4649
4750 # Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ
@@ -52,50 +55,195 @@ function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractA
5255 # TODO with e.g. https://github.com/FluxML/Zygote.jl/pull/1340
5356 # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592
5457 dx = only_derivative .(Ω, σ, NotaNumber ()) .* unthunk (Δ)
55- db = eltype (B) === Bool ? NoTangent () : reshape (sum (dx; dims = b_dims), size_b)
56- return (NoTangent (), NoTangent (), dx, db)
58+ return (NoTangent (), NoTangent (), dx, biasgrad (dx))
5759 end
5860 return Ω, bias_act!_fastback
5961
60- # # Slower path: can't overwrite x, but can use derivatives_given_output
61- # # This case is WRONG and tests fail, but not sure why
62- # elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T}))
63- # Ω2 = fast_act(σ, x).(x) .+ b
64- # @show σ b
65- # function bias_act!_back2(Δ)
66- # dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ)
67- # db = eltype(B) === Bool ? NoTangent() : reshape(sum(dx; dims = b_dims), size_b)
68- # return (NoTangent(), NoTangent(), dx, db)
69- # end
70- # return Ω2, bias_act!_back2
62+ # Slower path: can't overwrite x, but can use derivatives_given_output
63+ # This case is WRONG and tests fail, but not sure why
64+ elseif isconcretetype (Core. Compiler. _return_type (only_derivative, Tuple{T, F, T}))
65+ Ω2 = fast_act (σ, x).(x) .+ b
66+ @show σ b
67+ function bias_act!_back2 (Δ)
68+ dx = only_derivative .(Ω2, σ, x .+ b) .* unthunk (Δ)
69+ return (NoTangent (), NoTangent (), dx, biasgrad (dx))
70+ end
71+ return Ω2, bias_act!_back2
7172
7273 # Fallback path: let AD handle the broadcast
7374 else
7475 Ω3, back = rrule_via_ad (cfg, broadcast, fast_act (σ, x), bias_act! (identity, x, b))
7576 @inline function bias_act!_slowback (Δ)
7677 _, _, dx = back (Δ)
77- db = eltype (B) === Bool ? NoTangent () : reshape (sum (dx; dims = b_dims), size_b)
78- return (NoTangent (), NoTangent (), dx, db)
78+ return (NoTangent (), NoTangent (), dx, biasgrad (dx))
7979 end
8080 return Ω3, bias_act!_slowback
8181 end
8282end
8383
84- # Two easy cases
84+ # Two easy cases with identity
8585function rrule (cfg:: RCR , :: typeof (bias_act!), :: typeof (identity), x:: AbstractArray{T,N} , b:: B ) where {T,N,B}
86- b_dims = ntuple (d -> size (b,d)== 1 ? d : N+ 1 , N)
87- size_b = size (b)
86+ dims = ntuple (d -> size (b,d)== 1 ? d : N+ 1 , N)
87+ biasgrad (dx) = reshape ( sum (dx; dims), size (b) )
8888 function bias_act!_idback (Δ)
8989 dx = unthunk (Δ)
90- db = reshape (sum (dx; dims = b_dims), size_b)
91- return (NoTangent (), NoTangent (), dx, db)
90+ return (NoTangent (), NoTangent (), dx, biasgrad (dx))
9291 end
9392 return bias_act! (identity, x, b), bias_act!_idback
9493end
95-
9694function rrule (cfg:: RCR , :: typeof (bias_act!), :: typeof (identity), x:: AbstractArray{T,N} , b:: Bool ) where {T,N}
9795 bias_act!_trivial (Δ) = (NoTangent (), NoTangent (), Δ, NoTangent ())
9896 return x, bias_act!_trivial
9997end
10098
10199
100+
101+ # """
102+ # add_act(σ, x, y...)
103+ # add_act!(σ, x, y, z...)
104+
105+ # Equivalent to `σ.(x .+ y .+ z)`. The mutating method `add_act!`
106+ # """
107+ # add_act(σ::Function, x::AbstractArray, yz::AbstractArray...) = σ.(.+(x, yz...)) # fused
108+
109+
110+ # function ChainRulesCore.rrule(::typeof(add_act), σ::F, x::AbstractArray, yz::AbstractArray...) where {F,T,N}
111+ # if isconcretetype(Core.Compiler._return_type(
112+ # derivatives_given_output, Tuple{T, F, NotaNumber}))
113+
114+ # end
115+
116+
117+ # bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) =
118+ # # b ? (x .= fast_act(σ, x).(x .+ b)) : (x .= fast_act(σ, x).(x))
119+ # (@assert !b "bias=true is not accepted"; (x .= fast_act(σ, x).(x)))
120+
121+
122+ # using NNlib, BenchmarkTools
123+
124+ #=
125+
126+ ## M1 mac, 1.10
127+
128+ julia> w, b = rand(Float32, 100, 10000), rand(Float32, 100);
129+
130+ julia> @btime bias_act!(relu, $w, $b);
131+ min 19.500 μs, mean 21.375 μs (0 allocations)
132+
133+ julia> @btime relu.($w .+ $b);
134+ min 17.208 μs, mean 62.826 μs (2 allocations, 390.67 KiB)
135+
136+ julia> @btime bias_act!(tanh, $w, $b);
137+ min 63.792 μs, mean 65.052 μs (0 allocations)
138+
139+ julia> @btime tanh_fast.($w .+ $b);
140+ min 63.583 μs, mean 102.004 μs (2 allocations, 390.67 KiB)
141+
142+ julia> using Zygote
143+
144+ julia> @btime gradient((w,b) -> sum(bias_act!(relu, w, b)), $w, $b);
145+ min 145.166 μs, mean 150.785 μs (51 allocations, 2.18 KiB)
146+
147+ julia> @btime gradient((w,b) -> sum(relu.(w .+ b)), $w, $b);
148+ min 165.583 μs, mean 314.267 μs (32 allocations, 1.15 MiB)
149+
150+ julia> @btime gradient((w,b) -> sum(bias_act!(tanh, w, b)), $w, $b);
151+ min 191.917 μs, mean 195.956 μs (51 allocations, 2.18 KiB)
152+
153+ julia> @btime gradient((w,b) -> sum(tanh_fast.(w .+ b)), $w, $b);
154+ min 209.458 μs, mean 338.652 μs (32 allocations, 1.15 MiB)
155+
156+
157+
158+ ## Cyclops
159+
160+ julia> using CUDA # 10x bigger
161+
162+ julia> cw, cb = CUDA.rand(Float32, 100, 100_00), CUDA.rand(Float32, 100);
163+
164+ julia> @btime CUDA.@sync bias_act!(relu, $cw, $cb);
165+ 22.546 μs (27 allocations: 1.45 KiB)
166+
167+ julia> @btime CUDA.@sync relu.($cw .+ $cb); # faster, that's odd?
168+ 31.282 μs (38 allocations: 1.81 KiB)
169+
170+ julia> @btime CUDA.@sync bias_act!(tanh, $cw, $cb);
171+ 27.030 μs (27 allocations: 1.45 KiB)
172+
173+ julia> @btime CUDA.@sync tanh_fast.($cw .+ $cb);
174+ 36.421 μs (38 allocations: 1.81 KiB)
175+
176+ julia> using Zygote
177+
178+ julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(relu, w, b)), $cw, $cb);
179+ 204.507 μs (382 allocations: 18.15 KiB)
180+
181+ julia> @btime CUDA.@sync gradient((w,b) -> sum(relu.(w .+ b)), $cw, $cb);
182+ 204.458 μs (409 allocations: 19.19 KiB)
183+
184+ julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(tanh, w, b)), $cw, $cb);
185+ 224.545 μs (382 allocations: 18.15 KiB)
186+
187+ julia> @btime CUDA.@sync gradient((w,b) -> sum(tanh_fast.(w .+ b)), $cw, $cb);
188+ 204.793 μs (411 allocations: 19.30 KiB)
189+
190+
191+ =#
192+
193+ #=
194+
195+ (jl_fuwIi8) pkg> add https://github.com/mcabbott/NNlib.jl/tree/bias_act_23
196+
197+ julia> using NNlib, Zygote, BenchmarkTools
198+
199+ julia> w, b, x = rand(Float32, 50, 50), rand(Float32, 50), randn(Float32, 50, 100);
200+
201+ julia> @btime bias_act!(relu, $w * $x, $b);
202+ min 5.243 μs, mean 8.600 μs (2 allocations, 19.61 KiB)
203+
204+ julia> @btime relu.($w * $x .+ $b);
205+ min 5.160 μs, mean 10.863 μs (4 allocations, 39.22 KiB)
206+
207+ julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
208+ min 21.042 μs, mean 40.476 μs (43 allocations, 89.83 KiB)
209+
210+ julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
211+ min 21.542 μs, mean 43.947 μs (41 allocations, 128.91 KiB)
212+
213+ julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
214+ min 14.708 μs, mean 26.450 μs (28 allocations, 69.41 KiB)
215+
216+ julia> @btime gradient(x -> sum(abs2, x), $x);
217+ min 1.938 μs, mean 4.160 μs (2 allocations, 19.61 KiB)
218+
219+
220+ # Cyclops
221+
222+ julia> @btime bias_act!(relu, $w * $x, $b);
223+ 24.786 μs (2 allocations: 19.61 KiB)
224+
225+ julia> @btime relu.($w * $x .+ $b);
226+ 25.501 μs (4 allocations: 39.22 KiB)
227+
228+ julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
229+ 91.847 μs (43 allocations: 89.83 KiB)
230+
231+ julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
232+ 98.054 μs (41 allocations: 128.91 KiB)
233+
234+ julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
235+ 80.464 μs (28 allocations: 69.41 KiB)
236+
237+ julia> @btime gradient(x -> sum(abs2, x), $x);
238+ 4.604 μs (2 allocations: 19.61 KiB)
239+
240+ julia> @time using CUDA; @time cu(ones(3)) .+ 1;
241+
242+ julia> w, b, x = CUDA.rand(Float32, 1000, 1000), CUDA.rand(Float32, 1000), CUDA.rand(Float32, 1000, 1000);
243+
244+
245+
246+ =#
247+
248+
249+
0 commit comments