|
1 | | -struct DIADGradient{B, E} <: ADBackend |
2 | | - backend::B |
3 | | - extras::E |
4 | | -end |
| 1 | +for (ADGradient, fbackend) in ((:EnzymeADGradient, :AutoEnzyme), |
| 2 | + (:ZygoteADGradient, :AutoZygote)) |
| 3 | + @eval begin |
5 | 4 |
|
6 | | -function DIADGradient( |
7 | | - nvar::Integer, |
8 | | - f, |
9 | | - ncon::Integer = 0, |
10 | | - c::Function = (args...) -> []; |
11 | | - x0::AbstractVector = rand(nvar), |
12 | | - backend::AbstractADType = AutoForwardDiff(), |
13 | | - kwargs..., |
14 | | -) |
15 | | - extras = DifferentiationInterface.prepare_gradient(f, backend, x0) |
16 | | - return DIADGradient(backend, extras) |
17 | | -end |
| 5 | + struct $ADGradient{B, E} <: ADBackend |
| 6 | + backend::B |
| 7 | + extras::E |
| 8 | + end |
18 | 9 |
|
19 | | -function gradient(b::DIADGradient, f, x) |
20 | | - g = DifferentiationInterface.gradient(f, b.backend, x, b.extras) |
21 | | - return g |
22 | | -end |
| 10 | + function $ADGradient( |
| 11 | + nvar::Integer, |
| 12 | + f, |
| 13 | + ncon::Integer = 0, |
| 14 | + c::Function = (args...) -> []; |
| 15 | + x0::AbstractVector = rand(nvar), |
| 16 | + kwargs..., |
| 17 | + ) |
| 18 | + backend = $fbackend() |
| 19 | + extras = DifferentiationInterface.prepare_gradient(f, backend, x0) |
| 20 | + return $ADGradient(backend, extras) |
| 21 | + end |
23 | 22 |
|
24 | | -function gradient!(b::DIADGradient, g, f, x) |
25 | | - DifferentiationInterface.gradient!(f, g, b.backend, x, b.extras) |
26 | | - return g |
27 | | -end |
| 23 | + function gradient(b::$ADGradient, f, x) |
| 24 | + g = DifferentiationInterface.gradient(f, b.extras, b.backend, x) |
| 25 | + return g |
| 26 | + end |
28 | 27 |
|
29 | | -struct DIADJprod{B, E} <: ADBackend |
30 | | - backend::B |
31 | | - extras::E |
32 | | -end |
| 28 | + function gradient!(b::$ADGradient, g, f, x) |
| 29 | + DifferentiationInterface.gradient!(f, g, b.extras, b.backend, x) |
| 30 | + return g |
| 31 | + end |
33 | 32 |
|
34 | | -function DIADJprod( |
35 | | - nvar::Integer, |
36 | | - f, |
37 | | - ncon::Integer = 0, |
38 | | - c::Function = (args...) -> []; |
39 | | - x0::AbstractVector = rand(nvar), |
40 | | - backend::AbstractADType = AutoForwardDiff(), |
41 | | - kwargs..., |
42 | | -) |
43 | | - dx = similar(x0, nvar) |
44 | | - extras = DifferentiationInterface.prepare_pushforward(f, backend, x0, dx) |
45 | | - return DIADJprod(backend, extras) |
| 33 | + end |
46 | 34 | end |
47 | 35 |
|
48 | | -function Jprod!(b::DIADJprod, Jv, f, x, v, ::Val) |
49 | | - DifferentiationInterface.pushforward!(f, Jv, b.backend, x, v, b.extras) |
50 | | - return Jv |
51 | | -end |
| 36 | +for (ADJprod, fbackend) in ((:EnzymeADJprod, :AutoEnzyme), |
| 37 | + (:ZygoteADJprod, :AutoZygote)) |
| 38 | + @eval begin |
| 39 | + |
| 40 | + struct $ADJprod{B, E} <: ADBackend |
| 41 | + backend::B |
| 42 | + extras::E |
| 43 | + end |
| 44 | + |
| 45 | + function $ADJprod( |
| 46 | + nvar::Integer, |
| 47 | + f, |
| 48 | + ncon::Integer = 0, |
| 49 | + c::Function = (args...) -> []; |
| 50 | + x0::AbstractVector = rand(nvar), |
| 51 | + kwargs..., |
| 52 | + ) |
| 53 | + backend = $fbackend() |
| 54 | + dx = similar(x0, nvar) |
| 55 | + extras = DifferentiationInterface.prepare_pushforward(f, backend, x0, dx) |
| 56 | + return $ADJprod(backend, extras) |
| 57 | + end |
52 | 58 |
|
53 | | -struct DIADJtprod{B, E} <: ADBackend |
54 | | - backend::B |
55 | | - extras::E |
| 59 | + function Jprod!(b::$ADJprod, Jv, f, x, v, ::Val) |
| 60 | + DifferentiationInterface.pushforward!(f, Jv, b.extras, b.backend, x, v) |
| 61 | + return Jv |
| 62 | + end |
| 63 | + |
| 64 | + end |
56 | 65 | end |
57 | 66 |
|
58 | | -function DIADJtprod( |
59 | | - nvar::Integer, |
60 | | - f, |
61 | | - ncon::Integer = 0, |
62 | | - c::Function = (args...) -> []; |
63 | | - x0::AbstractVector = rand(nvar), |
64 | | - backend::AbstractADType = AutoForwardDiff(), |
65 | | - kwargs..., |
66 | | -) |
67 | | - dy = similar(x0, ncon) |
68 | | - extras = DifferentiationInterface.prepare_pullback(f, backend, x0, dy) |
69 | | - return DIADJtprod(backend, extras) |
| 67 | +for (ADJtprod, fbackend) in ((:EnzymeADJtprod, :AutoEnzyme), |
| 68 | + (:ZygoteADJtprod, :AutoZygote)) |
| 69 | + @eval begin |
| 70 | + |
| 71 | + struct $ADJtprod{B, E} <: ADBackend |
| 72 | + backend::B |
| 73 | + extras::E |
| 74 | + end |
| 75 | + |
| 76 | + function $ADJtprod( |
| 77 | + nvar::Integer, |
| 78 | + f, |
| 79 | + ncon::Integer = 0, |
| 80 | + c::Function = (args...) -> []; |
| 81 | + x0::AbstractVector = rand(nvar), |
| 82 | + kwargs..., |
| 83 | + ) |
| 84 | + backend = $fbackend() |
| 85 | + dy = similar(x0, ncon) |
| 86 | + extras = DifferentiationInterface.prepare_pullback(f, backend, x0, dy) |
| 87 | + return $ADJtprod(backend, extras) |
| 88 | + end |
| 89 | + |
| 90 | + function Jtprod!(b::$ADJtprod, Jtv, f, x, v, ::Val) |
| 91 | + DifferentiationInterface.pullback!(f, Jtv, b.extras, b.backend, x, v) |
| 92 | + return Jtv |
| 93 | + end |
| 94 | + |
| 95 | + end |
70 | 96 | end |
71 | 97 |
|
72 | | -function Jtprod!(b::DIADJtprod, Jtv, f, x, v, ::Val) |
73 | | - DifferentiationInterface.pullback!(f, Jtv, b.backend, x, v, b.extras) |
74 | | - return Jtv |
| 98 | +for (ADJacobian, fbackend) in ((:EnzymeADJacobian, :AutoEnzyme), |
| 99 | + (:ZygoteADJacobian, :AutoZygote)) |
| 100 | + @eval begin |
| 101 | + |
| 102 | + struct $ADJacobian{B, E} <: ADBackend |
| 103 | + backend::B |
| 104 | + extras::E |
| 105 | + end |
| 106 | + |
| 107 | + function $ADJacobian( |
| 108 | + nvar::Integer, |
| 109 | + f, |
| 110 | + ncon::Integer = 0, |
| 111 | + c::Function = (args...) -> []; |
| 112 | + x0::AbstractVector = rand(nvar), |
| 113 | + kwargs..., |
| 114 | + ) |
| 115 | + backend = $fbackend() |
| 116 | + y = similar(x0, ncon) |
| 117 | + extras = DifferentiationInterface.prepare_jacobian(f, y, backend, x0) |
| 118 | + return $ADJacobian(backend, extras) |
| 119 | + end |
| 120 | + |
| 121 | + function jacobian(b::$ADJacobian, f, x) |
| 122 | + return DifferentiationInterface.jacobian(f, b.extras, b.backend, x) |
| 123 | + end |
| 124 | + |
| 125 | + end |
75 | 126 | end |
| 127 | + |
| 128 | +# for (ADHessian, fbackend) in ((:EnzymeADHessian, :AutoEnzyme), |
| 129 | +# (:ZygoteADHessian, :AutoZygote)) |
| 130 | +# @eval begin |
| 131 | +# |
| 132 | +# struct $ADHessian{B, E} <: ADBackend |
| 133 | +# backend::B |
| 134 | +# extras::E |
| 135 | +# end |
| 136 | +# |
| 137 | +# function $ADHessian( |
| 138 | +# nvar::Integer, |
| 139 | +# f, |
| 140 | +# ncon::Integer = 0, |
| 141 | +# c::Function = (args...) -> []; |
| 142 | +# x0::AbstractVector = rand(nvar), |
| 143 | +# kwargs..., |
| 144 | +# ) |
| 145 | +# # We don't support constraints yet! |
| 146 | +# ( c(x0) |> isempty ) || error("Constrained problems are not supported.") |
| 147 | +# backend = $fbackend() |
| 148 | +# extras = DifferentiationInterface.prepare_hessian(f, backend, x0) |
| 149 | +# return $ADHessian(backend, extras) |
| 150 | +# end |
| 151 | +# |
| 152 | +# function Hessian(b::$ADHessian, f, x) |
| 153 | +# return DifferentiationInterface.hessian(f, b.extras, b.backend, x) |
| 154 | +# end |
| 155 | +# |
| 156 | +# end |
| 157 | +# end |
| 158 | + |
| 159 | +# for (ADHvprod, fbackend) in ((:EnzymeADHvprod, :AutoEnzyme), |
| 160 | +# (:ZygoteADHvprod, :AutoZygote)) |
| 161 | +# @eval begin |
| 162 | +# |
| 163 | +# struct $ADHvprod{B, E} <: ADBackend |
| 164 | +# backend::B |
| 165 | +# extras::E |
| 166 | +# end |
| 167 | +# |
| 168 | +# function $ADHvprod( |
| 169 | +# nvar::Integer, |
| 170 | +# f, |
| 171 | +# ncon::Integer = 0, |
| 172 | +# c::Function = (args...) -> []; |
| 173 | +# x0::AbstractVector = rand(nvar), |
| 174 | +# kwargs..., |
| 175 | +# ) |
| 176 | +# backend = $fbackend() |
| 177 | +# ( c(x0) |> isempty ) || error("Constrained problems are not supported.") |
| 178 | +# backend = $fbackend() |
| 179 | +# extras = DifferentiationInterface.prepare_hvp(f, backend, x0, tx) |
| 180 | +# return $ADHprod(backend, extras) |
| 181 | +# end |
| 182 | +# |
| 183 | +# function Hvprod!(b::$ADHprod, Hv, f, x, v, ::Val) |
| 184 | +# DifferentiationInterface.hvp!(f, Hv, b.extras, b.backend, x, v) |
| 185 | +# return Hv |
| 186 | +# end |
| 187 | +# |
| 188 | +# end |
| 189 | +# end |
0 commit comments