@@ -2,7 +2,7 @@ export σ, sigmoid, hardσ, hardsigmoid, hardtanh, relu, leakyrelu, relu6, rrelu
22 logsigmoid, logcosh, mish, tanhshrink, softshrink, thresholdrelu, trelu, lisht
33
44# # Activation functions
5- #
5+ #
66# Some of activation functions have its wrapper function for GPU in CuArrays.jl.
77# https://github.com/JuliaGPU/CuArrays.jl/issues/614
88
@@ -12,15 +12,11 @@ export σ, sigmoid, hardσ, hardsigmoid, hardtanh, relu, leakyrelu, relu6, rrelu
1212Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation
1313function.
1414"""
15- σ (x:: Real ) = one (x) / (one (x) + exp (- x))
16- const sigmoid = σ
17-
18- # ForwardDiff numerical stability hack
19- σ_stable (x:: Real ) = ifelse (x < - 80 , zero (x), one (x) / (one (x) + exp (- x)))
20- σ (x:: Float32 ) = σ_stable (x)
21- @init @require ForwardDiff = " f6369f11-7733-5829-9624-2563aa707210" begin
22- σ (x:: ForwardDiff.Dual{T,Float32} ) where T = σ_stable (x)
15+ function σ (x:: Real )
16+ t = exp (- abs (x))
17+ ifelse (x ≥ 0 , inv (one (t) + t), t / (one (t) + t))
2318end
19+ const sigmoid = σ
2420
2521"""
2622 hardσ(x, a=0.2) = max(0, min(1.0, a * x + 0.5))
@@ -159,17 +155,17 @@ function selu(x::Real)
159155end
160156
161157"""
162- celu(x, α=1) =
158+ celu(x, α=1) =
163159 (x ≥ 0 ? x : α * (exp(x/α) - 1))
164160
165161Continuously Differentiable Exponential Linear Units
166162See [Continuously Differentiable Exponential Linear Units](https://arxiv.org/pdf/1704.07483.pdf).
167163"""
168- celu (x:: Real , α:: Real = one (x)) = ifelse (x ≥ 0 , x / one (x), α * (exp (x/ α) - one (x)))
164+ celu (x:: Real , α:: Real = one (x)) = ifelse (x ≥ 0 , x / one (x), α * (exp (x/ α) - one (x)))
169165
170166
171167"""
172- trelu(x, theta = 1.0) = x > theta ? x : 0
168+ trelu(x, theta = 1.0) = x > theta ? x : 0
173169
174170Threshold Gated Rectified Linear.
175171See [ThresholdRelu](https://arxiv.org/pdf/1402.3337.pdf)
@@ -218,15 +214,15 @@ See [Tanhshrink Activation Function](https://www.gabormelli.com/RKB/Tanhshrink_A
218214tanhshrink (x:: Real ) = x - tanh (x)
219215
220216"""
221- softshrink(x, λ=0.5) =
217+ softshrink(x, λ=0.5) =
222218 (x ≥ λ ? x - λ : (-λ ≥ x ? x + λ : 0))
223219
224220See [Softshrink Activation Function](https://www.gabormelli.com/RKB/Softshrink_Activation_Function).
225221"""
226222softshrink (x:: Real , λ = oftype (x/ 1 , 0.5 )) = min (max (zero (x), x - λ), x + λ)
227223
228224# Provide an informative error message if activation functions are called with an array
229- for f in (:σ , :σ_stable , : hardσ , :logσ , :hardtanh , :relu , :leakyrelu , :relu6 , :rrelu , :elu , :gelu , :swish , :lisht , :selu , :celu , :trelu , :softsign , :softplus , :logcosh , :mish , :tanhshrink , :softshrink )
225+ for f in (:σ , :hardσ , :logσ , :hardtanh , :relu , :leakyrelu , :relu6 , :rrelu , :elu , :gelu , :swish , :lisht , :selu , :celu , :trelu , :softsign , :softplus , :logcosh , :mish , :tanhshrink , :softshrink )
230226 @eval $ (f)(x:: AbstractArray , args... ) =
231227 error (" Use broadcasting (`" , $ (string (f)), " .(x)`) to apply activation functions to arrays." )
232228end
0 commit comments