Skip to content

Commit 2b59e8a

Browse files
Merge pull request #277 from SciML/dot
add dot syntax, variable searching in scope, and variable->operation
2 parents 87c8a35 + 2b1c0eb commit 2b59e8a

File tree

7 files changed

+133
-62
lines changed

7 files changed

+133
-62
lines changed

src/operations.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ Operation(x) = convert(Operation, x)
6666
#convert to Expr
6767
Base.Expr(op::Operation) = simplified_expr(op)
6868
Base.convert(::Type{Expr},x::Operation) = Expr(x)
69+
function Base.convert(::Type{Variable},x::Operation)
70+
x.op isa Variable ? x.op : throw(error("This Operation is not a Variable"))
71+
end
6972

7073
# promotion
7174
Base.promote_rule(::Type{<:Constant}, ::Type{<:Operation}) = Operation

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 78 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -107,60 +107,6 @@ function calculate_massmatrix(sys::AbstractODESystem, simplify=true)
107107
M == I ? I : M
108108
end
109109

110-
"""
111-
$(SIGNATURES)
112-
113-
Create an `ODEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and `ps`
114-
are used to set the order of the dependent variable and parameter vectors,
115-
respectively.
116-
"""
117-
function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
118-
ps = parameters(sys);
119-
version = nothing, tgrad=false,
120-
jac = false, Wfact = false) where {iip}
121-
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false})
122-
123-
f(u,p,t) = f_oop(u,p,t)
124-
f(du,u,p,t) = f_iip(du,u,p,t)
125-
126-
if tgrad
127-
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps, Val{false})
128-
_tgrad(u,p,t) = tgrad_oop(u,p,t)
129-
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
130-
else
131-
_tgrad = nothing
132-
end
133-
134-
if jac
135-
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false})
136-
_jac(u,p,t) = jac_oop(u,p,t)
137-
_jac(J,u,p,t) = jac_iip(J,u,p,t)
138-
else
139-
_jac = nothing
140-
end
141-
142-
if Wfact
143-
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true, Val{false})
144-
Wfact_oop, Wfact_iip = tmp_Wfact
145-
Wfact_oop_t, Wfact_iip_t = tmp_Wfact_t
146-
_Wfact(u,p,dtgamma,t) = Wfact_oop(u,p,dtgamma,t)
147-
_Wfact(W,u,p,dtgamma,t) = Wfact_iip(W,u,p,dtgamma,t)
148-
_Wfact_t(u,p,dtgamma,t) = Wfact_oop_t(u,p,dtgamma,t)
149-
_Wfact_t(W,u,p,dtgamma,t) = Wfact_iip_t(W,u,p,dtgamma,t)
150-
else
151-
_Wfact,_Wfact_t = nothing,nothing
152-
end
153-
154-
M = calculate_massmatrix(sys)
155-
156-
ODEFunction{iip}(f,jac=_jac,
157-
tgrad = _tgrad,
158-
Wfact = _Wfact,
159-
Wfact_t = _Wfact_t,
160-
mass_matrix = M,
161-
syms = Symbol.(sys.dvs))
162-
end
163-
164110
renamespace(namespace,name) = Symbol(string(namespace)*""*string(name))
165111

166112
function DiffEqBase.ODEFunction(sys::AbstractODESystem, args...; kwargs...)
@@ -192,6 +138,8 @@ function namespace_operation(O::Operation,name,ivname)
192138
end
193139
namespace_operation(O::Constant,name,ivname) = O
194140

141+
142+
195143
independent_variable(sys::AbstractODESystem) = sys.iv
196144
states(sys::AbstractODESystem) = isempty(sys.systems) ? sys.dvs : [sys.dvs;reduce(vcat,namespace_variables.(sys.systems))]
197145
parameters(sys::AbstractODESystem) = isempty(sys.systems) ? sys.ps : [sys.ps;reduce(vcat,namespace_parameters.(sys.systems))]
@@ -202,12 +150,12 @@ end
202150

203151
function states(sys::AbstractODESystem,name::Symbol)
204152
x = sys.dvs[findfirst(x->x.name==name,sys.dvs)]
205-
Variable(Symbol(string(sys.name)*""*string(x.name)),known=x.known)(sys.iv())
153+
Variable(Symbol(string(sys.name)*""*string(x.name)))(sys.iv())
206154
end
207155

208156
function parameters(sys::AbstractODESystem,name::Symbol)
209157
x = sys.ps[findfirst(x->x.name==name,sys.ps)]
210-
Variable(Symbol(string(sys.name)*""*string(x.name)),known=x.known)(sys.iv())
158+
Variable(Symbol(string(sys.name)*""*string(x.name)))(sys.iv())
211159
end
212160

213161
function states(sys::AbstractODESystem,args...)
@@ -234,3 +182,77 @@ function _eq_unordered(a, b)
234182
end
235183
return true
236184
end
185+
186+
"""
187+
$(SIGNATURES)
188+
189+
Create an `ODEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and `ps`
190+
are used to set the order of the dependent variable and parameter vectors,
191+
respectively.
192+
"""
193+
function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
194+
ps = parameters(sys);
195+
version = nothing, tgrad=false,
196+
jac = false, Wfact = false) where {iip}
197+
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false})
198+
199+
f(u,p,t) = f_oop(u,p,t)
200+
f(du,u,p,t) = f_iip(du,u,p,t)
201+
202+
if tgrad
203+
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps, Val{false})
204+
_tgrad(u,p,t) = tgrad_oop(u,p,t)
205+
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
206+
else
207+
_tgrad = nothing
208+
end
209+
210+
if jac
211+
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false})
212+
_jac(u,p,t) = jac_oop(u,p,t)
213+
_jac(J,u,p,t) = jac_iip(J,u,p,t)
214+
else
215+
_jac = nothing
216+
end
217+
218+
if Wfact
219+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true, Val{false})
220+
Wfact_oop, Wfact_iip = tmp_Wfact
221+
Wfact_oop_t, Wfact_iip_t = tmp_Wfact_t
222+
_Wfact(u,p,dtgamma,t) = Wfact_oop(u,p,dtgamma,t)
223+
_Wfact(W,u,p,dtgamma,t) = Wfact_iip(W,u,p,dtgamma,t)
224+
_Wfact_t(u,p,dtgamma,t) = Wfact_oop_t(u,p,dtgamma,t)
225+
_Wfact_t(W,u,p,dtgamma,t) = Wfact_iip_t(W,u,p,dtgamma,t)
226+
else
227+
_Wfact,_Wfact_t = nothing,nothing
228+
end
229+
230+
M = calculate_massmatrix(sys)
231+
232+
ODEFunction{iip}(f,jac=_jac,
233+
tgrad = _tgrad,
234+
Wfact = _Wfact,
235+
Wfact_t = _Wfact_t,
236+
mass_matrix = M,
237+
syms = Symbol.(sys.dvs))
238+
end
239+
240+
function Base.getproperty(sys::AbstractODESystem, name::Symbol)
241+
if name fieldnames(typeof(sys))
242+
return getfield(sys,name)
243+
elseif !isempty(sys.systems)
244+
i = findfirst(x->x.name==name,sys.systems)
245+
if i !== nothing
246+
return rename(sys.systems[i],renamespace(sys.name,name))
247+
end
248+
end
249+
i = findfirst(x->x.name==name,sys.dvs)
250+
if i !== nothing
251+
return rename(sys.dvs[i],renamespace(sys.name,name))(getfield(sys,:iv)())
252+
end
253+
i = findfirst(x->x.name==name,sys.ps)
254+
if i !== nothing
255+
return rename(sys.ps[i],renamespace(sys.name,name))()
256+
end
257+
throw(error("Variable name does not exist"))
258+
end

src/systems/diffeqs/first_order_transform.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function ode_order_lowering(eqs, iv)
3030
any(isequal(var), vars) || push!(vars, var)
3131
end
3232
var′ = lower_varname(var, iv, maxorder - 1)
33-
rhs′ = rename(eq.rhs)
33+
rhs′ = rename_lower_order(eq.rhs)
3434
push!(new_eqs,Differential(iv())(var′(iv())) ~ rhs′)
3535
end
3636

@@ -50,11 +50,11 @@ function ode_order_lowering(eqs, iv)
5050
return (new_eqs, new_vars)
5151
end
5252

53-
function rename(O::Expression)
53+
function rename_lower_order(O::Expression)
5454
isa(O, Operation) || return O
5555
if is_derivative(O)
5656
(x, t, order) = flatten_differential(O)
5757
return lower_varname(x.op, t.op, order)(x.args...)
5858
end
59-
return Operation(O.op, rename.(O.args))
59+
return Operation(O.op, rename_lower_order.(O.args))
6060
end

src/systems/diffeqs/odesystem.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,7 @@ Base.:(==)(sys1::ODESystem, sys2::ODESystem) =
9494
_eq_unordered(sys1.eqs, sys2.eqs) && isequal(sys1.iv, sys2.iv) &&
9595
_eq_unordered(sys1.dvs, sys2.dvs) && _eq_unordered(sys1.ps, sys2.ps)
9696
# NOTE: equality does not check cached Jacobian
97+
98+
function rename(sys::ODESystem,name)
99+
ODESystem(sys.eqs, sys.iv, sys.dvs, sys.ps, sys.tgrad, sys.jac, sys.Wfact, sys.Wfact_t, name, sys.systems)
100+
end

src/systems/diffeqs/sdesystem.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,7 @@ end
117117
function DiffEqBase.SDEFunction(sys::SDESystem, args...; kwargs...)
118118
SDEFunction{true}(sys, args...; kwargs...)
119119
end
120+
121+
function rename(sys::SDESystem,name)
122+
ODESystem(sys.eqs, sys.noiseeqs, sys.iv, sys.dvs, sys.ps, sys.tgrad, sys.jac, sys.Wfact, sys.Wfact_t, name, sys.systems)
123+
end

src/variables.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ end
4040

4141
vartype(::Variable{T}) where T = T
4242
(x::Variable)(args...) = Operation(x, collect(Expression, args))
43+
rename(x::Variable{T},name) where T = Variable{T}(name)
4344

4445
Base.isequal(x::Variable, y::Variable) = x.name == y.name
4546
Base.print(io::IO, x::Variable) = show(io, x)

test/components.jl

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,26 @@ lorenz2 = ODESystem(eqs,t,[x,y,z],[σ,ρ,β],name=:lorenz2)
1616

1717
@parameters α
1818
@variables a(t)
19-
connnectedeqs = [D(a) ~ a*states(lorenz1,:x)]
19+
connnectedeqs = [D(a) ~ a*lorenz1.x]
2020

2121
connected1 = ODESystem(connnectedeqs,t,[a],[α],systems=[lorenz1,lorenz2],name=:connected1)
2222

23+
eqs_flat = [D(a) ~ a*lorenz1.x,
24+
D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x),
25+
D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-lorenz1.y,
26+
0 ~ lorenz1.x + lorenz1.y + lorenz1.β*lorenz1.z,
27+
D(lorenz2.x) ~ lorenz2.σ*(lorenz2.y-lorenz2.x),
28+
D(lorenz2.y) ~ lorenz2.x*(lorenz2.ρ-lorenz2.z)-lorenz2.y,
29+
0 ~ lorenz2.x + lorenz2.y + lorenz2.β*lorenz2.z]
30+
31+
@test states(connected1) == convert.(Variable,[a,lorenz1.x,lorenz1.y,lorenz1.z,lorenz2.x,lorenz2.y,lorenz2.z])
32+
@test parameters(connected1) == convert.(Variable,[α,lorenz1.σ,lorenz1.ρ,lorenz1.β,lorenz2.σ,lorenz2.ρ,lorenz2.β])
33+
@test eqs_flat == equations(connected1)
34+
2335
@variables lorenz1′x(t) lorenz1′y(t) lorenz1′z(t) lorenz2′x(t) lorenz2′y(t) lorenz2′z(t)
2436
@parameters lorenz1′σ lorenz1′ρ lorenz1′β lorenz2′σ lorenz2′ρ lorenz2′β
2537

26-
eqs_flat = [D(a) ~ a*lorenz1′x,
38+
eqs_flat2 = [D(a) ~ a*lorenz1′x,
2739
D(lorenz1′x) ~ lorenz1′σ*(lorenz1′y-lorenz1′x),
2840
D(lorenz1′y) ~ lorenz1′x*(lorenz1′ρ-lorenz1′z)-lorenz1′y,
2941
0 ~ lorenz1′x + lorenz1′y + lorenz1′β*lorenz1′z,
@@ -39,9 +51,34 @@ connected2 = ODESystem(connnectedeqs,t,[a],[α],systems=[lorenz1,lorenz2],name=:
3951

4052
@parameters γ
4153
@variables g(t)
42-
connnectedeqs2 = [D(g) ~ g*states(connected1,lorenz1,:x)]
54+
connnectedeqs2 = [D(g) ~ g*connected1.lorenz1.x]
4355
doublelevel = ODESystem(connnectedeqs2,t,[g],[γ],systems=[connected1,connected2],name=:doublelevel)
4456

57+
@test states(doublelevel) == convert.(Variable,[g,connected1.a,connected1.lorenz1.x,connected1.lorenz1.y,connected1.lorenz1.z,connected1.lorenz2.x,connected1.lorenz2.y,connected1.lorenz2.z,
58+
connected2.a,connected2.lorenz1.x,connected2.lorenz1.y,connected2.lorenz1.z,connected2.lorenz2.x,connected2.lorenz2.y,connected2.lorenz2.z])
59+
60+
@test parameters(doublelevel) == convert.(Variable,[γ,
61+
connected1.α,connected1.lorenz1.σ,connected1.lorenz1.ρ,connected1.lorenz1.β,connected1.lorenz2.σ,connected1.lorenz2.ρ,connected1.lorenz2.β,
62+
connected2.α,connected2.lorenz1.σ,connected2.lorenz1.ρ,connected2.lorenz1.β,connected2.lorenz2.σ,connected2.lorenz2.ρ,connected2.lorenz2.β])
63+
64+
eqs_flat = [D(g) ~ g*connected1.lorenz1.x,
65+
D(connected1.a) ~ connected1.a*connected1.lorenz1.x,
66+
D(connected1.lorenz1.x) ~ connected1.lorenz1.σ*(connected1.lorenz1.y-connected1.lorenz1.x),
67+
D(connected1.lorenz1.y) ~ connected1.lorenz1.x*(connected1.lorenz1.ρ-connected1.lorenz1.z)-connected1.lorenz1.y,
68+
0 ~ connected1.lorenz1.x + connected1.lorenz1.y + connected1.lorenz1.β*connected1.lorenz1.z,
69+
D(connected1.lorenz2.x) ~ connected1.lorenz2.σ*(connected1.lorenz2.y-connected1.lorenz2.x),
70+
D(connected1.lorenz2.y) ~ connected1.lorenz2.x*(connected1.lorenz2.ρ-connected1.lorenz2.z)-connected1.lorenz2.y,
71+
0 ~ connected1.lorenz2.x + connected1.lorenz2.y + connected1.lorenz2.β*connected1.lorenz2.z,
72+
D(connected2.a) ~ connected2.a*connected2.lorenz1.x,
73+
D(connected2.lorenz1.x) ~ connected2.lorenz1.σ*(connected2.lorenz1.y-connected2.lorenz1.x),
74+
D(connected2.lorenz1.y) ~ connected2.lorenz1.x*(connected2.lorenz1.ρ-connected2.lorenz1.z)-connected2.lorenz1.y,
75+
0 ~ connected2.lorenz1.x + connected2.lorenz1.y + connected2.lorenz1.β*connected2.lorenz1.z,
76+
D(connected2.lorenz2.x) ~ connected2.lorenz2.σ*(connected2.lorenz2.y-connected2.lorenz2.x),
77+
D(connected2.lorenz2.y) ~ connected2.lorenz2.x*(connected2.lorenz2.ρ-connected2.lorenz2.z)-connected2.lorenz2.y,
78+
0 ~ connected2.lorenz2.x + connected2.lorenz2.y + connected2.lorenz2.β*connected2.lorenz2.z]
79+
80+
@test eqs_flat == equations(doublelevel)
81+
4582
@test [x.name for x in states(doublelevel)] == [:g,
4683
:connected1′a,:connected1′lorenz1′x,:connected1′lorenz1′y,:connected1′lorenz1′z,:connected1′lorenz2′x,:connected1′lorenz2′y,:connected1′lorenz2′z,
4784
:connected2′a,:connected2′lorenz1′x,:connected2′lorenz1′y,:connected2′lorenz1′z,:connected2′lorenz2′x,:connected2′lorenz2′y,:connected2′lorenz2′z]

0 commit comments

Comments
 (0)