Skip to content

Commit c26cde7

Browse files
Merge pull request #400 from SciML/io
[WIP] Causal Input/Output variables, explicit algebraic variables, aliases, and open connections
2 parents 0783668 + fad0dd8 commit c26cde7

File tree

10 files changed

+137
-20
lines changed

10 files changed

+137
-20
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ export Differential, expand_derivatives, @derivatives
132132
export IntervalDomain, ProductDomain, , CircleDomain
133133
export Equation, ConstrainedEquation
134134
export Operation, Expression, Variable
135-
export independent_variable, states, parameters, equations
135+
export independent_variable, states, parameters, equations, pins, observed
136136

137137
export calculate_jacobian, generate_jacobian, generate_function
138138
export calculate_tgrad, generate_tgrad

src/symutils.jl

Whitespace-only changes.

src/systems/abstractsystem.jl

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ Generate a function to evaluate the system's equations.
118118
function generate_function end
119119

120120
function Base.getproperty(sys::AbstractSystem, name::Symbol)
121+
121122
if name fieldnames(typeof(sys))
122123
return getfield(sys,name)
123124
elseif !isempty(sys.systems)
@@ -126,6 +127,7 @@ function Base.getproperty(sys::AbstractSystem, name::Symbol)
126127
return rename(sys.systems[i],renamespace(sys.name,name))
127128
end
128129
end
130+
129131
i = findfirst(x->x.name==name,sys.states)
130132
if i !== nothing
131133
x = rename(sys.states[i],renamespace(sys.name,name))
@@ -135,12 +137,21 @@ function Base.getproperty(sys::AbstractSystem, name::Symbol)
135137
return x()
136138
end
137139
end
140+
138141
if :ps fieldnames(typeof(sys))
139142
i = findfirst(x->x.name==name,sys.ps)
140143
if i !== nothing
141144
return rename(sys.ps[i],renamespace(sys.name,name))()
142145
end
143146
end
147+
148+
if :observed fieldnames(typeof(sys))
149+
i = findfirst(x->convert(Variable,x.lhs).name==name,sys.observed)
150+
if i !== nothing
151+
return rename(convert(Variable,sys.observed[i].lhs),renamespace(sys.name,name))(getfield(sys,:iv)())
152+
end
153+
end
154+
144155
throw(error("Variable $name does not exist"))
145156
end
146157

@@ -154,8 +165,13 @@ function namespace_parameters(sys::AbstractSystem)
154165
[rename(x,renamespace(sys.name,x.name)) for x in parameters(sys)]
155166
end
156167

168+
function namespace_pins(sys::AbstractSystem)
169+
[rename(x,renamespace(sys.name,x.name)) for x in pins(sys)]
170+
end
171+
157172
namespace_equations(sys::AbstractSystem) = namespace_equation.(equations(sys),sys.name,sys.iv.name)
158173

174+
159175
function namespace_equation(eq::Equation,name,ivname)
160176
_lhs = namespace_operation(eq.lhs,name,ivname)
161177
_rhs = namespace_operation(eq.rhs,name,ivname)
@@ -172,11 +188,14 @@ end
172188
namespace_operation(O::Constant,name,ivname) = O
173189

174190
independent_variable(sys::AbstractSystem) = sys.iv
175-
states(sys::AbstractSystem) = isempty(sys.systems) ? sys.states : [sys.states;reduce(vcat,namespace_variables.(sys.systems))]
191+
states(sys::AbstractSystem) = isempty(sys.systems) ? setdiff(sys.states, convert.(Variable,sys.pins)) : [sys.states;reduce(vcat,namespace_variables.(sys.systems))]
176192
parameters(sys::AbstractSystem) = isempty(sys.systems) ? sys.ps : [sys.ps;reduce(vcat,namespace_parameters.(sys.systems))]
177-
178-
function equations(sys::AbstractSystem)
179-
isempty(sys.systems) ? sys.eqs : [sys.eqs;reduce(vcat,namespace_equations.(sys.systems))]
193+
pins(sys::AbstractSystem) = isempty(sys.systems) ? sys.pins : [sys.pins;reduce(vcat,namespace_pins.(sys.systems))]
194+
function observed(sys::AbstractSystem)
195+
[sys.observed;
196+
reduce(vcat,
197+
(namespace_equation.(s.observed, s.name, s.iv.name) for s in sys.systems),
198+
init=Equation[])]
180199
end
181200

182201
function states(sys::AbstractSystem,name::Symbol)
@@ -189,6 +208,34 @@ function parameters(sys::AbstractSystem,name::Symbol)
189208
rename(x,renamespace(sys.name,x.name))()
190209
end
191210

211+
function pins(sys::AbstractSystem,name::Symbol)
212+
x = sys.pins[findfirst(x->x.name==name,sys.ps)]
213+
rename(x,renamespace(sys.name,x.name))(sys.iv())
214+
end
215+
216+
lhss(xs) = map(x->x.lhs, xs)
217+
rhss(xs) = map(x->x.rhs, xs)
218+
219+
function equations(sys::ModelingToolkit.AbstractSystem; remove_aliases = true)
220+
if isempty(sys.systems)
221+
return sys.eqs
222+
else
223+
eqs = [sys.eqs;
224+
reduce(vcat,
225+
namespace_equations.(sys.systems);
226+
init=Equation[])]
227+
228+
if !remove_aliases
229+
return eqs
230+
end
231+
aliases = observed(sys)
232+
dict = Dict(lhss(aliases) .=> rhss(aliases))
233+
234+
# Substitute aliases
235+
return Equation.(lhss(eqs), Rewriters.Fixpoint(x->substitute(x, dict)).(rhss(eqs)))
236+
end
237+
end
238+
192239
function states(sys::AbstractSystem,args...)
193240
name = last(args)
194241
extra_names = reduce(Symbol,[Symbol(:₊,x.name) for x in args[1:end-1]])
@@ -212,6 +259,13 @@ function islinear(sys::AbstractSystem)
212259
all(islinear(r, dvs) for r in rhs)
213260
end
214261

262+
function pins(sys::AbstractSystem,args...)
263+
name = last(args)
264+
extra_names = reduce(Symbol,[Symbol(:₊,x.name) for x in args[1:end-1]])
265+
newname = renamespace(extra_names,name)
266+
rename(x,renamespace(sys.name,newname))(sys.iv())
267+
end
268+
215269
struct AbstractSysToExpr
216270
sys::AbstractSystem
217271
states::Vector{Variable}

src/systems/diffeqs/odesystem.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ struct ODESystem <: AbstractODESystem
3131
states::Vector{Variable}
3232
"""Parameter variables."""
3333
ps::Vector{Variable}
34+
pins::Vector{Variable}
35+
observed::Vector{Equation}
3436
"""
3537
Time-derivative matrix. Note: this field will not be defined until
3638
[`calculate_tgrad`](@ref) is called on the system.
@@ -62,6 +64,8 @@ struct ODESystem <: AbstractODESystem
6264
end
6365

6466
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
67+
pins = Variable[],
68+
observed = Operation[],
6569
systems = ODESystem[],
6670
name=gensym(:ODESystem))
6771
iv′ = convert(Variable,iv)
@@ -71,7 +75,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
7175
jac = RefValue{Any}(Matrix{Expression}(undef, 0, 0))
7276
Wfact = RefValue(Matrix{Expression}(undef, 0, 0))
7377
Wfact_t = RefValue(Matrix{Expression}(undef, 0, 0))
74-
ODESystem(deqs, iv′, dvs′, ps′, tgrad, jac, Wfact, Wfact_t, name, systems)
78+
ODESystem(deqs, iv′, dvs′, ps′, pins, observed, tgrad, jac, Wfact, Wfact_t, name, systems)
7579
end
7680

7781
var_from_nested_derivative(x::Constant) = (missing, missing)

src/systems/diffeqs/sdesystem.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ struct SDESystem <: AbstractODESystem
3737
states::Vector{Variable}
3838
"""Parameter variables."""
3939
ps::Vector{Variable}
40+
pins::Vector{Variable}
41+
observed::Vector{Equation}
4042
"""
4143
Time-derivative matrix. Note: this field will not be defined until
4244
[`calculate_tgrad`](@ref) is called on the system.
@@ -68,6 +70,8 @@ struct SDESystem <: AbstractODESystem
6870
end
6971

7072
function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
73+
pins = Variable[],
74+
observed = Operation[],
7175
systems = SDESystem[],
7276
name = gensym(:SDESystem))
7377
iv′ = convert(Variable,iv)
@@ -77,7 +81,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
7781
jac = RefValue(Matrix{Expression}(undef, 0, 0))
7882
Wfact = RefValue(Matrix{Expression}(undef, 0, 0))
7983
Wfact_t = RefValue(Matrix{Expression}(undef, 0, 0))
80-
SDESystem(deqs, neqs, iv′, dvs′, ps′, tgrad, jac, Wfact, Wfact_t, name, systems)
84+
SDESystem(deqs, neqs, iv′, dvs′, ps′, pins, observed, tgrad, jac, Wfact, Wfact_t, name, systems)
8185
end
8286

8387
function generate_diffusion_function(sys::SDESystem, dvs = sys.states, ps = sys.ps; kwargs...)

src/systems/jumps/jumpsystem.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,20 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractSystem
3737
states::Vector{Variable}
3838
"""The parameters of the system."""
3939
ps::Vector{Variable}
40+
pins::Vector{Variable}
41+
observed::Vector{Equation}
4042
"""The name of the system."""
4143
name::Symbol
4244
"""The internal systems."""
4345
systems::Vector{JumpSystem}
4446
end
4547

46-
function JumpSystem(eqs, iv, states, ps; systems = JumpSystem[],
47-
name = gensym(:JumpSystem))
48+
function JumpSystem(eqs, iv, states, ps;
49+
pins = Variable[],
50+
observed = Equation[],
51+
systems = JumpSystem[],
52+
name = gensym(:JumpSystem))
53+
4854
ap = ArrayPartition(MassActionJump[], ConstantRateJump[], VariableRateJump[])
4955
for eq in eqs
5056
if eq isa MassActionJump
@@ -58,13 +64,9 @@ function JumpSystem(eqs, iv, states, ps; systems = JumpSystem[],
5864
end
5965
end
6066

61-
JumpSystem{typeof(ap)}(ap, convert(Variable,iv), convert.(Variable, states), convert.(Variable, ps), name, systems)
67+
JumpSystem{typeof(ap)}(ap, convert(Variable,iv), convert.(Variable, states), convert.(Variable, ps), pins, observed, name, systems)
6268
end
6369

64-
JumpSystem(eqs::ArrayPartition, iv, states, ps; systems = JumpSystem[], name = gensym(:JumpSystem)) =
65-
JumpSystem{typeof(eqs)}(eqs, convert(Variable,iv), convert.(Variable, states), convert.(Variable, ps), name, systems)
66-
67-
6870
generate_rate_function(js, rate) = build_function(rate, states(js), parameters(js),
6971
independent_variable(js),
7072
expression=Val{true})

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ struct NonlinearSystem <: AbstractSystem
2525
states::Vector{Variable}
2626
"""Parameters."""
2727
ps::Vector{Variable}
28+
pins::Vector{Variable}
29+
observed::Vector{Equation}
2830
"""
2931
Name: the name of the system
3032
"""
@@ -36,9 +38,11 @@ struct NonlinearSystem <: AbstractSystem
3638
end
3739

3840
function NonlinearSystem(eqs, states, ps;
41+
pins = Variable[],
42+
observed = Operation[],
3943
name = gensym(:NonlinearSystem),
4044
systems = NonlinearSystem[])
41-
NonlinearSystem(eqs, convert.(Variable,states), convert.(Variable,ps), name, systems)
45+
NonlinearSystem(eqs, convert.(Variable,states), convert.(Variable,ps), pins, observed, name, systems)
4246
end
4347

4448
function calculate_jacobian(sys::NonlinearSystem;sparse=false,simplify=true)

src/systems/optimization/optimizationsystem.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ struct OptimizationSystem <: AbstractSystem
2323
states::Vector{Variable}
2424
"""Parameters."""
2525
ps::Vector{Variable}
26+
pins::Vector{Variable}
27+
observed::Vector{Equation}
2628
"""
2729
Name: the name of the system
2830
"""
@@ -34,9 +36,11 @@ struct OptimizationSystem <: AbstractSystem
3436
end
3537

3638
function OptimizationSystem(op, states, ps;
39+
pins = Variable[],
40+
observed = Operation[],
3741
name = gensym(:OptimizationSystem),
3842
systems = OptimizationSystem[])
39-
OptimizationSystem(op, convert.(Variable,states), convert.(Variable,ps), name, systems)
43+
OptimizationSystem(op, convert.(Variable,states), convert.(Variable,ps), pins, observed, name, systems)
4044
end
4145

4246
function calculate_gradient(sys::OptimizationSystem)

src/systems/reaction/reactionsystem.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,20 +129,25 @@ struct ReactionSystem <: AbstractSystem
129129
states::Vector{Variable}
130130
"""Parameter variables."""
131131
ps::Vector{Variable}
132+
pins::Vector{Variable}
133+
observed::Vector{Equation}
132134
"""The name of the system"""
133135
name::Symbol
134136
"""systems: The internal systems"""
135137
systems::Vector{ReactionSystem}
136138
end
137139

138-
function ReactionSystem(eqs, iv, species, params; systems = ReactionSystem[],
139-
name = gensym(:ReactionSystem))
140-
140+
function ReactionSystem(eqs, iv, species, params;
141+
pins = Variable[],
142+
observed = Operation[],
143+
systems = ReactionSystem[],
144+
name = gensym(:ReactionSystem))
141145

142146
isempty(species) && error("ReactionSystems require at least one species.")
143147
paramvars = map(v -> convert(Variable,v), params)
144148
specvars = map(s -> convert(Variable,s), species)
145-
ReactionSystem(eqs, convert(Variable,iv), specvars, paramvars, name, systems)
149+
ReactionSystem(eqs, convert(Variable,iv), specvars, paramvars,
150+
pins, observed, name, systems)
146151
end
147152

148153
"""
@@ -294,6 +299,7 @@ explicitly on the independent variable (usually time).
294299
function ismassaction(rx, rs; rxvars = get_variables(rx.rate),
295300
haveivdep = any(var -> isequal(rs.iv,convert(Variable,var)), rxvars),
296301
stateset = Set(states(rs)))
302+
return !(haveivdep || rx.only_use_rate || any(convert(Variable,rxv) in states(rs) for rxv in rxvars))
297303
# if no dependencies must be zero order
298304
(length(rxvars)==0) && return true
299305
(haveivdep || rx.only_use_rate) && return false

test/inputoutput.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using ModelingToolkit, OrdinaryDiffEq, Test
2+
3+
@parameters t σ ρ β
4+
@variables x(t) y(t) z(t) F(t) u(t)
5+
@derivatives D'~t
6+
7+
eqs = [D(x) ~ σ*(y-x) + F,
8+
D(y) ~ x*-z)-y,
9+
D(z) ~ x*y - β*z]
10+
11+
aliases = [u ~ x + y - z]
12+
lorenz1 = ODESystem(eqs,pins=[F],observed=aliases,name=:lorenz1)
13+
lorenz2 = ODESystem(eqs,pins=[F],observed=aliases,name=:lorenz2)
14+
15+
connections = [lorenz1.F ~ lorenz2.u,
16+
lorenz2.F ~ lorenz1.u]
17+
connected = ODESystem(Equation[],t,[],[],observed=connections,systems=[lorenz1,lorenz2])
18+
19+
sys = connected
20+
21+
@variables lorenz1₊F lorenz2₊F
22+
@test pins(connected) == Variable[lorenz1₊F, lorenz2₊F]
23+
@test isequal(observed(connected),
24+
[connections...,
25+
lorenz1.u ~ lorenz1.x + lorenz1.y - lorenz1.z,
26+
lorenz2.u ~ lorenz2.x + lorenz2.y - lorenz2.z])
27+
28+
collapsed_eqs = [D(lorenz1.x) ~ (lorenz1.σ * (lorenz1.y - lorenz1.x) +
29+
(lorenz2.x + lorenz2.y - lorenz2.z)),
30+
D(lorenz1.y) ~ lorenz1.x * (lorenz1.ρ - lorenz1.z) - lorenz1.y,
31+
D(lorenz1.z) ~ lorenz1.x * lorenz1.y - (lorenz1.β * lorenz1.z),
32+
D(lorenz2.x) ~ (lorenz2.σ * (lorenz2.y - lorenz2.x) +
33+
(lorenz1.x + lorenz1.y - lorenz1.z)),
34+
D(lorenz2.y) ~ lorenz2.x * (lorenz2.ρ - lorenz2.z) - lorenz2.y,
35+
D(lorenz2.z) ~ lorenz2.x * lorenz2.y - (lorenz2.β * lorenz2.z)]
36+
37+
simplifyeqs(eqs) = Equation.((x->x.lhs).(eqs), simplify.((x->x.rhs).(eqs)))
38+
39+
@test isequal(simplifyeqs(equations(connected)), simplifyeqs(collapsed_eqs))

0 commit comments

Comments
 (0)