@@ -24,7 +24,7 @@ i.e., f(u,p,args...) for the out-of-place and scalar functions and
2424```julia
2525build_function(ex, args...;
2626 conv = simplified_expr, expression = Val{true},
27- checkbounds = false, constructor=nothing,
27+ checkbounds = false,
2828 linenumbers = false, target = JuliaTarget())
2929```
3030
@@ -46,8 +46,6 @@ Keyword Arguments:
4646
4747- `checkbounds`: For whether to enable bounds checking inside of the generated
4848 function. Defaults to false, meaning that `@inbounds` is applied.
49- - `constructor`: Allows for an arbitrary constructor function to be passed in
50- for handling expressions of "weird" types. Defaults to nothing.
5149- `linenumbers`: Determines whether the generated function expression retains
5250 the line numbers. Defaults to true.
5351- `target`: The output target of the compilation process. Possible options are:
104102# Scalar output
105103function _build_function (target:: JuliaTarget , op:: Operation , args... ;
106104 conv = simplified_expr, expression = Val{true },
107- checkbounds = false , constructor = nothing ,
105+ checkbounds = false ,
108106 linenumbers = true , headerfun= addheader)
109107
110108 argnames = [gensym (:MTKArg ) for i in 1 : length (args)]
165163
166164function _build_function (target:: JuliaTarget , rhss, args... ;
167165 conv = simplified_expr, expression = Val{true },
168- checkbounds = false , constructor = nothing ,
166+ checkbounds = false ,
169167 linenumbers = false , multithread= nothing ,
170168 headerfun= addheader, outputidxs= nothing ,
171169 skipzeros = false , parallel= SerialForm ())
@@ -323,41 +321,39 @@ function _build_function(target::JuliaTarget, rhss, args...;
323321
324322 if rhss isa Matrix
325323 arr_sys_expr = build_expr (:vcat , [build_expr (:row ,[conv (rhs) for rhs ∈ rhss[i,:]]) for i in 1 : size (rhss,1 )])
326- # : x because ??? what to do in the general case?
327- _constructor = constructor === nothing ? :($ (first (argnames)) isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. SMatrix{$ (size (rhss)... )} : x-> (out = similar (typeof ($ (fargs. args[1 ])),$ (size (rhss)... )); out .= x)) : constructor
328324 elseif typeof (rhss) <: Array && ! (typeof (rhss) <: Vector )
329325 vector_form = build_expr (:vect , [conv (rhs) for rhs ∈ rhss])
330326 arr_sys_expr = :(reshape ($ vector_form,$ (size (rhss)... )))
331- _constructor = constructor === nothing ? :($ (first (argnames)) isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. SArray{$ (size (rhss)... )} : x-> (out = similar (typeof ($ (fargs. args[1 ])),$ (size (rhss)... )); out .= x)) : constructor
332327 elseif rhss isa SparseMatrixCSC
333328 vector_form = build_expr (:vect , [conv (rhs) for rhs ∈ nonzeros (rhss)])
334329 arr_sys_expr = :(SparseMatrixCSC {eltype($(first(argnames))),Int} ($ (size (rhss)... ), $ (rhss. colptr), $ (rhss. rowval), $ vector_form))
335- # Static and sparse? Probably not a combo that will actually be hit, but give a default anyways
336- _constructor = constructor === nothing ? :($ (first (argnames)) isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. SMatrix{$ (size (rhss)... )} : x-> x) : constructor
337330 else # Vector
338331 arr_sys_expr = build_expr (:vect , [conv (rhs) for rhs ∈ rhss])
339- # Handle vector constructor separately using `typeof(u)` to support things like LabelledArrays
340- _constructor = constructor === nothing ? :($ (first (argnames)) isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. similar_type (typeof ($ (fargs. args[1 ])), eltype (X)) : x-> convert (typeof ($ (fargs. args[1 ])),x)) : constructor
341332 end
342333
334+ xname = gensym (:MTK )
335+
336+ arr_sys_expr = (typeof (rhss) <: Vector || typeof (rhss) <: Matrix ) && ! (eltype (rhss) <: AbstractArray ) ? quote
337+ if typeof ($ (fargs. args[1 ])) <: Union{ModelingToolkit.StaticArrays.SArray,ModelingToolkit.LabelledArrays.SLArray}
338+ $ xname = ModelingToolkit. StaticArrays. @SArray $ arr_sys_expr
339+ convert (typeof ($ (fargs. args[1 ])),$ xname)
340+ else
341+ $ xname = $ arr_sys_expr
342+ if ! (typeof ($ (fargs. args[1 ])) <: Array )
343+ convert (typeof ($ (fargs. args[1 ])),$ xname)
344+ else
345+ $ xname
346+ end
347+ end
348+ end : arr_sys_expr
349+
343350 let_expr = Expr (:let , var_eqs, tuple_sys_expr)
344351 arr_let_expr = Expr (:let , var_eqs, arr_sys_expr)
345352 bounds_block = checkbounds ? let_expr : :(@inbounds begin $ let_expr end )
346- arr_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $ arr_let_expr end )
353+ oop_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $ arr_let_expr end )
347354 ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ ip_let_expr end )
348355
349- oop_body_block = :(
350- # If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
351- if $ (fargs. args[1 ]) isa Array || (! (typeof ($ (fargs. args[1 ])) <: StaticArray ) && $ (rhss isa SparseMatrixCSC))
352- return $ arr_bounds_block
353- else
354- X = $ bounds_block
355- construct = $ _constructor
356- return construct (X)
357- end
358- )
359-
360- oop_ex = headerfun (oop_body_block, fargs, false )
356+ oop_ex = headerfun (oop_bounds_block, fargs, false )
361357 iip_ex = headerfun (ip_bounds_block, fargs, true ; X= X)
362358
363359 if ! linenumbers
0 commit comments