Skip to content

Commit ad56716

Browse files
committed
Add more backends for Zygote and Enzyme
1 parent 419e586 commit ad56716

File tree

8 files changed

+202
-156
lines changed

8 files changed

+202
-156
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
1616

1717
[compat]
1818
ADTypes = "1.2.1"
19+
DifferentiationInterface = "0.6.0"
1920
ForwardDiff = "0.9.0, 0.10.0"
2021
NLPModels = "0.18, 0.19, 0.20, 0.21"
2122
Requires = "1"

docs/src/backend.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ The functions used internally to define the NLPModel API and the possible backen
1010
| Functions | FowardDiff backends | ReverseDiff backends | Zygote backends | Enzyme backend | Sparse backend |
1111
| --------- | ------------------- | -------------------- | --------------- | -------------- | -------------- |
1212
| `gradient` and `gradient!` | `ForwardDiffADGradient`/`GenericForwardDiffADGradient` | `ReverseDiffADGradient`/`GenericReverseDiffADGradient` | `ZygoteADGradient` | `EnzymeADGradient` | -- |
13-
| `jacobian` | `ForwardDiffADJacobian` | `ReverseDiffADJacobian` | `ZygoteADJacobian` | -- | `SparseADJacobian` |
14-
| `hessian` | `ForwardDiffADHessian` | `ReverseDiffADHessian` | `ZygoteADHessian` | -- | `SparseADHessian`/`SparseReverseADHessian` |
15-
| `Jprod` | `ForwardDiffADJprod`/`GenericForwardDiffADJprod` | `ReverseDiffADJprod`/`GenericReverseDiffADJprod` | `ZygoteADJprod` | -- | -- |
16-
| `Jtprod` | `ForwardDiffADJtprod`/`GenericForwardDiffADJtprod` | `ReverseDiffADJtprod`/`GenericReverseDiffADJtprod` | `ZygoteADJtprod` | -- | -- |
13+
| `jacobian` | `ForwardDiffADJacobian` | `ReverseDiffADJacobian` | `ZygoteADJacobian` | `EnzymeADJacobian` | `SparseADJacobian` |
14+
| `hessian` | `ForwardDiffADHessian` | `ReverseDiffADHessian` | -- | -- | `SparseADHessian`/`SparseReverseADHessian` |
15+
| `Jprod` | `ForwardDiffADJprod`/`GenericForwardDiffADJprod` | `ReverseDiffADJprod`/`GenericReverseDiffADJprod` | `ZygoteADJprod` | `EnzymeADJprod` | -- |
16+
| `Jtprod` | `ForwardDiffADJtprod`/`GenericForwardDiffADJtprod` | `ReverseDiffADJtprod`/`GenericReverseDiffADJtprod` | `ZygoteADJtprod` | `EnzymeADJtprod` | -- |
1717
| `Hvprod` | `ForwardDiffADHvprod`/`GenericForwardDiffADHvprod` | `ReverseDiffADHvprod`/`GenericReverseDiffADHvprod` | -- | -- | -- |
1818
| `directional_second_derivative` | `ForwardDiffADGHjvprod` | -- | -- | -- | -- |
1919

src/ADNLPModels.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ include("sparse_hessian.jl")
2929
include("di.jl")
3030
include("forward.jl")
3131
include("reverse.jl")
32-
include("enzyme.jl")
33-
include("zygote.jl")
3432
include("predefined_backend.jl")
3533
include("nlp.jl")
3634

src/di.jl

Lines changed: 176 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,189 @@
1-
struct DIADGradient{B, E} <: ADBackend
2-
backend::B
3-
extras::E
4-
end
1+
for (ADGradient, fbackend) in ((:EnzymeADGradient, :AutoEnzyme),
2+
(:ZygoteADGradient, :AutoZygote))
3+
@eval begin
54

6-
function DIADGradient(
7-
nvar::Integer,
8-
f,
9-
ncon::Integer = 0,
10-
c::Function = (args...) -> [];
11-
x0::AbstractVector = rand(nvar),
12-
backend::AbstractADType = AutoForwardDiff(),
13-
kwargs...,
14-
)
15-
extras = DifferentiationInterface.prepare_gradient(f, backend, x0)
16-
return DIADGradient(backend, extras)
17-
end
5+
struct $ADGradient{B, E} <: ADBackend
6+
backend::B
7+
extras::E
8+
end
189

19-
function gradient(b::DIADGradient, f, x)
20-
g = DifferentiationInterface.gradient(f, b.backend, x, b.extras)
21-
return g
22-
end
10+
function $ADGradient(
11+
nvar::Integer,
12+
f,
13+
ncon::Integer = 0,
14+
c::Function = (args...) -> [];
15+
x0::AbstractVector = rand(nvar),
16+
kwargs...,
17+
)
18+
backend = $fbackend()
19+
extras = DifferentiationInterface.prepare_gradient(f, backend, x0)
20+
return $ADGradient(backend, extras)
21+
end
2322

24-
function gradient!(b::DIADGradient, g, f, x)
25-
DifferentiationInterface.gradient!(f, g, b.backend, x, b.extras)
26-
return g
27-
end
23+
function gradient(b::$ADGradient, f, x)
24+
g = DifferentiationInterface.gradient(f, b.extras, b.backend, x)
25+
return g
26+
end
2827

29-
struct DIADJprod{B, E} <: ADBackend
30-
backend::B
31-
extras::E
32-
end
28+
function gradient!(b::$ADGradient, g, f, x)
29+
DifferentiationInterface.gradient!(f, g, b.extras, b.backend, x)
30+
return g
31+
end
3332

34-
function DIADJprod(
35-
nvar::Integer,
36-
f,
37-
ncon::Integer = 0,
38-
c::Function = (args...) -> [];
39-
x0::AbstractVector = rand(nvar),
40-
backend::AbstractADType = AutoForwardDiff(),
41-
kwargs...,
42-
)
43-
dx = similar(x0, nvar)
44-
extras = DifferentiationInterface.prepare_pushforward(f, backend, x0, dx)
45-
return DIADJprod(backend, extras)
33+
end
4634
end
4735

48-
function Jprod!(b::DIADJprod, Jv, f, x, v, ::Val)
49-
DifferentiationInterface.pushforward!(f, Jv, b.backend, x, v, b.extras)
50-
return Jv
51-
end
36+
for (ADJprod, fbackend) in ((:EnzymeADJprod, :AutoEnzyme),
37+
(:ZygoteADJprod, :AutoZygote))
38+
@eval begin
39+
40+
struct $ADJprod{B, E} <: ADBackend
41+
backend::B
42+
extras::E
43+
end
44+
45+
function $ADJprod(
46+
nvar::Integer,
47+
f,
48+
ncon::Integer = 0,
49+
c::Function = (args...) -> [];
50+
x0::AbstractVector = rand(nvar),
51+
kwargs...,
52+
)
53+
backend = $fbackend()
54+
dx = similar(x0, nvar)
55+
extras = DifferentiationInterface.prepare_pushforward(f, backend, x0, dx)
56+
return $ADJprod(backend, extras)
57+
end
5258

53-
struct DIADJtprod{B, E} <: ADBackend
54-
backend::B
55-
extras::E
59+
function Jprod!(b::$ADJprod, Jv, f, x, v, ::Val)
60+
DifferentiationInterface.pushforward!(f, Jv, b.extras, b.backend, x, v)
61+
return Jv
62+
end
63+
64+
end
5665
end
5766

58-
function DIADJtprod(
59-
nvar::Integer,
60-
f,
61-
ncon::Integer = 0,
62-
c::Function = (args...) -> [];
63-
x0::AbstractVector = rand(nvar),
64-
backend::AbstractADType = AutoForwardDiff(),
65-
kwargs...,
66-
)
67-
dy = similar(x0, ncon)
68-
extras = DifferentiationInterface.prepare_pullback(f, backend, x0, dy)
69-
return DIADJtprod(backend, extras)
67+
for (ADJtprod, fbackend) in ((:EnzymeADJtprod, :AutoEnzyme),
68+
(:ZygoteADJtprod, :AutoZygote))
69+
@eval begin
70+
71+
struct $ADJtprod{B, E} <: ADBackend
72+
backend::B
73+
extras::E
74+
end
75+
76+
function $ADJtprod(
77+
nvar::Integer,
78+
f,
79+
ncon::Integer = 0,
80+
c::Function = (args...) -> [];
81+
x0::AbstractVector = rand(nvar),
82+
kwargs...,
83+
)
84+
backend = $fbackend()
85+
dy = similar(x0, ncon)
86+
extras = DifferentiationInterface.prepare_pullback(f, backend, x0, dy)
87+
return $ADJtprod(backend, extras)
88+
end
89+
90+
function Jtprod!(b::$ADJtprod, Jtv, f, x, v, ::Val)
91+
DifferentiationInterface.pullback!(f, Jtv, b.extras, b.backend, x, v)
92+
return Jtv
93+
end
94+
95+
end
7096
end
7197

72-
function Jtprod!(b::DIADJtprod, Jtv, f, x, v, ::Val)
73-
DifferentiationInterface.pullback!(f, Jtv, b.backend, x, v, b.extras)
74-
return Jtv
98+
for (ADJacobian, fbackend) in ((:EnzymeADJacobian, :AutoEnzyme),
99+
(:ZygoteADJacobian, :AutoZygote))
100+
@eval begin
101+
102+
struct $ADJacobian{B, E} <: ADBackend
103+
backend::B
104+
extras::E
105+
end
106+
107+
function $ADJacobian(
108+
nvar::Integer,
109+
f,
110+
ncon::Integer = 0,
111+
c::Function = (args...) -> [];
112+
x0::AbstractVector = rand(nvar),
113+
kwargs...,
114+
)
115+
backend = $fbackend()
116+
y = similar(x0, ncon)
117+
extras = DifferentiationInterface.prepare_jacobian(f, y, backend, x0)
118+
return $ADJacobian(backend, extras)
119+
end
120+
121+
function jacobian(b::$ADJacobian, f, x)
122+
return DifferentiationInterface.jacobian(f, b.extras, b.backend, x)
123+
end
124+
125+
end
75126
end
127+
128+
# for (ADHessian, fbackend) in ((:EnzymeADHessian, :AutoEnzyme),
129+
# (:ZygoteADHessian, :AutoZygote))
130+
# @eval begin
131+
#
132+
# struct $ADHessian{B, E} <: ADBackend
133+
# backend::B
134+
# extras::E
135+
# end
136+
#
137+
# function $ADHessian(
138+
# nvar::Integer,
139+
# f,
140+
# ncon::Integer = 0,
141+
# c::Function = (args...) -> [];
142+
# x0::AbstractVector = rand(nvar),
143+
# kwargs...,
144+
# )
145+
# # We don't support constraints yet!
146+
# ( c(x0) |> isempty ) || error("Constrained problems are not supported.")
147+
# backend = $fbackend()
148+
# extras = DifferentiationInterface.prepare_hessian(f, backend, x0)
149+
# return $ADHessian(backend, extras)
150+
# end
151+
#
152+
# function Hessian(b::$ADHessian, f, x)
153+
# return DifferentiationInterface.hessian(f, b.extras, b.backend, x)
154+
# end
155+
#
156+
# end
157+
# end
158+
159+
# for (ADHvprod, fbackend) in ((:EnzymeADHvprod, :AutoEnzyme),
160+
# (:ZygoteADHvprod, :AutoZygote))
161+
# @eval begin
162+
#
163+
# struct $ADHvprod{B, E} <: ADBackend
164+
# backend::B
165+
# extras::E
166+
# end
167+
#
168+
# function $ADHvprod(
169+
# nvar::Integer,
170+
# f,
171+
# ncon::Integer = 0,
172+
# c::Function = (args...) -> [];
173+
# x0::AbstractVector = rand(nvar),
174+
# kwargs...,
175+
# )
176+
# backend = $fbackend()
177+
# ( c(x0) |> isempty ) || error("Constrained problems are not supported.")
178+
# backend = $fbackend()
179+
# extras = DifferentiationInterface.prepare_hvp(f, backend, x0, tx)
180+
# return $ADHprod(backend, extras)
181+
# end
182+
#
183+
# function Hvprod!(b::$ADHprod, Hv, f, x, v, ::Val)
184+
# DifferentiationInterface.hvp!(f, Hv, b.extras, b.backend, x, v)
185+
# return Hv
186+
# end
187+
#
188+
# end
189+
# end

src/predefined_backend.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
default_backend = Dict(
2-
:gradient_backend => DIADGradient,
2+
:gradient_backend => ForwardDiffADGradient,
33
:hprod_backend => ForwardDiffADHvprod,
4-
:jprod_backend => DIADJprod,
5-
:jtprod_backend => DIADJtprod,
4+
:jprod_backend => ForwardDiffADJprod,
5+
:jtprod_backend => ForwardDiffADJtprod,
66
:jacobian_backend => SparseADJacobian,
77
:hessian_backend => SparseADHessian,
88
:ghjvprod_backend => ForwardDiffADGHjvprod,

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1515

1616
[compat]
1717
CUDA = "4, 5"
18-
Enzyme = "0.10, 0.11, 0.12"
18+
Enzyme = "0.12"
1919
ForwardDiff = "0.10"
2020
ManualNLPModels = "0.1"
2121
NLPModels = "0.21"

test/runtests.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using SparseMatrixColorings
33
using ADNLPModels, ManualNLPModels, NLPModels, NLPModelsModifiers, NLPModelsTest
44
using ADNLPModels:
55
gradient, gradient!, jacobian, hessian, Jprod!, Jtprod!, directional_second_derivative, Hvprod!
6+
import Enzyme, Zygote
67

78
@testset "Test sparsity pattern of Jacobian and Hessian" begin
89
f(x) = sum(x .^ 2)
@@ -49,13 +50,27 @@ push!(
4950
:jtprod_backend => ADNLPModels.ZygoteADJtprod,
5051
:hprod_backend => ADNLPModels.ForwardDiffADHvprod,
5152
:jacobian_backend => ADNLPModels.ZygoteADJacobian,
52-
:hessian_backend => ADNLPModels.ZygoteADHessian,
53+
:hessian_backend => ADNLPModels.ForwardDiffADHessian,
5354
:ghjvprod_backend => ADNLPModels.ForwardDiffADGHjvprod,
5455
:jprod_residual_backend => ADNLPModels.ZygoteADJprod,
5556
:jtprod_residual_backend => ADNLPModels.ZygoteADJtprod,
5657
:hprod_residual_backend => ADNLPModels.ForwardDiffADHvprod,
5758
:jacobian_residual_backend => ADNLPModels.ZygoteADJacobian,
58-
:hessian_residual_backend => ADNLPModels.ZygoteADHessian,
59+
:hessian_residual_backend => ADNLPModels.ForwardDiffADHessian,
60+
),
61+
:enzyme_backend => Dict(
62+
:gradient_backend => ADNLPModels.EnzymeADGradient,
63+
:jprod_backend => ADNLPModels.EnzymeADJprod,
64+
:jtprod_backend => ADNLPModels.EnzymeADJtprod,
65+
:hprod_backend => ADNLPModels.ForwardDiffADHvprod,
66+
:jacobian_backend => ADNLPModels.EnzymeADJacobian,
67+
:hessian_backend => ADNLPModels.ForwardDiffADHessian,
68+
:ghjvprod_backend => ADNLPModels.ForwardDiffADGHjvprod,
69+
:jprod_residual_backend => ADNLPModels.EnzymeADJprod,
70+
:jtprod_residual_backend => ADNLPModels.EnzymeADJtprod,
71+
:hprod_residual_backend => ADNLPModels.ForwardDiffADHvprod,
72+
:jacobian_residual_backend => ADNLPModels.EnzymeADJacobian,
73+
:hessian_residual_backend => ADNLPModels.ForwardDiffADHessian,
5974
),
6075
)
6176

@@ -140,9 +155,6 @@ end
140155
# Test the argument error without loading the packages
141156
test_autodiff_backend_error()
142157

143-
# Automatically loads the code for Zygote with Requires
144-
import Zygote
145-
146158
include("nlp/basic.jl")
147159
include("nls/basic.jl")
148160
include("nlp/nlpmodelstest.jl")

0 commit comments

Comments
 (0)