Skip to content

Commit 27262db

Browse files
Merge pull request #496 from SciML/safer_oop_conversion
safer OOP conversion strategies
2 parents 387830f + d32181f commit 27262db

File tree

3 files changed

+29
-10
lines changed

3 files changed

+29
-10
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ DiffEqJump = "6.7.5"
3737
DiffRules = "0.1, 1.0"
3838
DocStringExtensions = "0.7, 0.8"
3939
GeneralizedGenerated = "0.1.4, 0.2"
40-
LabelledArrays = "1.2"
40+
LabelledArrays = "1.3"
4141
Latexify = "0.11, 0.12, 0.13"
4242
LightGraphs = "1.3"
4343
MacroTools = "0.5"
@@ -55,10 +55,11 @@ julia = "1.2"
5555

5656
[extras]
5757
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
58+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
5859
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
5960
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
6061
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
6162
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6263

6364
[targets]
64-
test = ["Dagger", "OrdinaryDiffEq", "SteadyStateDiffEq", "Test", "StochasticDiffEq"]
65+
test = ["Dagger", "ForwardDiff", "OrdinaryDiffEq", "SteadyStateDiffEq", "Test", "StochasticDiffEq"]

src/build_function.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
166166
checkbounds = false,
167167
linenumbers = false, multithread=nothing,
168168
headerfun=addheader, outputidxs=nothing,
169+
convert_oop = true,
169170
skipzeros = false, parallel=SerialForm())
170171

171172
if multithread isa Bool
@@ -336,17 +337,25 @@ function _build_function(target::JuliaTarget, rhss, args...;
336337
arr_sys_expr = (typeof(rhss) <: Vector || typeof(rhss) <: Matrix) && !(eltype(rhss) <: AbstractArray) ? quote
337338
if typeof($(fargs.args[1])) <: Union{ModelingToolkit.StaticArrays.SArray,ModelingToolkit.LabelledArrays.SLArray}
338339
$xname = ModelingToolkit.StaticArrays.@SArray $arr_sys_expr
339-
if $(typeof(rhss) <: Vector) # Only try converting if it should match `u`
340-
convert(typeof($(fargs.args[1])),$xname)
340+
if $convert_oop && !(typeof($(fargs.args[1])) <: Number) && $(typeof(rhss) <: Vector) # Only try converting if it should match `u`
341+
return similar_type($(fargs.args[1]),eltype($xname))($xname)
341342
else
342-
$xname
343+
return $xname
343344
end
344345
else
345346
$xname = $arr_sys_expr
346-
if !(typeof($(fargs.args[1])) <: Array) && $(typeof(rhss) <: Vector)
347-
convert(typeof($(fargs.args[1])),$xname)
347+
if $convert_oop && $(typeof(rhss) <: Vector)
348+
if !(typeof($(fargs.args[1])) <: Array) && !(typeof($(fargs.args[1])) <: Number) && eltype($(fargs.args[1])) <: eltype($xname)
349+
# Last condition: avoid known error because this doesn't change eltypes!
350+
return convert(typeof($(fargs.args[1])),$xname)
351+
elseif typeof($(fargs.args[1])) <: ModelingToolkit.LabelledArrays.LArray
352+
# LArray just needs to add the names back!
353+
return ModelingToolkit.LabelledArrays.LArray{ModelingToolkit.LabelledArrays.symnames(typeof($(fargs.args[1])))}($xname)
354+
else
355+
return $xname
356+
end
348357
else
349-
$xname
358+
return $xname
350359
end
351360
end
352361
end : arr_sys_expr

test/labelledarrays.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using ModelingToolkit, StaticArrays, LinearAlgebra, LabelledArrays
2-
using DiffEqBase
2+
using DiffEqBase, ForwardDiff
33
using Test
44

55
# Define some variables
@@ -9,7 +9,7 @@ using Test
99

1010
# Define a differential equation
1111
eqs = [D(x) ~ σ*(y-x),
12-
D(y) ~ x*-z)-y,
12+
D(y) ~ t*x*-z)-y,
1313
D(z) ~ x*y - β*z]
1414

1515
de = ODESystem(eqs)
@@ -30,3 +30,12 @@ p = SLVector(σ=10.0,ρ=26.0,β=8/3)
3030
@test ff.jac(c,p,0.0) isa Matrix
3131
@test ff.jac(a,p,0.0) == ff.jac(b,p,0.0)
3232
@test ff.jac(a,p,0.0) == ff.jac(c,p,0.0)
33+
34+
# Test similar_type
35+
@test ff(b,p,ForwardDiff.Dual(0.0,1.0)) isa SLArray
36+
d = LVector(x=1.0,y=2.0,z=3.0)
37+
@test ff(d,p,ForwardDiff.Dual(0.0,1.0)) isa LArray
38+
@test ff.jac(b,p,ForwardDiff.Dual(0.0,1.0)) isa SArray
39+
@test eltype(ff.jac(b,p,ForwardDiff.Dual(0.0,1.0))) <: ForwardDiff.Dual
40+
@test ff.jac(d,p,ForwardDiff.Dual(0.0,1.0)) isa Array
41+
@test eltype(ff.jac(d,p,ForwardDiff.Dual(0.0,1.0))) <: ForwardDiff.Dual

0 commit comments

Comments
 (0)