@@ -5,11 +5,9 @@ using LinearSolve: LinearSolve, is_cusparse, defaultalg, cudss_loaded, DefaultLi
55 DefaultAlgorithmChoice, ALREADY_WARNED_CUDSS, LinearCache,
66 needs_concrete_A,
77 error_no_cudss_lu, init_cacheval, OperatorAssumptions,
8- CudaOffloadFactorization, CudaOffloadLUFactorization,
9- CudaOffloadQRFactorization,
8+ CudaOffloadFactorization, CudaOffloadLUFactorization, CudaOffloadQRFactorization,
109 CUDAOffload32MixedLUFactorization,
11- SparspakFactorization, KLUFactorization, UMFPACKFactorization,
12- LinearVerbosity
10+ SparspakFactorization, KLUFactorization, UMFPACKFactorization, LinearVerbosity
1311using LinearSolve. LinearAlgebra, LinearSolve. SciMLBase, LinearSolve. ArrayInterface
1412using SciMLBase: AbstractSciMLOperator
1513
@@ -25,16 +23,11 @@ function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
2523 if LinearSolve. cudss_loaded (A)
2624 LinearSolve. DefaultLinearSolver (LinearSolve. DefaultAlgorithmChoice. LUFactorization)
2725 else
28- error (" CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library." )
29- end
30- end
31-
32- function LinearSolve. defaultalg (A:: CUDA.CUSPARSE.CuSparseMatrixCSC{Tv, Ti} , b,
33- assump:: OperatorAssumptions{Bool} ) where {Tv, Ti}
34- if LinearSolve. cudss_loaded (A)
35- LinearSolve. DefaultLinearSolver (LinearSolve. DefaultAlgorithmChoice. LUFactorization)
36- else
37- error (" CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library." )
26+ if ! LinearSolve. ALREADY_WARNED_CUDSS[]
27+ @warn (" CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov" )
28+ LinearSolve. ALREADY_WARNED_CUDSS[] = true
29+ end
30+ LinearSolve. DefaultLinearSolver (LinearSolve. DefaultAlgorithmChoice. KrylovJL_GMRES)
3831 end
3932end
4033
@@ -45,13 +38,6 @@ function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
4538 nothing
4639end
4740
48- function LinearSolve. error_no_cudss_lu (A:: CUDA.CUSPARSE.CuSparseMatrixCSC )
49- if ! LinearSolve. cudss_loaded (A)
50- error (" CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library." )
51- end
52- nothing
53- end
54-
5541function SciMLBase. solve! (cache:: LinearSolve.LinearCache , alg:: CudaOffloadLUFactorization ;
5642 kwargs... )
5743 if cache. isfresh
@@ -66,15 +52,14 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFact
6652 SciMLBase. build_linear_solution (alg, y, nothing , cache)
6753end
6854
69- function LinearSolve. init_cacheval (
70- alg:: CudaOffloadLUFactorization , A:: AbstractArray , b, u, Pl, Pr,
55+ function LinearSolve. init_cacheval (alg:: CudaOffloadLUFactorization , A:: AbstractArray , b, u, Pl, Pr,
7156 maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} ,
7257 assumptions:: OperatorAssumptions )
7358 # Check if CUDA is functional before creating CUDA arrays
7459 if ! CUDA. functional ()
7560 return nothing
7661 end
77-
62+
7863 T = eltype (A)
7964 noUnitT = typeof (zero (T))
8065 luT = LinearAlgebra. lutype (noUnitT)
@@ -102,7 +87,7 @@ function LinearSolve.init_cacheval(alg::CudaOffloadQRFactorization, A, b, u, Pl,
10287 if ! CUDA. functional ()
10388 return nothing
10489 end
105-
90+
10691 qr (CUDA. CuArray (A))
10792end
10893
@@ -119,42 +104,35 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactor
119104 SciMLBase. build_linear_solution (alg, y, nothing , cache)
120105end
121106
122- function LinearSolve. init_cacheval (
123- alg:: CudaOffloadFactorization , A:: AbstractArray , b, u, Pl, Pr,
107+ function LinearSolve. init_cacheval (alg:: CudaOffloadFactorization , A:: AbstractArray , b, u, Pl, Pr,
124108 maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} ,
125109 assumptions:: OperatorAssumptions )
126110 qr (CUDA. CuArray (A))
127111end
128112
129113function LinearSolve. init_cacheval (
130114 :: SparspakFactorization , A:: CUDA.CUSPARSE.CuSparseMatrixCSR , b, u,
131- Pl, Pr, maxiters:: Int , abstol, reltol,
132- verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
115+ Pl, Pr, maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
133116 nothing
134117end
135118
136119function LinearSolve. init_cacheval (
137120 :: KLUFactorization , A:: CUDA.CUSPARSE.CuSparseMatrixCSR , b, u,
138- Pl, Pr, maxiters:: Int , abstol, reltol,
139- verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
121+ Pl, Pr, maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
140122 nothing
141123end
142124
143125function LinearSolve. init_cacheval (
144126 :: UMFPACKFactorization , A:: CUDA.CUSPARSE.CuSparseMatrixCSR , b, u,
145- Pl, Pr, maxiters:: Int , abstol, reltol,
146- verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
127+ Pl, Pr, maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
147128 nothing
148129end
149130
150131# Mixed precision CUDA LU implementation
151- function SciMLBase. solve! (
152- cache:: LinearSolve.LinearCache , alg:: CUDAOffload32MixedLUFactorization ;
132+ function SciMLBase. solve! (cache:: LinearSolve.LinearCache , alg:: CUDAOffload32MixedLUFactorization ;
153133 kwargs... )
154134 if cache. isfresh
155- fact, A_gpu_f32,
156- b_gpu_f32,
157- u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
135+ fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
158136 # Compute 32-bit type on demand and convert
159137 T32 = eltype (cache. A) <: Complex ? ComplexF32 : Float32
160138 A_f32 = T32 .(cache. A)
@@ -163,14 +141,12 @@ function SciMLBase.solve!(
163141 cache. cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32)
164142 cache. isfresh = false
165143 end
166- fact, A_gpu_f32,
167- b_gpu_f32,
168- u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
169-
144+ fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
145+
170146 # Compute types on demand for conversions
171147 T32 = eltype (cache. A) <: Complex ? ComplexF32 : Float32
172148 Torig = eltype (cache. u)
173-
149+
174150 # Convert b to Float32, solve, then convert back to original precision
175151 b_f32 = T32 .(cache. b)
176152 copyto! (b_gpu_f32, b_f32)
0 commit comments