Skip to content

Commit f783443

Browse files
authored
Disable gradient and Hessian backends for NLSModels (#357)
1 parent bfe3f7c commit f783443

File tree

3 files changed

+45
-40
lines changed

3 files changed

+45
-40
lines changed

src/ad.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -218,14 +218,14 @@ function ADModelNLSBackend(
218218
backend::Symbol = :default,
219219
matrix_free::Bool = false,
220220
show_time::Bool = false,
221-
gradient_backend = get_default_backend(:gradient_backend, backend),
222-
hprod_backend = get_default_backend(:hprod_backend, backend),
223-
hessian_backend = get_default_backend(:hessian_backend, backend, matrix_free),
224-
hprod_residual_backend = get_default_backend(:hprod_residual_backend, backend),
221+
gradient_backend = EmptyADbackend(),
222+
hprod_backend = EmptyADbackend(),
223+
hessian_backend = EmptyADbackend(),
224+
hprod_residual_backend = EmptyADbackend(),
225225
jprod_residual_backend = get_default_backend(:jprod_residual_backend, backend),
226226
jtprod_residual_backend = get_default_backend(:jtprod_residual_backend, backend),
227227
jacobian_residual_backend = get_default_backend(:jacobian_residual_backend, backend, matrix_free),
228-
hessian_residual_backend = get_default_backend(:hessian_residual_backend, backend, matrix_free),
228+
hessian_residual_backend = EmptyADbackend(),
229229
kwargs...,
230230
)
231231
function F(x; nequ = nequ)
@@ -343,18 +343,18 @@ function ADModelNLSBackend(
343343
backend::Symbol = :default,
344344
matrix_free::Bool = false,
345345
show_time::Bool = false,
346-
gradient_backend = get_default_backend(:gradient_backend, backend),
347-
hprod_backend = get_default_backend(:hprod_backend, backend),
346+
gradient_backend = EmptyADbackend(),
347+
hprod_backend = EmptyADbackend(),
348348
jprod_backend = get_default_backend(:jprod_backend, backend),
349349
jtprod_backend = get_default_backend(:jtprod_backend, backend),
350350
jacobian_backend = get_default_backend(:jacobian_backend, backend, matrix_free),
351-
hessian_backend = get_default_backend(:hessian_backend, backend, matrix_free),
352-
ghjvprod_backend = get_default_backend(:ghjvprod_backend, backend),
353-
hprod_residual_backend = get_default_backend(:hprod_residual_backend, backend),
351+
hessian_backend = EmptyADbackend(),
352+
ghjvprod_backend = EmptyADbackend(),
353+
hprod_residual_backend = EmptyADbackend(),
354354
jprod_residual_backend = get_default_backend(:jprod_residual_backend, backend),
355355
jtprod_residual_backend = get_default_backend(:jtprod_residual_backend, backend),
356356
jacobian_residual_backend = get_default_backend(:jacobian_residual_backend, backend, matrix_free),
357-
hessian_residual_backend = get_default_backend(:hessian_residual_backend, backend, matrix_free),
357+
hessian_residual_backend = EmptyADbackend(),
358358
kwargs...,
359359
)
360360
function F(x; nequ = nequ)

test/manual.jl

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,13 @@ function test_nlp_consistency(nlp, model; counters = true)
44
v = 2 * ones(nvar)
55
y = ones(ncon)
66

7-
@test grad(nlp, x) == grad(model, x)
8-
@test !counters || (neval_grad(model) == 2)
9-
@test hess_coord(nlp, x) == hess_coord(model, x)
10-
@test !counters || (neval_hess(model) == 2)
11-
@test hprod(nlp, x, v) == hprod(model, x, v)
12-
@test !counters || (neval_hprod(model) == 2)
7+
# TODO: only test the backends that are defined
138
if model.meta.nnln > 0
149
@test jac(nlp, x) == jac(model, x)
1510
@test !counters || (neval_jac_nln(model) == 2)
1611
@test jprod(nlp, x, v) == jprod(model, x, v)
1712
@test !counters || (neval_jprod_nln(model) == 2)
1813
@test jtprod(nlp, x, y) == jtprod(model, x, y)
19-
@test hess_coord(nlp, x, y) == hess_coord(model, x, y)
20-
@test !counters || (neval_hess(model) == 4)
21-
@test hprod(nlp, x, y, v) == hprod(model, x, y, v)
22-
@test !counters || (neval_hprod(model) == 4)
23-
@test ghjvprod(nlp, x, x, v) == ghjvprod(model, x, x, v)
24-
@test !counters || (neval_hprod(model) == 6)
25-
for j in model.meta.nln
26-
@test jth_hess(nlp, x, j) == jth_hess(model, x, j)
27-
@test jth_hprod(nlp, x, v, j) == jth_hprod(model, x, v, j)
28-
end
2914
end
3015

3116
if (nlp isa AbstractNLSModel) && (model isa AbstractNLSModel)
@@ -41,6 +26,25 @@ function test_nlp_consistency(nlp, model; counters = true)
4126
#for i=1:nequ
4227
# @test hprod_residual(nlp, x, i, v) == hprod_residual(model, x, i, v)
4328
#end
29+
else
30+
@test grad(nlp, x) == grad(model, x)
31+
@test !counters || (neval_grad(model) == 2)
32+
@test hess_coord(nlp, x) == hess_coord(model, x)
33+
@test !counters || (neval_hess(model) == 2)
34+
@test hprod(nlp, x, v) == hprod(model, x, v)
35+
@test !counters || (neval_hprod(model) == 2)
36+
if model.meta.nnln > 0
37+
@test hess_coord(nlp, x, y) == hess_coord(model, x, y)
38+
@test !counters || (neval_hess(model) == 4)
39+
@test hprod(nlp, x, y, v) == hprod(model, x, y, v)
40+
@test !counters || (neval_hprod(model) == 4)
41+
@test ghjvprod(nlp, x, x, v) == ghjvprod(model, x, x, v)
42+
@test !counters || (neval_hprod(model) == 6)
43+
for j in model.meta.nln
44+
@test jth_hess(nlp, x, j) == jth_hess(model, x, j)
45+
@test jth_hprod(nlp, x, v, j) == jth_hprod(model, x, v, j)
46+
end
47+
end
4448
end
4549
end
4650

@@ -250,18 +254,12 @@ end
250254
c!,
251255
lcon,
252256
ucon,
253-
gradient_backend = adbackend.gradient_backend,
254-
hprod_backend = adbackend.hprod_backend,
255-
hessian_backend = adbackend.hessian_backend,
256257
jprod_backend = adbackend.jprod_backend,
257258
jtprod_backend = adbackend.jtprod_backend,
258259
jacobian_backend = adbackend.jacobian_backend,
259-
ghjvprod_backend = adbackend.ghjvprod_backend,
260-
hprod_residual_backend = adbackend.hprod_residual_backend,
261260
jprod_residual_backend = adbackend.jprod_residual_backend,
262261
jtprod_residual_backend = adbackend.jtprod_residual_backend,
263262
jacobian_residual_backend = adbackend.jacobian_residual_backend,
264-
hessian_residual_backend = adbackend.hessian_residual_backend,
265263
)
266264
test_nlp_consistency(nlp, nlp; counters = false)
267265
end

test/nls/nlpmodelstest.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,20 @@ function nls_nlpmodelstest(backend)
1313
push!(nlss, eval(Meta.parse(spc))())
1414
end
1515

16-
exclude = if problem == "LLS"
17-
[hess_coord, hess]
18-
elseif problem == "MGH01"
19-
[hess_coord, hess, ghjvprod]
20-
else
21-
[]
22-
end
16+
# TODO: test backends that have been defined
17+
exclude = [
18+
grad,
19+
hess,
20+
hess_coord,
21+
hprod,
22+
jth_hess,
23+
jth_hess_coord,
24+
jth_hprod,
25+
ghjvprod,
26+
hess_residual,
27+
jth_hess_residual,
28+
hprod_residual,
29+
]
2330

2431
for nls in nlss
2532
show(IOBuffer(), nls)

0 commit comments

Comments
 (0)