11function scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
22 f = prob. f
33 p = value (prob. p)
4-
4+ u0 = value (prob . u0)
55 if prob isa IntervalNonlinearProblem
66 tspan = value (prob. tspan)
77 newprob = IntervalNonlinearProblem (f, tspan, p; prob. kwargs... )
@@ -13,66 +13,57 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
1313 sol = solve (newprob, alg, args... ; kwargs... )
1414
1515 uu = sol. u
16- if p isa Number
17- f_p = ForwardDiff. derivative (Base. Fix1 (f, uu), p)
18- else
19- f_p = ForwardDiff. gradient (Base. Fix1 (f, uu), p)
20- end
16+ f_p = scalar_nlsolve_∂f_∂p (f, uu, p)
17+ f_x = scalar_nlsolve_∂f_∂u (f, uu, p)
18+
19+ z_arr = - inv (f_x) * f_p
2120
22- f_x = ForwardDiff. derivative (Base. Fix2 (f, p), uu)
2321 pp = prob. p
24- sumfun = let f_x′ = - f_x
25- ((fp, p),) -> (fp / f_x′) * ForwardDiff. partials (p)
22+ sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
23+ if uu isa Number
24+ partials = sum (sumfun, zip (z_arr, pp))
25+ elseif p isa Number
26+ partials = sumfun ((z_arr, pp))
27+ else
28+ partials = sum (sumfun, zip (eachcol (z_arr), pp))
2629 end
27- partials = sum (sumfun, zip (f_p, pp))
30+
2831 return sol, partials
2932end
3033
31- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, StaticArraysCore.SVector} ,
32- iip,
33- <: Dual{T, V, P} },
34- alg:: AbstractSimpleNonlinearSolveAlgorithm ,
35- args... ; kwargs... ) where {iip, T, V, P}
34+ function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
35+ false , <: Dual{T, V, P} }, alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ;
36+ kwargs... ) where {T, V, P}
3637 sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
37- return SciMLBase . build_solution (prob, alg, Dual {T, V, P} (sol. u, partials), sol . resid;
38- retcode = sol. retcode)
38+ dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob . p)
39+ return SciMLBase . build_solution (prob, alg, dual_soln, sol . resid; sol. retcode)
3940end
40- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, StaticArraysCore.SVector} ,
41- iip,
42- <: AbstractArray{<:Dual{T, V, P}} },
43- alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ;
44- kwargs... ) where {iip, T, V, P}
41+
42+ function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
43+ false , <: AbstractArray{<:Dual{T, V, P}} },
44+ alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ; kwargs... ) where {T, V, P}
4545 sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
46- return SciMLBase . build_solution (prob, alg, Dual {T, V, P} (sol. u, partials), sol . resid;
47- retcode = sol. retcode)
46+ dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob . p)
47+ return SciMLBase . build_solution (prob, alg, dual_soln, sol . resid; sol. retcode)
4848end
4949
5050# avoid ambiguities
5151for Alg in [Bisection]
5252 @eval function SciMLBase. solve (prob:: IntervalNonlinearProblem {uType, iip,
53- <: Dual{T, V, P} },
54- alg:: $Alg , args... ;
55- kwargs... ) where {uType, iip, T, V, P}
53+ <: Dual{T, V, P} }, alg:: $Alg , args... ; kwargs... ) where {uType, iip, T, V, P}
5654 sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
57- return SciMLBase . build_solution (prob, alg, Dual {T, V, P} (sol. u, partials),
58- sol. resid; retcode = sol. retcode,
55+ dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob . p)
56+ return SciMLBase . build_solution (prob, alg, dual_soln, sol. resid; sol. retcode,
5957 left = Dual {T, V, P} (sol. left, partials),
6058 right = Dual {T, V, P} (sol. right, partials))
61- # return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
6259 end
6360 @eval function SciMLBase. solve (prob:: IntervalNonlinearProblem {uType, iip,
64- <: AbstractArray {
65- <: Dual {T,
66- V,
67- P},
68- }},
69- alg:: $Alg , args... ;
61+ <: AbstractArray{<:Dual{T, V, P}} }, alg:: $Alg , args... ;
7062 kwargs... ) where {uType, iip, T, V, P}
7163 sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
72- return SciMLBase . build_solution (prob, alg, Dual {T, V, P} (sol. u, partials),
73- sol. resid; retcode = sol. retcode,
64+ dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob . p)
65+ return SciMLBase . build_solution (prob, alg, dual_soln, sol. resid; sol. retcode,
7466 left = Dual {T, V, P} (sol. left, partials),
7567 right = Dual {T, V, P} (sol. right, partials))
76- # return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
7768 end
7869end
0 commit comments