@@ -16,10 +16,10 @@ i.e. f(u,p,args...) for the out-of-place and scalar functions and
1616`f!(du,u,p,args..)` for the in-place version.
1717
1818```julia
19- build_function(ex vs, ps = (), args = (),
20- conv = simplified_expr, expression = Val{true};
19+ build_function(ex, args...;
20+ conv = simplified_expr, expression = Val{true},
2121 checkbounds = false, constructor=nothing,
22- linenumbers = true , target = JuliaTarget())
22+ linenumbers = false , target = JuliaTarget())
2323```
2424
2525Arguments:
@@ -57,23 +57,23 @@ function build_function(args...;target = JuliaTarget(),kwargs...)
5757end
5858
5959# Scalar output
60- function _build_function (target:: JuliaTarget , op:: Operation , vs, ps = (), args = (),
61- conv = simplified_expr, expression = Val{true };
60+ function _build_function (target:: JuliaTarget , op:: Operation , args... ;
61+ conv = simplified_expr, expression = Val{true },
6262 checkbounds = false , constructor= nothing ,
6363 linenumbers = true )
64- _vs = convert .(Variable,vs)
65- _ps = convert .(Variable,ps)
66- var_pairs = [(u . name, :(u[ $ i])) for (i, u) ∈ enumerate (_vs)]
67- param_pairs = [(p . name, :(p[ $ i])) for (i, p) ∈ enumerate (_ps)]
68- (ls, rs) = zip (var_pairs ... , param_pairs ... )
69- var_eqs = Expr (:(= ), build_expr (:tuple , ls), build_expr (:tuple , rs))
64+
65+ argnames = [ gensym ( :MTKArg ) for i in 1 : length (args)]
66+ arg_pairs = map (vars_to_pairs, zip (argnames,args))
67+ ls = reduce (vcat, first .(arg_pairs))
68+ rs = reduce (vcat, last .(arg_pairs) )
69+ var_eqs = Expr (:(= ), ModelingToolkit . build_expr (:tuple , ls), ModelingToolkit . build_expr (:tuple , rs))
7070
7171 fname = gensym (:ModelingToolkitFunction )
7272 out_expr = conv (op)
7373 let_expr = Expr (:let , var_eqs, out_expr)
7474 bounds_block = checkbounds ? let_expr : :(@inbounds begin $ let_expr end )
7575
76- fargs = ps == () ? :(u, $ (args ... )) : :(u,p, $ (args ... ) )
76+ fargs = Expr ( :tuple ,argnames ... )
7777
7878 oop_ex = :(
7979 ($ (fargs. args... ),) -> begin
@@ -92,19 +92,19 @@ function _build_function(target::JuliaTarget, op::Operation, vs, ps = (), args =
9292 end
9393end
9494
95- function _build_function (target:: JuliaTarget , rhss, vs, ps = (), args = (),
96- conv = simplified_expr, expression = Val{true };
95+ function _build_function (target:: JuliaTarget , rhss, args... ;
96+ conv = simplified_expr, expression = Val{true },
9797 checkbounds = false , constructor= nothing ,
98- linenumbers = true , multithread= false )
99- _vs = convert .(Variable,vs)
100- _ps = convert .(Variable,ps)
101- var_pairs = [(u. name, :(u[$ i])) for (i, u) ∈ enumerate (_vs)]
102- param_pairs = [(p. name, :(p[$ i])) for (i, p) ∈ enumerate (_ps)]
103- (ls, rs) = zip (var_pairs... , param_pairs... )
98+ linenumbers = false , multithread= false )
10499
105- var_eqs = Expr (:(= ), build_expr (:tuple , ls), build_expr (:tuple , rs))
100+ argnames = [gensym (:MTKArg ) for i in 1 : length (args)]
101+ arg_pairs = map (vars_to_pairs,zip (argnames,args))
102+ ls = reduce (vcat,first .(arg_pairs))
103+ rs = reduce (vcat,last .(arg_pairs))
104+ var_eqs = Expr (:(= ), ModelingToolkit. build_expr (:tuple , ls), ModelingToolkit. build_expr (:tuple , rs))
106105
107106 fname = gensym (:ModelingToolkitFunction )
107+ fargs = Expr (:tuple ,argnames... )
108108
109109 X = gensym (:MTIIPVar )
110110 if rhss isa SparseMatrixCSC
@@ -135,20 +135,20 @@ function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
135135 if rhss isa Matrix
136136 arr_sys_expr = build_expr (:vcat , [build_expr (:row ,[conv (rhs) for rhs ∈ rhss[i,:]]) for i in 1 : size (rhss,1 )])
137137 # : x because ??? what to do in the general case?
138- _constructor = constructor === nothing ? :(u isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. SMatrix{$ (size (rhss)... )} : x-> (out = similar (typeof (u ),$ (size (rhss)... )); out .= x)) : constructor
138+ _constructor = constructor === nothing ? :($ ( first (argnames)) isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. SMatrix{$ (size (rhss)... )} : x-> (out = similar (typeof ($ (fargs . args[ 1 ]) ),$ (size (rhss)... )); out .= x)) : constructor
139139 elseif typeof (rhss) <: Array && ! (typeof (rhss) <: Vector )
140140 vector_form = build_expr (:vect , [conv (rhs) for rhs ∈ rhss])
141141 arr_sys_expr = :(reshape ($ vector_form,$ (size (rhss)... )))
142- _constructor = constructor === nothing ? :(u isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. SArray{$ (size (rhss)... )} : x-> (out = similar (typeof (u ),$ (size (rhss)... )); out .= x)) : constructor
142+ _constructor = constructor === nothing ? :($ ( first (argnames)) isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. SArray{$ (size (rhss)... )} : x-> (out = similar (typeof ($ (fargs . args[ 1 ]) ),$ (size (rhss)... )); out .= x)) : constructor
143143 elseif rhss isa SparseMatrixCSC
144144 vector_form = build_expr (:vect , [conv (rhs) for rhs ∈ nonzeros (rhss)])
145- arr_sys_expr = :(SparseMatrixCSC {eltype(u ),Int} ($ (size (rhss)... ), $ (rhss. colptr), $ (rhss. rowval), $ vector_form))
145+ arr_sys_expr = :(SparseMatrixCSC {eltype($(first(argnames)) ),Int} ($ (size (rhss)... ), $ (rhss. colptr), $ (rhss. rowval), $ vector_form))
146146 # Static and sparse? Probably not a combo that will actually be hit, but give a default anyways
147- _constructor = constructor === nothing ? :(u isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. SMatrix{$ (size (rhss)... )} : x-> x) : constructor
147+ _constructor = constructor === nothing ? :($ ( first (argnames)) isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. SMatrix{$ (size (rhss)... )} : x-> x) : constructor
148148 else # Vector
149149 arr_sys_expr = build_expr (:vect , [conv (rhs) for rhs ∈ rhss])
150150 # Handle vector constructor separately using `typeof(u)` to support things like LabelledArrays
151- _constructor = constructor === nothing ? :(u isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. similar_type (typeof (u) , eltype (X)) : x-> convert (typeof (u ),x)) : constructor
151+ _constructor = constructor === nothing ? :($ ( first (argnames)) isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. similar_type (typeof ($ (fargs . args[ 1 ])) , eltype (X)) : x-> convert (typeof ($ (fargs . args[ 1 ]) ),x)) : constructor
152152 end
153153
154154 let_expr = Expr (:let , var_eqs, tuple_sys_expr)
@@ -157,8 +157,6 @@ function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
157157 arr_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $ arr_let_expr end )
158158 ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ ip_let_expr end )
159159
160- fargs = ps == () ? :(u,$ (args... )) : :(u,p,$ (args... ))
161-
162160 oop_ex = :(
163161 ($ (fargs. args... ),) -> begin
164162 # If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
@@ -191,6 +189,21 @@ function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
191189 end
192190end
193191
192+ vars_to_pairs (args) = vars_to_pairs (args[1 ],args[2 ])
193+ function vars_to_pairs (name,vs:: AbstractArray )
194+ _vs = convert .(Variable,vs)
195+ names = [Symbol (u) for u ∈ _vs]
196+ exs = [:($ name[$ i]) for (i, u) ∈ enumerate (_vs)]
197+ names,exs
198+ end
199+
200+ function vars_to_pairs (name,vs)
201+ _vs = convert (Variable,vs)
202+ names = [Symbol (_vs)]
203+ exs = [name]
204+ names,exs
205+ end
206+
194207get_varnumber (varop:: Operation ,vars:: Vector{Operation} ) = findfirst (x-> isequal (x,varop),vars)
195208get_varnumber (varop:: Operation ,vars:: Vector{<:Variable} ) = findfirst (x-> isequal (x,varop. op),vars)
196209
@@ -251,8 +264,8 @@ function _build_function(target::StanTarget, eqs, vs, ps, iv,
251264 """
252265end
253266
254- function _build_function (target:: CTarget , eqs, vs, ps, iv,
255- conv = simplified_expr, expression = Val{true };
267+ function _build_function (target:: CTarget , eqs, vs, ps, iv;
268+ conv = simplified_expr, expression = Val{true },
256269 fname = :diffeqf , derivname= :internal_var___du ,
257270 varname= :internal_var___u ,paramname= :internal_var___p )
258271 differential_equation = string (join ([numbered_expr (eq,vs,ps,derivname= derivname,
@@ -265,8 +278,8 @@ function _build_function(target::CTarget, eqs, vs, ps, iv,
265278 """
266279end
267280
268- function _build_function (target:: MATLABTarget , eqs, vs, ps, iv,
269- conv = simplified_expr, expression = Val{true };
281+ function _build_function (target:: MATLABTarget , eqs, vs, ps, iv;
282+ conv = simplified_expr, expression = Val{true },
270283 fname = :diffeqf , derivname= :internal_var___du ,
271284 varname= :internal_var___u ,paramname= :internal_var___p )
272285 matstr = join ([numbered_expr (eq. rhs,vs,ps,derivname= derivname,
0 commit comments