@@ -52,7 +52,8 @@ function norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing},
5252 error (" both scale and bias must be provided or left as nothing" )
5353 end
5454 scale′, bias′ = _maybe_reshape (scale, affine_size), _maybe_reshape (bias, affine_size)
55- return _apply_scale_bias ((x .- μ) ./ sqrt .(σ² .+ ϵ), scale′, bias′)
55+ denom = inv .(sqrt .(σ² .+ ϵ))
56+ return _apply_scale_bias ((x .- μ) .* denom, scale′, bias′)
5657end
5758
5859"""
6162Contains running mean and variance estimates for stateful norm functions.
6263`momentum` controls the strength of the moving average update.
6364
64- If the parameters are mutable, they will be updated in-place.
65- Otherwise, they will be replaced wholesale.
65+ Parameters should be mutable and will be updated in-place.
6666
6767See also [`update_running_stats!`](@ref).
6868"""
69- mutable struct RunningStats{M <: AbstractArray , V <: AbstractArray , MT <: Real }
69+ struct RunningStats{M <: AbstractArray , V <: AbstractArray , MT <: Real }
7070 mean:: M
7171 variance:: V
7272 momentum:: MT
@@ -127,16 +127,9 @@ function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Di
127127 correction = m / (m - one (V))
128128
129129 running_mean, running_var = stats. mean, stats. variance
130- if ChainRulesCore. is_inplaceable_destination (running_mean)
131- stats. mean .= res_mtm .* running_mean .+ momentum .* vec (μ)
132- else
133- stats. mean = res_mtm .* running_mean .+ momentum .* vec (μ)
134- end
135- if ChainRulesCore. is_inplaceable_destination (running_var)
136- stats. variance .= res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
137- else
138- stats. variance = res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
139- end
130+ stats. mean .= res_mtm .* running_mean .+ momentum .* vec (μ)
131+ stats. variance .= res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
132+ return
140133end
141134
142135# Convenience functions
@@ -175,7 +168,7 @@ function layernorm(x::AbstractArray{<:Any, N}, ::Val{S}, scale = nothing, bias =
175168 throw (DimensionMismatch (" got $S reduction dims for $N -dimensional array" ))
176169 end
177170 μ, σ² = norm_stats (x, ntuple (identity, S))
178- return norm_helper (x, μ, σ², scale, bias, ϵ, size (x)[1 : S])
171+ return norm_helper (x, μ, σ², scale, bias, ϵ, size (x)[1 : S]:: Dims{S} )
179172end
180173
181174"""
0 commit comments