11needs_concrete_A (alg:: DefaultLinearSolver ) = true
22mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
3- T13, T14, T15, T16, T17, T18}
3+ T13, T14, T15, T16, T17, T18, T19 }
44 LUFactorization:: T1
55 QRFactorization:: T2
66 DiagonalFactorization:: T3
@@ -19,6 +19,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
1919 NormalCholeskyFactorization:: T16
2020 AppleAccelerateLUFactorization:: T17
2121 MKLLUFactorization:: T18
22+ QRFactorizationPivoted:: T19
2223end
2324
2425# Legacy fallback
@@ -168,8 +169,8 @@ function defaultalg(A, b, assump::OperatorAssumptions)
168169 (A === nothing ? eltype (b) <: Union{Float32, Float64} :
169170 eltype (A) <: Union{Float32, Float64} )
170171 DefaultAlgorithmChoice. RFLUFactorization
171- # elseif A === nothing || A isa Matrix
172- # alg = FastLUFactorization()
172+ # elseif A === nothing || A isa Matrix
173+ # alg = FastLUFactorization()
173174 elseif usemkl && (A === nothing ? eltype (b) <: Union{Float32, Float64} :
174175 eltype (A) <: Union{Float32, Float64} )
175176 DefaultAlgorithmChoice. MKLLUFactorization
@@ -199,9 +200,19 @@ function defaultalg(A, b, assump::OperatorAssumptions)
199200 elseif assump. condition === OperatorCondition. WellConditioned
200201 DefaultAlgorithmChoice. NormalCholeskyFactorization
201202 elseif assump. condition === OperatorCondition. IllConditioned
202- DefaultAlgorithmChoice. QRFactorization
203+ if is_underdetermined (A)
204+ # Underdetermined
205+ DefaultAlgorithmChoice. QRFactorizationPivoted
206+ else
207+ DefaultAlgorithmChoice. QRFactorization
208+ end
203209 elseif assump. condition === OperatorCondition. VeryIllConditioned
204- DefaultAlgorithmChoice. QRFactorization
210+ if is_underdetermined (A)
211+ # Underdetermined
212+ DefaultAlgorithmChoice. QRFactorizationPivoted
213+ else
214+ DefaultAlgorithmChoice. QRFactorization
215+ end
205216 elseif assump. condition === OperatorCondition. SuperIllConditioned
206217 DefaultAlgorithmChoice. SVDFactorization
207218 else
@@ -247,6 +258,12 @@ function algchoice_to_alg(alg::Symbol)
247258 NormalCholeskyFactorization ()
248259 elseif alg === :AppleAccelerateLUFactorization
249260 AppleAccelerateLUFactorization ()
261+ elseif alg === :QRFactorizationPivoted
262+ @static if VERSION ≥ v " 1.7beta"
263+ QRFactorization (ColumnNorm ())
264+ else
265+ QRFactorization (Val (true ))
266+ end
250267 else
251268 error (" Algorithm choice symbol $alg not allowed in the default" )
252269 end
@@ -311,6 +328,12 @@ function defaultalg_symbol(::Type{T}) where {T}
311328end
312329defaultalg_symbol (:: Type{<:GenericFactorization{typeof(ldlt!)}} ) = :LDLtFactorization
313330
331+ @static if VERSION >= v " 1.7"
332+ defaultalg_symbol (:: Type{<:QRFactorization{ColumnNorm}} ) = :QRFactorizationPivoted
333+ else
334+ defaultalg_symbol (:: Type{<:QRFactorization{Val{true}}} ) = :QRFactorizationPivoted
335+ end
336+
314337"""
315338if alg.alg === DefaultAlgorithmChoice.LUFactorization
316339 SciMLBase.solve!(cache, LUFactorization(), args...; kwargs...))
339362 end
340363 ex = Expr (:if , ex. args... )
341364end
365+
366+ """
367+ ```
368+ elseif DefaultAlgorithmChoice.LUFactorization === cache.alg
369+ (cache.cacheval.LUFactorization)' \\ dy
370+ else
371+ ...
372+ end
373+ ```
374+ """
375+ @generated function defaultalg_adjoint_eval (cache:: LinearCache , dy)
376+ ex = :()
377+ for alg in first .(EnumX. symbol_map (DefaultAlgorithmChoice. T))
378+ newex = if alg in Symbol .((DefaultAlgorithmChoice. MKLLUFactorization,
379+ DefaultAlgorithmChoice. AppleAccelerateLUFactorization,
380+ DefaultAlgorithmChoice. RFLUFactorization))
381+ quote
382+ getproperty (cache. cacheval,$ (Meta. quot (alg)))[1 ]' \ dy
383+ end
384+ elseif alg in Symbol .((DefaultAlgorithmChoice. LUFactorization,
385+ DefaultAlgorithmChoice. QRFactorization,
386+ DefaultAlgorithmChoice. KLUFactorization,
387+ DefaultAlgorithmChoice. UMFPACKFactorization,
388+ DefaultAlgorithmChoice. LDLtFactorization,
389+ DefaultAlgorithmChoice. SparspakFactorization,
390+ DefaultAlgorithmChoice. BunchKaufmanFactorization,
391+ DefaultAlgorithmChoice. CHOLMODFactorization,
392+ DefaultAlgorithmChoice. SVDFactorization,
393+ DefaultAlgorithmChoice. CholeskyFactorization,
394+ DefaultAlgorithmChoice. NormalCholeskyFactorization,
395+ DefaultAlgorithmChoice. QRFactorizationPivoted,
396+ DefaultAlgorithmChoice. GenericLUFactorization))
397+ quote
398+ getproperty (cache. cacheval,$ (Meta. quot (alg)))' \ dy
399+ end
400+ elseif alg in Symbol .((DefaultAlgorithmChoice. KrylovJL_GMRES,))
401+ quote
402+ invprob = LinearSolve. LinearProblem (transpose (cache. A), dy)
403+ solve (invprob, cache. alg;
404+ abstol = cache. val. abstol,
405+ reltol = cache. val. reltol,
406+ verbose = cache. val. verbose)
407+ end
408+ else
409+ quote
410+ error (" Default linear solver with algorithm $(alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling" )
411+ end
412+ end
413+
414+ ex = if ex == :()
415+ Expr (:elseif , :(getproperty (DefaultAlgorithmChoice, $ (Meta. quot (alg))) === cache. alg. alg), newex,
416+ :(error (" Algorithm Choice not Allowed" )))
417+ else
418+ Expr (:elseif , :(getproperty (DefaultAlgorithmChoice, $ (Meta. quot (alg))) === cache. alg. alg), newex, ex)
419+ end
420+ end
421+ ex = Expr (:if , ex. args... )
422+ end
0 commit comments