@@ -36,14 +36,22 @@ function DiffEqSystem(eqs, ivs;
3636 DiffEqSystem (eqs, ivs, dvs, vs, ps, ivs[1 ]. subtype, dv_name, p_name, Matrix {Expression} (undef,0 ,0 ))
3737end
3838
39- function generate_ode_function (sys:: DiffEqSystem )
39+ function generate_ode_function (sys:: DiffEqSystem ;version = ArrayFunction )
4040 var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in 1 : length (sys. dvs)]
4141 param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in 1 : length (sys. ps)]
4242 sys_exprs = build_equals_expr .(sys. eqs)
43- dvar_exprs = [:(du[$ i] = $ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[1 ]. name) " ))) for i in 1 : length (sys. dvs)]
44- exprs = vcat (var_exprs,param_exprs,sys_exprs,dvar_exprs)
45- block = expr_arr_to_block (exprs)
46- :((du,u,p,t)-> $ (toexpr (block)))
43+ if version == ArrayFunction
44+ dvar_exprs = [:(du[$ i] = $ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[1 ]. name) " ))) for i in 1 : length (sys. dvs)]
45+ exprs = vcat (var_exprs,param_exprs,sys_exprs,dvar_exprs)
46+ block = expr_arr_to_block (exprs)
47+ :((du,u,p,t)-> $ (toexpr (block)))
48+ elseif version == SArrayFunction
49+ dvar_exprs = [:($ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[1 ]. name) " ))) for i in 1 : length (sys. dvs)]
50+ svector_expr = :(typeof (u)($ (dvar_exprs... )))
51+ exprs = vcat (var_exprs,param_exprs,sys_exprs,svector_expr)
52+ block = expr_arr_to_block (exprs)
53+ :((u,p,t)-> $ (toexpr (block)))
54+ end
4755end
4856
4957isintermediate (eq) = eq. args[1 ]. diff == nothing
@@ -123,9 +131,13 @@ function generate_ode_iW(sys::DiffEqSystem,simplify=true)
123131 :((iW,u,p,gam,t)-> $ (block)),:((iW,u,p,gam,t)-> $ (block2))
124132end
125133
126- function DiffEqBase. ODEFunction (sys:: DiffEqSystem )
127- expr = generate_ode_function (sys)
128- ODEFunction {true} (eval (expr))
134+ function DiffEqBase. ODEFunction (sys:: DiffEqSystem ;version = ArrayFunction,kwargs... )
135+ expr = generate_ode_function (sys;kwargs... )
136+ if version == ArrayFunction
137+ ODEFunction {true} (eval (expr))
138+ elseif version == SArrayFunction
139+ ODEFunction {false} (eval (expr))
140+ end
129141end
130142
131143export DiffEqSystem, ODEFunction
0 commit comments