Skip to content

Commit c105960

Browse files
Merge pull request #253 from ValentinKaisermayer/patch-opfunc
changes all OptimizationFunction constructors to outer one
2 parents 64ba70d + a7dd7cf commit c105960

File tree

8 files changed

+34
-20
lines changed

8 files changed

+34
-20
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@ TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
2020

2121
[compat]
2222
ArrayInterfaceCore = "0.1.1"
23+
ArrayInterface = "6"
2324
ConsoleProgressMonitor = "0.1"
2425
DiffResults = "1.0"
2526
DocStringExtensions = "0.8"
2627
LoggingExtras = "0.4"
2728
ProgressLogging = "0.1"
2829
Reexport = "0.2, 1.0"
2930
Requires = "1.0"
30-
SciMLBase = "1.32"
31+
SciMLBase = "1.34"
3132
TerminalLoggers = "0.1"
3233
julia = "1.6"
34+
35+
[extras]
36+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/function/finitediff.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,7 @@ function instantiate_function(f, x, adtype::AutoFiniteDiff, p, num_cons = 0)
3232
hv = f.hv
3333
end
3434

35-
return OptimizationFunction{false,AutoFiniteDiff,typeof(f),typeof(grad),typeof(hess),typeof(hv),Nothing,Nothing,Nothing}(f,adtype,grad,hess,hv,nothing,nothing,nothing)
35+
return OptimizationFunction{false}(f, adtype; grad=grad, hess=hess, hv=hv,
36+
cons=nothing, cons_j=nothing, cons_h=nothing,
37+
hess_prototype=nothing, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
3638
end

src/function/forwarddiff.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function default_chunk_size(len)
1111
end
1212
end
1313

14-
function instantiate_function(f::OptimizationFunction{true}, x, ::AutoForwardDiff{_chunksize}, p, num_cons = 0) where _chunksize
14+
function instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoForwardDiff{_chunksize}, p, num_cons = 0) where _chunksize
1515

1616
chunksize = _chunksize === nothing ? default_chunk_size(length(x)) : _chunksize
1717

@@ -67,5 +67,7 @@ function instantiate_function(f::OptimizationFunction{true}, x, ::AutoForwardDif
6767
cons_h = f.cons_h
6868
end
6969

70-
return OptimizationFunction{true,AutoForwardDiff,typeof(f.f),typeof(grad),typeof(hess),typeof(hv),typeof(cons),typeof(cons_j),typeof(cons_h)}(f.f,AutoForwardDiff(),grad,hess,hv,cons,cons_j,cons_h)
70+
return OptimizationFunction{true}(f.f, adtype; grad=grad, hess=hess, hv=hv,
71+
cons=cons, cons_j=cons_j, cons_h=cons_h,
72+
hess_prototype=nothing, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
7173
end

src/function/function.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@ function instantiate_function(f, x, ::AbstractADType, p, num_cons = 0)
66
cons_j = f.cons_j === nothing ? nothing : (res,x)->f.cons_j(res,x,p)
77
cons_h = f.cons_h === nothing ? nothing : (res,x)->f.cons_h(res,x,p)
88

9-
OptimizationFunction{true,SciMLBase.NoAD,typeof(f.f),typeof(grad),
10-
typeof(hess),typeof(hv),typeof(cons),
11-
typeof(cons_j),typeof(cons_h)}(f.f,
12-
SciMLBase.NoAD(),grad,hess,hv,cons,
13-
cons_j,cons_h)
9+
return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad=grad, hess=hess, hv=hv,
10+
cons=cons, cons_j=cons_j, cons_h=cons_h,
11+
hess_prototype=nothing, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
1412
end

src/function/mtk.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ end
55

66
AutoModelingToolkit() = AutoModelingToolkit(false, false)
77

8-
function instantiate_function(f, x, ad::AutoModelingToolkit, p, num_cons=0)
8+
function instantiate_function(f, x, adtype::AutoModelingToolkit, p, num_cons=0)
99
p = isnothing(p) ? SciMLBase.NullParameters() : p
1010
sys = ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, x, p))
1111

@@ -17,7 +17,7 @@ function instantiate_function(f, x, ad::AutoModelingToolkit, p, num_cons=0)
1717
end
1818

1919
if f.hess === nothing
20-
hess_oop, hess_iip = ModelingToolkit.generate_hessian(sys, expression=Val{false}, sparse = ad.obj_sparse)
20+
hess_oop, hess_iip = ModelingToolkit.generate_hessian(sys, expression=Val{false}, sparse = adtype.obj_sparse)
2121
hess(H, u) = (hess_iip(H, u, p); H)
2222
else
2323
hess = f.hess
@@ -41,7 +41,7 @@ function instantiate_function(f, x, ad::AutoModelingToolkit, p, num_cons=0)
4141
end
4242

4343
if f.cons !== nothing && f.cons_j === nothing
44-
jac_oop, jac_iip = ModelingToolkit.generate_jacobian(cons_sys, expression=Val{false}, sparse=ad.cons_sparse)
44+
jac_oop, jac_iip = ModelingToolkit.generate_jacobian(cons_sys, expression=Val{false}, sparse=adtype.cons_sparse)
4545
cons_j = function (J, θ)
4646
jac_iip(J, θ, p)
4747
end
@@ -50,13 +50,15 @@ function instantiate_function(f, x, ad::AutoModelingToolkit, p, num_cons=0)
5050
end
5151

5252
if f.cons !== nothing && f.cons_h === nothing
53-
cons_hess_oop, cons_hess_iip = ModelingToolkit.generate_hessian(cons_sys, expression=Val{false}, sparse=ad.cons_sparse)
53+
cons_hess_oop, cons_hess_iip = ModelingToolkit.generate_hessian(cons_sys, expression=Val{false}, sparse=adtype.cons_sparse)
5454
cons_h = function (res, θ)
5555
cons_hess_iip(res, θ, p)
5656
end
5757
else
5858
cons_h = f.cons_h
5959
end
6060

61-
return OptimizationFunction{true,AutoModelingToolkit,typeof(f.f),typeof(grad),typeof(hess),typeof(hv),typeof(cons),typeof(cons_j),typeof(cons_h)}(f.f, AutoModelingToolkit(), grad, hess, hv, cons, cons_j, cons_h)
61+
return OptimizationFunction{true}(f.f, adtype; grad=grad, hess=hess, hv=hv,
62+
cons=cons, cons_j=cons_j, cons_h=cons_h,
63+
hess_prototype=nothing, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
6264
end

src/function/reversediff.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
struct AutoReverseDiff <: AbstractADType end
22

3-
function instantiate_function(f, x, ::AutoReverseDiff, p=SciMLBase.NullParameters(), num_cons = 0)
3+
function instantiate_function(f, x, adtype::AutoReverseDiff, p=SciMLBase.NullParameters(), num_cons = 0)
44
num_cons != 0 && error("AutoReverseDiff does not currently support constraints")
55

66
_f = (θ, args...) -> first(f.f(θ,p, args...))
@@ -39,5 +39,7 @@ function instantiate_function(f, x, ::AutoReverseDiff, p=SciMLBase.NullParameter
3939
hv = f.hv
4040
end
4141

42-
return OptimizationFunction{false,AutoReverseDiff,typeof(f),typeof(grad),typeof(hess),typeof(hv),Nothing,Nothing,Nothing}(f,AutoReverseDiff(),grad,hess,hv,nothing,nothing,nothing)
42+
return OptimizationFunction{false}(f, adtype; grad=grad, hess=hess, hv=hv,
43+
cons=nothing, cons_j=nothing, cons_h=nothing,
44+
hess_prototype=nothing, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
4345
end

src/function/tracker.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
struct AutoTracker <: AbstractADType end
22

3-
function instantiate_function(f, x, ::AutoTracker, p, num_cons = 0)
3+
function instantiate_function(f, x, adtype::AutoTracker, p, num_cons = 0)
44
num_cons != 0 && error("AutoTracker does not currently support constraints")
55
_f = (θ, args...) -> first(f.f(θ, p, args...))
66

@@ -23,5 +23,7 @@ function instantiate_function(f, x, ::AutoTracker, p, num_cons = 0)
2323
end
2424

2525

26-
return OptimizationFunction{false,AutoTracker,typeof(f),typeof(grad),typeof(hess),typeof(hv),Nothing,Nothing,Nothing}(f,AutoTracker(),grad,hess,hv,nothing,nothing,nothing)
26+
return OptimizationFunction{false}(f, adtype; grad=grad, hess=hess, hv=hv,
27+
cons=nothing, cons_j=nothing, cons_h=nothing,
28+
hess_prototype=nothing, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
2729
end

src/function/zygote.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
struct AutoZygote <: AbstractADType end
22

3-
function instantiate_function(f, x, ::AutoZygote, p, num_cons = 0)
3+
function instantiate_function(f, x, adtype::AutoZygote, p, num_cons = 0)
44
num_cons != 0 && error("AutoZygote does not currently support constraints")
55

66
_f = (θ, args...) -> f(θ,p,args...)[1]
@@ -37,5 +37,7 @@ function instantiate_function(f, x, ::AutoZygote, p, num_cons = 0)
3737
hv = f.hv
3838
end
3939

40-
return OptimizationFunction{false,AutoZygote,typeof(f),typeof(grad),typeof(hess),typeof(hv),Nothing,Nothing,Nothing}(f,AutoZygote(),grad,hess,hv,nothing,nothing,nothing)
40+
return OptimizationFunction{false}(f, adtype; grad=grad, hess=hess, hv=hv,
41+
cons=nothing, cons_j=nothing, cons_h=nothing,
42+
hess_prototype=nothing, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
4143
end

0 commit comments

Comments
 (0)