Skip to content

Commit 419e586

Browse files
committed
Add three backends based on DifferentiationInterface.jl
1 parent 41a615f commit 419e586

File tree

4 files changed

+82
-4
lines changed

4 files changed

+82
-4
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "0.8.7"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
7+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
78
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6"

src/ADNLPModels.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ module ADNLPModels
44
using LinearAlgebra, SparseArrays
55

66
# external
7-
using ADTypes: ADTypes, AbstractColoringAlgorithm, AbstractSparsityDetector
7+
using ADTypes: ADTypes, AbstractADType, AbstractColoringAlgorithm, AbstractSparsityDetector, AutoForwardDiff
88
using SparseConnectivityTracer: TracerSparsityDetector
9+
import DifferentiationInterface
910
using SparseMatrixColorings
1011
using ForwardDiff, ReverseDiff
1112

@@ -25,6 +26,7 @@ include("sparsity_pattern.jl")
2526
include("sparse_jacobian.jl")
2627
include("sparse_hessian.jl")
2728

29+
include("di.jl")
2830
include("forward.jl")
2931
include("reverse.jl")
3032
include("enzyme.jl")

src/di.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
struct DIADGradient{B, E} <: ADBackend
2+
backend::B
3+
extras::E
4+
end
5+
6+
function DIADGradient(
7+
nvar::Integer,
8+
f,
9+
ncon::Integer = 0,
10+
c::Function = (args...) -> [];
11+
x0::AbstractVector = rand(nvar),
12+
backend::AbstractADType = AutoForwardDiff(),
13+
kwargs...,
14+
)
15+
extras = DifferentiationInterface.prepare_gradient(f, backend, x0)
16+
return DIADGradient(backend, extras)
17+
end
18+
19+
function gradient(b::DIADGradient, f, x)
20+
g = DifferentiationInterface.gradient(f, b.backend, x, b.extras)
21+
return g
22+
end
23+
24+
function gradient!(b::DIADGradient, g, f, x)
25+
DifferentiationInterface.gradient!(f, g, b.backend, x, b.extras)
26+
return g
27+
end
28+
29+
struct DIADJprod{B, E} <: ADBackend
30+
backend::B
31+
extras::E
32+
end
33+
34+
function DIADJprod(
35+
nvar::Integer,
36+
f,
37+
ncon::Integer = 0,
38+
c::Function = (args...) -> [];
39+
x0::AbstractVector = rand(nvar),
40+
backend::AbstractADType = AutoForwardDiff(),
41+
kwargs...,
42+
)
43+
dx = similar(x0, nvar)
44+
extras = DifferentiationInterface.prepare_pushforward(f, backend, x0, dx)
45+
return DIADJprod(backend, extras)
46+
end
47+
48+
function Jprod!(b::DIADJprod, Jv, f, x, v, ::Val)
49+
DifferentiationInterface.pushforward!(f, Jv, b.backend, x, v, b.extras)
50+
return Jv
51+
end
52+
53+
struct DIADJtprod{B, E} <: ADBackend
54+
backend::B
55+
extras::E
56+
end
57+
58+
function DIADJtprod(
59+
nvar::Integer,
60+
f,
61+
ncon::Integer = 0,
62+
c::Function = (args...) -> [];
63+
x0::AbstractVector = rand(nvar),
64+
backend::AbstractADType = AutoForwardDiff(),
65+
kwargs...,
66+
)
67+
dy = similar(x0, ncon)
68+
extras = DifferentiationInterface.prepare_pullback(f, backend, x0, dy)
69+
return DIADJtprod(backend, extras)
70+
end
71+
72+
function Jtprod!(b::DIADJtprod, Jtv, f, x, v, ::Val)
73+
DifferentiationInterface.pullback!(f, Jtv, b.backend, x, v, b.extras)
74+
return Jtv
75+
end

src/predefined_backend.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
default_backend = Dict(
2-
:gradient_backend => ForwardDiffADGradient,
2+
:gradient_backend => DIADGradient,
33
:hprod_backend => ForwardDiffADHvprod,
4-
:jprod_backend => ForwardDiffADJprod,
5-
:jtprod_backend => ForwardDiffADJtprod,
4+
:jprod_backend => DIADJprod,
5+
:jtprod_backend => DIADJtprod,
66
:jacobian_backend => SparseADJacobian,
77
:hessian_backend => SparseADHessian,
88
:ghjvprod_backend => ForwardDiffADGHjvprod,

0 commit comments

Comments
 (0)