Skip to content

Commit a9b45a9

Browse files
Merge pull request #71 from SciML/dataarg
Clean up data arg checking
2 parents 754f452 + 6803112 commit a9b45a9

File tree

5 files changed

+50
-13
lines changed

5 files changed

+50
-13
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,14 @@ julia = "1.3"
4545
[extras]
4646
BlackBoxOptim = "a134a8b2-14d6-55f6-9291-3336d3ab0209"
4747
CMAEvolutionStrategy = "8d3b24bd-414e-49e0-94fb-163cc3a3e411"
48+
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
4849
Evolutionary = "86b6b26d-c046-49b6-aa0b-5f0f74682bd6"
50+
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
51+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
4952
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
53+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
5054
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
5155
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5256

5357
[targets]
54-
test = ["BlackBoxOptim", "Evolutionary", "NLopt", "CMAEvolutionStrategy", "SafeTestsets", "Test"]
58+
test = ["BlackBoxOptim", "Evolutionary", "DiffEqFlux", "IterTools", "OrdinaryDiffEq", "NLopt", "CMAEvolutionStrategy", "Plots" ,"SafeTestsets", "Test"]

src/solve.jl

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ macro withprogress(progress, exprs...)
6262
end |> esc
6363
end
6464

65-
function __solve(prob::OptimizationProblem, opt, _data = DEFAULT_DATA;cb = (args...) -> (false), maxiters::Number = 1000, progress = true, save_best = true, kwargs...)
65+
function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;cb = (args...) -> (false), maxiters::Number = 1000, progress = true, save_best = true, kwargs...)
6666
if maxiters <= 0.0
6767
error("The number of maxiters has to be a non-negative and non-zero number.")
6868
else
@@ -74,14 +74,10 @@ function __solve(prob::OptimizationProblem, opt, _data = DEFAULT_DATA;cb = (args
7474
θ = copy(prob.u0)
7575
ps = Flux.params(θ)
7676

77-
if _data == DEFAULT_DATA && maxiters == typemax(Int)
78-
error("For Flux optimizers, either a data iterator must be provided or the `maxiters` keyword argument must be set.")
79-
elseif _data == DEFAULT_DATA && maxiters != typemax(Int)
80-
data = Iterators.repeated((), maxiters)
81-
elseif maxiters != typemax(Int)
82-
data = take(_data, maxiters)
83-
else
84-
data = _data
77+
if data != DEFAULT_DATA
78+
maxiters = length(data)
79+
else
80+
data = take(data, maxiters)
8581
end
8682

8783
t0 = time()
@@ -158,6 +154,11 @@ decompose_trace(trace) = trace
158154

159155
function __solve(prob::OptimizationProblem, opt::Optim.AbstractOptimizer, data = DEFAULT_DATA;cb = (args...) -> (false), maxiters::Number = 1000, kwargs...)
160156
local x, cur, state
157+
158+
if data != DEFAULT_DATA
159+
maxiters = length(data)
160+
end
161+
161162
cur, state = iterate(data)
162163

163164
function _cb(trace)
@@ -202,6 +203,11 @@ end
202203

203204
function __solve(prob::OptimizationProblem, opt::Union{Optim.Fminbox,Optim.SAMIN}, data = DEFAULT_DATA;cb = (args...) -> (false), maxiters::Number = 1000, kwargs...)
204205
local x, cur, state
206+
207+
if data != DEFAULT_DATA
208+
maxiters = length(data)
209+
end
210+
205211
cur, state = iterate(data)
206212

207213
function _cb(trace)
@@ -242,6 +248,11 @@ end
242248

243249
function __solve(prob::OptimizationProblem, opt::Optim.ConstrainedOptimizer, data = DEFAULT_DATA;cb = (args...) -> (false), maxiters::Number = 1000, kwargs...)
244250
local x, cur, state
251+
252+
if data != DEFAULT_DATA
253+
maxiters = length(data)
254+
end
255+
245256
cur, state = iterate(data)
246257

247258
function _cb(trace)
@@ -309,6 +320,11 @@ function __init__()
309320

310321
function __solve(prob::OptimizationProblem, opt::BBO, data = DEFAULT_DATA; cb = (args...) -> (false), maxiters::Number = 1000, kwargs...)
311322
local x, cur, state
323+
324+
if data != DEFAULT_DATA
325+
maxiters = length(data)
326+
end
327+
312328
cur, state = iterate(data)
313329

314330
function _cb(trace)
@@ -563,6 +579,11 @@ function __init__()
563579

564580
function __solve(prob::OptimizationProblem, opt::Evolutionary.AbstractOptimizer, data = DEFAULT_DATA; cb = (args...) -> (false), maxiters::Number = 1000, kwargs...)
565581
local x, cur, state
582+
583+
if data != DEFAULT_DATA
584+
maxiters = length(data)
585+
end
586+
566587
cur, state = iterate(data)
567588

568589
function _cb(trace)
@@ -594,6 +615,11 @@ function __init__()
594615

595616
function __solve(prob::OptimizationProblem, opt::CMAEvolutionStrategyOpt, data = DEFAULT_DATA; cb = (args...) -> (false), maxiters::Number = 1000, kwargs...)
596617
local x, cur, state
618+
619+
if data != DEFAULT_DATA
620+
maxiters = length(data)
621+
end
622+
597623
cur, state = iterate(data)
598624

599625
function _cb(trace)

test/minibatch.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using DifferentialEquations, Flux, Optim, DiffEqFlux, Plots, GalacticOptim
1+
using DiffEqFlux, Plots, GalacticOptim, OrdinaryDiffEq
22

33
function newtons_cooling(du, u, p, t)
44
temp = u[1]
@@ -52,9 +52,10 @@ k = 10
5252
train_loader = Flux.Data.DataLoader(ode_data, t, batchsize = k)
5353

5454
numEpochs = 300
55+
l1 = loss_adjoint(pp, train_loader.data[1], train_loader.data[2])[1]
5556

5657
optfun = OptimizationFunction((θ, p, batch, time_batch) -> loss_adjoint(θ, batch, time_batch), GalacticOptim.AutoZygote())
5758
optprob = OptimizationProblem(optfun, pp)
5859
using IterTools: ncycle
5960
res1 = GalacticOptim.solve(optprob, ADAM(0.05), ncycle(train_loader, numEpochs), cb = cb, maxiters = numEpochs)
60-
cb(res1.minimizer, loss_adjoint(res1.minimizer, ode_data, t)...; doplot=true)
61+
@test 10res1.minimum < l1

test/rosenbrock.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using GalacticOptim, Optim, Flux, Test
1+
using GalacticOptim, Optim, Test
22

33
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
44
x0 = zeros(2)

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
using SafeTestsets
22

3+
println("Rosenbrock Tests")
34
@safetestset "Rosenbrock" begin include("rosenbrock.jl") end
5+
println("AD Tests")
46
@safetestset "AD Tests" begin include("ADtests.jl") end
7+
println("Mini batching Tests")
8+
@safetestset "Mini batching" begin include("minibatch.jl") end
9+
println("DiffEqFlux Tests")
10+
@safetestset "DiffEqFlux" begin include("diffeqfluxtests.jl") end

0 commit comments

Comments
 (0)