Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit a7c5a89

Browse files
enzyme reverse mode in constraint jacobian
1 parent 4ba2a4c commit a7c5a89

File tree

1 file changed

+25
-26
lines changed

1 file changed

+25
-26
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
204204
end
205205

206206
if cons !== nothing && cons_j == true && f.cons_j === nothing
207-
if num_cons > length(x)
208-
seeds = Enzyme.onehot(x)
209-
Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x))
210-
else
211-
seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
212-
Jaccache = Tuple(zero(x) for i in 1:num_cons)
213-
end
207+
# if num_cons > length(x)
208+
seeds = Enzyme.onehot(x)
209+
Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x))
210+
# else
211+
# seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
212+
# Jaccache = Tuple(zero(x) for i in 1:num_cons)
213+
# end
214214

215215
y = zeros(eltype(x), num_cons)
216216

@@ -219,27 +219,26 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
219219
Enzyme.make_zero!(Jaccache[i])
220220
end
221221
Enzyme.make_zero!(y)
222-
if num_cons > length(θ)
223-
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
224-
BatchDuplicated(θ, seeds), Const(p))
225-
for i in eachindex(θ)
226-
if J isa Vector
227-
J[i] = Jaccache[i][1]
228-
else
229-
copyto!(@view(J[:, i]), Jaccache[i])
230-
end
231-
end
232-
else
233-
Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds),
234-
BatchDuplicated(θ, Jaccache), Const(p))
235-
for i in 1:num_cons
236-
if J isa Vector
237-
J .= Jaccache[1]
238-
else
239-
copyto!(@view(J[i, :]), Jaccache[i])
240-
end
222+
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
223+
BatchDuplicated(θ, seeds), Const(p))
224+
for i in eachindex(θ)
225+
if J isa Vector
226+
J[i] = Jaccache[i][1]
227+
else
228+
copyto!(@view(J[:, i]), Jaccache[i])
241229
end
242230
end
231+
# else
232+
# Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds),
233+
# BatchDuplicated(θ, Jaccache), Const(p))
234+
# for i in 1:num_cons
235+
# if J isa Vector
236+
# J .= Jaccache[1]
237+
# else
238+
# J[i, :] = Jaccache[i]
239+
# end
240+
# end
241+
# end
243242
end
244243
elseif cons_j == true && cons !== nothing
245244
cons_j! = (J, θ) -> f.cons_j(J, θ, p)

0 commit comments

Comments
 (0)