@@ -1008,72 +1008,52 @@ function init_optimization!(
10081008 Jfunc, gfunc = let estim= estim, model= model, nZ̃= nZ̃, nV̂= nV̂, nX̂= nX̂, ng= ng, nx̂= nx̂, nu= nu, nŷ= nŷ
10091009 Nc = nZ̃ + 3
10101010 last_Z̃tup_float, last_Z̃tup_dual = nothing , nothing
1011+ Z̃_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nZ̃), Nc)
10111012 V̂_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nV̂), Nc)
10121013 g_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, ng), Nc)
10131014 X̂_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nX̂), Nc)
10141015 x̄_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nx̂), Nc)
10151016 û_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nu), Nc)
10161017 ŷ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nŷ), Nc)
1017- function Jfunc (Z̃tup:: JNT ... )
1018+ function Jfunc (Z̃tup:: T ... ):: T where {T <: Real }
10181019 Z̃1 = Z̃tup[begin ]
1019- V̂ = get_tmp (V̂_cache, Z̃1)
1020- Z̃ = collect (Z̃tup)
1021- if Z̃tup != = last_Z̃tup_float
1022- g = get_tmp (g_cache, Z̃1)
1023- X̂ = get_tmp (X̂_cache, Z̃1)
1024- û, ŷ = get_tmp (û_cache, Z̃1), get_tmp (ŷ_cache, Z̃1)
1025- V̂, X̂ = predict! (V̂, X̂, û, ŷ, estim, model, Z̃)
1026- g = con_nonlinprog! (g, estim, model, X̂, V̂, Z̃)
1020+ if T == JNT
10271021 last_Z̃tup_float = Z̃tup
1028- end
1029- x̄ = get_tmp (x̄_cache, Z̃1)
1030- return obj_nonlinprog! (x̄, estim, model, V̂, Z̃)
1031- end
1032- function Jfunc (Z̃tup:: ForwardDiff.Dual... )
1033- Z̃1 = Z̃tup[begin ]
1034- V̂ = get_tmp (V̂_cache, Z̃1)
1035- Z̃ = collect (Z̃tup)
1036- if Z̃tup != = last_Z̃tup_dual
1037- g = get_tmp (g_cache, Z̃1)
1038- X̂ = get_tmp (X̂_cache, Z̃1)
1039- û, ŷ = get_tmp (û_cache, Z̃1), get_tmp (ŷ_cache, Z̃1)
1040- V̂, X̂ = predict! (V̂, X̂, û, ŷ, estim, model, Z̃)
1041- g = con_nonlinprog! (g, estim, model, X̂, V̂, Z̃)
1022+ else
10421023 last_Z̃tup_dual = Z̃tup
10431024 end
1025+ Z̃, V̂ = get_tmp (Z̃_cache, Z̃1), get_tmp (V̂_cache, Z̃1)
1026+ X̂ = get_tmp (X̂_cache, Z̃1)
1027+ û, ŷ = get_tmp (û_cache, Z̃1), get_tmp (ŷ_cache, Z̃1)
1028+ Z̃ .= Z̃tup
1029+ V̂, X̂ = predict! (V̂, X̂, û, ŷ, estim, model, Z̃)
1030+ g = get_tmp (g_cache, Z̃1)
1031+ g = con_nonlinprog! (g, estim, model, X̂, V̂, Z̃)
10441032 x̄ = get_tmp (x̄_cache, Z̃1)
1045- return obj_nonlinprog! (x̄, estim, model, V̂, Z̃)
1033+ return obj_nonlinprog! (x̄, estim, model, V̂, Z̃):: T
10461034 end
1047- function gfunc_i (i, Z̃tup:: NTuple{N, JNT} ) where N
1035+ function gfunc_i (i, Z̃tup:: NTuple{N, T} ) :: T where {N, T <: Real }
10481036 Z̃1 = Z̃tup[begin ]
10491037 g = get_tmp (g_cache, Z̃1)
1050- if Z̃tup != = last_Z̃tup_float
1051- Z̃ = collect (Z̃tup)
1052- V̂ = get_tmp (V̂_cache, Z̃1)
1053- X̂ = get_tmp (X̂_cache, Z̃1)
1054- û, ŷ = get_tmp (û_cache, Z̃1), get_tmp (ŷ_cache, Z̃1)
1055- V̂, X̂ = predict! (V̂, X̂, û, ŷ, estim, model, Z̃)
1056- g = con_nonlinprog! (g, estim, model, X̂, V̂, Z̃)
1057- last_Z̃tup_float = Z̃tup
1038+ if T == JNT
1039+ isnewvalue = (Z̃tup != = last_Z̃tup_float)
1040+ isnewvalue && (last_Z̃tup_float = Z̃tup)
1041+ else
1042+ isnewvalue = (Z̃tup != = last_Z̃tup_dual)
1043+ isnewvalue && (last_Z̃tup_dual = Z̃tup)
10581044 end
1059- return g[i]
1060- end
1061- function gfunc_i (i, Z̃tup:: NTuple{N, ForwardDiff.Dual} ) where N
1062- Z̃1 = Z̃tup[begin ]
1063- g = get_tmp (g_cache, Z̃1)
1064- if Z̃tup != = last_Z̃tup_dual
1065- Z̃ = collect (Z̃tup)
1066- V̂ = get_tmp (V̂_cache, Z̃1)
1045+ if isnewvalue
1046+ Z̃, V̂ = get_tmp (Z̃_cache, Z̃1), get_tmp (V̂_cache, Z̃1)
10671047 X̂ = get_tmp (X̂_cache, Z̃1)
10681048 û, ŷ = get_tmp (û_cache, Z̃1), get_tmp (ŷ_cache, Z̃1)
1049+ Z̃ .= Z̃tup
10691050 V̂, X̂ = predict! (V̂, X̂, û, ŷ, estim, model, Z̃)
10701051 g = con_nonlinprog! (g, estim, model, X̂, V̂, Z̃)
1071- last_Z̃tup_dual = Z̃tup
10721052 end
10731053 return g[i]
10741054 end
10751055 gfunc = [(Z̃... ) -> gfunc_i (i, Z̃) for i in 1 : ng]
1076- Jfunc, gfunc
1056+ ( Jfunc, gfunc)
10771057 end
10781058 register (optim, :Jfunc , nZ̃, Jfunc, autodiff= true )
10791059 @NLobjective (optim, Min, Jfunc (Z̃var... ))
0 commit comments