Skip to content

Commit 6b8fbc3

Browse files
finish optimizationsystem
1 parent 205a264 commit 6b8fbc3

File tree

13 files changed

+142
-96
lines changed

13 files changed

+142
-96
lines changed

src/ModelingToolkit.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ include("systems/diffeqs/validation.jl")
9999

100100
include("systems/nonlinear/nonlinearsystem.jl")
101101

102-
#include("systems/optimization/optimizationsystem.jl")
102+
include("systems/optimization/optimizationsystem.jl")
103103

104104
include("systems/pde/pdesystem.jl")
105105

@@ -110,7 +110,7 @@ include("build_function.jl")
110110

111111
export ODESystem, ODEFunction
112112
export SDESystem, SDEFunction
113-
export NonlinearSystem
113+
export NonlinearSystem, OptimizationSystem
114114
export ode_order_lowering
115115
export PDESystem
116116
export Reaction, ReactionSystem
@@ -119,12 +119,14 @@ export IntervalDomain, ProductDomain, ⊗, CircleDomain
119119
export Equation, ConstrainedEquation
120120
export simplify_constants
121121

122-
export Operation, Expression
122+
export Operation, Expression, Variable
123123
export calculate_jacobian, generate_jacobian, generate_function
124+
export calculate_tgrad, generate_tgrad
125+
export calculate_hessian, generate_hessian
124126
export calculate_massmatrix, generate_diffusion_function
125127
export independent_variable, states, parameters, equations
126-
export simplified_expr, eval_function
127-
export @register, @I
128+
export simplified_expr
129+
export @register
128130
export modelingtoolkitize
129-
export Variable, @variables, @parameters
131+
export @variables, @parameters
130132
end # module

src/build_function.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,42 @@ function build_function(args...;target = JuliaTarget(),kwargs...)
88
_build_function(target,args...;kwargs...)
99
end
1010

11+
# Scalar output
12+
function _build_function(target::JuliaTarget, op::Operation, vs, ps = (), args = (),
13+
conv = simplified_expr, expression = Val{true};
14+
checkbounds = false, constructor=nothing,
15+
linenumbers = true)
16+
_vs = map(x-> x isa Operation ? x.op : x, vs)
17+
_ps = map(x-> x isa Operation ? x.op : x, ps)
18+
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
19+
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(_ps)]
20+
(ls, rs) = zip(var_pairs..., param_pairs...)
21+
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
22+
23+
fname = gensym(:ModelingToolkitFunction)
24+
out_expr = conv(op)
25+
let_expr = Expr(:let, var_eqs, out_expr)
26+
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
27+
28+
fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))
29+
30+
oop_ex = :(
31+
($(fargs.args...),) -> begin
32+
$bounds_block
33+
end
34+
)
35+
36+
if !linenumbers
37+
oop_ex = striplines(oop_ex)
38+
end
39+
40+
if expression == Val{true}
41+
return oop_ex
42+
else
43+
return GeneralizedGenerated.mk_function(@__MODULE__,oop_ex)
44+
end
45+
end
46+
1147
function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
1248
conv = simplified_expr, expression = Val{true};
1349
checkbounds = false, constructor=nothing,

src/systems/abstractsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function Base.getproperty(sys::AbstractSystem, name::Symbol)
1010
i = findfirst(x->x.name==name,sys.states)
1111
if i !== nothing
1212
x = rename(sys.states[i],renamespace(sys.name,name))
13-
if iv fieldnames(typeof(sys))
13+
if :iv fieldnames(typeof(sys))
1414
return x(getfield(sys,:iv)())
1515
else
1616
return x()
@@ -22,7 +22,7 @@ function Base.getproperty(sys::AbstractSystem, name::Symbol)
2222
return rename(sys.ps[i],renamespace(sys.name,name))()
2323
end
2424
end
25-
throw(error("Variable name does not exist"))
25+
throw(error("Variable $name does not exist"))
2626
end
2727

2828
renamespace(namespace,name) = Symbol(string(namespace)*""*string(name))

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ end
4646

4747
function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
4848
rhss = [deq.rhs for deq equations(sys)]
49-
dvs′ = convert.(Variable,states)
49+
dvs′ = convert.(Variable,dvs)
5050
ps′ = convert.(Variable,ps)
5151
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
5252
end

src/systems/diffeqs/odesystem.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ end
6464
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
6565
systems = ODESystem[],
6666
name=gensym(:ODESystem))
67-
iv′ = clean(iv)
68-
dvs′ = [clean(dv) for dv dvs]
69-
ps′ = [clean(p) for p ps]
67+
iv′ = convert(Variable,iv)
68+
dvs′ = convert.(Variable,dvs)
69+
ps′ = convert.(Variable,ps)
7070
tgrad = RefValue(Vector{Expression}(undef, 0))
7171
jac = RefValue(Matrix{Expression}(undef, 0, 0))
7272
Wfact = RefValue(Matrix{Expression}(undef, 0, 0))
@@ -92,7 +92,7 @@ end
9292

9393
Base.:(==)(sys1::ODESystem, sys2::ODESystem) =
9494
_eq_unordered(sys1.eqs, sys2.eqs) && isequal(sys1.iv, sys2.iv) &&
95-
_eq_unordered(sys1.dvs, sys2.dvs) && _eq_unordered(sys1.ps, sys2.ps)
95+
_eq_unordered(sys1.states, sys2.states) && _eq_unordered(sys1.ps, sys2.ps)
9696
# NOTE: equality does not check cached Jacobian
9797

9898
function rename(sys::ODESystem,name)

src/systems/diffeqs/sdesystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ end
4242
function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
4343
systems = SDESystem[],
4444
name = gensym(:SDESystem))
45-
dvs= [clean(dv) for dv dvs]
46-
ps= [clean(p) for p ps]
47-
iv= clean(iv)
45+
iv= convert(Variable,iv)
46+
dvs= convert.(Variable,dvs)
47+
ps= convert.(Variable,ps)
4848
tgrad = RefValue(Vector{Expression}(undef, 0))
4949
jac = RefValue(Matrix{Expression}(undef, 0, 0))
5050
Wfact = RefValue(Matrix{Expression}(undef, 0, 0))
@@ -53,8 +53,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
5353
end
5454

5555
function generate_diffusion_function(sys::SDESystem, dvs = sys.states, ps = sys.ps, expression = Val{true}; kwargs...)
56-
dvs′ = [clean(dv) for dv dvs]
57-
ps′ = [clean(p) for p ps]
56+
dvs′ = convert.(Variable,dvs)
57+
ps′ = convert.(Variable,ps)
5858
return build_function(sys.noiseeqs, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
5959
end
6060

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct NonlinearSystem <: AbstractSystem
2222
"""Vector of equations defining the system."""
2323
eqs::Vector{Equation}
2424
"""Unknown variables."""
25-
states::Vector{Expression}
25+
states::Vector{Variable}
2626
"""Parameters."""
2727
ps::Vector{Variable}
2828
"""
@@ -38,16 +38,16 @@ end
3838
function NonlinearSystem(eqs, states, ps;
3939
name = gensym(:NonlinearSystem),
4040
systems = NonlinearSystem[])
41-
NonlinearSystem(eqs, states, convert.(Variable,ps), name, systems)
41+
NonlinearSystem(eqs, convert.(Variable,states), convert.(Variable,ps), name, systems)
4242
end
4343

4444
function calculate_jacobian(sys::NonlinearSystem)
4545
rhs = [eq.rhs for eq in sys.eqs]
46-
jac = expand_derivatives.(calculate_jacobian(rhs, sys.states))
46+
jac = expand_derivatives.(calculate_jacobian(rhs, [dv() for dv in states(sys)]))
4747
return jac
4848
end
4949

50-
function generate_jacobian(sys::NonlinearSystem, vs = sys.states, ps = sys.ps, expression = Val{true}; kwargs...)
50+
function generate_jacobian(sys::NonlinearSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
5151
jac = calculate_jacobian(sys)
5252
return build_function(jac, convert.(Variable,vs), convert.(Variable,ps), (), NLSysToExpr(sys))
5353
end
@@ -65,7 +65,7 @@ function (f::NLSysToExpr)(O::Operation)
6565
end
6666
(f::NLSysToExpr)(x) = convert(Expr, x)
6767

68-
function generate_function(sys::NonlinearSystem, vs = sys.states, ps = sys.ps, expression = Val{true}; kwargs...)
68+
function generate_function(sys::NonlinearSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
6969
rhss = [eq.rhs for eq sys.eqs]
7070
vs′ = convert.(Variable,vs)
7171
ps′ = convert.(Variable,ps)

src/systems/optimization/optimizationsystem.jl

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct OptimizationSystem <: AbstractSystem
2020
"""Vector of equations defining the system."""
2121
op::Operation
2222
"""Unknown variables."""
23-
states::Vector{Expression}
23+
states::Vector{Variable}
2424
"""Parameters."""
2525
ps::Vector{Variable}
2626
"""
@@ -30,42 +30,29 @@ struct OptimizationSystem <: AbstractSystem
3030
"""
3131
systems: The internal systems
3232
"""
33-
systems::Vector{NonlinearSystem}
33+
systems::Vector{OptimizationSystem}
3434
end
3535

36-
function NonlinearSystem(eqs, states, ps;
37-
name = gensym(:NonlinearSystem),
38-
systems = NonlinearSystem[])
39-
NonlinearSystem(eqs, states, convert.(Variable,ps), name, systems)
36+
function OptimizationSystem(op, states, ps;
37+
name = gensym(:OptimizationSystem),
38+
systems = OptimizationSystem[])
39+
OptimizationSystem(op, convert.(Variable,states), convert.(Variable,ps), name, systems)
4040
end
4141

42-
function calculate_jacobian(sys::NonlinearSystem)
43-
rhs = [eq.rhs for eq in sys.eqs]
44-
jac = expand_derivatives.(calculate_jacobian(rhs, sys.states))
45-
return jac
42+
function calculate_hessian(sys::OptimizationSystem)
43+
expand_derivatives.(hessian(equations(sys), [dv() for dv in states(sys)]))
4644
end
4745

48-
function generate_jacobian(sys::NonlinearSystem, vs = sys.states, ps = sys.ps, expression = Val{true}; kwargs...)
49-
jac = calculate_jacobian(sys)
50-
return build_function(jac, convert.(Variable,vs), convert.(Variable,ps), (), NLSysToExpr(sys))
46+
function generate_hessian(sys::OptimizationSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
47+
hes = calculate_hessian(sys)
48+
return build_function(hes, convert.(Variable,vs), convert.(Variable,ps), (), x->convert(Expr, x))
5149
end
5250

53-
struct NLSysToExpr
54-
sys::NonlinearSystem
55-
end
56-
function (f::NLSysToExpr)(O::Operation)
57-
any(isequal(O), f.sys.states) && return O.op.name # variables
58-
if isa(O.op, Variable)
59-
isempty(O.args) && return O.op.name # 0-ary parameters
60-
return build_expr(:call, Any[O.op.name; f.(O.args)])
61-
end
62-
return build_expr(:call, Any[O.op; f.(O.args)])
63-
end
64-
(f::NLSysToExpr)(x) = convert(Expr, x)
65-
66-
function generate_function(sys::NonlinearSystem, vs = sys.states, ps = sys.ps, expression = Val{true}; kwargs...)
67-
rhss = [eq.rhs for eq sys.eqs]
51+
function generate_function(sys::OptimizationSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
6852
vs′ = convert.(Variable,vs)
6953
ps′ = convert.(Variable,ps)
70-
return build_function(rhss, vs′, ps′, (), NLSysToExpr(sys), expression; kwargs...)
54+
return build_function(equations(sys), vs′, ps′, (), x->convert(Expr, x), expression; kwargs...)
7155
end
56+
57+
equations(sys::OptimizationSystem) = isempty(sys.systems) ? sys.op : sys.op + reduce(+,namespace_operation.(sys.systems))
58+
namespace_operation(sys::OptimizationSystem) = namespace_operation(sys.op,sys.name,nothing)

test/build_targets.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ModelingToolkit, Test
44
@derivatives D'~t
55
eqs = [D(x) ~ a*x - x*y,
66
D(y) ~ -3y + x*y]
7-
@test ModelingToolkit.build_function(eqs,ModelingToolkit.clean.([x,y]),ModelingToolkit.clean.([a]),t,target = ModelingToolkit.StanTarget()) ==
7+
@test ModelingToolkit.build_function(eqs,convert.(Variable,[x,y]),convert.(Variable,[a]),t,target = ModelingToolkit.StanTarget()) ==
88
"""
99
real[] diffeqf(real t,real[] internal_var___u,real[] internal_var___p,real[] x_r,int[] x_i) {
1010
real internal_var___du[2];
@@ -38,10 +38,10 @@ sys = ODESystem(eqs,t,[x,y],[a])
3838
ModelingToolkit.build_function(sys.eqs,[x,y],[a],t,target = ModelingToolkit.MATLABTarget())
3939

4040
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.CTarget()) ==
41-
ModelingToolkit.build_function(sys.eqs,sys.dvs,sys.ps,sys.iv,target = ModelingToolkit.CTarget())
41+
ModelingToolkit.build_function(sys.eqs,sys.states,sys.ps,sys.iv,target = ModelingToolkit.CTarget())
4242

4343
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.StanTarget()) ==
44-
ModelingToolkit.build_function(sys.eqs,sys.dvs,sys.ps,sys.iv,target = ModelingToolkit.StanTarget())
44+
ModelingToolkit.build_function(sys.eqs,sys.states,sys.ps,sys.iv,target = ModelingToolkit.StanTarget())
4545

4646
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.MATLABTarget()) ==
47-
ModelingToolkit.build_function(sys.eqs,sys.dvs,sys.ps,sys.iv,target = ModelingToolkit.MATLABTarget())
47+
ModelingToolkit.build_function(sys.eqs,sys.states,sys.ps,sys.iv,target = ModelingToolkit.MATLABTarget())

test/derivatives.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33

44
# Derivatives
55
@parameters t σ ρ β
6-
@variables x(t) y(t) z(t)
6+
@variables x y z
77
@derivatives D'~t D2''~t Dx'~x
88

99
@test @macroexpand(@derivatives D'~t D2''~t) == @macroexpand(@derivatives (D'~t), (D2''~t))
@@ -36,7 +36,7 @@ d2 = D(sin(t)*cos(t))
3636
eqs = [0 ~ σ*(y-x),
3737
0 ~ x*-z)-y,
3838
0 ~ x*y - β*z]
39-
sys = NonlinearSystem(eqs, [x,y,z])
39+
sys = NonlinearSystem(eqs, [x,y,z], [σ,ρ,β])
4040
jac = calculate_jacobian(sys)
4141
@test isequal(jac[1,1], σ*-1)
4242
@test isequal(jac[1,2], σ)
@@ -54,6 +54,8 @@ jac = calculate_jacobian(sys)
5454
@test isequal(expand_derivatives(D(t)), 1)
5555
@test isequal(expand_derivatives(Dx(x)), 1)
5656

57+
@variables x(t) y(t) z(t)
58+
5759
@test isequal(expand_derivatives(D(x * y)), simplify_constants(y*D(x) + x*D(y)))
5860
@test_broken isequal(expand_derivatives(D(x * y)), simplify_constants(D(x)*y + x*D(y)))
5961

0 commit comments

Comments
 (0)