Skip to content

Commit 20b5792

Browse files
Fix ComponentArray type preservation in ForwardDiff specialization
Use ArrayInterface.restructure to maintain ComponentArray structure when initializing dual_u cache. Previously, zeros(dual_type, length(b)) would create a plain Vector regardless of the input type, causing a TypeError when trying to assign a ComponentVector to a Vector field. This fix ensures that if the input b is a ComponentArray, the dual_u cache will also be a ComponentArray with the same structure. Fixes SciML/DifferentialEquations.jl#1110 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 013925b commit 20b5792

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
module LinearSolveForwardDiffExt
22

33
using LinearSolve
4-
using LinearSolve: SciMLLinearSolveAlgorithm, __init, LinearVerbosity, DefaultLinearSolver, DefaultAlgorithmChoice, defaultalg
4+
using LinearSolve: SciMLLinearSolveAlgorithm, __init, LinearVerbosity, DefaultLinearSolver,
5+
DefaultAlgorithmChoice, defaultalg
56
using LinearAlgebra
67
using ForwardDiff
78
using ForwardDiff: Dual, Partials
89
using SciMLBase
910
using RecursiveArrayTools
1011
using SciMLLogging
12+
using ArrayInterface
1113

1214
const DualLinearProblem = LinearProblem{
1315
<:Union{Number, <:AbstractArray, Nothing}, iip,
@@ -63,7 +65,7 @@ end
6365
function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kwargs...)
6466
# Solve the primal problem
6567
cache.dual_u0_cache .= cache.linear_cache.u
66-
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
68+
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
6769

6870
cache.primal_u_cache .= cache.linear_cache.u
6971
cache.primal_b_cache .= cache.linear_cache.b
@@ -165,26 +167,27 @@ function linearsolve_dual_solution(
165167
end
166168

167169
function linearsolve_dual_solution(u::AbstractArray, partials,
168-
cache::DualLinearCache{DT}) where {T, V, N, DT <: Dual{T,V,N}}
170+
cache::DualLinearCache{DT}) where {T, V, N, DT <: Dual{T, V, N}}
169171
# Optimized in-place version that reuses cache.dual_u
170172
linearsolve_dual_solution!(getfield(cache, :dual_u), u, partials)
171173
return getfield(cache, :dual_u)
172174
end
173175

174-
function linearsolve_dual_solution!(dual_u::AbstractArray{DT}, u::AbstractArray, partials) where {T, V, N, DT <: Dual{T,V,N}}
176+
function linearsolve_dual_solution!(dual_u::AbstractArray{DT}, u::AbstractArray,
177+
partials) where {T, V, N, DT <: Dual{T, V, N}}
175178
# Direct in-place construction of dual numbers without temporary allocations
176179
n_partials = length(partials)
177-
180+
178181
for i in eachindex(u, dual_u)
179182
# Extract partials for this element directly
180183
partial_vals = ntuple(Val(N)) do j
181184
V(partials[j][i])
182185
end
183-
186+
184187
# Construct dual number in-place
185-
dual_u[i] = DT(u[i], Partials{N,V}(partial_vals))
188+
dual_u[i] = DT(u[i], Partials{N, V}(partial_vals))
186189
end
187-
190+
188191
return dual_u
189192
end
190193

@@ -279,7 +282,7 @@ function __dual_init(
279282
true, # Cache is initially valid
280283
A,
281284
b,
282-
zeros(dual_type, length(b))
285+
ArrayInterface.restructure(b, zeros(dual_type, length(b)))
283286
)
284287
end
285288

@@ -288,18 +291,20 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
288291
end
289292

290293
function SciMLBase.solve!(
291-
cache::DualLinearCache{DT}, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) where {DT <: ForwardDiff.Dual}
294+
cache::DualLinearCache{DT}, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) where {DT <:
295+
ForwardDiff.Dual}
292296
primal_sol = linearsolve_forwarddiff_solve!(
293297
cache::DualLinearCache, getfield(cache, :linear_cache).alg, args...; kwargs...)
294-
dual_sol = linearsolve_dual_solution(getfield(cache,:linear_cache).u, getfield(cache, :rhs_list), cache)
298+
dual_sol = linearsolve_dual_solution(getfield(cache, :linear_cache).u, getfield(cache, :rhs_list), cache)
295299

296300
# For scalars, we still need to assign since cache.dual_u might not be pre-allocated
297301
if !(getfield(cache, :dual_u) isa AbstractArray)
298302
setfield!(cache, :dual_u, dual_sol)
299303
end
300304

301305
return SciMLBase.build_linear_solution(
302-
getfield(cache, :linear_cache).alg, getfield(cache, :dual_u), primal_sol.resid, cache; primal_sol.retcode, primal_sol.iters, primal_sol.stats
306+
getfield(cache, :linear_cache).alg, getfield(cache, :dual_u), primal_sol.resid, cache;
307+
primal_sol.retcode, primal_sol.iters, primal_sol.stats
303308
)
304309
end
305310

0 commit comments

Comments
 (0)