@@ -7,6 +7,9 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
77 A = A. A
88 end
99
10+ # Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
11+ # it makes sense according to the benchmarks, which is dependent on
12+ # whether MKL or OpenBLAS is being used
1013 if A isa Matrix
1114 if ArrayInterface. can_setindex (cache. b) && (size (A,1 ) <= 100 ||
1215 (isopenblas () && size (A,1 ) <= 500 )
@@ -17,6 +20,9 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
1720 alg = LUFactorization ()
1821 SciMLBase. solve (cache, alg, args... ; kwargs... )
1922 end
23+
24+ # These few cases ensure the choice is optimal without the
25+ # dynamic dispatching of factorize
2026 elseif A isa Tridiagonal
2127 alg = GenericFactorization (;fact_alg= lu!)
2228 SciMLBase. solve (cache, alg, args... ; kwargs... )
@@ -26,14 +32,26 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
2632 elseif A isa SparseMatrixCSC
2733 alg = LUFactorization ()
2834 SciMLBase. solve (cache, alg, args... ; kwargs... )
35+
36+ # This catches the cases where a factorization overload could exist
37+ # For example, BlockBandedMatrix
2938 elseif ArrayInterface. isstructured (A)
3039 alg = GenericFactorization ()
3140 SciMLBase. solve (cache, alg, args... ; kwargs... )
41+
42+ # This catches the case where A is a CuMatrix
43+ # Which does not have LU fully defined
3244 elseif ! (A isa AbstractDiffEqOperator)
3345 alg = QRFactorization ()
3446 SciMLBase. solve (cache, alg, args... ; kwargs... )
35- else
47+
48+ # Not factorizable operator, default to only using A*x
49+ # IterativeSolvers is faster on CPU but not GPU-compatible
50+ elseif cache. u isa Array
3651 alg = IterativeSolversJL_GMRES ()
3752 SciMLBase. solve (cache, alg, args... ; kwargs... )
53+ else
54+ alg = KrylovJL_GMRES ()
55+ SciMLBase. solve (cache, alg, args... ; kwargs... )
3856 end
3957end
0 commit comments