Skip to content

Commit 483716b

Browse files
committed
[WIP] Update parameters and implement grid search
1 parent 1571e27 commit 483716b

File tree

6 files changed

+154
-44
lines changed

6 files changed

+154
-44
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "0.1.0"
44

55
[deps]
66
ADNLPModels = "54578032-b7ea-4c30-94aa-7cbd1cce6c9a"
7+
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
78
NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6"
89
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
910

solver-example.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,11 @@ output = solve!(solver, nlp, δ = 1e-2)
2121

2222
#%%
2323
include("test/dummy_solver.jl")
24-
output, solver = DummySolver(nlp)
24+
output, solver = DummySolver(nlp)
25+
26+
#%%
27+
include("test/dummy_solver.jl")
28+
problems = (
29+
ADNLPModel(x -> (x[1] - a)^2 + b^2 * (x[2] - x[1]^2)^2, [-1.2; 1.0], x -> [x[1]^2 + x[2]^2 - 1], [0.0], [0.0]) for a = 0:0.1:1, b = 2:3
30+
)
31+
o = grid_search_tune(DummySolver, problems)

src/SolverCore.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
module SolverCore
22

33
# stdlib
4-
using Printf
4+
using Logging, Printf
55

66
# our packages
77
using NLPModels
88

99
include("solver.jl")
10+
include("grid-search-tuning.jl")
1011
include("logger.jl")
1112
include("stats.jl")
1213

src/grid-search-tuning.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
export grid_search_tune
2+
3+
"""
4+
solver, results = grid_search_tune(SolverType, problems; kwargs...)
5+
6+
Simple tuning of solver `SolverType` by grid search, on `problems`, which should be iterable.
7+
The following keyword arguments are available:
8+
- `success`: A function to be applied on a solver output that returns whether the problem has terminated succesfully. Defaults to `o -> o.status == :first_order`.
9+
- `costs`: A vector of cost functions and penalties. Each element is a tuple of two elements. The first is a function to be applied to the output of the solver, and the second is the cost when the solver fails (see `success` above) or throws an error. Defaults to
10+
```
11+
[
12+
(o -> o.elapsed_time, 100.0),
13+
(o -> o.counters.neval_obj + o.counters.neval_cons, 1000),
14+
(o -> !success(o), 1),
15+
]
16+
```
17+
which represent the total elapsed_time (with a penalty of 100.0 for failures); the number of objective and constraints functions evaluations (with a penalty of 1000 for failures); and the number of failures.
18+
- `grid_length`: The number of points in the ranges of the grid for continuous points.
19+
- `solver_kwargs`: Arguments to be passed to the solver. Note: use this to set the stopping parameters, but not the other parameters being optimize.
20+
- Any parameters accepted by the `Solver`: a range to be used instead of the default range.
21+
22+
The default ranges are based on the parameters types, and are as follows:
23+
- `:real`: linear range from `:min` to `:max` with `grid_length` points.
24+
- `:log`: logarithmic range from `:min` to `:max` with `grid_length` points. Computed by exp of linear range of `log(:min)` to `log(:max)`.
25+
- `:bool`: either `false` or `true`.
26+
- `:int`: integer range from `:min` to `:max`.
27+
"""
28+
function grid_search_tune(
29+
::Type{Solver},
30+
problems;
31+
success = o -> o.status == :first_order,
32+
costs = [
33+
(o -> o.elapsed_time, 100.0),
34+
(o -> o.counters.neval_obj + o.counters.neval_cons, 1000),
35+
(o -> !success(o), 1),
36+
],
37+
grid_length = 10,
38+
solver_kwargs = Dict(),
39+
kwargs...
40+
) where Solver <: AbstractSolver
41+
42+
solver_params = parameters(Solver)
43+
params = Dict()
44+
for (k,v) in pairs(solver_params)
45+
if v[:type] == :real
46+
params[k] = LinRange(v[:min], v[:max], grid_length)
47+
elseif v[:type] == :log
48+
params[k] = exp.(LinRange(log(v[:min]), log(v[:max]), grid_length))
49+
elseif v[:type] == :bool
50+
params[k] = (false, true)
51+
elseif v[:type] == :int
52+
params[k] = v[:min]:v[:max]
53+
end
54+
end
55+
for (k,v) in kwargs
56+
params[k] = v
57+
end
58+
59+
# Precompiling
60+
nlp = first(problems)
61+
try
62+
solver = Solver(Val(:nosolve), nlp)
63+
output = with_logger(NullLogger()) do
64+
solve!(solver, nlp)
65+
end
66+
finally
67+
finalize(nlp)
68+
end
69+
70+
cost(θ) = begin
71+
total_cost = [zero(x[2]) for x in costs]
72+
for nlp in problems
73+
reset!(nlp)
74+
try
75+
solver = Solver(Val(:nosolve), nlp)
76+
output = with_logger(NullLogger()) do
77+
solve!(solver, nlp; (k => θi for (k,θi) in zip(keys(solver_params), θ))...)
78+
end
79+
for (i, c) in enumerate(costs)
80+
if success(output)
81+
total_cost[i] += (c[1])(output)
82+
else
83+
total_cost[i] += c[2]
84+
end
85+
end
86+
catch ex
87+
for (i, c) in enumerate(costs)
88+
total_cost[i] += c[2]
89+
end
90+
@show ex
91+
finally
92+
finalize(nlp)
93+
end
94+
end
95+
total_cost
96+
end
97+
98+
=> cost(θ) for θ in Iterators.product(values(params)...)]
99+
end

src/solver.jl

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ Base type for JSO-compliant solvers.
77
"""
88
abstract type AbstractSolver{T} end
99

10+
function Base.show(io :: IO, solver :: AbstractSolver)
11+
show(io, "Solver $(typeof(solver))")
12+
end
13+
1014
"""
1115
output = solve!(solver, problem)
1216
@@ -27,7 +31,7 @@ Each key of `named_tuple` is the name of a parameter, and its value is a NamedTu
2731
- `default`: The default value of the parameter.
2832
- `type`: The type of the parameter, which can any of:
2933
- `:real`: A continuous value within a range
30-
- `:log`: A continuous value that should be explorer logarithmically around it's lower value (usually 0) to avoid the bound itself.
34+
- `:log`: A positive continuous value that should be explored logarithmically (like 10⁻², 10⁻¹, 1, 10).
3135
- `:int`: Integer value.
3236
- `:bool`: Boolean value.
3337
- `min`: Minimum value (may not be included for some parameter types).
@@ -36,32 +40,4 @@ Each key of `named_tuple` is the name of a parameter, and its value is a NamedTu
3640
function parameters(::Type{AbstractSolver{T}}) where T end
3741

3842
parameters(::Type{S}) where S <: AbstractSolver = parameters(S{Float64})
39-
parameters(solver :: AbstractSolver) = parameters(typeof(solver))
40-
41-
"""
42-
nlp = parameter_problem(solver)
43-
44-
Return the problem associated with the tuning of the parameters of `solver`.
45-
"""
46-
function parameter_problem(::AbstractSolver) end
47-
48-
# parameter_problem(
49-
# solver::DummySolver,
50-
# problems,
51-
# cost,
52-
# cost_bad
53-
# ) = ADNLPModel(
54-
# x -> begin
55-
# total_cost = 0.0
56-
# for nlp in problems
57-
# try
58-
# output = with_logger(NullLogger()) do
59-
# output, _ = DummySolver(nlp)
60-
# end
61-
# total_cost += cost(output)
62-
# catch
63-
# total_cost +=
64-
# end
65-
# end
66-
# end
67-
# )
43+
parameters(solver :: AbstractSolver) = parameters(typeof(solver))

test/dummy_solver.jl

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
mutable struct DummySolver{T} <: AbstractSolver{T}
22
initialized :: Bool
3+
params :: Dict
34
x :: Vector{T}
45
xt :: Vector{T}
56
gx :: Vector{T}
@@ -9,9 +10,31 @@ mutable struct DummySolver{T} <: AbstractSolver{T}
910
ct :: Vector{T}
1011
end
1112

12-
function DummySolver(::Type{T}, meta :: AbstractNLPModelMeta) where T
13+
function SolverCore.parameters(::Type{DummySolver{T}}) where T
14+
(
15+
α = (default=T(1e-2), type=:log, min=√√eps(T), max=one(T) / 2),
16+
δ = (default=eps(T), type=:log, min=eps(T), max=√√√eps(T)),
17+
reboot_y = (default=false, type=:bool)
18+
)
19+
end
20+
21+
function DummySolver(::Type{T}, meta :: AbstractNLPModelMeta; kwargs...) where T
1322
nvar, ncon = meta.nvar, meta.ncon
14-
DummySolver{T}(true, zeros(T, nvar), zeros(T, nvar), zeros(T, nvar), zeros(T, nvar), zeros(T, ncon), zeros(T, ncon), zeros(T, ncon))
23+
params = parameters(DummySolver{T})
24+
solver = DummySolver{T}(true,
25+
Dict(k => v[:default] for (k,v) in pairs(params)),
26+
zeros(T, nvar),
27+
zeros(T, nvar),
28+
zeros(T, nvar),
29+
zeros(T, nvar),
30+
zeros(T, ncon),
31+
zeros(T, ncon),
32+
zeros(T, ncon),
33+
)
34+
for (k,v) in kwargs
35+
solver.params[k] = v
36+
end
37+
solver
1538
end
1639

1740
function DummySolver(::Type{T}, ::Val{:nosolve}, nlp :: AbstractNLPModel) where T
@@ -36,11 +59,16 @@ function SolverCore.solve!(solver::DummySolver{T}, nlp :: AbstractNLPModel;
3659
rtol :: Real = sqrt(eps(T)),
3760
max_eval :: Int = 1000,
3861
max_time :: Float64 = 30.0,
39-
α :: Float64 = 1e-2,
40-
δ :: Float64 = 1e-8,
62+
kwargs...
4163
) where T
4264
solver.initialized || error("Solver not initialized.")
4365
nvar, ncon = nlp.meta.nvar, nlp.meta.ncon
66+
for (k,v) in kwargs
67+
solver.params[k] = v
68+
end
69+
α = solver.params[]
70+
δ = solver.params[]
71+
reboot_y = solver.params[:reboot_y]
4472

4573
start_time = time()
4674
elapsed_time = 0.0
@@ -95,11 +123,16 @@ function SolverCore.solve!(solver::DummySolver{T}, nlp :: AbstractNLPModel;
95123
end
96124

97125
x .= xt
98-
y .+= t * Δy
126+
99127

100128
fx = ft
101129
grad!(nlp, x, gx)
102130
Jx = ncon > 0 ? jac(nlp, x) : zeros(T, 0, nvar)
131+
if reboot_y
132+
y .= -Jx' \ gx
133+
else
134+
y .+= t * Δy
135+
end
103136
cx .= ct
104137
dual .= gx .+ Jx' * y
105138
elapsed_time = time() - start_time
@@ -132,10 +165,3 @@ function SolverCore.solve!(solver::DummySolver{T}, nlp :: AbstractNLPModel;
132165
iter=iter
133166
)
134167
end
135-
136-
function SolverCore.parameters(::Type{DummySolver{T}}) where T
137-
(
138-
α = (default=T(1e-2), type=:log, min=zero(T), max=one(T)),
139-
δ = (default=eps(T), type=:log, min=zero(T), max=one(T)),
140-
)
141-
end

0 commit comments

Comments
 (0)