@@ -12,6 +12,7 @@ function DiffEqBase.solve(prob::OptimizationProblem, opt, args...;kwargs...)
1212 __solve (prob, opt, args... ; kwargs... )
1313end
1414
15+ #=
1516function update!(x::AbstractArray, x̄::AbstractArray{<:ForwardDiff.Dual})
1617 x .-= x̄
1718end
3132function update!(opt, xs::Flux.Zygote.Params, gs)
3233 update!(opt, xs[1], gs)
3334end
35+ =#
3436
3537maybe_with_logger (f, logger) = logger === nothing ? f () : Logging. with_logger (f, logger)
3638
@@ -62,7 +64,10 @@ macro withprogress(progress, exprs...)
6264 end |> esc
6365 end
6466
65- function __solve (prob:: OptimizationProblem , opt, data = DEFAULT_DATA;cb = (args... ) -> (false ), maxiters:: Number = 1000 , progress = true , save_best = true , kwargs... )
67+ function __solve (prob:: OptimizationProblem , opt, data = DEFAULT_DATA;
68+ cb = (args... ) -> (false ), maxiters:: Number = 1000 ,
69+ progress = true , save_best = true , kwargs... )
70+
6671 if maxiters <= 0.0
6772 error (" The number of maxiters has to be a non-negative and non-zero number." )
6873 else
@@ -76,7 +81,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;cb = (args.
7681
7782 if data != DEFAULT_DATA
7883 maxiters = length (data)
79- else
84+ else
8085 data = take (data, maxiters)
8186 end
8287
@@ -90,8 +95,10 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;cb = (args.
9095
9196 @withprogress progress name= " Training" begin
9297 for (i,d) in enumerate (data)
93- gs = prob. f. adtype isa AutoFiniteDiff ? Array {Number} (undef,length (θ)) : DiffResults. GradientResult (θ)
94- f. grad (gs, θ, d... )
98+ gs = Flux. Zygote. gradient (ps) do
99+ x = prob. f (θ,prob. p, d... )
100+ first (x)
101+ end
95102 x = f. f (θ, prob. p, d... )
96103 cb_call = cb (θ, x... )
97104 if ! (typeof (cb_call) <: Bool )
@@ -101,7 +108,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;cb = (args.
101108 end
102109 msg = @sprintf (" loss: %.3g" , x[1 ])
103110 progress && ProgressLogging. @logprogress msg i/ maxiters
104- update! (opt, ps, prob . f . adtype isa AutoFiniteDiff ? gs : DiffResults . gradient (gs) )
111+ Flux . update! (opt, ps, gs )
105112
106113 if save_best
107114 if first (x) < first (min_err) # found a better solution
@@ -215,7 +222,7 @@ function __solve(prob::OptimizationProblem, opt::Union{Optim.Fminbox,Optim.SAMIN
215222 if ! (typeof (cb_call) <: Bool )
216223 error (" The callback should return a boolean `halt` for whether to stop the optimization process." )
217224 end
218- cur, state = iterate (data, state)
225+ cur, state = iterate (data, state)
219226 cb_call
220227 end
221228
0 commit comments