Skip to content

Commit 037b533

Browse files
committed
Merge branch 'master' into speed-up-reaction-jump-sys
2 parents e268db2 + b328c1b commit 037b533

File tree

9 files changed

+133
-153
lines changed

9 files changed

+133
-153
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
4-
version = "3.11.1"
4+
version = "3.12.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -12,6 +12,7 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1212
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1313
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1414
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
15+
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1516
Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
1617
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1718
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -36,6 +37,7 @@ DiffEqJump = "6.7.5"
3637
DiffRules = "0.1, 1.0"
3738
DocStringExtensions = "0.7, 0.8"
3839
GeneralizedGenerated = "0.1.4, 0.2"
40+
LabelledArrays = "1.2"
3941
Latexify = "0.11, 0.12, 0.13"
4042
LightGraphs = "1.3"
4143
MacroTools = "0.5"

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module ModelingToolkit
22

33
using DiffEqBase, Distributed
4-
using StaticArrays, LinearAlgebra, SparseArrays
4+
using StaticArrays, LinearAlgebra, SparseArrays, LabelledArrays
55
using Latexify, Unitful, ArrayInterface
66
using MacroTools
77
using UnPack: @unpack

src/build_function.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ i.e., f(u,p,args...) for the out-of-place and scalar functions and
2424
```julia
2525
build_function(ex, args...;
2626
conv = simplified_expr, expression = Val{true},
27-
checkbounds = false, constructor=nothing,
27+
checkbounds = false,
2828
linenumbers = false, target = JuliaTarget())
2929
```
3030
@@ -46,8 +46,6 @@ Keyword Arguments:
4646
4747
- `checkbounds`: For whether to enable bounds checking inside of the generated
4848
function. Defaults to false, meaning that `@inbounds` is applied.
49-
- `constructor`: Allows for an arbitrary constructor function to be passed in
50-
for handling expressions of "weird" types. Defaults to nothing.
5149
- `linenumbers`: Determines whether the generated function expression retains
5250
the line numbers. Defaults to true.
5351
- `target`: The output target of the compilation process. Possible options are:
@@ -104,7 +102,7 @@ end
104102
# Scalar output
105103
function _build_function(target::JuliaTarget, op::Operation, args...;
106104
conv = simplified_expr, expression = Val{true},
107-
checkbounds = false, constructor=nothing,
105+
checkbounds = false,
108106
linenumbers = true, headerfun=addheader)
109107

110108
argnames = [gensym(:MTKArg) for i in 1:length(args)]
@@ -165,7 +163,7 @@ end
165163

166164
function _build_function(target::JuliaTarget, rhss, args...;
167165
conv = simplified_expr, expression = Val{true},
168-
checkbounds = false, constructor=nothing,
166+
checkbounds = false,
169167
linenumbers = false, multithread=nothing,
170168
headerfun=addheader, outputidxs=nothing,
171169
skipzeros = false, parallel=SerialForm())
@@ -323,41 +321,43 @@ function _build_function(target::JuliaTarget, rhss, args...;
323321

324322
if rhss isa Matrix
325323
arr_sys_expr = build_expr(:vcat, [build_expr(:row,[conv(rhs) for rhs rhss[i,:]]) for i in 1:size(rhss,1)])
326-
# : x because ??? what to do in the general case?
327-
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->(out = similar(typeof($(fargs.args[1])),$(size(rhss)...)); out .= x)) : constructor
328324
elseif typeof(rhss) <: Array && !(typeof(rhss) <: Vector)
329325
vector_form = build_expr(:vect, [conv(rhs) for rhs rhss])
330326
arr_sys_expr = :(reshape($vector_form,$(size(rhss)...)))
331-
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SArray{$(size(rhss)...)} : x->(out = similar(typeof($(fargs.args[1])),$(size(rhss)...)); out .= x)) : constructor
332327
elseif rhss isa SparseMatrixCSC
333328
vector_form = build_expr(:vect, [conv(rhs) for rhs nonzeros(rhss)])
334329
arr_sys_expr = :(SparseMatrixCSC{eltype($(first(argnames))),Int}($(size(rhss)...), $(rhss.colptr), $(rhss.rowval), $vector_form))
335-
# Static and sparse? Probably not a combo that will actually be hit, but give a default anyways
336-
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->x) : constructor
337330
else # Vector
338331
arr_sys_expr = build_expr(:vect, [conv(rhs) for rhs rhss])
339-
# Handle vector constructor separately using `typeof(u)` to support things like LabelledArrays
340-
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof($(fargs.args[1])), eltype(X)) : x->convert(typeof($(fargs.args[1])),x)) : constructor
341332
end
342333

334+
xname = gensym(:MTK)
335+
336+
arr_sys_expr = (typeof(rhss) <: Vector || typeof(rhss) <: Matrix) && !(eltype(rhss) <: AbstractArray) ? quote
337+
if typeof($(fargs.args[1])) <: Union{ModelingToolkit.StaticArrays.SArray,ModelingToolkit.LabelledArrays.SLArray}
338+
$xname = ModelingToolkit.StaticArrays.@SArray $arr_sys_expr
339+
if $(typeof(rhss) <: Vector) # Only try converting if it should match `u`
340+
convert(typeof($(fargs.args[1])),$xname)
341+
else
342+
$xname
343+
end
344+
else
345+
$xname = $arr_sys_expr
346+
if !(typeof($(fargs.args[1])) <: Array) && $(typeof(rhss) <: Vector)
347+
convert(typeof($(fargs.args[1])),$xname)
348+
else
349+
$xname
350+
end
351+
end
352+
end : arr_sys_expr
353+
343354
let_expr = Expr(:let, var_eqs, tuple_sys_expr)
344355
arr_let_expr = Expr(:let, var_eqs, arr_sys_expr)
345356
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
346-
arr_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end)
357+
oop_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end)
347358
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)
348359

349-
oop_body_block = :(
350-
# If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
351-
if $(fargs.args[1]) isa Array || (!(typeof($(fargs.args[1])) <: StaticArray) && $(rhss isa SparseMatrixCSC))
352-
return $arr_bounds_block
353-
else
354-
X = $bounds_block
355-
construct = $_constructor
356-
return construct(X)
357-
end
358-
)
359-
360-
oop_ex = headerfun(oop_body_block, fargs, false)
360+
oop_ex = headerfun(oop_bounds_block, fargs, false)
361361
iip_ex = headerfun(ip_bounds_block, fargs, true; X=X)
362362

363363
if !linenumbers

0 commit comments

Comments
 (0)