Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit b8a75e6

Browse files
committed
Fix Halley's method
1 parent 79280f6 commit b8a75e6

File tree

4 files changed

+186
-209
lines changed

4 files changed

+186
-209
lines changed

src/SimpleNonlinearSolve.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ include("nlsolve/broyden.jl")
3030
include("nlsolve/lbroyden.jl")
3131
include("nlsolve/klement.jl")
3232
include("nlsolve/trustRegion.jl")
33-
# include("nlsolve/halley.jl")
34-
# include("nlsolve/dfsane.jl")
33+
include("nlsolve/halley.jl")
34+
include("nlsolve/dfsane.jl")
3535

3636
## Interval Nonlinear Solvers
3737
include("bracketing/bisection.jl")
@@ -64,7 +64,7 @@ end
6464
prob_no_brack_oop = NonlinearProblem{false}((u, p) -> u .* u .- p,
6565
T.([1.0, 1.0, 1.0]), T(2))
6666

67-
algs = [SimpleNewtonRaphson(), SimpleBroyden(), SimpleKlement(),
67+
algs = [SimpleNewtonRaphson(), SimpleBroyden(), SimpleKlement(), SimpleDFSane(),
6868
SimpleTrustRegion(), SimpleLimitedMemoryBroyden(; threshold = 2)]
6969

7070
@compile_workload begin
@@ -86,9 +86,8 @@ end
8686
end
8787
end
8888

89-
export SimpleBroyden, SimpleGaussNewton, SimpleKlement, SimpleLimitedMemoryBroyden,
90-
SimpleNewtonRaphson, SimpleTrustRegion
91-
# SimpleDFSane, SimpleHalley
89+
export SimpleBroyden, SimpleDFSane, SimpleGaussNewton, SimpleHalley, SimpleKlement,
90+
SimpleLimitedMemoryBroyden, SimpleNewtonRaphson, SimpleTrustRegion
9291
export Alefeld, Bisection, Brent, Falsi, ITP, Ridder
9392

9493
end # module

src/nlsolve/dfsane.jl

Lines changed: 86 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -53,117 +53,91 @@ end
5353
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane, args...;
5454
abstol = nothing, reltol = nothing, maxiters = 1000,
5555
termination_condition = nothing, kwargs...)
56+
x = float(copy(prob.u0))
57+
fx = _get_fx(prob, x)
58+
T = eltype(x)
5659

57-
# f = isinplace(prob) ? (du, u) -> prob.f(du, u, prob.p) : u -> prob.f(u, prob.p)
58-
59-
# x = float(prob.u0)
60-
# fx = _get_fx(prob, x)
61-
# T = eltype(x)
62-
63-
# σ_min = T(alg.σ_min)
64-
# σ_max = T(alg.σ_max)
65-
# σ_k = T(alg.σ_1)
66-
67-
# M = alg.M
68-
# γ = T(alg.γ)
69-
# τ_min = T(alg.τ_min)
70-
# τ_max = T(alg.τ_max)
71-
# nexp = alg.nexp
72-
# η_strategy = alg.η_strategy
73-
74-
# abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
75-
# termination_condition)
76-
77-
# ff = if isinplace(prob)
78-
# function (_fx, x)
79-
# f(_fx, x)
80-
# f_k = norm(_fx)^nexp
81-
# return f_k, _fx
82-
# end
83-
# else
84-
# function (x)
85-
# _fx = f(x)
86-
# f_k = norm(_fx)^nexp
87-
# return f_k, _fx
88-
# end
89-
# end
90-
91-
# generate_history(f_k, M) = fill(f_k, M)
92-
93-
# f_k, F_k = isinplace(prob) ? ff(fx, x) : ff(x)
94-
# F_k = __copy(F_k)
95-
# α_1 = one(T)
96-
# f_1 = f_k
97-
# history_f_k = generate_history(f_k, M)
98-
99-
# # Generate the cache
100-
# d, xo, x_cache, δx, δf = __copy(x), __copy(x), __copy(x), __copy(x), __copy(x)
101-
# α_tp, α_tm = __copy(x), __copy(x)
102-
103-
# for k in 1:maxiters
104-
# # Spectral parameter range check
105-
# σ_k = sign(σ_k) * clamp(abs(σ_k), σ_min, σ_max)
106-
107-
# # Line search direction
108-
# d = __broadcast!!(d, *, -σ_k, F_k)
109-
110-
# η = η_strategy(f_1, k, x, F_k)
111-
# f̄ = maximum(history_f_k)
112-
# α_p = α_1
113-
# α_m = α_1
114-
115-
# x_cache = __broadcast!!(x_cache, *, α_p, d)
116-
# x = __broadcast!!(x, +, x_cache)
117-
118-
# f_new, F_new = isinplace(prob) ? ff(fx, x) : ff(x)
119-
120-
# # FIXME: This part is not correctly implemented
121-
# while true
122-
# criteria = f̄ + η - γ * α_p^2 * f_k
123-
# f_new ≤ criteria && break
124-
125-
# if ArrayInterface.can_setindex(α_tp) && !(x isa Number)
126-
# @. α_tp = α_p^2 * f_k / (f_new + (2 * α_p - 1) * f_k)
127-
# else
128-
# α_tp = @. α_p^2 * f_k / (f_new + (2 * α_p - 1) * f_k)
129-
# end
130-
# x_cache = __broadcast!!(x_cache, *, α_m, d)
131-
# x = __broadcast!!(x, -, x_cache)
132-
# f_new, F_new = isinplace(prob) ? ff(fx, x) : ff(x)
133-
134-
# f_new ≤ criteria && break
135-
136-
# if ArrayInterface.can_setindex(α_tm) && !(x isa Number)
137-
# @. α_tm = α_m^2 * f_k / (f_new + (2 * α_m - 1) * f_k)
138-
# @. α_p = clamp(α_tp, τ_min * α_p, τ_max * α_p)
139-
# @. α_m = clamp(α_tm, τ_min * α_m, τ_max * α_m)
140-
# else
141-
# α_tm = @. α_m^2 * f_k / (f_new + (2 * α_m - 1) * f_k)
142-
# α_p = @. clamp(α_tp, τ_min * α_p, τ_max * α_p)
143-
# α_m = @. clamp(α_tm, τ_min * α_m, τ_max * α_m)
144-
# end
145-
# x_cache = __broadcast!!(x_cache, *, α_p, d)
146-
# x = __broadcast!!(x, +, x_cache)
147-
# f_new, F_new = isinplace(prob) ? ff(fx, x) : ff(x)
148-
# end
149-
150-
# tc_sol = check_termination(tc_cache, f_new, x, xo, prob, alg)
151-
# tc_sol !== nothing && return tc_sol
152-
153-
# # Update spectral parameter
154-
# δx = __broadcast!!(δx, -, x, xo)
155-
# δf = __broadcast!!(δf, -, F_new, F_k)
156-
157-
# σ_k = dot(δx, δx) / dot(δx, δf)
158-
159-
# # Take step
160-
# xo = __copyto!!(xo, x)
161-
# F_k = __copyto!!(F_k, F_new)
162-
# f_k = f_new
163-
164-
# # Store function value
165-
# history_f_k[k % M + 1] = f_new
166-
# end
167-
168-
# return build_solution(prob, alg, x, F_k; retcode = ReturnCode.MaxIters)
60+
σ_min = T(alg.σ_min)
61+
σ_max = T(alg.σ_max)
62+
σ_k = T(alg.σ_1)
63+
64+
(; M, nexp, η_strategy) = alg
65+
γ = T(alg.γ)
66+
τ_min = T(alg.τ_min)
67+
τ_max = T(alg.τ_max)
68+
69+
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
70+
termination_condition)
71+
72+
fx_norm = norm(fx)^nexp
73+
α_1 = one(T)
74+
f_1 = fx_norm
75+
history_f_k = fill(fx_norm, M)
76+
77+
# Generate the cache
78+
@bb d = copy(x)
79+
@bb xo = copy(x)
80+
@bb x_cache = copy(x)
81+
@bb δx = copy(x)
82+
@bb fxo = copy(fx)
83+
@bb δf = copy(fx)
84+
85+
k = 0
86+
while k < maxiters
87+
# Spectral parameter range check
88+
σ_k = sign(σ_k) * clamp(abs(σ_k), σ_min, σ_max)
89+
90+
# Line search direction
91+
@bb @. d = -σ_k * fx
92+
93+
η = η_strategy(f_1, k, x, fx)
94+
f_bar = maximum(history_f_k)
95+
α_p = α_1
96+
α_m = α_1
97+
98+
@bb @. x += α_p * d
99+
100+
fx = __eval_f(prob, fx, x)
101+
fx_norm_new = norm(fx)^nexp
102+
103+
while k < maxiters
104+
fx_norm_new (f_bar + η - γ * α_p^2 * fx_norm) && break
105+
106+
α_p = α_p^2 * fx_norm / (fx_norm_new + (T(2) * α_p - T(1)) * fx_norm)
107+
@bb @. x -= α_m * d
108+
109+
fx = __eval_f(prob, fx, x)
110+
fx_norm_new = norm(fx)^nexp
111+
112+
fx_norm_new (f_bar + η - γ * α_m^2 * fx_norm) && break
113+
114+
α_tm = α_m^2 * fx_norm / (fx_norm_new + (T(2) * α_m - T(1)) * fx_norm)
115+
α_p = clamp(α_p, τ_min * α_p, τ_max * α_p)
116+
α_m = clamp(α_tm, τ_min * α_m, τ_max * α_m)
117+
@bb @. x += α_p * d
118+
119+
fx = __eval_f(prob, fx, x)
120+
fx_norm_new = norm(fx)^nexp
121+
end
122+
123+
tc_sol = check_termination(tc_cache, fx, x, xo, prob, alg)
124+
tc_sol !== nothing && return tc_sol
125+
126+
# Update spectral parameter
127+
@bb @. δx = x - xo
128+
@bb @. δf = fx - fxo
129+
130+
σ_k = dot(δx, δx) / dot(δx, δf)
131+
132+
# Take step
133+
@bb copyto!(xo, x)
134+
@bb copyto!(fxo, fx)
135+
fx_norm = fx_norm_new
136+
137+
# Store function value
138+
history_f_k[mod1(k, M)] = fx_norm_new
139+
k += 1
140+
end
141+
142+
return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
169143
end

0 commit comments

Comments
 (0)