Skip to content

Commit 387830f

Browse files
Merge pull request #444 from isaacsas/speed-up-reaction-jump-sys
Speed up JumpSystem generation from ReactionSystems
2 parents b328c1b + d35620c commit 387830f

File tree

3 files changed

+113
-51
lines changed

3 files changed

+113
-51
lines changed

src/systems/jumps/jumpsystem.jl

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ end
4545

4646
function JumpSystem(eqs, iv, states, ps; systems = JumpSystem[],
4747
name = gensym(:JumpSystem))
48-
4948
ap = ArrayPartition(MassActionJump[], ConstantRateJump[], VariableRateJump[])
5049
for eq in eqs
5150
if eq isa MassActionJump
@@ -62,6 +61,9 @@ function JumpSystem(eqs, iv, states, ps; systems = JumpSystem[],
6261
JumpSystem{typeof(ap)}(ap, convert(Variable,iv), convert.(Variable, states), convert.(Variable, ps), name, systems)
6362
end
6463

64+
JumpSystem(eqs::ArrayPartition, iv, states, ps; systems = JumpSystem[], name = gensym(:JumpSystem)) =
65+
JumpSystem{typeof(eqs)}(eqs, convert(Variable,iv), convert.(Variable, states), convert.(Variable, ps), name, systems)
66+
6567

6668
generate_rate_function(js, rate) = build_function(rate, states(js), parameters(js),
6769
independent_variable(js),
@@ -114,45 +116,64 @@ function assemble_crj_expr(js, crj, statetoid)
114116
end
115117
end
116118

117-
function assemble_maj(js, maj::MassActionJump{U,Vector{Pair{V,W}},Vector{Pair{V2,W2}}},
118-
statetoid, subber, invttype) where {U,V,W,V2,W2}
119-
sr = maj.scaled_rates
120-
if sr isa Operation
121-
pval = subber(sr).value
122-
elseif sr isa Variable
123-
pval = subber(sr()).value
119+
function numericrate(rate, subber)
120+
if rate isa Operation
121+
rval = subber(rate).value
122+
elseif rate isa Variable
123+
rval = subber(rate()).value
124124
else
125-
pval = maj.scaled_rates
125+
rval = rate
126126
end
127+
rval
128+
end
127129

130+
function numericrstoich(mtrs::Vector{Pair{V,W}}, statetoid) where {V,W}
128131
rs = Vector{Pair{Int,W}}()
129-
for (spec,stoich) in maj.reactant_stoch
130-
if iszero(spec)
132+
for (spec,stoich) in mtrs
133+
if !(spec isa Operation) && iszero(spec)
131134
push!(rs, 0 => stoich)
132135
else
133136
push!(rs, statetoid[convert(Variable,spec)] => stoich)
134137
end
135138
end
136139
sort!(rs)
140+
rs
141+
end
137142

138-
ns = Vector{Pair{Int,W2}}()
139-
for (spec,stoich) in maj.net_stoch
140-
iszero(spec) && error("Net stoichiometry can not have a species labelled 0.")
143+
function numericnstoich(mtrs::Vector{Pair{V,W}}, statetoid) where {V,W}
144+
ns = Vector{Pair{Int,W}}()
145+
for (spec,stoich) in mtrs
146+
!(spec isa Operation) && iszero(spec) && error("Net stoichiometry can not have a species labelled 0.")
141147
push!(ns, statetoid[convert(Variable,spec)] => stoich)
142148
end
143149
sort!(ns)
150+
end
144151

145-
maj = MassActionJump(convert(invttype, pval), rs, ns, scale_rates = false)
146-
return maj
152+
# assemble a numeric MassActionJump from a MT MassActionJump representing one rx.
153+
function assemble_maj(maj::MassActionJump, statetoid, subber, invttype)
154+
rval = numericrate(maj.scaled_rates, subber)
155+
rs = numericrstoich(maj.reactant_stoch, statetoid)
156+
ns = numericnstoich(maj.net_stoch, statetoid)
157+
maj = MassActionJump(convert(invttype, rval), rs, ns, scale_rates = false)
158+
maj
147159
end
148160

161+
# For MassActionJumps that contain many reactions
162+
# function assemble_maj(maj::MassActionJump{U,V,W}, statetoid, subber,
163+
# invttype) where {U <: AbstractVector,V,W}
164+
# rval = [convert(invttype,numericrate(sr, subber)) for sr in maj.scaled_rates]
165+
# rs = [numericrstoich(rs, statetoid) for rs in maj.reactant_stoch]
166+
# ns = [numericnstoich(ns, statetoid) for ns in maj.net_stoch]
167+
# maj = MassActionJump(rval, rs, ns, scale_rates = false)
168+
# maj
169+
# end
149170
"""
150171
```julia
151172
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan,
152173
parammap=DiffEqBase.NullParameters; kwargs...)
153174
```
154175
155-
Generates a blank DiscreteProblem for a pure jump JumpSystem to utilize as
176+
Generates a blank DiscreteProblem for a pure jump JumpSystem to utilize as
156177
its `prob.prob`. This is used in the case where there are no ODEs
157178
and no SDEs associated with the system.
158179
@@ -167,9 +188,17 @@ dprob = DiscreteProblem(js, u₀map, tspan, parammap)
167188
"""
168189
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Tuple,
169190
parammap=DiffEqBase.NullParameters(); kwargs...)
170-
u0 = varmap_to_vars(u0map, states(sys))
171-
p = varmap_to_vars(parammap, parameters(sys))
172-
# identity function to make syms works
191+
192+
(u0map isa AbstractVector) || error("For DiscreteProblems u0map must be an AbstractVector.")
193+
u0d = Dict( convert(Variable,u[1]) => u[2] for u in u0map)
194+
u0 = [u0d[u] for u in states(sys)]
195+
if parammap != DiffEqBase.NullParameters()
196+
(parammap isa AbstractVector) || error("For DiscreteProblems parammap must be an AbstractVector.")
197+
pd = Dict( convert(Variable,u[1]) => u[2] for u in parammap)
198+
p = [pd[u] for u in parameters(sys)]
199+
else
200+
p = parammap
201+
end
173202
# EvalFunc because we know that the jump functions are generated via eval
174203
f = DiffEqBase.EvalFunc(DiffEqBase.DISCRETE_INPLACE_DEFAULT)
175204
df = DiscreteFunction(f, syms=Symbol.(states(sys)))
@@ -235,7 +264,7 @@ function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
235264
parammap = map((x,y)->Pair(x(),y), parameters(js), p)
236265
subber = substituter(parammap)
237266

238-
majs = MassActionJump[assemble_maj(js, j, statetoid, subber, invttype) for j in eqs.x[1]]
267+
majs = MassActionJump[assemble_maj(j, statetoid, subber, invttype) for j in eqs.x[1]]
239268
crjs = ConstantRateJump[assemble_crj(js, j, statetoid) for j in eqs.x[2]]
240269
vrjs = VariableRateJump[assemble_vrj(js, j, statetoid) for j in eqs.x[3]]
241270
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")

src/systems/reaction/reactionsystem.jl

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,12 @@ function Reaction(rate, subs, prods, substoich, prodstoich;
6969
(isnothing(prodstoich)&&isnothing(substoich)) && error("Both substrate and product stochiometry inputs cannot be nothing.")
7070
if isnothing(subs)
7171
subs = Vector{Operation}()
72-
(substoich!=nothing) && error("If substrates are nothing, substrate stiocihometries have to be so too.")
72+
!isnothing(substoich) && error("If substrates are nothing, substrate stiocihometries have to be so too.")
7373
substoich = typeof(prodstoich)()
7474
end
7575
if isnothing(prods)
7676
prods = Vector{Operation}()
77-
(prodstoich!=nothing) && error("If products are nothing, product stiocihometries have to be so too.")
77+
!isnothing(prodstoich) && error("If products are nothing, product stiocihometries have to be so too.")
7878
prodstoich = typeof(substoich)()
7979
end
8080
ns = isnothing(netstoich) ? get_netstoich(subs, prods, substoich, prodstoich) : netstoich
@@ -140,8 +140,9 @@ function ReactionSystem(eqs, iv, species, params; systems = ReactionSystem[],
140140

141141

142142
isempty(species) && error("ReactionSystems require at least one species.")
143-
paramvars = isempty(params) ? Variable[] : convert.(Variable, params)
144-
ReactionSystem(eqs, iv, convert.(Variable,species), paramvars, name, systems)
143+
paramvars = map(v -> convert(Variable,v), params)
144+
specvars = map(s -> convert(Variable,s), species)
145+
ReactionSystem(eqs, convert(Variable,iv), specvars, paramvars, name, systems)
145146
end
146147

147148
# Calculate the ODE rate law
@@ -220,7 +221,8 @@ end
220221
"""
221222
```julia
222223
ismassaction(rx, rs; rxvars = get_variables(rx.rate),
223-
haveivdep = any(var -> isequal(rs.iv,convert(Variable,var)), rxvars))
224+
haveivdep = any(var -> isequal(rs.iv,convert(Variable,var)), rxvars),
225+
stateset = Set(states(rs)))
224226
```
225227
226228
True if a given reaction is of mass action form, i.e. `rx.rate` does not depend
@@ -232,43 +234,71 @@ explicitly on the independent variable (usually time).
232234
- `rs`, a [`ReactionSystem`](@ref) containing the reaction.
233235
- Optional: `rxvars`, `Variable`s which are not in `rxvars` are ignored as possible dependencies.
234236
- Optional: `haveivdep`, `true` if the [`Reaction`](@ref) `rate` field explicitly depends on the independent variable.
237+
- Optional: `stateset`, set of states which if the rxvars are within mean rx is non-mass action.
235238
"""
236239
function ismassaction(rx, rs; rxvars = get_variables(rx.rate),
237-
haveivdep = any(var -> isequal(rs.iv,convert(Variable,var)), rxvars))
240+
haveivdep = any(var -> isequal(rs.iv,convert(Variable,var)), rxvars),
241+
stateset = Set(states(rs)))
238242
# if no dependencies must be zero order
239-
if isempty(rxvars)
240-
return true
241-
else
242-
return !(haveivdep || rx.only_use_rate || any(convert(Variable,rxv) in states(rs) for rxv in rxvars))
243+
(length(rxvars)==0) && return true
244+
(haveivdep || rx.only_use_rate) && return false
245+
@inbounds for i = 1:length(rxvars)
246+
(rxvars[i].op in stateset) && return false
243247
end
248+
return true
249+
end
250+
251+
@inline function makemajump(rx)
252+
@unpack rate, substrates, substoich, netstoich = rx
253+
havesubstoich = (length(substoich) == 0)
254+
reactant_stoch = Vector{Pair{Operation,eltype(substoich)}}(undef, length(substoich))
255+
@inbounds for i = 1:length(reactant_stoch)
256+
reactant_stoch[i] = var2op(substrates[i].op) => substoich[i]
257+
end
258+
#push!(rstoich, reactant_stoch)
259+
coef = havesubstoich ? one(eltype(substoich)) : prod(stoich -> factorial(stoich), substoich)
260+
rate = isone(coef) ? rate : rate/coef
261+
#push!(rates, rate)
262+
net_stoch = [Pair(var2op(p[1]),p[2]) for p in netstoich]
263+
#push!(nstoich, net_stoch)
264+
MassActionJump(rate, reactant_stoch, net_stoch, scale_rates=false, useiszero=false)
244265
end
245266

246267
function assemble_jumps(rs)
247-
eqs = Vector{Union{ConstantRateJump, MassActionJump, VariableRateJump}}()
268+
meqs = MassActionJump[]; ceqs = ConstantRateJump[]; veqs = VariableRateJump[]
269+
stateset = Set(states(rs))
270+
#rates = []; rstoich = []; nstoich = []
271+
rxvars = Operation[]
272+
ivname = rs.iv.name
248273

274+
isempty(equations(rs)) && error("Must give at least one reaction before constructing a JumpSystem.")
249275
for rx in equations(rs)
250-
rxvars = (rx.rate isa Operation) ? get_variables(rx.rate) : Operation[]
251-
haveivdep = any(var -> isequal(rs.iv,convert(Variable,var)), rxvars)
252-
if ismassaction(rx, rs; rxvars=rxvars, haveivdep=haveivdep)
253-
reactant_stoch = isempty(rx.substoich) ? [0 => 1] : [var2op(sub.op) => stoich for (sub,stoich) in zip(rx.substrates,rx.substoich)]
254-
coef = isempty(rx.substoich) ? one(eltype(rx.substoich)) : prod(stoich -> factorial(stoich), rx.substoich)
255-
rate = isone(coef) ? rx.rate : rx.rate/coef
256-
net_stoch = [Pair(var2op(p[1]),p[2]) for p in rx.netstoich]
257-
push!(eqs, MassActionJump(rate, reactant_stoch, net_stoch, scale_rates=false))
276+
empty!(rxvars)
277+
(rx.rate isa Operation) && get_variables!(rxvars, rx.rate)
278+
haveivdep = false
279+
@inbounds for i = 1:length(rxvars)
280+
if rxvars[i].op.name == ivname
281+
haveivdep = true
282+
break
283+
end
284+
end
285+
if ismassaction(rx, rs; rxvars=rxvars, haveivdep=haveivdep, stateset=stateset)
286+
push!(meqs, makemajump(rx))
258287
else
259288
rl = jumpratelaw(rx, rxvars=rxvars)
260289
affect = Vector{Equation}()
261290
for (spec,stoich) in rx.netstoich
262291
push!(affect, var2op(spec) ~ var2op(spec) + stoich)
263292
end
264293
if haveivdep
265-
push!(eqs, VariableRateJump(rl,affect))
294+
push!(veqs, VariableRateJump(rl,affect))
266295
else
267-
push!(eqs, ConstantRateJump(rl,affect))
296+
push!(ceqs, ConstantRateJump(rl,affect))
268297
end
269298
end
270299
end
271-
eqs
300+
#eqs[1] = MassActionJump(rates, rstoich, nstoich, scale_rates=false, useiszero=false)
301+
ArrayPartition(meqs,ceqs,veqs)
272302
end
273303

274304
"""

test/reactionsystem.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,15 @@ G2 = sf.g(u,p,t)
8484
# test with JumpSystem
8585
js = convert(JumpSystem, rs)
8686

87-
@test all(map(i -> typeof(js.eqs[i]) <: DiffEqJump.MassActionJump, 1:14))
88-
@test all(map(i -> typeof(js.eqs[i]) <: DiffEqJump.ConstantRateJump, 15:18))
89-
@test all(map(i -> typeof(js.eqs[i]) <: DiffEqJump.VariableRateJump, 19:20))
87+
midxs = 1:14
88+
cidxs = 15:18
89+
vidxs = 19:20
90+
@test all(map(i -> typeof(js.eqs[i]) <: DiffEqJump.MassActionJump, midxs))
91+
@test all(map(i -> typeof(js.eqs[i]) <: DiffEqJump.ConstantRateJump, cidxs))
92+
@test all(map(i -> typeof(js.eqs[i]) <: DiffEqJump.VariableRateJump, vidxs))
9093

9194
pars = rand(length(k)); u0 = rand(1:10,4); time = rand();
92-
jumps = Vector{Union{ConstantRateJump, MassActionJump, VariableRateJump}}(undef,length(js.eqs))
95+
jumps = Vector{Union{ConstantRateJump, MassActionJump, VariableRateJump}}(undef,length(rxs))
9396

9497
jumps[1] = MassActionJump(pars[1], Vector{Pair{Int,Int}}(), [1 => 1]);
9598
jumps[2] = MassActionJump(pars[2], [2 => 1], [2 => -1]);
@@ -116,20 +119,20 @@ jumps[20] = VariableRateJump((u,p,t) -> p[20]*t*u[1]*binomial(u[2],2)*u[3], inte
116119

117120
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
118121
parammap = map((x,y)->Pair(x(),y),parameters(js),pars)
119-
for i = 1:14
120-
maj = MT.assemble_maj(js, js.eqs[i], statetoid, ModelingToolkit.substituter(parammap),eltype(pars))
122+
for i in midxs
123+
maj = MT.assemble_maj(js.eqs[i], statetoid, ModelingToolkit.substituter(parammap),eltype(pars))
121124
@test abs(jumps[i].scaled_rates - maj.scaled_rates) < 100*eps()
122125
@test jumps[i].reactant_stoch == maj.reactant_stoch
123126
@test jumps[i].net_stoch == maj.net_stoch
124127
end
125-
for i = 15:18
128+
for i in cidxs
126129
crj = MT.assemble_crj(js, js.eqs[i], statetoid)
127130
@test isapprox(crj.rate(u0,p,time), jumps[i].rate(u0,p,time))
128131
fake_integrator1 = (u=zeros(4),p=p,t=0); fake_integrator2 = deepcopy(fake_integrator1);
129132
crj.affect!(fake_integrator1); jumps[i].affect!(fake_integrator2);
130133
@test fake_integrator1 == fake_integrator2
131134
end
132-
for i = 19:20
135+
for i in vidxs
133136
crj = MT.assemble_vrj(js, js.eqs[i], statetoid)
134137
@test isapprox(crj.rate(u0,p,time), jumps[i].rate(u0,p,time))
135138
fake_integrator1 = (u=zeros(4),p=p,t=0.); fake_integrator2 = deepcopy(fake_integrator1);

0 commit comments

Comments
 (0)