Skip to content

Commit 218dfee

Browse files
committed
Uses Operation.
Now handles noise already delcared noise parameters.
1 parent 6acffc7 commit 218dfee

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

src/systems/reaction/reactionsystem.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ function assemble_diffusion(rs, noise_scaling; combinatoric_ratelaws=true)
209209

210210
for (j,rx) in enumerate(rs.eqs)
211211
rlsqrt = sqrt(oderatelaw(rx; combinatoric_ratelaw=combinatoric_ratelaws))
212-
(noise_scaling!==nothing) && (rlsqrt *= var2op(Variable(noise_scaling[j])))
212+
(noise_scaling!==nothing) && (rlsqrt *= noise_scaling[j])
213213
for (spec,stoich) in rx.netstoich
214214
i = species_to_idx[spec]
215215
signedrlsqrt = (stoich > zero(stoich)) ? rlsqrt : -rlsqrt
@@ -387,12 +387,12 @@ law, i.e. for `2S -> 0` at rate `k` the ratelaw would be `k*S^2/2!`. If
387387
`combinatoric_ratelaws=false` then the ratelaw is `k*S^2`, i.e. the scaling factor is
388388
ignored.
389389
"""
390-
function Base.convert(::Type{<:SDESystem},rs::ReactionSystem, combinatoric_ratelaws=true; noise_scaling=nothing::Union{Vector{Symbol},Symbol,Nothing})
391-
(typeof(noise_scaling) <: Vector{Symbol}) && (length(noise_scaling)!=length(rs.eqs)) && error("The number of elements in 'noise_scaling' must be equal to the number of reactions in the reaction system.")
392-
(typeof(noise_scaling) <: Symbol) && (noise_scaling = fill(noise_scaling,length(rs.eqs)))
390+
function Base.convert(::Type{<:SDESystem},rs::ReactionSystem, combinatoric_ratelaws=true; noise_scaling=nothing::Union{Vector{Operation},Operation,Nothing})
391+
(typeof(noise_scaling) <: Vector{Operation}) && (length(noise_scaling)!=length(rs.eqs)) && error("The number of elements in 'noise_scaling' must be equal to the number of reactions in the reaction system.")
392+
(typeof(noise_scaling) <: Operation) && (noise_scaling = fill(noise_scaling,length(rs.eqs)))
393393
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws)
394394
noiseeqs = assemble_diffusion(rs,noise_scaling; combinatoric_ratelaws=combinatoric_ratelaws)
395-
SDESystem(eqs,noiseeqs,rs.iv,rs.states,(noise_scaling===nothing) ? rs.ps : vcat(rs.ps,Variable.(unique(noise_scaling))),name=rs.name,systems=convert.(SDESystem,rs.systems))
395+
SDESystem(eqs,noiseeqs,rs.iv,rs.states,(noise_scaling===nothing) ? rs.ps : union(rs.ps,Variable{ModelingToolkit.Parameter{Number}}.(noise_scaling)),name=rs.name,systems=convert.(SDESystem,rs.systems))
396396
end
397397

398398
"""
@@ -458,12 +458,12 @@ function DiffEqBase.ODEProblem(rs::ReactionSystem, u0::Union{AbstractArray, Numb
458458
end
459459

460460
# SDEProblem from AbstractReactionNetwork
461-
function DiffEqBase.SDEProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, tspan, p, args...; noise_scaling=nothing::Union{Symbol,Nothing}, kwargs...)
461+
function DiffEqBase.SDEProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, tspan, p, args...; noise_scaling=nothing::Union{Operation,Nothing}, kwargs...)
462462
sde_sys = convert(SDESystem,rs,noise_scaling=noise_scaling)
463463
u0 = typeof(u0) <: Array{<:Pair} ? u0 : Pair.(rs.states,u0)
464-
p = typeof(p) <: Array{<:Pair} ? p : Pair.(rs.ps,p)
464+
p = typeof(p) <: Array{<:Pair} ? p : Pair.(sde_sys.ps,p)
465465
p_matrix = zeros(length(rs.states), length(rs.eqs))
466-
return SDEProblem(convert(SDESystem,rs),u0,tspan,p,args...; noise_rate_prototype=p_matrix,kwargs...)
466+
return SDEProblem(sde_sys,u0,tspan,p,args...; noise_rate_prototype=p_matrix,kwargs...)
467467
end
468468

469469
# DiscreteProblem from AbstractReactionNetwork

test/reactionsystem.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ p = rand(length(k)+1)
8686
u = rand(length(k))
8787
t = 0.
8888
G = p[21]*sdenoise(u,p,t)
89-
sdesys_noise_scaling = convert(SDESystem,rs;noise_scaling=)
89+
@variables η
90+
sdesys_noise_scaling = convert(SDESystem,rs;noise_scaling=η)
9091
sf = SDEFunction{false}(sdesys_noise_scaling, states(rs), parameters(sdesys_noise_scaling))
9192
G2 = sf.g(u,p,t)
9293
@test norm(G-G2) < 100*eps()
@@ -96,7 +97,18 @@ p = rand(length(k)+3)
9697
u = rand(length(k))
9798
t = 0.
9899
G = vcat(fill(p[21],8),fill(p[22],3),fill(p[23],9))' .* sdenoise(u,p,t)
99-
sdesys_noise_scaling = convert(SDESystem,rs;noise_scaling=vcat(fill(:η1,8),fill(:η2,3),fill(:η3,9)))
100+
@variables η[1:3]
101+
sdesys_noise_scaling = convert(SDESystem,rs;noise_scaling=vcat(fill(η[1],8),fill(η[2],3),fill(η[3],9)))
102+
sf = SDEFunction{false}(sdesys_noise_scaling, states(rs), parameters(sdesys_noise_scaling))
103+
G2 = sf.g(u,p,t)
104+
@test norm(G-G2) < 100*eps()
105+
106+
# tests using previous parameter for noise scaling
107+
p = rand(length(k)+3)
108+
u = rand(length(k))
109+
t = 0.
110+
G = [p p p p]' .* sdenoise(u,p,t)
111+
sdesys_noise_scaling = convert(SDESystem,rs;noise_scaling=k)
100112
sf = SDEFunction{false}(sdesys_noise_scaling, states(rs), parameters(sdesys_noise_scaling))
101113
G2 = sf.g(u,p,t)
102114
@test norm(G-G2) < 100*eps()

0 commit comments

Comments
 (0)