diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index cacba0112..31bfab6b2 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -200,6 +200,11 @@ function SciMLBase.init(prob::DualAbstractLinearProblem, alg::GenericLUFactoriza return __init(prob, alg, args...; kwargs...) end +# Opt out for SparspakFactorization +function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SparspakFactorization, args...; kwargs...) + return __init(prob, alg, args...; kwargs...) +end + function SciMLBase.init(prob::DualAbstractLinearProblem, alg::DefaultLinearSolver, args...; kwargs...) if alg.alg === DefaultAlgorithmChoice.GenericLUFactorization return __init(prob, alg, args...; kwargs...) diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 46116e7a3..4dd936873 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -3,6 +3,7 @@ using ForwardDiff using Test using SparseArrays using ComponentArrays +using Sparspak function h(p) (A = [p[1] p[2]+1 p[2]^3; @@ -188,7 +189,6 @@ backslash_x_p = A \ b @test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) - # Test that GenericLU doesn't create a DualLinearCache A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) @@ -197,6 +197,12 @@ prob = LinearProblem(A, b) @test init(prob) isa LinearSolve.LinearCache +# Test that SparspakFactorization doesn't create a DualLinearCache +A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) + +prob = LinearProblem(sparse(A), b) +@test init(prob, SparspakFactorization()) isa LinearSolve.LinearCache + # Test ComponentArray with ForwardDiff (Issue SciML/DifferentialEquations.jl#1110) # This tests that ArrayInterface.restructure preserves ComponentArray structure @@ -234,4 +240,4 @@ grad = ForwardDiff.gradient(component_linsolve, p_test) @test grad isa Vector @test length(grad) == 2 @test !any(isnan, grad) -@test !any(isinf, grad) \ No newline at end of file +@test !any(isinf, grad)