Skip to content

Commit def6288

Browse files
Merge pull request #346 from SciML/rebuild/97c99f0a
Rebuild model_inference/01-pendulum_bayesian_inference.jmd
2 parents fb4f352 + 64d0b28 commit def6288

File tree

8 files changed

+1609
-0
lines changed

8 files changed

+1609
-0
lines changed

html/model_inference/01-pendulum_bayesian_inference.html

Lines changed: 1003 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
---
2+
author: "Vaibhav Dixit"
3+
title: "Bayesian Inference on a Pendulum using DiffEqBayes.jl"
4+
---
5+
6+
7+
### Set up simple pendulum problem
8+
9+
```julia
10+
using DiffEqBayes, OrdinaryDiffEq, RecursiveArrayTools, Distributions, Plots, StatsPlots, BenchmarkTools, TransformVariables, CmdStan, DynamicHMC
11+
```
12+
13+
14+
15+
16+
Let's define our simple pendulum problem. Here our pendulum has a drag term `ω`
17+
and a length `L`.
18+
19+
![pendulum](https://user-images.githubusercontent.com/1814174/59942945-059c1680-942f-11e9-991c-2025e6e4ccd3.jpg)
20+
21+
We get first order equations by defining the first term as the velocity and the
22+
second term as the position, getting:
23+
24+
```julia
25+
function pendulum(du,u,p,t)
26+
ω,L = p
27+
x,y = u
28+
du[1] = y
29+
du[2] = - ω*y -(9.8/L)*sin(x)
30+
end
31+
32+
u0 = [1.0,0.1]
33+
tspan = (0.0,10.0)
34+
prob1 = ODEProblem(pendulum,u0,tspan,[1.0,2.5])
35+
```
36+
37+
```
38+
ODEProblem with uType Array{Float64,1} and tType Float64. In-place: true
39+
timespan: (0.0, 10.0)
40+
u0: [1.0, 0.1]
41+
```
42+
43+
44+
45+
46+
47+
### Solve the model and plot
48+
49+
To understand the model and generate data, let's solve and visualize the solution
50+
with the known parameters:
51+
52+
```julia
53+
sol = solve(prob1,Tsit5())
54+
plot(sol)
55+
```
56+
57+
![](figures/01-pendulum_bayesian_inference_3_1.png)
58+
59+
60+
61+
It's the pendulum, so you know what it looks like. It's periodic, but since we
62+
have not made a small angle assumption it's not exactly `sin` or `cos`. Because
63+
the true dampening parameter `ω` is 1, the solution does not decay over time,
64+
nor does it increase. The length `L` determines the period.
65+
66+
### Create some dummy data to use for estimation
67+
68+
We now generate some dummy data to use for estimation
69+
70+
```julia
71+
t = collect(range(1,stop=10,length=10))
72+
randomized = VectorOfArray([(sol(t[i]) + .01randn(2)) for i in 1:length(t)])
73+
data = convert(Array,randomized)
74+
```
75+
76+
```
77+
2×10 Array{Float64,2}:
78+
0.0669231 -0.377851 0.119404 0.0795968 … -0.01553 0.00535298
79+
-1.21411 0.344681 0.323712 -0.253243 0.0164092 -0.00897403
80+
```
81+
82+
83+
84+
85+
86+
Let's see what our data looks like on top of the real solution
87+
88+
```julia
89+
scatter!(data')
90+
```
91+
92+
![](figures/01-pendulum_bayesian_inference_5_1.png)
93+
94+
95+
96+
This data captures the non-dampening effect and the true period, making it
97+
perfect to attempting a Bayesian inference.
98+
99+
### Perform Bayesian Estimation
100+
101+
Now let's fit the pendulum to the data. Since we know our model is correct,
102+
this should give us back the parameters that we used to generate the data!
103+
Define priors on our parameters. In this case, let's assume we don't have much
104+
information, but have a prior belief that ω is between 0.1 and 3.0, while the
105+
length of the pendulum L is probably around 3.0:
106+
107+
```julia
108+
priors = [Uniform(0.1,3.0), Normal(3.0,1.0)]
109+
```
110+
111+
```
112+
2-element Array{Distributions.Distribution{Distributions.Univariate,Distrib
113+
utions.Continuous},1}:
114+
Distributions.Uniform{Float64}(a=0.1, b=3.0)
115+
Distributions.Normal{Float64}(μ=3.0, σ=1.0)
116+
```
117+
118+
119+
120+
121+
122+
Finally let's run the estimation routine from DiffEqBayes.jl with the Turing.jl backend to check if we indeed recover the parameters!
123+
124+
```julia
125+
bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;num_samples=10_000,
126+
syms = [:omega,:L])
127+
```
128+
129+
```
130+
Chains MCMC chain (9000×15×1 Array{Float64,3}):
131+
132+
Iterations = 1:9000
133+
Thinning interval = 1
134+
Chains = 1
135+
Samples per chain = 9000
136+
parameters = L, omega, σ[1]
137+
internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy
138+
_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps,
139+
nom_step_size, numerical_error, step_size, tree_depth
140+
141+
Summary Statistics
142+
parameters mean std naive_se mcse ess rhat
143+
144+
Symbol Float64 Float64 Float64 Float64 Float64 Float64
145+
146+
147+
148+
L 2.5036 0.2148 0.0023 0.0035 3703.0703 1.0000
149+
150+
omega 1.0777 0.2217 0.0023 0.0048 2000.0506 1.0008
151+
152+
σ[1] 0.1603 0.0390 0.0004 0.0007 3326.8139 0.9999
153+
154+
155+
Quantiles
156+
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
157+
Symbol Float64 Float64 Float64 Float64 Float64
158+
159+
L 2.0761 2.3766 2.5024 2.6302 2.9287
160+
omega 0.7670 0.9384 1.0395 1.1706 1.6059
161+
σ[1] 0.1018 0.1325 0.1540 0.1812 0.2529
162+
```
163+
164+
165+
166+
167+
168+
Notice that while our guesses had the wrong means, the learned parameters converged
169+
to the correct means, meaning that it learned good posterior distributions for the
170+
parameters. To look at these posterior distributions on the parameters, we can
171+
examine the chains:
172+
173+
```julia
174+
plot(bayesian_result)
175+
```
176+
177+
![](figures/01-pendulum_bayesian_inference_8_1.png)
178+
179+
180+
181+
As a diagnostic, we will also check the parameter chains. The chain is the MCMC
182+
sampling process. The chain should explore parameter space and converge reasonably
183+
well, and we should be taking a lot of samples after it converges (it is these
184+
samples that form the posterior distribution!)
185+
186+
```julia
187+
plot(bayesian_result, colordim = :parameter)
188+
```
189+
190+
![](figures/01-pendulum_bayesian_inference_9_1.png)
191+
192+
193+
194+
Notice that after awhile these chains converge to a "fuzzy line", meaning it
195+
found the area with the most likelihood and then starts to sample around there,
196+
which builds a posterior distribution around the true mean.
197+
198+
DiffEqBayes.jl allows the choice of using Stan.jl, Turing.jl and DynamicHMC.jl for MCMC, you can also use ApproxBayes.jl for Approximate Bayesian computation algorithms.
199+
Let's compare the timings across the different MCMC backends. We'll stick with the default arguments and 10,000 samples in each since there is a lot of room for micro-optimization
200+
specific to each package and algorithm combinations, you might want to do your own experiments for specific problems to get better understanding of the performance.
201+
202+
```julia
203+
@btime bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;syms = [:omega,:L],num_samples=10_000)
204+
```
205+
206+
```
207+
2.710 s (23598867 allocations: 1.50 GiB)
208+
Chains MCMC chain (9000×15×1 Array{Float64,3}):
209+
210+
Iterations = 1:9000
211+
Thinning interval = 1
212+
Chains = 1
213+
Samples per chain = 9000
214+
parameters = L, omega, σ[1]
215+
internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy
216+
_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps,
217+
nom_step_size, numerical_error, step_size, tree_depth
218+
219+
Summary Statistics
220+
parameters mean std naive_se mcse ess rhat
221+
222+
Symbol Float64 Float64 Float64 Float64 Float64 Float64
223+
224+
225+
226+
L 2.5019 0.2081 0.0022 0.0034 3767.3721 1.0000
227+
228+
omega 1.0773 0.2137 0.0023 0.0040 2973.1493 1.0001
229+
230+
σ[1] 0.1593 0.0371 0.0004 0.0006 4173.1326 1.0004
231+
232+
233+
Quantiles
234+
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
235+
Symbol Float64 Float64 Float64 Float64 Float64
236+
237+
L 2.0844 2.3770 2.5029 2.6269 2.9178
238+
omega 0.7660 0.9383 1.0424 1.1743 1.6056
239+
σ[1] 0.1032 0.1325 0.1538 0.1793 0.2468
240+
```
241+
242+
243+
244+
```julia
245+
@btime bayesian_result = stan_inference(prob1,t,data,priors;num_samples=10_000,printsummary=false)
246+
```
247+
248+
```
249+
Error: MethodError: no method matching iterate(::ModelingToolkit.ODESystem)
250+
Closest candidates are:
251+
iterate(!Matched::Core.SimpleVector) at essentials.jl:603
252+
iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:603
253+
iterate(!Matched::ExponentialBackOff) at error.jl:253
254+
...
255+
```
256+
257+
258+
259+
```julia
260+
@btime bayesian_result = dynamichmc_inference(prob1,Tsit5(),t,data,priors;num_samples = 10_000)
261+
```
262+
263+
```
264+
6.027 s (40540072 allocations: 3.52 GiB)
265+
(posterior = NamedTuple{(:parameters, :σ),Tuple{Array{Float64,1},Array{Floa
266+
t64,1}}}[(parameters = [0.9925322562437633, 2.499846186491921], σ = [0.0059
267+
66804814045917, 0.008177933301622841]), (parameters = [0.9963837443808898,
268+
2.502334158254934], σ = [0.006778656235910425, 0.009222937077381753]), (par
269+
ameters = [1.0036593578298718, 2.5036585671312954], σ = [0.0052096129327429
270+
53, 0.009389702547257326]), (parameters = [1.012182569162702, 2.49418373759
271+
8965], σ = [0.009352843917833122, 0.006784007433840137]), (parameters = [0.
272+
9776075654162109, 2.506917713604719], σ = [0.00819375792385741, 0.008509449
273+
758223278]), (parameters = [0.9711538245294002, 2.52311064587977], σ = [0.0
274+
08075286197741678, 0.008380689150424841]), (parameters = [1.030127379473760
275+
4, 2.4900769004417103], σ = [0.005671628576862689, 0.009404135319949877]),
276+
(parameters = [1.027662532372297, 2.47639876182845], σ = [0.005769289829077
277+
621, 0.010107438192968452]), (parameters = [1.02159465396289, 2.47247053938
278+
56765], σ = [0.005902386622399507, 0.009452950393124696]), (parameters = [1
279+
.022343560476168, 2.4776243412273726], σ = [0.0058799594731072675, 0.009630
280+
482440656476]) … (parameters = [0.99920993073967, 2.507790860320399], σ =
281+
[0.011498727162582899, 0.009257041097492019]), (parameters = [1.0033967451
282+
562795, 2.5151332738328414], σ = [0.009519097562008954, 0.00920901544856593
283+
2]), (parameters = [0.9976601690281769, 2.507783954703131], σ = [0.00947878
284+
0454357206, 0.008575044597729732]), (parameters = [0.9944432965906622, 2.50
285+
63137783766347], σ = [0.006264533060570234, 0.00805770806170044]), (paramet
286+
ers = [0.995789590567554, 2.507352981860877], σ = [0.005952805608239749, 0.
287+
008806071695367526]), (parameters = [1.0082385919204935, 2.497953403941628]
288+
, σ = [0.005182367418104671, 0.008522531747426976]), (parameters = [0.99962
289+
84152075758, 2.5189801260651614], σ = [0.009705763622756209, 0.010744838454
290+
615106]), (parameters = [1.0119656370109682, 2.494765752066562], σ = [0.004
291+
729975947126395, 0.007155123027207663]), (parameters = [1.0046646904759642,
292+
2.4934723249796953], σ = [0.005248078221127134, 0.008421618259762318]), (p
293+
arameters = [1.005881733899965, 2.5009460269940513], σ = [0.005481409918559
294+
183, 0.008388170512741225])], chain = [[-0.007495766955188785, 0.9162292045
295+
781582, -5.121543701879515, -4.8063158129028745], [-0.0036228100779651584,
296+
0.9172239595836579, -4.993976391966563, -4.686061737110957], [0.00365267866
297+
93412507, 0.917753088961152, -5.257249719116282, -4.66814166387209], [0.012
298+
108958905343204, 0.913961516396238, -4.6720748195944335, -4.993187284703021
299+
], [-0.022646951830213737, 0.9190539959828018, -4.804382643391568, -4.76657
300+
7996745424], [-0.029270404566869655, 0.9254925235842713, -4.818946968056696
301+
, -4.781825130348601], [0.029682463987115947, 0.9123135937112545, -5.172278
302+
975524333, -4.666605758791736], [0.02728683723620924, 0.9068053926743036, -
303+
5.155206285919641, -4.5944836714353015], [0.021364792718281282, 0.905217869
304+
1295568, -5.132398497592392, -4.6614283752915044], [0.022097600140248842, 0
305+
.9073001741472317, -5.13620540942405, -4.64282195674715] … [-0.0007903815
306+
295350864, 0.9194022302666618, -4.465518931253114, -4.682370817294514], [0.
307+
0033909892480041516, 0.9223257937313328, -4.654455228578712, -4.68757233469
308+
34475], [-0.0023425726538614036, 0.9193994765973328, -4.65869961504571, -4.
309+
7588990850422315], [-0.0055721993164947236, 0.9188130594802799, -5.07285122
310+
4845033, -4.821126122491971], [-0.004219298165150283, 0.9192276077735035, -
311+
5.123892639763714, -4.732313830031988], [0.008204839974368702, 0.9154717581
312+
833844, -5.262493296595247, -4.765041828828386], [-0.0003716538471601385, 0
313+
.9238541077264724, -4.63503538201953, -4.533329783471124], [0.0118946147660
314+
13212, 0.914194837848557, -5.353835161665349, -4.939926671543258], [0.00465
315+
384452294309, 0.9136762470772759, -5.249893322511043, -4.776953276823668],
316+
[0.0058645040311212536, 0.9166690710924666, -5.20639292670556, -4.780932837
317+
9729665]], tree_statistics = DynamicHMC.TreeStatisticsNUTS[DynamicHMC.TreeS
318+
tatisticsNUTS(50.56734359829101, 3, turning at positions -2:5, 0.9755240439
319+
296025, 7, DynamicHMC.Directions(0x60136a7d)), DynamicHMC.TreeStatisticsNUT
320+
S(53.21141797289948, 4, turning at positions -15:-30, 0.9303977362255577, 3
321+
1, DynamicHMC.Directions(0xf1231261)), DynamicHMC.TreeStatisticsNUTS(54.360
322+
60487933965, 4, turning at positions -1:14, 0.9893416328240263, 15, Dynamic
323+
HMC.Directions(0xe900c27e)), DynamicHMC.TreeStatisticsNUTS(52.0483617779477
324+
5, 5, turning at positions -21:10, 0.9309817421978657, 31, DynamicHMC.Direc
325+
tions(0xe9e3438a)), DynamicHMC.TreeStatisticsNUTS(49.87055132316347, 3, tur
326+
ning at positions 11:14, 0.8572610926819243, 15, DynamicHMC.Directions(0x29
327+
4e978e)), DynamicHMC.TreeStatisticsNUTS(49.19262263342683, 2, turning at po
328+
sitions -2:1, 0.9759539124756529, 3, DynamicHMC.Directions(0xdbea4a09)), Dy
329+
namicHMC.TreeStatisticsNUTS(48.43198128487201, 3, turning at positions -7:0
330+
, 0.981239057221168, 7, DynamicHMC.Directions(0x6fcc3a70)), DynamicHMC.Tree
331+
StatisticsNUTS(47.63166261616587, 2, turning at positions -3:0, 0.995085735
332+
1314513, 3, DynamicHMC.Directions(0xe54734a8)), DynamicHMC.TreeStatisticsNU
333+
TS(46.56863924159915, 3, turning at positions -5:-12, 0.905437436880739, 15
334+
, DynamicHMC.Directions(0xea3a9733)), DynamicHMC.TreeStatisticsNUTS(49.6833
335+
97905049375, 1, turning at positions -1:0, 1.0, 1, DynamicHMC.Directions(0x
336+
1fab498e)) … DynamicHMC.TreeStatisticsNUTS(51.65316234412242, 4, turning
337+
at positions 3:18, 0.9415872806438115, 31, DynamicHMC.Directions(0xa9468a12
338+
)), DynamicHMC.TreeStatisticsNUTS(51.711606238318794, 3, turning at positio
339+
ns -1:6, 0.8623986191786585, 7, DynamicHMC.Directions(0x51644e36)), Dynamic
340+
HMC.TreeStatisticsNUTS(51.919656641254015, 4, turning at positions -8:-23,
341+
0.9084259271209352, 31, DynamicHMC.Directions(0x3319ca68)), DynamicHMC.Tree
342+
StatisticsNUTS(53.202536402977834, 3, turning at positions 6:13, 0.96705305
343+
01687453, 15, DynamicHMC.Directions(0x5012713d)), DynamicHMC.TreeStatistics
344+
NUTS(52.73931670212959, 3, turning at positions -4:-11, 0.9929665212170776,
345+
15, DynamicHMC.Directions(0x6329d9e4)), DynamicHMC.TreeStatisticsNUTS(54.1
346+
31036067725944, 3, turning at positions -4:-11, 0.9888188927167221, 15, Dyn
347+
amicHMC.Directions(0x8b23b0b4)), DynamicHMC.TreeStatisticsNUTS(51.488325981
348+
24062, 4, turning at positions 0:15, 0.6147719493975838, 15, DynamicHMC.Dir
349+
ections(0xfc3d19af)), DynamicHMC.TreeStatisticsNUTS(49.47067467051124, 4, t
350+
urning at positions -13:2, 0.9105644910181027, 15, DynamicHMC.Directions(0x
351+
cd45ce02)), DynamicHMC.TreeStatisticsNUTS(48.9855447707631, 3, turning at p
352+
ositions -2:5, 0.5761761677703705, 7, DynamicHMC.Directions(0xd968714d)), D
353+
ynamicHMC.TreeStatisticsNUTS(54.41136057897444, 3, turning at positions -3:
354+
4, 0.9945976434248669, 7, DynamicHMC.Directions(0x77185434))], κ = Gaussian
355+
kinetic energy (Diagonal), √diag(M⁻¹): [0.022077917058954552, 0.0202796088
356+
43069858, 0.2918638996850573, 0.24966858752333757], ϵ = 0.17395099027012478
357+
)
358+
```
359+
360+
21.3 KB
Loading
24.5 KB
Loading
133 KB
Loading
66.8 KB
Loading

0 commit comments

Comments
 (0)