Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 17fdeee

Browse files
Merge pull request #28 from SciML/ChrisRackauckas-patch-1
TrustRegion -> SimpleTrustRegion and specialize the number types
2 parents 2b4cfa5 + 632910b commit 17fdeee

File tree

3 files changed

+53
-50
lines changed

3 files changed

+53
-50
lines changed

src/SimpleNonlinearSolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
3232
solve(prob_no_brack, alg(), abstol = T(1e-2))
3333
end
3434

35-
for alg in (TrustRegion(10.0),)
35+
for alg in (SimpleTrustRegion(10.0),)
3636
solve(prob_no_brack, alg, abstol = T(1e-2))
3737
end
3838

@@ -53,6 +53,6 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
5353
end end
5454

5555
# DiffEq styled algorithms
56-
export Bisection, Broyden, Falsi, Klement, SimpleNewtonRaphson, TrustRegion
56+
export Bisection, Broyden, Falsi, Klement, SimpleNewtonRaphson, SimpleTrustRegion
5757

5858
end # module

src/trustRegion.jl

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
```julia
3-
TrustRegion(max_trust_radius::Number; chunk_size = Val{0}(),
3+
SimpleTrustRegion(max_trust_radius::Number; chunk_size = Val{0}(),
44
autodiff = Val{true}(), diff_type = Val{:forward})
55
```
66
@@ -49,36 +49,39 @@ solver
4949
- `max_shrink_times`: the maximum number of times to shrink the trust region radius in a
5050
row, `max_shrink_times` is exceeded, the algorithm returns. Defaults to `32`.
5151
"""
52-
struct TrustRegion{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT}
53-
max_trust_radius::Number
54-
initial_trust_radius::Number
55-
step_threshold::Number
56-
shrink_threshold::Number
57-
expand_threshold::Number
58-
shrink_factor::Number
59-
expand_factor::Number
52+
struct SimpleTrustRegion{T, CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT}
53+
max_trust_radius::T
54+
initial_trust_radius::T
55+
step_threshold::T
56+
shrink_threshold::T
57+
expand_threshold::T
58+
shrink_factor::T
59+
expand_factor::T
6060
max_shrink_times::Int
61-
function TrustRegion(max_trust_radius::Number; chunk_size = Val{0}(),
62-
autodiff = Val{true}(),
63-
diff_type = Val{:forward},
64-
initial_trust_radius::Number = max_trust_radius / 11,
65-
step_threshold::Number = 0.1,
66-
shrink_threshold::Number = 0.25,
67-
expand_threshold::Number = 0.75,
68-
shrink_factor::Number = 0.25,
69-
expand_factor::Number = 2.0,
70-
max_shrink_times::Int = 32)
71-
new{SciMLBase._unwrap_val(chunk_size), SciMLBase._unwrap_val(autodiff),
72-
SciMLBase._unwrap_val(diff_type)}(max_trust_radius, initial_trust_radius,
73-
step_threshold,
74-
shrink_threshold, expand_threshold,
75-
shrink_factor,
76-
expand_factor, max_shrink_times)
61+
function SimpleTrustRegion(max_trust_radius::Number; chunk_size = Val{0}(),
62+
autodiff = Val{true}(),
63+
diff_type = Val{:forward},
64+
initial_trust_radius::Number = max_trust_radius / 11,
65+
step_threshold::Number = 0.1,
66+
shrink_threshold::Number = 0.25,
67+
expand_threshold::Number = 0.75,
68+
shrink_factor::Number = 0.25,
69+
expand_factor::Number = 2.0,
70+
max_shrink_times::Int = 32)
71+
new{typeof(initial_trust_radius), SciMLBase._unwrap_val(chunk_size),
72+
SciMLBase._unwrap_val(autodiff), SciMLBase._unwrap_val(diff_type)}(max_trust_radius,
73+
initial_trust_radius,
74+
step_threshold,
75+
shrink_threshold,
76+
expand_threshold,
77+
shrink_factor,
78+
expand_factor,
79+
max_shrink_times)
7780
end
7881
end
7982

8083
function SciMLBase.__solve(prob::NonlinearProblem,
81-
alg::TrustRegion, args...; abstol = nothing,
84+
alg::SimpleTrustRegion, args...; abstol = nothing,
8285
reltol = nothing,
8386
maxiters = 1000, kwargs...)
8487
f = Base.Fix2(prob.f, prob.p)
@@ -94,7 +97,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
9497
max_shrink_times = alg.max_shrink_times
9598

9699
if SciMLBase.isinplace(prob)
97-
error("TrustRegion currently only supports out-of-place nonlinear problems")
100+
error("SimpleTrustRegion currently only supports out-of-place nonlinear problems")
98101
end
99102

100103
atol = abstol !== nothing ? abstol :

test/basictests.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ if VERSION >= v"1.7"
5252
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
5353
end
5454

55-
# TrustRegion
55+
# SimpleTrustRegion
5656
function benchmark_scalar(f, u0)
5757
probN = NonlinearProblem{false}(f, u0)
58-
sol = (solve(probN, TrustRegion(10.0)))
58+
sol = (solve(probN, SimpleTrustRegion(10.0)))
5959
end
6060

6161
sol = benchmark_scalar(sf, csu0)
@@ -69,7 +69,7 @@ using ForwardDiff
6969
f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]
7070

7171
for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
72-
TrustRegion(10.0)]
72+
SimpleTrustRegion(10.0)]
7373
g = function (p)
7474
probN = NonlinearProblem{false}(f, csu0, p)
7575
sol = solve(probN, alg, abstol = 1e-9)
@@ -85,7 +85,7 @@ end
8585
# Scalar
8686
f, u0 = (u, p) -> u * u - p, 1.0
8787
for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
88-
TrustRegion(10.0)]
88+
SimpleTrustRegion(10.0)]
8989
g = function (p)
9090
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
9191
sol = solve(probN, alg)
@@ -127,7 +127,7 @@ for alg in [Bisection(), Falsi()]
127127
end
128128

129129
for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
130-
TrustRegion(10.0)]
130+
SimpleTrustRegion(10.0)]
131131
global g, p
132132
g = function (p)
133133
probN = NonlinearProblem{false}(f, 0.5, p)
@@ -144,8 +144,8 @@ probN = NonlinearProblem(f, u0)
144144

145145
@test solve(probN, SimpleNewtonRaphson()).u[end] sqrt(2.0)
146146
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u[end] sqrt(2.0)
147-
@test solve(probN, TrustRegion(10.0)).u[end] sqrt(2.0)
148-
@test solve(probN, TrustRegion(10.0; autodiff = false)).u[end] sqrt(2.0)
147+
@test solve(probN, SimpleTrustRegion(10.0)).u[end] sqrt(2.0)
148+
@test solve(probN, SimpleTrustRegion(10.0; autodiff = false)).u[end] sqrt(2.0)
149149
@test solve(probN, Broyden()).u[end] sqrt(2.0)
150150
@test solve(probN, Klement()).u[end] sqrt(2.0)
151151

@@ -159,9 +159,9 @@ for u0 in [1.0, [1, 1.0]]
159159
@test solve(probN, SimpleNewtonRaphson()).u sol
160160
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u sol
161161

162-
@test solve(probN, TrustRegion(10.0)).u sol
163-
@test solve(probN, TrustRegion(10.0)).u sol
164-
@test solve(probN, TrustRegion(10.0; autodiff = false)).u sol
162+
@test solve(probN, SimpleTrustRegion(10.0)).u sol
163+
@test solve(probN, SimpleTrustRegion(10.0)).u sol
164+
@test solve(probN, SimpleTrustRegion(10.0; autodiff = false)).u sol
165165

166166
@test solve(probN, Broyden()).u sol
167167

@@ -205,7 +205,7 @@ sol = solve(probB, Bisection(; exact_left = true, exact_right = true); immutable
205205
@test f(sol.right, nothing) >= 0.0
206206
@test f(prevfloat(sol.right), nothing) <= 0.0
207207

208-
# Test that `TrustRegion` passes a test that `SimpleNewtonRaphson` fails on.
208+
# Test that `SimpleTrustRegion` passes a test that `SimpleNewtonRaphson` fails on.
209209
u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0]
210210
global g, f
211211
f = (u, p) -> 0.010000000000000002 .+
@@ -219,15 +219,15 @@ f = (u, p) -> 0.010000000000000002 .+
219219
.-p
220220
g = function (p)
221221
probN = NonlinearProblem{false}(f, u0, p)
222-
sol = solve(probN, TrustRegion(100.0))
222+
sol = solve(probN, SimpleTrustRegion(100.0))
223223
return sol.u
224224
end
225225
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
226226
u = g(p)
227227
f(u, p)
228228
@test all(f(u, p) .< 1e-10)
229229

230-
# Test kwars in `TrustRegion`
230+
# Test kwars in `SimpleTrustRegion`
231231
max_trust_radius = [10.0, 100.0, 1000.0]
232232
initial_trust_radius = [10.0, 1.0, 0.1]
233233
step_threshold = [0.0, 0.01, 0.25]
@@ -242,14 +242,14 @@ list_of_options = zip(max_trust_radius, initial_trust_radius, step_threshold,
242242
expand_factor, max_shrink_times)
243243
for options in list_of_options
244244
local probN, sol, alg
245-
alg = TrustRegion(options[1];
246-
initial_trust_radius = options[2],
247-
step_threshold = options[3],
248-
shrink_threshold = options[4],
249-
expand_threshold = options[5],
250-
shrink_factor = options[6],
251-
expand_factor = options[7],
252-
max_shrink_times = options[8])
245+
alg = SimpleTrustRegion(options[1];
246+
initial_trust_radius = options[2],
247+
step_threshold = options[3],
248+
shrink_threshold = options[4],
249+
expand_threshold = options[5],
250+
shrink_factor = options[6],
251+
expand_factor = options[7],
252+
max_shrink_times = options[8])
253253

254254
probN = NonlinearProblem(f, u0, p)
255255
sol = solve(probN, alg)

0 commit comments

Comments
 (0)