Skip to content

Commit ae0900b

Browse files
Merge pull request #280 from SciML/minibatch_docs
Fix minibatch example
2 parents e1c5b04 + b2a1738 commit ae0900b

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

docs/src/tutorials/minibatch.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
!!! note
44

5-
This example uses the OptimizationOptimJL.jl package. See the [Optim.jl page](@ref optim)
5+
This example uses the OptimizationOptimisers.jl package. See the [Optimisers.jl page](@ref optimisers)
66
for details on the installation and usage.
77

88
```julia
9-
using DiffEqFlux, Optimization, OptimizationOptimJL, OrdinaryDiffEq
9+
using Flux, Optimization, OptimizationOptimisers, OrdinaryDiffEq
1010

1111
function newtons_cooling(du, u, p, t)
1212
temp = u[1]
@@ -19,8 +19,11 @@ function true_sol(du, u, p, t)
1919
newtons_cooling(du, u, true_p, t)
2020
end
2121

22+
ann = Chain(FastDense(1,8,tanh), FastDense(8,1,tanh))
23+
pp,re = Flux.destructure(ann)
24+
2225
function dudt_(u,p,t)
23-
ann(u, p).* u
26+
re(p)(u) .* u
2427
end
2528

2629
callback = function (p,l,pred;doplot=false) #callback function to observe training
@@ -42,8 +45,6 @@ t = range(tspan[1], tspan[2], length=datasize)
4245
true_prob = ODEProblem(true_sol, u0, tspan)
4346
ode_data = Array(solve(true_prob, Tsit5(), saveat=t))
4447

45-
ann = FastChain(FastDense(1,8,tanh), FastDense(8,1,tanh))
46-
pp = initial_params(ann)
4748
prob = ODEProblem{false}(dudt_, u0, tspan, pp)
4849

4950
function predict_adjoint(fullp, time_batch)
@@ -65,6 +66,6 @@ l1 = loss_adjoint(pp, train_loader.data[1], train_loader.data[2])[1]
6566
optfun = OptimizationFunction((θ, p, batch, time_batch) -> loss_adjoint(θ, batch, time_batch), Optimization.AutoZygote())
6667
optprob = OptimizationProblem(optfun, pp)
6768
using IterTools: ncycle
68-
res1 = Optimization.solve(optprob, ADAM(0.05), ncycle(train_loader, numEpochs), callback = callback)
69+
res1 = Optimization.solve(optprob, Optimisers.ADAM(0.05), ncycle(train_loader, numEpochs), callback = callback)
6970
@test 10res1.minimum < l1
7071
```

0 commit comments

Comments
 (0)