@@ -302,49 +302,55 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
302302 nu, ny, nx̂, Hp, ng = model. nu, model. ny, mpc. estim. nx̂, mpc. Hp, length (con. i_g)
303303 # inspired from https://jump.dev/JuMP.jl/stable/tutorials/nonlinear/tips_and_tricks/#User-defined-operators-with-vector-outputs
304304 Jfunc, gfunc = let mpc= mpc, model= model, ng= ng, nΔŨ= nΔŨ, nŶ= Hp* ny, nx̂= nx̂, nu= nu, nU= Hp* nu
305+ Nc = nΔŨ + 3
305306 last_ΔŨtup_float, last_ΔŨtup_dual = nothing , nothing
306- Ŷ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nŶ), nΔŨ + 3 )
307- g_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, ng), nΔŨ + 3 )
308- x̂_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nx̂), nΔŨ + 3 )
309- u_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nu), nΔŨ + 3 )
310- Ȳ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nŶ), nΔŨ + 3 )
311- Ū_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nU), nΔŨ + 3 )
307+ Ŷ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nŶ), Nc)
308+ U_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nU), Nc)
309+ g_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, ng), Nc)
310+ x̂_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nx̂), Nc)
311+ x̂next_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nx̂), Nc)
312+ u_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nu), Nc)
313+ Ȳ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nŶ), Nc)
314+ Ū_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nU), Nc)
312315 function Jfunc (ΔŨtup:: JNT... )
313316 ΔŨ1 = ΔŨtup[begin ]
314317 Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
315318 ΔŨ = collect (ΔŨtup)
316319 if ΔŨtup != = last_ΔŨtup_float
317- x̂, u = get_tmp (x̂_cache, ΔŨ1), get_tmp (u_cache, ΔŨ1)
318- Ŷ, x̂end = predict! (Ŷ, x̂, u, mpc, model, ΔŨ)
320+ x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
321+ u = get_tmp (u_cache, ΔŨ1)
322+ Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, mpc, model, ΔŨ)
319323 g = get_tmp (g_cache, ΔŨ1)
320324 g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
321325 last_ΔŨtup_float = ΔŨtup
322326 end
323- Ȳ, Ū = get_tmp (Ȳ_cache, ΔŨ1), get_tmp (Ū_cache, ΔŨ1)
324- return obj_nonlinprog! (Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
327+ U, Ȳ, Ū = get_tmp (U_cache, ΔŨ1), get_tmp (Ȳ_cache, ΔŨ1), get_tmp (Ū_cache, ΔŨ1)
328+ return obj_nonlinprog! (U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
325329 end
326330 function Jfunc (ΔŨtup:: ForwardDiff.Dual... )
327331 ΔŨ1 = ΔŨtup[begin ]
328332 Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
329333 ΔŨ = collect (ΔŨtup)
330334 if ΔŨtup != = last_ΔŨtup_dual
331- x̂, u = get_tmp (x̂_cache, ΔŨ1), get_tmp (u_cache, ΔŨ1)
332- Ŷ, x̂end = predict! (Ŷ, x̂, u, mpc, model, ΔŨ)
335+ x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
336+ u = get_tmp (u_cache, ΔŨ1)
337+ Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, mpc, model, ΔŨ)
333338 g = get_tmp (g_cache, ΔŨ1)
334339 g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
335340 last_ΔŨtup_dual = ΔŨtup
336341 end
337- Ȳ, Ū = get_tmp (Ȳ_cache, ΔŨ1), get_tmp (Ū_cache, ΔŨ1)
338- return obj_nonlinprog! (Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
342+ U, Ȳ, Ū = get_tmp (U_cache, ΔŨ1), get_tmp (Ȳ_cache, ΔŨ1), get_tmp (Ū_cache, ΔŨ1)
343+ return obj_nonlinprog! (U, Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
339344 end
340345 function gfunc_i (i, ΔŨtup:: NTuple{N, JNT} ) where N
341346 ΔŨ1 = ΔŨtup[begin ]
342347 g = get_tmp (g_cache, ΔŨ1)
343348 if ΔŨtup != = last_ΔŨtup_float
344- x̂, u = get_tmp (x̂_cache, ΔŨ1), get_tmp (u_cache, ΔŨ1)
345349 Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
346350 ΔŨ = collect (ΔŨtup)
347- Ŷ, x̂end = predict! (Ŷ, x̂, u, mpc, model, ΔŨ)
351+ x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
352+ u = get_tmp (u_cache, ΔŨ1)
353+ Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, mpc, model, ΔŨ)
348354 g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
349355 last_ΔŨtup_float = ΔŨtup
350356 end
@@ -354,10 +360,11 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
354360 ΔŨ1 = ΔŨtup[begin ]
355361 g = get_tmp (g_cache, ΔŨ1)
356362 if ΔŨtup != = last_ΔŨtup_dual
357- x̂, u = get_tmp (x̂_cache, ΔŨ1), get_tmp (u_cache, ΔŨ1)
358363 Ŷ = get_tmp (Ŷ_cache, ΔŨ1)
359364 ΔŨ = collect (ΔŨtup)
360- Ŷ, x̂end = predict! (Ŷ, x̂, u, mpc, model, ΔŨ)
365+ x̂, x̂next = get_tmp (x̂_cache, ΔŨ1), get_tmp (x̂next_cache, ΔŨ1)
366+ u = get_tmp (u_cache, ΔŨ1)
367+ Ŷ, x̂end = predict! (Ŷ, x̂, x̂next, u, mpc, model, ΔŨ)
361368 g = con_nonlinprog! (g, mpc, model, x̂end , Ŷ, ΔŨ)
362369 last_ΔŨtup_dual = ΔŨtup
363370 end
0 commit comments