11module LinearSolveForwardDiffExt
22
33using LinearSolve
4- using LinearSolve: SciMLLinearSolveAlgorithm, __init, LinearVerbosity, DefaultLinearSolver, DefaultAlgorithmChoice, defaultalg
4+ using LinearSolve: SciMLLinearSolveAlgorithm, __init, LinearVerbosity, DefaultLinearSolver,
5+ DefaultAlgorithmChoice, defaultalg
56using LinearAlgebra
67using ForwardDiff
78using ForwardDiff: Dual, Partials
89using SciMLBase
910using RecursiveArrayTools
1011using SciMLLogging
12+ using ArrayInterface
1113
1214const DualLinearProblem = LinearProblem{
1315 <: Union{Number, <:AbstractArray, Nothing} , iip,
6365function linearsolve_forwarddiff_solve! (cache:: DualLinearCache , alg, args... ; kwargs... )
6466 # Solve the primal problem
6567 cache. dual_u0_cache .= cache. linear_cache. u
66- sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
68+ sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
6769
6870 cache. primal_u_cache .= cache. linear_cache. u
6971 cache. primal_b_cache .= cache. linear_cache. b
@@ -165,26 +167,27 @@ function linearsolve_dual_solution(
165167end
166168
167169function linearsolve_dual_solution (u:: AbstractArray , partials,
168- cache:: DualLinearCache{DT} ) where {T, V, N, DT <: Dual{T,V, N} }
170+ cache:: DualLinearCache{DT} ) where {T, V, N, DT <: Dual{T, V, N} }
169171 # Optimized in-place version that reuses cache.dual_u
170172 linearsolve_dual_solution! (getfield (cache, :dual_u ), u, partials)
171173 return getfield (cache, :dual_u )
172174end
173175
174- function linearsolve_dual_solution! (dual_u:: AbstractArray{DT} , u:: AbstractArray , partials) where {T, V, N, DT <: Dual{T,V,N} }
176+ function linearsolve_dual_solution! (dual_u:: AbstractArray{DT} , u:: AbstractArray ,
177+ partials) where {T, V, N, DT <: Dual{T, V, N} }
175178 # Direct in-place construction of dual numbers without temporary allocations
176179 n_partials = length (partials)
177-
180+
178181 for i in eachindex (u, dual_u)
179182 # Extract partials for this element directly
180183 partial_vals = ntuple (Val (N)) do j
181184 V (partials[j][i])
182185 end
183-
186+
184187 # Construct dual number in-place
185- dual_u[i] = DT (u[i], Partials {N,V} (partial_vals))
188+ dual_u[i] = DT (u[i], Partials {N, V} (partial_vals))
186189 end
187-
190+
188191 return dual_u
189192end
190193
@@ -279,7 +282,7 @@ function __dual_init(
279282 true , # Cache is initially valid
280283 A,
281284 b,
282- zeros (dual_type, length (b))
285+ ArrayInterface . restructure (b, zeros (dual_type, length (b) ))
283286 )
284287end
285288
@@ -288,18 +291,20 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
288291end
289292
290293function SciMLBase. solve! (
291- cache:: DualLinearCache{DT} , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... ) where {DT <: ForwardDiff.Dual }
294+ cache:: DualLinearCache{DT} , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... ) where {DT < :
295+ ForwardDiff. Dual}
292296 primal_sol = linearsolve_forwarddiff_solve! (
293297 cache:: DualLinearCache , getfield (cache, :linear_cache ). alg, args... ; kwargs... )
294- dual_sol = linearsolve_dual_solution (getfield (cache,:linear_cache ). u, getfield (cache, :rhs_list ), cache)
298+ dual_sol = linearsolve_dual_solution (getfield (cache, :linear_cache ). u, getfield (cache, :rhs_list ), cache)
295299
296300 # For scalars, we still need to assign since cache.dual_u might not be pre-allocated
297301 if ! (getfield (cache, :dual_u ) isa AbstractArray)
298302 setfield! (cache, :dual_u , dual_sol)
299303 end
300304
301305 return SciMLBase. build_linear_solution (
302- getfield (cache, :linear_cache ). alg, getfield (cache, :dual_u ), primal_sol. resid, cache; primal_sol. retcode, primal_sol. iters, primal_sol. stats
306+ getfield (cache, :linear_cache ). alg, getfield (cache, :dual_u ), primal_sol. resid, cache;
307+ primal_sol. retcode, primal_sol. iters, primal_sol. stats
303308 )
304309end
305310
0 commit comments