Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 39657a4

Browse files
committed
Fix Limited Memory Broyden
1 parent c7d01d0 commit 39657a4

File tree

6 files changed

+167
-187
lines changed

6 files changed

+167
-187
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,5 @@ For more details on the bracketing methods, refer to the [Tutorials](https://doc
5050
- `Broyden` and `Klement` have been renamed to `SimpleBroyden` and `SimpleKlement` to
5151
avoid conflicts with `NonlinearSolve.jl`'s `GeneralBroyden` and `GeneralKlement`, which
5252
will be renamed to `Broyden` and `Klement` in the future.
53+
- `LBroyden` has been renamed to `SimpleLimitedMemoryBroyden` to make it consistent with
54+
`NonlinearSolve.jl`'s `LimitedMemoryBroyden`.

src/SimpleNonlinearSolve.jl

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat
1313
import ForwardDiff: Dual
1414
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
1515
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace
16-
import StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray
16+
import StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, MMatrix, Size
1717
end
1818

1919
@reexport using ADTypes, SciMLBase
@@ -24,16 +24,16 @@ abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm e
2424

2525
include("utils.jl")
2626

27-
# Nonlinear Solvera
27+
## Nonlinear Solvers
2828
include("nlsolve/raphson.jl")
2929
include("nlsolve/broyden.jl")
30-
# include("nlsolve/lbroyden.jl")
30+
include("nlsolve/lbroyden.jl")
3131
include("nlsolve/klement.jl")
3232
include("nlsolve/trustRegion.jl")
3333
# include("nlsolve/halley.jl")
3434
# include("nlsolve/dfsane.jl")
3535

36-
# Interval Nonlinear Solvers
36+
## Interval Nonlinear Solvers
3737
include("bracketing/bisection.jl")
3838
include("bracketing/falsi.jl")
3939
include("bracketing/ridder.jl")
@@ -42,7 +42,7 @@ include("bracketing/alefeld.jl")
4242
include("bracketing/itp.jl")
4343

4444
# AD
45-
# include("ad.jl")
45+
include("ad.jl")
4646

4747
## Default algorithm
4848

@@ -58,34 +58,22 @@ end
5858

5959
@setup_workload begin
6060
for T in (Float32, Float64)
61-
prob_no_brack = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
62-
algs = [SimpleNewtonRaphson(), SimpleBroyden(), SimpleKlement(),
63-
SimpleTrustRegion()]
64-
65-
@compile_workload begin
66-
for alg in algs
67-
solve(prob_no_brack, alg, abstol = T(1e-2))
68-
end
69-
end
61+
prob_no_brack_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
62+
prob_no_brack_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p,
63+
T.([1.0, 1.0, 1.0]), T(2))
64+
prob_no_brack_oop = NonlinearProblem{false}((u, p) -> u .* u .- p,
65+
T.([1.0, 1.0, 1.0]), T(2))
7066

71-
prob_no_brack = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p,
72-
T.([1.0, 1.0]), T(2))
67+
algs = [SimpleNewtonRaphson(), SimpleBroyden(), SimpleKlement(),
68+
SimpleTrustRegion(), SimpleLimitedMemoryBroyden(; threshold = 2)]
7369

7470
@compile_workload begin
7571
for alg in algs
76-
solve(prob_no_brack, alg, abstol = T(1e-2))
77-
end
78-
end
79-
80-
#=
81-
for alg in (SimpleNewtonRaphson,)
82-
for u0 in ([1., 1.], StaticArraysCore.SA[1.0, 1.0])
83-
u0 = T.(.1)
84-
probN = NonlinearProblem{false}((u,p) -> u .* u .- p, u0, T(2))
85-
solve(probN, alg(), tol = T(1e-2))
72+
solve(prob_no_brack_scalar, alg, abstol = T(1e-2))
73+
solve(prob_no_brack_iip, alg, abstol = T(1e-2))
74+
solve(prob_no_brack_oop, alg, abstol = T(1e-2))
8675
end
8776
end
88-
=#
8977

9078
prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p,
9179
T.((0.0, 2.0)), T(2))
@@ -98,9 +86,9 @@ end
9886
end
9987
end
10088

101-
export SimpleBroyden,
102-
SimpleGaussNewton, SimpleKlement, SimpleNewtonRaphson, SimpleTrustRegion
103-
# SimpleDFSane, SimpleHalley, LBroyden
89+
export SimpleBroyden, SimpleGaussNewton, SimpleKlement, SimpleLimitedMemoryBroyden,
90+
SimpleNewtonRaphson, SimpleTrustRegion
91+
# SimpleDFSane, SimpleHalley
10492
export Alefeld, Bisection, Brent, Falsi, ITP, Ridder
10593

10694
end # module

src/ad.jl

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
22
f = prob.f
33
p = value(prob.p)
4-
4+
u0 = value(prob.u0)
55
if prob isa IntervalNonlinearProblem
66
tspan = value(prob.tspan)
77
newprob = IntervalNonlinearProblem(f, tspan, p; prob.kwargs...)
@@ -13,66 +13,57 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
1313
sol = solve(newprob, alg, args...; kwargs...)
1414

1515
uu = sol.u
16-
if p isa Number
17-
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
18-
else
19-
f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p)
20-
end
16+
f_p = scalar_nlsolve_∂f_∂p(f, uu, p)
17+
f_x = scalar_nlsolve_∂f_∂u(f, uu, p)
18+
19+
z_arr = -inv(f_x) * f_p
2120

22-
f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu)
2321
pp = prob.p
24-
sumfun = let f_x′ = -f_x
25-
((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p)
22+
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
23+
if uu isa Number
24+
partials = sum(sumfun, zip(z_arr, pp))
25+
elseif p isa Number
26+
partials = sumfun((z_arr, pp))
27+
else
28+
partials = sum(sumfun, zip(eachcol(z_arr), pp))
2629
end
27-
partials = sum(sumfun, zip(f_p, pp))
30+
2831
return sol, partials
2932
end
3033

31-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
32-
iip,
33-
<:Dual{T, V, P}},
34-
alg::AbstractSimpleNonlinearSolveAlgorithm,
35-
args...; kwargs...) where {iip, T, V, P}
34+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
35+
false, <:Dual{T, V, P}}, alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
36+
kwargs...) where {T, V, P}
3637
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
37-
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
38-
retcode = sol.retcode)
38+
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
39+
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
3940
end
40-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
41-
iip,
42-
<:AbstractArray{<:Dual{T, V, P}}},
43-
alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
44-
kwargs...) where {iip, T, V, P}
41+
42+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
43+
false, <:AbstractArray{<:Dual{T, V, P}}},
44+
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P}
4545
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
46-
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
47-
retcode = sol.retcode)
46+
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
47+
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
4848
end
4949

5050
# avoid ambiguities
5151
for Alg in [Bisection]
5252
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
53-
<:Dual{T, V, P}},
54-
alg::$Alg, args...;
55-
kwargs...) where {uType, iip, T, V, P}
53+
<:Dual{T, V, P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
5654
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
57-
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials),
58-
sol.resid; retcode = sol.retcode,
55+
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
56+
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
5957
left = Dual{T, V, P}(sol.left, partials),
6058
right = Dual{T, V, P}(sol.right, partials))
61-
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
6259
end
6360
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
64-
<:AbstractArray{
65-
<:Dual{T,
66-
V,
67-
P},
68-
}},
69-
alg::$Alg, args...;
61+
<:AbstractArray{<:Dual{T, V, P}}}, alg::$Alg, args...;
7062
kwargs...) where {uType, iip, T, V, P}
7163
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
72-
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials),
73-
sol.resid; retcode = sol.retcode,
64+
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
65+
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
7466
left = Dual{T, V, P}(sol.left, partials),
7567
right = Dual{T, V, P}(sol.right, partials))
76-
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
7768
end
7869
end

0 commit comments

Comments
 (0)