@@ -204,13 +204,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
204204 end
205205
206206 if cons != = nothing && cons_j == true && f. cons_j === nothing
207- if num_cons > length (x)
208- seeds = Enzyme. onehot (x)
209- Jaccache = Tuple (zeros (eltype (x), num_cons) for i in 1 : length (x))
210- else
211- seeds = Enzyme. onehot (zeros (eltype (x), num_cons))
212- Jaccache = Tuple (zero (x) for i in 1 : num_cons)
213- end
207+ # if num_cons > length(x)
208+ seeds = Enzyme. onehot (x)
209+ Jaccache = Tuple (zeros (eltype (x), num_cons) for i in 1 : length (x))
210+ # else
211+ # seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
212+ # Jaccache = Tuple(zero(x) for i in 1:num_cons)
213+ # end
214214
215215 y = zeros (eltype (x), num_cons)
216216
@@ -219,27 +219,26 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
219219 Enzyme. make_zero! (Jaccache[i])
220220 end
221221 Enzyme. make_zero! (y)
222- if num_cons > length (θ)
223- Enzyme. autodiff (Enzyme. Forward, f. cons, BatchDuplicated (y, Jaccache),
224- BatchDuplicated (θ, seeds), Const (p))
225- for i in eachindex (θ)
226- if J isa Vector
227- J[i] = Jaccache[i][1 ]
228- else
229- copyto! (@view (J[:, i]), Jaccache[i])
230- end
231- end
232- else
233- Enzyme. autodiff (Enzyme. Reverse, f. cons, BatchDuplicated (y, seeds),
234- BatchDuplicated (θ, Jaccache), Const (p))
235- for i in 1 : num_cons
236- if J isa Vector
237- J .= Jaccache[1 ]
238- else
239- copyto! (@view (J[i, :]), Jaccache[i])
240- end
222+ Enzyme. autodiff (Enzyme. Forward, f. cons, BatchDuplicated (y, Jaccache),
223+ BatchDuplicated (θ, seeds), Const (p))
224+ for i in eachindex (θ)
225+ if J isa Vector
226+ J[i] = Jaccache[i][1 ]
227+ else
228+ copyto! (@view (J[:, i]), Jaccache[i])
241229 end
242230 end
231+ # else
232+ # Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds),
233+ # BatchDuplicated(θ, Jaccache), Const(p))
234+ # for i in 1:num_cons
235+ # if J isa Vector
236+ # J .= Jaccache[1]
237+ # else
238+ # J[i, :] = Jaccache[i]
239+ # end
240+ # end
241+ # end
243242 end
244243 elseif cons_j == true && cons != = nothing
245244 cons_j! = (J, θ) -> f. cons_j (J, θ, p)
0 commit comments