@@ -26,6 +26,46 @@ function appleaccelerate_isavailable()
2626 return true
2727end
2828
29+ function aa_getrf! (A:: AbstractMatrix{<:ComplexF64} ;
30+ ipiv = similar (A, Cint, min (size (A, 1 ), size (A, 2 ))),
31+ info = Ref {Cint} (),
32+ check = false )
33+ require_one_based_indexing (A)
34+ check && chkfinite (A)
35+ chkstride1 (A)
36+ m, n = size (A)
37+ lda = max (1 , stride (A, 2 ))
38+ if isempty (ipiv)
39+ ipiv = similar (A, Cint, min (size (A, 1 ), size (A, 2 )))
40+ end
41+ ccall ((" zgetrf_" , libacc), Cvoid,
42+ (Ref{Cint}, Ref{Cint}, Ptr{ComplexF64},
43+ Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
44+ m, n, A, lda, ipiv, info)
45+ info[] < 0 && throw (ArgumentError (" Invalid arguments sent to LAPACK dgetrf_" ))
46+ A, ipiv, BlasInt (info[]), info # Error code is stored in LU factorization type
47+ end
48+
49+ function aa_getrf! (A:: AbstractMatrix{<:ComplexF32} ;
50+ ipiv = similar (A, Cint, min (size (A, 1 ), size (A, 2 ))),
51+ info = Ref {Cint} (),
52+ check = false )
53+ require_one_based_indexing (A)
54+ check && chkfinite (A)
55+ chkstride1 (A)
56+ m, n = size (A)
57+ lda = max (1 , stride (A, 2 ))
58+ if isempty (ipiv)
59+ ipiv = similar (A, Cint, min (size (A, 1 ), size (A, 2 )))
60+ end
61+ ccall ((" cgetrf_" , libacc), Cvoid,
62+ (Ref{Cint}, Ref{Cint}, Ptr{ComplexF32},
63+ Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
64+ m, n, A, lda, ipiv, info)
65+ info[] < 0 && throw (ArgumentError (" Invalid arguments sent to LAPACK dgetrf_" ))
66+ A, ipiv, BlasInt (info[]), info # Error code is stored in LU factorization type
67+ end
68+
2969function aa_getrf! (A:: AbstractMatrix{<:Float64} ;
3070 ipiv = similar (A, Cint, min (size (A, 1 ), size (A, 2 ))),
3171 info = Ref {Cint} (),
@@ -67,6 +107,55 @@ function aa_getrf!(A::AbstractMatrix{<:Float32};
67107 A, ipiv, BlasInt (info[]), info # Error code is stored in LU factorization type
68108end
69109
110+ function aa_getrs! (trans:: AbstractChar ,
111+ A:: AbstractMatrix{<:ComplexF64} ,
112+ ipiv:: AbstractVector{Cint} ,
113+ B:: AbstractVecOrMat{<:ComplexF64} ;
114+ info = Ref {Cint} ())
115+ require_one_based_indexing (A, ipiv, B)
116+ LinearAlgebra. LAPACK. chktrans (trans)
117+ chkstride1 (A, B, ipiv)
118+ n = LinearAlgebra. checksquare (A)
119+ if n != size (B, 1 )
120+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
121+ end
122+ if n != length (ipiv)
123+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
124+ end
125+ nrhs = size (B, 2 )
126+ ccall ((" zgetrs_" , libacc), Cvoid,
127+ (Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{ComplexF64}, Ref{Cint},
128+ Ptr{Cint}, Ptr{ComplexF64}, Ref{Cint}, Ptr{Cint}, Clong),
129+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
130+ 1 )
131+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
132+ end
133+
134+ function aa_getrs! (trans:: AbstractChar ,
135+ A:: AbstractMatrix{<:ComplexF32} ,
136+ ipiv:: AbstractVector{Cint} ,
137+ B:: AbstractVecOrMat{<:ComplexF32} ;
138+ info = Ref {Cint} ())
139+ require_one_based_indexing (A, ipiv, B)
140+ LinearAlgebra. LAPACK. chktrans (trans)
141+ chkstride1 (A, B, ipiv)
142+ n = LinearAlgebra. checksquare (A)
143+ if n != size (B, 1 )
144+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
145+ end
146+ if n != length (ipiv)
147+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
148+ end
149+ nrhs = size (B, 2 )
150+ ccall ((" cgetrs_" , libacc), Cvoid,
151+ (Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{ComplexF32}, Ref{Cint},
152+ Ptr{Cint}, Ptr{ComplexF32}, Ref{Cint}, Ptr{Cint}, Clong),
153+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
154+ 1 )
155+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
156+ B
157+ end
158+
70159function aa_getrs! (trans:: AbstractChar ,
71160 A:: AbstractMatrix{<:Float64} ,
72161 ipiv:: AbstractVector{Cint} ,
@@ -128,12 +217,20 @@ else
128217 nothing
129218end
130219
131- function LinearSolve. init_cacheval (alg:: AppleAccelerateLUFactorization , A, b, u, Pl, Pr,
220+ function LinearSolve. init_cacheval (alg:: AppleAccelerateLUFactorization , A:: AbstractMatrix{<:Float64} , b:: AbstractArray{<:Float64} , u, Pl, Pr,
132221 maxiters:: Int , abstol, reltol, verbose:: Bool ,
133222 assumptions:: OperatorAssumptions )
134223 PREALLOCATED_APPLE_LU
135224end
136225
226+ function LinearSolve. init_cacheval (alg:: AppleAccelerateLUFactorization , A, b, u, Pl, Pr,
227+ maxiters:: Int , abstol, reltol, verbose:: Bool ,
228+ assumptions:: OperatorAssumptions )
229+ A = rand (eltype (A), 0 , 0 )
230+ luinst = ArrayInterface. lu_instance (A)
231+ LU (luinst. factors, similar (A, Cint, 0 ), luinst. info), Ref {Cint} ()
232+ end
233+
137234function SciMLBase. solve! (cache:: LinearCache , alg:: AppleAccelerateLUFactorization ;
138235 kwargs... )
139236 A = cache. A
0 commit comments