|
| 1 | +--- |
| 2 | +author: "Vaibhav Dixit" |
| 3 | +title: "Bayesian Inference on a Pendulum using DiffEqBayes.jl" |
| 4 | +--- |
| 5 | + |
| 6 | + |
| 7 | +### Set up simple pendulum problem |
| 8 | + |
| 9 | +```julia |
| 10 | +using DiffEqBayes, OrdinaryDiffEq, RecursiveArrayTools, Distributions, Plots, StatsPlots, BenchmarkTools, TransformVariables, CmdStan, DynamicHMC |
| 11 | +``` |
| 12 | + |
| 13 | + |
| 14 | + |
| 15 | + |
| 16 | +Let's define our simple pendulum problem. Here our pendulum has a drag term `ω` |
| 17 | +and a length `L`. |
| 18 | + |
| 19 | + |
| 20 | + |
| 21 | +We get first order equations by defining the first term as the velocity and the |
| 22 | +second term as the position, getting: |
| 23 | + |
| 24 | +```julia |
| 25 | +function pendulum(du,u,p,t) |
| 26 | + ω,L = p |
| 27 | + x,y = u |
| 28 | + du[1] = y |
| 29 | + du[2] = - ω*y -(9.8/L)*sin(x) |
| 30 | +end |
| 31 | + |
| 32 | +u0 = [1.0,0.1] |
| 33 | +tspan = (0.0,10.0) |
| 34 | +prob1 = ODEProblem(pendulum,u0,tspan,[1.0,2.5]) |
| 35 | +``` |
| 36 | + |
| 37 | +``` |
| 38 | +ODEProblem with uType Array{Float64,1} and tType Float64. In-place: true |
| 39 | +timespan: (0.0, 10.0) |
| 40 | +u0: [1.0, 0.1] |
| 41 | +``` |
| 42 | + |
| 43 | + |
| 44 | + |
| 45 | + |
| 46 | + |
| 47 | +### Solve the model and plot |
| 48 | + |
| 49 | +To understand the model and generate data, let's solve and visualize the solution |
| 50 | +with the known parameters: |
| 51 | + |
| 52 | +```julia |
| 53 | +sol = solve(prob1,Tsit5()) |
| 54 | +plot(sol) |
| 55 | +``` |
| 56 | + |
| 57 | + |
| 58 | + |
| 59 | + |
| 60 | + |
| 61 | +It's the pendulum, so you know what it looks like. It's periodic, but since we |
| 62 | +have not made a small angle assumption it's not exactly `sin` or `cos`. Because |
| 63 | +the true dampening parameter `ω` is 1, the solution does not decay over time, |
| 64 | +nor does it increase. The length `L` determines the period. |
| 65 | + |
| 66 | +### Create some dummy data to use for estimation |
| 67 | + |
| 68 | +We now generate some dummy data to use for estimation |
| 69 | + |
| 70 | +```julia |
| 71 | +t = collect(range(1,stop=10,length=10)) |
| 72 | +randomized = VectorOfArray([(sol(t[i]) + .01randn(2)) for i in 1:length(t)]) |
| 73 | +data = convert(Array,randomized) |
| 74 | +``` |
| 75 | + |
| 76 | +``` |
| 77 | +2×10 Array{Float64,2}: |
| 78 | + 0.0669231 -0.377851 0.119404 0.0795968 … -0.01553 0.00535298 |
| 79 | + -1.21411 0.344681 0.323712 -0.253243 0.0164092 -0.00897403 |
| 80 | +``` |
| 81 | + |
| 82 | + |
| 83 | + |
| 84 | + |
| 85 | + |
| 86 | +Let's see what our data looks like on top of the real solution |
| 87 | + |
| 88 | +```julia |
| 89 | +scatter!(data') |
| 90 | +``` |
| 91 | + |
| 92 | + |
| 93 | + |
| 94 | + |
| 95 | + |
| 96 | +This data captures the non-dampening effect and the true period, making it |
| 97 | +perfect to attempting a Bayesian inference. |
| 98 | + |
| 99 | +### Perform Bayesian Estimation |
| 100 | + |
| 101 | +Now let's fit the pendulum to the data. Since we know our model is correct, |
| 102 | +this should give us back the parameters that we used to generate the data! |
| 103 | +Define priors on our parameters. In this case, let's assume we don't have much |
| 104 | +information, but have a prior belief that ω is between 0.1 and 3.0, while the |
| 105 | +length of the pendulum L is probably around 3.0: |
| 106 | + |
| 107 | +```julia |
| 108 | +priors = [Uniform(0.1,3.0), Normal(3.0,1.0)] |
| 109 | +``` |
| 110 | + |
| 111 | +``` |
| 112 | +2-element Array{Distributions.Distribution{Distributions.Univariate,Distrib |
| 113 | +utions.Continuous},1}: |
| 114 | + Distributions.Uniform{Float64}(a=0.1, b=3.0) |
| 115 | + Distributions.Normal{Float64}(μ=3.0, σ=1.0) |
| 116 | +``` |
| 117 | + |
| 118 | + |
| 119 | + |
| 120 | + |
| 121 | + |
| 122 | +Finally let's run the estimation routine from DiffEqBayes.jl with the Turing.jl backend to check if we indeed recover the parameters! |
| 123 | + |
| 124 | +```julia |
| 125 | +bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;num_samples=10_000, |
| 126 | + syms = [:omega,:L]) |
| 127 | +``` |
| 128 | + |
| 129 | +``` |
| 130 | +Chains MCMC chain (9000×15×1 Array{Float64,3}): |
| 131 | +
|
| 132 | +Iterations = 1:9000 |
| 133 | +Thinning interval = 1 |
| 134 | +Chains = 1 |
| 135 | +Samples per chain = 9000 |
| 136 | +parameters = L, omega, σ[1] |
| 137 | +internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy |
| 138 | +_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, |
| 139 | +nom_step_size, numerical_error, step_size, tree_depth |
| 140 | +
|
| 141 | +Summary Statistics |
| 142 | + parameters mean std naive_se mcse ess rhat |
| 143 | + |
| 144 | + Symbol Float64 Float64 Float64 Float64 Float64 Float64 |
| 145 | + |
| 146 | + |
| 147 | + |
| 148 | + L 2.5036 0.2148 0.0023 0.0035 3703.0703 1.0000 |
| 149 | + |
| 150 | + omega 1.0777 0.2217 0.0023 0.0048 2000.0506 1.0008 |
| 151 | + |
| 152 | + σ[1] 0.1603 0.0390 0.0004 0.0007 3326.8139 0.9999 |
| 153 | + |
| 154 | +
|
| 155 | +Quantiles |
| 156 | + parameters 2.5% 25.0% 50.0% 75.0% 97.5% |
| 157 | + Symbol Float64 Float64 Float64 Float64 Float64 |
| 158 | + |
| 159 | + L 2.0761 2.3766 2.5024 2.6302 2.9287 |
| 160 | + omega 0.7670 0.9384 1.0395 1.1706 1.6059 |
| 161 | + σ[1] 0.1018 0.1325 0.1540 0.1812 0.2529 |
| 162 | +``` |
| 163 | + |
| 164 | + |
| 165 | + |
| 166 | + |
| 167 | + |
| 168 | +Notice that while our guesses had the wrong means, the learned parameters converged |
| 169 | +to the correct means, meaning that it learned good posterior distributions for the |
| 170 | +parameters. To look at these posterior distributions on the parameters, we can |
| 171 | +examine the chains: |
| 172 | + |
| 173 | +```julia |
| 174 | +plot(bayesian_result) |
| 175 | +``` |
| 176 | + |
| 177 | + |
| 178 | + |
| 179 | + |
| 180 | + |
| 181 | +As a diagnostic, we will also check the parameter chains. The chain is the MCMC |
| 182 | +sampling process. The chain should explore parameter space and converge reasonably |
| 183 | +well, and we should be taking a lot of samples after it converges (it is these |
| 184 | +samples that form the posterior distribution!) |
| 185 | + |
| 186 | +```julia |
| 187 | +plot(bayesian_result, colordim = :parameter) |
| 188 | +``` |
| 189 | + |
| 190 | + |
| 191 | + |
| 192 | + |
| 193 | + |
| 194 | +Notice that after awhile these chains converge to a "fuzzy line", meaning it |
| 195 | +found the area with the most likelihood and then starts to sample around there, |
| 196 | +which builds a posterior distribution around the true mean. |
| 197 | + |
| 198 | +DiffEqBayes.jl allows the choice of using Stan.jl, Turing.jl and DynamicHMC.jl for MCMC, you can also use ApproxBayes.jl for Approximate Bayesian computation algorithms. |
| 199 | +Let's compare the timings across the different MCMC backends. We'll stick with the default arguments and 10,000 samples in each since there is a lot of room for micro-optimization |
| 200 | +specific to each package and algorithm combinations, you might want to do your own experiments for specific problems to get better understanding of the performance. |
| 201 | + |
| 202 | +```julia |
| 203 | +@btime bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;syms = [:omega,:L],num_samples=10_000) |
| 204 | +``` |
| 205 | + |
| 206 | +``` |
| 207 | +2.710 s (23598867 allocations: 1.50 GiB) |
| 208 | +Chains MCMC chain (9000×15×1 Array{Float64,3}): |
| 209 | +
|
| 210 | +Iterations = 1:9000 |
| 211 | +Thinning interval = 1 |
| 212 | +Chains = 1 |
| 213 | +Samples per chain = 9000 |
| 214 | +parameters = L, omega, σ[1] |
| 215 | +internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy |
| 216 | +_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, |
| 217 | +nom_step_size, numerical_error, step_size, tree_depth |
| 218 | +
|
| 219 | +Summary Statistics |
| 220 | + parameters mean std naive_se mcse ess rhat |
| 221 | + |
| 222 | + Symbol Float64 Float64 Float64 Float64 Float64 Float64 |
| 223 | + |
| 224 | + |
| 225 | + |
| 226 | + L 2.5019 0.2081 0.0022 0.0034 3767.3721 1.0000 |
| 227 | + |
| 228 | + omega 1.0773 0.2137 0.0023 0.0040 2973.1493 1.0001 |
| 229 | + |
| 230 | + σ[1] 0.1593 0.0371 0.0004 0.0006 4173.1326 1.0004 |
| 231 | + |
| 232 | +
|
| 233 | +Quantiles |
| 234 | + parameters 2.5% 25.0% 50.0% 75.0% 97.5% |
| 235 | + Symbol Float64 Float64 Float64 Float64 Float64 |
| 236 | + |
| 237 | + L 2.0844 2.3770 2.5029 2.6269 2.9178 |
| 238 | + omega 0.7660 0.9383 1.0424 1.1743 1.6056 |
| 239 | + σ[1] 0.1032 0.1325 0.1538 0.1793 0.2468 |
| 240 | +``` |
| 241 | + |
| 242 | + |
| 243 | + |
| 244 | +```julia |
| 245 | +@btime bayesian_result = stan_inference(prob1,t,data,priors;num_samples=10_000,printsummary=false) |
| 246 | +``` |
| 247 | + |
| 248 | +``` |
| 249 | +Error: MethodError: no method matching iterate(::ModelingToolkit.ODESystem) |
| 250 | +Closest candidates are: |
| 251 | + iterate(!Matched::Core.SimpleVector) at essentials.jl:603 |
| 252 | + iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:603 |
| 253 | + iterate(!Matched::ExponentialBackOff) at error.jl:253 |
| 254 | + ... |
| 255 | +``` |
| 256 | + |
| 257 | + |
| 258 | + |
| 259 | +```julia |
| 260 | +@btime bayesian_result = dynamichmc_inference(prob1,Tsit5(),t,data,priors;num_samples = 10_000) |
| 261 | +``` |
| 262 | + |
| 263 | +``` |
| 264 | +6.027 s (40540072 allocations: 3.52 GiB) |
| 265 | +(posterior = NamedTuple{(:parameters, :σ),Tuple{Array{Float64,1},Array{Floa |
| 266 | +t64,1}}}[(parameters = [0.9925322562437633, 2.499846186491921], σ = [0.0059 |
| 267 | +66804814045917, 0.008177933301622841]), (parameters = [0.9963837443808898, |
| 268 | +2.502334158254934], σ = [0.006778656235910425, 0.009222937077381753]), (par |
| 269 | +ameters = [1.0036593578298718, 2.5036585671312954], σ = [0.0052096129327429 |
| 270 | +53, 0.009389702547257326]), (parameters = [1.012182569162702, 2.49418373759 |
| 271 | +8965], σ = [0.009352843917833122, 0.006784007433840137]), (parameters = [0. |
| 272 | +9776075654162109, 2.506917713604719], σ = [0.00819375792385741, 0.008509449 |
| 273 | +758223278]), (parameters = [0.9711538245294002, 2.52311064587977], σ = [0.0 |
| 274 | +08075286197741678, 0.008380689150424841]), (parameters = [1.030127379473760 |
| 275 | +4, 2.4900769004417103], σ = [0.005671628576862689, 0.009404135319949877]), |
| 276 | +(parameters = [1.027662532372297, 2.47639876182845], σ = [0.005769289829077 |
| 277 | +621, 0.010107438192968452]), (parameters = [1.02159465396289, 2.47247053938 |
| 278 | +56765], σ = [0.005902386622399507, 0.009452950393124696]), (parameters = [1 |
| 279 | +.022343560476168, 2.4776243412273726], σ = [0.0058799594731072675, 0.009630 |
| 280 | +482440656476]) … (parameters = [0.99920993073967, 2.507790860320399], σ = |
| 281 | + [0.011498727162582899, 0.009257041097492019]), (parameters = [1.0033967451 |
| 282 | +562795, 2.5151332738328414], σ = [0.009519097562008954, 0.00920901544856593 |
| 283 | +2]), (parameters = [0.9976601690281769, 2.507783954703131], σ = [0.00947878 |
| 284 | +0454357206, 0.008575044597729732]), (parameters = [0.9944432965906622, 2.50 |
| 285 | +63137783766347], σ = [0.006264533060570234, 0.00805770806170044]), (paramet |
| 286 | +ers = [0.995789590567554, 2.507352981860877], σ = [0.005952805608239749, 0. |
| 287 | +008806071695367526]), (parameters = [1.0082385919204935, 2.497953403941628] |
| 288 | +, σ = [0.005182367418104671, 0.008522531747426976]), (parameters = [0.99962 |
| 289 | +84152075758, 2.5189801260651614], σ = [0.009705763622756209, 0.010744838454 |
| 290 | +615106]), (parameters = [1.0119656370109682, 2.494765752066562], σ = [0.004 |
| 291 | +729975947126395, 0.007155123027207663]), (parameters = [1.0046646904759642, |
| 292 | + 2.4934723249796953], σ = [0.005248078221127134, 0.008421618259762318]), (p |
| 293 | +arameters = [1.005881733899965, 2.5009460269940513], σ = [0.005481409918559 |
| 294 | +183, 0.008388170512741225])], chain = [[-0.007495766955188785, 0.9162292045 |
| 295 | +781582, -5.121543701879515, -4.8063158129028745], [-0.0036228100779651584, |
| 296 | +0.9172239595836579, -4.993976391966563, -4.686061737110957], [0.00365267866 |
| 297 | +93412507, 0.917753088961152, -5.257249719116282, -4.66814166387209], [0.012 |
| 298 | +108958905343204, 0.913961516396238, -4.6720748195944335, -4.993187284703021 |
| 299 | +], [-0.022646951830213737, 0.9190539959828018, -4.804382643391568, -4.76657 |
| 300 | +7996745424], [-0.029270404566869655, 0.9254925235842713, -4.818946968056696 |
| 301 | +, -4.781825130348601], [0.029682463987115947, 0.9123135937112545, -5.172278 |
| 302 | +975524333, -4.666605758791736], [0.02728683723620924, 0.9068053926743036, - |
| 303 | +5.155206285919641, -4.5944836714353015], [0.021364792718281282, 0.905217869 |
| 304 | +1295568, -5.132398497592392, -4.6614283752915044], [0.022097600140248842, 0 |
| 305 | +.9073001741472317, -5.13620540942405, -4.64282195674715] … [-0.0007903815 |
| 306 | +295350864, 0.9194022302666618, -4.465518931253114, -4.682370817294514], [0. |
| 307 | +0033909892480041516, 0.9223257937313328, -4.654455228578712, -4.68757233469 |
| 308 | +34475], [-0.0023425726538614036, 0.9193994765973328, -4.65869961504571, -4. |
| 309 | +7588990850422315], [-0.0055721993164947236, 0.9188130594802799, -5.07285122 |
| 310 | +4845033, -4.821126122491971], [-0.004219298165150283, 0.9192276077735035, - |
| 311 | +5.123892639763714, -4.732313830031988], [0.008204839974368702, 0.9154717581 |
| 312 | +833844, -5.262493296595247, -4.765041828828386], [-0.0003716538471601385, 0 |
| 313 | +.9238541077264724, -4.63503538201953, -4.533329783471124], [0.0118946147660 |
| 314 | +13212, 0.914194837848557, -5.353835161665349, -4.939926671543258], [0.00465 |
| 315 | +384452294309, 0.9136762470772759, -5.249893322511043, -4.776953276823668], |
| 316 | +[0.0058645040311212536, 0.9166690710924666, -5.20639292670556, -4.780932837 |
| 317 | +9729665]], tree_statistics = DynamicHMC.TreeStatisticsNUTS[DynamicHMC.TreeS |
| 318 | +tatisticsNUTS(50.56734359829101, 3, turning at positions -2:5, 0.9755240439 |
| 319 | +296025, 7, DynamicHMC.Directions(0x60136a7d)), DynamicHMC.TreeStatisticsNUT |
| 320 | +S(53.21141797289948, 4, turning at positions -15:-30, 0.9303977362255577, 3 |
| 321 | +1, DynamicHMC.Directions(0xf1231261)), DynamicHMC.TreeStatisticsNUTS(54.360 |
| 322 | +60487933965, 4, turning at positions -1:14, 0.9893416328240263, 15, Dynamic |
| 323 | +HMC.Directions(0xe900c27e)), DynamicHMC.TreeStatisticsNUTS(52.0483617779477 |
| 324 | +5, 5, turning at positions -21:10, 0.9309817421978657, 31, DynamicHMC.Direc |
| 325 | +tions(0xe9e3438a)), DynamicHMC.TreeStatisticsNUTS(49.87055132316347, 3, tur |
| 326 | +ning at positions 11:14, 0.8572610926819243, 15, DynamicHMC.Directions(0x29 |
| 327 | +4e978e)), DynamicHMC.TreeStatisticsNUTS(49.19262263342683, 2, turning at po |
| 328 | +sitions -2:1, 0.9759539124756529, 3, DynamicHMC.Directions(0xdbea4a09)), Dy |
| 329 | +namicHMC.TreeStatisticsNUTS(48.43198128487201, 3, turning at positions -7:0 |
| 330 | +, 0.981239057221168, 7, DynamicHMC.Directions(0x6fcc3a70)), DynamicHMC.Tree |
| 331 | +StatisticsNUTS(47.63166261616587, 2, turning at positions -3:0, 0.995085735 |
| 332 | +1314513, 3, DynamicHMC.Directions(0xe54734a8)), DynamicHMC.TreeStatisticsNU |
| 333 | +TS(46.56863924159915, 3, turning at positions -5:-12, 0.905437436880739, 15 |
| 334 | +, DynamicHMC.Directions(0xea3a9733)), DynamicHMC.TreeStatisticsNUTS(49.6833 |
| 335 | +97905049375, 1, turning at positions -1:0, 1.0, 1, DynamicHMC.Directions(0x |
| 336 | +1fab498e)) … DynamicHMC.TreeStatisticsNUTS(51.65316234412242, 4, turning |
| 337 | +at positions 3:18, 0.9415872806438115, 31, DynamicHMC.Directions(0xa9468a12 |
| 338 | +)), DynamicHMC.TreeStatisticsNUTS(51.711606238318794, 3, turning at positio |
| 339 | +ns -1:6, 0.8623986191786585, 7, DynamicHMC.Directions(0x51644e36)), Dynamic |
| 340 | +HMC.TreeStatisticsNUTS(51.919656641254015, 4, turning at positions -8:-23, |
| 341 | +0.9084259271209352, 31, DynamicHMC.Directions(0x3319ca68)), DynamicHMC.Tree |
| 342 | +StatisticsNUTS(53.202536402977834, 3, turning at positions 6:13, 0.96705305 |
| 343 | +01687453, 15, DynamicHMC.Directions(0x5012713d)), DynamicHMC.TreeStatistics |
| 344 | +NUTS(52.73931670212959, 3, turning at positions -4:-11, 0.9929665212170776, |
| 345 | + 15, DynamicHMC.Directions(0x6329d9e4)), DynamicHMC.TreeStatisticsNUTS(54.1 |
| 346 | +31036067725944, 3, turning at positions -4:-11, 0.9888188927167221, 15, Dyn |
| 347 | +amicHMC.Directions(0x8b23b0b4)), DynamicHMC.TreeStatisticsNUTS(51.488325981 |
| 348 | +24062, 4, turning at positions 0:15, 0.6147719493975838, 15, DynamicHMC.Dir |
| 349 | +ections(0xfc3d19af)), DynamicHMC.TreeStatisticsNUTS(49.47067467051124, 4, t |
| 350 | +urning at positions -13:2, 0.9105644910181027, 15, DynamicHMC.Directions(0x |
| 351 | +cd45ce02)), DynamicHMC.TreeStatisticsNUTS(48.9855447707631, 3, turning at p |
| 352 | +ositions -2:5, 0.5761761677703705, 7, DynamicHMC.Directions(0xd968714d)), D |
| 353 | +ynamicHMC.TreeStatisticsNUTS(54.41136057897444, 3, turning at positions -3: |
| 354 | +4, 0.9945976434248669, 7, DynamicHMC.Directions(0x77185434))], κ = Gaussian |
| 355 | + kinetic energy (Diagonal), √diag(M⁻¹): [0.022077917058954552, 0.0202796088 |
| 356 | +43069858, 0.2918638996850573, 0.24966858752333757], ϵ = 0.17395099027012478 |
| 357 | +) |
| 358 | +``` |
| 359 | + |
| 360 | + |
0 commit comments