@@ -65,6 +65,11 @@ struct ODESystem <: AbstractSystem
6565 """ Parameter variables."""
6666 ps:: Vector{Variable}
6767 """
68+ Time-derivative matrix. Note: this field will not be defined until
69+ [`calculate_tgrad`](@ref) is called on the system.
70+ """
71+ tgrad:: RefValue{Vector{Expression}}
72+ """
6873 Jacobian matrix. Note: this field will not be defined until
6974 [`calculate_jacobian`](@ref) is called on the system.
7075 """
@@ -99,10 +104,11 @@ function ODESystem(eqs)
99104end
100105
101106function ODESystem (deqs:: AbstractVector{DiffEq} , iv, dvs, ps)
107+ tgrad = RefValue (Vector {Expression} (undef, 0 ))
102108 jac = RefValue (Matrix {Expression} (undef, 0 , 0 ))
103109 Wfact = RefValue (Matrix {Expression} (undef, 0 , 0 ))
104110 Wfact_t = RefValue (Matrix {Expression} (undef, 0 , 0 ))
105- ODESystem (deqs, iv, dvs, ps, jac, Wfact, Wfact_t)
111+ ODESystem (deqs, iv, dvs, ps, tgrad, jac, Wfact, Wfact_t)
106112end
107113
108114function ODESystem (deqs:: AbstractVector{<:Equation} , iv, dvs, ps)
@@ -133,6 +139,15 @@ independent_variables(sys::ODESystem) = Set{Variable}([sys.iv])
133139dependent_variables (sys:: ODESystem ) = Set {Variable} (sys. dvs)
134140parameters (sys:: ODESystem ) = Set {Variable} (sys. ps)
135141
142+ function calculate_tgrad (sys:: ODESystem )
143+ isempty (sys. tgrad[]) || return sys. tgrad[] # use cached tgrad, if possible
144+ rhs = [detime_dvs (eq. rhs) for eq ∈ sys. eqs]
145+ iv = sys. iv ()
146+ notime_tgrad = [expand_derivatives (ModelingToolkit. Differential (iv)(r)) for r in rhs]
147+ tgrad = retime_dvs .(notime_tgrad,(sys. dvs,),iv)
148+ sys. tgrad[] = tgrad
149+ return tgrad
150+ end
136151
137152function calculate_jacobian (sys:: ODESystem )
138153 isempty (sys. jac[]) || return sys. jac[] # use cached Jacobian, if possible
@@ -160,6 +175,11 @@ function (f::ODEToExpr)(O::Operation)
160175end
161176(f:: ODEToExpr )(x) = convert (Expr, x)
162177
178+ function generate_tgrad (sys:: ODESystem , dvs = sys. dvs, ps = sys. ps, expression = Val{true }; kwargs... )
179+ tgrad = calculate_tgrad (sys)
180+ return build_function (tgrad, dvs, ps, (sys. iv. name,), ODEToExpr (sys), expression; kwargs... )
181+ end
182+
163183function generate_jacobian (sys:: ODESystem , dvs = sys. dvs, ps = sys. ps, expression = Val{true }; kwargs... )
164184 jac = calculate_jacobian (sys)
165185 return build_function (jac, dvs, ps, (sys. iv. name,), ODEToExpr (sys), expression; kwargs... )
@@ -218,13 +238,21 @@ are used to set the order of the dependent variable and parameter vectors,
218238respectively.
219239"""
220240function DiffEqBase. ODEFunction {iip} (sys:: ODESystem , dvs = sys. dvs, ps = sys. ps;
221- version = nothing ,
241+ version = nothing , tgrad = false ,
222242 jac = false , Wfact = false ) where {iip}
223243 f_oop,f_iip = generate_function (sys, dvs, ps, Val{false })
224244
225245 f (u,p,t) = f_oop (u,p,t)
226246 f (du,u,p,t) = f_iip (du,u,p,t)
227247
248+ if tgrad
249+ tgrad_oop,tgrad_iip = generate_tgrad (sys, dvs, ps, Val{false })
250+ _tgrad (u,p,t) = tgrad_oop (u,p,t)
251+ _tgrad (J,u,p,t) = tgrad_iip (J,u,p,t)
252+ else
253+ _tgrad = nothing
254+ end
255+
228256 if jac
229257 jac_oop,jac_iip = generate_jacobian (sys, dvs, ps, Val{false })
230258 _jac (u,p,t) = jac_oop (u,p,t)
@@ -246,6 +274,7 @@ function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs = sys.dvs, ps = sys.ps;
246274 end
247275
248276 ODEFunction {iip} (f,jac= _jac,
277+ tgrad = tgrad,
249278 Wfact = _Wfact,
250279 Wfact_t = _Wfact_t,
251280 syms = string .(sys. dvs))
0 commit comments