@@ -101,10 +101,10 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
101101end
102102
103103function generate_diffusion_function (sys:: SDESystem , dvs = states (sys), ps = parameters (sys); kwargs... )
104- return build_function (sys. noiseeqs ,
104+ return build_function (get_noiseeqs ( sys) ,
105105 map (x-> time_varying_as_func (value (x), sys), dvs),
106106 map (x-> time_varying_as_func (value (x), sys), ps),
107- sys. iv ; kwargs... )
107+ independent_variable ( sys) ; kwargs... )
108108end
109109
110110"""
@@ -114,36 +114,34 @@ Choose correction_factor=-1//2 (1//2) to converte Ito -> Stratonovich (Stratonov
114114"""
115115function stochastic_integral_transform (sys:: SDESystem , correction_factor)
116116 # use the general interface
117- if typeof (sys. noiseeqs ) <: Vector
118- eqs = vcat ([sys. eqs [i]. lhs ~ sys. noiseeqs [i] for i in eachindex (sys . states)]. .. )
119- de = ODESystem (eqs,sys. iv,sys . states, sys. ps )
117+ if typeof (get_noiseeqs ( sys) ) <: Vector
118+ eqs = vcat ([equations ( sys) [i]. lhs ~ get_noiseeqs ( sys) [i] for i in eachindex (states (sys) )]. .. )
119+ de = ODESystem (eqs,get_iv ( sys), states (sys), parameters ( sys) )
120120
121121 jac = calculate_jacobian (de, sparse= false , simplify= false )
122- ∇σσ′ = simplify .(jac* sys. noiseeqs )
122+ ∇σσ′ = simplify .(jac* get_noiseeqs ( sys) )
123123
124- deqs = vcat ([sys. eqs [i]. lhs ~ sys. eqs [i]. rhs+ correction_factor* ∇σσ′[i] for i in eachindex (sys . states)]. .. )
124+ deqs = vcat ([equations ( sys) [i]. lhs ~ equations ( sys) [i]. rhs+ correction_factor* ∇σσ′[i] for i in eachindex (states (sys) )]. .. )
125125 else
126- dimstate, m = size (sys. noiseeqs )
127- eqs = vcat ([sys. eqs [i]. lhs ~ sys. noiseeqs [i] for i in eachindex (sys . states)]. .. )
128- de = ODESystem (eqs,sys. iv,sys . states, sys. ps )
126+ dimstate, m = size (get_noiseeqs ( sys) )
127+ eqs = vcat ([equations ( sys) [i]. lhs ~ get_noiseeqs ( sys) [i] for i in eachindex (states (sys) )]. .. )
128+ de = ODESystem (eqs,get_iv ( sys), states (sys), parameters ( sys) )
129129
130130 jac = calculate_jacobian (de, sparse= false , simplify= false )
131- ∇σσ′ = simplify .(jac* sys. noiseeqs [:,1 ])
131+ ∇σσ′ = simplify .(jac* get_noiseeqs ( sys) [:,1 ])
132132 for k = 2 : m
133- eqs = vcat ([sys. eqs [i]. lhs ~ sys. noiseeqs [Int (i+ (k- 1 )* dimstate)] for i in eachindex (sys . states)]. .. )
134- de = ODESystem (eqs,sys. iv,sys . states, sys. ps )
133+ eqs = vcat ([equations ( sys) [i]. lhs ~ get_noiseeqs ( sys) [Int (i+ (k- 1 )* dimstate)] for i in eachindex (states (sys) )]. .. )
134+ de = ODESystem (eqs,get_iv ( sys), states (sys), parameters ( sys) )
135135
136136 jac = calculate_jacobian (de, sparse= false , simplify= false )
137- ∇σσ′ = ∇σσ′ + simplify .(jac* sys. noiseeqs [:,k])
137+ ∇σσ′ = ∇σσ′ + simplify .(jac* get_noiseeqs ( sys) [:,k])
138138 end
139139
140- deqs = vcat ([sys. eqs [i]. lhs ~ sys. eqs [i]. rhs + correction_factor* ∇σσ′[i] for i in eachindex (sys . states)]. .. )
140+ deqs = vcat ([equations ( sys) [i]. lhs ~ equations ( sys) [i]. rhs + correction_factor* ∇σσ′[i] for i in eachindex (states (sys) )]. .. )
141141 end
142142
143143
144- de = SDESystem (deqs,sys. noiseeqs,sys. iv,sys. states,sys. ps)
145-
146- de
144+ SDESystem (deqs,get_noiseeqs (sys),get_iv (sys),states (sys),parameters (sys))
147145end
148146
149147
@@ -161,7 +159,7 @@ Create an `SDEFunction` from the [`SDESystem`](@ref). The arguments `dvs` and `p
161159are used to set the order of the dependent variable and parameter vectors,
162160respectively.
163161"""
164- function DiffEqBase. SDEFunction {iip} (sys:: SDESystem , dvs = sys . states, ps = sys. ps ,
162+ function DiffEqBase. SDEFunction {iip} (sys:: SDESystem , dvs = states (sys) , ps = parameters ( sys) ,
165163 u0 = nothing ;
166164 version = nothing , tgrad= false , sparse = false ,
167165 jac = false , Wfact = false , eval_expression = true , kwargs... ) where {iip}
@@ -215,7 +213,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.
215213 Wfact = _Wfact === nothing ? nothing : _Wfact,
216214 Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t,
217215 mass_matrix = _M,
218- syms = Symbol .(sys . states))
216+ syms = Symbol .(states (sys) ))
219217end
220218
221219function DiffEqBase. SDEFunction (sys:: SDESystem , args... ; kwargs... )
@@ -287,7 +285,7 @@ function SDEFunctionExpr{iip}(sys::SDESystem, dvs = states(sys),
287285 Wfact = Wfact,
288286 Wfact_t = Wfact_t,
289287 mass_matrix = M,
290- syms = $ (Symbol .(states (sys))),kwargs ... )
288+ syms = $ (Symbol .(states (sys))))
291289 end
292290 ! linenumbers ? striplines (ex) : ex
293291end
@@ -322,13 +320,14 @@ function DiffEqBase.SDEProblem{iip}(sys::SDESystem,u0map,tspan,parammap=DiffEqBa
322320 f, u0, p = process_DEProblem (SDEFunction{iip}, sys, u0map, parammap; kwargs... )
323321 sparsenoise === nothing && (sparsenoise = get (kwargs, :sparse , false ))
324322
325- if typeof (sys. noiseeqs) <: AbstractVector
323+ noiseeqs = get_noiseeqs (sys)
324+ if noiseeqs isa AbstractVector
326325 noise_rate_prototype = nothing
327326 elseif sparsenoise
328- I,J,V = findnz (SparseArrays. sparse (sys . noiseeqs))
327+ I,J,V = findnz (SparseArrays. sparse (noiseeqs))
329328 noise_rate_prototype = SparseArrays. sparse (I,J,zero (eltype (u0)))
330329 else
331- noise_rate_prototype = zeros (eltype (u0),size (sys . noiseeqs))
330+ noise_rate_prototype = zeros (eltype (u0),size (noiseeqs))
332331 end
333332
334333 SDEProblem {iip} (f,f. g,u0,tspan,p;noise_rate_prototype= noise_rate_prototype,kwargs... )
@@ -363,13 +362,15 @@ function SDEProblemExpr{iip}(sys::SDESystem,u0map,tspan,
363362 linenumbers = get (kwargs, :linenumbers , true )
364363 sparsenoise === nothing && (sparsenoise = get (kwargs, :sparse , false ))
365364
366- if typeof (sys. noiseeqs) <: AbstractVector
365+ noiseeqs = get_noiseeqs (sys)
366+ if noiseeqs isa AbstractVector
367367 noise_rate_prototype = nothing
368368 elseif sparsenoise
369- I,J,V = findnz (SparseArrays. sparse (sys . noiseeqs))
369+ I,J,V = findnz (SparseArrays. sparse (noiseeqs))
370370 noise_rate_prototype = SparseArrays. sparse (I,J,zero (eltype (u0)))
371371 else
372- noise_rate_prototype = zeros (eltype (u0),size (sys. noiseeqs))
372+ T = u0 === nothing ? Float64 : eltype (u0)
373+ noise_rate_prototype = zeros (T,size (get_noiseeqs (sys)))
373374 end
374375 ex = quote
375376 f = $ f
0 commit comments