Skip to content

Commit 11cc884

Browse files
update to optim 1
1 parent ff0cc8d commit 11cc884

File tree

4 files changed

+49
-51
lines changed

4 files changed

+49
-51
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ FiniteDiff = "2.5"
2929
Flux = "0.11"
3030
ForwardDiff = "0.10"
3131
LoggingExtras = "0.4"
32-
Optim = "0.22"
32+
Optim = "0.22, 1"
3333
ProgressLogging = "0.1"
3434
Requires = "1.0"
3535
ReverseDiff = "1.4"
@@ -43,7 +43,8 @@ BlackBoxOptim = "a134a8b2-14d6-55f6-9291-3336d3ab0209"
4343
CMAEvolutionStrategy = "8d3b24bd-414e-49e0-94fb-163cc3a3e411"
4444
Evolutionary = "86b6b26d-c046-49b6-aa0b-5f0f74682bd6"
4545
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
46+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4647
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4748

4849
[targets]
49-
test = ["BlackBoxOptim", "Evolutionary", "NLopt", "CMAEvolutionStrategy", "Test"]
50+
test = ["BlackBoxOptim", "Evolutionary", "NLopt", "CMAEvolutionStrategy", "SafeTestsets", "Test"]

src/solve.jl

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,19 @@ function __solve(prob::OptimizationProblem, opt;cb = (args...) -> (false), maxit
6565
# this is a Flux optimizer
6666
θ = copy(prob.x)
6767
ps = Flux.params(θ)
68-
68+
6969
t0 = time()
70-
70+
7171
local x, min_err, _loss
7272
min_err = typemax(eltype(prob.x)) #dummy variables
7373
min_opt = 1
74-
75-
76-
if prob.f isa OptimizationFunction
74+
75+
76+
if prob.f isa OptimizationFunction
7777
_loss = function(θ)
7878
x = prob.f.f(θ, prob.p)
7979
end
80-
else
80+
else
8181
_loss = function(θ)
8282
x = prob.f(θ, prob.p)
8383
end
@@ -98,7 +98,7 @@ function __solve(prob::OptimizationProblem, opt;cb = (args...) -> (false), maxit
9898
msg = @sprintf("loss: %.3g", x[1])
9999
progress && ProgressLogging.@logprogress msg i/maxiters
100100
update!(opt, ps, gs)
101-
101+
102102
if save_best
103103
if first(x) < first(min_err) #found a better solution
104104
min_opt = opt
@@ -111,9 +111,9 @@ function __solve(prob::OptimizationProblem, opt;cb = (args...) -> (false), maxit
111111
end
112112
end
113113
end
114-
114+
115115
_time = time()
116-
116+
117117
Optim.MultivariateOptimizationResults(opt,
118118
prob.x,# initial_x,
119119
θ, #pick_best_x(f_incr_pick, state),
@@ -142,7 +142,7 @@ function __solve(prob::OptimizationProblem, opt;cb = (args...) -> (false), maxit
142142
NaN,
143143
_time-t0)
144144
end
145-
145+
146146

147147
decompose_trace(trace::Optim.OptimizationTrace) = last(trace)
148148
decompose_trace(trace) = trace
@@ -157,7 +157,7 @@ function __solve(prob::OptimizationProblem, opt::Optim.AbstractOptimizer;cb = (a
157157
end
158158
cb_call
159159
end
160-
160+
161161
if prob.f isa OptimizationFunction
162162
_loss = function(θ)
163163
x = prob.f.f(θ, prob.p)
@@ -197,11 +197,11 @@ function __solve(prob::OptimizationProblem, opt::Union{Optim.Fminbox,Optim.SAMIN
197197
end
198198
cb_call
199199
end
200-
200+
201201
if prob.f isa OptimizationFunction && !(opt isa Optim.SAMIN)
202202
_loss = function(θ)
203203
x = prob.f.f(θ, prob.p)
204-
return x[1]
204+
return x[1]
205205
end
206206
fg! = function (G,θ)
207207
if G !== nothing
@@ -212,14 +212,14 @@ function __solve(prob::OptimizationProblem, opt::Union{Optim.Fminbox,Optim.SAMIN
212212
end
213213
optim_f = OnceDifferentiable(_loss, prob.f.grad, fg!, prob.x)
214214
else
215-
!(opt isa Optim.ZerothOrderOptimizer) && error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")
215+
!(opt isa Optim.ZerothOrderOptimizer || opt isa Optim.SAMIN) && error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")
216216
_loss = function(θ)
217217
x = prob.f isa OptimizationFunction ? prob.f.f(θ, prob.p) : prob.f(θ, prob.p)
218-
return x[1]
218+
return x[1]
219219
end
220220
optim_f = _loss
221221
end
222-
222+
223223
Optim.optimize(optim_f, prob.lb, prob.ub, prob.x, opt, Optim.Options(;extended_trace = true, callback = _cb, iterations = maxiters, kwargs...))
224224
end
225225

@@ -228,14 +228,14 @@ function __init__()
228228
decompose_trace(opt::BlackBoxOptim.OptRunController) = BlackBoxOptim.best_candidate(opt)
229229

230230
struct BBO
231-
method::Symbol
231+
method::Symbol
232232
end
233233

234234
BBO() = BBO(:adaptive_de_rand_1_bin)
235235

236236
function __solve(prob::OptimizationProblem, opt::BBO; cb = (args...) -> (false), maxiters = 1000, kwargs...)
237237
local x, _loss
238-
238+
239239
function _cb(trace)
240240
cb_call = cb(decompose_trace(trace),x...)
241241
if !(typeof(cb_call) <: Bool)
@@ -247,20 +247,20 @@ function __init__()
247247
cb_call
248248
end
249249

250-
if prob.f isa OptimizationFunction
250+
if prob.f isa OptimizationFunction
251251
_loss = function(θ)
252252
x = prob.f.f(θ, prob.p)
253253
return x[1]
254254
end
255-
else
255+
else
256256
_loss = function(θ)
257257
x = prob.f(θ, prob.p)
258258
return x[1]
259259
end
260260
end
261261

262262
bboptre = BlackBoxOptim.bboptimize(_loss;Method = opt.method, SearchRange = [(prob.lb[i], prob.ub[i]) for i in 1:length(prob.lb)], MaxSteps = maxiters, CallbackFunction = _cb, CallbackInterval = 0.0, kwargs...)
263-
263+
264264
Optim.MultivariateOptimizationResults(opt.method,
265265
[NaN],# initial_x,
266266
BlackBoxOptim.best_candidate(bboptre), #pick_best_x(f_incr_pick, state),
@@ -292,10 +292,10 @@ function __init__()
292292
end
293293

294294
@require NLopt="76087f3c-5699-56af-9a33-bf431cd00edd" begin
295-
function __solve(prob::OptimizationProblem, opt::NLopt.Opt; maxiters = 1000, nstart = 1, local_method = nothing, kwargs...)
295+
function __solve(prob::OptimizationProblem, opt::NLopt.Opt; maxiters = 1000, nstart = 1, local_method = nothing, kwargs...)
296296
local x
297297

298-
if prob.f isa OptimizationFunction
298+
if prob.f isa OptimizationFunction
299299
_loss = function(θ)
300300
x = prob.f.f(θ, prob.p)
301301
return x[1]
@@ -304,11 +304,11 @@ function __init__()
304304
if length(G) > 0
305305
prob.f.grad(G, θ)
306306
end
307-
307+
308308
return _loss(θ)
309309
end
310310
NLopt.min_objective!(opt, fg!)
311-
else
311+
else
312312
_loss = function(θ,G)
313313
x = prob.f(θ, prob.p)
314314
return x[1]
@@ -317,7 +317,7 @@ function __init__()
317317
end
318318

319319
if prob.ub !== nothing
320-
NLopt.upper_bounds!(opt, prob.ub)
320+
NLopt.upper_bounds!(opt, prob.ub)
321321
end
322322
if prob.lb !== nothing
323323
NLopt.lower_bounds!(opt, prob.lb)
@@ -361,19 +361,19 @@ function __init__()
361361
ret,
362362
NaN,
363363
_time-t0,)
364-
end
364+
end
365365
end
366366

367367
@require MultistartOptimization = "3933049c-43be-478e-a8bb-6e0f7fd53575" begin
368368
function __solve(prob::OptimizationProblem, opt::MultistartOptimization.TikTak; local_method, local_maxiters = 1000, kwargs...)
369369
local x, _loss
370-
371-
if prob.f isa OptimizationFunction
370+
371+
if prob.f isa OptimizationFunction
372372
_loss = function(θ)
373373
x = prob.f.f(θ, prob.p)
374374
return x[1]
375375
end
376-
else
376+
else
377377
_loss = function(θ)
378378
x = prob.f(θ, prob.p)
379379
return x[1]
@@ -386,9 +386,9 @@ function __init__()
386386
multistart_method = opt
387387
local_method = MultistartOptimization.NLoptLocalMethod(local_method, maxeval = local_maxiters)
388388
p = MultistartOptimization.multistart_minimization(multistart_method, local_method, P)
389-
389+
390390
t1 = time()
391-
391+
392392
Optim.MultivariateOptimizationResults(opt,
393393
[NaN],# initial_x,
394394
p.location, #pick_best_x(f_incr_pick, state),
@@ -421,19 +421,19 @@ function __init__()
421421

422422
@require QuadDIRECT = "dae52e8d-d666-5120-a592-9e15c33b8d7a" begin
423423
export QuadDirect
424-
424+
425425
struct QuadDirect
426426
end
427427

428428
function __solve(prob::OptimizationProblem, opt::QuadDirect; splits, maxiters = 1000, kwargs...)
429429
local x, _loss
430-
431-
if prob.f isa OptimizationFunction
430+
431+
if prob.f isa OptimizationFunction
432432
_loss = function(θ)
433433
x = prob.f.f(θ, prob.p)
434434
return x[1]
435435
end
436-
else
436+
else
437437
_loss = function(θ)
438438
x = prob.f(θ, prob.p)
439439
return x[1]
@@ -485,7 +485,7 @@ function __init__()
485485

486486
function __solve(prob::OptimizationProblem, opt::Evolutionary.AbstractOptimizer; cb = (args...) -> (false), maxiters = 1000, kwargs...)
487487
local x, _loss
488-
488+
489489
function _cb(trace)
490490
cb_call = cb(decompose_trace(trace).metadata["x"],trace.value...)
491491
if !(typeof(cb_call) <: Bool)
@@ -494,12 +494,12 @@ function __init__()
494494
cb_call
495495
end
496496

497-
if prob.f isa OptimizationFunction
497+
if prob.f isa OptimizationFunction
498498
_loss = function(θ)
499499
x = prob.f.f(θ, prob.p)
500500
return x[1]
501501
end
502-
else
502+
else
503503
_loss = function(θ)
504504
x = prob.f(θ, prob.p)
505505
return x[1]
@@ -515,7 +515,7 @@ function __init__()
515515

516516
function __solve(prob::OptimizationProblem, opt::CMAEvolutionStrategyOpt; cb = (args...) -> (false), maxiters = 1000, kwargs...)
517517
local x, _loss
518-
518+
519519
function _cb(trace)
520520
cb_call = cb(decompose_trace(trace).metadata["x"],trace.value...)
521521
if !(typeof(cb_call) <: Bool)
@@ -524,12 +524,12 @@ function __init__()
524524
cb_call
525525
end
526526

527-
if prob.f isa OptimizationFunction
527+
if prob.f isa OptimizationFunction
528528
_loss = function(θ)
529529
x = prob.f.f(θ, prob.p)
530530
return x[1]
531531
end
532-
else
532+
else
533533
_loss = function(θ)
534534
x = prob.f(θ, prob.p)
535535
return x[1]

test/rosenbrock.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ rosenbrock(x, p=nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
2121

2222
l1 = rosenbrock(x0)
2323
prob = OptimizationProblem(rosenbrock, x0)
24-
sol = solve(prob, NelderMead())
24+
sol = solve(prob, NelderMead())
2525
@test 10*sol.minimum < l1
2626

2727

test/runtests.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
using GalacticOptim
2-
using Test
1+
using SafeTestsets
32

4-
@testset "GalacticOptim.jl" begin
5-
include("rosenbrock.jl")
6-
include("ADtests.jl")
7-
end
3+
@safetestset "Rosenbrock" begin include("rosenbrock.jl") end
4+
@safetestset "AD Tests" begin include("ADtests.jl") end

0 commit comments

Comments
 (0)