Skip to content

Commit f5cb8bf

Browse files
Merge pull request #806 from ChrisRackauckas-Claude/fix-componentarray-forwarddiff-restructure
Fix ComponentArray type preservation in ForwardDiff specialization
2 parents 013925b + a3b6184 commit f5cb8bf

File tree

3 files changed

+61
-14
lines changed

3 files changed

+61
-14
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ CUSOLVERRF = "0.2.6"
9191
ChainRulesCore = "1.25"
9292
CliqueTrees = "1.11.0"
9393
ConcreteStructs = "0.2.3"
94+
ComponentArrays = "0.15.29"
9495
DocStringExtensions = "0.9.3"
9596
EnumX = "1.0.4"
9697
EnzymeCore = "0.8.5"
@@ -148,6 +149,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
148149
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
149150
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
150151
CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8"
152+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
151153
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
152154
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
153155
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
@@ -176,4 +178,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
176178
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
177179

178180
[targets]
179-
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak", "CliqueTrees", "FastLapackInterface", "SparseArrays", "ExplicitImports"]
181+
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak", "CliqueTrees", "FastLapackInterface", "SparseArrays", "ExplicitImports", "ComponentArrays"]

ext/LinearSolveForwardDiffExt.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
module LinearSolveForwardDiffExt
22

33
using LinearSolve
4-
using LinearSolve: SciMLLinearSolveAlgorithm, __init, LinearVerbosity, DefaultLinearSolver, DefaultAlgorithmChoice, defaultalg
4+
using LinearSolve: SciMLLinearSolveAlgorithm, __init, LinearVerbosity, DefaultLinearSolver,
5+
DefaultAlgorithmChoice, defaultalg
56
using LinearAlgebra
67
using ForwardDiff
78
using ForwardDiff: Dual, Partials
89
using SciMLBase
910
using RecursiveArrayTools
1011
using SciMLLogging
12+
using ArrayInterface
1113

1214
const DualLinearProblem = LinearProblem{
1315
<:Union{Number, <:AbstractArray, Nothing}, iip,
@@ -63,7 +65,7 @@ end
6365
function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kwargs...)
6466
# Solve the primal problem
6567
cache.dual_u0_cache .= cache.linear_cache.u
66-
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
68+
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
6769

6870
cache.primal_u_cache .= cache.linear_cache.u
6971
cache.primal_b_cache .= cache.linear_cache.b
@@ -165,26 +167,27 @@ function linearsolve_dual_solution(
165167
end
166168

167169
function linearsolve_dual_solution(u::AbstractArray, partials,
168-
cache::DualLinearCache{DT}) where {T, V, N, DT <: Dual{T,V,N}}
170+
cache::DualLinearCache{DT}) where {T, V, N, DT <: Dual{T, V, N}}
169171
# Optimized in-place version that reuses cache.dual_u
170172
linearsolve_dual_solution!(getfield(cache, :dual_u), u, partials)
171173
return getfield(cache, :dual_u)
172174
end
173175

174-
function linearsolve_dual_solution!(dual_u::AbstractArray{DT}, u::AbstractArray, partials) where {T, V, N, DT <: Dual{T,V,N}}
176+
function linearsolve_dual_solution!(dual_u::AbstractArray{DT}, u::AbstractArray,
177+
partials) where {T, V, N, DT <: Dual{T, V, N}}
175178
# Direct in-place construction of dual numbers without temporary allocations
176179
n_partials = length(partials)
177-
180+
178181
for i in eachindex(u, dual_u)
179182
# Extract partials for this element directly
180183
partial_vals = ntuple(Val(N)) do j
181184
V(partials[j][i])
182185
end
183-
186+
184187
# Construct dual number in-place
185-
dual_u[i] = DT(u[i], Partials{N,V}(partial_vals))
188+
dual_u[i] = DT(u[i], Partials{N, V}(partial_vals))
186189
end
187-
190+
188191
return dual_u
189192
end
190193

@@ -279,7 +282,7 @@ function __dual_init(
279282
true, # Cache is initially valid
280283
A,
281284
b,
282-
zeros(dual_type, length(b))
285+
ArrayInterface.restructure(b, zeros(dual_type, length(b)))
283286
)
284287
end
285288

@@ -288,18 +291,20 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
288291
end
289292

290293
function SciMLBase.solve!(
291-
cache::DualLinearCache{DT}, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) where {DT <: ForwardDiff.Dual}
294+
cache::DualLinearCache{DT}, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) where {DT <:
295+
ForwardDiff.Dual}
292296
primal_sol = linearsolve_forwarddiff_solve!(
293297
cache::DualLinearCache, getfield(cache, :linear_cache).alg, args...; kwargs...)
294-
dual_sol = linearsolve_dual_solution(getfield(cache,:linear_cache).u, getfield(cache, :rhs_list), cache)
298+
dual_sol = linearsolve_dual_solution(getfield(cache, :linear_cache).u, getfield(cache, :rhs_list), cache)
295299

296300
# For scalars, we still need to assign since cache.dual_u might not be pre-allocated
297301
if !(getfield(cache, :dual_u) isa AbstractArray)
298302
setfield!(cache, :dual_u, dual_sol)
299303
end
300304

301305
return SciMLBase.build_linear_solution(
302-
getfield(cache, :linear_cache).alg, getfield(cache, :dual_u), primal_sol.resid, cache; primal_sol.retcode, primal_sol.iters, primal_sol.stats
306+
getfield(cache, :linear_cache).alg, getfield(cache, :dual_u), primal_sol.resid, cache;
307+
primal_sol.retcode, primal_sol.iters, primal_sol.stats
303308
)
304309
end
305310

test/forwarddiff_overloads.jl

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using LinearSolve
22
using ForwardDiff
33
using Test
44
using SparseArrays
5+
using ComponentArrays
56

67
function h(p)
78
(A = [p[1] p[2]+1 p[2]^3;
@@ -194,4 +195,43 @@ A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
194195
prob = LinearProblem(A, b)
195196
@test init(prob, GenericLUFactorization()) isa LinearSolve.LinearCache
196197

197-
@test init(prob) isa LinearSolve.LinearCache
198+
@test init(prob) isa LinearSolve.LinearCache
199+
200+
# Test ComponentArray with ForwardDiff (Issue SciML/DifferentialEquations.jl#1110)
201+
# This tests that ArrayInterface.restructure preserves ComponentArray structure
202+
203+
# Direct test: ComponentVector with Dual elements should preserve structure
204+
ca_dual = ComponentArray(
205+
a = ForwardDiff.Dual(1.0, 1.0, 0.0),
206+
b = ForwardDiff.Dual(2.0, 0.0, 1.0)
207+
)
208+
A_dual = [ca_dual.a 1.0; 1.0 ca_dual.b]
209+
b_dual = ComponentArray(x = ca_dual.a + 1, y = ca_dual.b * 2)
210+
211+
prob_dual = LinearProblem(A_dual, b_dual)
212+
sol_dual = solve(prob_dual)
213+
214+
# The solution should preserve ComponentArray type
215+
@test sol_dual.u isa ComponentVector
216+
@test hasproperty(sol_dual.u, :x)
217+
@test hasproperty(sol_dual.u, :y)
218+
219+
# Test gradient computation with ComponentArray inside ForwardDiff
220+
function component_linsolve(p)
221+
# Create a matrix that depends on p
222+
A = [p[1] p[2]; p[2] p[1] + 5]
223+
# Create a ComponentArray RHS that depends on p
224+
b_vec = ComponentArray(x = p[1] + 1, y = p[2] * 2)
225+
prob = LinearProblem(A, b_vec)
226+
sol = solve(prob)
227+
# Return sum of solution
228+
return sum(sol.u)
229+
end
230+
231+
p_test = [2.0, 3.0]
232+
# This will internally create Dual numbers and ComponentArrays with Dual elements
233+
grad = ForwardDiff.gradient(component_linsolve, p_test)
234+
@test grad isa Vector
235+
@test length(grad) == 2
236+
@test !any(isnan, grad)
237+
@test !any(isinf, grad)

0 commit comments

Comments
 (0)