@@ -5,8 +5,17 @@ function SciMLBase.solve(
55 sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
66 dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
77 return SciMLBase. build_solution (
8- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats,
9- sol. original)
8+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
9+ end
10+
11+ function SciMLBase. solve (
12+ prob:: NonlinearLeastSquaresProblem {<: AbstractArray ,
13+ iip, <: Union{<:AbstractArray{<:Dual{T, V, P}}} },
14+ alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ; kwargs... ) where {T, V, P, iip}
15+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
16+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
17+ return SciMLBase. build_solution (
18+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
1019end
1120
1221for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@@ -24,7 +33,8 @@ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
2433 end
2534end
2635
27- function __nlsolve_ad (prob, alg, args... ; kwargs... )
36+ function __nlsolve_ad (
37+ prob:: Union{IntervalNonlinearProblem, NonlinearProblem} , alg, args... ; kwargs... )
2838 p = value (prob. p)
2939 if prob isa IntervalNonlinearProblem
3040 tspan = value .(prob. tspan)
@@ -55,6 +65,96 @@ function __nlsolve_ad(prob, alg, args...; kwargs...)
5565 return sol, partials
5666end
5767
68+ function __nlsolve_ad (prob:: NonlinearLeastSquaresProblem , alg, args... ; kwargs... )
69+ p = value (prob. p)
70+ u0 = value (prob. u0)
71+ newprob = NonlinearLeastSquaresProblem (prob. f, u0, p; prob. kwargs... )
72+
73+ sol = solve (newprob, alg, args... ; kwargs... )
74+
75+ uu = sol. u
76+
77+ # First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
78+ # nested autodiff as the last resort
79+ if SciMLBase. has_vjp (prob. f)
80+ if isinplace (prob)
81+ _F = @closure (du, u, p) -> begin
82+ resid = similar (du, length (sol. resid))
83+ prob. f (resid, u, p)
84+ prob. f. vjp (du, resid, u, p)
85+ du .*= 2
86+ return nothing
87+ end
88+ else
89+ _F = @closure (u, p) -> begin
90+ resid = prob. f (u, p)
91+ return reshape (2 .* prob. f. vjp (resid, u, p), size (u))
92+ end
93+ end
94+ elseif SciMLBase. has_jac (prob. f)
95+ if isinplace (prob)
96+ _F = @closure (du, u, p) -> begin
97+ J = similar (du, length (sol. resid), length (u))
98+ prob. f. jac (J, u, p)
99+ resid = similar (du, length (sol. resid))
100+ prob. f (resid, u, p)
101+ mul! (reshape (du, 1 , :), vec (resid)' , J, 2 , false )
102+ return nothing
103+ end
104+ else
105+ _F = @closure (u, p) -> begin
106+ return reshape (2 .* vec (prob. f (u, p))' * prob. f. jac (u, p), size (u))
107+ end
108+ end
109+ else
110+ if isinplace (prob)
111+ _F = @closure (du, u, p) -> begin
112+ resid = similar (du, length (sol. resid))
113+ res = DiffResults. DiffResult (
114+ resid, similar (du, length (sol. resid), length (u)))
115+ _f = @closure (du, u) -> prob. f (du, u, p)
116+ ForwardDiff. jacobian! (res, _f, resid, u)
117+ mul! (reshape (du, 1 , :), vec (DiffResults. value (res))' ,
118+ DiffResults. jacobian (res), 2 , false )
119+ return nothing
120+ end
121+ else
122+ # For small problems, nesting ForwardDiff is actually quite fast
123+ if __is_extension_loaded (Val (:Zygote )) && (length (uu) + length (sol. resid) ≥ 50 )
124+ _F = @closure (u, p) -> __zygote_compute_nlls_vjp (prob. f, u, p)
125+ else
126+ _F = @closure (u, p) -> begin
127+ T = promote_type (eltype (u), eltype (p))
128+ res = DiffResults. DiffResult (
129+ similar (u, T, size (sol. resid)), similar (
130+ u, T, length (sol. resid), length (u)))
131+ ForwardDiff. jacobian! (res, Base. Fix2 (prob. f, p), u)
132+ return reshape (
133+ 2 .* vec (DiffResults. value (res))' * DiffResults. jacobian (res),
134+ size (u))
135+ end
136+ end
137+ end
138+ end
139+
140+ f_p = __nlsolve_∂f_∂p (prob, _F, uu, p)
141+ f_x = __nlsolve_∂f_∂u (prob, _F, uu, p)
142+
143+ z_arr = - f_x \ f_p
144+
145+ pp = prob. p
146+ sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
147+ if uu isa Number
148+ partials = sum (sumfun, zip (z_arr, pp))
149+ elseif p isa Number
150+ partials = sumfun ((z_arr, pp))
151+ else
152+ partials = sum (sumfun, zip (eachcol (z_arr), pp))
153+ end
154+
155+ return sol, partials
156+ end
157+
58158@inline function __nlsolve_∂f_∂p (prob, f:: F , u, p) where {F}
59159 if isinplace (prob)
60160 __f = p -> begin
0 commit comments