Skip to content

Commit 7f051bb

Browse files
Merge pull request #268 from SciML/auto_detect
Make variables typed and use types for auto-detection
2 parents 416f5e8 + f089bbd commit 7f051bb

File tree

10 files changed

+68
-61
lines changed

10 files changed

+68
-61
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ Get the set of parameters variables for the given system.
7878
function parameters end
7979

8080
include("variables.jl")
81+
include("context_dsl.jl")
8182
include("operations.jl")
8283
include("differentials.jl")
8384
include("equations.jl")

src/build_function.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
9393
end
9494

9595
get_varnumber(varop::Operation,vars::Vector{Operation}) = findfirst(x->isequal(x,varop),vars)
96-
get_varnumber(varop::Operation,vars::Vector{Variable}) = findfirst(x->isequal(x,varop.op),vars)
96+
get_varnumber(varop::Operation,vars::Vector{<:Variable}) = findfirst(x->isequal(x,varop.op),vars)
9797

9898
function numbered_expr(O::Equation,args...;kwargs...)
9999
:($(numbered_expr(O.lhs,args...;kwargs...)) = $(numbered_expr(O.rhs,args...;kwargs...)))
@@ -120,7 +120,7 @@ function numbered_expr(O::Operation,vars,parameters;
120120
varname=varname,paramname=paramname) for x in O.args]...)
121121
end
122122

123-
function numbered_expr(de::ModelingToolkit.Equation,vars::Vector{Variable},parameters;
123+
function numbered_expr(de::ModelingToolkit.Equation,vars::Vector{<:Variable},parameters;
124124
derivname=:du,varname=:u,paramname=:p)
125125
i = findfirst(x->isequal(x.name,var_from_nested_derivative(de.lhs)[1].name),vars)
126126
:($derivname[$i] = $(numbered_expr(de.rhs,vars,parameters;

src/context_dsl.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
struct Parameter{T} end
2+
isparameter(::Variable) = false
3+
isparameter(::Variable{<:Parameter}) = true
4+
5+
"""
6+
$(SIGNATURES)
7+
8+
Define one or more known variables.
9+
"""
10+
macro parameters(xs...)
11+
esc(_parse_vars(:parameters, Parameter{Number}, xs))
12+
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function calculate_factorized_W(sys::AbstractODESystem, simplify=true)
5555
isempty(sys.Wfact[]) || return (sys.Wfact[],sys.Wfact_t[])
5656

5757
jac = calculate_jacobian(sys)
58-
gam = Variable(:gam; known = true)()
58+
gam = Variable(:__MTKWgamma)()
5959

6060
W = - LinearAlgebra.I + gam*jac
6161
Wfact = lu(W, Val(false), check=false).factors
@@ -76,15 +76,15 @@ function calculate_factorized_W(sys::AbstractODESystem, simplify=true)
7676
end
7777

7878
function generate_factorized_W(sys::AbstractODESystem, vs = states(sys), ps = parameters(sys), simplify=true, expression = Val{true}; kwargs...)
79-
(Wfact,Wfact_t) = calculate_factorized_W(sys,simplify)
79+
Wfact,Wfact_t = calculate_factorized_W(sys,simplify)
8080
siz = size(Wfact)
8181
constructor = :(x -> begin
8282
A = SMatrix{$siz...}(x)
8383
StaticArrays.LU(LowerTriangular( SMatrix{$siz...}(UnitLowerTriangular(A)) ), UpperTriangular(A), SVector(ntuple(n->n, max($siz...))))
8484
end)
8585

86-
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor,kwargs...)
87-
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor,kwargs...)
86+
Wfact_func = build_function(Wfact , vs, ps, (:__MTKWgamma,:t), ODEToExpr(sys), expression;constructor=constructor,kwargs...)
87+
Wfact_t_func = build_function(Wfact_t, vs, ps, (:__MTKWgamma,:t), ODEToExpr(sys), expression;constructor=constructor,kwargs...)
8888

8989
return (Wfact_func, Wfact_t_func)
9090
end

src/systems/diffeqs/first_order_transform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function lower_varname(var::Variable, idv, order)
22
order == 0 && return var
33
name = Symbol(var.name, :_, string(idv.name)^order)
4-
return Variable(name; known = var.known)
4+
return Variable{vartype(var)}(name)
55
end
66

77
function flatten_differential(O::Operation)

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
99
@parameters t
1010
vars = [Variable(:x, i)(t) for i in eachindex(prob.u0)]
1111
params = prob.p isa DiffEqBase.NullParameters ? [] :
12-
[Variable(,i; known = true)() for i in eachindex(prob.p)]
12+
[Variable(,i)() for i in eachindex(prob.p)]
1313
@derivatives D'~t
1414

1515
rhs = [D(var) for var in vars]
@@ -22,7 +22,7 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
2222
end
2323

2424
eqs = vcat([rhs[i] ~ lhs[i] for i in eachindex(prob.u0)]...)
25-
de = ODESystem(eqs)
25+
de = ODESystem(eqs,t,vars,params)
2626

2727
de, vars, params
2828
end

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ function ODESystem(eqs; kwargs...)
8585

8686
dvs = unique(var_from_nested_derivative(eq.lhs)[1] for eq eqs)
8787
ps = filter(vars(eq.rhs for eq eqs)) do x
88-
x.known & !isequal(x, iv)
88+
isparameter(x) & !isequal(x, iv)
8989
end |> collect
9090
ODESystem(eqs, iv, dvs, ps; kwargs...)
9191
end

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ end
3535

3636
function detime_dvs(op::Operation)
3737
if op.op isa Variable
38-
Operation(Variable(op.op.name,known=op.op.known),Expression[])
38+
Operation(Variable{vartype(op.op)}(op.op.name),Expression[])
3939
else
4040
Operation(op.op,detime_dvs.(op.args))
4141
end
@@ -44,7 +44,7 @@ detime_dvs(op::Constant) = op
4444

4545
function retime_dvs(op::Operation,dvs,iv)
4646
if op.op isa Variable && op.op dvs
47-
Operation(Variable(op.op.name),Expression[iv])
47+
Operation(Variable{vartype(op.op)}(op.op.name),Expression[iv])
4848
else
4949
Operation(op.op,retime_dvs.(op.args,(dvs,),iv))
5050
end

src/variables.jl

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,35 @@ end
1717
"""
1818
$(TYPEDEF)
1919
20-
A named variable which represents a numerical value. The variable's value may
21-
be known (parameters, independent variables) or unknown (dependent variables).
20+
A named variable which represents a numerical value.
2221
2322
# Fields
2423
$(FIELDS)
2524
"""
26-
struct Variable <: Function
25+
struct Variable{T} <: Function
2726
"""The variable's unique name."""
2827
name::Symbol
29-
"""
30-
Whether the variable's value is known.
31-
"""
32-
known::Bool
33-
Variable(name; known = false) = new(name, known)
28+
Variable(name) = new{Number}(name)
29+
Variable{T}(name) where T = new{T}(name)
30+
function Variable{T}(name, indices...) where T
31+
var_name = Symbol("$(name)$(join(map_subscripts.(indices), "ˏ"))")
32+
Variable{T}(var_name)
33+
end
3434
end
35-
function Variable(name, indices...; known = false)
35+
36+
function Variable(name, indices...)
3637
var_name = Symbol("$(name)$(join(map_subscripts.(indices), "ˏ"))")
37-
Variable(var_name; known=known)
38+
Variable(var_name)
3839
end
3940

41+
vartype(::Variable{T}) where T = T
4042
(x::Variable)(args...) = Operation(x, collect(Expression, args))
4143

42-
Base.isequal(x::Variable, y::Variable) = (x.name, x.known) == (y.name, y.known)
44+
Base.isequal(x::Variable, y::Variable) = x.name == y.name
4345
Base.print(io::IO, x::Variable) = show(io, x)
4446
Base.show(io::IO, x::Variable) = print(io, x.name)
4547
function Base.show(io::IO, ::MIME"text/plain", x::Variable)
46-
known = x.known ? "known" : "unknown"
47-
print(io, x.name, " (callable ", known, " variable)")
48+
print(io, x.name)
4849
end
4950

5051

@@ -79,7 +80,7 @@ Base.convert(::Type{Expr}, c::Constant) = c.value
7980

8081

8182
# Build variables more easily
82-
function _parse_vars(macroname, known, x)
83+
function _parse_vars(macroname, type, x)
8384
ex = Expr(:block)
8485
var_names = Symbol[]
8586
# if parsing things in the form of
@@ -97,9 +98,9 @@ function _parse_vars(macroname, known, x)
9798
@assert iscall || isarray || issym "@$macroname expects a tuple of expressions or an expression of a tuple (`@$macroname x y z(t) v[1:3] w[1:2,1:4]` or `@$macroname x, y, z(t) v[1:3] w[1:2,1:4]`)"
9899

99100
if iscall
100-
var_name, expr = _construct_vars(_var.args[1], known, _var.args[2:end])
101+
var_name, expr = _construct_vars(_var.args[1], type, _var.args[2:end])
101102
else
102-
var_name, expr = _construct_vars(_var, known, nothing)
103+
var_name, expr = _construct_vars(_var, type, nothing)
103104
end
104105
push!(var_names, var_name)
105106
push!(ex.args, expr)
@@ -108,45 +109,45 @@ function _parse_vars(macroname, known, x)
108109
return ex
109110
end
110111

111-
function _construct_vars(_var, known, call_args)
112+
function _construct_vars(_var, type, call_args)
112113
issym = _var isa Symbol
113114
isarray = isa(_var, Expr) && _var.head == :ref
114115
if isarray
115116
var_name = _var.args[1]
116117
indices = _var.args[2:end]
117-
expr = _construct_array_vars(var_name, known, call_args, indices...)
118+
expr = _construct_array_vars(var_name, type, call_args, indices...)
118119
else
119120
# Implicit 0-args call
120121
var_name = _var
121-
expr = _construct_var(var_name, known, call_args)
122+
expr = _construct_var(var_name, type, call_args)
122123
end
123124
var_name, :($var_name = $expr)
124125
end
125126

126-
function _construct_var(var_name, known, call_args)
127+
function _construct_var(var_name, type, call_args)
127128
if call_args === nothing
128-
:(Variable($(Meta.quot(var_name)); known = $known)())
129+
:(Variable{$type}($(Meta.quot(var_name)))())
129130
elseif !isempty(call_args) && call_args[end] == :..
130-
:(Variable($(Meta.quot(var_name)); known = $known))
131+
:(Variable{$type}($(Meta.quot(var_name))))
131132
else
132-
:(Variable($(Meta.quot(var_name)); known = $known)($(call_args...)))
133+
:(Variable{$type}($(Meta.quot(var_name)))($(call_args...)))
133134
end
134135
end
135136

136-
function _construct_var(var_name, known, call_args, ind)
137+
function _construct_var(var_name, type, call_args, ind)
137138
if call_args === nothing
138-
:(Variable($(Meta.quot(var_name)), $ind...; known = $known)())
139+
:(Variable{$type}($(Meta.quot(var_name)), $ind...)())
139140
elseif !isempty(call_args) && call_args[end] == :..
140-
:(Variable($(Meta.quot(var_name)), $ind...; known = $known))
141+
:(Variable{$type}($(Meta.quot(var_name)), $ind...))
141142
else
142-
:(Variable($(Meta.quot(var_name)), $ind...; known = $known)($(call_args...)))
143+
:(Variable{$type}($(Meta.quot(var_name)), $ind...)($(call_args...)))
143144
end
144145
end
145146

146147

147-
function _construct_array_vars(var_name, known, call_args, indices...)
148+
function _construct_array_vars(var_name, type, call_args, indices...)
148149
:(map(Iterators.product($(indices...))) do ind
149-
$(_construct_var(var_name, known, call_args, :ind))
150+
$(_construct_var(var_name, type, call_args, :ind))
150151
end)
151152
end
152153

@@ -157,20 +158,11 @@ $(SIGNATURES)
157158
Define one or more unknown variables.
158159
"""
159160
macro variables(xs...)
160-
esc(_parse_vars(:variables, false, xs))
161-
end
162-
163-
"""
164-
$(SIGNATURES)
165-
166-
Define one or more known variables.
167-
"""
168-
macro parameters(xs...)
169-
esc(_parse_vars(:parameters, true, xs))
161+
esc(_parse_vars(:variables, Number, xs))
170162
end
171163

172164
function rename(x::Variable,name::Symbol)
173-
Variable(name,known=x.known)
165+
Variable{vartype(x)}(name)
174166
end
175167

176168
TreeViews.hastreeview(x::Variable) = true

test/variable_parsing.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using ModelingToolkit
22
using Test
33

4+
const PType = ModelingToolkit.Parameter{Number}
5+
46
@parameters t
57
@variables x(t) y(t) # test multi-arg
68
@variables z(t) # test single-arg
@@ -17,9 +19,9 @@ z1 = Variable(:z)(t)
1719
end
1820
@parameters σ(..)
1921

20-
t1 = Variable(:t; known = true)()
21-
s1 = Variable(:s; known = true)()
22-
σ1 = Variable(; known = true)
22+
t1 = Variable{PType}(:t)()
23+
s1 = Variable{PType}(:s)()
24+
σ1 = Variable()
2325
@test isequal(t1, t)
2426
@test isequal(s1, s)
2527
@test isequal(σ1, σ)
@@ -42,12 +44,12 @@ convert(Expression, :($x == 0 ? $y : $x))
4244
end
4345
@parameters σ[1:2](..)
4446

45-
t1 = [Variable(:t, 1; known = true)(),
46-
Variable(:t, 2; known = true)()]
47-
s1 = [Variable(:s, 1, 1; known = true)() Variable(:s, 1, 2; known = true)()
48-
Variable(:s, 3, 1; known = true)() Variable(:s, 3, 2; known = true)()]
49-
σ1 = [Variable(, 1; known = true),
50-
Variable(, 2; known = true)]
47+
t1 = [Variable{PType}(:t, 1)(),
48+
Variable{PType}(:t, 2)()]
49+
s1 = [Variable{PType}(:s, 1, 1)() Variable{PType}(:s, 1, 2)()
50+
Variable{PType}(:s, 3, 1)() Variable{PType}(:s, 3, 2)()]
51+
σ1 = [Variable(, 1),
52+
Variable(, 2)]
5153
@test isequal(t1, t)
5254
@test isequal(s1, s)
5355
@test isequal(σ1, σ)

0 commit comments

Comments
 (0)