Skip to content

Commit 8099df3

Browse files
Merge pull request #530 from TorkelE/noise_scalling_in_reaction_sys_SDEs_2
Noise scaling in reaction sys sd es 2
2 parents bde8358 + 3c22516 commit 8099df3

File tree

2 files changed

+57
-22
lines changed

2 files changed

+57
-22
lines changed

src/systems/reaction/reactionsystem.jl

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -154,19 +154,19 @@ generated ODEs for the reaction. Note, for a reaction defined by
154154
`k*X*Y, X+Z --> 2X + Y`
155155
156156
the expression that is returned will be `k*X(t)^2*Y(t)*Z(t)`. For a reaction
157-
of the form
157+
of the form
158158
159159
`k, 2X+3Y --> Z`
160160
161-
the `Operation` that is returned will be `k * (X(t)^2/2) * (Y(t)^3/6)`.
161+
the `Operation` that is returned will be `k * (X(t)^2/2) * (Y(t)^3/6)`.
162162
163163
Notes:
164164
- Allocates
165165
- `combinatoric_ratelaw=true` uses factorial scaling factors in calculating the rate
166166
law, i.e. for `2S -> 0` at rate `k` the ratelaw would be `k*S^2/2!`. If
167-
`combinatoric_ratelaw=false` then the ratelaw is `k*S^2`, i.e. the scaling factor is
167+
`combinatoric_ratelaw=false` then the ratelaw is `k*S^2`, i.e. the scaling factor is
168168
ignored.
169-
"""
169+
"""
170170
function oderatelaw(rx; combinatoric_ratelaw=true)
171171
@unpack rate, substrates, substoich, only_use_rate = rx
172172
rl = rate
@@ -203,12 +203,13 @@ function assemble_drift(rs; combinatoric_ratelaws=true)
203203
eqs
204204
end
205205

206-
function assemble_diffusion(rs; combinatoric_ratelaws=true)
206+
function assemble_diffusion(rs, noise_scaling; combinatoric_ratelaws=true)
207207
eqs = Expression[Constant(0) for x in rs.states, y in rs.eqs]
208208
species_to_idx = Dict((x => i for (i,x) in enumerate(rs.states)))
209209

210210
for (j,rx) in enumerate(rs.eqs)
211211
rlsqrt = sqrt(oderatelaw(rx; combinatoric_ratelaw=combinatoric_ratelaws))
212+
(noise_scaling!==nothing) && (rlsqrt *= noise_scaling[j])
212213
for (spec,stoich) in rx.netstoich
213214
i = species_to_idx[spec]
214215
signedrlsqrt = (stoich > zero(stoich)) ? rlsqrt : -rlsqrt
@@ -234,7 +235,7 @@ for a reaction defined by
234235
`k*X*Y, X+Z --> 2X + Y`
235236
236237
the expression that is returned will be `k*X^2*Y*Z`. For a reaction of
237-
the form
238+
the form
238239
239240
`k, 2X+3Y --> Z`
240241
@@ -247,8 +248,8 @@ Notes:
247248
- `combinatoric_ratelaw=true` uses binomials in calculating the rate law, i.e. for `2S ->
248249
0` at rate `k` the ratelaw would be `k*S*(S-1)/2`. If `combinatoric_ratelaw=false` then
249250
the ratelaw is `k*S*(S-1)`, i.e. the rate law is not normalized by the scaling
250-
factor.
251-
"""
251+
factor.
252+
"""
252253
function jumpratelaw(rx; rxvars=get_variables(rx.rate), combinatoric_ratelaw=true)
253254
@unpack rate, substrates, substoich, only_use_rate = rx
254255
rl = rate
@@ -364,7 +365,7 @@ Convert a [`ReactionSystem`](@ref) to an [`ODESystem`](@ref).
364365
Notes:
365366
- `combinatoric_ratelaws=true` uses factorial scaling factors in calculating the rate
366367
law, i.e. for `2S -> 0` at rate `k` the ratelaw would be `k*S^2/2!`. If
367-
`combinatoric_ratelaws=false` then the ratelaw is `k*S^2`, i.e. the scaling factor is
368+
`combinatoric_ratelaws=false` then the ratelaw is `k*S^2`, i.e. the scaling factor is
368369
ignored.
369370
"""
370371
function Base.convert(::Type{<:ODESystem}, rs::ReactionSystem; combinatoric_ratelaws=true)
@@ -383,14 +384,15 @@ Convert a [`ReactionSystem`](@ref) to an [`SDESystem`](@ref).
383384
Notes:
384385
- `combinatoric_ratelaws=true` uses factorial scaling factors in calculating the rate
385386
law, i.e. for `2S -> 0` at rate `k` the ratelaw would be `k*S^2/2!`. If
386-
`combinatoric_ratelaws=false` then the ratelaw is `k*S^2`, i.e. the scaling factor is
387+
`combinatoric_ratelaws=false` then the ratelaw is `k*S^2`, i.e. the scaling factor is
387388
ignored.
388389
"""
389-
function Base.convert(::Type{<:SDESystem},rs::ReactionSystem, combinatoric_ratelaws=true)
390+
function Base.convert(::Type{<:SDESystem},rs::ReactionSystem, combinatoric_ratelaws=true; noise_scaling=nothing::Union{Vector{Operation},Operation,Nothing})
391+
(typeof(noise_scaling) <: Vector{Operation}) && (length(noise_scaling)!=length(rs.eqs)) && error("The number of elements in 'noise_scaling' must be equal to the number of reactions in the reaction system.")
392+
(typeof(noise_scaling) <: Operation) && (noise_scaling = fill(noise_scaling,length(rs.eqs)))
390393
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws)
391-
noiseeqs = assemble_diffusion(rs; combinatoric_ratelaws=combinatoric_ratelaws)
392-
SDESystem(eqs,noiseeqs,rs.iv,rs.states,rs.ps,
393-
name=rs.name,systems=convert.(SDESystem,rs.systems))
394+
noiseeqs = assemble_diffusion(rs,noise_scaling; combinatoric_ratelaws=combinatoric_ratelaws)
395+
SDESystem(eqs,noiseeqs,rs.iv,rs.states,(noise_scaling===nothing) ? rs.ps : union(rs.ps,Variable{ModelingToolkit.Parameter{Number}}.(noise_scaling)),name=rs.name,systems=convert.(SDESystem,rs.systems))
394396
end
395397

396398
"""
@@ -404,7 +406,7 @@ Notes:
404406
- `combinatoric_ratelaws=true` uses binomials in calculating the rate law, i.e. for `2S ->
405407
0` at rate `k` the ratelaw would be `k*S*(S-1)/2`. If `combinatoric_ratelaws=false` then
406408
the ratelaw is `k*S*(S-1)`, i.e. the rate law is not normalized by the scaling
407-
factor.
409+
factor.
408410
"""
409411
function Base.convert(::Type{<:JumpSystem},rs::ReactionSystem; combinatoric_ratelaws=true)
410412
eqs = assemble_jumps(rs; combinatoric_ratelaws=combinatoric_ratelaws)
@@ -423,7 +425,7 @@ Convert a [`ReactionSystem`](@ref) to an [`NonlinearSystem`](@ref).
423425
Notes:
424426
- `combinatoric_ratelaws=true` uses factorial scaling factors in calculating the rate
425427
law, i.e. for `2S -> 0` at rate `k` the ratelaw would be `k*S^2/2!`. If
426-
`combinatoric_ratelaws=false` then the ratelaw is `k*S^2`, i.e. the scaling factor is
428+
`combinatoric_ratelaws=false` then the ratelaw is `k*S^2`, i.e. the scaling factor is
427429
ignored.
428430
"""
429431
function Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem; combinatoric_ratelaws=true)
@@ -456,11 +458,12 @@ function DiffEqBase.ODEProblem(rs::ReactionSystem, u0::Union{AbstractArray, Numb
456458
end
457459

458460
# SDEProblem from AbstractReactionNetwork
459-
function DiffEqBase.SDEProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, tspan, p, args...; kwargs...)
461+
function DiffEqBase.SDEProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, tspan, p, args...; noise_scaling=nothing::Union{Operation,Nothing}, kwargs...)
462+
sde_sys = convert(SDESystem,rs,noise_scaling=noise_scaling)
460463
u0 = typeof(u0) <: Array{<:Pair} ? u0 : Pair.(rs.states,u0)
461-
p = typeof(p) <: Array{<:Pair} ? p : Pair.(rs.ps,p)
464+
p = typeof(p) <: Array{<:Pair} ? p : Pair.(sde_sys.ps,p)
462465
p_matrix = zeros(length(rs.states), length(rs.eqs))
463-
return SDEProblem(convert(SDESystem,rs),u0,tspan,p,args...; noise_rate_prototype=p_matrix,kwargs...)
466+
return SDEProblem(sde_sys,u0,tspan,p,args...; noise_rate_prototype=p_matrix,kwargs...)
464467
end
465468

466469
# DiscreteProblem from AbstractReactionNetwork
@@ -496,4 +499,4 @@ function modified_states!(mstates, rx::Reaction, sts)
496499
for (species,stoich) in rx.netstoich
497500
(species in sts) && push!(mstates, species())
498501
end
499-
end
502+
end

test/reactionsystem.jl

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,38 @@ du2 = sf.f(u,p,t)
8181
G2 = sf.g(u,p,t)
8282
@test norm(G-G2) < 100*eps()
8383

84+
# tests the noise_scaling argument.
85+
p = rand(length(k)+1)
86+
u = rand(length(k))
87+
t = 0.
88+
G = p[21]*sdenoise(u,p,t)
89+
@variables η
90+
sdesys_noise_scaling = convert(SDESystem,rs;noise_scaling=η)
91+
sf = SDEFunction{false}(sdesys_noise_scaling, states(rs), parameters(sdesys_noise_scaling))
92+
G2 = sf.g(u,p,t)
93+
@test norm(G-G2) < 100*eps()
94+
95+
# tests the noise_scaling vector argument.
96+
p = rand(length(k)+3)
97+
u = rand(length(k))
98+
t = 0.
99+
G = vcat(fill(p[21],8),fill(p[22],3),fill(p[23],9))' .* sdenoise(u,p,t)
100+
@variables η[1:3]
101+
sdesys_noise_scaling = convert(SDESystem,rs;noise_scaling=vcat(fill(η[1],8),fill(η[2],3),fill(η[3],9)))
102+
sf = SDEFunction{false}(sdesys_noise_scaling, states(rs), parameters(sdesys_noise_scaling))
103+
G2 = sf.g(u,p,t)
104+
@test norm(G-G2) < 100*eps()
105+
106+
# tests using previous parameter for noise scaling
107+
p = rand(length(k))
108+
u = rand(length(k))
109+
t = 0.
110+
G = [p p p p]' .* sdenoise(u,p,t)
111+
sdesys_noise_scaling = convert(SDESystem,rs;noise_scaling=k)
112+
sf = SDEFunction{false}(sdesys_noise_scaling, states(rs), parameters(sdesys_noise_scaling))
113+
G2 = sf.g(u,p,t)
114+
@test norm(G-G2) < 100*eps()
115+
84116
# test with JumpSystem
85117
js = convert(JumpSystem, rs)
86118

@@ -142,7 +174,7 @@ end
142174

143175

144176
# test for https://github.com/SciML/ModelingToolkit.jl/issues/436
145-
@parameters t
177+
@parameters t
146178
@variables S I
147179
rxs = [Reaction(1,[S],[I]), Reaction(1.1,[S],[I])]
148180
rs = ReactionSystem(rxs, t, [S,I], [])
@@ -192,4 +224,4 @@ rs = ReactionSystem(rxs, t, [S,I,R], [k1,k2])
192224
js = convert(JumpSystem, rs)
193225
@test isequal2(js.eqs[1].scaled_rates, k1/12)
194226
js = convert(JumpSystem,rs; combinatoric_ratelaws=false)
195-
@test isequal2(js.eqs[1].scaled_rates, k1)
227+
@test isequal2(js.eqs[1].scaled_rates, k1)

0 commit comments

Comments
 (0)