Skip to content

Commit 7484133

Browse files
Add test for ComponentArray type preservation with ForwardDiff
Added comprehensive tests to verify that ComponentArrays maintain their structure when used with ForwardDiff in LinearSolve: 1. Direct test with Dual-valued ComponentArrays to verify structure preservation 2. Test gradient computation with ComponentArray RHS inside ForwardDiff Tests cover the fix for issue SciML/DifferentialEquations.jl#1110 where ComponentArrays were being converted to plain Vectors, causing TypeErrors. Added ComponentArrays to test dependencies in Project.toml [extras] and [targets]. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 20b5792 commit 7484133

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
148148
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
149149
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
150150
CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8"
151+
ComponentArrays = "b0b7db55-8e73-11e8-0d91-e798b35f94b1"
151152
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
152153
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
153154
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
@@ -176,4 +177,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
176177
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
177178

178179
[targets]
179-
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak", "CliqueTrees", "FastLapackInterface", "SparseArrays", "ExplicitImports"]
180+
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak", "CliqueTrees", "FastLapackInterface", "SparseArrays", "ExplicitImports", "ComponentArrays"]

test/forwarddiff_overloads.jl

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using LinearSolve
22
using ForwardDiff
33
using Test
44
using SparseArrays
5+
using ComponentArrays
56

67
function h(p)
78
(A = [p[1] p[2]+1 p[2]^3;
@@ -194,4 +195,43 @@ A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
194195
prob = LinearProblem(A, b)
195196
@test init(prob, GenericLUFactorization()) isa LinearSolve.LinearCache
196197

197-
@test init(prob) isa LinearSolve.LinearCache
198+
@test init(prob) isa LinearSolve.LinearCache
199+
200+
# Test ComponentArray with ForwardDiff (Issue SciML/DifferentialEquations.jl#1110)
201+
# This tests that ArrayInterface.restructure preserves ComponentArray structure
202+
203+
# Direct test: ComponentVector with Dual elements should preserve structure
204+
ca_dual = ComponentArray(
205+
a = ForwardDiff.Dual(1.0, 1.0, 0.0),
206+
b = ForwardDiff.Dual(2.0, 0.0, 1.0)
207+
)
208+
A_dual = [ca_dual.a 1.0; 1.0 ca_dual.b]
209+
b_dual = ComponentArray(x = ca_dual.a + 1, y = ca_dual.b * 2)
210+
211+
prob_dual = LinearProblem(A_dual, b_dual)
212+
sol_dual = solve(prob_dual)
213+
214+
# The solution should preserve ComponentArray type
215+
@test sol_dual.u isa ComponentVector
216+
@test hasproperty(sol_dual.u, :x)
217+
@test hasproperty(sol_dual.u, :y)
218+
219+
# Test gradient computation with ComponentArray inside ForwardDiff
220+
function component_linsolve(p)
221+
# Create a matrix that depends on p
222+
A = [p[1] p[2]; p[2] p[1] + 5]
223+
# Create a ComponentArray RHS that depends on p
224+
b_vec = ComponentArray(x = p[1] + 1, y = p[2] * 2)
225+
prob = LinearProblem(A, b_vec)
226+
sol = solve(prob)
227+
# Return sum of solution
228+
return sum(sol.u)
229+
end
230+
231+
p_test = [2.0, 3.0]
232+
# This will internally create Dual numbers and ComponentArrays with Dual elements
233+
grad = ForwardDiff.gradient(component_linsolve, p_test)
234+
@test grad isa Vector
235+
@test length(grad) == 2
236+
@test !any(isnan, grad)
237+
@test !any(isinf, grad)

0 commit comments

Comments
 (0)