Skip to content

Commit c77b702

Browse files
Add alternative build targets
Build outputs for C, Stan, and "Octave"
1 parent b62bb85 commit c77b702

File tree

5 files changed

+193
-82
lines changed

5 files changed

+193
-82
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,5 +96,6 @@ include("systems/diffeqs/first_order_transform.jl")
9696
include("systems/nonlinear/nonlinear_system.jl")
9797
include("systems/pde/pdesystem.jl")
9898
include("latexify_recipes.jl")
99+
include("build_function.jl")
99100

100101
end # module

src/build_function.jl

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
abstract type BuildTargets end
2+
struct JuliaTarget <: BuildTargets end
3+
struct StanTarget <: BuildTargets end
4+
struct CTarget <: BuildTargets end
5+
struct MATLABTarget <: BuildTargets end
6+
7+
function build_function(args...;target = JuliaTarget(),kwargs...)
8+
_build_function(target,args...;kwargs...)
9+
end
10+
11+
function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
12+
conv = simplified_expr, expression = Val{true};
13+
checkbounds = false, constructor=nothing,
14+
linenumbers = true)
15+
_vs = map(x-> x isa Operation ? x.op : x, vs)
16+
_ps = map(x-> x isa Operation ? x.op : x, ps)
17+
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
18+
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(_ps)]
19+
(ls, rs) = zip(var_pairs..., param_pairs...)
20+
21+
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
22+
23+
fname = gensym(:ModelingToolkitFunction)
24+
25+
X = gensym(:MTIIPVar)
26+
if rhss isa SparseMatrixCSC
27+
ip_sys_exprs = [:($X.nzval[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss.nzval)]
28+
else
29+
ip_sys_exprs = [:($X[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss)]
30+
end
31+
32+
ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))
33+
34+
tuple_sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
35+
36+
if rhss isa Matrix
37+
arr_sys_expr = build_expr(:vcat, [build_expr(:row,[conv(rhs) for rhs rhss[i,:]]) for i in 1:size(rhss,1)])
38+
# : x because ??? what to do in the general case?
39+
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->(out = similar(typeof(u),$(size(rhss)...)); out .= x)) : constructor
40+
elseif typeof(rhss) <: Array && !(typeof(rhss) <: Vector)
41+
vector_form = build_expr(:vect, [conv(rhs) for rhs rhss])
42+
arr_sys_expr = :(reshape($vector_form,$(size(rhss)...)))
43+
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SArray{$(size(rhss)...)} : x->(out = similar(typeof(u),$(size(rhss)...)); out .= x)) : constructor
44+
elseif rhss isa SparseMatrixCSC
45+
vector_form = build_expr(:vect, [conv(rhs) for rhs nonzeros(rhss)])
46+
arr_sys_expr = :(SparseMatrixCSC{eltype(u),Int}($(size(rhss)...), $(rhss.colptr), $(rhss.rowval), $vector_form))
47+
# Static and sparse? Probably not a combo that will actually be hit, but give a default anyways
48+
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->x) : constructor
49+
else # Vector
50+
arr_sys_expr = build_expr(:vect, [conv(rhs) for rhs rhss])
51+
# Handle vector constructor separately using `typeof(u)` to support things like LabelledArrays
52+
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X)) : x->convert(typeof(u),x)) : constructor
53+
end
54+
55+
let_expr = Expr(:let, var_eqs, tuple_sys_expr)
56+
arr_let_expr = Expr(:let, var_eqs, arr_sys_expr)
57+
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
58+
arr_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end)
59+
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)
60+
61+
fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))
62+
63+
oop_ex = :(
64+
($(fargs.args...),) -> begin
65+
# If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
66+
if $(fargs.args[1]) isa Array || (!(typeof($(fargs.args[1])) <: StaticArray) && $(rhss isa SparseMatrixCSC))
67+
return $arr_bounds_block
68+
else
69+
X = $bounds_block
70+
construct = $_constructor
71+
return construct(X)
72+
end
73+
end
74+
)
75+
76+
iip_ex = :(
77+
($X,$(fargs.args...)) -> begin
78+
$ip_bounds_block
79+
nothing
80+
end
81+
)
82+
83+
if !linenumbers
84+
oop_ex = striplines(oop_ex)
85+
iip_ex = striplines(iip_ex)
86+
end
87+
88+
if expression == Val{true}
89+
return oop_ex, iip_ex
90+
else
91+
return GeneralizedGenerated.mk_function(@__MODULE__,oop_ex), GeneralizedGenerated.mk_function(@__MODULE__,iip_ex)
92+
end
93+
end
94+
95+
function numbered_expr(O::Equation,args...;kwargs...)
96+
:($(numbered_expr(O.lhs,args...;kwargs...)) = $(numbered_expr(O.rhs,args...;kwargs...)))
97+
end
98+
99+
function numbered_expr(O::Operation,vars,parameters;
100+
derivname=:du,
101+
varname=:u,paramname=:p)
102+
if isa(O.op, ModelingToolkit.Differential)
103+
varop = O.args[1]
104+
i = findfirst(x->isequal(x,varop),vars)
105+
return :($derivname[$i])
106+
elseif isa(O.op, ModelingToolkit.Variable)
107+
i = findfirst(x->isequal(x,O),vars)
108+
if i == nothing
109+
i = findfirst(x->isequal(x,O),parameters)
110+
return :($paramname[$i])
111+
else
112+
return :($varname[$i])
113+
end
114+
end
115+
return Expr(:call, Symbol(O.op),
116+
[numbered_expr(x,vars,parameters;derivname=derivname,
117+
varname=varname,paramname=paramname) for x in O.args]...)
118+
end
119+
120+
numbered_expr(c::ModelingToolkit.Constant,args...;kwargs...) = c.value
121+
122+
function _build_function(target::StanTarget, eqs, vs, ps, iv,
123+
conv = simplified_expr, expression = Val{true};
124+
fname = :diffeqf, derivname=:internal_var___du,
125+
varname=:internal_var___u,paramname=:internal_var___p)
126+
differential_equation = string(join([numbered_expr(eq,vs,ps,derivname=derivname,
127+
varname=varname,paramname=paramname) for
128+
(i, eq) enumerate(eqs)],";\n "),";")
129+
"""
130+
real[] $fname(real $iv,real[] $varname,real[] $paramname,real[] x_r,int[] x_i) {
131+
real $derivname[$(length(eqs))];
132+
$differential_equation
133+
return $derivname;
134+
}
135+
"""
136+
end
137+
138+
function _build_function(target::CTarget, eqs, vs, ps, iv,
139+
conv = simplified_expr, expression = Val{true};
140+
fname = :diffeqf, derivname=:internal_var___du,
141+
varname=:internal_var___u,paramname=:internal_var___p)
142+
differential_equation = string(join([numbered_expr(eq,vs,ps,derivname=derivname,
143+
varname=varname,paramname=paramname) for
144+
(i, eq) enumerate(eqs)],";\n "),";")
145+
"""
146+
void $fname(double* $derivname, double* $varname, double* $paramname, $iv) {
147+
$differential_equation
148+
}
149+
"""
150+
end
151+
152+
function _build_function(target::MATLABTarget, eqs, vs, ps, iv,
153+
conv = simplified_expr, expression = Val{true};
154+
fname = :diffeqf, derivname=:internal_var___du,
155+
varname=:internal_var___u,paramname=:internal_var___p)
156+
matstr = join([numbered_expr(eq.rhs,vs,ps,derivname=derivname,
157+
varname=varname,paramname=paramname) for
158+
(i, eq) enumerate(eqs)],"; ")
159+
160+
matstr = replace(matstr,"["=>"(")
161+
matstr = replace(matstr,"]"=>")")
162+
matstr = "$fname = @(t,$varname) ["*matstr*"];"
163+
matstr
164+
end

src/utils.jl

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -53,88 +53,6 @@ function retime_dvs(op::Operation,dvs,iv)
5353
end
5454
retime_dvs(op::Constant,dvs,iv) = op
5555

56-
function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, expression = Val{true};
57-
checkbounds = false, constructor=nothing, linenumbers = true)
58-
_vs = map(x-> x isa Operation ? x.op : x, vs)
59-
_ps = map(x-> x isa Operation ? x.op : x, ps)
60-
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
61-
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(_ps)]
62-
(ls, rs) = zip(var_pairs..., param_pairs...)
63-
64-
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
65-
66-
fname = gensym(:ModelingToolkitFunction)
67-
68-
X = gensym(:MTIIPVar)
69-
if rhss isa SparseMatrixCSC
70-
ip_sys_exprs = [:($X.nzval[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss.nzval)]
71-
else
72-
ip_sys_exprs = [:($X[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss)]
73-
end
74-
75-
ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))
76-
77-
tuple_sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
78-
79-
if rhss isa Matrix
80-
arr_sys_expr = build_expr(:vcat, [build_expr(:row,[conv(rhs) for rhs rhss[i,:]]) for i in 1:size(rhss,1)])
81-
# : x because ??? what to do in the general case?
82-
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->(out = similar(typeof(u),$(size(rhss)...)); out .= x)) : constructor
83-
elseif typeof(rhss) <: Array && !(typeof(rhss) <: Vector)
84-
vector_form = build_expr(:vect, [conv(rhs) for rhs rhss])
85-
arr_sys_expr = :(reshape($vector_form,$(size(rhss)...)))
86-
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SArray{$(size(rhss)...)} : x->(out = similar(typeof(u),$(size(rhss)...)); out .= x)) : constructor
87-
elseif rhss isa SparseMatrixCSC
88-
vector_form = build_expr(:vect, [conv(rhs) for rhs nonzeros(rhss)])
89-
arr_sys_expr = :(SparseMatrixCSC{eltype(u),Int}($(size(rhss)...), $(rhss.colptr), $(rhss.rowval), $vector_form))
90-
# Static and sparse? Probably not a combo that will actually be hit, but give a default anyways
91-
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->x) : constructor
92-
else # Vector
93-
arr_sys_expr = build_expr(:vect, [conv(rhs) for rhs rhss])
94-
# Handle vector constructor separately using `typeof(u)` to support things like LabelledArrays
95-
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X)) : x->convert(typeof(u),x)) : constructor
96-
end
97-
98-
let_expr = Expr(:let, var_eqs, tuple_sys_expr)
99-
arr_let_expr = Expr(:let, var_eqs, arr_sys_expr)
100-
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
101-
arr_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end)
102-
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)
103-
104-
fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))
105-
106-
oop_ex = :(
107-
($(fargs.args...),) -> begin
108-
# If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
109-
if $(fargs.args[1]) isa Array || (!(typeof($(fargs.args[1])) <: StaticArray) && $(rhss isa SparseMatrixCSC))
110-
return $arr_bounds_block
111-
else
112-
X = $bounds_block
113-
construct = $_constructor
114-
return construct(X)
115-
end
116-
end
117-
)
118-
119-
iip_ex = :(
120-
($X,$(fargs.args...)) -> begin
121-
$ip_bounds_block
122-
nothing
123-
end
124-
)
125-
126-
if !linenumbers
127-
oop_ex = striplines(oop_ex)
128-
iip_ex = striplines(iip_ex)
129-
end
130-
131-
if expression == Val{true}
132-
return oop_ex, iip_ex
133-
else
134-
return GeneralizedGenerated.mk_function(@__MODULE__,oop_ex), GeneralizedGenerated.mk_function(@__MODULE__,iip_ex)
135-
end
136-
end
137-
13856
is_constant(::Constant) = true
13957
is_constant(::Any) = false
14058

test/build_targets.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using ModelingToolkit, Test
2+
@parameters t a
3+
@variables x(t) y(t)
4+
@derivatives D'~t
5+
eqs = [D(x) ~ a*x - x*y,
6+
D(y) ~ -3y + x*y]
7+
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.StanTarget()) ==
8+
"""
9+
real[] diffeqf(real t,real[] internal_var___u,real[] internal_var___p,real[] x_r,int[] x_i) {
10+
real internal_var___du[2];
11+
internal_var___du[1] = internal_var___p[1] * internal_var___u[1] - internal_var___u[1] * internal_var___u[2];
12+
internal_var___du[2] = -3 * internal_var___u[2] + internal_var___u[1] * internal_var___u[2];
13+
return internal_var___du;
14+
}
15+
"""
16+
17+
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.CTarget()) ==
18+
"""
19+
void diffeqf(double* internal_var___du, double* internal_var___u, double* internal_var___p, t) {
20+
internal_var___du[1] = internal_var___p[1] * internal_var___u[1] - internal_var___u[1] * internal_var___u[2];
21+
internal_var___du[2] = -3 * internal_var___u[2] + internal_var___u[1] * internal_var___u[2];
22+
}
23+
"""
24+
25+
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.MATLABTarget()) ==
26+
"""
27+
diffeqf = @(t,internal_var___u) [internal_var___p(1) * internal_var___u(1) - internal_var___u(1) * internal_var___u(2); -3 * internal_var___u(2) + internal_var___u(1) * internal_var___u(2)];"""

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using ModelingToolkit, Test
55
@testset "Simplify Test" begin include("simplify.jl") end
66
@testset "Direct Usage Test" begin include("direct.jl") end
77
@testset "System Construction Test" begin include("system_construction.jl") end
8+
@testset "Build Targets Test" begin include("build_targets.jl") end
89
@testset "Domain Test" begin include("domains.jl") end
910
@testset "Constraints Test" begin include("constraints.jl") end
1011
@testset "PDE Construction Test" begin include("pde.jl") end

0 commit comments

Comments
 (0)