Skip to content

Commit 732985d

Browse files
tmigotabelsiqueira
authored andcommitted
externalize dense hessian
1 parent 2e55e5d commit 732985d

File tree

7 files changed

+153
-118
lines changed

7 files changed

+153
-118
lines changed

src/ad.jl

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,49 @@
11
abstract type ADBackend end
2-
struct ForwardDiffAD <: ADBackend end
3-
struct ZygoteAD <: ADBackend end
4-
struct ReverseDiffAD <: ADBackend end
2+
struct ForwardDiffAD <: ADBackend
3+
nnzh::Int
4+
nnzj::Int
5+
end
6+
function ForwardDiffAD(f, c, x0::AbstractVector, ncon::Integer)
7+
nvar = length(x0)
8+
nnzh = nvar * (nvar + 1) / 2
9+
nnzj = nvar * ncon
10+
return ForwardDiffAD(nnzh, nnzj)
11+
end
12+
function ForwardDiffAD(f, x0::AbstractVector)
13+
nvar = length(x0)
14+
nnzh = nvar * (nvar + 1) / 2
15+
return ForwardDiffAD(nnzh, 0)
16+
end
17+
struct ZygoteAD <: ADBackend
18+
nnzh::Int
19+
nnzj::Int
20+
end
21+
function ZygoteAD(f, c, x0::AbstractVector, ncon::Integer)
22+
nvar = length(x0)
23+
nnzh = nvar * (nvar + 1) / 2
24+
nnzj = nvar * ncon
25+
return ZygoteAD(nnzh, nnzj)
26+
end
27+
function ZygoteAD(f, x0::AbstractVector)
28+
nvar = length(x0)
29+
nnzh = nvar * (nvar + 1) / 2
30+
return ZygoteAD(nnzh, 0)
31+
end
32+
struct ReverseDiffAD <: ADBackend
33+
nnzh::Int
34+
nnzj::Int
35+
end
36+
function ReverseDiffAD(f, c, x0::AbstractVector, ncon::Integer)
37+
nvar = length(x0)
38+
nnzh = nvar * (nvar + 1) / 2
39+
nnzj = nvar * ncon
40+
return ReverseDiffAD(nnzh, nnzj)
41+
end
42+
function ReverseDiffAD(f, x0::AbstractVector)
43+
nvar = length(x0)
44+
nnzh = nvar * (nvar + 1) / 2
45+
return ReverseDiffAD(nnzh, 0)
46+
end
547

648
throw_error(b) =
749
throw(ArgumentError("The AD backend $b is not loaded. Please load the corresponding AD package."))
@@ -11,6 +53,46 @@ jacobian(b::ADBackend, ::Any, ::Any) = throw_error(b)
1153
hessian(b::ADBackend, ::Any, ::Any) = throw_error(b)
1254
Jprod(b::ADBackend, ::Any, ::Any, ::Any) = throw_error(b)
1355
Jtprod(b::ADBackend, ::Any, ::Any, ::Any) = throw_error(b)
56+
function hess_structure!(
57+
b::ADBackend,
58+
nlp,
59+
rows::AbstractVector{<:Integer},
60+
cols::AbstractVector{<:Integer},
61+
)
62+
n = nlp.meta.nvar
63+
I = ((i, j) for i = 1:n, j = 1:n if i j)
64+
rows .= getindex.(I, 1)
65+
cols .= getindex.(I, 2)
66+
return rows, cols
67+
end
68+
function hess_coord!(b::ADBackend, nlp, x::AbstractVector, ℓ::Function, vals::AbstractVector)
69+
Hx = hessian(b, ℓ, x)
70+
k = 1
71+
for j = 1:(nlp.meta.nvar)
72+
for i = j:(nlp.meta.nvar)
73+
vals[k] = Hx[i, j]
74+
k += 1
75+
end
76+
end
77+
return vals
78+
end
79+
function jac_structure!(
80+
b::ADBackend,
81+
nlp,
82+
rows::AbstractVector{<:Integer},
83+
cols::AbstractVector{<:Integer},
84+
)
85+
m, n = nlp.meta.ncon, nlp.meta.nvar
86+
I = ((i, j) for i = 1:m, j = 1:n)
87+
rows .= getindex.(I, 1)[:]
88+
cols .= getindex.(I, 2)[:]
89+
return rows, cols
90+
end
91+
function jac_coord!(b::ADBackend, nlp, x::AbstractVector, vals::AbstractVector)
92+
Jx = jacobian(b, nlp.c, x)
93+
vals .= Jx[:]
94+
return vals
95+
end
1496
function directional_second_derivative(::ADBackend, f, x, v, w)
1597
return ForwardDiff.derivative(t -> ForwardDiff.derivative(s -> f(x + s * w + t * v), 0), 0)
1698
end
@@ -45,7 +127,7 @@ end
45127
return Zygote.jacobian(f, x)[1]
46128
end
47129
function hessian(b::ZygoteAD, f, x)
48-
return jacobian(ForwardDiffAD(), x -> gradient(b, f, x), x)
130+
return jacobian(ForwardDiffAD(f, x), x -> gradient(b, f, x), x)
49131
end
50132
function Jprod(::ZygoteAD, f, x, v)
51133
return vec(Zygote.jacobian(t -> f(x + t * v), 0)[1])

src/nlp.jl

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function ADNLPModel(
3939
f,
4040
x0::AbstractVector{T};
4141
name::String = "Generic",
42-
adbackend = ForwardDiffAD(),
42+
adbackend = ForwardDiffAD(f, x0),
4343
) where {T}
4444
nvar = length(x0)
4545
@lencheck nvar x0
@@ -57,7 +57,7 @@ function ADNLPModel(
5757
lvar::AbstractVector,
5858
uvar::AbstractVector;
5959
name::String = "Generic",
60-
adbackend = ForwardDiffAD(),
60+
adbackend = ForwardDiffAD(f, x0),
6161
) where {T}
6262
nvar = length(x0)
6363
@lencheck nvar x0 lvar uvar
@@ -87,7 +87,7 @@ function ADNLPModel(
8787
y0::AbstractVector = fill!(similar(lcon), zero(T)),
8888
name::String = "Generic",
8989
lin::AbstractVector{<:Integer} = Int[],
90-
adbackend = ForwardDiffAD(),
90+
adbackend = ForwardDiffAD(f, c, x0, length(lcon)),
9191
) where {T}
9292
nvar = length(x0)
9393
ncon = length(lcon)
@@ -129,7 +129,7 @@ function ADNLPModel(
129129
y0::AbstractVector = fill!(similar(lcon), zero(T)),
130130
name::String = "Generic",
131131
lin::AbstractVector{<:Integer} = Int[],
132-
adbackend = ForwardDiffAD(),
132+
adbackend = ForwardDiffAD(f, c, x0, length(lcon)),
133133
) where {T}
134134
nvar = length(x0)
135135
ncon = length(lcon)
@@ -195,20 +195,14 @@ function NLPModels.jac_structure!(
195195
cols::AbstractVector{<:Integer},
196196
)
197197
@lencheck nlp.meta.nnzj rows cols
198-
m, n = nlp.meta.ncon, nlp.meta.nvar
199-
I = ((i, j) for i = 1:m, j = 1:n)
200-
rows .= getindex.(I, 1)[:]
201-
cols .= getindex.(I, 2)[:]
202-
return rows, cols
198+
return jac_structure!(nlp.adbackend, nlp, rows, cols)
203199
end
204200

205201
function NLPModels.jac_coord!(nlp::ADNLPModel, x::AbstractVector, vals::AbstractVector)
206202
@lencheck nlp.meta.nvar x
207203
@lencheck nlp.meta.nnzj vals
208204
increment!(nlp, :neval_jac)
209-
Jx = jacobian(nlp.adbackend, nlp.c, x)
210-
vals .= Jx[:]
211-
return vals
205+
return jac_coord!(nlp.adbackend, nlp, x, vals)
212206
end
213207

214208
function NLPModels.jprod!(nlp::ADNLPModel, x::AbstractVector, v::AbstractVector, Jv::AbstractVector)
@@ -259,12 +253,8 @@ function NLPModels.hess_structure!(
259253
rows::AbstractVector{<:Integer},
260254
cols::AbstractVector{<:Integer},
261255
)
262-
n = nlp.meta.nvar
263256
@lencheck nlp.meta.nnzh rows cols
264-
I = ((i, j) for i = 1:n, j = 1:n if i j)
265-
rows .= getindex.(I, 1)
266-
cols .= getindex.(I, 2)
267-
return rows, cols
257+
return hess_structure!(nlp.adbackend, nlp, rows, cols)
268258
end
269259

270260
function NLPModels.hess_coord!(
@@ -277,15 +267,7 @@ function NLPModels.hess_coord!(
277267
@lencheck nlp.meta.nnzh vals
278268
increment!(nlp, :neval_hess)
279269
(x) = obj_weight * nlp.f(x)
280-
Hx = hessian(nlp.adbackend, ℓ, x)
281-
k = 1
282-
for j = 1:(nlp.meta.nvar)
283-
for i = j:(nlp.meta.nvar)
284-
vals[k] = Hx[i, j]
285-
k += 1
286-
end
287-
end
288-
return vals
270+
return hess_coord!(nlp.adbackend, nlp, x, ℓ, vals)
289271
end
290272

291273
function NLPModels.hess_coord!(
@@ -300,15 +282,7 @@ function NLPModels.hess_coord!(
300282
@lencheck nlp.meta.nnzh vals
301283
increment!(nlp, :neval_hess)
302284
(x) = obj_weight * nlp.f(x) + dot(nlp.c(x), y)
303-
Hx = hessian(nlp.adbackend, ℓ, x)
304-
k = 1
305-
for j = 1:(nlp.meta.nvar)
306-
for i = j:(nlp.meta.nvar)
307-
vals[k] = Hx[i, j]
308-
k += 1
309-
end
310-
end
311-
return vals
285+
return hess_coord!(nlp.adbackend, nlp, x, ℓ, vals)
312286
end
313287

314288
function NLPModels.hprod!(

src/nls.jl

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function ADNLSModel(
4646
nequ::Integer;
4747
linequ::AbstractVector{<:Integer} = Int[],
4848
name::String = "Generic",
49-
adbackend = ForwardDiffAD(),
49+
adbackend = ForwardDiffAD(F, x0),
5050
) where {T}
5151
nvar = length(x0)
5252

@@ -72,7 +72,7 @@ function ADNLSModel(
7272
uvar::AbstractVector;
7373
linequ::AbstractVector{<:Integer} = Int[],
7474
name::String = "Generic",
75-
adbackend = ForwardDiffAD(),
75+
adbackend = ForwardDiffAD(F, x0),
7676
) where {T}
7777
nvar = length(x0)
7878
@lencheck nvar lvar uvar
@@ -102,7 +102,7 @@ function ADNLSModel(
102102
lin::AbstractVector{<:Integer} = Int[],
103103
linequ::AbstractVector{<:Integer} = Int[],
104104
name::String = "Generic",
105-
adbackend = ForwardDiffAD(),
105+
adbackend = ForwardDiffAD(F, c, x0, length(lcon)),
106106
) where {T}
107107
nvar = length(x0)
108108
ncon = length(lcon)
@@ -148,7 +148,7 @@ function ADNLSModel(
148148
lin::AbstractVector{<:Integer} = Int[],
149149
linequ::AbstractVector{<:Integer} = Int[],
150150
name::String = "Generic",
151-
adbackend = ForwardDiffAD(),
151+
adbackend = ForwardDiffAD(F, c, x0, length(lcon)),
152152
) where {T}
153153
nvar = length(x0)
154154
ncon = length(lcon)
@@ -327,19 +327,13 @@ function NLPModels.jac_structure!(
327327
cols::AbstractVector{<:Integer},
328328
)
329329
@lencheck nls.meta.nnzj rows cols
330-
m, n = nls.meta.ncon, nls.meta.nvar
331-
I = ((i, j) for i = 1:m, j = 1:n)
332-
rows .= getindex.(I, 1)[:]
333-
cols .= getindex.(I, 2)[:]
334-
return rows, cols
330+
return jac_structure!(nls.adbackend, nls, rows, cols)
335331
end
336332

337333
function NLPModels.jac_coord!(nls::ADNLSModel, x::AbstractVector, vals::AbstractVector)
338334
@lencheck nls.meta.nvar x
339335
@lencheck nls.meta.nnzj vals
340-
Jx = jacobian(nls.adbackend, nls.c, x)
341-
vals .= Jx[:]
342-
return vals
336+
return jac_coord!(nls.adbackend, nls, x, vals)
343337
end
344338

345339
function NLPModels.jprod!(nls::ADNLSModel, x::AbstractVector, v::AbstractVector, Jv::AbstractVector)
@@ -391,11 +385,7 @@ function NLPModels.hess_structure!(
391385
cols::AbstractVector{<:Integer},
392386
)
393387
@lencheck nls.meta.nnzh rows cols
394-
n = nls.meta.nvar
395-
I = ((i, j) for i = 1:n, j = 1:n if i j)
396-
rows .= getindex.(I, 1)
397-
cols .= getindex.(I, 2)
398-
return rows, cols
388+
return hess_structure!(nls.adbackend, nls, rows, cols)
399389
end
400390

401391
function NLPModels.hess_coord!(
@@ -408,15 +398,7 @@ function NLPModels.hess_coord!(
408398
@lencheck nls.meta.nnzh vals
409399
increment!(nls, :neval_hess)
410400
(x) = obj_weight * sum(nls.F(x) .^ 2) / 2
411-
Hx = hessian(nls.adbackend, ℓ, x)
412-
k = 1
413-
for j = 1:(nls.meta.nvar)
414-
for i = j:(nls.meta.nvar)
415-
vals[k] = Hx[i, j]
416-
k += 1
417-
end
418-
end
419-
return vals
401+
return hess_coord!(nls.adbackend, nls, x, ℓ, vals)
420402
end
421403

422404
function NLPModels.hess_coord!(
@@ -431,15 +413,7 @@ function NLPModels.hess_coord!(
431413
@lencheck nls.meta.nnzh vals
432414
increment!(nls, :neval_hess)
433415
(x) = obj_weight * sum(nls.F(x) .^ 2) / 2 + dot(y, nls.c(x))
434-
Hx = hessian(nls.adbackend, ℓ, x)
435-
k = 1
436-
for j = 1:(nls.meta.nvar)
437-
for i = j:(nls.meta.nvar)
438-
vals[k] = Hx[i, j]
439-
k += 1
440-
end
441-
end
442-
return vals
416+
return hess_coord!(nls.adbackend, nls, x, ℓ, vals)
443417
end
444418

445419
function NLPModels.hprod!(

0 commit comments

Comments
 (0)