@@ -69,6 +69,16 @@ struct ODESystem <: AbstractSystem
6969 [`calculate_jacobian`](@ref) is called on the system.
7070 """
7171 jac:: RefValue{Matrix{Expression}}
72+ """
73+ Wfact matrix. Note: this field will not be defined until
74+ [`generate_factorized_W`](@ref) is called on the system.
75+ """
76+ Wfact:: RefValue{Matrix{Expression}}
77+ """
78+ Wfact_t matrix. Note: this field will not be defined until
79+ [`generate_factorized_W`](@ref) is called on the system.
80+ """
81+ Wfact_t:: RefValue{Matrix{Expression}}
7282end
7383
7484function ODESystem (eqs)
@@ -89,7 +99,9 @@ function ODESystem(eqs)
8999end
90100function ODESystem (deqs, iv, dvs, ps)
91101 jac = RefValue (Matrix {Expression} (undef, 0 , 0 ))
92- ODESystem (deqs, iv, dvs, ps, jac)
102+ Wfact = RefValue (Matrix {Expression} (undef, 0 , 0 ))
103+ Wfact_t = RefValue (Matrix {Expression} (undef, 0 , 0 ))
104+ ODESystem (deqs, iv, dvs, ps, jac, Wfact, Wfact_t)
93105end
94106
95107function _eq_unordered (a, b)
@@ -152,10 +164,10 @@ function generate_function(sys::ODESystem, dvs, ps; version::FunctionVersion = A
152164 return build_function (rhss, dvs′, ps′, (sys. iv. name,), ODEToExpr (sys); version = version)
153165end
154166
167+ function calculate_factorized_W (sys:: ODESystem , simplify= true )
168+ isempty (sys. Wfact[]) || return (sys. Wfact[],sys. Wfact_t[])
155169
156- function generate_factorized_W (sys:: ODESystem , simplify= true ; version:: FunctionVersion = ArrayFunction)
157170 jac = calculate_jacobian (sys)
158-
159171 gam = Variable (:gam ; known = true )()
160172
161173 W = - LinearAlgebra. I + gam* jac
@@ -170,6 +182,14 @@ function generate_factorized_W(sys::ODESystem, simplify=true; version::FunctionV
170182 if simplify
171183 Wfact_t = simplify_constants .(Wfact_t)
172184 end
185+ sys. Wfact[] = Wfact
186+ sys. Wfact_t[] = Wfact_t
187+
188+ (Wfact,Wfact_t)
189+ end
190+
191+ function generate_factorized_W (sys:: ODESystem , simplify= true ; version:: FunctionVersion = ArrayFunction)
192+ (Wfact,Wfact_t) = calculate_factorized_W (sys,simplify)
173193
174194 if version === SArrayFunction
175195 siz = size (Wfact)
@@ -198,9 +218,12 @@ respectively.
198218function DiffEqBase. ODEFunction (sys:: ODESystem , dvs, ps; version:: FunctionVersion = ArrayFunction)
199219 expr = eval (generate_function (sys, dvs, ps; version = version))
200220 jac_expr = isempty (sys. jac[]) ? nothing : eval (generate_jacobian (sys))
221+ Wfact_expr,Wfact_t_expr = isempty (sys. Wfact[]) ? (nothing ,nothing ) : eval .(calculate_factorized_W (sys))
201222 if version === ArrayFunction
202- ODEFunction {true} (expr,jac= jac_expr)
223+ ODEFunction {true} (eval (expr),jac= jac_expr,
224+ Wfact = Wfact_expr, Wfact_t = Wfact_t_expr)
203225 elseif version === SArrayFunction
204- ODEFunction {false} (expr,jac= jac_expr)
226+ ODEFunction {false} (eval (expr),jac= jac_expr,
227+ Wfact = Wfact_expr, Wfact_t = Wfact_t_expr)
205228 end
206229end
0 commit comments