@@ -43,7 +43,7 @@ function hv_f2_alloc(x, f, p)
4343 Enzyme. autodiff_deferred (Enzyme. Reverse,
4444 firstapply,
4545 Active,
46- f ,
46+ Const (f) ,
4747 Enzyme. Duplicated (x, dx),
4848 Const (p)
4949 )
@@ -105,7 +105,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
105105 )
106106 end
107107 elseif g == true
108- grad = (G, θ) -> f. grad (G, θ, p)
108+ grad = (G, θ, p = p ) -> f. grad (G, θ, p)
109109 else
110110 grad = nothing
111111 end
@@ -123,7 +123,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
123123 return y
124124 end
125125 elseif fg == true
126- fg! = (res, θ) -> f. fg (res, θ, p)
126+ fg! = (res, θ, p = p ) -> f. fg (res, θ, p)
127127 else
128128 fg! = nothing
129129 end
@@ -139,7 +139,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
139139 vdbθ = Tuple ((copy (r) for r in eachrow (f. hess_prototype)))
140140 end
141141
142- function hess (res, θ)
142+ function hess (res, θ, p = p )
143143 Enzyme. make_zero! (bθ)
144144 Enzyme. make_zero! .(vdbθ)
145145
@@ -156,13 +156,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
156156 end
157157 end
158158 elseif h == true
159- hess = (H, θ) -> f. hess (H, θ, p)
159+ hess = (H, θ, p = p ) -> f. hess (H, θ, p)
160160 else
161161 hess = nothing
162162 end
163163
164164 if fgh == true && f. fgh === nothing
165- function fgh! (G, H, θ)
165+ function fgh! (G, H, θ, p = p )
166166 vdθ = Tuple ((Array (r) for r in eachrow (I (length (θ)) * one (eltype (θ)))))
167167 vdbθ = Tuple (zeros (eltype (θ), length (θ)) for i in eachindex (θ))
168168
@@ -179,20 +179,20 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
179179 end
180180 end
181181 elseif fgh == true
182- fgh! = (G, H, θ) -> f. fgh (G, H, θ, p)
182+ fgh! = (G, H, θ, p = p ) -> f. fgh (G, H, θ, p)
183183 else
184184 fgh! = nothing
185185 end
186186
187187 if hv == true && f. hv === nothing
188- function hv! (H, θ, v)
188+ function hv! (H, θ, v, p = p )
189189 H .= Enzyme. autodiff (
190190 Enzyme. Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated (θ, v),
191- Const (_f), Const ( f. f), Const (p)
191+ Const (f. f), Const (p)
192192 )[1 ]
193193 end
194194 elseif hv == true
195- hv! = (H, θ, v) -> f. hv (H, θ, v, p)
195+ hv! = (H, θ, v, p = p ) -> f. hv (H, θ, v, p)
196196 else
197197 hv! = nothing
198198 end
@@ -247,7 +247,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
247247 cons_j! = nothing
248248 end
249249
250- if cons != = nothing && cons_vjp == true && f. cons_vjp == true
250+ if cons != = nothing && cons_vjp == true && f. cons_vjp === nothing
251251 cons_res = zeros (eltype (x), num_cons)
252252 function cons_vjp! (res, θ, v)
253253 Enzyme. make_zero! (res)
@@ -267,7 +267,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
267267 cons_vjp! = nothing
268268 end
269269
270- if cons != = nothing && cons_jvp == true && f. cons_jvp == true
270+ if cons != = nothing && cons_jvp == true && f. cons_jvp === nothing
271271 cons_res = zeros (eltype (x), num_cons)
272272
273273 function cons_jvp! (res, θ, v)
@@ -327,7 +327,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
327327 lag_vdbθ = Tuple ((copy (r) for r in eachrow (f. hess_prototype)))
328328 end
329329
330- function lag_h! (h, θ, σ, μ)
330+ function lag_h! (h, θ, σ, μ, p = p )
331331 Enzyme. make_zero! (lag_bθ)
332332 Enzyme. make_zero! .(lag_vdbθ)
333333
@@ -350,8 +350,30 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
350350 k += i
351351 end
352352 end
353+
354+ function lag_h! (H:: AbstractMatrix , θ, σ, μ, p = p)
355+ Enzyme. make_zero! (H)
356+ Enzyme. make_zero! (lag_bθ)
357+ Enzyme. make_zero! .(lag_vdbθ)
358+
359+ Enzyme. autodiff (Enzyme. Forward,
360+ lag_grad,
361+ Enzyme. BatchDuplicated (θ, lag_vdθ),
362+ Enzyme. BatchDuplicatedNoNeed (lag_bθ, lag_vdbθ),
363+ Const (lagrangian),
364+ Const (f. f),
365+ Const (f. cons),
366+ Const (p),
367+ Const (σ),
368+ Const (μ)
369+ )
370+
371+ for i in eachindex (θ)
372+ H[i, :] .= lag_vdbθ[i]
373+ end
374+ end
353375 elseif lag_h == true && cons != = nothing
354- lag_h! = (θ, σ, μ) -> f. lag_h (θ, σ, μ, p)
376+ lag_h! = (θ, σ, μ, p = p ) -> f. lag_h (θ, σ, μ, p)
355377 else
356378 lag_h! = nothing
357379 end
@@ -389,7 +411,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
389411 lag_h = false )
390412 if g == true && f. grad === nothing
391413 res = zeros (eltype (x), size (x))
392- function grad (θ)
414+ function grad (θ, p = p )
393415 Enzyme. make_zero! (res)
394416 Enzyme. autodiff (Enzyme. Reverse,
395417 Const (firstapply),
@@ -401,14 +423,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
401423 return res
402424 end
403425 elseif fg == true
404- grad = (θ) -> f. grad (θ, p)
426+ grad = (θ, p = p ) -> f. grad (θ, p)
405427 else
406428 grad = nothing
407429 end
408430
409431 if fg == true && f. fg === nothing
410432 res_fg = zeros (eltype (x), size (x))
411- function fg! (θ)
433+ function fg! (θ, p = p )
412434 Enzyme. make_zero! (res_fg)
413435 y = Enzyme. autodiff (Enzyme. ReverseWithPrimal,
414436 Const (firstapply),
@@ -420,7 +442,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
420442 return y, res
421443 end
422444 elseif fg == true
423- fg! = (θ) -> f. fg (θ, p)
445+ fg! = (θ, p = p ) -> f. fg (θ, p)
424446 else
425447 fg! = nothing
426448 end
@@ -430,7 +452,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
430452 bθ = zeros (eltype (x), length (x))
431453 vdbθ = Tuple (zeros (eltype (x), length (x)) for i in eachindex (x))
432454
433- function hess (θ)
455+ function hess (θ, p = p )
434456 Enzyme. make_zero! (bθ)
435457 Enzyme. make_zero! .(vdbθ)
436458
@@ -446,7 +468,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
446468 vcat, [reshape (vdbθ[i], (1 , length (vdbθ[i]))) for i in eachindex (θ)])
447469 end
448470 elseif h == true
449- hess = (θ) -> f. hess (θ, p)
471+ hess = (θ, p = p ) -> f. hess (θ, p)
450472 else
451473 hess = nothing
452474 end
@@ -457,7 +479,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
457479 G_fgh = zeros (eltype (x), length (x))
458480 H_fgh = zeros (eltype (x), length (x), length (x))
459481
460- function fgh! (θ)
482+ function fgh! (θ, p = p )
461483 Enzyme. make_zero! (G_fgh)
462484 Enzyme. make_zero! (H_fgh)
463485 Enzyme. make_zero! .(vdbθ_fgh)
@@ -476,20 +498,20 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
476498 return G_fgh, H_fgh
477499 end
478500 elseif fgh == true
479- fgh! = (θ) -> f. fgh (θ, p)
501+ fgh! = (θ, p = p ) -> f. fgh (θ, p)
480502 else
481503 fgh! = nothing
482504 end
483505
484506 if hv == true && f. hv === nothing
485- function hv! (θ, v)
507+ function hv! (θ, v, p = p )
486508 return Enzyme. autodiff (
487509 Enzyme. Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated (θ, v),
488510 Const (_f), Const (f. f), Const (p)
489511 )[1 ]
490512 end
491513 elseif hv == true
492- hv! = (θ, v) -> f. hv (θ, v, p)
514+ hv! = (θ, v, p = p ) -> f. hv (θ, v, p)
493515 else
494516 hv! = f. hv
495517 end
@@ -604,7 +626,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
604626 lag_vdbθ = Tuple ((copy (r) for r in eachrow (f. hess_prototype)))
605627 end
606628
607- function lag_h! (θ, σ, μ)
629+ function lag_h! (θ, σ, μ, p = p )
608630 Enzyme. make_zero! (lag_bθ)
609631 Enzyme. make_zero! .(lag_vdbθ)
610632
@@ -630,7 +652,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
630652 return res
631653 end
632654 elseif lag_h == true && cons != = nothing
633- lag_h! = (θ, σ, μ) -> f. lag_h (θ, σ, μ, p)
655+ lag_h! = (θ, σ, μ, p = p ) -> f. lag_h (θ, σ, μ, p)
634656 else
635657 lag_h! = nothing
636658 end
0 commit comments