@@ -87,6 +87,16 @@ function add_integrator_header(ex, fargs, iip; X=gensym(:MTIIPVar))
8787 wrappedex
8888end
8989
90+ function unflatten_long_ops (op, N= 4 )
91+ rule1 = @rule ((+ )((~~ x)) => length (~~ x) > N ?
92+ + (+ ((~~ x)[1 : N]. .. ) + (+ )((~~ x)[N+ 1 : end ]. .. )) : nothing )
93+ rule2 = @rule ((* )((~~ x)) => length (~~ x) > N ?
94+ * (* ((~~ x)[1 : N]. .. ) * (* )((~~ x)[N+ 1 : end ]. .. )) : nothing )
95+
96+ op = to_symbolic (op)
97+ Rewriters. Fixpoint (Rewriters. Postwalk (Rewriters. Chain ([rule1, rule2])))(op) |> to_mtk
98+ end
99+
90100# Scalar output
91101function _build_function (target:: JuliaTarget , op:: Operation , args... ;
92102 conv = simplified_expr, expression = Val{true },
@@ -97,9 +107,10 @@ function _build_function(target::JuliaTarget, op::Operation, args...;
97107 arg_pairs = map (vars_to_pairs,zip (argnames,args))
98108 ls = reduce (vcat,first .(arg_pairs))
99109 rs = reduce (vcat,last .(arg_pairs))
100- var_eqs = Expr (:(= ), ModelingToolkit. build_expr (:tuple , ls), ModelingToolkit. build_expr (:tuple , rs ))
110+ var_eqs = Expr (:(= ), ModelingToolkit. build_expr (:tuple , ls), ModelingToolkit. build_expr (:tuple , unflatten_long_ops .(rs) ))
101111
102112 fname = gensym (:ModelingToolkitFunction )
113+ op = unflatten_long_ops (op)
103114 out_expr = conv (op)
104115 let_expr = Expr (:let , var_eqs, Expr (:block , out_expr))
105116 bounds_block = checkbounds ? let_expr : :(@inbounds begin $ let_expr end )
@@ -242,7 +253,13 @@ function _build_function(target::JuliaTarget, rhss, args...;
242253 oidx = isnothing (outputidxs) ? (i -> i) : (i -> outputidxs[i])
243254 X = gensym (:MTIIPVar )
244255
245- rhs_length = rhss isa SparseMatrixCSC ? length (rhss. nzval) : length (rhss)
256+ if rhss isa SparseMatrixCSC
257+ rhs_length = length (rhss. nzval)
258+ rhss = SparseMatrixCSC (rhss. m, rhss. m, rhss. colptr, rhss. rowval, map (unflatten_long_ops, rhss. nzval))
259+ else
260+ rhs_length = length (rhss)
261+ rhss = [unflatten_long_ops (r) for r in rhss]
262+ end
246263
247264 if parallel isa DistributedForm
248265 numworks = Distributed. nworkers ()
@@ -251,6 +268,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
251268 finalsize = rhs_length - (numworks- 1 )* lens
252269 _rhss = vcat (reduce (vcat,[[getindex (reducevars[i],j) for j in 1 : lens] for i in 1 : numworks- 1 ],init= Expr[]),
253270 [getindex (reducevars[end ],j) for j in 1 : finalsize])
271+
254272 elseif parallel isa DaggerForm
255273 computevars = [Variable (gensym (:MTComputeVar ))() for i in axes (rhss,1 )]
256274 reducevar = Variable (gensym (:MTReduceVar ))()
0 commit comments