@@ -20,12 +20,15 @@ _copywitheltype(::Type{T}, As...) where {T} = map(A -> copyto!(similar(A, T), A)
2020
2121# matrix division
2222
23- const CuMatOrAdj{T} = Union{CuMatrix,
24- LinearAlgebra. Adjoint{T, <: CuMatrix{T} },
25- LinearAlgebra. Transpose{T, <: CuMatrix{T} }}
26- const CuOrAdj{T} = Union{CuVecOrMat,
27- LinearAlgebra. Adjoint{T, <: CuVecOrMat{T} },
28- LinearAlgebra. Transpose{T, <: CuVecOrMat{T} }}
23+ const CuMatOrAdj{T} = Union{StridedCuMatrix,
24+ LinearAlgebra. Adjoint{T, <: StridedCuMatrix{T} },
25+ LinearAlgebra. Transpose{T, <: StridedCuMatrix{T} }}
26+ const CuOrAdj{T} = Union{StridedCuVector,
27+ LinearAlgebra. Adjoint{T, <: StridedCuVector{T} },
28+ LinearAlgebra. Transpose{T, <: StridedCuVector{T} },
29+ StridedCuMatrix,
30+ LinearAlgebra. Adjoint{T, <: StridedCuMatrix{T} },
31+ LinearAlgebra. Transpose{T, <: StridedCuMatrix{T} }}
2932
3033function Base.:\ (_A:: CuMatOrAdj , _B:: CuOrAdj )
3134 A, B = copy_cublasfloat (_A, _B)
@@ -101,31 +104,34 @@ using LinearAlgebra: Factorization, AbstractQ, QRCompactWY, QRCompactWYQ, QRPack
101104
102105if VERSION >= v " 1.8-"
103106
107+
108+
104109LinearAlgebra. qr! (A:: StridedCuMatrix{T} ) where T = QR (geqrf! (A:: StridedCuMatrix{T} )... )
105110
111+
106112# conversions
107113CuMatrix (F:: Union{QR,QRCompactWY} ) = CuArray (AbstractArray (F))
108114CuArray (F:: Union{QR,QRCompactWY} ) = CuMatrix (F)
109115CuMatrix (F:: QRPivoted ) = CuArray (AbstractArray (F))
110116CuArray (F:: QRPivoted ) = CuMatrix (F)
111117
112- function LinearAlgebra. ldiv! (_qr:: QR , b:: CuVector )
118+ function LinearAlgebra. ldiv! (_qr:: QR , b:: StridedCuVector )
113119 m,n = size (_qr)
114120 _x = UpperTriangular (_qr. R[1 : min (m,n), 1 : n]) \ ((_qr. Q' * b)[1 : n])
115121 b[1 : n] .= _x
116122 unsafe_free! (_x)
117123 return b[1 : n]
118124end
119125
120- function LinearAlgebra. ldiv! (_qr:: QR , B:: CuMatrix )
126+ function LinearAlgebra. ldiv! (_qr:: QR , B:: StridedCuMatrix )
121127 m,n = size (_qr)
122128 _x = UpperTriangular (_qr. R[1 : min (m,n), 1 : n]) \ ((_qr. Q' * B)[1 : n, 1 : size (B, 2 )])
123129 B[1 : n, 1 : size (B, 2 )] .= _x
124130 unsafe_free! (_x)
125131 return B[1 : n, 1 : size (B, 2 )]
126132end
127133
128- function LinearAlgebra. ldiv! (x:: CuArray , _qr:: QR , b:: CuArray )
134+ function LinearAlgebra. ldiv! (x:: StridedCuArray , _qr:: QR , b:: StridedCuArray )
129135 _x = ldiv! (_qr, b)
130136 x .= vec (_x)
131137 unsafe_free! (_x)
@@ -146,71 +152,74 @@ CuMatrix{T}(Q::QRCompactWYQ) where {T} = error("QRCompactWY format is not suppor
146152Matrix {T} (Q:: QRPackedQ{S,<:CuArray,<:CuArray} ) where {T,S} = Array (CuMatrix {T} (Q))
147153Matrix {T} (Q:: QRCompactWYQ{S,<:CuArray,<:CuArray} ) where {T,S} = Array (CuMatrix {T} (Q))
148154
155+
156+
149157# extracting the full matrix can be done with `collect` (which defaults to `Array`)
150- function Base. collect (src:: Union {QRPackedQ{<: Any ,<: CuArray ,<: CuArray },
151- QRCompactWYQ{<: Any ,<: CuArray ,<: CuArray }})
158+ function Base. collect (src:: Union {QRPackedQ{<: Any ,<: StridedCuArray ,<: StridedCuArray },
159+ QRCompactWYQ{<: Any ,<: StridedCuArray ,<: StridedCuArray }})
152160 dest = similar (src)
153161 copyto! (dest, I)
154162 lmul! (src, dest)
155163 collect (dest)
156164end
157165
158166# avoid the generic similar fallback that returns a CPU array
159- Base. similar (Q:: Union {QRPackedQ{<: Any ,<: CuArray ,<: CuArray },
160- QRCompactWYQ{<: Any ,<: CuArray ,<: CuArray }},
167+ Base. similar (Q:: Union {QRPackedQ{<: Any ,<: StridedCuArray ,<: StridedCuArray },
168+ QRCompactWYQ{<: Any ,<: StridedCuArray ,<: StridedCuArray }},
161169 :: Type{T} , dims:: Dims{N} ) where {T,N} =
162170 CuArray {T,N} (undef, dims)
163171
164- function Base. getindex (Q:: QRPackedQ{<:Any, <:CuArray } , :: Colon , j:: Int )
172+ function Base. getindex (Q:: QRPackedQ{<:Any, <:StridedCuArray } , :: Colon , j:: Int )
165173 y = CUDA. zeros (eltype (Q), size (Q, 2 ))
166174 y[j] = 1
167175 lmul! (Q, y)
168176end
169177
178+
170179# multiplication by Q
171- LinearAlgebra. lmul! (A:: QRPackedQ{T,<:CuArray ,<:CuArray } ,
180+ LinearAlgebra. lmul! (A:: QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray } ,
172181 B:: CuVecOrMat{T} ) where {T<: BlasFloat } =
173182 ormqr! (' L' , ' N' , A. factors, A. τ, B)
174- LinearAlgebra. lmul! (adjA:: Adjoint{T,<:QRPackedQ{T,<:CuArray ,<:CuArray }} ,
183+ LinearAlgebra. lmul! (adjA:: Adjoint{T,<:QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray }} ,
175184 B:: CuVecOrMat{T} ) where {T<: BlasReal } =
176185 ormqr! (' L' , ' T' , parent (adjA). factors, parent (adjA). τ, B)
177- LinearAlgebra. lmul! (adjA:: Adjoint{T,<:QRPackedQ{T,<:CuArray ,<:CuArray }} ,
186+ LinearAlgebra. lmul! (adjA:: Adjoint{T,<:QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray }} ,
178187 B:: CuVecOrMat{T} ) where {T<: BlasComplex } =
179188 ormqr! (' L' , ' C' , parent (adjA). factors, parent (adjA). τ, B)
180- LinearAlgebra. lmul! (trA:: Transpose{T,<:QRPackedQ{T,<:CuArray ,<:CuArray }} ,
189+ LinearAlgebra. lmul! (trA:: Transpose{T,<:QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray }} ,
181190 B:: CuVecOrMat{T} ) where {T<: BlasFloat } =
182191 ormqr! (' L' , ' T' , parent (trA). factors, parent (trA). τ, B)
183192
184193LinearAlgebra. rmul! (A:: CuVecOrMat{T} ,
185- B:: QRPackedQ{T,<:CuArray ,<:CuArray } ) where {T<: BlasFloat } =
194+ B:: QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray } ) where {T<: BlasFloat } =
186195 ormqr! (' R' , ' N' , B. factors, B. τ, A)
187196LinearAlgebra. rmul! (A:: CuVecOrMat{T} ,
188- adjB:: Adjoint{<:Any,<:QRPackedQ{T,<:CuArray ,<:CuArray }} ) where {T<: BlasReal } =
197+ adjB:: Adjoint{<:Any,<:QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray }} ) where {T<: BlasReal } =
189198 ormqr! (' R' , ' T' , parent (adjB). factors, parent (adjB). τ, A)
190199LinearAlgebra. rmul! (A:: CuVecOrMat{T} ,
191- adjB:: Adjoint{<:Any,<:QRPackedQ{T,<:CuArray ,<:CuArray }} ) where {T<: BlasComplex } =
200+ adjB:: Adjoint{<:Any,<:QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray }} ) where {T<: BlasComplex } =
192201 ormqr! (' R' , ' C' , parent (adjB). factors, parent (adjB). τ, A)
193202LinearAlgebra. rmul! (A:: CuVecOrMat{T} ,
194- trA:: Transpose{<:Any,<:QRPackedQ{T,<:CuArray ,<:CuArray }} ) where {T<: BlasFloat } =
203+ trA:: Transpose{<:Any,<:QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray }} ) where {T<: BlasFloat } =
195204 ormqr! (' R' , ' T' , parent (trA). factors, parent (adjB). τ, A)
196205
197206else
198207
199208struct CuQR{T} <: Factorization{T}
200- factors:: CuMatrix
201- τ:: CuVector {T}
202- CuQR {T} (factors:: CuMatrix {T} , τ:: CuVector {T} ) where {T} = new (factors, τ)
209+ factors:: StridedCuMatrix
210+ τ:: StridedCuVector {T}
211+ CuQR {T} (factors:: StridedCuMatrix {T} , τ:: StridedCuVector {T} ) where {T} = new (factors, τ)
203212end
204213
205214struct CuQRPackedQ{T} <: AbstractQ{T}
206- factors:: CuMatrix {T}
207- τ:: CuVector {T}
208- CuQRPackedQ {T} (factors:: CuMatrix {T} , τ:: CuVector {T} ) where {T} = new (factors, τ)
215+ factors:: StridedCuMatrix {T}
216+ τ:: StridedCuVector {T}
217+ CuQRPackedQ {T} (factors:: StridedCuMatrix {T} , τ:: StridedCuVector {T} ) where {T} = new (factors, τ)
209218end
210219
211- CuQR (factors:: CuMatrix {T} , τ:: CuVector {T} ) where {T} =
220+ CuQR (factors:: StridedCuMatrix {T} , τ:: StridedCuVector {T} ) where {T} =
212221 CuQR {T} (factors, τ)
213- CuQRPackedQ (factors:: CuMatrix {T} , τ:: CuVector {T} ) where {T} =
222+ CuQRPackedQ (factors:: StridedCuMatrix {T} , τ:: StridedCuVector {T} ) where {T} =
214223 CuQRPackedQ {T} (factors, τ)
215224
216225# AbstractQ's `size` is the size of the full matrix,
@@ -245,7 +254,7 @@ Base.Matrix(A::CuQRPackedQ) = Matrix(CuMatrix(A))
245254function Base. getproperty (A:: CuQR , d:: Symbol )
246255 m, n = size (getfield (A, :factors ))
247256 if d == :R
248- return triu! (A. factors[ 1 : min (m, n), 1 : n] )
257+ return triu! (view ( A. factors, 1 : min (m, n), 1 : n) )
249258 elseif d == :Q
250259 return CuQRPackedQ (A. factors, A. τ)
251260 else
@@ -259,25 +268,25 @@ Base.iterate(S::CuQR, ::Val{:R}) = (S.R, Val(:done))
259268Base. iterate (S:: CuQR , :: Val{:done} ) = nothing
260269
261270# Apply changes Q from the left
262- LinearAlgebra. lmul! (A:: CuQRPackedQ{T} , B:: CuVecOrMat {T} ) where {T<: BlasFloat } =
271+ LinearAlgebra. lmul! (A:: CuQRPackedQ{T} , B:: StridedCuVecOrMat {T} ) where {T<: BlasFloat } =
263272 ormqr! (' L' , ' N' , A. factors, A. τ, B)
264- LinearAlgebra. lmul! (adjA:: Adjoint{T,<:CuQRPackedQ{T}} , B:: CuVecOrMat {T} ) where {T<: BlasReal } =
273+ LinearAlgebra. lmul! (adjA:: Adjoint{T,<:CuQRPackedQ{T}} , B:: StridedCuVecOrMat {T} ) where {T<: BlasReal } =
265274 ormqr! (' L' , ' T' , parent (adjA). factors, parent (adjA). τ, B)
266- LinearAlgebra. lmul! (adjA:: Adjoint{T,<:CuQRPackedQ{T}} , B:: CuVecOrMat {T} ) where {T<: BlasComplex } =
275+ LinearAlgebra. lmul! (adjA:: Adjoint{T,<:CuQRPackedQ{T}} , B:: StridedCuVecOrMat {T} ) where {T<: BlasComplex } =
267276 ormqr! (' L' , ' C' , parent (adjA). factors, parent (adjA). τ, B)
268- LinearAlgebra. lmul! (trA:: Transpose{T,<:CuQRPackedQ{T}} , B:: CuVecOrMat {T} ) where {T<: BlasFloat } =
277+ LinearAlgebra. lmul! (trA:: Transpose{T,<:CuQRPackedQ{T}} , B:: StridedCuVecOrMat {T} ) where {T<: BlasFloat } =
269278 ormqr! (' L' , ' T' , parent (trA). factors, parent (trA). τ, B)
270279
271280# Apply changes Q from the right
272- LinearAlgebra. rmul! (A:: CuVecOrMat {T} , B:: CuQRPackedQ{T} ) where {T<: BlasFloat } =
281+ LinearAlgebra. rmul! (A:: StridedCuVecOrMat {T} , B:: CuQRPackedQ{T} ) where {T<: BlasFloat } =
273282 ormqr! (' R' , ' N' , B. factors, B. τ, A)
274- LinearAlgebra. rmul! (A:: CuVecOrMat {T} ,
283+ LinearAlgebra. rmul! (A:: StridedCuVecOrMat {T} ,
275284 adjB:: Adjoint{<:Any,<:CuQRPackedQ{T}} ) where {T<: BlasReal } =
276285 ormqr! (' R' , ' T' , parent (adjB). factors, parent (adjB). τ, A)
277- LinearAlgebra. rmul! (A:: CuVecOrMat {T} ,
286+ LinearAlgebra. rmul! (A:: StridedCuVecOrMat {T} ,
278287 adjB:: Adjoint{<:Any,<:CuQRPackedQ{T}} ) where {T<: BlasComplex } =
279288 ormqr! (' R' , ' C' , parent (adjB). factors, parent (adjB). τ, A)
280- LinearAlgebra. rmul! (A:: CuVecOrMat {T} ,
289+ LinearAlgebra. rmul! (A:: StridedCuVecOrMat {T} ,
281290 trA:: Transpose{<:Any,<:CuQRPackedQ{T}} ) where {T<: BlasFloat } =
282291 ormqr! (' R' , ' T' , parent (trA). factors, parent (adjB). τ, A)
283292
@@ -300,23 +309,23 @@ end
300309LinearAlgebra. det (Q:: CuQRPackedQ{<:Real} ) = isodd (count (! iszero, Q. τ)) ? - 1 : 1
301310LinearAlgebra. det (Q:: CuQRPackedQ ) = prod (τ -> iszero (τ) ? one (τ) : - sign (τ)^ 2 , Q. τ)
302311
303- function LinearAlgebra. ldiv! (_qr:: CuQR , b:: CuVector )
312+ function LinearAlgebra. ldiv! (_qr:: CuQR , b:: StridedCuVector )
304313 m,n = size (_qr)
305314 _x = UpperTriangular (_qr. R[1 : min (m,n), 1 : n]) \ ((_qr. Q' * b)[1 : n])
306315 b[1 : n] .= _x
307316 unsafe_free! (_x)
308317 return b[1 : n]
309318end
310319
311- function LinearAlgebra. ldiv! (_qr:: CuQR , B:: CuMatrix )
320+ function LinearAlgebra. ldiv! (_qr:: CuQR , B:: StridedCuMatrix )
312321 m,n = size (_qr)
313322 _x = UpperTriangular (_qr. R[1 : min (m,n), 1 : n]) \ ((_qr. Q' * B)[1 : n, 1 : size (B, 2 )])
314323 B[1 : n, 1 : size (B, 2 )] .= _x
315324 unsafe_free! (_x)
316325 return B[1 : n, 1 : size (B, 2 )]
317326end
318327
319- function LinearAlgebra. ldiv! (x:: CuArray ,_qr:: CuQR , b:: CuArray )
328+ function LinearAlgebra. ldiv! (x:: StridedCuArray ,_qr:: CuQR , b:: StridedCuArray )
320329 _x = ldiv! (_qr, b)
321330 x .= vec (_x)
322331 unsafe_free! (_x)
0 commit comments