Skip to content

Commit 229ab20

Browse files
Merge pull request #298 from SciML/build_function
Generalize build_function
2 parents a699b71 + fc03508 commit 229ab20

File tree

10 files changed

+161
-79
lines changed

10 files changed

+161
-79
lines changed

README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,18 +84,18 @@ connected = ODESystem(connections,t,[α],[γ],systems=[lorenz1,lorenz2])
8484
u0 = [lorenz1.x => 1.0,
8585
lorenz1.y => 0.0,
8686
lorenz1.z => 0.0,
87-
lorenz2.x => 0.0,
88-
lorenz2.y => 1.0,
89-
lorenz2.z => 0.0,
90-
α => 2.0]
87+
lorenz2.x => 0.0,
88+
lorenz2.y => 1.0,
89+
lorenz2.z => 0.0,
90+
α => 2.0]
9191

9292
p = [lorenz1.σ => 10.0,
9393
lorenz1.ρ => 28.0,
9494
lorenz1.β => 8/3,
95-
lorenz2.σ => 10.0,
96-
lorenz2.ρ => 28.0,
97-
lorenz2.β => 8/3,
98-
γ => 2.0]
95+
lorenz2.σ => 10.0,
96+
lorenz2.ρ => 28.0,
97+
lorenz2.β => 8/3,
98+
γ => 2.0]
9999

100100
tspan = (0.0,100.0)
101101
prob = ODEProblem(connected,u0,tspan,p)

src/build_function.jl

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ i.e. f(u,p,args...) for the out-of-place and scalar functions and
1616
`f!(du,u,p,args..)` for the in-place version.
1717
1818
```julia
19-
build_function(ex vs, ps = (), args = (),
20-
conv = simplified_expr, expression = Val{true};
19+
build_function(ex, args...;
20+
conv = simplified_expr, expression = Val{true},
2121
checkbounds = false, constructor=nothing,
22-
linenumbers = true, target = JuliaTarget())
22+
linenumbers = false, target = JuliaTarget())
2323
```
2424
2525
Arguments:
@@ -57,23 +57,23 @@ function build_function(args...;target = JuliaTarget(),kwargs...)
5757
end
5858

5959
# Scalar output
60-
function _build_function(target::JuliaTarget, op::Operation, vs, ps = (), args = (),
61-
conv = simplified_expr, expression = Val{true};
60+
function _build_function(target::JuliaTarget, op::Operation, args...;
61+
conv = simplified_expr, expression = Val{true},
6262
checkbounds = false, constructor=nothing,
6363
linenumbers = true)
64-
_vs = convert.(Variable,vs)
65-
_ps = convert.(Variable,ps)
66-
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
67-
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(_ps)]
68-
(ls, rs) = zip(var_pairs..., param_pairs...)
69-
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
64+
65+
argnames = [gensym(:MTKArg) for i in 1:length(args)]
66+
arg_pairs = map(vars_to_pairs,zip(argnames,args))
67+
ls = reduce(vcat,first.(arg_pairs))
68+
rs = reduce(vcat,last.(arg_pairs))
69+
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, rs))
7070

7171
fname = gensym(:ModelingToolkitFunction)
7272
out_expr = conv(op)
7373
let_expr = Expr(:let, var_eqs, out_expr)
7474
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
7575

76-
fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))
76+
fargs = Expr(:tuple,argnames...)
7777

7878
oop_ex = :(
7979
($(fargs.args...),) -> begin
@@ -92,19 +92,19 @@ function _build_function(target::JuliaTarget, op::Operation, vs, ps = (), args =
9292
end
9393
end
9494

95-
function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
96-
conv = simplified_expr, expression = Val{true};
95+
function _build_function(target::JuliaTarget, rhss, args...;
96+
conv = simplified_expr, expression = Val{true},
9797
checkbounds = false, constructor=nothing,
98-
linenumbers = true, multithread=false)
99-
_vs = convert.(Variable,vs)
100-
_ps = convert.(Variable,ps)
101-
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
102-
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(_ps)]
103-
(ls, rs) = zip(var_pairs..., param_pairs...)
98+
linenumbers = false, multithread=false)
10499

105-
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
100+
argnames = [gensym(:MTKArg) for i in 1:length(args)]
101+
arg_pairs = map(vars_to_pairs,zip(argnames,args))
102+
ls = reduce(vcat,first.(arg_pairs))
103+
rs = reduce(vcat,last.(arg_pairs))
104+
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, rs))
106105

107106
fname = gensym(:ModelingToolkitFunction)
107+
fargs = Expr(:tuple,argnames...)
108108

109109
X = gensym(:MTIIPVar)
110110
if rhss isa SparseMatrixCSC
@@ -135,20 +135,20 @@ function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
135135
if rhss isa Matrix
136136
arr_sys_expr = build_expr(:vcat, [build_expr(:row,[conv(rhs) for rhs rhss[i,:]]) for i in 1:size(rhss,1)])
137137
# : x because ??? what to do in the general case?
138-
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->(out = similar(typeof(u),$(size(rhss)...)); out .= x)) : constructor
138+
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->(out = similar(typeof($(fargs.args[1])),$(size(rhss)...)); out .= x)) : constructor
139139
elseif typeof(rhss) <: Array && !(typeof(rhss) <: Vector)
140140
vector_form = build_expr(:vect, [conv(rhs) for rhs rhss])
141141
arr_sys_expr = :(reshape($vector_form,$(size(rhss)...)))
142-
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SArray{$(size(rhss)...)} : x->(out = similar(typeof(u),$(size(rhss)...)); out .= x)) : constructor
142+
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SArray{$(size(rhss)...)} : x->(out = similar(typeof($(fargs.args[1])),$(size(rhss)...)); out .= x)) : constructor
143143
elseif rhss isa SparseMatrixCSC
144144
vector_form = build_expr(:vect, [conv(rhs) for rhs nonzeros(rhss)])
145-
arr_sys_expr = :(SparseMatrixCSC{eltype(u),Int}($(size(rhss)...), $(rhss.colptr), $(rhss.rowval), $vector_form))
145+
arr_sys_expr = :(SparseMatrixCSC{eltype($(first(argnames))),Int}($(size(rhss)...), $(rhss.colptr), $(rhss.rowval), $vector_form))
146146
# Static and sparse? Probably not a combo that will actually be hit, but give a default anyways
147-
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->x) : constructor
147+
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->x) : constructor
148148
else # Vector
149149
arr_sys_expr = build_expr(:vect, [conv(rhs) for rhs rhss])
150150
# Handle vector constructor separately using `typeof(u)` to support things like LabelledArrays
151-
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X)) : x->convert(typeof(u),x)) : constructor
151+
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof($(fargs.args[1])), eltype(X)) : x->convert(typeof($(fargs.args[1])),x)) : constructor
152152
end
153153

154154
let_expr = Expr(:let, var_eqs, tuple_sys_expr)
@@ -157,8 +157,6 @@ function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
157157
arr_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end)
158158
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)
159159

160-
fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))
161-
162160
oop_ex = :(
163161
($(fargs.args...),) -> begin
164162
# If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
@@ -191,6 +189,21 @@ function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
191189
end
192190
end
193191

192+
vars_to_pairs(args) = vars_to_pairs(args[1],args[2])
193+
function vars_to_pairs(name,vs::AbstractArray)
194+
_vs = convert.(Variable,vs)
195+
names = [Symbol(u) for u _vs]
196+
exs = [:($name[$i]) for (i, u) enumerate(_vs)]
197+
names,exs
198+
end
199+
200+
function vars_to_pairs(name,vs)
201+
_vs = convert(Variable,vs)
202+
names = [Symbol(_vs)]
203+
exs = [name]
204+
names,exs
205+
end
206+
194207
get_varnumber(varop::Operation,vars::Vector{Operation}) = findfirst(x->isequal(x,varop),vars)
195208
get_varnumber(varop::Operation,vars::Vector{<:Variable}) = findfirst(x->isequal(x,varop.op),vars)
196209

@@ -251,8 +264,8 @@ function _build_function(target::StanTarget, eqs, vs, ps, iv,
251264
"""
252265
end
253266

254-
function _build_function(target::CTarget, eqs, vs, ps, iv,
255-
conv = simplified_expr, expression = Val{true};
267+
function _build_function(target::CTarget, eqs, vs, ps, iv;
268+
conv = simplified_expr, expression = Val{true},
256269
fname = :diffeqf, derivname=:internal_var___du,
257270
varname=:internal_var___u,paramname=:internal_var___p)
258271
differential_equation = string(join([numbered_expr(eq,vs,ps,derivname=derivname,
@@ -265,8 +278,8 @@ function _build_function(target::CTarget, eqs, vs, ps, iv,
265278
"""
266279
end
267280

268-
function _build_function(target::MATLABTarget, eqs, vs, ps, iv,
269-
conv = simplified_expr, expression = Val{true};
281+
function _build_function(target::MATLABTarget, eqs, vs, ps, iv;
282+
conv = simplified_expr, expression = Val{true},
270283
fname = :diffeqf, derivname=:internal_var___du,
271284
varname=:internal_var___u,paramname=:internal_var___p)
272285
matstr = join([numbered_expr(eq.rhs,vs,ps,derivname=derivname,

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,27 @@ function (f::ODEToExpr)(O::Operation)
3636
end
3737
(f::ODEToExpr)(x) = convert(Expr, x)
3838

39-
function generate_tgrad(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
39+
function generate_tgrad(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
4040
tgrad = calculate_tgrad(sys)
41-
return build_function(tgrad, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
41+
return build_function(tgrad, dvs, ps, sys.iv;
42+
conv = ODEToExpr(sys), kwargs...)
4243
end
4344

44-
function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys), expression = Val{true}; sparse = false, kwargs...)
45+
function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys); sparse = false, kwargs...)
4546
jac = calculate_jacobian(sys)
4647
if sparse
4748
jac = SparseArrays.sparse(jac)
4849
end
49-
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
50+
return build_function(jac, dvs, ps, sys.iv;
51+
conv = ODEToExpr(sys), kwargs...)
5052
end
5153

52-
function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
54+
function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
5355
rhss = [deq.rhs for deq equations(sys)]
5456
dvs′ = convert.(Variable,dvs)
5557
ps′ = convert.(Variable,ps)
56-
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
58+
return build_function(rhss, dvs′, ps′, sys.iv;
59+
conv = ODEToExpr(sys),kwargs...)
5760
end
5861

5962
function calculate_factorized_W(sys::AbstractODESystem, simplify=true)
@@ -88,8 +91,10 @@ function generate_factorized_W(sys::AbstractODESystem, vs = states(sys), ps = pa
8891
StaticArrays.LU(LowerTriangular( SMatrix{$siz...}(UnitLowerTriangular(A)) ), UpperTriangular(A), SVector(ntuple(n->n, max($siz...))))
8992
end)
9093

91-
Wfact_func = build_function(Wfact , vs, ps, (:__MTKWgamma,:t), ODEToExpr(sys), expression;constructor=constructor,kwargs...)
92-
Wfact_t_func = build_function(Wfact_t, vs, ps, (:__MTKWgamma,:t), ODEToExpr(sys), expression;constructor=constructor,kwargs...)
94+
Wfact_func = build_function(Wfact , vs, ps, Variable(:__MTKWgamma), sys.iv;
95+
conv = ODEToExpr(sys), expression = expression, constructor=constructor,kwargs...)
96+
Wfact_t_func = build_function(Wfact_t, vs, ps, Variable(:__MTKWgamma), sys.iv;
97+
conv = ODEToExpr(sys), expression = expression, constructor=constructor,kwargs...)
9398

9499
return (Wfact_func, Wfact_t_func)
95100
end
@@ -136,29 +141,29 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
136141
jac = false, Wfact = false,
137142
sparse = false,
138143
kwargs...) where {iip}
139-
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false}; kwargs...)
144+
f_oop,f_iip = generate_function(sys, dvs, ps; expression=Val{false}, kwargs...)
140145

141146
f(u,p,t) = f_oop(u,p,t)
142147
f(du,u,p,t) = f_iip(du,u,p,t)
143148

144149
if tgrad
145-
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps, Val{false}; kwargs...)
150+
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps; expression=Val{false}, kwargs...)
146151
_tgrad(u,p,t) = tgrad_oop(u,p,t)
147152
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
148153
else
149154
_tgrad = nothing
150155
end
151156

152157
if jac
153-
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false}; sparse = sparse, kwargs...)
158+
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps; sparse = sparse, expression=Val{false}, kwargs...)
154159
_jac(u,p,t) = jac_oop(u,p,t)
155160
_jac(J,u,p,t) = jac_iip(J,u,p,t)
156161
else
157162
_jac = nothing
158163
end
159164

160165
if Wfact
161-
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true, Val{false}; kwargs...)
166+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps; expression=Val{false}, kwargs...)
162167
Wfact_oop, Wfact_iip = tmp_Wfact
163168
Wfact_oop_t, Wfact_iip_t = tmp_Wfact_t
164169
_Wfact(u,p,dtgamma,t) = Wfact_oop(u,p,dtgamma,t)

src/systems/diffeqs/sdesystem.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,11 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
8080
SDESystem(deqs, neqs, iv′, dvs′, ps′, tgrad, jac, Wfact, Wfact_t, name, systems)
8181
end
8282

83-
function generate_diffusion_function(sys::SDESystem, dvs = sys.states, ps = sys.ps, expression = Val{true}; kwargs...)
83+
function generate_diffusion_function(sys::SDESystem, dvs = sys.states, ps = sys.ps; kwargs...)
8484
dvs′ = convert.(Variable,dvs)
8585
ps′ = convert.(Variable,ps)
86-
return build_function(sys.noiseeqs, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
86+
return build_function(sys.noiseeqs, dvs′, ps′, sys.iv;
87+
conv = ODEToExpr(sys),kwargs...)
8788
end
8889

8990
"""
@@ -100,32 +101,32 @@ respectively.
100101
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.ps;
101102
version = nothing, tgrad=false, sparse = false,
102103
jac = false, Wfact = false, kwargs...) where {iip}
103-
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false}; kwargs...)
104-
g_oop,g_iip = generate_diffusion_function(sys, dvs, ps, Val{false}; kwargs...)
104+
f_oop,f_iip = generate_function(sys, dvs, ps; expression=Val{false}, kwargs...)
105+
g_oop,g_iip = generate_diffusion_function(sys, dvs, ps; expression=Val{false}, kwargs...)
105106

106107
f(u,p,t) = f_oop(u,p,t)
107108
f(du,u,p,t) = f_iip(du,u,p,t)
108109
g(u,p,t) = g_oop(u,p,t)
109110
g(du,u,p,t) = g_iip(du,u,p,t)
110111

111112
if tgrad
112-
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps, Val{false}; kwargs...)
113+
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps; expression=Val{false}, kwargs...)
113114
_tgrad(u,p,t) = tgrad_oop(u,p,t)
114115
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
115116
else
116117
_tgrad = nothing
117118
end
118119

119120
if jac
120-
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false}; sparse=sparse, kwargs...)
121+
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps; expression=Val{false}, sparse=sparse, kwargs...)
121122
_jac(u,p,t) = jac_oop(u,p,t)
122123
_jac(J,u,p,t) = jac_iip(J,u,p,t)
123124
else
124125
_jac = nothing
125126
end
126127

127128
if Wfact
128-
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true, Val{false}; kwargs...)
129+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true; expression=Val{false}, kwargs...)
129130
Wfact_oop, Wfact_iip = tmp_Wfact
130131
Wfact_oop_t, Wfact_iip_t = tmp_Wfact_t
131132
_Wfact(u,p,dtgamma,t) = Wfact_oop(u,p,dtgamma,t)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,22 @@ function calculate_jacobian(sys::NonlinearSystem)
4747
return jac
4848
end
4949

50-
function generate_jacobian(sys::NonlinearSystem, vs = states(sys), ps = parameters(sys), expression = Val{true};
50+
function generate_jacobian(sys::NonlinearSystem, vs = states(sys), ps = parameters(sys);
5151
sparse = false, kwargs...)
5252
jac = calculate_jacobian(sys)
5353
if sparse
5454
jac = SparseArrays.sparse(jac)
5555
end
56-
return build_function(jac, convert.(Variable,vs), convert.(Variable,ps), (), AbstractSysToExpr(sys))
56+
return build_function(jac, convert.(Variable,vs), convert.(Variable,ps),
57+
conv = AbstractSysToExpr(sys))
5758
end
5859

59-
function generate_function(sys::NonlinearSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
60+
function generate_function(sys::NonlinearSystem, vs = states(sys), ps = parameters(sys); kwargs...)
6061
rhss = [eq.rhs for eq sys.eqs]
6162
vs′ = convert.(Variable,vs)
6263
ps′ = convert.(Variable,ps)
63-
return build_function(rhss, vs′, ps′, (), AbstractSysToExpr(sys), expression; kwargs...)
64+
return build_function(rhss, vs′, ps′;
65+
conv = AbstractSysToExpr(sys), kwargs...)
6466
end
6567

6668
"""
@@ -86,7 +88,7 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,tspan,
8688
ps = parameters(sys)
8789

8890
f = generate_function(sys;checkbounds=checkbounds,linenumbers=linenumbers,
89-
multithread=multithread,sparse=sparse)
91+
multithread=multithread,sparse=sparse,expression=Val{false})
9092
u0 = varmap_to_vars(u0map,dvs)
9193
p = varmap_to_vars(parammap,ps)
9294
NonlinearProblem(f,u0,tspan,p;kwargs...)

0 commit comments

Comments
 (0)