Skip to content

Commit 6a529bb

Browse files
Merge pull request #280 from SciML/threads
Add and test automatic multithreading
2 parents fc5a147 + b64b1ef commit 6a529bb

File tree

6 files changed

+93
-18
lines changed

6 files changed

+93
-18
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1414
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1515
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
16+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1617
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1718
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1819
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using StaticArrays, LinearAlgebra, SparseArrays
55
using Latexify, Unitful
66
using MacroTools
77

8-
using MacroTools
8+
using Base.Threads
99
import MacroTools: splitdef, combinedef, postwalk, striplines
1010
import GeneralizedGenerated
1111
using DocStringExtensions

src/build_function.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ end
4747
function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
4848
conv = simplified_expr, expression = Val{true};
4949
checkbounds = false, constructor=nothing,
50-
linenumbers = true)
50+
linenumbers = true, multithread=false)
5151
_vs = map(x-> x isa Operation ? x.op : x, vs)
5252
_ps = map(x-> x isa Operation ? x.op : x, ps)
5353
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
@@ -67,6 +67,21 @@ function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
6767

6868
ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))
6969

70+
if multithread
71+
lens = Int(ceil(length(ip_let_expr.args[2].args)/Threads.nthreads()))
72+
threaded_exprs = vcat([quote
73+
Threads.@spawn begin
74+
$(ip_let_expr.args[2].args[((i-1)*lens+1):i*lens]...)
75+
end
76+
end for i in 1:Threads.nthreads()-1],
77+
quote
78+
Threads.@spawn begin
79+
$(ip_let_expr.args[2].args[((Threads.nthreads()-1)*lens+1):end]...)
80+
end
81+
end)
82+
ip_let_expr.args[2] = ModelingToolkit.build_expr(:block, threaded_exprs)
83+
end
84+
7085
tuple_sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
7186

7287
if rhss isa Matrix

test/bigsystem.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using ModelingToolkit, LinearAlgebra, SparseArrays
2+
3+
# Define the constants for the PDE
4+
const α₂ = 1.0
5+
const α₃ = 1.0
6+
const β₁ = 1.0
7+
const β₂ = 1.0
8+
const β₃ = 1.0
9+
const r₁ = 1.0
10+
const r₂ = 1.0
11+
const _DD = 100.0
12+
const γ₁ = 0.1
13+
const γ₂ = 0.1
14+
const γ₃ = 0.1
15+
const N = 8
16+
const X = reshape([i for i in 1:N for j in 1:N],N,N)
17+
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
18+
const α₁ = 1.0.*(X.>=4*N/5)
19+
20+
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
21+
const My = copy(Mx)
22+
Mx[2,1] = 2.0
23+
Mx[end-1,end] = 2.0
24+
My[1,2] = 2.0
25+
My[end,end-1] = 2.0
26+
27+
# Define the initial condition as normal arrays
28+
@variables du[1:N,1:N,1:3] u[1:N,1:N,1:3] MyA[1:N,1:N] AMx[1:N,1:N] DA[1:N,1:N]
29+
30+
# Define the discretized PDE as an ODE function
31+
function f(du,u,p,t)
32+
A = @view u[:,:,1]
33+
B = @view u[:,:,2]
34+
C = @view u[:,:,3]
35+
dA = @view du[:,:,1]
36+
dB = @view du[:,:,2]
37+
dC = @view du[:,:,3]
38+
mul!(MyA,My,A)
39+
mul!(AMx,A,Mx)
40+
@. DA = _DD*(MyA + AMx)
41+
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
42+
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
43+
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
44+
end
45+
46+
f(du,u,nothing,0.0)
47+
48+
multithreadedf = eval(ModelingToolkit.build_function(du,u,multithread=true)[2])
49+
_du = rand(N,N,3)
50+
_u = rand(N,N,3)
51+
multithreadedf(_du,_u)
52+
53+
jac = sparse(ModelingToolkit.jacobian(vec(du),vec(u),simplify=false))
54+
multithreadedjac = eval(ModelingToolkit.build_function(vec(jac),u,multithread=true)[2])
55+
56+
#_jac = similar(jac,Float64)
57+
#multithreadedjac(_jac,_u)

test/nonlinearsystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ jac_func = generate_jacobian(ns)
5252
f = @eval eval(nlsys_func)
5353

5454
# Intermediate calculations
55+
a = y - x
5556
# Define a nonlinear system
5657
eqs = [0 ~ σ*a,
5758
0 ~ x*-z)-y,

test/runtests.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
using ModelingToolkit, Test
1+
using SafeTestsets, Test
22

3-
@testset "Parsing Test" begin include("variable_parsing.jl") end
4-
@testset "Differentiation Test" begin include("derivatives.jl") end
5-
@testset "Simplify Test" begin include("simplify.jl") end
6-
@testset "Operation Overloads Test" begin include("operation_overloads.jl") end
7-
@testset "Direct Usage Test" begin include("direct.jl") end
8-
@testset "ODESystem Test" begin include("odesystem.jl") end
9-
@testset "Mass Matrix Test" begin include("mass_matrix.jl") end
10-
@testset "SDESystem Test" begin include("sdesystem.jl") end
11-
@testset "NonlinearSystem Test" begin include("nonlinearsystem.jl") end
12-
@testset "OptimizationSystem Test" begin include("optimizationsystem.jl") end
13-
@testset "Build Targets Test" begin include("build_targets.jl") end
14-
@testset "Domain Test" begin include("domains.jl") end
15-
@testset "Constraints Test" begin include("constraints.jl") end
16-
@testset "PDE Construction Test" begin include("pde.jl") end
17-
@testset "Distributed Test" begin include("distributed.jl") end
3+
@safetestset "Parsing Test" begin include("variable_parsing.jl") end
4+
@safetestset "Differentiation Test" begin include("derivatives.jl") end
5+
@safetestset "Simplify Test" begin include("simplify.jl") end
6+
@safetestset "Operation Overloads Test" begin include("operation_overloads.jl") end
7+
@safetestset "Direct Usage Test" begin include("direct.jl") end
8+
@safetestset "ODESystem Test" begin include("odesystem.jl") end
9+
@safetestset "Mass Matrix Test" begin include("mass_matrix.jl") end
10+
@safetestset "SDESystem Test" begin include("sdesystem.jl") end
11+
@safetestset "NonlinearSystem Test" begin include("nonlinearsystem.jl") end
12+
@safetestset "OptimizationSystem Test" begin include("optimizationsystem.jl") end
13+
@safetestset "Build Targets Test" begin include("build_targets.jl") end
14+
@safetestset "Domain Test" begin include("domains.jl") end
15+
@safetestset "Constraints Test" begin include("constraints.jl") end
16+
@safetestset "PDE Construction Test" begin include("pde.jl") end
17+
@safetestset "Test Big System Usage" begin include("bigsystem.jl") end
1818
#@testset "Latexify recipes Test" begin include("latexify.jl") end
19+
@testset "Distributed Test" begin include("distributed.jl") end

0 commit comments

Comments
 (0)