Skip to content

Commit 5a8fcb4

Browse files
committed
Automatic state detection and some performance optimization
1 parent fff4d5c commit 5a8fcb4

File tree

3 files changed

+36
-8
lines changed

3 files changed

+36
-8
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Latexify, Unitful, ArrayInterface
66
using MacroTools
77
using UnPack: @unpack
88
using DiffEqJump
9+
using DataStructures: OrderedDict, OrderedSet
910

1011
using Base.Threads
1112
import MacroTools: splitdef, combinedef, postwalk, striplines

src/systems/diffeqs/first_order_transform.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using DataStructures: OrderedDict
21
function lower_varname(var::Variable, idv, order)
32
order == 0 && return var
43
name = Symbol(var.name, , string(idv.name)^order)

src/systems/diffeqs/odesystem.jl

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,19 +77,47 @@ end
7777
var_from_nested_derivative(x) = var_from_nested_derivative(x,0)
7878
var_from_nested_derivative(x::Constant) = (missing, missing)
7979
var_from_nested_derivative(x,i) = x.op isa Differential ? var_from_nested_derivative(x.args[1],i+1) : (x.op,i)
80+
81+
function extract_eqs_states_ps(eqs::AbstractArray{<:Equation}, iv)
82+
# NOTE: this assumes that the order of algebric equations doesn't matter
83+
diffvars = OrderedSet{Variable}()
84+
allstates = OrderedSet{Variable}()
85+
ps = OrderedSet{Variable}()
86+
# reorder equations such that it is in the form of `diffeq, algeeq`
87+
diffeq = Equation[]
88+
algeeq = Equation[]
89+
for eq in eqs
90+
for var in vars(eq.rhs for eq eqs)
91+
var isa Variable || continue
92+
if isparameter(var)
93+
isequal(var, iv) || push!(ps, var)
94+
else
95+
push!(allstates, var)
96+
end
97+
end
98+
if eq.lhs isa Constant
99+
push!(algeeq, eq)
100+
else
101+
diffvar = first(var_from_nested_derivative(eq.lhs))
102+
diffvar in diffvars && throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
103+
push!(diffvars, diffvar)
104+
push!(diffeq, eq)
105+
end
106+
end
107+
algevars = setdiff(allstates, diffvars)
108+
# the orders here are very important!
109+
return append!(diffeq, algeeq), vcat(collect(diffvars), collect(algevars)), ps
110+
end
111+
80112
iv_from_nested_derivative(x) = x.op isa Differential ? iv_from_nested_derivative(x.args[1]) : x.args[1].op
81113
iv_from_nested_derivative(x::Constant) = missing
82114

83115
function ODESystem(eqs; kwargs...)
84116
ivs = unique(skipmissing(iv_from_nested_derivative(eq.lhs) for eq eqs))
85-
length(ivs) == 1 || throw(ArgumentError("one independent variable currently supported"))
117+
length(ivs) == 1 || throw(ArgumentError("An ODESystem can only have one independent variable."))
86118
iv = first(ivs)
87-
88-
dvs = unique(skipmissing(var_from_nested_derivative(eq.lhs)[1] for eq eqs))
89-
ps = filter(vars(eq.rhs for eq eqs)) do x
90-
isparameter(x) & !isequal(x, iv)
91-
end |> collect
92-
ODESystem(eqs, iv, dvs, ps; kwargs...)
119+
eqs, dvs, ps = extract_eqs_states_ps(eqs, iv)
120+
return ODESystem(eqs, iv, dvs, ps; kwargs...)
93121
end
94122

95123
Base.:(==)(sys1::ODESystem, sys2::ODESystem) =

0 commit comments

Comments
 (0)