|
| 1 | +abstract type BuildTargets end |
| 2 | +struct JuliaTarget <: BuildTargets end |
| 3 | +struct StanTarget <: BuildTargets end |
| 4 | +struct CTarget <: BuildTargets end |
| 5 | +struct MATLABTarget <: BuildTargets end |
| 6 | + |
| 7 | +function build_function(args...;target = JuliaTarget(),kwargs...) |
| 8 | + _build_function(target,args...;kwargs...) |
| 9 | +end |
| 10 | + |
| 11 | +function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (), |
| 12 | + conv = simplified_expr, expression = Val{true}; |
| 13 | + checkbounds = false, constructor=nothing, |
| 14 | + linenumbers = true) |
| 15 | + _vs = map(x-> x isa Operation ? x.op : x, vs) |
| 16 | + _ps = map(x-> x isa Operation ? x.op : x, ps) |
| 17 | + var_pairs = [(u.name, :(u[$i])) for (i, u) ∈ enumerate(_vs)] |
| 18 | + param_pairs = [(p.name, :(p[$i])) for (i, p) ∈ enumerate(_ps)] |
| 19 | + (ls, rs) = zip(var_pairs..., param_pairs...) |
| 20 | + |
| 21 | + var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs)) |
| 22 | + |
| 23 | + fname = gensym(:ModelingToolkitFunction) |
| 24 | + |
| 25 | + X = gensym(:MTIIPVar) |
| 26 | + if rhss isa SparseMatrixCSC |
| 27 | + ip_sys_exprs = [:($X.nzval[$i] = $(conv(rhs))) for (i, rhs) ∈ enumerate(rhss.nzval)] |
| 28 | + else |
| 29 | + ip_sys_exprs = [:($X[$i] = $(conv(rhs))) for (i, rhs) ∈ enumerate(rhss)] |
| 30 | + end |
| 31 | + |
| 32 | + ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs)) |
| 33 | + |
| 34 | + tuple_sys_expr = build_expr(:tuple, [conv(rhs) for rhs ∈ rhss]) |
| 35 | + |
| 36 | + if rhss isa Matrix |
| 37 | + arr_sys_expr = build_expr(:vcat, [build_expr(:row,[conv(rhs) for rhs ∈ rhss[i,:]]) for i in 1:size(rhss,1)]) |
| 38 | + # : x because ??? what to do in the general case? |
| 39 | + _constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->(out = similar(typeof(u),$(size(rhss)...)); out .= x)) : constructor |
| 40 | + elseif typeof(rhss) <: Array && !(typeof(rhss) <: Vector) |
| 41 | + vector_form = build_expr(:vect, [conv(rhs) for rhs ∈ rhss]) |
| 42 | + arr_sys_expr = :(reshape($vector_form,$(size(rhss)...))) |
| 43 | + _constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SArray{$(size(rhss)...)} : x->(out = similar(typeof(u),$(size(rhss)...)); out .= x)) : constructor |
| 44 | + elseif rhss isa SparseMatrixCSC |
| 45 | + vector_form = build_expr(:vect, [conv(rhs) for rhs ∈ nonzeros(rhss)]) |
| 46 | + arr_sys_expr = :(SparseMatrixCSC{eltype(u),Int}($(size(rhss)...), $(rhss.colptr), $(rhss.rowval), $vector_form)) |
| 47 | + # Static and sparse? Probably not a combo that will actually be hit, but give a default anyways |
| 48 | + _constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->x) : constructor |
| 49 | + else # Vector |
| 50 | + arr_sys_expr = build_expr(:vect, [conv(rhs) for rhs ∈ rhss]) |
| 51 | + # Handle vector constructor separately using `typeof(u)` to support things like LabelledArrays |
| 52 | + _constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X)) : x->convert(typeof(u),x)) : constructor |
| 53 | + end |
| 54 | + |
| 55 | + let_expr = Expr(:let, var_eqs, tuple_sys_expr) |
| 56 | + arr_let_expr = Expr(:let, var_eqs, arr_sys_expr) |
| 57 | + bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end) |
| 58 | + arr_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end) |
| 59 | + ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end) |
| 60 | + |
| 61 | + fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...)) |
| 62 | + |
| 63 | + oop_ex = :( |
| 64 | + ($(fargs.args...),) -> begin |
| 65 | + # If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways |
| 66 | + if $(fargs.args[1]) isa Array || (!(typeof($(fargs.args[1])) <: StaticArray) && $(rhss isa SparseMatrixCSC)) |
| 67 | + return $arr_bounds_block |
| 68 | + else |
| 69 | + X = $bounds_block |
| 70 | + construct = $_constructor |
| 71 | + return construct(X) |
| 72 | + end |
| 73 | + end |
| 74 | + ) |
| 75 | + |
| 76 | + iip_ex = :( |
| 77 | + ($X,$(fargs.args...)) -> begin |
| 78 | + $ip_bounds_block |
| 79 | + nothing |
| 80 | + end |
| 81 | + ) |
| 82 | + |
| 83 | + if !linenumbers |
| 84 | + oop_ex = striplines(oop_ex) |
| 85 | + iip_ex = striplines(iip_ex) |
| 86 | + end |
| 87 | + |
| 88 | + if expression == Val{true} |
| 89 | + return oop_ex, iip_ex |
| 90 | + else |
| 91 | + return GeneralizedGenerated.mk_function(@__MODULE__,oop_ex), GeneralizedGenerated.mk_function(@__MODULE__,iip_ex) |
| 92 | + end |
| 93 | +end |
| 94 | + |
| 95 | +function numbered_expr(O::Equation,args...;kwargs...) |
| 96 | + :($(numbered_expr(O.lhs,args...;kwargs...)) = $(numbered_expr(O.rhs,args...;kwargs...))) |
| 97 | +end |
| 98 | + |
| 99 | +function numbered_expr(O::Operation,vars,parameters; |
| 100 | + derivname=:du, |
| 101 | + varname=:u,paramname=:p) |
| 102 | + if isa(O.op, ModelingToolkit.Differential) |
| 103 | + varop = O.args[1] |
| 104 | + i = findfirst(x->isequal(x,varop),vars) |
| 105 | + return :($derivname[$i]) |
| 106 | + elseif isa(O.op, ModelingToolkit.Variable) |
| 107 | + i = findfirst(x->isequal(x,O),vars) |
| 108 | + if i == nothing |
| 109 | + i = findfirst(x->isequal(x,O),parameters) |
| 110 | + return :($paramname[$i]) |
| 111 | + else |
| 112 | + return :($varname[$i]) |
| 113 | + end |
| 114 | + end |
| 115 | + return Expr(:call, Symbol(O.op), |
| 116 | + [numbered_expr(x,vars,parameters;derivname=derivname, |
| 117 | + varname=varname,paramname=paramname) for x in O.args]...) |
| 118 | +end |
| 119 | + |
| 120 | +numbered_expr(c::ModelingToolkit.Constant,args...;kwargs...) = c.value |
| 121 | + |
| 122 | +function _build_function(target::StanTarget, eqs, vs, ps, iv, |
| 123 | + conv = simplified_expr, expression = Val{true}; |
| 124 | + fname = :diffeqf, derivname=:internal_var___du, |
| 125 | + varname=:internal_var___u,paramname=:internal_var___p) |
| 126 | + differential_equation = string(join([numbered_expr(eq,vs,ps,derivname=derivname, |
| 127 | + varname=varname,paramname=paramname) for |
| 128 | + (i, eq) ∈ enumerate(eqs)],";\n "),";") |
| 129 | + """ |
| 130 | + real[] $fname(real $iv,real[] $varname,real[] $paramname,real[] x_r,int[] x_i) { |
| 131 | + real $derivname[$(length(eqs))]; |
| 132 | + $differential_equation |
| 133 | + return $derivname; |
| 134 | + } |
| 135 | + """ |
| 136 | +end |
| 137 | + |
| 138 | +function _build_function(target::CTarget, eqs, vs, ps, iv, |
| 139 | + conv = simplified_expr, expression = Val{true}; |
| 140 | + fname = :diffeqf, derivname=:internal_var___du, |
| 141 | + varname=:internal_var___u,paramname=:internal_var___p) |
| 142 | + differential_equation = string(join([numbered_expr(eq,vs,ps,derivname=derivname, |
| 143 | + varname=varname,paramname=paramname) for |
| 144 | + (i, eq) ∈ enumerate(eqs)],";\n "),";") |
| 145 | + """ |
| 146 | + void $fname(double* $derivname, double* $varname, double* $paramname, $iv) { |
| 147 | + $differential_equation |
| 148 | + } |
| 149 | + """ |
| 150 | +end |
| 151 | + |
| 152 | +function _build_function(target::MATLABTarget, eqs, vs, ps, iv, |
| 153 | + conv = simplified_expr, expression = Val{true}; |
| 154 | + fname = :diffeqf, derivname=:internal_var___du, |
| 155 | + varname=:internal_var___u,paramname=:internal_var___p) |
| 156 | + matstr = join([numbered_expr(eq.rhs,vs,ps,derivname=derivname, |
| 157 | + varname=varname,paramname=paramname) for |
| 158 | + (i, eq) ∈ enumerate(eqs)],"; ") |
| 159 | + |
| 160 | + matstr = replace(matstr,"["=>"(") |
| 161 | + matstr = replace(matstr,"]"=>")") |
| 162 | + matstr = "$fname = @(t,$varname) ["*matstr*"];" |
| 163 | + matstr |
| 164 | +end |
0 commit comments