@@ -59,16 +59,16 @@ function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractA
5959 end
6060 return Ω, bias_act!_fastback
6161
62- # Slower path: can't overwrite x, but can use derivatives_given_output
63- # This case is WRONG and tests fail, but not sure why
64- elseif isconcretetype (Core. Compiler. _return_type (only_derivative, Tuple{T, F, T}))
65- Ω2 = fast_act (σ, x).(x) .+ b
66- @show σ b
67- function bias_act!_back2 (Δ)
68- dx = only_derivative .(Ω2, σ, x .+ b) .* unthunk (Δ)
69- return (NoTangent (), NoTangent (), dx, biasgrad (dx))
70- end
71- return Ω2, bias_act!_back2
62+ # # Slower path: can't overwrite x, but can use derivatives_given_output
63+ # # This case is WRONG and tests fail, but not sure why
64+ # elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T}))
65+ # Ω2 = fast_act(σ, x).(x) .+ b
66+ # @show σ b
67+ # function bias_act!_back2(Δ)
68+ # dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ)
69+ # return (NoTangent(), NoTangent(), dx, biasgrad(dx))
70+ # end
71+ # return Ω2, bias_act!_back2
7272
7373 # Fallback path: let AD handle the broadcast
7474 else
@@ -96,154 +96,3 @@ function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArr
9696 return x, bias_act!_trivial
9797end
9898
99-
100-
101- # """
102- # add_act(σ, x, y...)
103- # add_act!(σ, x, y, z...)
104-
105- # Equivalent to `σ.(x .+ y .+ z)`. The mutating method `add_act!`
106- # """
107- # add_act(σ::Function, x::AbstractArray, yz::AbstractArray...) = σ.(.+(x, yz...)) # fused
108-
109-
110- # function ChainRulesCore.rrule(::typeof(add_act), σ::F, x::AbstractArray, yz::AbstractArray...) where {F,T,N}
111- # if isconcretetype(Core.Compiler._return_type(
112- # derivatives_given_output, Tuple{T, F, NotaNumber}))
113-
114- # end
115-
116-
117- # bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) =
118- # # b ? (x .= fast_act(σ, x).(x .+ b)) : (x .= fast_act(σ, x).(x))
119- # (@assert !b "bias=true is not accepted"; (x .= fast_act(σ, x).(x)))
120-
121-
122- # using NNlib, BenchmarkTools
123-
124- #=
125-
126- ## M1 mac, 1.10
127-
128- julia> w, b = rand(Float32, 100, 10000), rand(Float32, 100);
129-
130- julia> @btime bias_act!(relu, $w, $b);
131- min 19.500 μs, mean 21.375 μs (0 allocations)
132-
133- julia> @btime relu.($w .+ $b);
134- min 17.208 μs, mean 62.826 μs (2 allocations, 390.67 KiB)
135-
136- julia> @btime bias_act!(tanh, $w, $b);
137- min 63.792 μs, mean 65.052 μs (0 allocations)
138-
139- julia> @btime tanh_fast.($w .+ $b);
140- min 63.583 μs, mean 102.004 μs (2 allocations, 390.67 KiB)
141-
142- julia> using Zygote
143-
144- julia> @btime gradient((w,b) -> sum(bias_act!(relu, w, b)), $w, $b);
145- min 145.166 μs, mean 150.785 μs (51 allocations, 2.18 KiB)
146-
147- julia> @btime gradient((w,b) -> sum(relu.(w .+ b)), $w, $b);
148- min 165.583 μs, mean 314.267 μs (32 allocations, 1.15 MiB)
149-
150- julia> @btime gradient((w,b) -> sum(bias_act!(tanh, w, b)), $w, $b);
151- min 191.917 μs, mean 195.956 μs (51 allocations, 2.18 KiB)
152-
153- julia> @btime gradient((w,b) -> sum(tanh_fast.(w .+ b)), $w, $b);
154- min 209.458 μs, mean 338.652 μs (32 allocations, 1.15 MiB)
155-
156-
157-
158- ## Cyclops
159-
160- julia> using CUDA # 10x bigger
161-
162- julia> cw, cb = CUDA.rand(Float32, 100, 100_00), CUDA.rand(Float32, 100);
163-
164- julia> @btime CUDA.@sync bias_act!(relu, $cw, $cb);
165- 22.546 μs (27 allocations: 1.45 KiB)
166-
167- julia> @btime CUDA.@sync relu.($cw .+ $cb); # faster, that's odd?
168- 31.282 μs (38 allocations: 1.81 KiB)
169-
170- julia> @btime CUDA.@sync bias_act!(tanh, $cw, $cb);
171- 27.030 μs (27 allocations: 1.45 KiB)
172-
173- julia> @btime CUDA.@sync tanh_fast.($cw .+ $cb);
174- 36.421 μs (38 allocations: 1.81 KiB)
175-
176- julia> using Zygote
177-
178- julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(relu, w, b)), $cw, $cb);
179- 204.507 μs (382 allocations: 18.15 KiB)
180-
181- julia> @btime CUDA.@sync gradient((w,b) -> sum(relu.(w .+ b)), $cw, $cb);
182- 204.458 μs (409 allocations: 19.19 KiB)
183-
184- julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(tanh, w, b)), $cw, $cb);
185- 224.545 μs (382 allocations: 18.15 KiB)
186-
187- julia> @btime CUDA.@sync gradient((w,b) -> sum(tanh_fast.(w .+ b)), $cw, $cb);
188- 204.793 μs (411 allocations: 19.30 KiB)
189-
190-
191- =#
192-
193- #=
194-
195- (jl_fuwIi8) pkg> add https://github.com/mcabbott/NNlib.jl/tree/bias_act_23
196-
197- julia> using NNlib, Zygote, BenchmarkTools
198-
199- julia> w, b, x = rand(Float32, 50, 50), rand(Float32, 50), randn(Float32, 50, 100);
200-
201- julia> @btime bias_act!(relu, $w * $x, $b);
202- min 5.243 μs, mean 8.600 μs (2 allocations, 19.61 KiB)
203-
204- julia> @btime relu.($w * $x .+ $b);
205- min 5.160 μs, mean 10.863 μs (4 allocations, 39.22 KiB)
206-
207- julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
208- min 21.042 μs, mean 40.476 μs (43 allocations, 89.83 KiB)
209-
210- julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
211- min 21.542 μs, mean 43.947 μs (41 allocations, 128.91 KiB)
212-
213- julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
214- min 14.708 μs, mean 26.450 μs (28 allocations, 69.41 KiB)
215-
216- julia> @btime gradient(x -> sum(abs2, x), $x);
217- min 1.938 μs, mean 4.160 μs (2 allocations, 19.61 KiB)
218-
219-
220- # Cyclops
221-
222- julia> @btime bias_act!(relu, $w * $x, $b);
223- 24.786 μs (2 allocations: 19.61 KiB)
224-
225- julia> @btime relu.($w * $x .+ $b);
226- 25.501 μs (4 allocations: 39.22 KiB)
227-
228- julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
229- 91.847 μs (43 allocations: 89.83 KiB)
230-
231- julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
232- 98.054 μs (41 allocations: 128.91 KiB)
233-
234- julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
235- 80.464 μs (28 allocations: 69.41 KiB)
236-
237- julia> @btime gradient(x -> sum(abs2, x), $x);
238- 4.604 μs (2 allocations: 19.61 KiB)
239-
240- julia> @time using CUDA; @time cu(ones(3)) .+ 1;
241-
242- julia> w, b, x = CUDA.rand(Float32, 1000, 1000), CUDA.rand(Float32, 1000), CUDA.rand(Float32, 1000, 1000);
243-
244-
245-
246- =#
247-
248-
249-
0 commit comments