Skip to content

Commit fc5a147

Browse files
Merge pull request #278 from SciML/abstract_system
Refactor to AbstractSystem interface and create OptimizationSystem
2 parents 2b59e8a + 6b8fbc3 commit fc5a147

18 files changed

+394
-284
lines changed

src/ModelingToolkit.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,21 +87,30 @@ include("simplify.jl")
8787
include("utils.jl")
8888
include("direct.jl")
8989
include("domains.jl")
90+
91+
include("systems/abstractsystem.jl")
92+
9093
include("systems/diffeqs/odesystem.jl")
9194
include("systems/diffeqs/sdesystem.jl")
9295
include("systems/diffeqs/abstractodesystem.jl")
9396
include("systems/diffeqs/first_order_transform.jl")
9497
include("systems/diffeqs/modelingtoolkitize.jl")
9598
include("systems/diffeqs/validation.jl")
96-
include("systems/nonlinear/nonlinear_system.jl")
99+
100+
include("systems/nonlinear/nonlinearsystem.jl")
101+
102+
include("systems/optimization/optimizationsystem.jl")
103+
97104
include("systems/pde/pdesystem.jl")
105+
98106
include("systems/reaction/reactionsystem.jl")
107+
99108
include("latexify_recipes.jl")
100109
include("build_function.jl")
101110

102111
export ODESystem, ODEFunction
103112
export SDESystem, SDEFunction
104-
export NonlinearSystem
113+
export NonlinearSystem, OptimizationSystem
105114
export ode_order_lowering
106115
export PDESystem
107116
export Reaction, ReactionSystem
@@ -110,12 +119,14 @@ export IntervalDomain, ProductDomain, ⊗, CircleDomain
110119
export Equation, ConstrainedEquation
111120
export simplify_constants
112121

113-
export Operation, Expression
122+
export Operation, Expression, Variable
114123
export calculate_jacobian, generate_jacobian, generate_function
124+
export calculate_tgrad, generate_tgrad
125+
export calculate_hessian, generate_hessian
115126
export calculate_massmatrix, generate_diffusion_function
116127
export independent_variable, states, parameters, equations
117-
export simplified_expr, eval_function
118-
export @register, @I
128+
export simplified_expr
129+
export @register
119130
export modelingtoolkitize
120-
export Variable, @variables, @parameters
131+
export @variables, @parameters
121132
end # module

src/build_function.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,42 @@ function build_function(args...;target = JuliaTarget(),kwargs...)
88
_build_function(target,args...;kwargs...)
99
end
1010

11+
# Scalar output
12+
function _build_function(target::JuliaTarget, op::Operation, vs, ps = (), args = (),
13+
conv = simplified_expr, expression = Val{true};
14+
checkbounds = false, constructor=nothing,
15+
linenumbers = true)
16+
_vs = map(x-> x isa Operation ? x.op : x, vs)
17+
_ps = map(x-> x isa Operation ? x.op : x, ps)
18+
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
19+
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(_ps)]
20+
(ls, rs) = zip(var_pairs..., param_pairs...)
21+
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
22+
23+
fname = gensym(:ModelingToolkitFunction)
24+
out_expr = conv(op)
25+
let_expr = Expr(:let, var_eqs, out_expr)
26+
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
27+
28+
fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))
29+
30+
oop_ex = :(
31+
($(fargs.args...),) -> begin
32+
$bounds_block
33+
end
34+
)
35+
36+
if !linenumbers
37+
oop_ex = striplines(oop_ex)
38+
end
39+
40+
if expression == Val{true}
41+
return oop_ex
42+
else
43+
return GeneralizedGenerated.mk_function(@__MODULE__,oop_ex)
44+
end
45+
end
46+
1147
function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
1248
conv = simplified_expr, expression = Val{true};
1349
checkbounds = false, constructor=nothing,

src/equations.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,16 @@ end
4545

4646

4747
Base.Expr(op::Equation) = simplified_expr(op)
48+
49+
function _eq_unordered(a, b)
50+
length(a) === length(b) || return false
51+
n = length(a)
52+
idxs = Set(1:n)
53+
for x a
54+
idx = findfirst(isequal(x), b)
55+
idx === nothing && return false
56+
idx idxs || return false
57+
delete!(idxs, idx)
58+
end
59+
return true
60+
end

src/systems/abstractsystem.jl

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
function Base.getproperty(sys::AbstractSystem, name::Symbol)
2+
if name fieldnames(typeof(sys))
3+
return getfield(sys,name)
4+
elseif !isempty(sys.systems)
5+
i = findfirst(x->x.name==name,sys.systems)
6+
if i !== nothing
7+
return rename(sys.systems[i],renamespace(sys.name,name))
8+
end
9+
end
10+
i = findfirst(x->x.name==name,sys.states)
11+
if i !== nothing
12+
x = rename(sys.states[i],renamespace(sys.name,name))
13+
if :iv fieldnames(typeof(sys))
14+
return x(getfield(sys,:iv)())
15+
else
16+
return x()
17+
end
18+
end
19+
if :ps fieldnames(typeof(sys))
20+
i = findfirst(x->x.name==name,sys.ps)
21+
if i !== nothing
22+
return rename(sys.ps[i],renamespace(sys.name,name))()
23+
end
24+
end
25+
throw(error("Variable $name does not exist"))
26+
end
27+
28+
renamespace(namespace,name) = Symbol(string(namespace)*""*string(name))
29+
30+
function namespace_variables(sys::AbstractSystem)
31+
[rename(x,renamespace(sys.name,x.name)) for x in states(sys)]
32+
end
33+
34+
function namespace_parameters(sys::AbstractSystem)
35+
[rename(x,renamespace(sys.name,x.name)) for x in parameters(sys)]
36+
end
37+
38+
namespace_equations(sys::AbstractSystem) = namespace_equation.(equations(sys),sys.name,sys.iv.name)
39+
40+
function namespace_equation(eq::Equation,name,ivname)
41+
_lhs = namespace_operation(eq.lhs,name,ivname)
42+
_rhs = namespace_operation(eq.rhs,name,ivname)
43+
_lhs ~ _rhs
44+
end
45+
46+
function namespace_operation(O::Operation,name,ivname)
47+
if O.op isa Variable && O.op.name != ivname
48+
Operation(rename(O.op,renamespace(name,O.op.name)),namespace_operation.(O.args,name,ivname))
49+
else
50+
Operation(O.op,namespace_operation.(O.args,name,ivname))
51+
end
52+
end
53+
namespace_operation(O::Constant,name,ivname) = O
54+
55+
independent_variable(sys::AbstractSystem) = sys.iv
56+
states(sys::AbstractSystem) = isempty(sys.systems) ? sys.states : [sys.states;reduce(vcat,namespace_variables.(sys.systems))]
57+
parameters(sys::AbstractSystem) = isempty(sys.systems) ? sys.ps : [sys.ps;reduce(vcat,namespace_parameters.(sys.systems))]
58+
59+
function equations(sys::AbstractSystem)
60+
isempty(sys.systems) ? sys.eqs : [sys.eqs;reduce(vcat,namespace_equations.(sys.systems))]
61+
end
62+
63+
function states(sys::AbstractSystem,name::Symbol)
64+
x = sys.states[findfirst(x->x.name==name,sys.states)]
65+
Variable(Symbol(string(sys.name)*""*string(x.name)))(sys.iv())
66+
end
67+
68+
function parameters(sys::AbstractSystem,name::Symbol)
69+
x = sys.ps[findfirst(x->x.name==name,sys.ps)]
70+
Variable(Symbol(string(sys.name)*""*string(x.name)))(sys.iv())
71+
end
72+
73+
function states(sys::AbstractSystem,args...)
74+
name = last(args)
75+
extra_names = reduce(*,["$(x.name)" for x in args[1:end-1]])
76+
Variable(Symbol(string(sys.name)*extra_names*""*string(name)))(sys.iv())
77+
end
78+
79+
function parameters(sys::AbstractSystem,args...)
80+
name = last(args)
81+
extra_names = reduce(*,["$(x.name)" for x in args[1:end-1]])
82+
Variable(Symbol(string(sys.name)*extra_names*""*string(name)))(sys.iv())
83+
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 4 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ end
2626
function (f::ODEToExpr)(O::Operation)
2727
if isa(O.op, Variable)
2828
isequal(O.op, f.sys.iv) && return O.op.name # independent variable
29-
O.op f.sys.dvs && return O.op.name # dependent variables
29+
O.op f.sys.states && return O.op.name # dependent variables
3030
isempty(O.args) && return O.op.name # 0-ary parameters
3131
return build_expr(:call, Any[O.op.name; f.(O.args)])
3232
end
@@ -46,8 +46,8 @@ end
4646

4747
function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
4848
rhss = [deq.rhs for deq equations(sys)]
49-
dvs′ = [clean(dv) for dv dvs]
50-
ps′ = [clean(p) for p ps]
49+
dvs′ = convert.(Variable,dvs)
50+
ps′ = convert.(Variable,ps)
5151
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
5252
end
5353

@@ -107,82 +107,10 @@ function calculate_massmatrix(sys::AbstractODESystem, simplify=true)
107107
M == I ? I : M
108108
end
109109

110-
renamespace(namespace,name) = Symbol(string(namespace)*""*string(name))
111-
112110
function DiffEqBase.ODEFunction(sys::AbstractODESystem, args...; kwargs...)
113111
ODEFunction{true}(sys, args...; kwargs...)
114112
end
115113

116-
function namespace_variables(sys::AbstractODESystem)
117-
[rename(x,renamespace(sys.name,x.name)) for x in states(sys)]
118-
end
119-
120-
function namespace_parameters(sys::AbstractODESystem)
121-
[rename(x,renamespace(sys.name,x.name)) for x in parameters(sys)]
122-
end
123-
124-
namespace_equations(sys::AbstractODESystem) = namespace_equation.(equations(sys),sys.name,sys.iv.name)
125-
126-
function namespace_equation(eq::Equation,name,ivname)
127-
_lhs = namespace_operation(eq.lhs,name,ivname)
128-
_rhs = namespace_operation(eq.rhs,name,ivname)
129-
_lhs ~ _rhs
130-
end
131-
132-
function namespace_operation(O::Operation,name,ivname)
133-
if O.op isa Variable && O.op.name != ivname
134-
Operation(rename(O.op,renamespace(name,O.op.name)),namespace_operation.(O.args,name,ivname))
135-
else
136-
Operation(O.op,namespace_operation.(O.args,name,ivname))
137-
end
138-
end
139-
namespace_operation(O::Constant,name,ivname) = O
140-
141-
142-
143-
independent_variable(sys::AbstractODESystem) = sys.iv
144-
states(sys::AbstractODESystem) = isempty(sys.systems) ? sys.dvs : [sys.dvs;reduce(vcat,namespace_variables.(sys.systems))]
145-
parameters(sys::AbstractODESystem) = isempty(sys.systems) ? sys.ps : [sys.ps;reduce(vcat,namespace_parameters.(sys.systems))]
146-
147-
function equations(sys::AbstractODESystem)
148-
isempty(sys.systems) ? sys.eqs : [sys.eqs;reduce(vcat,namespace_equations.(sys.systems))]
149-
end
150-
151-
function states(sys::AbstractODESystem,name::Symbol)
152-
x = sys.dvs[findfirst(x->x.name==name,sys.dvs)]
153-
Variable(Symbol(string(sys.name)*""*string(x.name)))(sys.iv())
154-
end
155-
156-
function parameters(sys::AbstractODESystem,name::Symbol)
157-
x = sys.ps[findfirst(x->x.name==name,sys.ps)]
158-
Variable(Symbol(string(sys.name)*""*string(x.name)))(sys.iv())
159-
end
160-
161-
function states(sys::AbstractODESystem,args...)
162-
name = last(args)
163-
extra_names = reduce(*,["$(x.name)" for x in args[1:end-1]])
164-
Variable(Symbol(string(sys.name)*extra_names*""*string(name)))(sys.iv())
165-
end
166-
167-
function parameters(sys::AbstractODESystem,args...)
168-
name = last(args)
169-
extra_names = reduce(*,["$(x.name)" for x in args[1:end-1]])
170-
Variable(Symbol(string(sys.name)*extra_names*""*string(name)))(sys.iv())
171-
end
172-
173-
function _eq_unordered(a, b)
174-
length(a) === length(b) || return false
175-
n = length(a)
176-
idxs = Set(1:n)
177-
for x a
178-
idx = findfirst(isequal(x), b)
179-
idx === nothing && return false
180-
idx idxs || return false
181-
delete!(idxs, idx)
182-
end
183-
return true
184-
end
185-
186114
"""
187115
$(SIGNATURES)
188116
@@ -234,25 +162,5 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
234162
Wfact = _Wfact,
235163
Wfact_t = _Wfact_t,
236164
mass_matrix = M,
237-
syms = Symbol.(sys.dvs))
238-
end
239-
240-
function Base.getproperty(sys::AbstractODESystem, name::Symbol)
241-
if name fieldnames(typeof(sys))
242-
return getfield(sys,name)
243-
elseif !isempty(sys.systems)
244-
i = findfirst(x->x.name==name,sys.systems)
245-
if i !== nothing
246-
return rename(sys.systems[i],renamespace(sys.name,name))
247-
end
248-
end
249-
i = findfirst(x->x.name==name,sys.dvs)
250-
if i !== nothing
251-
return rename(sys.dvs[i],renamespace(sys.name,name))(getfield(sys,:iv)())
252-
end
253-
i = findfirst(x->x.name==name,sys.ps)
254-
if i !== nothing
255-
return rename(sys.ps[i],renamespace(sys.name,name))()
256-
end
257-
throw(error("Variable name does not exist"))
165+
syms = Symbol.(sys.states))
258166
end

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Generate `ODESystem`, dependent variables, and parameters from an `ODEProblem`.
55
"""
66
function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
77
prob.f isa DiffEqBase.AbstractParameterizedFunction &&
8-
return (prob.f.sys, prob.f.sys.dvs, prob.f.sys.ps)
8+
return (prob.f.sys, prob.f.sys.states, prob.f.sys.ps)
99
@parameters t
1010
vars = [Variable(:x, i)(t) for i in eachindex(prob.u0)]
1111
params = prob.p isa DiffEqBase.NullParameters ? [] :

src/systems/diffeqs/odesystem.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ struct ODESystem <: AbstractODESystem
2828
"""Independent variable."""
2929
iv::Variable
3030
"""Dependent (state) variables."""
31-
dvs::Vector{Variable}
31+
states::Vector{Variable}
3232
"""Parameter variables."""
3333
ps::Vector{Variable}
3434
"""
@@ -64,9 +64,9 @@ end
6464
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
6565
systems = ODESystem[],
6666
name=gensym(:ODESystem))
67-
iv′ = clean(iv)
68-
dvs′ = [clean(dv) for dv dvs]
69-
ps′ = [clean(p) for p ps]
67+
iv′ = convert(Variable,iv)
68+
dvs′ = convert.(Variable,dvs)
69+
ps′ = convert.(Variable,ps)
7070
tgrad = RefValue(Vector{Expression}(undef, 0))
7171
jac = RefValue(Matrix{Expression}(undef, 0, 0))
7272
Wfact = RefValue(Matrix{Expression}(undef, 0, 0))
@@ -92,9 +92,9 @@ end
9292

9393
Base.:(==)(sys1::ODESystem, sys2::ODESystem) =
9494
_eq_unordered(sys1.eqs, sys2.eqs) && isequal(sys1.iv, sys2.iv) &&
95-
_eq_unordered(sys1.dvs, sys2.dvs) && _eq_unordered(sys1.ps, sys2.ps)
95+
_eq_unordered(sys1.states, sys2.states) && _eq_unordered(sys1.ps, sys2.ps)
9696
# NOTE: equality does not check cached Jacobian
9797

9898
function rename(sys::ODESystem,name)
99-
ODESystem(sys.eqs, sys.iv, sys.dvs, sys.ps, sys.tgrad, sys.jac, sys.Wfact, sys.Wfact_t, name, sys.systems)
99+
ODESystem(sys.eqs, sys.iv, sys.states, sys.ps, sys.tgrad, sys.jac, sys.Wfact, sys.Wfact_t, name, sys.systems)
100100
end

0 commit comments

Comments
 (0)