22
33!!! note
44
5- This example uses the OptimizationOptimJL .jl package. See the [Optim .jl page](@ref optim )
5+ This example uses the OptimizationOptimisers .jl package. See the [Optimisers .jl page](@ref optimisers )
66 for details on the installation and usage.
77
88``` julia
9- using DiffEqFlux , Optimization, OptimizationOptimJL , OrdinaryDiffEq
9+ using Flux , Optimization, OptimizationOptimisers , OrdinaryDiffEq
1010
1111function newtons_cooling (du, u, p, t)
1212 temp = u[1 ]
@@ -19,8 +19,11 @@ function true_sol(du, u, p, t)
1919 newtons_cooling (du, u, true_p, t)
2020end
2121
22+ ann = Chain (FastDense (1 ,8 ,tanh), FastDense (8 ,1 ,tanh))
23+ pp,re = Flux. destructure (ann)
24+
2225function dudt_ (u,p,t)
23- ann (u, p) .* u
26+ re (p)(u) .* u
2427end
2528
2629callback = function (p,l,pred;doplot= false ) # callback function to observe training
@@ -42,8 +45,6 @@ t = range(tspan[1], tspan[2], length=datasize)
4245true_prob = ODEProblem (true_sol, u0, tspan)
4346ode_data = Array (solve (true_prob, Tsit5 (), saveat= t))
4447
45- ann = FastChain (FastDense (1 ,8 ,tanh), FastDense (8 ,1 ,tanh))
46- pp = initial_params (ann)
4748prob = ODEProblem {false} (dudt_, u0, tspan, pp)
4849
4950function predict_adjoint (fullp, time_batch)
@@ -65,6 +66,6 @@ l1 = loss_adjoint(pp, train_loader.data[1], train_loader.data[2])[1]
6566optfun = OptimizationFunction ((θ, p, batch, time_batch) -> loss_adjoint (θ, batch, time_batch), Optimization. AutoZygote ())
6667optprob = OptimizationProblem (optfun, pp)
6768using IterTools: ncycle
68- res1 = Optimization. solve (optprob, ADAM (0.05 ), ncycle (train_loader, numEpochs), callback = callback)
69+ res1 = Optimization. solve (optprob, Optimisers . ADAM (0.05 ), ncycle (train_loader, numEpochs), callback = callback)
6970@test 10 res1. minimum < l1
7071```
0 commit comments