Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 9f31a04

Browse files
committed
Add tests for the nonlinear solvers
1 parent b8a75e6 commit 9f31a04

File tree

8 files changed

+445
-599
lines changed

8 files changed

+445
-599
lines changed

src/SimpleNonlinearSolve.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,19 @@ end
6767
algs = [SimpleNewtonRaphson(), SimpleBroyden(), SimpleKlement(), SimpleDFSane(),
6868
SimpleTrustRegion(), SimpleLimitedMemoryBroyden(; threshold = 2)]
6969

70+
algs_no_iip = [SimpleHalley()]
71+
7072
@compile_workload begin
7173
for alg in algs
7274
solve(prob_no_brack_scalar, alg, abstol = T(1e-2))
7375
solve(prob_no_brack_iip, alg, abstol = T(1e-2))
7476
solve(prob_no_brack_oop, alg, abstol = T(1e-2))
7577
end
78+
79+
for alg in algs_no_iip
80+
solve(prob_no_brack_scalar, alg, abstol = T(1e-2))
81+
solve(prob_no_brack_oop, alg, abstol = T(1e-2))
82+
end
7683
end
7784

7885
prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p,

src/ad.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,27 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:Abstr
4747
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
4848
end
4949

50+
function scalar_nlsolve_∂f_∂p(f, u, p)
51+
ff = p isa Number ? ForwardDiff.derivative :
52+
(u isa Number ? ForwardDiff.gradient : ForwardDiff.jacobian)
53+
return ff(Base.Fix1(f, u), p)
54+
end
55+
56+
function scalar_nlsolve_∂f_∂u(f, u, p)
57+
ff = u isa Number ? ForwardDiff.derivative : ForwardDiff.jacobian
58+
return ff(Base.Fix2(f, p), u)
59+
end
60+
61+
function scalar_nlsolve_dual_soln(u::Number, partials,
62+
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
63+
return Dual{T, V, P}(u, partials)
64+
end
65+
66+
function scalar_nlsolve_dual_soln(u::AbstractArray, partials,
67+
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
68+
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials))
69+
end
70+
5071
# avoid ambiguities
5172
for Alg in [Bisection]
5273
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ function __init_identity_jacobian(u::StaticArray, fu)
207207
J = SMatrix{S1, S2, eltype(u)}(I)
208208
return J
209209
end
210-
function __init_identity_jacobian!!(J::StaticArray{S1, S2}) where {S1, S2}
211-
return SMMatrix{S1, S2, eltype(J)}(I)
210+
function __init_identity_jacobian!!(J::SMatrix{S1, S2}) where {S1, S2}
211+
return SMatrix{S1, S2, eltype(J)}(I)
212212
end
213213

214214
function __init_low_rank_jacobian(u::StaticArray{S1, T1}, fu::StaticArray{S2, T2},

test/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
33
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
44
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
55
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
6-
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
6+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
77
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
89
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
910
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1011
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

0 commit comments

Comments
 (0)