From ea8e45d99a2d73832fb8f29228cd4cbb90357b51 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Mon, 17 Nov 2025 13:42:11 -0500 Subject: [PATCH 01/13] batch_* API & dumb VectorBatchNLPModel --- src/NLPModels.jl | 3 + src/nlp/batch/api.jl | 187 +++++++++++++ src/nlp/batch/vector.jl | 83 ++++++ test/nlp/batch_api.jl | 567 ++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 5 files changed, 841 insertions(+) create mode 100644 src/nlp/batch/api.jl create mode 100644 src/nlp/batch/vector.jl create mode 100644 test/nlp/batch_api.jl diff --git a/src/NLPModels.jl b/src/NLPModels.jl index 17299f9b..a5e2d117 100644 --- a/src/NLPModels.jl +++ b/src/NLPModels.jl @@ -42,4 +42,7 @@ for f in ["utils", "api", "counters", "meta", "show", "tools"] include("nls/$f.jl") end +include("nlp/batch/api.jl") +include("nlp/batch/vector.jl") + end # module diff --git a/src/nlp/batch/api.jl b/src/nlp/batch/api.jl new file mode 100644 index 00000000..ada0e537 --- /dev/null +++ b/src/nlp/batch/api.jl @@ -0,0 +1,187 @@ +const VV = Vector{<:AbstractVector} + +export AbstractBatchNLPModel +export batch_obj, batch_grad, batch_grad!, batch_objgrad, batch_objgrad!, batch_objcons, batch_objcons! +export batch_cons, batch_cons!, batch_cons_lin, batch_cons_lin!, batch_cons_nln, batch_cons_nln! +export batch_jth_con, batch_jth_congrad, batch_jth_congrad!, batch_jth_sparse_congrad +export batch_jac_structure!, batch_jac_structure, batch_jac_coord!, batch_jac_coord +export batch_jac, batch_jprod, batch_jprod!, batch_jtprod, batch_jtprod!, batch_jac_op, batch_jac_op! +export batch_jac_lin_structure!, batch_jac_lin_structure, batch_jac_lin_coord!, batch_jac_lin_coord +export batch_jac_lin, batch_jprod_lin, batch_jprod_lin!, batch_jtprod_lin, batch_jtprod_lin!, batch_jac_lin_op, batch_jac_lin_op! +export batch_jac_nln_structure!, batch_jac_nln_structure, batch_jac_nln_coord!, batch_jac_nln_coord +export batch_jac_nln, batch_jprod_nln, batch_jprod_nln!, batch_jtprod_nln, batch_jtprod_nln!, batch_jac_nln_op, batch_jac_nln_op! +export batch_jth_hess_coord, batch_jth_hess_coord!, batch_jth_hess +export batch_jth_hprod, batch_jth_hprod!, batch_ghjvprod, batch_ghjvprod! +export batch_hess_structure!, batch_hess_structure, batch_hess_coord!, batch_hess_coord +export batch_hess, batch_hprod, batch_hprod!, batch_hess_op, batch_hess_op! +export batch_varscale, batch_lagscale, batch_conscale + +abstract type AbstractBatchNLPModel{T, S} end + +## base api +batch_jac_structure(bnlp::AbstractBatchNLPModel) = + jac_structure(first(bnlp)) +batch_jac_lin_structure(bnlp::AbstractBatchNLPModel) = + jac_lin_structure(first(bnlp)) +batch_jac_nln_structure(bnlp::AbstractBatchNLPModel) = + jac_nln_structure(first(bnlp)) +batch_hess_structure(bnlp::AbstractBatchNLPModel) = + hess_structure(first(bnlp)) +batch_jac_structure!(bnlp::AbstractBatchNLPModel, rows, cols) = + jac_structure!(first(bnlp), rows, cols) +batch_jac_lin_structure!(bnlp::AbstractBatchNLPModel, rows, cols) = + jac_lin_structure!(first(bnlp), rows, cols) +batch_jac_nln_structure!(bnlp::AbstractBatchNLPModel, rows, cols) = + jac_nln_structure!(first(bnlp), rows, cols) +batch_hess_structure!(bnlp::AbstractBatchNLPModel, rows, cols) = + hess_structure!(first(bnlp), rows, cols) +batch_obj(bnlp::AbstractBatchNLPModel, xs::VV) = + _batch_map(obj, bnlp, xs) +batch_grad(bnlp::AbstractBatchNLPModel, xs::VV) = + _batch_map(grad, bnlp, xs) +batch_cons(bnlp::AbstractBatchNLPModel, xs::VV) = + _batch_map(cons, bnlp, xs) +batch_cons_lin(bnlp::AbstractBatchNLPModel, xs::VV) = + _batch_map(cons_lin, bnlp, xs) +batch_cons_nln(bnlp::AbstractBatchNLPModel, xs::VV) = + _batch_map(cons_nln, bnlp, xs) +batch_jac(bnlp::AbstractBatchNLPModel, xs::VV) = + _batch_map(jac, bnlp, xs) +batch_jac_lin(bnlp::AbstractBatchNLPModel) = + _batch_map(jac_lin, bnlp) +batch_jac_nln(bnlp::AbstractBatchNLPModel, xs::VV) = + _batch_map(jac_nln, bnlp, xs) +batch_jac_lin_coord(bnlp::AbstractBatchNLPModel) = + _batch_map(jac_lin_coord, bnlp) +batch_jac_coord(bnlp::AbstractBatchNLPModel, xs::VV) = + _batch_map(jac_coord, bnlp, xs) +batch_jac_nln_coord(bnlp::AbstractBatchNLPModel, xs::VV) = + _batch_map(jac_nln_coord, bnlp, xs) +batch_varscale(bnlp::AbstractBatchNLPModel) = + _batch_map(varscale, bnlp) +batch_lagscale(bnlp::AbstractBatchNLPModel) = + _batch_map(lagscale, bnlp) +batch_conscale(bnlp::AbstractBatchNLPModel) = + _batch_map(conscale, bnlp) +batch_jprod(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV) = + _batch_map(jprod, bnlp, xs, vs) +batch_jtprod(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV) = + _batch_map(jtprod, bnlp, xs, vs) +batch_jprod_nln(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV) = + _batch_map(jprod_nln, bnlp, xs, vs) +batch_jtprod_nln(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV) = + _batch_map(jtprod_nln, bnlp, xs, vs) +batch_jprod_lin(bnlp::AbstractBatchNLPModel, vs::VV) = + _batch_map(jprod_lin, bnlp, vs) +batch_jtprod_lin(bnlp::AbstractBatchNLPModel, vs::VV) = + _batch_map(jtprod_lin, bnlp, vs) +batch_ghjvprod(bnlp::AbstractBatchNLPModel, xs::VV, gs::VV, vs::VV) = + _batch_map(ghjvprod, bnlp, xs, gs, vs) + +batch_grad!(bnlp::AbstractBatchNLPModel, xs::VV, gs::Vector) = + _batch_map!((m, g, x) -> grad!(m, x, g), bnlp, gs, xs) +batch_cons!(bnlp::AbstractBatchNLPModel, xs::VV, cs::Vector) = + _batch_map!((m, c, x) -> cons!(m, x, c), bnlp, cs, xs) +batch_cons_lin!(bnlp::AbstractBatchNLPModel, xs::VV, cs::Vector) = + _batch_map!((m, c, x) -> cons_lin!(m, x, c), bnlp, cs, xs) +batch_cons_nln!(bnlp::AbstractBatchNLPModel, xs::VV, cs::Vector) = + _batch_map!((m, c, x) -> cons_nln!(m, x, c), bnlp, cs, xs) +batch_jac_lin_coord!(bnlp::AbstractBatchNLPModel, valss::Vector) = + _batch_map!((m, vals) -> jac_lin_coord!(m, vals), bnlp, valss) +batch_jac_coord!(bnlp::AbstractBatchNLPModel, xs::VV, valss::Vector) = + _batch_map!((m, vals, x) -> jac_coord!(m, x, vals), bnlp, valss, xs) +batch_jac_nln_coord!(bnlp::AbstractBatchNLPModel, xs::VV, valss::Vector) = + _batch_map!((m, vals, x) -> jac_nln_coord!(m, x, vals), bnlp, valss, xs) +batch_jprod!(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV, Jvs::Vector) = + _batch_map!((m, Jv, x, v) -> jprod!(m, x, v, Jv), bnlp, Jvs, xs, vs) +batch_jtprod!(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV, Jtvs::Vector) = + _batch_map!((m, Jtv, x, v) -> jtprod!(m, x, v, Jtv), bnlp, Jtvs, xs, vs) +batch_jprod_nln!(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV, Jvs::Vector) = + _batch_map!((m, Jv, x, v) -> jprod_nln!(m, x, v, Jv), bnlp, Jvs, xs, vs) +batch_jtprod_nln!(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV, Jtvs::Vector) = + _batch_map!((m, Jtv, x, v) -> jtprod_nln!(m, x, v, Jtv), bnlp, Jtvs, xs, vs) +batch_jprod_lin!(bnlp::AbstractBatchNLPModel, vs::VV, Jvs::Vector) = + _batch_map!((m, Jv, v) -> jprod_lin!(m, v, Jv), bnlp, Jvs, vs) +batch_jtprod_lin!(bnlp::AbstractBatchNLPModel, vs::VV, Jtvs::Vector) = + _batch_map!((m, Jtv, v) -> jtprod_lin!(m, v, Jtv), bnlp, Jtvs, vs) +batch_ghjvprod!(bnlp::AbstractBatchNLPModel, xs::VV, gs::VV, vs::VV, gHvs::Vector) = + _batch_map!((m, gHv, x, g, v) -> ghjvprod!(m, x, g, v, gHv), bnlp, gHvs, xs, gs, vs) + +## jth +batch_jth_con(bnlp::AbstractBatchNLPModel, xs::VV, j::Integer) = + _batch_map((m, x) -> jth_con(m, x, j), bnlp, xs) +batch_jth_congrad(bnlp::AbstractBatchNLPModel, xs::VV, j::Integer) = + _batch_map((m, x) -> jth_congrad(m, x, j), bnlp, xs) +batch_jth_sparse_congrad(bnlp::AbstractBatchNLPModel, xs::VV, j::Integer) = + _batch_map((m, x) -> jth_sparse_congrad(m, x, j), bnlp, xs) +batch_jth_hess_coord(bnlp::AbstractBatchNLPModel, xs::VV, j::Integer) = + _batch_map((m, x) -> jth_hess_coord(m, x, j), bnlp, xs) +batch_jth_hess(bnlp::AbstractBatchNLPModel, xs::VV, j::Integer) = + _batch_map((m, x) -> jth_hess(m, x, j), bnlp, xs) +batch_jth_hprod(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV, j::Integer) = + _batch_map((m, x, v) -> jth_hprod(m, x, v, j), bnlp, xs, vs) + +batch_jth_congrad!(bnlp::AbstractBatchNLPModel, xs::VV, j::Integer, outputs::Vector) = + _batch_map!((m, out, x) -> jth_congrad!(m, x, j, out), bnlp, outputs, xs) +batch_jth_hess_coord!(bnlp::AbstractBatchNLPModel, xs::VV, j::Integer, outputs::Vector) = + _batch_map!((m, out, x) -> jth_hess_coord!(m, x, j, out), bnlp, outputs, xs) +batch_jth_hprod!(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV, j::Integer, outputs::Vector) = + _batch_map!((m, out, x, v) -> jth_hprod!(m, x, v, j, out), bnlp, outputs, xs, vs) + +# hess (need to treat obj_weight) +batch_hprod(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, vs::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = + _batch_map_weight((m, x, v; obj_weight) -> hprod(m, x, v; obj_weight = obj_weight), bnlp, obj_weights, xs, vs) +batch_hprod(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, ys::VV, vs::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = + _batch_map_weight((m, x, y, v; obj_weight) -> hprod(m, x, y, v; obj_weight = obj_weight), bnlp, obj_weights, xs, ys, vs) +batch_hess_coord(bnlp::AbstractBatchNLPModel{T, S}, xs::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = + _batch_map_weight((m, x; obj_weight) -> hess_coord(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) +batch_hess_coord(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, ys::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = + _batch_map_weight((m, x, y; obj_weight) -> hess_coord(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) +batch_hess_op(bnlp::AbstractBatchNLPModel{T, S}, xs::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = + _batch_map_weight((m, x; obj_weight) -> hess_op(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) +batch_hess_op(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, ys::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = + _batch_map_weight((m, x, y; obj_weight) -> hess_op(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) + +batch_hprod!(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, vs::VV, outputs::Vector; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = + _batch_map_weight!((m, Hv, x, v; obj_weight) -> hprod!(m, x, v, Hv; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, vs) +batch_hprod!(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, ys::VV, vs::VV, outputs::Vector; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = + _batch_map_weight!((m, Hv, x, y, v; obj_weight) -> hprod!(m, x, y, v, Hv; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, ys, vs) +batch_hess_coord!(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, outputs::Vector; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = + _batch_map_weight!((m, vals, x; obj_weight) -> hess_coord!(m, x, vals; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs) +batch_hess_coord!(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, ys::VV, outputs::Vector; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = + _batch_map_weight!((m, vals, x, y; obj_weight) -> hess_coord!(m, x, y, vals; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, ys) +batch_hess_op!(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, Hvs::Vector; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = + _batch_map_weight((m, x, Hv; obj_weight) -> hess_op!(m, x, Hv; obj_weight = obj_weight), bnlp, obj_weights, xs, Hvs) +batch_hess_op!(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, ys::VV, Hvs::Vector; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = + _batch_map_weight((m, x, y, Hv; obj_weight) -> hess_op!(m, x, y, Hv; obj_weight = obj_weight), bnlp, obj_weights, xs, ys, Hvs) + +batch_hess(bnlp::AbstractBatchNLPModel{T, S}, xs::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = + _batch_map_weight((m, x; obj_weight) -> hess(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) +batch_hess(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, ys::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = + _batch_map_weight((m, x, y; obj_weight) -> hess(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) + +## operators +batch_jac_op(bnlp::AbstractBatchNLPModel, xs::VV) = + _batch_map(jac_op, bnlp, xs) +batch_jac_lin_op(bnlp::AbstractBatchNLPModel) = + _batch_map(jac_lin_op, bnlp) +batch_jac_nln_op(bnlp::AbstractBatchNLPModel, xs::VV) = + _batch_map(jac_nln_op, bnlp, xs) + +batch_jac_op!(bnlp::AbstractBatchNLPModel, xs::VV, Jvs::Vector, Jtvs::Vector) = + _batch_map((m, x, Jv, Jtv) -> jac_op!(m, x, Jv, Jtv), bnlp, xs, Jvs, Jtvs) +batch_jac_lin_op!(bnlp::AbstractBatchNLPModel, Jvs::Vector, Jtvs::Vector) = + _batch_map((m, Jv, Jtv) -> jac_lin_op!(m, Jv, Jtv), bnlp, Jvs, Jtvs) +batch_jac_nln_op!(bnlp::AbstractBatchNLPModel, xs::VV, Jvs::Vector, Jtvs::Vector) = + _batch_map((m, x, Jv, Jtv) -> jac_nln_op!(m, x, Jv, Jtv), bnlp, xs, Jvs, Jtvs) + +## tuple functions +batch_objgrad(bnlp::AbstractBatchNLPModel, xs::VV) = + _batch_map_tuple(objgrad, bnlp, xs) +batch_objcons(bnlp::AbstractBatchNLPModel, xs::VV) = + _batch_map_tuple(objcons, bnlp, xs) + +batch_objgrad!(bnlp::AbstractBatchNLPModel, xs::VV, gs::Vector) = + _batch_map_tuple!(objgrad!, bnlp, gs, xs) +batch_objcons!(bnlp::AbstractBatchNLPModel, xs::VV, cs::Vector) = + _batch_map_tuple!(objcons!, bnlp, cs, xs) \ No newline at end of file diff --git a/src/nlp/batch/vector.jl b/src/nlp/batch/vector.jl new file mode 100644 index 00000000..8b4e7c30 --- /dev/null +++ b/src/nlp/batch/vector.jl @@ -0,0 +1,83 @@ +export VectorBatchNLPModel +struct VectorBatchNLPModel{T, S, M <: AbstractNLPModel{T, S}} <: AbstractBatchNLPModel{T, S} + models::Vector{M} + meta::NLPModelMeta{T, S} +end +function VectorBatchNLPModel(models::Vector{M}) where {M <: AbstractNLPModel} + isempty(models) && error("Cannot create VectorBatchNLPModel from empty vector") + # TODO: check all metas the same, all structures same, etc. + meta = first(models).meta + VectorBatchNLPModel{eltype(meta.x0), typeof(meta.x0), M}(models, meta) +end +Base.length(vnlp::VectorBatchNLPModel) = length(vnlp.models) +Base.getindex(vnlp::VectorBatchNLPModel, i::Integer) = vnlp.models[i] +Base.iterate(vnlp::VectorBatchNLPModel, state::Integer = 1) = iterate(vnlp.models, state) + +function _batch_map(f, bnlp::VectorBatchNLPModel, xs::VV...) + n = length(bnlp) + results = Vector{Any}(undef, n) + for i = 1:n + args_i = (x[i] for x in xs) + results[i] = f(bnlp[i], args_i...) + end + return results +end + +function _batch_map!(f, bnlp::VectorBatchNLPModel, outputs::Vector, xs::VV...) + n = length(bnlp) + for i = 1:n + args_i = (x[i] for x in xs) + f(bnlp[i], outputs[i], args_i...) + end + return outputs +end + +function _batch_map_weight(f, bnlp::VectorBatchNLPModel, obj_weights::Vector, xs::VV...) + n = length(bnlp) + results = Vector{Any}(undef, n) + for i = 1:n + args_i = (x[i] for x in xs) + results[i] = f(bnlp[i], args_i...; obj_weight = obj_weights[i]) + end + return results +end + +function _batch_map_weight!( + f, + bnlp::VectorBatchNLPModel, + outputs::Vector, + obj_weights::Vector, + xs::VV..., +) + n = length(bnlp) + for i = 1:n + args_i = (x[i] for x in xs) + f(bnlp[i], outputs[i], args_i...; obj_weight = obj_weights[i]) + end + return outputs +end + +function _batch_map_tuple(f, bnlp::VectorBatchNLPModel, xs::VV...) + n = length(bnlp) + results = _batch_map(f, bnlp, xs...) + # Get types from first result + first_result = results[1] + T1 = typeof(first_result[1]) + T2 = typeof(first_result[2]) + vec1 = Vector{T1}(undef, n) + vec2 = Vector{T2}(undef, n) + for i = 1:n + vec1[i], vec2[i] = results[i] + end + return vec1, vec2 +end + +function _batch_map_tuple!(f, bnlp::VectorBatchNLPModel, outputs::Vector, xs::VV...) + n = length(bnlp) + firsts = Vector{eltype(bnlp.meta.x0)}(undef, n) + for i = 1:n + args_i = (x[i] for x in xs) + firsts[i], _ = f(bnlp[i], args_i..., outputs[i]) + end + return firsts, outputs +end \ No newline at end of file diff --git a/test/nlp/batch_api.jl b/test/nlp/batch_api.jl new file mode 100644 index 00000000..4043dd4b --- /dev/null +++ b/test/nlp/batch_api.jl @@ -0,0 +1,567 @@ +@testset "Batch API" begin + # Generate models + n_models = 5 + models = [SimpleNLPModel() for _ = 1:n_models] + n, m = models[1].meta.nvar, models[1].meta.ncon + xs = [randn(n) for _ = 1:n_models] + ys = [randn(m) for _ = 1:n_models] + vs = [randn(n) for _ = 1:n_models] + ws = [randn(m) for _ = 1:n_models] + gs = [zeros(n) for _ = 1:n_models] + cs = [zeros(m) for _ = 1:n_models] + obj_weights = rand(n_models) + for batch_model in [VectorBatchNLPModel] + @testset "$batch_model consistency" begin + bnlp = batch_model(models) + + # Test batch_obj + batch_fs = batch_obj(bnlp, xs) + manual_fs = [obj(models[i], xs[i]) for i = 1:n_models] + @test batch_fs ≈ manual_fs + + # Test batch_grad + batch_gs = batch_grad(bnlp, xs) + manual_gs = [grad(models[i], xs[i]) for i = 1:n_models] + @test batch_gs ≈ manual_gs + + # Test batch_grad! + batch_grad!(bnlp, xs, gs) + manual_gs = [grad!(models[i], xs[i], zeros(n)) for i = 1:n_models] + @test gs ≈ manual_gs + + # Test batch_objgrad + batch_fs, batch_gs = batch_objgrad(bnlp, xs) + manual_fs = [obj(models[i], xs[i]) for i = 1:n_models] + manual_gs = [grad(models[i], xs[i]) for i = 1:n_models] + @test batch_fs ≈ manual_fs + @test batch_gs ≈ manual_gs + + # Test batch_objgrad! + batch_fs, batch_gs = batch_objgrad!(bnlp, xs, gs) + manual_fs = [obj(models[i], xs[i]) for i = 1:n_models] + manual_gs = [grad!(models[i], xs[i], zeros(n)) for i = 1:n_models] + @test batch_fs ≈ manual_fs + @test batch_gs ≈ manual_gs + + # Test batch_cons + batch_cs = batch_cons(bnlp, xs) + manual_cs = [cons(models[i], xs[i]) for i = 1:n_models] + @test batch_cs ≈ manual_cs + + # Test batch_cons! + batch_cons!(bnlp, xs, cs) + manual_cs = [cons!(models[i], xs[i], zeros(m)) for i = 1:n_models] + @test cs ≈ manual_cs + + # Test batch_cons_lin + batch_cs_lin = batch_cons_lin(bnlp, xs) + manual_cs_lin = [cons_lin(models[i], xs[i]) for i = 1:n_models] + @test batch_cs_lin ≈ manual_cs_lin + + # Test batch_cons_lin! + cs_lin = [zeros(bnlp.meta.nlin) for _ = 1:n_models] + batch_cons_lin!(bnlp, xs, cs_lin) + manual_cs_lin = [cons_lin!(models[i], xs[i], zeros(bnlp.meta.nlin)) for i = 1:n_models] + @test cs_lin ≈ manual_cs_lin + + # Test batch_cons_nln + batch_cs_nln = batch_cons_nln(bnlp, xs) + manual_cs_nln = [cons_nln(models[i], xs[i]) for i = 1:n_models] + @test batch_cs_nln ≈ manual_cs_nln + + # Test batch_cons_nln! + cs_nln = [zeros(bnlp.meta.nnln) for _ = 1:n_models] + batch_cons_nln!(bnlp, xs, cs_nln) + manual_cs_nln = [cons_nln!(models[i], xs[i], zeros(bnlp.meta.nnln)) for i = 1:n_models] + @test cs_nln ≈ manual_cs_nln + + # Test batch_objcons + batch_fs, batch_cs = batch_objcons(bnlp, xs) + manual_fs = [obj(models[i], xs[i]) for i = 1:n_models] + manual_cs = [cons(models[i], xs[i]) for i = 1:n_models] + @test batch_fs ≈ manual_fs + @test batch_cs ≈ manual_cs + + # Test batch_objcons! + batch_fs, batch_cs = batch_objcons!(bnlp, xs, cs) + manual_fs = [obj(models[i], xs[i]) for i = 1:n_models] + manual_cs = [cons!(models[i], xs[i], zeros(m)) for i = 1:n_models] + @test batch_fs ≈ manual_fs + @test batch_cs ≈ manual_cs + + # Test batch_jac + batch_jacs = batch_jac(bnlp, xs) + manual_jacs = [jac(models[i], xs[i]) for i = 1:n_models] + @test batch_jacs ≈ manual_jacs + + # Test batch_jac_coord + batch_jac_coords = batch_jac_coord(bnlp, xs) + manual_jac_coords = [jac_coord(models[i], xs[i]) for i = 1:n_models] + @test batch_jac_coords ≈ manual_jac_coords + + # Test batch_jac_coord! + jac_coords = [zeros(bnlp.meta.nnzj) for _ = 1:n_models] + batch_jac_coord!(bnlp, xs, jac_coords) + manual_jac_coords = [jac_coord!(models[i], xs[i], zeros(bnlp.meta.nnzj)) for i = 1:n_models] + @test jac_coords ≈ manual_jac_coords + + # Test batch_jac_lin + batch_jac_lins = batch_jac_lin(bnlp) + manual_jac_lins = [jac_lin(models[i]) for i = 1:n_models] + @test batch_jac_lins ≈ manual_jac_lins + + # Test batch_jac_lin_coord + batch_jac_lin_coords = batch_jac_lin_coord(bnlp) + manual_jac_lin_coords = [jac_lin_coord(models[i]) for i = 1:n_models] + @test batch_jac_lin_coords ≈ manual_jac_lin_coords + + # Test batch_jac_lin_coord! + jac_lin_coords = [zeros(bnlp.meta.lin_nnzj) for _ = 1:n_models] + batch_jac_lin_coord!(bnlp, jac_lin_coords) + manual_jac_lin_coords = + [jac_lin_coord!(models[i], zeros(bnlp.meta.lin_nnzj)) for i = 1:n_models] + @test jac_lin_coords ≈ manual_jac_lin_coords + + # Test batch_jac_nln + batch_jac_nlns = batch_jac_nln(bnlp, xs) + manual_jac_nlns = [jac_nln(models[i], xs[i]) for i = 1:n_models] + @test batch_jac_nlns ≈ manual_jac_nlns + + # Test batch_jac_nln_coord + batch_jac_nln_coords = batch_jac_nln_coord(bnlp, xs) + manual_jac_nln_coords = [jac_nln_coord(models[i], xs[i]) for i = 1:n_models] + @test batch_jac_nln_coords ≈ manual_jac_nln_coords + + # Test batch_jac_nln_coord! + jac_nln_coords = [zeros(bnlp.meta.nln_nnzj) for _ = 1:n_models] + batch_jac_nln_coord!(bnlp, xs, jac_nln_coords) + manual_jac_nln_coords = + [jac_nln_coord!(models[i], xs[i], zeros(bnlp.meta.nln_nnzj)) for i = 1:n_models] + @test jac_nln_coords ≈ manual_jac_nln_coords + + # Test batch_jprod + batch_jprods = batch_jprod(bnlp, xs, vs) + manual_jprods = [jprod(models[i], xs[i], vs[i]) for i = 1:n_models] + @test batch_jprods ≈ manual_jprods + + # Test batch_jprod! + jprods = [zeros(m) for _ = 1:n_models] + batch_jprod!(bnlp, xs, vs, jprods) + manual_jprods = [jprod!(models[i], xs[i], vs[i], zeros(m)) for i = 1:n_models] + @test jprods ≈ manual_jprods + + # Test batch_jtprod + batch_jtprods = batch_jtprod(bnlp, xs, ws) + manual_jtprods = [jtprod(models[i], xs[i], ws[i]) for i = 1:n_models] + @test batch_jtprods ≈ manual_jtprods + + # Test batch_jtprod! + jtprods = [zeros(n) for _ = 1:n_models] + batch_jtprod!(bnlp, xs, ws, jtprods) + manual_jtprods = [jtprod!(models[i], xs[i], ws[i], zeros(n)) for i = 1:n_models] + @test jtprods ≈ manual_jtprods + + # Test batch_jprod_lin + batch_jprod_lins = batch_jprod_lin(bnlp, vs) + manual_jprod_lins = [jprod_lin(models[i], vs[i]) for i = 1:n_models] + @test batch_jprod_lins ≈ manual_jprod_lins + + # Test batch_jprod_lin! + jprod_lins = [zeros(bnlp.meta.nlin) for _ = 1:n_models] + batch_jprod_lin!(bnlp, vs, jprod_lins) + manual_jprod_lins = [jprod_lin!(models[i], vs[i], zeros(bnlp.meta.nlin)) for i = 1:n_models] + @test jprod_lins ≈ manual_jprod_lins + + # Test batch_jtprod_lin + ws_lin = [ws[i][1:(bnlp.meta.nlin)] for i = 1:n_models] + batch_jtprod_lins = batch_jtprod_lin(bnlp, ws_lin) + manual_jtprod_lins = [jtprod_lin(models[i], ws_lin[i]) for i = 1:n_models] + @test batch_jtprod_lins ≈ manual_jtprod_lins + + # Test batch_jtprod_lin! + jtprod_lins = [zeros(n) for _ = 1:n_models] + batch_jtprod_lin!(bnlp, ws_lin, jtprod_lins) + manual_jtprod_lins = [jtprod_lin!(models[i], ws_lin[i], zeros(n)) for i = 1:n_models] + @test jtprod_lins ≈ manual_jtprod_lins + + # Test batch_jprod_nln + batch_jprod_nlns = batch_jprod_nln(bnlp, xs, vs) + manual_jprod_nlns = [jprod_nln(models[i], xs[i], vs[i]) for i = 1:n_models] + @test batch_jprod_nlns ≈ manual_jprod_nlns + + # Test batch_jprod_nln! + jprod_nlns = [zeros(bnlp.meta.nnln) for _ = 1:n_models] + batch_jprod_nln!(bnlp, xs, vs, jprod_nlns) + manual_jprod_nlns = + [jprod_nln!(models[i], xs[i], vs[i], zeros(bnlp.meta.nnln)) for i = 1:n_models] + @test jprod_nlns ≈ manual_jprod_nlns + + # Test batch_jtprod_nln + ws_nln = [ws[i][(bnlp.meta.nlin + 1):end] for i = 1:n_models] + batch_jtprod_nlns = batch_jtprod_nln(bnlp, xs, ws_nln) + manual_jtprod_nlns = [jtprod_nln(models[i], xs[i], ws_nln[i]) for i = 1:n_models] + @test batch_jtprod_nlns ≈ manual_jtprod_nlns + + # Test batch_jtprod_nln! + jtprod_nlns = [zeros(n) for _ = 1:n_models] + batch_jtprod_nln!(bnlp, xs, ws_nln, jtprod_nlns) + manual_jtprod_nlns = [jtprod_nln!(models[i], xs[i], ws_nln[i], zeros(n)) for i = 1:n_models] + @test jtprod_nlns ≈ manual_jtprod_nlns + + # Test batch_hess (without y) + batch_hesses = batch_hess(bnlp, xs) + manual_hesses = [hess(models[i], xs[i]) for i = 1:n_models] + @test batch_hesses ≈ manual_hesses + + # Test batch_hess (with y) + batch_hesses = batch_hess(bnlp, xs, ys) + manual_hesses = [hess(models[i], xs[i], ys[i]) for i = 1:n_models] + @test batch_hesses ≈ manual_hesses + + # Test batch_hess with obj_weights (without y) + batch_hesses = batch_hess(bnlp, xs; obj_weights = obj_weights) + manual_hesses = [hess(models[i], xs[i]; obj_weight = obj_weights[i]) for i = 1:n_models] + @test batch_hesses ≈ manual_hesses + + # Test batch_hess with obj_weights (with y) + batch_hesses = batch_hess(bnlp, xs, ys; obj_weights = obj_weights) + manual_hesses = + [hess(models[i], xs[i], ys[i]; obj_weight = obj_weights[i]) for i = 1:n_models] + @test batch_hesses ≈ manual_hesses + + # Test batch_hess_coord (without y) + batch_hess_coords = batch_hess_coord(bnlp, xs) + manual_hess_coords = [hess_coord(models[i], xs[i]) for i = 1:n_models] + @test batch_hess_coords ≈ manual_hess_coords + + # Test batch_hess_coord (with y) + batch_hess_coords = batch_hess_coord(bnlp, xs, ys) + manual_hess_coords = [hess_coord(models[i], xs[i], ys[i]) for i = 1:n_models] + @test batch_hess_coords ≈ manual_hess_coords + + # Test batch_hess_coord with obj_weights (without y) + batch_hess_coords = batch_hess_coord(bnlp, xs; obj_weights = obj_weights) + manual_hess_coords = + [hess_coord(models[i], xs[i]; obj_weight = obj_weights[i]) for i = 1:n_models] + @test batch_hess_coords ≈ manual_hess_coords + + # Test batch_hess_coord with obj_weights (with y) + batch_hess_coords = batch_hess_coord(bnlp, xs, ys; obj_weights = obj_weights) + manual_hess_coords = + [hess_coord(models[i], xs[i], ys[i]; obj_weight = obj_weights[i]) for i = 1:n_models] + @test batch_hess_coords ≈ manual_hess_coords + + # Test batch_hess_coord! (without y) + hess_coords = [zeros(bnlp.meta.nnzh) for _ = 1:n_models] + batch_hess_coord!(bnlp, xs, hess_coords) + manual_hess_coords = [hess_coord!(models[i], xs[i], zeros(bnlp.meta.nnzh)) for i = 1:n_models] + @test hess_coords ≈ manual_hess_coords + + # Test batch_hess_coord! (with y) + hess_coords = [zeros(bnlp.meta.nnzh) for _ = 1:n_models] + batch_hess_coord!(bnlp, xs, ys, hess_coords) + manual_hess_coords = + [hess_coord!(models[i], xs[i], ys[i], zeros(bnlp.meta.nnzh)) for i = 1:n_models] + @test hess_coords ≈ manual_hess_coords + + # Test batch_hess_coord! with obj_weights (without y) + hess_coords = [zeros(bnlp.meta.nnzh) for _ = 1:n_models] + batch_hess_coord!(bnlp, xs, hess_coords; obj_weights = obj_weights) + manual_hess_coords = [ + hess_coord!(models[i], xs[i], zeros(bnlp.meta.nnzh); obj_weight = obj_weights[i]) for + i = 1:n_models + ] + @test hess_coords ≈ manual_hess_coords + + # Test batch_hess_coord! with obj_weights (with y) + hess_coords = [zeros(bnlp.meta.nnzh) for _ = 1:n_models] + batch_hess_coord!(bnlp, xs, ys, hess_coords; obj_weights = obj_weights) + manual_hess_coords = [ + hess_coord!(models[i], xs[i], ys[i], zeros(bnlp.meta.nnzh); obj_weight = obj_weights[i]) + for i = 1:n_models + ] + @test hess_coords ≈ manual_hess_coords + + # Test batch_hprod (without y) + batch_hprods = batch_hprod(bnlp, xs, vs) + manual_hprods = [hprod(models[i], xs[i], vs[i]) for i = 1:n_models] + @test batch_hprods ≈ manual_hprods + + # Test batch_hprod (with y) + batch_hprods = batch_hprod(bnlp, xs, ys, vs) + manual_hprods = [hprod(models[i], xs[i], ys[i], vs[i]) for i = 1:n_models] + @test batch_hprods ≈ manual_hprods + + # Test batch_hprod with obj_weights (without y) + batch_hprods = batch_hprod(bnlp, xs, vs; obj_weights = obj_weights) + manual_hprods = + [hprod(models[i], xs[i], vs[i]; obj_weight = obj_weights[i]) for i = 1:n_models] + @test batch_hprods ≈ manual_hprods + + # Test batch_hprod with obj_weights (with y) + batch_hprods = batch_hprod(bnlp, xs, ys, vs; obj_weights = obj_weights) + manual_hprods = + [hprod(models[i], xs[i], ys[i], vs[i]; obj_weight = obj_weights[i]) for i = 1:n_models] + @test batch_hprods ≈ manual_hprods + + # Test batch_hprod! (without y) + hprods = [zeros(n) for _ = 1:n_models] + batch_hprod!(bnlp, xs, vs, hprods) + manual_hprods = [hprod!(models[i], xs[i], vs[i], zeros(n)) for i = 1:n_models] + @test hprods ≈ manual_hprods + + # Test batch_hprod! (with y) + hprods = [zeros(n) for _ = 1:n_models] + batch_hprod!(bnlp, xs, ys, vs, hprods) + manual_hprods = [hprod!(models[i], xs[i], ys[i], vs[i], zeros(n)) for i = 1:n_models] + @test hprods ≈ manual_hprods + + # Test batch_hprod! with obj_weights (without y) + hprods = [zeros(n) for _ = 1:n_models] + batch_hprod!(bnlp, xs, vs, hprods; obj_weights = obj_weights) + manual_hprods = + [hprod!(models[i], xs[i], vs[i], zeros(n); obj_weight = obj_weights[i]) for i = 1:n_models] + @test hprods ≈ manual_hprods + + # Test batch_hprod! with obj_weights (with y) + hprods = [zeros(n) for _ = 1:n_models] + batch_hprod!(bnlp, xs, ys, vs, hprods; obj_weights = obj_weights) + manual_hprods = [ + hprod!(models[i], xs[i], ys[i], vs[i], zeros(n); obj_weight = obj_weights[i]) for + i = 1:n_models + ] + @test hprods ≈ manual_hprods + + # Test batch_hess_op (without y) + batch_hess_ops = batch_hess_op(bnlp, xs) + manual_hess_ops = [hess_op(models[i], xs[i]) for i = 1:n_models] + for i = 1:n_models + @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] + end + + # Test batch_hess_op (with y) + batch_hess_ops = batch_hess_op(bnlp, xs, ys) + manual_hess_ops = [hess_op(models[i], xs[i], ys[i]) for i = 1:n_models] + for i = 1:n_models + @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] + end + + # Test batch_hess_op with obj_weights (without y) + batch_hess_ops = batch_hess_op(bnlp, xs; obj_weights = obj_weights) + manual_hess_ops = [hess_op(models[i], xs[i]; obj_weight = obj_weights[i]) for i = 1:n_models] + for i = 1:n_models + @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] + end + + # Test batch_hess_op with obj_weights (with y) + batch_hess_ops = batch_hess_op(bnlp, xs, ys; obj_weights = obj_weights) + manual_hess_ops = + [hess_op(models[i], xs[i], ys[i]; obj_weight = obj_weights[i]) for i = 1:n_models] + for i = 1:n_models + @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] + end + + # Test batch_hess_op! (without y) + hvs = [zeros(n) for _ = 1:n_models] + batch_hess_ops = batch_hess_op!(bnlp, xs, hvs) + manual_hess_ops = [hess_op!(models[i], xs[i], zeros(n)) for i = 1:n_models] + for i = 1:n_models + @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] + end + + # Test batch_hess_op! (with y) + hvs = [zeros(n) for _ = 1:n_models] + batch_hess_ops = batch_hess_op!(bnlp, xs, ys, hvs) + manual_hess_ops = [hess_op!(models[i], xs[i], ys[i], zeros(n)) for i = 1:n_models] + for i = 1:n_models + @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] + end + + # Test batch_hess_op! with obj_weights (without y) + hvs = [zeros(n) for _ = 1:n_models] + batch_hess_ops = batch_hess_op!(bnlp, xs, hvs; obj_weights = obj_weights) + manual_hess_ops = + [hess_op!(models[i], xs[i], zeros(n); obj_weight = obj_weights[i]) for i = 1:n_models] + for i = 1:n_models + @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] + end + + # Test batch_hess_op! with obj_weights (with y) + hvs = [zeros(n) for _ = 1:n_models] + batch_hess_ops = batch_hess_op!(bnlp, xs, ys, hvs; obj_weights = obj_weights) + manual_hess_ops = [ + hess_op!(models[i], xs[i], ys[i], zeros(n); obj_weight = obj_weights[i]) for i = 1:n_models + ] + for i = 1:n_models + @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] + end + + # Test batch_jth_con + j = 1 + batch_jth_cons = batch_jth_con(bnlp, xs, j) + manual_jth_cons = [jth_con(models[i], xs[i], j) for i = 1:n_models] + @test batch_jth_cons ≈ manual_jth_cons + + # Test batch_jth_congrad + batch_jth_congrads = batch_jth_congrad(bnlp, xs, j) + manual_jth_congrads = [jth_congrad(models[i], xs[i], j) for i = 1:n_models] + @test batch_jth_congrads ≈ manual_jth_congrads + + # Test batch_jth_congrad! + jth_congrads = [zeros(n) for _ = 1:n_models] + batch_jth_congrad!(bnlp, xs, j, jth_congrads) + manual_jth_congrads = [jth_congrad!(models[i], xs[i], j, zeros(n)) for i = 1:n_models] + @test jth_congrads ≈ manual_jth_congrads + + # Test batch_jth_sparse_congrad + batch_jth_sparse_congrads = batch_jth_sparse_congrad(bnlp, xs, j) + manual_jth_sparse_congrads = [jth_sparse_congrad(models[i], xs[i], j) for i = 1:n_models] + @test batch_jth_sparse_congrads ≈ manual_jth_sparse_congrads + + # Test batch_jth_hess_coord + batch_jth_hess_coords = batch_jth_hess_coord(bnlp, xs, j) + manual_jth_hess_coords = [jth_hess_coord(models[i], xs[i], j) for i = 1:n_models] + @test batch_jth_hess_coords ≈ manual_jth_hess_coords + + # Test batch_jth_hess_coord! + jth_hess_coords = [zeros(bnlp.meta.nnzh) for _ = 1:n_models] + batch_jth_hess_coord!(bnlp, xs, j, jth_hess_coords) + manual_jth_hess_coords = + [jth_hess_coord!(models[i], xs[i], j, zeros(bnlp.meta.nnzh)) for i = 1:n_models] + @test jth_hess_coords ≈ manual_jth_hess_coords + + # Test batch_jth_hess + batch_jth_hesses = batch_jth_hess(bnlp, xs, j) + manual_jth_hesses = [jth_hess(models[i], xs[i], j) for i = 1:n_models] + @test batch_jth_hesses ≈ manual_jth_hesses + + # Test batch_jth_hprod + batch_jth_hprods = batch_jth_hprod(bnlp, xs, vs, j) + manual_jth_hprods = [jth_hprod(models[i], xs[i], vs[i], j) for i = 1:n_models] + @test batch_jth_hprods ≈ manual_jth_hprods + + # Test batch_jth_hprod! + jth_hprods = [zeros(n) for _ = 1:n_models] + batch_jth_hprod!(bnlp, xs, vs, j, jth_hprods) + manual_jth_hprods = [jth_hprod!(models[i], xs[i], vs[i], j, zeros(n)) for i = 1:n_models] + @test jth_hprods ≈ manual_jth_hprods + + # Test batch_ghjvprod + batch_ghjvprods = batch_ghjvprod(bnlp, xs, gs, vs) + manual_ghjvprods = [ghjvprod(models[i], xs[i], gs[i], vs[i]) for i = 1:n_models] + @test batch_ghjvprods ≈ manual_ghjvprods + + # Test batch_ghjvprod! + ghjvprods = [zeros(m) for _ = 1:n_models] + batch_ghjvprod!(bnlp, xs, gs, vs, ghjvprods) + manual_ghjvprods = [ghjvprod!(models[i], xs[i], gs[i], vs[i], zeros(m)) for i = 1:n_models] + @test ghjvprods ≈ manual_ghjvprods + + # Test batch_jac_op + batch_jac_ops = batch_jac_op(bnlp, xs) + manual_jac_ops = [jac_op(models[i], xs[i]) for i = 1:n_models] + for i = 1:n_models + @test batch_jac_ops[i] * vs[i] ≈ manual_jac_ops[i] * vs[i] + @test batch_jac_ops[i]' * ws[i] ≈ manual_jac_ops[i]' * ws[i] + end + + # Test batch_jac_op! + jvs = [zeros(m) for _ = 1:n_models] + jtvs = [zeros(n) for _ = 1:n_models] + batch_jac_ops = batch_jac_op!(bnlp, xs, jvs, jtvs) + manual_jac_ops = [jac_op!(models[i], xs[i], zeros(m), zeros(n)) for i = 1:n_models] + for i = 1:n_models + @test batch_jac_ops[i] * vs[i] ≈ manual_jac_ops[i] * vs[i] + @test batch_jac_ops[i]' * ws[i] ≈ manual_jac_ops[i]' * ws[i] + end + + # Test batch_jac_lin_op + batch_jac_lin_ops = batch_jac_lin_op(bnlp) + manual_jac_lin_ops = [jac_lin_op(models[i]) for i = 1:n_models] + ws_lin_vec = ws[1][1:(bnlp.meta.nlin)] + for i = 1:n_models + @test batch_jac_lin_ops[i] * vs[i] ≈ manual_jac_lin_ops[i] * vs[i] + @test batch_jac_lin_ops[i]' * ws_lin_vec ≈ manual_jac_lin_ops[i]' * ws_lin_vec + end + + # Test batch_jac_lin_op! + jvs_lin = [zeros(bnlp.meta.nlin) for _ = 1:n_models] + jtvs_lin = [zeros(n) for _ = 1:n_models] + batch_jac_lin_ops = batch_jac_lin_op!(bnlp, jvs_lin, jtvs_lin) + manual_jac_lin_ops = + [jac_lin_op!(models[i], zeros(bnlp.meta.nlin), zeros(n)) for i = 1:n_models] + for i = 1:n_models + @test batch_jac_lin_ops[i] * vs[i] ≈ manual_jac_lin_ops[i] * vs[i] + @test batch_jac_lin_ops[i]' * ws_lin_vec ≈ manual_jac_lin_ops[i]' * ws_lin_vec + end + + # Test batch_jac_nln_op + batch_jac_nln_ops = batch_jac_nln_op(bnlp, xs) + manual_jac_nln_ops = [jac_nln_op(models[i], xs[i]) for i = 1:n_models] + ws_nln_vec = ws[1][(bnlp.meta.nlin + 1):end] + for i = 1:n_models + @test batch_jac_nln_ops[i] * vs[i] ≈ manual_jac_nln_ops[i] * vs[i] + @test batch_jac_nln_ops[i]' * ws_nln_vec ≈ manual_jac_nln_ops[i]' * ws_nln_vec + end + + # Test batch_jac_nln_op! + jvs_nln = [zeros(bnlp.meta.nnln) for _ = 1:n_models] + jtvs_nln = [zeros(n) for _ = 1:n_models] + batch_jac_nln_ops = batch_jac_nln_op!(bnlp, xs, jvs_nln, jtvs_nln) + manual_jac_nln_ops = + [jac_nln_op!(models[i], xs[i], zeros(bnlp.meta.nnln), zeros(n)) for i = 1:n_models] + for i = 1:n_models + @test batch_jac_nln_ops[i] * vs[i] ≈ manual_jac_nln_ops[i] * vs[i] + @test batch_jac_nln_ops[i]' * ws_nln_vec ≈ manual_jac_nln_ops[i]' * ws_nln_vec + end + + # Test batch_varscale, batch_lagscale, batch_conscale + batch_varscales = batch_varscale(bnlp) + manual_varscales = [varscale(models[i]) for i = 1:n_models] + @test batch_varscales ≈ manual_varscales + + batch_lagscales = batch_lagscale(bnlp) + manual_lagscales = [lagscale(models[i]) for i = 1:n_models] + @test batch_lagscales ≈ manual_lagscales + + batch_conscales = batch_conscale(bnlp) + manual_conscales = [conscale(models[i]) for i = 1:n_models] + @test batch_conscales ≈ manual_conscales + + # Test structure functions + first_model = first(models) + @test batch_jac_structure(bnlp) == jac_structure(first_model) + @test batch_jac_lin_structure(bnlp) == jac_lin_structure(first_model) + @test batch_jac_nln_structure(bnlp) == jac_nln_structure(first_model) + @test batch_hess_structure(bnlp) == hess_structure(first_model) + + rows, cols = jac_structure(first_model) + fill!(rows, 0) + fill!(cols, 0) + batch_jac_structure!(bnlp, rows, cols) + @test rows == jac_structure(first_model)[1] + @test cols == jac_structure(first_model)[2] + + rows, cols = jac_lin_structure(first_model) + fill!(rows, 0) + fill!(cols, 0) + batch_jac_lin_structure!(bnlp, rows, cols) + @test rows == jac_lin_structure(first_model)[1] + @test cols == jac_lin_structure(first_model)[2] + + rows, cols = jac_nln_structure(first_model) + fill!(rows, 0) + fill!(cols, 0) + batch_jac_nln_structure!(bnlp, rows, cols) + @test rows == jac_nln_structure(first_model)[1] + @test cols == jac_nln_structure(first_model)[2] + + rows, cols = hess_structure(first_model) + fill!(rows, 0) + fill!(cols, 0) + batch_hess_structure!(bnlp, rows, cols) + @test rows == hess_structure(first_model)[1] + @test cols == hess_structure(first_model)[2] + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 06639cf7..3b3ad64b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ include("nlp/simple-model.jl") include("nlp/dummy-model.jl") include("nlp/api.jl") +include("nlp/batch_api.jl") include("nlp/counters.jl") include("nlp/meta.jl") include("nlp/show.jl") From 7382d2188c0a74b7060f0595a61add967544445a Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Mon, 17 Nov 2025 13:42:26 -0500 Subject: [PATCH 02/13] add missing SimpleNLPModel methods --- test/nlp/simple-model.jl | 52 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/test/nlp/simple-model.jl b/test/nlp/simple-model.jl index d8fa9b84..12071267 100644 --- a/test/nlp/simple-model.jl +++ b/test/nlp/simple-model.jl @@ -236,3 +236,55 @@ function NLPModels.ghjvprod!( gHv .= [T(0); -g[1] * v[1] / 2 - 2 * g[2] * v[2]] return gHv end + +function NLPModels.jth_con(nlp::SimpleNLPModel, x::AbstractVector{T}, j::Integer) where {T} + @lencheck 2 x + @rangecheck 1 nlp.meta.ncon j + increment!(nlp, :neval_jcon) + if j == 1 + return x[1] - 2 * x[2] + 1 + elseif j == 2 + return -x[1]^2 / 4 - x[2]^2 + 1 + end +end + +function NLPModels.jth_congrad!( + nlp::SimpleNLPModel, + x::AbstractVector{T}, + j::Integer, + g::AbstractVector{T}, +) where {T} + @lencheck 2 x g + @rangecheck 1 nlp.meta.ncon j + increment!(nlp, :neval_jgrad) + if j == 1 + g .= [T(1); T(-2)] + elseif j == 2 + g .= [-x[1] / 2; -2 * x[2]] + end + return g +end + +function NLPModels.jth_sparse_congrad(nlp::SimpleNLPModel, x::AbstractVector{T}, j::Integer) where {T} + @lencheck 2 x + @rangecheck 1 nlp.meta.ncon j + increment!(nlp, :neval_jgrad) + if j == 1 + vals = [T(1); T(-2)] + elseif j == 2 + vals = [-x[1] / 2; -2 * x[2]] + end + return sparse([1, 1], [1, 2], vals, 1, nlp.meta.nvar) +end + +function NLPModels.varscale(nlp::SimpleNLPModel{T}) where {T} + return ones(T, nlp.meta.nvar) +end + +function NLPModels.lagscale(nlp::SimpleNLPModel{T}) where {T} + return ones(T, nlp.meta.ncon) +end + +function NLPModels.conscale(nlp::SimpleNLPModel{T}) where {T} + return ones(T, nlp.meta.ncon) +end From d6ca0adc399fcd7caec5d71df34714a72f6aa7b3 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Mon, 17 Nov 2025 17:26:56 -0500 Subject: [PATCH 03/13] counters --- src/nlp/batch/api.jl | 7 ++++++- src/nlp/batch/vector.jl | 10 +++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/nlp/batch/api.jl b/src/nlp/batch/api.jl index ada0e537..5c808d17 100644 --- a/src/nlp/batch/api.jl +++ b/src/nlp/batch/api.jl @@ -184,4 +184,9 @@ batch_objcons(bnlp::AbstractBatchNLPModel, xs::VV) = batch_objgrad!(bnlp::AbstractBatchNLPModel, xs::VV, gs::Vector) = _batch_map_tuple!(objgrad!, bnlp, gs, xs) batch_objcons!(bnlp::AbstractBatchNLPModel, xs::VV, cs::Vector) = - _batch_map_tuple!(objcons!, bnlp, cs, xs) \ No newline at end of file + _batch_map_tuple!(objcons!, bnlp, cs, xs) + +function NLPModels.increment!(bnlp::AbstractBatchNLPModel, fun::Symbol) + NLPModels.increment!(bnlp, Val(fun)) +end + \ No newline at end of file diff --git a/src/nlp/batch/vector.jl b/src/nlp/batch/vector.jl index 8b4e7c30..b939a663 100644 --- a/src/nlp/batch/vector.jl +++ b/src/nlp/batch/vector.jl @@ -1,13 +1,14 @@ export VectorBatchNLPModel struct VectorBatchNLPModel{T, S, M <: AbstractNLPModel{T, S}} <: AbstractBatchNLPModel{T, S} models::Vector{M} + counters::Counters meta::NLPModelMeta{T, S} end function VectorBatchNLPModel(models::Vector{M}) where {M <: AbstractNLPModel} isempty(models) && error("Cannot create VectorBatchNLPModel from empty vector") # TODO: check all metas the same, all structures same, etc. meta = first(models).meta - VectorBatchNLPModel{eltype(meta.x0), typeof(meta.x0), M}(models, meta) + VectorBatchNLPModel{eltype(meta.x0), typeof(meta.x0), M}(models, Counters(), meta) end Base.length(vnlp::VectorBatchNLPModel) = length(vnlp.models) Base.getindex(vnlp::VectorBatchNLPModel, i::Integer) = vnlp.models[i] @@ -80,4 +81,11 @@ function _batch_map_tuple!(f, bnlp::VectorBatchNLPModel, outputs::Vector, xs::VV firsts[i], _ = f(bnlp[i], args_i..., outputs[i]) end return firsts, outputs +end + +for fun in fieldnames(Counters) + @eval function NLPModels.increment!(bnlp::VectorBatchNLPModel, ::Val{$(Meta.quot(fun))}) + # sub-model counters are already incremented since we call their methods + bnlp.counters.$fun += 1 + end end \ No newline at end of file From b4b32eb5e2b98c2e9f232cfca25b4d5feda04c7d Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Tue, 18 Nov 2025 15:14:48 -0500 Subject: [PATCH 04/13] cleanup --- src/nlp/batch/api.jl | 132 +++++++++++++++++++++---------------------- 1 file changed, 66 insertions(+), 66 deletions(-) diff --git a/src/nlp/batch/api.jl b/src/nlp/batch/api.jl index 5c808d17..b8429e29 100644 --- a/src/nlp/batch/api.jl +++ b/src/nlp/batch/api.jl @@ -35,27 +35,27 @@ batch_jac_nln_structure!(bnlp::AbstractBatchNLPModel, rows, cols) = jac_nln_structure!(first(bnlp), rows, cols) batch_hess_structure!(bnlp::AbstractBatchNLPModel, rows, cols) = hess_structure!(first(bnlp), rows, cols) -batch_obj(bnlp::AbstractBatchNLPModel, xs::VV) = +batch_obj(bnlp::AbstractBatchNLPModel, xs) = _batch_map(obj, bnlp, xs) -batch_grad(bnlp::AbstractBatchNLPModel, xs::VV) = +batch_grad(bnlp::AbstractBatchNLPModel, xs) = _batch_map(grad, bnlp, xs) -batch_cons(bnlp::AbstractBatchNLPModel, xs::VV) = +batch_cons(bnlp::AbstractBatchNLPModel, xs) = _batch_map(cons, bnlp, xs) -batch_cons_lin(bnlp::AbstractBatchNLPModel, xs::VV) = +batch_cons_lin(bnlp::AbstractBatchNLPModel, xs) = _batch_map(cons_lin, bnlp, xs) -batch_cons_nln(bnlp::AbstractBatchNLPModel, xs::VV) = +batch_cons_nln(bnlp::AbstractBatchNLPModel, xs) = _batch_map(cons_nln, bnlp, xs) -batch_jac(bnlp::AbstractBatchNLPModel, xs::VV) = +batch_jac(bnlp::AbstractBatchNLPModel, xs) = _batch_map(jac, bnlp, xs) batch_jac_lin(bnlp::AbstractBatchNLPModel) = _batch_map(jac_lin, bnlp) -batch_jac_nln(bnlp::AbstractBatchNLPModel, xs::VV) = +batch_jac_nln(bnlp::AbstractBatchNLPModel, xs) = _batch_map(jac_nln, bnlp, xs) batch_jac_lin_coord(bnlp::AbstractBatchNLPModel) = _batch_map(jac_lin_coord, bnlp) -batch_jac_coord(bnlp::AbstractBatchNLPModel, xs::VV) = +batch_jac_coord(bnlp::AbstractBatchNLPModel, xs) = _batch_map(jac_coord, bnlp, xs) -batch_jac_nln_coord(bnlp::AbstractBatchNLPModel, xs::VV) = +batch_jac_nln_coord(bnlp::AbstractBatchNLPModel, xs) = _batch_map(jac_nln_coord, bnlp, xs) batch_varscale(bnlp::AbstractBatchNLPModel) = _batch_map(varscale, bnlp) @@ -63,127 +63,127 @@ batch_lagscale(bnlp::AbstractBatchNLPModel) = _batch_map(lagscale, bnlp) batch_conscale(bnlp::AbstractBatchNLPModel) = _batch_map(conscale, bnlp) -batch_jprod(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV) = +batch_jprod(bnlp::AbstractBatchNLPModel, xs, vs) = _batch_map(jprod, bnlp, xs, vs) -batch_jtprod(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV) = +batch_jtprod(bnlp::AbstractBatchNLPModel, xs, vs) = _batch_map(jtprod, bnlp, xs, vs) -batch_jprod_nln(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV) = +batch_jprod_nln(bnlp::AbstractBatchNLPModel, xs, vs) = _batch_map(jprod_nln, bnlp, xs, vs) -batch_jtprod_nln(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV) = +batch_jtprod_nln(bnlp::AbstractBatchNLPModel, xs, vs) = _batch_map(jtprod_nln, bnlp, xs, vs) -batch_jprod_lin(bnlp::AbstractBatchNLPModel, vs::VV) = +batch_jprod_lin(bnlp::AbstractBatchNLPModel, vs) = _batch_map(jprod_lin, bnlp, vs) -batch_jtprod_lin(bnlp::AbstractBatchNLPModel, vs::VV) = +batch_jtprod_lin(bnlp::AbstractBatchNLPModel, vs) = _batch_map(jtprod_lin, bnlp, vs) -batch_ghjvprod(bnlp::AbstractBatchNLPModel, xs::VV, gs::VV, vs::VV) = +batch_ghjvprod(bnlp::AbstractBatchNLPModel, xs, gs, vs) = _batch_map(ghjvprod, bnlp, xs, gs, vs) -batch_grad!(bnlp::AbstractBatchNLPModel, xs::VV, gs::Vector) = +batch_jac_lin_coord!(bnlp::AbstractBatchNLPModel, valss) = + _batch_map!(jac_lin_coord!, bnlp, valss) +batch_grad!(bnlp::AbstractBatchNLPModel, xs, gs) = _batch_map!((m, g, x) -> grad!(m, x, g), bnlp, gs, xs) -batch_cons!(bnlp::AbstractBatchNLPModel, xs::VV, cs::Vector) = +batch_cons!(bnlp::AbstractBatchNLPModel, xs, cs) = _batch_map!((m, c, x) -> cons!(m, x, c), bnlp, cs, xs) -batch_cons_lin!(bnlp::AbstractBatchNLPModel, xs::VV, cs::Vector) = +batch_cons_lin!(bnlp::AbstractBatchNLPModel, xs, cs) = _batch_map!((m, c, x) -> cons_lin!(m, x, c), bnlp, cs, xs) -batch_cons_nln!(bnlp::AbstractBatchNLPModel, xs::VV, cs::Vector) = +batch_cons_nln!(bnlp::AbstractBatchNLPModel, xs, cs) = _batch_map!((m, c, x) -> cons_nln!(m, x, c), bnlp, cs, xs) -batch_jac_lin_coord!(bnlp::AbstractBatchNLPModel, valss::Vector) = - _batch_map!((m, vals) -> jac_lin_coord!(m, vals), bnlp, valss) -batch_jac_coord!(bnlp::AbstractBatchNLPModel, xs::VV, valss::Vector) = +batch_jac_coord!(bnlp::AbstractBatchNLPModel, xs, valss) = _batch_map!((m, vals, x) -> jac_coord!(m, x, vals), bnlp, valss, xs) -batch_jac_nln_coord!(bnlp::AbstractBatchNLPModel, xs::VV, valss::Vector) = +batch_jac_nln_coord!(bnlp::AbstractBatchNLPModel, xs, valss) = _batch_map!((m, vals, x) -> jac_nln_coord!(m, x, vals), bnlp, valss, xs) -batch_jprod!(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV, Jvs::Vector) = +batch_jprod!(bnlp::AbstractBatchNLPModel, xs, vs, Jvs) = _batch_map!((m, Jv, x, v) -> jprod!(m, x, v, Jv), bnlp, Jvs, xs, vs) -batch_jtprod!(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV, Jtvs::Vector) = +batch_jtprod!(bnlp::AbstractBatchNLPModel, xs, vs, Jtvs) = _batch_map!((m, Jtv, x, v) -> jtprod!(m, x, v, Jtv), bnlp, Jtvs, xs, vs) -batch_jprod_nln!(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV, Jvs::Vector) = +batch_jprod_nln!(bnlp::AbstractBatchNLPModel, xs, vs, Jvs) = _batch_map!((m, Jv, x, v) -> jprod_nln!(m, x, v, Jv), bnlp, Jvs, xs, vs) -batch_jtprod_nln!(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV, Jtvs::Vector) = +batch_jtprod_nln!(bnlp::AbstractBatchNLPModel, xs, vs, Jtvs) = _batch_map!((m, Jtv, x, v) -> jtprod_nln!(m, x, v, Jtv), bnlp, Jtvs, xs, vs) -batch_jprod_lin!(bnlp::AbstractBatchNLPModel, vs::VV, Jvs::Vector) = +batch_jprod_lin!(bnlp::AbstractBatchNLPModel, vs, Jvs) = _batch_map!((m, Jv, v) -> jprod_lin!(m, v, Jv), bnlp, Jvs, vs) -batch_jtprod_lin!(bnlp::AbstractBatchNLPModel, vs::VV, Jtvs::Vector) = +batch_jtprod_lin!(bnlp::AbstractBatchNLPModel, vs, Jtvs) = _batch_map!((m, Jtv, v) -> jtprod_lin!(m, v, Jtv), bnlp, Jtvs, vs) -batch_ghjvprod!(bnlp::AbstractBatchNLPModel, xs::VV, gs::VV, vs::VV, gHvs::Vector) = +batch_ghjvprod!(bnlp::AbstractBatchNLPModel, xs, gs, vs, gHvs) = _batch_map!((m, gHv, x, g, v) -> ghjvprod!(m, x, g, v, gHv), bnlp, gHvs, xs, gs, vs) ## jth -batch_jth_con(bnlp::AbstractBatchNLPModel, xs::VV, j::Integer) = +batch_jth_con(bnlp::AbstractBatchNLPModel, xs, j::Integer) = _batch_map((m, x) -> jth_con(m, x, j), bnlp, xs) -batch_jth_congrad(bnlp::AbstractBatchNLPModel, xs::VV, j::Integer) = +batch_jth_congrad(bnlp::AbstractBatchNLPModel, xs, j::Integer) = _batch_map((m, x) -> jth_congrad(m, x, j), bnlp, xs) -batch_jth_sparse_congrad(bnlp::AbstractBatchNLPModel, xs::VV, j::Integer) = +batch_jth_sparse_congrad(bnlp::AbstractBatchNLPModel, xs, j::Integer) = _batch_map((m, x) -> jth_sparse_congrad(m, x, j), bnlp, xs) -batch_jth_hess_coord(bnlp::AbstractBatchNLPModel, xs::VV, j::Integer) = +batch_jth_hess_coord(bnlp::AbstractBatchNLPModel, xs, j::Integer) = _batch_map((m, x) -> jth_hess_coord(m, x, j), bnlp, xs) -batch_jth_hess(bnlp::AbstractBatchNLPModel, xs::VV, j::Integer) = +batch_jth_hess(bnlp::AbstractBatchNLPModel, xs, j::Integer) = _batch_map((m, x) -> jth_hess(m, x, j), bnlp, xs) -batch_jth_hprod(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV, j::Integer) = +batch_jth_hprod(bnlp::AbstractBatchNLPModel, xs, vs, j::Integer) = _batch_map((m, x, v) -> jth_hprod(m, x, v, j), bnlp, xs, vs) -batch_jth_congrad!(bnlp::AbstractBatchNLPModel, xs::VV, j::Integer, outputs::Vector) = +batch_jth_congrad!(bnlp::AbstractBatchNLPModel, xs, j::Integer, outputs) = _batch_map!((m, out, x) -> jth_congrad!(m, x, j, out), bnlp, outputs, xs) -batch_jth_hess_coord!(bnlp::AbstractBatchNLPModel, xs::VV, j::Integer, outputs::Vector) = +batch_jth_hess_coord!(bnlp::AbstractBatchNLPModel, xs, j::Integer, outputs) = _batch_map!((m, out, x) -> jth_hess_coord!(m, x, j, out), bnlp, outputs, xs) -batch_jth_hprod!(bnlp::AbstractBatchNLPModel, xs::VV, vs::VV, j::Integer, outputs::Vector) = +batch_jth_hprod!(bnlp::AbstractBatchNLPModel, xs, vs, j::Integer, outputs) = _batch_map!((m, out, x, v) -> jth_hprod!(m, x, v, j, out), bnlp, outputs, xs, vs) # hess (need to treat obj_weight) -batch_hprod(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, vs::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = +batch_hprod(bnlp::AbstractBatchNLPModel{T, S}, xs, vs; obj_weights = ones(T, length(bnlp))) where {T, S} = _batch_map_weight((m, x, v; obj_weight) -> hprod(m, x, v; obj_weight = obj_weight), bnlp, obj_weights, xs, vs) -batch_hprod(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, ys::VV, vs::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = +batch_hprod(bnlp::AbstractBatchNLPModel{T, S}, xs, ys, vs; obj_weights = ones(T, length(bnlp))) where {T, S} = _batch_map_weight((m, x, y, v; obj_weight) -> hprod(m, x, y, v; obj_weight = obj_weight), bnlp, obj_weights, xs, ys, vs) -batch_hess_coord(bnlp::AbstractBatchNLPModel{T, S}, xs::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = +batch_hess_coord(bnlp::AbstractBatchNLPModel{T, S}, xs; obj_weights = ones(T, length(bnlp))) where {T, S} = _batch_map_weight((m, x; obj_weight) -> hess_coord(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) -batch_hess_coord(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, ys::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = +batch_hess_coord(bnlp::AbstractBatchNLPModel{T, S}, xs, ys; obj_weights = ones(T, length(bnlp))) where {T, S} = _batch_map_weight((m, x, y; obj_weight) -> hess_coord(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) -batch_hess_op(bnlp::AbstractBatchNLPModel{T, S}, xs::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = +batch_hess_op(bnlp::AbstractBatchNLPModel{T, S}, xs; obj_weights = ones(T, length(bnlp))) where {T, S} = _batch_map_weight((m, x; obj_weight) -> hess_op(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) -batch_hess_op(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, ys::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = +batch_hess_op(bnlp::AbstractBatchNLPModel{T, S}, xs, ys; obj_weights = ones(T, length(bnlp))) where {T, S} = _batch_map_weight((m, x, y; obj_weight) -> hess_op(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) -batch_hprod!(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, vs::VV, outputs::Vector; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = +batch_hprod!(bnlp::AbstractBatchNLPModel{T, S}, xs, vs, outputs; obj_weights = ones(T, length(bnlp))) where {T, S} = _batch_map_weight!((m, Hv, x, v; obj_weight) -> hprod!(m, x, v, Hv; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, vs) -batch_hprod!(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, ys::VV, vs::VV, outputs::Vector; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = +batch_hprod!(bnlp::AbstractBatchNLPModel{T, S}, xs, ys, vs, outputs; obj_weights = ones(T, length(bnlp))) where {T, S} = _batch_map_weight!((m, Hv, x, y, v; obj_weight) -> hprod!(m, x, y, v, Hv; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, ys, vs) -batch_hess_coord!(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, outputs::Vector; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = +batch_hess_coord!(bnlp::AbstractBatchNLPModel{T, S}, xs, outputs; obj_weights = ones(T, length(bnlp))) where {T, S} = _batch_map_weight!((m, vals, x; obj_weight) -> hess_coord!(m, x, vals; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs) -batch_hess_coord!(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, ys::VV, outputs::Vector; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = +batch_hess_coord!(bnlp::AbstractBatchNLPModel{T, S}, xs, ys, outputs; obj_weights = ones(T, length(bnlp))) where {T, S} = _batch_map_weight!((m, vals, x, y; obj_weight) -> hess_coord!(m, x, y, vals; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, ys) -batch_hess_op!(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, Hvs::Vector; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = +batch_hess_op!(bnlp::AbstractBatchNLPModel{T, S}, xs, Hvs; obj_weights = ones(T, length(bnlp))) where {T, S} = _batch_map_weight((m, x, Hv; obj_weight) -> hess_op!(m, x, Hv; obj_weight = obj_weight), bnlp, obj_weights, xs, Hvs) -batch_hess_op!(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, ys::VV, Hvs::Vector; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = +batch_hess_op!(bnlp::AbstractBatchNLPModel{T, S}, xs, ys, Hvs; obj_weights = ones(T, length(bnlp))) where {T, S} = _batch_map_weight((m, x, y, Hv; obj_weight) -> hess_op!(m, x, y, Hv; obj_weight = obj_weight), bnlp, obj_weights, xs, ys, Hvs) -batch_hess(bnlp::AbstractBatchNLPModel{T, S}, xs::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = +batch_hess(bnlp::AbstractBatchNLPModel{T, S}, xs; obj_weights = ones(T, length(bnlp))) where {T, S} = _batch_map_weight((m, x; obj_weight) -> hess(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) -batch_hess(bnlp::AbstractBatchNLPModel{T, S}, xs::VV, ys::VV; obj_weights::Vector{<:Real} = ones(T, length(bnlp))) where {T, S} = +batch_hess(bnlp::AbstractBatchNLPModel{T, S}, xs, ys; obj_weights = ones(T, length(bnlp))) where {T, S} = _batch_map_weight((m, x, y; obj_weight) -> hess(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) ## operators -batch_jac_op(bnlp::AbstractBatchNLPModel, xs::VV) = +batch_jac_op(bnlp::AbstractBatchNLPModel, xs) = _batch_map(jac_op, bnlp, xs) batch_jac_lin_op(bnlp::AbstractBatchNLPModel) = _batch_map(jac_lin_op, bnlp) -batch_jac_nln_op(bnlp::AbstractBatchNLPModel, xs::VV) = +batch_jac_nln_op(bnlp::AbstractBatchNLPModel, xs) = _batch_map(jac_nln_op, bnlp, xs) -batch_jac_op!(bnlp::AbstractBatchNLPModel, xs::VV, Jvs::Vector, Jtvs::Vector) = - _batch_map((m, x, Jv, Jtv) -> jac_op!(m, x, Jv, Jtv), bnlp, xs, Jvs, Jtvs) -batch_jac_lin_op!(bnlp::AbstractBatchNLPModel, Jvs::Vector, Jtvs::Vector) = - _batch_map((m, Jv, Jtv) -> jac_lin_op!(m, Jv, Jtv), bnlp, Jvs, Jtvs) -batch_jac_nln_op!(bnlp::AbstractBatchNLPModel, xs::VV, Jvs::Vector, Jtvs::Vector) = - _batch_map((m, x, Jv, Jtv) -> jac_nln_op!(m, x, Jv, Jtv), bnlp, xs, Jvs, Jtvs) +batch_jac_op!(bnlp::AbstractBatchNLPModel, xs, Jvs, Jtvs) = + _batch_map(jac_op!, bnlp, xs, Jvs, Jtvs) +batch_jac_lin_op!(bnlp::AbstractBatchNLPModel, Jvs, Jtvs) = + _batch_map(jac_lin_op!, bnlp, Jvs, Jtvs) +batch_jac_nln_op!(bnlp::AbstractBatchNLPModel, xs, Jvs, Jtvs) = + _batch_map(jac_nln_op!, bnlp, xs, Jvs, Jtvs) ## tuple functions -batch_objgrad(bnlp::AbstractBatchNLPModel, xs::VV) = +batch_objgrad(bnlp::AbstractBatchNLPModel, xs) = _batch_map_tuple(objgrad, bnlp, xs) -batch_objcons(bnlp::AbstractBatchNLPModel, xs::VV) = +batch_objcons(bnlp::AbstractBatchNLPModel, xs) = _batch_map_tuple(objcons, bnlp, xs) -batch_objgrad!(bnlp::AbstractBatchNLPModel, xs::VV, gs::Vector) = +batch_objgrad!(bnlp::AbstractBatchNLPModel, xs, gs) = _batch_map_tuple!(objgrad!, bnlp, gs, xs) -batch_objcons!(bnlp::AbstractBatchNLPModel, xs::VV, cs::Vector) = +batch_objcons!(bnlp::AbstractBatchNLPModel, xs, cs) = _batch_map_tuple!(objcons!, bnlp, cs, xs) function NLPModels.increment!(bnlp::AbstractBatchNLPModel, fun::Symbol) From a4165c97d32ddbe8370b26e126d08eeb81b702e1 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 20 Nov 2025 09:43:01 -0500 Subject: [PATCH 05/13] update --- src/NLPModels.jl | 2 +- src/nlp/batch/api.jl | 173 +----------------------- src/nlp/batch/foreach.jl | 277 +++++++++++++++++++++++++++++++++++++++ src/nlp/batch/vector.jl | 91 ------------- src/nlp/utils.jl | 20 ++- test/nlp/batch_api.jl | 235 ++++++++++++--------------------- test/nlp/dummy-model.jl | 2 +- test/nlp/utils.jl | 16 ++- 8 files changed, 401 insertions(+), 415 deletions(-) create mode 100644 src/nlp/batch/foreach.jl delete mode 100644 src/nlp/batch/vector.jl diff --git a/src/NLPModels.jl b/src/NLPModels.jl index a5e2d117..6aeeb086 100644 --- a/src/NLPModels.jl +++ b/src/NLPModels.jl @@ -43,6 +43,6 @@ for f in ["utils", "api", "counters", "meta", "show", "tools"] end include("nlp/batch/api.jl") -include("nlp/batch/vector.jl") +include("nlp/batch/foreach.jl") end # module diff --git a/src/nlp/batch/api.jl b/src/nlp/batch/api.jl index b8429e29..4ca10bdf 100644 --- a/src/nlp/batch/api.jl +++ b/src/nlp/batch/api.jl @@ -1,5 +1,3 @@ -const VV = Vector{<:AbstractVector} - export AbstractBatchNLPModel export batch_obj, batch_grad, batch_grad!, batch_objgrad, batch_objgrad!, batch_objcons, batch_objcons! export batch_cons, batch_cons!, batch_cons_lin, batch_cons_lin!, batch_cons_nln, batch_cons_nln! @@ -16,177 +14,8 @@ export batch_hess_structure!, batch_hess_structure, batch_hess_coord!, batch_hes export batch_hess, batch_hprod, batch_hprod!, batch_hess_op, batch_hess_op! export batch_varscale, batch_lagscale, batch_conscale -abstract type AbstractBatchNLPModel{T, S} end - -## base api -batch_jac_structure(bnlp::AbstractBatchNLPModel) = - jac_structure(first(bnlp)) -batch_jac_lin_structure(bnlp::AbstractBatchNLPModel) = - jac_lin_structure(first(bnlp)) -batch_jac_nln_structure(bnlp::AbstractBatchNLPModel) = - jac_nln_structure(first(bnlp)) -batch_hess_structure(bnlp::AbstractBatchNLPModel) = - hess_structure(first(bnlp)) -batch_jac_structure!(bnlp::AbstractBatchNLPModel, rows, cols) = - jac_structure!(first(bnlp), rows, cols) -batch_jac_lin_structure!(bnlp::AbstractBatchNLPModel, rows, cols) = - jac_lin_structure!(first(bnlp), rows, cols) -batch_jac_nln_structure!(bnlp::AbstractBatchNLPModel, rows, cols) = - jac_nln_structure!(first(bnlp), rows, cols) -batch_hess_structure!(bnlp::AbstractBatchNLPModel, rows, cols) = - hess_structure!(first(bnlp), rows, cols) -batch_obj(bnlp::AbstractBatchNLPModel, xs) = - _batch_map(obj, bnlp, xs) -batch_grad(bnlp::AbstractBatchNLPModel, xs) = - _batch_map(grad, bnlp, xs) -batch_cons(bnlp::AbstractBatchNLPModel, xs) = - _batch_map(cons, bnlp, xs) -batch_cons_lin(bnlp::AbstractBatchNLPModel, xs) = - _batch_map(cons_lin, bnlp, xs) -batch_cons_nln(bnlp::AbstractBatchNLPModel, xs) = - _batch_map(cons_nln, bnlp, xs) -batch_jac(bnlp::AbstractBatchNLPModel, xs) = - _batch_map(jac, bnlp, xs) -batch_jac_lin(bnlp::AbstractBatchNLPModel) = - _batch_map(jac_lin, bnlp) -batch_jac_nln(bnlp::AbstractBatchNLPModel, xs) = - _batch_map(jac_nln, bnlp, xs) -batch_jac_lin_coord(bnlp::AbstractBatchNLPModel) = - _batch_map(jac_lin_coord, bnlp) -batch_jac_coord(bnlp::AbstractBatchNLPModel, xs) = - _batch_map(jac_coord, bnlp, xs) -batch_jac_nln_coord(bnlp::AbstractBatchNLPModel, xs) = - _batch_map(jac_nln_coord, bnlp, xs) -batch_varscale(bnlp::AbstractBatchNLPModel) = - _batch_map(varscale, bnlp) -batch_lagscale(bnlp::AbstractBatchNLPModel) = - _batch_map(lagscale, bnlp) -batch_conscale(bnlp::AbstractBatchNLPModel) = - _batch_map(conscale, bnlp) -batch_jprod(bnlp::AbstractBatchNLPModel, xs, vs) = - _batch_map(jprod, bnlp, xs, vs) -batch_jtprod(bnlp::AbstractBatchNLPModel, xs, vs) = - _batch_map(jtprod, bnlp, xs, vs) -batch_jprod_nln(bnlp::AbstractBatchNLPModel, xs, vs) = - _batch_map(jprod_nln, bnlp, xs, vs) -batch_jtprod_nln(bnlp::AbstractBatchNLPModel, xs, vs) = - _batch_map(jtprod_nln, bnlp, xs, vs) -batch_jprod_lin(bnlp::AbstractBatchNLPModel, vs) = - _batch_map(jprod_lin, bnlp, vs) -batch_jtprod_lin(bnlp::AbstractBatchNLPModel, vs) = - _batch_map(jtprod_lin, bnlp, vs) -batch_ghjvprod(bnlp::AbstractBatchNLPModel, xs, gs, vs) = - _batch_map(ghjvprod, bnlp, xs, gs, vs) - -batch_jac_lin_coord!(bnlp::AbstractBatchNLPModel, valss) = - _batch_map!(jac_lin_coord!, bnlp, valss) -batch_grad!(bnlp::AbstractBatchNLPModel, xs, gs) = - _batch_map!((m, g, x) -> grad!(m, x, g), bnlp, gs, xs) -batch_cons!(bnlp::AbstractBatchNLPModel, xs, cs) = - _batch_map!((m, c, x) -> cons!(m, x, c), bnlp, cs, xs) -batch_cons_lin!(bnlp::AbstractBatchNLPModel, xs, cs) = - _batch_map!((m, c, x) -> cons_lin!(m, x, c), bnlp, cs, xs) -batch_cons_nln!(bnlp::AbstractBatchNLPModel, xs, cs) = - _batch_map!((m, c, x) -> cons_nln!(m, x, c), bnlp, cs, xs) -batch_jac_coord!(bnlp::AbstractBatchNLPModel, xs, valss) = - _batch_map!((m, vals, x) -> jac_coord!(m, x, vals), bnlp, valss, xs) -batch_jac_nln_coord!(bnlp::AbstractBatchNLPModel, xs, valss) = - _batch_map!((m, vals, x) -> jac_nln_coord!(m, x, vals), bnlp, valss, xs) -batch_jprod!(bnlp::AbstractBatchNLPModel, xs, vs, Jvs) = - _batch_map!((m, Jv, x, v) -> jprod!(m, x, v, Jv), bnlp, Jvs, xs, vs) -batch_jtprod!(bnlp::AbstractBatchNLPModel, xs, vs, Jtvs) = - _batch_map!((m, Jtv, x, v) -> jtprod!(m, x, v, Jtv), bnlp, Jtvs, xs, vs) -batch_jprod_nln!(bnlp::AbstractBatchNLPModel, xs, vs, Jvs) = - _batch_map!((m, Jv, x, v) -> jprod_nln!(m, x, v, Jv), bnlp, Jvs, xs, vs) -batch_jtprod_nln!(bnlp::AbstractBatchNLPModel, xs, vs, Jtvs) = - _batch_map!((m, Jtv, x, v) -> jtprod_nln!(m, x, v, Jtv), bnlp, Jtvs, xs, vs) -batch_jprod_lin!(bnlp::AbstractBatchNLPModel, vs, Jvs) = - _batch_map!((m, Jv, v) -> jprod_lin!(m, v, Jv), bnlp, Jvs, vs) -batch_jtprod_lin!(bnlp::AbstractBatchNLPModel, vs, Jtvs) = - _batch_map!((m, Jtv, v) -> jtprod_lin!(m, v, Jtv), bnlp, Jtvs, vs) -batch_ghjvprod!(bnlp::AbstractBatchNLPModel, xs, gs, vs, gHvs) = - _batch_map!((m, gHv, x, g, v) -> ghjvprod!(m, x, g, v, gHv), bnlp, gHvs, xs, gs, vs) - -## jth -batch_jth_con(bnlp::AbstractBatchNLPModel, xs, j::Integer) = - _batch_map((m, x) -> jth_con(m, x, j), bnlp, xs) -batch_jth_congrad(bnlp::AbstractBatchNLPModel, xs, j::Integer) = - _batch_map((m, x) -> jth_congrad(m, x, j), bnlp, xs) -batch_jth_sparse_congrad(bnlp::AbstractBatchNLPModel, xs, j::Integer) = - _batch_map((m, x) -> jth_sparse_congrad(m, x, j), bnlp, xs) -batch_jth_hess_coord(bnlp::AbstractBatchNLPModel, xs, j::Integer) = - _batch_map((m, x) -> jth_hess_coord(m, x, j), bnlp, xs) -batch_jth_hess(bnlp::AbstractBatchNLPModel, xs, j::Integer) = - _batch_map((m, x) -> jth_hess(m, x, j), bnlp, xs) -batch_jth_hprod(bnlp::AbstractBatchNLPModel, xs, vs, j::Integer) = - _batch_map((m, x, v) -> jth_hprod(m, x, v, j), bnlp, xs, vs) - -batch_jth_congrad!(bnlp::AbstractBatchNLPModel, xs, j::Integer, outputs) = - _batch_map!((m, out, x) -> jth_congrad!(m, x, j, out), bnlp, outputs, xs) -batch_jth_hess_coord!(bnlp::AbstractBatchNLPModel, xs, j::Integer, outputs) = - _batch_map!((m, out, x) -> jth_hess_coord!(m, x, j, out), bnlp, outputs, xs) -batch_jth_hprod!(bnlp::AbstractBatchNLPModel, xs, vs, j::Integer, outputs) = - _batch_map!((m, out, x, v) -> jth_hprod!(m, x, v, j, out), bnlp, outputs, xs, vs) - -# hess (need to treat obj_weight) -batch_hprod(bnlp::AbstractBatchNLPModel{T, S}, xs, vs; obj_weights = ones(T, length(bnlp))) where {T, S} = - _batch_map_weight((m, x, v; obj_weight) -> hprod(m, x, v; obj_weight = obj_weight), bnlp, obj_weights, xs, vs) -batch_hprod(bnlp::AbstractBatchNLPModel{T, S}, xs, ys, vs; obj_weights = ones(T, length(bnlp))) where {T, S} = - _batch_map_weight((m, x, y, v; obj_weight) -> hprod(m, x, y, v; obj_weight = obj_weight), bnlp, obj_weights, xs, ys, vs) -batch_hess_coord(bnlp::AbstractBatchNLPModel{T, S}, xs; obj_weights = ones(T, length(bnlp))) where {T, S} = - _batch_map_weight((m, x; obj_weight) -> hess_coord(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) -batch_hess_coord(bnlp::AbstractBatchNLPModel{T, S}, xs, ys; obj_weights = ones(T, length(bnlp))) where {T, S} = - _batch_map_weight((m, x, y; obj_weight) -> hess_coord(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) -batch_hess_op(bnlp::AbstractBatchNLPModel{T, S}, xs; obj_weights = ones(T, length(bnlp))) where {T, S} = - _batch_map_weight((m, x; obj_weight) -> hess_op(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) -batch_hess_op(bnlp::AbstractBatchNLPModel{T, S}, xs, ys; obj_weights = ones(T, length(bnlp))) where {T, S} = - _batch_map_weight((m, x, y; obj_weight) -> hess_op(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) - -batch_hprod!(bnlp::AbstractBatchNLPModel{T, S}, xs, vs, outputs; obj_weights = ones(T, length(bnlp))) where {T, S} = - _batch_map_weight!((m, Hv, x, v; obj_weight) -> hprod!(m, x, v, Hv; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, vs) -batch_hprod!(bnlp::AbstractBatchNLPModel{T, S}, xs, ys, vs, outputs; obj_weights = ones(T, length(bnlp))) where {T, S} = - _batch_map_weight!((m, Hv, x, y, v; obj_weight) -> hprod!(m, x, y, v, Hv; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, ys, vs) -batch_hess_coord!(bnlp::AbstractBatchNLPModel{T, S}, xs, outputs; obj_weights = ones(T, length(bnlp))) where {T, S} = - _batch_map_weight!((m, vals, x; obj_weight) -> hess_coord!(m, x, vals; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs) -batch_hess_coord!(bnlp::AbstractBatchNLPModel{T, S}, xs, ys, outputs; obj_weights = ones(T, length(bnlp))) where {T, S} = - _batch_map_weight!((m, vals, x, y; obj_weight) -> hess_coord!(m, x, y, vals; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, ys) -batch_hess_op!(bnlp::AbstractBatchNLPModel{T, S}, xs, Hvs; obj_weights = ones(T, length(bnlp))) where {T, S} = - _batch_map_weight((m, x, Hv; obj_weight) -> hess_op!(m, x, Hv; obj_weight = obj_weight), bnlp, obj_weights, xs, Hvs) -batch_hess_op!(bnlp::AbstractBatchNLPModel{T, S}, xs, ys, Hvs; obj_weights = ones(T, length(bnlp))) where {T, S} = - _batch_map_weight((m, x, y, Hv; obj_weight) -> hess_op!(m, x, y, Hv; obj_weight = obj_weight), bnlp, obj_weights, xs, ys, Hvs) - -batch_hess(bnlp::AbstractBatchNLPModel{T, S}, xs; obj_weights = ones(T, length(bnlp))) where {T, S} = - _batch_map_weight((m, x; obj_weight) -> hess(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) -batch_hess(bnlp::AbstractBatchNLPModel{T, S}, xs, ys; obj_weights = ones(T, length(bnlp))) where {T, S} = - _batch_map_weight((m, x, y; obj_weight) -> hess(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) - -## operators -batch_jac_op(bnlp::AbstractBatchNLPModel, xs) = - _batch_map(jac_op, bnlp, xs) -batch_jac_lin_op(bnlp::AbstractBatchNLPModel) = - _batch_map(jac_lin_op, bnlp) -batch_jac_nln_op(bnlp::AbstractBatchNLPModel, xs) = - _batch_map(jac_nln_op, bnlp, xs) - -batch_jac_op!(bnlp::AbstractBatchNLPModel, xs, Jvs, Jtvs) = - _batch_map(jac_op!, bnlp, xs, Jvs, Jtvs) -batch_jac_lin_op!(bnlp::AbstractBatchNLPModel, Jvs, Jtvs) = - _batch_map(jac_lin_op!, bnlp, Jvs, Jtvs) -batch_jac_nln_op!(bnlp::AbstractBatchNLPModel, xs, Jvs, Jtvs) = - _batch_map(jac_nln_op!, bnlp, xs, Jvs, Jtvs) - -## tuple functions -batch_objgrad(bnlp::AbstractBatchNLPModel, xs) = - _batch_map_tuple(objgrad, bnlp, xs) -batch_objcons(bnlp::AbstractBatchNLPModel, xs) = - _batch_map_tuple(objcons, bnlp, xs) - -batch_objgrad!(bnlp::AbstractBatchNLPModel, xs, gs) = - _batch_map_tuple!(objgrad!, bnlp, gs, xs) -batch_objcons!(bnlp::AbstractBatchNLPModel, xs, cs) = - _batch_map_tuple!(objcons!, bnlp, cs, xs) +abstract type AbstractBatchNLPModel end function NLPModels.increment!(bnlp::AbstractBatchNLPModel, fun::Symbol) NLPModels.increment!(bnlp, Val(fun)) end - \ No newline at end of file diff --git a/src/nlp/batch/foreach.jl b/src/nlp/batch/foreach.jl new file mode 100644 index 00000000..ee5c89f0 --- /dev/null +++ b/src/nlp/batch/foreach.jl @@ -0,0 +1,277 @@ +export ForEachBatchNLPModel +struct ForEachBatchNLPModel{M} <: AbstractBatchNLPModel + models::M + counters::Counters + batch_size::Int +end +function ForEachBatchNLPModel(models::M) where {M} + isempty(models) && error("Cannot create ForEachBatchNLPModel from empty collection.") + ForEachBatchNLPModel{M}(models, Counters(), length(models)) +end +Base.length(vnlp::ForEachBatchNLPModel) = length(vnlp.models) +Base.getindex(vnlp::ForEachBatchNLPModel, i::Integer) = vnlp.models[i] +Base.iterate(vnlp::ForEachBatchNLPModel, state::Integer = 1) = iterate(vnlp.models, state) + + +function _batch_map(f::F, bnlp::ForEachBatchNLPModel, xs::Vararg{T,N}) where {F,T,N} + n = bnlp.batch_size + @lencheck_tup n xs + results = [] + resize!(results, n) + for i = 1:n + args_i = (x[i] for x in xs) + results[i] = f(bnlp[i], args_i...) + end + return results +end + +function _batch_map!(f::F, bnlp::ForEachBatchNLPModel, outputs, xs::Vararg{T,N}) where {F,T,N} + n = bnlp.batch_size + @lencheck_tup n xs + @lencheck n outputs + for i = 1:n + args_i = (x[i] for x in xs) + f(bnlp[i], outputs[i], args_i...) + end + return outputs +end + +function _batch_map_weight(f::F, bnlp::ForEachBatchNLPModel, obj_weights, xs::Vararg{T,N}) where {F,T,N} + n = bnlp.batch_size + @lencheck_tup n xs + @lencheck n obj_weights + results = [] + resize!(results, n) + for i = 1:n + args_i = (x[i] for x in xs) + results[i] = f(bnlp[i], args_i...; obj_weight = obj_weights[i]) + end + return results +end + +function _batch_map_weight!(f::F, bnlp::ForEachBatchNLPModel, outputs, obj_weights, xs::Vararg{T,N}) where {F,T,N} + n = bnlp.batch_size + @lencheck_tup n xs + @lencheck n outputs obj_weights + for i = 1:n + args_i = (x[i] for x in xs) + f(bnlp[i], outputs[i], args_i...; obj_weight = obj_weights[i]) + end + return outputs +end + +function _batch_map_tuple(f::F, bnlp::ForEachBatchNLPModel, xs::Vararg{T,N}) where {F,T,N} + n = bnlp.batch_size + @lencheck_tup n xs + results = _batch_map(f, bnlp, xs...) + + first_result = first(results) + T1, T2 = typeof(first_result[1]), typeof(first_result[2]) + vec1, vec2 = Vector{T1}(undef, n), Vector{T2}(undef, n) + for i = 1:n + vec1[i], vec2[i] = results[i] + end + return vec1, vec2 +end + +function _batch_map_tuple!(f::F, bnlp::ForEachBatchNLPModel, outputs, xs::Vararg{T,N}) where {F,T,N} + n = bnlp.batch_size + @lencheck_tup n xs + @lencheck n outputs + firsts = [] + resize!(firsts, n) + for i = 1:n + args_i = (x[i] for x in xs) + firsts[i], _ = f(bnlp[i], args_i..., outputs[i]) + end + return firsts, outputs +end + +for fun in fieldnames(Counters) + @eval function NLPModels.increment!(bnlp::ForEachBatchNLPModel, ::Val{$(Meta.quot(fun))}) + # sub-model counters are already incremented since we call their methods + bnlp.counters.$fun += 1 + end +end + +# There are two options for defining "special cases": +# 1. define batch_func(::MyBatchModel) +# 2. define _batch_map(f::func, ::MyBatchModel, ...) +# in most cases, using the first option is preferable. +# however, when overriding several functions at a time, +# for example if you know all the jac/hess structures are identical, one can write +# +# function NLPModels._batch_map( +# f::F, +# bnlp::MyBatchModel +# ) where {F<:Union{jac_structure,jac_lin_structure,jac_nln_structure,hess_structure}} +# +# return f(first(bnlp)) +# end + +batch_jac_structure(bnlp::ForEachBatchNLPModel) = + _batch_map(jac_structure, bnlp) +batch_jac_lin_structure(bnlp::ForEachBatchNLPModel) = + _batch_map(jac_lin_structure, bnlp) +batch_jac_nln_structure(bnlp::ForEachBatchNLPModel) = + _batch_map(jac_nln_structure, bnlp) +batch_hess_structure(bnlp::ForEachBatchNLPModel) = + _batch_map(hess_structure, bnlp) +batch_obj(bnlp::ForEachBatchNLPModel, xs) = + _batch_map(obj, bnlp, xs) +batch_grad(bnlp::ForEachBatchNLPModel, xs) = + _batch_map(grad, bnlp, xs) +batch_cons(bnlp::ForEachBatchNLPModel, xs) = + _batch_map(cons, bnlp, xs) +batch_cons_lin(bnlp::ForEachBatchNLPModel, xs) = + _batch_map(cons_lin, bnlp, xs) +batch_cons_nln(bnlp::ForEachBatchNLPModel, xs) = + _batch_map(cons_nln, bnlp, xs) +batch_jac(bnlp::ForEachBatchNLPModel, xs) = + _batch_map(jac, bnlp, xs) +batch_jac_lin(bnlp::ForEachBatchNLPModel) = + _batch_map(jac_lin, bnlp) +batch_jac_nln(bnlp::ForEachBatchNLPModel, xs) = + _batch_map(jac_nln, bnlp, xs) +batch_jac_lin_coord(bnlp::ForEachBatchNLPModel) = + _batch_map(jac_lin_coord, bnlp) +batch_jac_coord(bnlp::ForEachBatchNLPModel, xs) = + _batch_map(jac_coord, bnlp, xs) +batch_jac_nln_coord(bnlp::ForEachBatchNLPModel, xs) = + _batch_map(jac_nln_coord, bnlp, xs) +batch_varscale(bnlp::ForEachBatchNLPModel) = + _batch_map(varscale, bnlp) +batch_lagscale(bnlp::ForEachBatchNLPModel) = + _batch_map(lagscale, bnlp) +batch_conscale(bnlp::ForEachBatchNLPModel) = + _batch_map(conscale, bnlp) +batch_jprod(bnlp::ForEachBatchNLPModel, xs, vs) = + _batch_map(jprod, bnlp, xs, vs) +batch_jtprod(bnlp::ForEachBatchNLPModel, xs, vs) = + _batch_map(jtprod, bnlp, xs, vs) +batch_jprod_nln(bnlp::ForEachBatchNLPModel, xs, vs) = + _batch_map(jprod_nln, bnlp, xs, vs) +batch_jtprod_nln(bnlp::ForEachBatchNLPModel, xs, vs) = + _batch_map(jtprod_nln, bnlp, xs, vs) +batch_jprod_lin(bnlp::ForEachBatchNLPModel, vs) = + _batch_map(jprod_lin, bnlp, vs) +batch_jtprod_lin(bnlp::ForEachBatchNLPModel, vs) = + _batch_map(jtprod_lin, bnlp, vs) +batch_ghjvprod(bnlp::ForEachBatchNLPModel, xs, gs, vs) = + _batch_map(ghjvprod, bnlp, xs, gs, vs) + +batch_jac_structure!(bnlp::ForEachBatchNLPModel, rowss, colss) = + _batch_map!(jac_structure!, bnlp, rowss, colss) +batch_jac_lin_structure!(bnlp::ForEachBatchNLPModel, rowss, colss) = + _batch_map!(jac_lin_structure!, bnlp, rowss, colss) +batch_jac_nln_structure!(bnlp::ForEachBatchNLPModel, rowss, colss) = + _batch_map!(jac_nln_structure!, bnlp, rowss, colss) +batch_hess_structure!(bnlp::ForEachBatchNLPModel, rowss, colss) = + _batch_map!(hess_structure!, bnlp, rowss, colss) +batch_jac_lin_coord!(bnlp::ForEachBatchNLPModel, valss) = + _batch_map!(jac_lin_coord!, bnlp, valss) +batch_grad!(bnlp::ForEachBatchNLPModel, xs, gs) = + _batch_map!((m, g, x) -> grad!(m, x, g), bnlp, gs, xs) +batch_cons!(bnlp::ForEachBatchNLPModel, xs, cs) = + _batch_map!((m, c, x) -> cons!(m, x, c), bnlp, cs, xs) +batch_cons_lin!(bnlp::ForEachBatchNLPModel, xs, cs) = + _batch_map!((m, c, x) -> cons_lin!(m, x, c), bnlp, cs, xs) +batch_cons_nln!(bnlp::ForEachBatchNLPModel, xs, cs) = + _batch_map!((m, c, x) -> cons_nln!(m, x, c), bnlp, cs, xs) +batch_jac_coord!(bnlp::ForEachBatchNLPModel, xs, valss) = + _batch_map!((m, vals, x) -> jac_coord!(m, x, vals), bnlp, valss, xs) +batch_jac_nln_coord!(bnlp::ForEachBatchNLPModel, xs, valss) = + _batch_map!((m, vals, x) -> jac_nln_coord!(m, x, vals), bnlp, valss, xs) +batch_jprod!(bnlp::ForEachBatchNLPModel, xs, vs, Jvs) = + _batch_map!((m, Jv, x, v) -> jprod!(m, x, v, Jv), bnlp, Jvs, xs, vs) +batch_jtprod!(bnlp::ForEachBatchNLPModel, xs, vs, Jtvs) = + _batch_map!((m, Jtv, x, v) -> jtprod!(m, x, v, Jtv), bnlp, Jtvs, xs, vs) +batch_jprod_nln!(bnlp::ForEachBatchNLPModel, xs, vs, Jvs) = + _batch_map!((m, Jv, x, v) -> jprod_nln!(m, x, v, Jv), bnlp, Jvs, xs, vs) +batch_jtprod_nln!(bnlp::ForEachBatchNLPModel, xs, vs, Jtvs) = + _batch_map!((m, Jtv, x, v) -> jtprod_nln!(m, x, v, Jtv), bnlp, Jtvs, xs, vs) +batch_jprod_lin!(bnlp::ForEachBatchNLPModel, vs, Jvs) = + _batch_map!((m, Jv, v) -> jprod_lin!(m, v, Jv), bnlp, Jvs, vs) +batch_jtprod_lin!(bnlp::ForEachBatchNLPModel, vs, Jtvs) = + _batch_map!((m, Jtv, v) -> jtprod_lin!(m, v, Jtv), bnlp, Jtvs, vs) +batch_ghjvprod!(bnlp::ForEachBatchNLPModel, xs, gs, vs, gHvs) = + _batch_map!((m, gHv, x, g, v) -> ghjvprod!(m, x, g, v, gHv), bnlp, gHvs, xs, gs, vs) + +## jth +batch_jth_con(bnlp::ForEachBatchNLPModel, xs, j::Integer) = + _batch_map((m, x) -> jth_con(m, x, j), bnlp, xs) +batch_jth_congrad(bnlp::ForEachBatchNLPModel, xs, j::Integer) = + _batch_map((m, x) -> jth_congrad(m, x, j), bnlp, xs) +batch_jth_sparse_congrad(bnlp::ForEachBatchNLPModel, xs, j::Integer) = + _batch_map((m, x) -> jth_sparse_congrad(m, x, j), bnlp, xs) +batch_jth_hess_coord(bnlp::ForEachBatchNLPModel, xs, j::Integer) = + _batch_map((m, x) -> jth_hess_coord(m, x, j), bnlp, xs) +batch_jth_hess(bnlp::ForEachBatchNLPModel, xs, j::Integer) = + _batch_map((m, x) -> jth_hess(m, x, j), bnlp, xs) +batch_jth_hprod(bnlp::ForEachBatchNLPModel, xs, vs, j::Integer) = + _batch_map((m, x, v) -> jth_hprod(m, x, v, j), bnlp, xs, vs) + +batch_jth_congrad!(bnlp::ForEachBatchNLPModel, xs, j::Integer, outputs) = + _batch_map!((m, out, x) -> jth_congrad!(m, x, j, out), bnlp, outputs, xs) +batch_jth_hess_coord!(bnlp::ForEachBatchNLPModel, xs, j::Integer, outputs) = + _batch_map!((m, out, x) -> jth_hess_coord!(m, x, j, out), bnlp, outputs, xs) +batch_jth_hprod!(bnlp::ForEachBatchNLPModel, xs, vs, j::Integer, outputs) = + _batch_map!((m, out, x, v) -> jth_hprod!(m, x, v, j, out), bnlp, outputs, xs, vs) + +# hess (need to treat obj_weight) FIXME: container type.. +batch_hprod(bnlp::ForEachBatchNLPModel, xs, vs; obj_weights) = + _batch_map_weight((m, x, v; obj_weight) -> hprod(m, x, v; obj_weight = obj_weight), bnlp, obj_weights, xs, vs) +batch_hprod(bnlp::ForEachBatchNLPModel, xs, ys, vs; obj_weights) = + _batch_map_weight((m, x, y, v; obj_weight) -> hprod(m, x, y, v; obj_weight = obj_weight), bnlp, obj_weights, xs, ys, vs) +batch_hess_coord(bnlp::ForEachBatchNLPModel, xs; obj_weights) = + _batch_map_weight((m, x; obj_weight) -> hess_coord(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) +batch_hess_coord(bnlp::ForEachBatchNLPModel, xs, ys; obj_weights) = + _batch_map_weight((m, x, y; obj_weight) -> hess_coord(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) +batch_hess_op(bnlp::ForEachBatchNLPModel, xs; obj_weights) = + _batch_map_weight((m, x; obj_weight) -> hess_op(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) +batch_hess_op(bnlp::ForEachBatchNLPModel, xs, ys; obj_weights) = + _batch_map_weight((m, x, y; obj_weight) -> hess_op(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) + +batch_hprod!(bnlp::ForEachBatchNLPModel, xs, vs, outputs; obj_weights) = + _batch_map_weight!((m, Hv, x, v; obj_weight) -> hprod!(m, x, v, Hv; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, vs) +batch_hprod!(bnlp::ForEachBatchNLPModel, xs, ys, vs, outputs; obj_weights) = + _batch_map_weight!((m, Hv, x, y, v; obj_weight) -> hprod!(m, x, y, v, Hv; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, ys, vs) +batch_hess_coord!(bnlp::ForEachBatchNLPModel, xs, outputs; obj_weights) = + _batch_map_weight!((m, vals, x; obj_weight) -> hess_coord!(m, x, vals; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs) +batch_hess_coord!(bnlp::ForEachBatchNLPModel, xs, ys, outputs; obj_weights) = + _batch_map_weight!((m, vals, x, y; obj_weight) -> hess_coord!(m, x, y, vals; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, ys) +batch_hess_op!(bnlp::ForEachBatchNLPModel, xs, Hvs; obj_weights) = + _batch_map_weight((m, x, Hv; obj_weight) -> hess_op!(m, x, Hv; obj_weight = obj_weight), bnlp, obj_weights, xs, Hvs) +batch_hess_op!(bnlp::ForEachBatchNLPModel, xs, ys, Hvs; obj_weights) = + _batch_map_weight((m, x, y, Hv; obj_weight) -> hess_op!(m, x, y, Hv; obj_weight = obj_weight), bnlp, obj_weights, xs, ys, Hvs) + +batch_hess(bnlp::ForEachBatchNLPModel, xs; obj_weights) = + _batch_map_weight((m, x; obj_weight) -> hess(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) +batch_hess(bnlp::ForEachBatchNLPModel, xs, ys; obj_weights) = + _batch_map_weight((m, x, y; obj_weight) -> hess(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) + +## operators +batch_jac_op(bnlp::ForEachBatchNLPModel, xs) = + _batch_map(jac_op, bnlp, xs) +batch_jac_lin_op(bnlp::ForEachBatchNLPModel) = + _batch_map(jac_lin_op, bnlp) +batch_jac_nln_op(bnlp::ForEachBatchNLPModel, xs) = + _batch_map(jac_nln_op, bnlp, xs) + +batch_jac_op!(bnlp::ForEachBatchNLPModel, xs, Jvs, Jtvs) = + _batch_map(jac_op!, bnlp, xs, Jvs, Jtvs) +batch_jac_lin_op!(bnlp::ForEachBatchNLPModel, Jvs, Jtvs) = + _batch_map(jac_lin_op!, bnlp, Jvs, Jtvs) +batch_jac_nln_op!(bnlp::ForEachBatchNLPModel, xs, Jvs, Jtvs) = + _batch_map(jac_nln_op!, bnlp, xs, Jvs, Jtvs) + +## tuple functions +batch_objgrad(bnlp::ForEachBatchNLPModel, xs) = + _batch_map_tuple(objgrad, bnlp, xs) +batch_objcons(bnlp::ForEachBatchNLPModel, xs) = + _batch_map_tuple(objcons, bnlp, xs) + +batch_objgrad!(bnlp::ForEachBatchNLPModel, xs, gs) = + _batch_map_tuple!(objgrad!, bnlp, gs, xs) +batch_objcons!(bnlp::ForEachBatchNLPModel, xs, cs) = + _batch_map_tuple!(objcons!, bnlp, cs, xs) diff --git a/src/nlp/batch/vector.jl b/src/nlp/batch/vector.jl deleted file mode 100644 index b939a663..00000000 --- a/src/nlp/batch/vector.jl +++ /dev/null @@ -1,91 +0,0 @@ -export VectorBatchNLPModel -struct VectorBatchNLPModel{T, S, M <: AbstractNLPModel{T, S}} <: AbstractBatchNLPModel{T, S} - models::Vector{M} - counters::Counters - meta::NLPModelMeta{T, S} -end -function VectorBatchNLPModel(models::Vector{M}) where {M <: AbstractNLPModel} - isempty(models) && error("Cannot create VectorBatchNLPModel from empty vector") - # TODO: check all metas the same, all structures same, etc. - meta = first(models).meta - VectorBatchNLPModel{eltype(meta.x0), typeof(meta.x0), M}(models, Counters(), meta) -end -Base.length(vnlp::VectorBatchNLPModel) = length(vnlp.models) -Base.getindex(vnlp::VectorBatchNLPModel, i::Integer) = vnlp.models[i] -Base.iterate(vnlp::VectorBatchNLPModel, state::Integer = 1) = iterate(vnlp.models, state) - -function _batch_map(f, bnlp::VectorBatchNLPModel, xs::VV...) - n = length(bnlp) - results = Vector{Any}(undef, n) - for i = 1:n - args_i = (x[i] for x in xs) - results[i] = f(bnlp[i], args_i...) - end - return results -end - -function _batch_map!(f, bnlp::VectorBatchNLPModel, outputs::Vector, xs::VV...) - n = length(bnlp) - for i = 1:n - args_i = (x[i] for x in xs) - f(bnlp[i], outputs[i], args_i...) - end - return outputs -end - -function _batch_map_weight(f, bnlp::VectorBatchNLPModel, obj_weights::Vector, xs::VV...) - n = length(bnlp) - results = Vector{Any}(undef, n) - for i = 1:n - args_i = (x[i] for x in xs) - results[i] = f(bnlp[i], args_i...; obj_weight = obj_weights[i]) - end - return results -end - -function _batch_map_weight!( - f, - bnlp::VectorBatchNLPModel, - outputs::Vector, - obj_weights::Vector, - xs::VV..., -) - n = length(bnlp) - for i = 1:n - args_i = (x[i] for x in xs) - f(bnlp[i], outputs[i], args_i...; obj_weight = obj_weights[i]) - end - return outputs -end - -function _batch_map_tuple(f, bnlp::VectorBatchNLPModel, xs::VV...) - n = length(bnlp) - results = _batch_map(f, bnlp, xs...) - # Get types from first result - first_result = results[1] - T1 = typeof(first_result[1]) - T2 = typeof(first_result[2]) - vec1 = Vector{T1}(undef, n) - vec2 = Vector{T2}(undef, n) - for i = 1:n - vec1[i], vec2[i] = results[i] - end - return vec1, vec2 -end - -function _batch_map_tuple!(f, bnlp::VectorBatchNLPModel, outputs::Vector, xs::VV...) - n = length(bnlp) - firsts = Vector{eltype(bnlp.meta.x0)}(undef, n) - for i = 1:n - args_i = (x[i] for x in xs) - firsts[i], _ = f(bnlp[i], args_i..., outputs[i]) - end - return firsts, outputs -end - -for fun in fieldnames(Counters) - @eval function NLPModels.increment!(bnlp::VectorBatchNLPModel, ::Val{$(Meta.quot(fun))}) - # sub-model counters are already incremented since we call their methods - bnlp.counters.$fun += 1 - end -end \ No newline at end of file diff --git a/src/nlp/utils.jl b/src/nlp/utils.jl index 8f2b4a1b..795b7be5 100644 --- a/src/nlp/utils.jl +++ b/src/nlp/utils.jl @@ -1,6 +1,6 @@ export coo_prod!, coo_sym_prod! export @default_counters -export DimensionError, @lencheck, @rangecheck +export DimensionError, @lencheck, @lencheck_tup, @rangecheck """ DimensionError <: Exception @@ -41,6 +41,24 @@ macro lencheck(l, vars...) Expr(:block, exprs...) end +""" + @lencheck_tup n xs + +Check that the entries contained in `xs` all have length `n`. +""" +macro lencheck_tup(l, tup) + tupname = string(tup) + quote + _expected_len = $(esc(l)) + _vars = $(esc(tup)) + for (_idx, _var) in enumerate(_vars) + if length(_var) != _expected_len + throw(DimensionError(string($tupname, "[", _idx, "]"), _expected_len, length(_var))) + end + end + end +end + """ @rangecheck ℓ u i j k … diff --git a/test/nlp/batch_api.jl b/test/nlp/batch_api.jl index 4043dd4b..5691ddb2 100644 --- a/test/nlp/batch_api.jl +++ b/test/nlp/batch_api.jl @@ -1,8 +1,10 @@ @testset "Batch API" begin # Generate models + # TODO: non-identical models n_models = 5 models = [SimpleNLPModel() for _ = 1:n_models] - n, m = models[1].meta.nvar, models[1].meta.ncon + meta = models[1].meta + n, m = meta.nvar, meta.ncon xs = [randn(n) for _ = 1:n_models] ys = [randn(m) for _ = 1:n_models] vs = [randn(n) for _ = 1:n_models] @@ -10,7 +12,7 @@ gs = [zeros(n) for _ = 1:n_models] cs = [zeros(m) for _ = 1:n_models] obj_weights = rand(n_models) - for batch_model in [VectorBatchNLPModel] + for batch_model in [ForEachBatchNLPModel] @testset "$batch_model consistency" begin bnlp = batch_model(models) @@ -59,9 +61,9 @@ @test batch_cs_lin ≈ manual_cs_lin # Test batch_cons_lin! - cs_lin = [zeros(bnlp.meta.nlin) for _ = 1:n_models] + cs_lin = [zeros(meta.nlin) for _ = 1:n_models] batch_cons_lin!(bnlp, xs, cs_lin) - manual_cs_lin = [cons_lin!(models[i], xs[i], zeros(bnlp.meta.nlin)) for i = 1:n_models] + manual_cs_lin = [cons_lin!(models[i], xs[i], zeros(meta.nlin)) for i = 1:n_models] @test cs_lin ≈ manual_cs_lin # Test batch_cons_nln @@ -70,9 +72,9 @@ @test batch_cs_nln ≈ manual_cs_nln # Test batch_cons_nln! - cs_nln = [zeros(bnlp.meta.nnln) for _ = 1:n_models] + cs_nln = [zeros(meta.nnln) for _ = 1:n_models] batch_cons_nln!(bnlp, xs, cs_nln) - manual_cs_nln = [cons_nln!(models[i], xs[i], zeros(bnlp.meta.nnln)) for i = 1:n_models] + manual_cs_nln = [cons_nln!(models[i], xs[i], zeros(meta.nnln)) for i = 1:n_models] @test cs_nln ≈ manual_cs_nln # Test batch_objcons @@ -100,9 +102,9 @@ @test batch_jac_coords ≈ manual_jac_coords # Test batch_jac_coord! - jac_coords = [zeros(bnlp.meta.nnzj) for _ = 1:n_models] + jac_coords = [zeros(meta.nnzj) for _ = 1:n_models] batch_jac_coord!(bnlp, xs, jac_coords) - manual_jac_coords = [jac_coord!(models[i], xs[i], zeros(bnlp.meta.nnzj)) for i = 1:n_models] + manual_jac_coords = [jac_coord!(models[i], xs[i], zeros(meta.nnzj)) for i = 1:n_models] @test jac_coords ≈ manual_jac_coords # Test batch_jac_lin @@ -116,10 +118,10 @@ @test batch_jac_lin_coords ≈ manual_jac_lin_coords # Test batch_jac_lin_coord! - jac_lin_coords = [zeros(bnlp.meta.lin_nnzj) for _ = 1:n_models] + jac_lin_coords = [zeros(meta.lin_nnzj) for _ = 1:n_models] batch_jac_lin_coord!(bnlp, jac_lin_coords) manual_jac_lin_coords = - [jac_lin_coord!(models[i], zeros(bnlp.meta.lin_nnzj)) for i = 1:n_models] + [jac_lin_coord!(models[i], zeros(meta.lin_nnzj)) for i = 1:n_models] @test jac_lin_coords ≈ manual_jac_lin_coords # Test batch_jac_nln @@ -133,10 +135,10 @@ @test batch_jac_nln_coords ≈ manual_jac_nln_coords # Test batch_jac_nln_coord! - jac_nln_coords = [zeros(bnlp.meta.nln_nnzj) for _ = 1:n_models] + jac_nln_coords = [zeros(meta.nln_nnzj) for _ = 1:n_models] batch_jac_nln_coord!(bnlp, xs, jac_nln_coords) manual_jac_nln_coords = - [jac_nln_coord!(models[i], xs[i], zeros(bnlp.meta.nln_nnzj)) for i = 1:n_models] + [jac_nln_coord!(models[i], xs[i], zeros(meta.nln_nnzj)) for i = 1:n_models] @test jac_nln_coords ≈ manual_jac_nln_coords # Test batch_jprod @@ -167,13 +169,13 @@ @test batch_jprod_lins ≈ manual_jprod_lins # Test batch_jprod_lin! - jprod_lins = [zeros(bnlp.meta.nlin) for _ = 1:n_models] + jprod_lins = [zeros(meta.nlin) for _ = 1:n_models] batch_jprod_lin!(bnlp, vs, jprod_lins) - manual_jprod_lins = [jprod_lin!(models[i], vs[i], zeros(bnlp.meta.nlin)) for i = 1:n_models] + manual_jprod_lins = [jprod_lin!(models[i], vs[i], zeros(meta.nlin)) for i = 1:n_models] @test jprod_lins ≈ manual_jprod_lins # Test batch_jtprod_lin - ws_lin = [ws[i][1:(bnlp.meta.nlin)] for i = 1:n_models] + ws_lin = [ws[i][1:(meta.nlin)] for i = 1:n_models] batch_jtprod_lins = batch_jtprod_lin(bnlp, ws_lin) manual_jtprod_lins = [jtprod_lin(models[i], ws_lin[i]) for i = 1:n_models] @test batch_jtprod_lins ≈ manual_jtprod_lins @@ -190,14 +192,14 @@ @test batch_jprod_nlns ≈ manual_jprod_nlns # Test batch_jprod_nln! - jprod_nlns = [zeros(bnlp.meta.nnln) for _ = 1:n_models] + jprod_nlns = [zeros(meta.nnln) for _ = 1:n_models] batch_jprod_nln!(bnlp, xs, vs, jprod_nlns) manual_jprod_nlns = - [jprod_nln!(models[i], xs[i], vs[i], zeros(bnlp.meta.nnln)) for i = 1:n_models] + [jprod_nln!(models[i], xs[i], vs[i], zeros(meta.nnln)) for i = 1:n_models] @test jprod_nlns ≈ manual_jprod_nlns # Test batch_jtprod_nln - ws_nln = [ws[i][(bnlp.meta.nlin + 1):end] for i = 1:n_models] + ws_nln = [ws[i][(meta.nlin + 1):end] for i = 1:n_models] batch_jtprod_nlns = batch_jtprod_nln(bnlp, xs, ws_nln) manual_jtprod_nlns = [jtprod_nln(models[i], xs[i], ws_nln[i]) for i = 1:n_models] @test batch_jtprod_nlns ≈ manual_jtprod_nlns @@ -208,16 +210,6 @@ manual_jtprod_nlns = [jtprod_nln!(models[i], xs[i], ws_nln[i], zeros(n)) for i = 1:n_models] @test jtprod_nlns ≈ manual_jtprod_nlns - # Test batch_hess (without y) - batch_hesses = batch_hess(bnlp, xs) - manual_hesses = [hess(models[i], xs[i]) for i = 1:n_models] - @test batch_hesses ≈ manual_hesses - - # Test batch_hess (with y) - batch_hesses = batch_hess(bnlp, xs, ys) - manual_hesses = [hess(models[i], xs[i], ys[i]) for i = 1:n_models] - @test batch_hesses ≈ manual_hesses - # Test batch_hess with obj_weights (without y) batch_hesses = batch_hess(bnlp, xs; obj_weights = obj_weights) manual_hesses = [hess(models[i], xs[i]; obj_weight = obj_weights[i]) for i = 1:n_models] @@ -229,16 +221,6 @@ [hess(models[i], xs[i], ys[i]; obj_weight = obj_weights[i]) for i = 1:n_models] @test batch_hesses ≈ manual_hesses - # Test batch_hess_coord (without y) - batch_hess_coords = batch_hess_coord(bnlp, xs) - manual_hess_coords = [hess_coord(models[i], xs[i]) for i = 1:n_models] - @test batch_hess_coords ≈ manual_hess_coords - - # Test batch_hess_coord (with y) - batch_hess_coords = batch_hess_coord(bnlp, xs, ys) - manual_hess_coords = [hess_coord(models[i], xs[i], ys[i]) for i = 1:n_models] - @test batch_hess_coords ≈ manual_hess_coords - # Test batch_hess_coord with obj_weights (without y) batch_hess_coords = batch_hess_coord(bnlp, xs; obj_weights = obj_weights) manual_hess_coords = @@ -251,47 +233,24 @@ [hess_coord(models[i], xs[i], ys[i]; obj_weight = obj_weights[i]) for i = 1:n_models] @test batch_hess_coords ≈ manual_hess_coords - # Test batch_hess_coord! (without y) - hess_coords = [zeros(bnlp.meta.nnzh) for _ = 1:n_models] - batch_hess_coord!(bnlp, xs, hess_coords) - manual_hess_coords = [hess_coord!(models[i], xs[i], zeros(bnlp.meta.nnzh)) for i = 1:n_models] - @test hess_coords ≈ manual_hess_coords - - # Test batch_hess_coord! (with y) - hess_coords = [zeros(bnlp.meta.nnzh) for _ = 1:n_models] - batch_hess_coord!(bnlp, xs, ys, hess_coords) - manual_hess_coords = - [hess_coord!(models[i], xs[i], ys[i], zeros(bnlp.meta.nnzh)) for i = 1:n_models] - @test hess_coords ≈ manual_hess_coords - # Test batch_hess_coord! with obj_weights (without y) - hess_coords = [zeros(bnlp.meta.nnzh) for _ = 1:n_models] + hess_coords = [zeros(meta.nnzh) for _ = 1:n_models] batch_hess_coord!(bnlp, xs, hess_coords; obj_weights = obj_weights) manual_hess_coords = [ - hess_coord!(models[i], xs[i], zeros(bnlp.meta.nnzh); obj_weight = obj_weights[i]) for + hess_coord!(models[i], xs[i], zeros(meta.nnzh); obj_weight = obj_weights[i]) for i = 1:n_models ] @test hess_coords ≈ manual_hess_coords # Test batch_hess_coord! with obj_weights (with y) - hess_coords = [zeros(bnlp.meta.nnzh) for _ = 1:n_models] + hess_coords = [zeros(meta.nnzh) for _ = 1:n_models] batch_hess_coord!(bnlp, xs, ys, hess_coords; obj_weights = obj_weights) manual_hess_coords = [ - hess_coord!(models[i], xs[i], ys[i], zeros(bnlp.meta.nnzh); obj_weight = obj_weights[i]) + hess_coord!(models[i], xs[i], ys[i], zeros(meta.nnzh); obj_weight = obj_weights[i]) for i = 1:n_models ] @test hess_coords ≈ manual_hess_coords - # Test batch_hprod (without y) - batch_hprods = batch_hprod(bnlp, xs, vs) - manual_hprods = [hprod(models[i], xs[i], vs[i]) for i = 1:n_models] - @test batch_hprods ≈ manual_hprods - - # Test batch_hprod (with y) - batch_hprods = batch_hprod(bnlp, xs, ys, vs) - manual_hprods = [hprod(models[i], xs[i], ys[i], vs[i]) for i = 1:n_models] - @test batch_hprods ≈ manual_hprods - # Test batch_hprod with obj_weights (without y) batch_hprods = batch_hprod(bnlp, xs, vs; obj_weights = obj_weights) manual_hprods = @@ -304,18 +263,6 @@ [hprod(models[i], xs[i], ys[i], vs[i]; obj_weight = obj_weights[i]) for i = 1:n_models] @test batch_hprods ≈ manual_hprods - # Test batch_hprod! (without y) - hprods = [zeros(n) for _ = 1:n_models] - batch_hprod!(bnlp, xs, vs, hprods) - manual_hprods = [hprod!(models[i], xs[i], vs[i], zeros(n)) for i = 1:n_models] - @test hprods ≈ manual_hprods - - # Test batch_hprod! (with y) - hprods = [zeros(n) for _ = 1:n_models] - batch_hprod!(bnlp, xs, ys, vs, hprods) - manual_hprods = [hprod!(models[i], xs[i], ys[i], vs[i], zeros(n)) for i = 1:n_models] - @test hprods ≈ manual_hprods - # Test batch_hprod! with obj_weights (without y) hprods = [zeros(n) for _ = 1:n_models] batch_hprod!(bnlp, xs, vs, hprods; obj_weights = obj_weights) @@ -332,20 +279,6 @@ ] @test hprods ≈ manual_hprods - # Test batch_hess_op (without y) - batch_hess_ops = batch_hess_op(bnlp, xs) - manual_hess_ops = [hess_op(models[i], xs[i]) for i = 1:n_models] - for i = 1:n_models - @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] - end - - # Test batch_hess_op (with y) - batch_hess_ops = batch_hess_op(bnlp, xs, ys) - manual_hess_ops = [hess_op(models[i], xs[i], ys[i]) for i = 1:n_models] - for i = 1:n_models - @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] - end - # Test batch_hess_op with obj_weights (without y) batch_hess_ops = batch_hess_op(bnlp, xs; obj_weights = obj_weights) manual_hess_ops = [hess_op(models[i], xs[i]; obj_weight = obj_weights[i]) for i = 1:n_models] @@ -361,22 +294,6 @@ @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] end - # Test batch_hess_op! (without y) - hvs = [zeros(n) for _ = 1:n_models] - batch_hess_ops = batch_hess_op!(bnlp, xs, hvs) - manual_hess_ops = [hess_op!(models[i], xs[i], zeros(n)) for i = 1:n_models] - for i = 1:n_models - @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] - end - - # Test batch_hess_op! (with y) - hvs = [zeros(n) for _ = 1:n_models] - batch_hess_ops = batch_hess_op!(bnlp, xs, ys, hvs) - manual_hess_ops = [hess_op!(models[i], xs[i], ys[i], zeros(n)) for i = 1:n_models] - for i = 1:n_models - @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] - end - # Test batch_hess_op! with obj_weights (without y) hvs = [zeros(n) for _ = 1:n_models] batch_hess_ops = batch_hess_op!(bnlp, xs, hvs; obj_weights = obj_weights) @@ -424,10 +341,10 @@ @test batch_jth_hess_coords ≈ manual_jth_hess_coords # Test batch_jth_hess_coord! - jth_hess_coords = [zeros(bnlp.meta.nnzh) for _ = 1:n_models] + jth_hess_coords = [zeros(meta.nnzh) for _ = 1:n_models] batch_jth_hess_coord!(bnlp, xs, j, jth_hess_coords) manual_jth_hess_coords = - [jth_hess_coord!(models[i], xs[i], j, zeros(bnlp.meta.nnzh)) for i = 1:n_models] + [jth_hess_coord!(models[i], xs[i], j, zeros(meta.nnzh)) for i = 1:n_models] @test jth_hess_coords ≈ manual_jth_hess_coords # Test batch_jth_hess @@ -478,18 +395,18 @@ # Test batch_jac_lin_op batch_jac_lin_ops = batch_jac_lin_op(bnlp) manual_jac_lin_ops = [jac_lin_op(models[i]) for i = 1:n_models] - ws_lin_vec = ws[1][1:(bnlp.meta.nlin)] + ws_lin_vec = ws[1][1:(meta.nlin)] for i = 1:n_models @test batch_jac_lin_ops[i] * vs[i] ≈ manual_jac_lin_ops[i] * vs[i] @test batch_jac_lin_ops[i]' * ws_lin_vec ≈ manual_jac_lin_ops[i]' * ws_lin_vec end # Test batch_jac_lin_op! - jvs_lin = [zeros(bnlp.meta.nlin) for _ = 1:n_models] + jvs_lin = [zeros(meta.nlin) for _ = 1:n_models] jtvs_lin = [zeros(n) for _ = 1:n_models] batch_jac_lin_ops = batch_jac_lin_op!(bnlp, jvs_lin, jtvs_lin) manual_jac_lin_ops = - [jac_lin_op!(models[i], zeros(bnlp.meta.nlin), zeros(n)) for i = 1:n_models] + [jac_lin_op!(models[i], zeros(meta.nlin), zeros(n)) for i = 1:n_models] for i = 1:n_models @test batch_jac_lin_ops[i] * vs[i] ≈ manual_jac_lin_ops[i] * vs[i] @test batch_jac_lin_ops[i]' * ws_lin_vec ≈ manual_jac_lin_ops[i]' * ws_lin_vec @@ -498,18 +415,18 @@ # Test batch_jac_nln_op batch_jac_nln_ops = batch_jac_nln_op(bnlp, xs) manual_jac_nln_ops = [jac_nln_op(models[i], xs[i]) for i = 1:n_models] - ws_nln_vec = ws[1][(bnlp.meta.nlin + 1):end] + ws_nln_vec = ws[1][(meta.nlin + 1):end] for i = 1:n_models @test batch_jac_nln_ops[i] * vs[i] ≈ manual_jac_nln_ops[i] * vs[i] @test batch_jac_nln_ops[i]' * ws_nln_vec ≈ manual_jac_nln_ops[i]' * ws_nln_vec end # Test batch_jac_nln_op! - jvs_nln = [zeros(bnlp.meta.nnln) for _ = 1:n_models] + jvs_nln = [zeros(meta.nnln) for _ = 1:n_models] jtvs_nln = [zeros(n) for _ = 1:n_models] batch_jac_nln_ops = batch_jac_nln_op!(bnlp, xs, jvs_nln, jtvs_nln) manual_jac_nln_ops = - [jac_nln_op!(models[i], xs[i], zeros(bnlp.meta.nnln), zeros(n)) for i = 1:n_models] + [jac_nln_op!(models[i], xs[i], zeros(meta.nnln), zeros(n)) for i = 1:n_models] for i = 1:n_models @test batch_jac_nln_ops[i] * vs[i] ≈ manual_jac_nln_ops[i] * vs[i] @test batch_jac_nln_ops[i]' * ws_nln_vec ≈ manual_jac_nln_ops[i]' * ws_nln_vec @@ -529,39 +446,61 @@ @test batch_conscales ≈ manual_conscales # Test structure functions - first_model = first(models) - @test batch_jac_structure(bnlp) == jac_structure(first_model) - @test batch_jac_lin_structure(bnlp) == jac_lin_structure(first_model) - @test batch_jac_nln_structure(bnlp) == jac_nln_structure(first_model) - @test batch_hess_structure(bnlp) == hess_structure(first_model) - - rows, cols = jac_structure(first_model) - fill!(rows, 0) - fill!(cols, 0) - batch_jac_structure!(bnlp, rows, cols) - @test rows == jac_structure(first_model)[1] - @test cols == jac_structure(first_model)[2] - - rows, cols = jac_lin_structure(first_model) - fill!(rows, 0) - fill!(cols, 0) - batch_jac_lin_structure!(bnlp, rows, cols) - @test rows == jac_lin_structure(first_model)[1] - @test cols == jac_lin_structure(first_model)[2] - - rows, cols = jac_nln_structure(first_model) - fill!(rows, 0) - fill!(cols, 0) - batch_jac_nln_structure!(bnlp, rows, cols) - @test rows == jac_nln_structure(first_model)[1] - @test cols == jac_nln_structure(first_model)[2] - - rows, cols = hess_structure(first_model) - fill!(rows, 0) - fill!(cols, 0) - batch_hess_structure!(bnlp, rows, cols) - @test rows == hess_structure(first_model)[1] - @test cols == hess_structure(first_model)[2] + jac_structures = batch_jac_structure(bnlp) + manual_jac_structures = [jac_structure(models[i]) for i = 1:n_models] + @test jac_structures == manual_jac_structures + + jac_lin_structures = batch_jac_lin_structure(bnlp) + manual_jac_lin_structures = [jac_lin_structure(models[i]) for i = 1:n_models] + @test jac_lin_structures == manual_jac_lin_structures + + jac_nln_structures = batch_jac_nln_structure(bnlp) + manual_jac_nln_structures = [jac_nln_structure(models[i]) for i = 1:n_models] + @test jac_nln_structures == manual_jac_nln_structures + + hess_structures = batch_hess_structure(bnlp) + manual_hess_structures = [hess_structure(models[i]) for i = 1:n_models] + @test hess_structures == manual_hess_structures + + rowss = [copy(manual_jac_structures[i][1]) for i = 1:n_models] + colss = [copy(manual_jac_structures[i][2]) for i = 1:n_models] + foreach(r -> fill!(r, 0), rowss) + foreach(c -> fill!(c, 0), colss) + batch_jac_structure!(bnlp, rowss, colss) + for i = 1:n_models + @test rowss[i] == manual_jac_structures[i][1] + @test colss[i] == manual_jac_structures[i][2] + end + + rowss = [copy(manual_jac_lin_structures[i][1]) for i = 1:n_models] + colss = [copy(manual_jac_lin_structures[i][2]) for i = 1:n_models] + foreach(r -> fill!(r, 0), rowss) + foreach(c -> fill!(c, 0), colss) + batch_jac_lin_structure!(bnlp, rowss, colss) + for i = 1:n_models + @test rowss[i] == manual_jac_lin_structures[i][1] + @test colss[i] == manual_jac_lin_structures[i][2] + end + + rowss = [copy(manual_jac_nln_structures[i][1]) for i = 1:n_models] + colss = [copy(manual_jac_nln_structures[i][2]) for i = 1:n_models] + foreach(r -> fill!(r, 0), rowss) + foreach(c -> fill!(c, 0), colss) + batch_jac_nln_structure!(bnlp, rowss, colss) + for i = 1:n_models + @test rowss[i] == manual_jac_nln_structures[i][1] + @test colss[i] == manual_jac_nln_structures[i][2] + end + + rowss = [copy(manual_hess_structures[i][1]) for i = 1:n_models] + colss = [copy(manual_hess_structures[i][2]) for i = 1:n_models] + foreach(r -> fill!(r, 0), rowss) + foreach(c -> fill!(c, 0), colss) + batch_hess_structure!(bnlp, rowss, colss) + for i = 1:n_models + @test rowss[i] == manual_hess_structures[i][1] + @test colss[i] == manual_hess_structures[i][2] + end end end end diff --git a/test/nlp/dummy-model.jl b/test/nlp/dummy-model.jl index 4dd2f9b5..d22c2f39 100644 --- a/test/nlp/dummy-model.jl +++ b/test/nlp/dummy-model.jl @@ -13,7 +13,7 @@ end @test_throws(MethodError, grad!(model, [0.0], [1.0])) @test_throws(MethodError, cons_lin!(model, [0.0], [1.0])) @test_throws(MethodError, cons_nln!(model, [0.0], [1.0])) - @test_throws(MethodError, jac_lin_coord!(model, [0.0], [1.0])) + @test_throws(MethodError, jac_lin_coord!(model, [1.0])) @test_throws(MethodError, jac_nln_coord!(model, [0.0], [1.0])) @test_throws(MethodError, jth_con(model, [0.0], 1)) @test_throws(MethodError, jth_congrad(model, [0.0], 1)) diff --git a/test/nlp/utils.jl b/test/nlp/utils.jl index 7d8b5ef8..88fd4ac2 100644 --- a/test/nlp/utils.jl +++ b/test/nlp/utils.jl @@ -2,12 +2,26 @@ mutable struct SuperNLPModel{T, S} <: AbstractNLPModel{T, S} model end -@testset "Testing @lencheck e @rangecheck" begin +@testset "Testing @lencheck, @lencheck_tup, @rangecheck" begin x = zeros(2) @lencheck 2 x @test_throws DimensionError @lencheck 1 x @test_throws DimensionError @lencheck 3 x + xs = (zeros(2), ones(2)) + @lencheck_tup 2 xs + xs_bad = (zeros(2), ones(3)) + err = try + @lencheck_tup 2 xs_bad + nothing + catch e + e + end + @test isa(err, DimensionError) + @test err.name == "xs_bad[2]" + @test err.dim_expected == 2 + @test err.dim_found == 3 + @rangecheck 1 3 2 @test_throws ErrorException @rangecheck 1 3 0 @test_throws ErrorException @rangecheck 1 3 4 From 544220a059ad9c1aed5b793921566d4763d8cb27 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 20 Nov 2025 13:00:33 -0500 Subject: [PATCH 06/13] stubs --- src/nlp/batch/api.jl | 70 ++++++++++++++++++++++++++++++++++++++++ src/nlp/batch/foreach.jl | 7 ++-- 2 files changed, 74 insertions(+), 3 deletions(-) diff --git a/src/nlp/batch/api.jl b/src/nlp/batch/api.jl index 4ca10bdf..85246f28 100644 --- a/src/nlp/batch/api.jl +++ b/src/nlp/batch/api.jl @@ -19,3 +19,73 @@ abstract type AbstractBatchNLPModel end function NLPModels.increment!(bnlp::AbstractBatchNLPModel, fun::Symbol) NLPModels.increment!(bnlp, Val(fun)) end + +function batch_obj end +function batch_grad end +function batch_grad! end +function batch_objgrad end +function batch_objgrad! end +function batch_objcons end +function batch_objcons! end +function batch_cons end +function batch_cons! end +function batch_cons_lin end +function batch_cons_lin! end +function batch_cons_nln end +function batch_cons_nln! end +function batch_jth_con end +function batch_jth_congrad end +function batch_jth_congrad! end +function batch_jth_sparse_congrad end +function batch_jac_structure! end +function batch_jac_structure end +function batch_jac_coord! end +function batch_jac_coord end +function batch_jac end +function batch_jprod end +function batch_jprod! end +function batch_jtprod end +function batch_jtprod! end +function batch_jac_op end +function batch_jac_op! end +function batch_jac_lin_structure! end +function batch_jac_lin_structure end +function batch_jac_lin_coord! end +function batch_jac_lin_coord end +function batch_jac_lin end +function batch_jprod_lin end +function batch_jprod_lin! end +function batch_jtprod_lin end +function batch_jtprod_lin! end +function batch_jac_lin_op end +function batch_jac_lin_op! end +function batch_jac_nln_structure! end +function batch_jac_nln_structure end +function batch_jac_nln_coord! end +function batch_jac_nln_coord end +function batch_jac_nln end +function batch_jprod_nln end +function batch_jprod_nln! end +function batch_jtprod_nln end +function batch_jtprod_nln! end +function batch_jac_nln_op end +function batch_jac_nln_op! end +function batch_jth_hess_coord end +function batch_jth_hess_coord! end +function batch_jth_hess end +function batch_jth_hprod end +function batch_jth_hprod! end +function batch_ghjvprod end +function batch_ghjvprod! end +function batch_hess_structure! end +function batch_hess_structure end +function batch_hess_coord! end +function batch_hess_coord end +function batch_hess end +function batch_hprod end +function batch_hprod! end +function batch_hess_op end +function batch_hess_op! end +function batch_varscale end +function batch_lagscale end +function batch_conscale end diff --git a/src/nlp/batch/foreach.jl b/src/nlp/batch/foreach.jl index ee5c89f0..530c8ec1 100644 --- a/src/nlp/batch/foreach.jl +++ b/src/nlp/batch/foreach.jl @@ -99,12 +99,13 @@ end # 2. define _batch_map(f::func, ::MyBatchModel, ...) # in most cases, using the first option is preferable. # however, when overriding several functions at a time, -# for example if you know all the jac/hess structures are identical, one can write +# for example if you know the hess structures are identical, +# one can write something like # # function NLPModels._batch_map( -# f::F, +# f::hess_structure, # bnlp::MyBatchModel -# ) where {F<:Union{jac_structure,jac_lin_structure,jac_nln_structure,hess_structure}} +# ) # # return f(first(bnlp)) # end From fb1937d50876a88c9342b2b8e161ed4cdcc01770 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 20 Nov 2025 15:42:59 -0500 Subject: [PATCH 07/13] rm comment --- src/nlp/batch/foreach.jl | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/nlp/batch/foreach.jl b/src/nlp/batch/foreach.jl index 530c8ec1..e536d08b 100644 --- a/src/nlp/batch/foreach.jl +++ b/src/nlp/batch/foreach.jl @@ -94,21 +94,6 @@ for fun in fieldnames(Counters) end end -# There are two options for defining "special cases": -# 1. define batch_func(::MyBatchModel) -# 2. define _batch_map(f::func, ::MyBatchModel, ...) -# in most cases, using the first option is preferable. -# however, when overriding several functions at a time, -# for example if you know the hess structures are identical, -# one can write something like -# -# function NLPModels._batch_map( -# f::hess_structure, -# bnlp::MyBatchModel -# ) -# -# return f(first(bnlp)) -# end batch_jac_structure(bnlp::ForEachBatchNLPModel) = _batch_map(jac_structure, bnlp) From cb48068f0e6a3b5da16baf9520851b802921257e Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 20 Nov 2025 16:12:05 -0500 Subject: [PATCH 08/13] inplace --- src/NLPModels.jl | 1 + src/nlp/batch/foreach.jl | 4 +- src/nlp/batch/inplace.jl | 268 +++++++++++++++++++++++++++++++++++++++ test/nlp/batch_api.jl | 24 +++- 4 files changed, 292 insertions(+), 5 deletions(-) create mode 100644 src/nlp/batch/inplace.jl diff --git a/src/NLPModels.jl b/src/NLPModels.jl index 6aeeb086..903d48a9 100644 --- a/src/NLPModels.jl +++ b/src/NLPModels.jl @@ -44,5 +44,6 @@ end include("nlp/batch/api.jl") include("nlp/batch/foreach.jl") +include("nlp/batch/inplace.jl") end # module diff --git a/src/nlp/batch/foreach.jl b/src/nlp/batch/foreach.jl index e536d08b..09002b09 100644 --- a/src/nlp/batch/foreach.jl +++ b/src/nlp/batch/foreach.jl @@ -8,7 +8,7 @@ function ForEachBatchNLPModel(models::M) where {M} isempty(models) && error("Cannot create ForEachBatchNLPModel from empty collection.") ForEachBatchNLPModel{M}(models, Counters(), length(models)) end -Base.length(vnlp::ForEachBatchNLPModel) = length(vnlp.models) +Base.length(vnlp::ForEachBatchNLPModel) = vnlp.batch_size Base.getindex(vnlp::ForEachBatchNLPModel, i::Integer) = vnlp.models[i] Base.iterate(vnlp::ForEachBatchNLPModel, state::Integer = 1) = iterate(vnlp.models, state) @@ -204,7 +204,7 @@ batch_jth_hess_coord!(bnlp::ForEachBatchNLPModel, xs, j::Integer, outputs) = batch_jth_hprod!(bnlp::ForEachBatchNLPModel, xs, vs, j::Integer, outputs) = _batch_map!((m, out, x, v) -> jth_hprod!(m, x, v, j, out), bnlp, outputs, xs, vs) -# hess (need to treat obj_weight) FIXME: container type.. +# hess (need to treat obj_weight) FIXME: obj_weights is required in batch API batch_hprod(bnlp::ForEachBatchNLPModel, xs, vs; obj_weights) = _batch_map_weight((m, x, v; obj_weight) -> hprod(m, x, v; obj_weight = obj_weight), bnlp, obj_weights, xs, vs) batch_hprod(bnlp::ForEachBatchNLPModel, xs, ys, vs; obj_weights) = diff --git a/src/nlp/batch/inplace.jl b/src/nlp/batch/inplace.jl new file mode 100644 index 00000000..3190b2e3 --- /dev/null +++ b/src/nlp/batch/inplace.jl @@ -0,0 +1,268 @@ +export InplaceBatchNLPModel +struct InplaceBatchNLPModel{M} <: AbstractBatchNLPModel + base_model::M + updates + counters::Counters + batch_size::Int +end +function InplaceBatchNLPModel(base_model::M, updates) where {M} + isempty(updates) && error("Cannot create InplaceBatchNLPModel from empty collection.") + InplaceBatchNLPModel{M}(base_model, updates, Counters(), length(updates)) +end +# TODO: counters? +Base.length(vnlp::InplaceBatchNLPModel) = vnlp.batch_size + + +function _batch_map(f::F, bnlp::InplaceBatchNLPModel, xs::Vararg{T,N}) where {F,T,N} + n = bnlp.batch_size + @lencheck_tup n xs + results = [] + resize!(results, n) + for i = 1:n + args_i = (x[i] for x in xs) + bnlp.updates[i](bnlp.base_model) # call update function + results[i] = f(bnlp.base_model, args_i...) + end + return results +end + +function _batch_map!(f::F, bnlp::InplaceBatchNLPModel, outputs, xs::Vararg{T,N}) where {F,T,N} + n = bnlp.batch_size + @lencheck_tup n xs + @lencheck n outputs + for i = 1:n + args_i = (x[i] for x in xs) + bnlp.updates[i](bnlp.base_model) # call update function + f(bnlp.base_model, outputs[i], args_i...) + end + return outputs +end + +function _batch_map_weight(f::F, bnlp::InplaceBatchNLPModel, obj_weights, xs::Vararg{T,N}) where {F,T,N} + n = bnlp.batch_size + @lencheck_tup n xs + @lencheck n obj_weights + results = [] + resize!(results, n) + for i = 1:n + args_i = (x[i] for x in xs) + bnlp.updates[i](bnlp.base_model) # call update function + results[i] = f(bnlp.base_model, args_i...; obj_weight = obj_weights[i]) + end + return results +end + +function _batch_map_weight!(f::F, bnlp::InplaceBatchNLPModel, outputs, obj_weights, xs::Vararg{T,N}) where {F,T,N} + n = bnlp.batch_size + @lencheck_tup n xs + @lencheck n outputs obj_weights + for i = 1:n + args_i = (x[i] for x in xs) + bnlp.updates[i](bnlp.base_model) # call update function + f(bnlp.base_model, outputs[i], args_i...; obj_weight = obj_weights[i]) + end + return outputs +end + +function _batch_map_tuple(f::F, bnlp::InplaceBatchNLPModel, xs::Vararg{T,N}) where {F,T,N} + n = bnlp.batch_size + @lencheck_tup n xs + results = _batch_map(f, bnlp, xs...) + + first_result = first(results) + T1, T2 = typeof(first_result[1]), typeof(first_result[2]) + vec1, vec2 = Vector{T1}(undef, n), Vector{T2}(undef, n) + for i = 1:n + vec1[i], vec2[i] = results[i] + end + return vec1, vec2 +end + +function _batch_map_tuple!(f::F, bnlp::InplaceBatchNLPModel, outputs, xs::Vararg{T,N}) where {F,T,N} + n = bnlp.batch_size + @lencheck_tup n xs + @lencheck n outputs + firsts = [] + resize!(firsts, n) + for i = 1:n + args_i = (x[i] for x in xs) + bnlp.updates[i](bnlp.base_model) # call update function + firsts[i], _ = f(bnlp.base_model, args_i..., outputs[i]) + end + return firsts, outputs +end + +for fun in fieldnames(Counters) + @eval function NLPModels.increment!(bnlp::InplaceBatchNLPModel, ::Val{$(Meta.quot(fun))}) + # sub-model counters are already incremented since we call their methods + bnlp.counters.$fun += 1 + end +end + + +batch_jac_structure(bnlp::InplaceBatchNLPModel) = + _batch_map(jac_structure, bnlp) +batch_jac_lin_structure(bnlp::InplaceBatchNLPModel) = + _batch_map(jac_lin_structure, bnlp) +batch_jac_nln_structure(bnlp::InplaceBatchNLPModel) = + _batch_map(jac_nln_structure, bnlp) +batch_hess_structure(bnlp::InplaceBatchNLPModel) = + _batch_map(hess_structure, bnlp) +batch_obj(bnlp::InplaceBatchNLPModel, xs) = + _batch_map(obj, bnlp, xs) +batch_grad(bnlp::InplaceBatchNLPModel, xs) = + _batch_map(grad, bnlp, xs) +batch_cons(bnlp::InplaceBatchNLPModel, xs) = + _batch_map(cons, bnlp, xs) +batch_cons_lin(bnlp::InplaceBatchNLPModel, xs) = + _batch_map(cons_lin, bnlp, xs) +batch_cons_nln(bnlp::InplaceBatchNLPModel, xs) = + _batch_map(cons_nln, bnlp, xs) +batch_jac(bnlp::InplaceBatchNLPModel, xs) = + _batch_map(jac, bnlp, xs) +batch_jac_lin(bnlp::InplaceBatchNLPModel) = + _batch_map(jac_lin, bnlp) +batch_jac_nln(bnlp::InplaceBatchNLPModel, xs) = + _batch_map(jac_nln, bnlp, xs) +batch_jac_lin_coord(bnlp::InplaceBatchNLPModel) = + _batch_map(jac_lin_coord, bnlp) +batch_jac_coord(bnlp::InplaceBatchNLPModel, xs) = + _batch_map(jac_coord, bnlp, xs) +batch_jac_nln_coord(bnlp::InplaceBatchNLPModel, xs) = + _batch_map(jac_nln_coord, bnlp, xs) +batch_varscale(bnlp::InplaceBatchNLPModel) = + _batch_map(varscale, bnlp) +batch_lagscale(bnlp::InplaceBatchNLPModel) = + _batch_map(lagscale, bnlp) +batch_conscale(bnlp::InplaceBatchNLPModel) = + _batch_map(conscale, bnlp) +batch_jprod(bnlp::InplaceBatchNLPModel, xs, vs) = + _batch_map(jprod, bnlp, xs, vs) +batch_jtprod(bnlp::InplaceBatchNLPModel, xs, vs) = + _batch_map(jtprod, bnlp, xs, vs) +batch_jprod_nln(bnlp::InplaceBatchNLPModel, xs, vs) = + _batch_map(jprod_nln, bnlp, xs, vs) +batch_jtprod_nln(bnlp::InplaceBatchNLPModel, xs, vs) = + _batch_map(jtprod_nln, bnlp, xs, vs) +batch_jprod_lin(bnlp::InplaceBatchNLPModel, vs) = + _batch_map(jprod_lin, bnlp, vs) +batch_jtprod_lin(bnlp::InplaceBatchNLPModel, vs) = + _batch_map(jtprod_lin, bnlp, vs) +batch_ghjvprod(bnlp::InplaceBatchNLPModel, xs, gs, vs) = + _batch_map(ghjvprod, bnlp, xs, gs, vs) + +batch_jac_structure!(bnlp::InplaceBatchNLPModel, rowss, colss) = + _batch_map!(jac_structure!, bnlp, rowss, colss) +batch_jac_lin_structure!(bnlp::InplaceBatchNLPModel, rowss, colss) = + _batch_map!(jac_lin_structure!, bnlp, rowss, colss) +batch_jac_nln_structure!(bnlp::InplaceBatchNLPModel, rowss, colss) = + _batch_map!(jac_nln_structure!, bnlp, rowss, colss) +batch_hess_structure!(bnlp::InplaceBatchNLPModel, rowss, colss) = + _batch_map!(hess_structure!, bnlp, rowss, colss) +batch_jac_lin_coord!(bnlp::InplaceBatchNLPModel, valss) = + _batch_map!(jac_lin_coord!, bnlp, valss) +batch_grad!(bnlp::InplaceBatchNLPModel, xs, gs) = + _batch_map!((m, g, x) -> grad!(m, x, g), bnlp, gs, xs) +batch_cons!(bnlp::InplaceBatchNLPModel, xs, cs) = + _batch_map!((m, c, x) -> cons!(m, x, c), bnlp, cs, xs) +batch_cons_lin!(bnlp::InplaceBatchNLPModel, xs, cs) = + _batch_map!((m, c, x) -> cons_lin!(m, x, c), bnlp, cs, xs) +batch_cons_nln!(bnlp::InplaceBatchNLPModel, xs, cs) = + _batch_map!((m, c, x) -> cons_nln!(m, x, c), bnlp, cs, xs) +batch_jac_coord!(bnlp::InplaceBatchNLPModel, xs, valss) = + _batch_map!((m, vals, x) -> jac_coord!(m, x, vals), bnlp, valss, xs) +batch_jac_nln_coord!(bnlp::InplaceBatchNLPModel, xs, valss) = + _batch_map!((m, vals, x) -> jac_nln_coord!(m, x, vals), bnlp, valss, xs) +batch_jprod!(bnlp::InplaceBatchNLPModel, xs, vs, Jvs) = + _batch_map!((m, Jv, x, v) -> jprod!(m, x, v, Jv), bnlp, Jvs, xs, vs) +batch_jtprod!(bnlp::InplaceBatchNLPModel, xs, vs, Jtvs) = + _batch_map!((m, Jtv, x, v) -> jtprod!(m, x, v, Jtv), bnlp, Jtvs, xs, vs) +batch_jprod_nln!(bnlp::InplaceBatchNLPModel, xs, vs, Jvs) = + _batch_map!((m, Jv, x, v) -> jprod_nln!(m, x, v, Jv), bnlp, Jvs, xs, vs) +batch_jtprod_nln!(bnlp::InplaceBatchNLPModel, xs, vs, Jtvs) = + _batch_map!((m, Jtv, x, v) -> jtprod_nln!(m, x, v, Jtv), bnlp, Jtvs, xs, vs) +batch_jprod_lin!(bnlp::InplaceBatchNLPModel, vs, Jvs) = + _batch_map!((m, Jv, v) -> jprod_lin!(m, v, Jv), bnlp, Jvs, vs) +batch_jtprod_lin!(bnlp::InplaceBatchNLPModel, vs, Jtvs) = + _batch_map!((m, Jtv, v) -> jtprod_lin!(m, v, Jtv), bnlp, Jtvs, vs) +batch_ghjvprod!(bnlp::InplaceBatchNLPModel, xs, gs, vs, gHvs) = + _batch_map!((m, gHv, x, g, v) -> ghjvprod!(m, x, g, v, gHv), bnlp, gHvs, xs, gs, vs) + +## jth +batch_jth_con(bnlp::InplaceBatchNLPModel, xs, j::Integer) = + _batch_map((m, x) -> jth_con(m, x, j), bnlp, xs) +batch_jth_congrad(bnlp::InplaceBatchNLPModel, xs, j::Integer) = + _batch_map((m, x) -> jth_congrad(m, x, j), bnlp, xs) +batch_jth_sparse_congrad(bnlp::InplaceBatchNLPModel, xs, j::Integer) = + _batch_map((m, x) -> jth_sparse_congrad(m, x, j), bnlp, xs) +batch_jth_hess_coord(bnlp::InplaceBatchNLPModel, xs, j::Integer) = + _batch_map((m, x) -> jth_hess_coord(m, x, j), bnlp, xs) +batch_jth_hess(bnlp::InplaceBatchNLPModel, xs, j::Integer) = + _batch_map((m, x) -> jth_hess(m, x, j), bnlp, xs) +batch_jth_hprod(bnlp::InplaceBatchNLPModel, xs, vs, j::Integer) = + _batch_map((m, x, v) -> jth_hprod(m, x, v, j), bnlp, xs, vs) + +batch_jth_congrad!(bnlp::InplaceBatchNLPModel, xs, j::Integer, outputs) = + _batch_map!((m, out, x) -> jth_congrad!(m, x, j, out), bnlp, outputs, xs) +batch_jth_hess_coord!(bnlp::InplaceBatchNLPModel, xs, j::Integer, outputs) = + _batch_map!((m, out, x) -> jth_hess_coord!(m, x, j, out), bnlp, outputs, xs) +batch_jth_hprod!(bnlp::InplaceBatchNLPModel, xs, vs, j::Integer, outputs) = + _batch_map!((m, out, x, v) -> jth_hprod!(m, x, v, j, out), bnlp, outputs, xs, vs) + +# hess (need to treat obj_weight) FIXME: obj_weights is required in batch API +batch_hprod(bnlp::InplaceBatchNLPModel, xs, vs; obj_weights) = + _batch_map_weight((m, x, v; obj_weight) -> hprod(m, x, v; obj_weight = obj_weight), bnlp, obj_weights, xs, vs) +batch_hprod(bnlp::InplaceBatchNLPModel, xs, ys, vs; obj_weights) = + _batch_map_weight((m, x, y, v; obj_weight) -> hprod(m, x, y, v; obj_weight = obj_weight), bnlp, obj_weights, xs, ys, vs) +batch_hess_coord(bnlp::InplaceBatchNLPModel, xs; obj_weights) = + _batch_map_weight((m, x; obj_weight) -> hess_coord(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) +batch_hess_coord(bnlp::InplaceBatchNLPModel, xs, ys; obj_weights) = + _batch_map_weight((m, x, y; obj_weight) -> hess_coord(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) +batch_hess_op(bnlp::InplaceBatchNLPModel, xs; obj_weights) = + _batch_map_weight((m, x; obj_weight) -> hess_op(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) +batch_hess_op(bnlp::InplaceBatchNLPModel, xs, ys; obj_weights) = + _batch_map_weight((m, x, y; obj_weight) -> hess_op(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) + +batch_hprod!(bnlp::InplaceBatchNLPModel, xs, vs, outputs; obj_weights) = + _batch_map_weight!((m, Hv, x, v; obj_weight) -> hprod!(m, x, v, Hv; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, vs) +batch_hprod!(bnlp::InplaceBatchNLPModel, xs, ys, vs, outputs; obj_weights) = + _batch_map_weight!((m, Hv, x, y, v; obj_weight) -> hprod!(m, x, y, v, Hv; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, ys, vs) +batch_hess_coord!(bnlp::InplaceBatchNLPModel, xs, outputs; obj_weights) = + _batch_map_weight!((m, vals, x; obj_weight) -> hess_coord!(m, x, vals; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs) +batch_hess_coord!(bnlp::InplaceBatchNLPModel, xs, ys, outputs; obj_weights) = + _batch_map_weight!((m, vals, x, y; obj_weight) -> hess_coord!(m, x, y, vals; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, ys) +batch_hess_op!(bnlp::InplaceBatchNLPModel, xs, Hvs; obj_weights) = + _batch_map_weight((m, x, Hv; obj_weight) -> hess_op!(m, x, Hv; obj_weight = obj_weight), bnlp, obj_weights, xs, Hvs) +batch_hess_op!(bnlp::InplaceBatchNLPModel, xs, ys, Hvs; obj_weights) = + _batch_map_weight((m, x, y, Hv; obj_weight) -> hess_op!(m, x, y, Hv; obj_weight = obj_weight), bnlp, obj_weights, xs, ys, Hvs) + +batch_hess(bnlp::InplaceBatchNLPModel, xs; obj_weights) = + _batch_map_weight((m, x; obj_weight) -> hess(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) +batch_hess(bnlp::InplaceBatchNLPModel, xs, ys; obj_weights) = + _batch_map_weight((m, x, y; obj_weight) -> hess(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) + +## operators +batch_jac_op(bnlp::InplaceBatchNLPModel, xs) = + _batch_map(jac_op, bnlp, xs) +batch_jac_lin_op(bnlp::InplaceBatchNLPModel) = + _batch_map(jac_lin_op, bnlp) +batch_jac_nln_op(bnlp::InplaceBatchNLPModel, xs) = + _batch_map(jac_nln_op, bnlp, xs) + +batch_jac_op!(bnlp::InplaceBatchNLPModel, xs, Jvs, Jtvs) = + _batch_map(jac_op!, bnlp, xs, Jvs, Jtvs) +batch_jac_lin_op!(bnlp::InplaceBatchNLPModel, Jvs, Jtvs) = + _batch_map(jac_lin_op!, bnlp, Jvs, Jtvs) +batch_jac_nln_op!(bnlp::InplaceBatchNLPModel, xs, Jvs, Jtvs) = + _batch_map(jac_nln_op!, bnlp, xs, Jvs, Jtvs) + +## tuple functions +batch_objgrad(bnlp::InplaceBatchNLPModel, xs) = + _batch_map_tuple(objgrad, bnlp, xs) +batch_objcons(bnlp::InplaceBatchNLPModel, xs) = + _batch_map_tuple(objcons, bnlp, xs) + +batch_objgrad!(bnlp::InplaceBatchNLPModel, xs, gs) = + _batch_map_tuple!(objgrad!, bnlp, gs, xs) +batch_objcons!(bnlp::InplaceBatchNLPModel, xs, cs) = + _batch_map_tuple!(objcons!, bnlp, cs, xs) diff --git a/test/nlp/batch_api.jl b/test/nlp/batch_api.jl index 5691ddb2..0b43fb16 100644 --- a/test/nlp/batch_api.jl +++ b/test/nlp/batch_api.jl @@ -5,6 +5,11 @@ models = [SimpleNLPModel() for _ = 1:n_models] meta = models[1].meta n, m = meta.nvar, meta.ncon + T = eltype(meta.lcon) + lcon_values = [[T(-i / 2), T((i - 1) / 2)] for i = 1:n_models] + for i = 1:n_models + models[i].meta.lcon .= lcon_values[i] + end xs = [randn(n) for _ = 1:n_models] ys = [randn(m) for _ = 1:n_models] vs = [randn(n) for _ = 1:n_models] @@ -12,9 +17,22 @@ gs = [zeros(n) for _ = 1:n_models] cs = [zeros(m) for _ = 1:n_models] obj_weights = rand(n_models) - for batch_model in [ForEachBatchNLPModel] - @testset "$batch_model consistency" begin - bnlp = batch_model(models) + function make_inplace_batch_model() + base_model = SimpleNLPModel() + updates = [nlp -> copyto!(get_lcon(nlp), lcons) for lcons in lcon_values] + return InplaceBatchNLPModel(base_model, updates) + end + + @test_throws ErrorException InplaceBatchNLPModel(SimpleNLPModel(), []) + + batch_model_builders = [ + ("ForEachBatchNLPModel", () -> ForEachBatchNLPModel(models)), + ("InplaceBatchNLPModel", () -> make_inplace_batch_model()), + ] + + for (batch_model_name, build_batch_model) in batch_model_builders + @testset "$batch_model_name consistency" begin + bnlp = build_batch_model() # Test batch_obj batch_fs = batch_obj(bnlp, xs) From 64de06b9c71f141d1b71bf8ffc15d16f64120d93 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 20 Nov 2025 17:26:22 -0500 Subject: [PATCH 09/13] reduce lambdas --- src/nlp/batch/foreach.jl | 81 ++++++++++++++++++++++------------------ 1 file changed, 44 insertions(+), 37 deletions(-) diff --git a/src/nlp/batch/foreach.jl b/src/nlp/batch/foreach.jl index 09002b09..701f5a48 100644 --- a/src/nlp/batch/foreach.jl +++ b/src/nlp/batch/foreach.jl @@ -25,13 +25,16 @@ function _batch_map(f::F, bnlp::ForEachBatchNLPModel, xs::Vararg{T,N}) where {F, return results end -function _batch_map!(f::F, bnlp::ForEachBatchNLPModel, outputs, xs::Vararg{T,N}) where {F,T,N} +function _batch_map!(f::F, bnlp::ForEachBatchNLPModel, xs::Vararg{Any,N}) where {F,N} n = bnlp.batch_size + length(xs) == 0 && error("Cannot call _batch_map! without providing arguments.") @lencheck_tup n xs + outputs = xs[end] + inputs = length(xs) == 1 ? () : Base.ntuple(i -> xs[i], length(xs) - 1) @lencheck n outputs for i = 1:n - args_i = (x[i] for x in xs) - f(bnlp[i], outputs[i], args_i...) + args_i = (x[i] for x in inputs) + f(bnlp[i], args_i..., outputs[i]) end return outputs end @@ -49,13 +52,17 @@ function _batch_map_weight(f::F, bnlp::ForEachBatchNLPModel, obj_weights, xs::Va return results end -function _batch_map_weight!(f::F, bnlp::ForEachBatchNLPModel, outputs, obj_weights, xs::Vararg{T,N}) where {F,T,N} +function _batch_map_weight!(f::F, bnlp::ForEachBatchNLPModel, obj_weights, xs::Vararg{Any,N}) where {F,N} n = bnlp.batch_size + length(xs) == 0 && error("Cannot call _batch_map_weight! without providing arguments.") @lencheck_tup n xs - @lencheck n outputs obj_weights + @lencheck n obj_weights + outputs = xs[end] + inputs = length(xs) == 1 ? () : Base.ntuple(i -> xs[i], length(xs) - 1) + @lencheck n outputs for i = 1:n - args_i = (x[i] for x in xs) - f(bnlp[i], outputs[i], args_i...; obj_weight = obj_weights[i]) + args_i = (x[i] for x in inputs) + f(bnlp[i], args_i..., outputs[i]; obj_weight = obj_weights[i]) end return outputs end @@ -157,31 +164,31 @@ batch_hess_structure!(bnlp::ForEachBatchNLPModel, rowss, colss) = batch_jac_lin_coord!(bnlp::ForEachBatchNLPModel, valss) = _batch_map!(jac_lin_coord!, bnlp, valss) batch_grad!(bnlp::ForEachBatchNLPModel, xs, gs) = - _batch_map!((m, g, x) -> grad!(m, x, g), bnlp, gs, xs) + _batch_map!(grad!, bnlp, xs, gs) batch_cons!(bnlp::ForEachBatchNLPModel, xs, cs) = - _batch_map!((m, c, x) -> cons!(m, x, c), bnlp, cs, xs) + _batch_map!(cons!, bnlp, xs, cs) batch_cons_lin!(bnlp::ForEachBatchNLPModel, xs, cs) = - _batch_map!((m, c, x) -> cons_lin!(m, x, c), bnlp, cs, xs) + _batch_map!(cons_lin!, bnlp, xs, cs) batch_cons_nln!(bnlp::ForEachBatchNLPModel, xs, cs) = - _batch_map!((m, c, x) -> cons_nln!(m, x, c), bnlp, cs, xs) + _batch_map!(cons_nln!, bnlp, xs, cs) batch_jac_coord!(bnlp::ForEachBatchNLPModel, xs, valss) = - _batch_map!((m, vals, x) -> jac_coord!(m, x, vals), bnlp, valss, xs) + _batch_map!(jac_coord!, bnlp, xs, valss) batch_jac_nln_coord!(bnlp::ForEachBatchNLPModel, xs, valss) = - _batch_map!((m, vals, x) -> jac_nln_coord!(m, x, vals), bnlp, valss, xs) + _batch_map!(jac_nln_coord!, bnlp, xs, valss) batch_jprod!(bnlp::ForEachBatchNLPModel, xs, vs, Jvs) = - _batch_map!((m, Jv, x, v) -> jprod!(m, x, v, Jv), bnlp, Jvs, xs, vs) + _batch_map!(jprod!, bnlp, xs, vs, Jvs) batch_jtprod!(bnlp::ForEachBatchNLPModel, xs, vs, Jtvs) = - _batch_map!((m, Jtv, x, v) -> jtprod!(m, x, v, Jtv), bnlp, Jtvs, xs, vs) + _batch_map!(jtprod!, bnlp, xs, vs, Jtvs) batch_jprod_nln!(bnlp::ForEachBatchNLPModel, xs, vs, Jvs) = - _batch_map!((m, Jv, x, v) -> jprod_nln!(m, x, v, Jv), bnlp, Jvs, xs, vs) + _batch_map!(jprod_nln!, bnlp, xs, vs, Jvs) batch_jtprod_nln!(bnlp::ForEachBatchNLPModel, xs, vs, Jtvs) = - _batch_map!((m, Jtv, x, v) -> jtprod_nln!(m, x, v, Jtv), bnlp, Jtvs, xs, vs) + _batch_map!(jtprod_nln!, bnlp, xs, vs, Jtvs) batch_jprod_lin!(bnlp::ForEachBatchNLPModel, vs, Jvs) = - _batch_map!((m, Jv, v) -> jprod_lin!(m, v, Jv), bnlp, Jvs, vs) + _batch_map!(jprod_lin!, bnlp, vs, Jvs) batch_jtprod_lin!(bnlp::ForEachBatchNLPModel, vs, Jtvs) = - _batch_map!((m, Jtv, v) -> jtprod_lin!(m, v, Jtv), bnlp, Jtvs, vs) + _batch_map!(jtprod_lin!, bnlp, vs, Jtvs) batch_ghjvprod!(bnlp::ForEachBatchNLPModel, xs, gs, vs, gHvs) = - _batch_map!((m, gHv, x, g, v) -> ghjvprod!(m, x, g, v, gHv), bnlp, gHvs, xs, gs, vs) + _batch_map!(ghjvprod!, bnlp, xs, gs, vs, gHvs) ## jth batch_jth_con(bnlp::ForEachBatchNLPModel, xs, j::Integer) = @@ -198,43 +205,43 @@ batch_jth_hprod(bnlp::ForEachBatchNLPModel, xs, vs, j::Integer) = _batch_map((m, x, v) -> jth_hprod(m, x, v, j), bnlp, xs, vs) batch_jth_congrad!(bnlp::ForEachBatchNLPModel, xs, j::Integer, outputs) = - _batch_map!((m, out, x) -> jth_congrad!(m, x, j, out), bnlp, outputs, xs) + _batch_map!((m, x, out) -> jth_congrad!(m, x, j, out), bnlp, xs, outputs) batch_jth_hess_coord!(bnlp::ForEachBatchNLPModel, xs, j::Integer, outputs) = - _batch_map!((m, out, x) -> jth_hess_coord!(m, x, j, out), bnlp, outputs, xs) + _batch_map!((m, x, out) -> jth_hess_coord!(m, x, j, out), bnlp, xs, outputs) batch_jth_hprod!(bnlp::ForEachBatchNLPModel, xs, vs, j::Integer, outputs) = - _batch_map!((m, out, x, v) -> jth_hprod!(m, x, v, j, out), bnlp, outputs, xs, vs) + _batch_map!((m, x, v, out) -> jth_hprod!(m, x, v, j, out), bnlp, xs, vs, outputs) # hess (need to treat obj_weight) FIXME: obj_weights is required in batch API batch_hprod(bnlp::ForEachBatchNLPModel, xs, vs; obj_weights) = - _batch_map_weight((m, x, v; obj_weight) -> hprod(m, x, v; obj_weight = obj_weight), bnlp, obj_weights, xs, vs) + _batch_map_weight(hprod, bnlp, obj_weights, xs, vs) batch_hprod(bnlp::ForEachBatchNLPModel, xs, ys, vs; obj_weights) = - _batch_map_weight((m, x, y, v; obj_weight) -> hprod(m, x, y, v; obj_weight = obj_weight), bnlp, obj_weights, xs, ys, vs) + _batch_map_weight(hprod, bnlp, obj_weights, xs, ys, vs) batch_hess_coord(bnlp::ForEachBatchNLPModel, xs; obj_weights) = - _batch_map_weight((m, x; obj_weight) -> hess_coord(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) + _batch_map_weight(hess_coord, bnlp, obj_weights, xs) batch_hess_coord(bnlp::ForEachBatchNLPModel, xs, ys; obj_weights) = - _batch_map_weight((m, x, y; obj_weight) -> hess_coord(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) + _batch_map_weight(hess_coord, bnlp, obj_weights, xs, ys) batch_hess_op(bnlp::ForEachBatchNLPModel, xs; obj_weights) = - _batch_map_weight((m, x; obj_weight) -> hess_op(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) + _batch_map_weight(hess_op, bnlp, obj_weights, xs) batch_hess_op(bnlp::ForEachBatchNLPModel, xs, ys; obj_weights) = - _batch_map_weight((m, x, y; obj_weight) -> hess_op(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) + _batch_map_weight(hess_op, bnlp, obj_weights, xs, ys) batch_hprod!(bnlp::ForEachBatchNLPModel, xs, vs, outputs; obj_weights) = - _batch_map_weight!((m, Hv, x, v; obj_weight) -> hprod!(m, x, v, Hv; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, vs) + _batch_map_weight!(hprod!, bnlp, obj_weights, xs, vs, outputs) batch_hprod!(bnlp::ForEachBatchNLPModel, xs, ys, vs, outputs; obj_weights) = - _batch_map_weight!((m, Hv, x, y, v; obj_weight) -> hprod!(m, x, y, v, Hv; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, ys, vs) + _batch_map_weight!(hprod!, bnlp, obj_weights, xs, ys, vs, outputs) batch_hess_coord!(bnlp::ForEachBatchNLPModel, xs, outputs; obj_weights) = - _batch_map_weight!((m, vals, x; obj_weight) -> hess_coord!(m, x, vals; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs) + _batch_map_weight!(hess_coord!, bnlp, obj_weights, xs, outputs) batch_hess_coord!(bnlp::ForEachBatchNLPModel, xs, ys, outputs; obj_weights) = - _batch_map_weight!((m, vals, x, y; obj_weight) -> hess_coord!(m, x, y, vals; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, ys) + _batch_map_weight!(hess_coord!, bnlp, obj_weights, xs, ys, outputs) batch_hess_op!(bnlp::ForEachBatchNLPModel, xs, Hvs; obj_weights) = - _batch_map_weight((m, x, Hv; obj_weight) -> hess_op!(m, x, Hv; obj_weight = obj_weight), bnlp, obj_weights, xs, Hvs) + _batch_map_weight(hess_op!, bnlp, obj_weights, xs, Hvs) batch_hess_op!(bnlp::ForEachBatchNLPModel, xs, ys, Hvs; obj_weights) = - _batch_map_weight((m, x, y, Hv; obj_weight) -> hess_op!(m, x, y, Hv; obj_weight = obj_weight), bnlp, obj_weights, xs, ys, Hvs) + _batch_map_weight(hess_op!, bnlp, obj_weights, xs, ys, Hvs) batch_hess(bnlp::ForEachBatchNLPModel, xs; obj_weights) = - _batch_map_weight((m, x; obj_weight) -> hess(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) + _batch_map_weight(hess, bnlp, obj_weights, xs) batch_hess(bnlp::ForEachBatchNLPModel, xs, ys; obj_weights) = - _batch_map_weight((m, x, y; obj_weight) -> hess(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) + _batch_map_weight(hess, bnlp, obj_weights, xs, ys) ## operators batch_jac_op(bnlp::ForEachBatchNLPModel, xs) = From e9fe7724c23226ceca83c846413dc65613c5e808 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 20 Nov 2025 17:28:00 -0500 Subject: [PATCH 10/13] no inplace ops, test with parametric simple model --- src/nlp/batch/inplace.jl | 108 ++++++++++--------- test/nlp/batch_api.jl | 224 +++++++++++++++++++++++---------------- test/nlp/simple-model.jl | 31 +++--- 3 files changed, 202 insertions(+), 161 deletions(-) diff --git a/src/nlp/batch/inplace.jl b/src/nlp/batch/inplace.jl index 3190b2e3..76cd60ea 100644 --- a/src/nlp/batch/inplace.jl +++ b/src/nlp/batch/inplace.jl @@ -9,6 +9,11 @@ function InplaceBatchNLPModel(base_model::M, updates) where {M} isempty(updates) && error("Cannot create InplaceBatchNLPModel from empty collection.") InplaceBatchNLPModel{M}(base_model, updates, Counters(), length(updates)) end + +const INPLACE_OPERATOR_ERROR = + "InplaceBatchNLPModel cannot return reusable linear operators because the base model is mutated per batch entry. Use ForEachBatchNLPModel instead." +_inplace_operator_error() = error(INPLACE_OPERATOR_ERROR) + # TODO: counters? Base.length(vnlp::InplaceBatchNLPModel) = vnlp.batch_size @@ -26,14 +31,17 @@ function _batch_map(f::F, bnlp::InplaceBatchNLPModel, xs::Vararg{T,N}) where {F, return results end -function _batch_map!(f::F, bnlp::InplaceBatchNLPModel, outputs, xs::Vararg{T,N}) where {F,T,N} +function _batch_map!(f::F, bnlp::InplaceBatchNLPModel, xs::Vararg{Any,N}) where {F,N} n = bnlp.batch_size + length(xs) == 0 && error("Cannot call _batch_map! without providing arguments.") @lencheck_tup n xs + outputs = xs[end] + inputs = length(xs) == 1 ? () : Base.ntuple(i -> xs[i], length(xs) - 1) @lencheck n outputs for i = 1:n - args_i = (x[i] for x in xs) + args_i = (x[i] for x in inputs) bnlp.updates[i](bnlp.base_model) # call update function - f(bnlp.base_model, outputs[i], args_i...) + f(bnlp.base_model, args_i..., outputs[i]) end return outputs end @@ -52,14 +60,18 @@ function _batch_map_weight(f::F, bnlp::InplaceBatchNLPModel, obj_weights, xs::Va return results end -function _batch_map_weight!(f::F, bnlp::InplaceBatchNLPModel, outputs, obj_weights, xs::Vararg{T,N}) where {F,T,N} +function _batch_map_weight!(f::F, bnlp::InplaceBatchNLPModel, obj_weights, xs::Vararg{Any,N}) where {F,N} n = bnlp.batch_size + length(xs) == 0 && error("_batch_map_weight! with zero args") @lencheck_tup n xs - @lencheck n outputs obj_weights + @lencheck n obj_weights + outputs = xs[end] + inputs = length(xs) == 1 ? () : Base.ntuple(i -> xs[i], length(xs) - 1) + @lencheck n outputs for i = 1:n - args_i = (x[i] for x in xs) + args_i = (x[i] for x in inputs) bnlp.updates[i](bnlp.base_model) # call update function - f(bnlp.base_model, outputs[i], args_i...; obj_weight = obj_weights[i]) + f(bnlp.base_model, args_i..., outputs[i]; obj_weight = obj_weights[i]) end return outputs end @@ -162,31 +174,31 @@ batch_hess_structure!(bnlp::InplaceBatchNLPModel, rowss, colss) = batch_jac_lin_coord!(bnlp::InplaceBatchNLPModel, valss) = _batch_map!(jac_lin_coord!, bnlp, valss) batch_grad!(bnlp::InplaceBatchNLPModel, xs, gs) = - _batch_map!((m, g, x) -> grad!(m, x, g), bnlp, gs, xs) + _batch_map!(grad!, bnlp, xs, gs) batch_cons!(bnlp::InplaceBatchNLPModel, xs, cs) = - _batch_map!((m, c, x) -> cons!(m, x, c), bnlp, cs, xs) + _batch_map!(cons!, bnlp, xs, cs) batch_cons_lin!(bnlp::InplaceBatchNLPModel, xs, cs) = - _batch_map!((m, c, x) -> cons_lin!(m, x, c), bnlp, cs, xs) + _batch_map!(cons_lin!, bnlp, xs, cs) batch_cons_nln!(bnlp::InplaceBatchNLPModel, xs, cs) = - _batch_map!((m, c, x) -> cons_nln!(m, x, c), bnlp, cs, xs) + _batch_map!(cons_nln!, bnlp, xs, cs) batch_jac_coord!(bnlp::InplaceBatchNLPModel, xs, valss) = - _batch_map!((m, vals, x) -> jac_coord!(m, x, vals), bnlp, valss, xs) + _batch_map!(jac_coord!, bnlp, xs, valss) batch_jac_nln_coord!(bnlp::InplaceBatchNLPModel, xs, valss) = - _batch_map!((m, vals, x) -> jac_nln_coord!(m, x, vals), bnlp, valss, xs) + _batch_map!(jac_nln_coord!, bnlp, xs, valss) batch_jprod!(bnlp::InplaceBatchNLPModel, xs, vs, Jvs) = - _batch_map!((m, Jv, x, v) -> jprod!(m, x, v, Jv), bnlp, Jvs, xs, vs) + _batch_map!(jprod!, bnlp, xs, vs, Jvs) batch_jtprod!(bnlp::InplaceBatchNLPModel, xs, vs, Jtvs) = - _batch_map!((m, Jtv, x, v) -> jtprod!(m, x, v, Jtv), bnlp, Jtvs, xs, vs) + _batch_map!(jtprod!, bnlp, xs, vs, Jtvs) batch_jprod_nln!(bnlp::InplaceBatchNLPModel, xs, vs, Jvs) = - _batch_map!((m, Jv, x, v) -> jprod_nln!(m, x, v, Jv), bnlp, Jvs, xs, vs) + _batch_map!(jprod_nln!, bnlp, xs, vs, Jvs) batch_jtprod_nln!(bnlp::InplaceBatchNLPModel, xs, vs, Jtvs) = - _batch_map!((m, Jtv, x, v) -> jtprod_nln!(m, x, v, Jtv), bnlp, Jtvs, xs, vs) + _batch_map!(jtprod_nln!, bnlp, xs, vs, Jtvs) batch_jprod_lin!(bnlp::InplaceBatchNLPModel, vs, Jvs) = - _batch_map!((m, Jv, v) -> jprod_lin!(m, v, Jv), bnlp, Jvs, vs) + _batch_map!(jprod_lin!, bnlp, vs, Jvs) batch_jtprod_lin!(bnlp::InplaceBatchNLPModel, vs, Jtvs) = - _batch_map!((m, Jtv, v) -> jtprod_lin!(m, v, Jtv), bnlp, Jtvs, vs) + _batch_map!(jtprod_lin!, bnlp, vs, Jtvs) batch_ghjvprod!(bnlp::InplaceBatchNLPModel, xs, gs, vs, gHvs) = - _batch_map!((m, gHv, x, g, v) -> ghjvprod!(m, x, g, v, gHv), bnlp, gHvs, xs, gs, vs) + _batch_map!(ghjvprod!, bnlp, xs, gs, vs, gHvs) ## jth batch_jth_con(bnlp::InplaceBatchNLPModel, xs, j::Integer) = @@ -203,58 +215,48 @@ batch_jth_hprod(bnlp::InplaceBatchNLPModel, xs, vs, j::Integer) = _batch_map((m, x, v) -> jth_hprod(m, x, v, j), bnlp, xs, vs) batch_jth_congrad!(bnlp::InplaceBatchNLPModel, xs, j::Integer, outputs) = - _batch_map!((m, out, x) -> jth_congrad!(m, x, j, out), bnlp, outputs, xs) + _batch_map!((m, x, out) -> jth_congrad!(m, x, j, out), bnlp, xs, outputs) batch_jth_hess_coord!(bnlp::InplaceBatchNLPModel, xs, j::Integer, outputs) = - _batch_map!((m, out, x) -> jth_hess_coord!(m, x, j, out), bnlp, outputs, xs) + _batch_map!((m, x, out) -> jth_hess_coord!(m, x, j, out), bnlp, xs, outputs) batch_jth_hprod!(bnlp::InplaceBatchNLPModel, xs, vs, j::Integer, outputs) = - _batch_map!((m, out, x, v) -> jth_hprod!(m, x, v, j, out), bnlp, outputs, xs, vs) + _batch_map!((m, x, v, out) -> jth_hprod!(m, x, v, j, out), bnlp, xs, vs, outputs) # hess (need to treat obj_weight) FIXME: obj_weights is required in batch API batch_hprod(bnlp::InplaceBatchNLPModel, xs, vs; obj_weights) = - _batch_map_weight((m, x, v; obj_weight) -> hprod(m, x, v; obj_weight = obj_weight), bnlp, obj_weights, xs, vs) + _batch_map_weight(hprod, bnlp, obj_weights, xs, vs) batch_hprod(bnlp::InplaceBatchNLPModel, xs, ys, vs; obj_weights) = - _batch_map_weight((m, x, y, v; obj_weight) -> hprod(m, x, y, v; obj_weight = obj_weight), bnlp, obj_weights, xs, ys, vs) + _batch_map_weight(hprod, bnlp, obj_weights, xs, ys, vs) batch_hess_coord(bnlp::InplaceBatchNLPModel, xs; obj_weights) = - _batch_map_weight((m, x; obj_weight) -> hess_coord(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) + _batch_map_weight(hess_coord, bnlp, obj_weights, xs) batch_hess_coord(bnlp::InplaceBatchNLPModel, xs, ys; obj_weights) = - _batch_map_weight((m, x, y; obj_weight) -> hess_coord(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) -batch_hess_op(bnlp::InplaceBatchNLPModel, xs; obj_weights) = - _batch_map_weight((m, x; obj_weight) -> hess_op(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) -batch_hess_op(bnlp::InplaceBatchNLPModel, xs, ys; obj_weights) = - _batch_map_weight((m, x, y; obj_weight) -> hess_op(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) + _batch_map_weight(hess_coord, bnlp, obj_weights, xs, ys) +batch_hess_op(bnlp::InplaceBatchNLPModel, xs; obj_weights) = _inplace_operator_error() +batch_hess_op(bnlp::InplaceBatchNLPModel, xs, ys; obj_weights) = _inplace_operator_error() batch_hprod!(bnlp::InplaceBatchNLPModel, xs, vs, outputs; obj_weights) = - _batch_map_weight!((m, Hv, x, v; obj_weight) -> hprod!(m, x, v, Hv; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, vs) + _batch_map_weight!(hprod!, bnlp, obj_weights, xs, vs, outputs) batch_hprod!(bnlp::InplaceBatchNLPModel, xs, ys, vs, outputs; obj_weights) = - _batch_map_weight!((m, Hv, x, y, v; obj_weight) -> hprod!(m, x, y, v, Hv; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, ys, vs) + _batch_map_weight!(hprod!, bnlp, obj_weights, xs, ys, vs, outputs) batch_hess_coord!(bnlp::InplaceBatchNLPModel, xs, outputs; obj_weights) = - _batch_map_weight!((m, vals, x; obj_weight) -> hess_coord!(m, x, vals; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs) + _batch_map_weight!(hess_coord!, bnlp, obj_weights, xs, outputs) batch_hess_coord!(bnlp::InplaceBatchNLPModel, xs, ys, outputs; obj_weights) = - _batch_map_weight!((m, vals, x, y; obj_weight) -> hess_coord!(m, x, y, vals; obj_weight = obj_weight), bnlp, outputs, obj_weights, xs, ys) -batch_hess_op!(bnlp::InplaceBatchNLPModel, xs, Hvs; obj_weights) = - _batch_map_weight((m, x, Hv; obj_weight) -> hess_op!(m, x, Hv; obj_weight = obj_weight), bnlp, obj_weights, xs, Hvs) -batch_hess_op!(bnlp::InplaceBatchNLPModel, xs, ys, Hvs; obj_weights) = - _batch_map_weight((m, x, y, Hv; obj_weight) -> hess_op!(m, x, y, Hv; obj_weight = obj_weight), bnlp, obj_weights, xs, ys, Hvs) + _batch_map_weight!(hess_coord!, bnlp, obj_weights, xs, ys, outputs) +batch_hess_op!(bnlp::InplaceBatchNLPModel, xs, Hvs; obj_weights) = _inplace_operator_error() +batch_hess_op!(bnlp::InplaceBatchNLPModel, xs, ys, Hvs; obj_weights) = _inplace_operator_error() batch_hess(bnlp::InplaceBatchNLPModel, xs; obj_weights) = - _batch_map_weight((m, x; obj_weight) -> hess(m, x; obj_weight = obj_weight), bnlp, obj_weights, xs) + _batch_map_weight(hess, bnlp, obj_weights, xs) batch_hess(bnlp::InplaceBatchNLPModel, xs, ys; obj_weights) = - _batch_map_weight((m, x, y; obj_weight) -> hess(m, x, y; obj_weight = obj_weight), bnlp, obj_weights, xs, ys) + _batch_map_weight(hess, bnlp, obj_weights, xs, ys) ## operators -batch_jac_op(bnlp::InplaceBatchNLPModel, xs) = - _batch_map(jac_op, bnlp, xs) -batch_jac_lin_op(bnlp::InplaceBatchNLPModel) = - _batch_map(jac_lin_op, bnlp) -batch_jac_nln_op(bnlp::InplaceBatchNLPModel, xs) = - _batch_map(jac_nln_op, bnlp, xs) +batch_jac_op(bnlp::InplaceBatchNLPModel, xs) = _inplace_operator_error() +batch_jac_lin_op(bnlp::InplaceBatchNLPModel) = _inplace_operator_error() +batch_jac_nln_op(bnlp::InplaceBatchNLPModel, xs) = _inplace_operator_error() -batch_jac_op!(bnlp::InplaceBatchNLPModel, xs, Jvs, Jtvs) = - _batch_map(jac_op!, bnlp, xs, Jvs, Jtvs) -batch_jac_lin_op!(bnlp::InplaceBatchNLPModel, Jvs, Jtvs) = - _batch_map(jac_lin_op!, bnlp, Jvs, Jtvs) -batch_jac_nln_op!(bnlp::InplaceBatchNLPModel, xs, Jvs, Jtvs) = - _batch_map(jac_nln_op!, bnlp, xs, Jvs, Jtvs) +batch_jac_op!(bnlp::InplaceBatchNLPModel, xs, Jvs, Jtvs) = _inplace_operator_error() +batch_jac_lin_op!(bnlp::InplaceBatchNLPModel, Jvs, Jtvs) = _inplace_operator_error() +batch_jac_nln_op!(bnlp::InplaceBatchNLPModel, xs, Jvs, Jtvs) = _inplace_operator_error() ## tuple functions batch_objgrad(bnlp::InplaceBatchNLPModel, xs) = diff --git a/test/nlp/batch_api.jl b/test/nlp/batch_api.jl index 0b43fb16..0d44d246 100644 --- a/test/nlp/batch_api.jl +++ b/test/nlp/batch_api.jl @@ -1,14 +1,13 @@ @testset "Batch API" begin - # Generate models - # TODO: non-identical models + # Generate models with varying curvature parameter n_models = 5 models = [SimpleNLPModel() for _ = 1:n_models] meta = models[1].meta n, m = meta.nvar, meta.ncon T = eltype(meta.lcon) - lcon_values = [[T(-i / 2), T((i - 1) / 2)] for i = 1:n_models] + p_values = [T(2 + i) for i = 1:n_models] for i = 1:n_models - models[i].meta.lcon .= lcon_values[i] + models[i].p = p_values[i] end xs = [randn(n) for _ = 1:n_models] ys = [randn(m) for _ = 1:n_models] @@ -19,7 +18,12 @@ obj_weights = rand(n_models) function make_inplace_batch_model() base_model = SimpleNLPModel() - updates = [nlp -> copyto!(get_lcon(nlp), lcons) for lcons in lcon_values] + updates = [ + begin + param = p + nlp -> (nlp.p = param) + end for p in p_values + ] return InplaceBatchNLPModel(base_model, updates) end @@ -297,38 +301,57 @@ ] @test hprods ≈ manual_hprods - # Test batch_hess_op with obj_weights (without y) - batch_hess_ops = batch_hess_op(bnlp, xs; obj_weights = obj_weights) - manual_hess_ops = [hess_op(models[i], xs[i]; obj_weight = obj_weights[i]) for i = 1:n_models] - for i = 1:n_models - @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] - end - - # Test batch_hess_op with obj_weights (with y) - batch_hess_ops = batch_hess_op(bnlp, xs, ys; obj_weights = obj_weights) - manual_hess_ops = - [hess_op(models[i], xs[i], ys[i]; obj_weight = obj_weights[i]) for i = 1:n_models] - for i = 1:n_models - @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] - end - - # Test batch_hess_op! with obj_weights (without y) - hvs = [zeros(n) for _ = 1:n_models] - batch_hess_ops = batch_hess_op!(bnlp, xs, hvs; obj_weights = obj_weights) - manual_hess_ops = - [hess_op!(models[i], xs[i], zeros(n); obj_weight = obj_weights[i]) for i = 1:n_models] - for i = 1:n_models - @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] - end - - # Test batch_hess_op! with obj_weights (with y) - hvs = [zeros(n) for _ = 1:n_models] - batch_hess_ops = batch_hess_op!(bnlp, xs, ys, hvs; obj_weights = obj_weights) - manual_hess_ops = [ - hess_op!(models[i], xs[i], ys[i], zeros(n); obj_weight = obj_weights[i]) for i = 1:n_models - ] - for i = 1:n_models - @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] + if isa(bnlp, ForEachBatchNLPModel) # NOTE: excluding InplaceBatchNLPModel + # Test batch_hess_op with obj_weights (without y) + batch_hess_ops = batch_hess_op(bnlp, xs; obj_weights = obj_weights) + manual_hess_ops = [ + hess_op(models[i], xs[i]; obj_weight = obj_weights[i]) for i = 1:n_models + ] + for i = 1:n_models + @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] + end + + # Test batch_hess_op with obj_weights (with y) + batch_hess_ops = batch_hess_op(bnlp, xs, ys; obj_weights = obj_weights) + manual_hess_ops = [ + hess_op(models[i], xs[i], ys[i]; obj_weight = obj_weights[i]) for i = 1:n_models + ] + for i = 1:n_models + @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] + end + + # Test batch_hess_op! with obj_weights (without y) + hvs = [zeros(n) for _ = 1:n_models] + batch_hess_ops = batch_hess_op!(bnlp, xs, hvs; obj_weights = obj_weights) + manual_hess_ops = [ + hess_op!(models[i], xs[i], zeros(n); obj_weight = obj_weights[i]) for i = 1:n_models + ] + for i = 1:n_models + @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] + end + + # Test batch_hess_op! with obj_weights (with y) + hvs = [zeros(n) for _ = 1:n_models] + batch_hess_ops = batch_hess_op!(bnlp, xs, ys, hvs; obj_weights = obj_weights) + manual_hess_ops = [ + hess_op!(models[i], xs[i], ys[i], zeros(n); obj_weight = obj_weights[i]) for + i = 1:n_models + ] + for i = 1:n_models + @test batch_hess_ops[i] * vs[i] ≈ manual_hess_ops[i] * vs[i] + end + else + @test_throws ErrorException batch_hess_op(bnlp, xs; obj_weights = obj_weights) + @test_throws ErrorException batch_hess_op(bnlp, xs, ys; obj_weights = obj_weights) + @test_throws ErrorException batch_hess_op!(bnlp, xs, [zeros(n) for _ = 1:n_models]; + obj_weights = obj_weights) + @test_throws ErrorException batch_hess_op!( + bnlp, + xs, + ys, + [zeros(n) for _ = 1:n_models]; + obj_weights = obj_weights, + ) end # Test batch_jth_con @@ -392,62 +415,77 @@ manual_ghjvprods = [ghjvprod!(models[i], xs[i], gs[i], vs[i], zeros(m)) for i = 1:n_models] @test ghjvprods ≈ manual_ghjvprods - # Test batch_jac_op - batch_jac_ops = batch_jac_op(bnlp, xs) - manual_jac_ops = [jac_op(models[i], xs[i]) for i = 1:n_models] - for i = 1:n_models - @test batch_jac_ops[i] * vs[i] ≈ manual_jac_ops[i] * vs[i] - @test batch_jac_ops[i]' * ws[i] ≈ manual_jac_ops[i]' * ws[i] - end - - # Test batch_jac_op! - jvs = [zeros(m) for _ = 1:n_models] - jtvs = [zeros(n) for _ = 1:n_models] - batch_jac_ops = batch_jac_op!(bnlp, xs, jvs, jtvs) - manual_jac_ops = [jac_op!(models[i], xs[i], zeros(m), zeros(n)) for i = 1:n_models] - for i = 1:n_models - @test batch_jac_ops[i] * vs[i] ≈ manual_jac_ops[i] * vs[i] - @test batch_jac_ops[i]' * ws[i] ≈ manual_jac_ops[i]' * ws[i] - end - - # Test batch_jac_lin_op - batch_jac_lin_ops = batch_jac_lin_op(bnlp) - manual_jac_lin_ops = [jac_lin_op(models[i]) for i = 1:n_models] - ws_lin_vec = ws[1][1:(meta.nlin)] - for i = 1:n_models - @test batch_jac_lin_ops[i] * vs[i] ≈ manual_jac_lin_ops[i] * vs[i] - @test batch_jac_lin_ops[i]' * ws_lin_vec ≈ manual_jac_lin_ops[i]' * ws_lin_vec - end - - # Test batch_jac_lin_op! - jvs_lin = [zeros(meta.nlin) for _ = 1:n_models] - jtvs_lin = [zeros(n) for _ = 1:n_models] - batch_jac_lin_ops = batch_jac_lin_op!(bnlp, jvs_lin, jtvs_lin) - manual_jac_lin_ops = - [jac_lin_op!(models[i], zeros(meta.nlin), zeros(n)) for i = 1:n_models] - for i = 1:n_models - @test batch_jac_lin_ops[i] * vs[i] ≈ manual_jac_lin_ops[i] * vs[i] - @test batch_jac_lin_ops[i]' * ws_lin_vec ≈ manual_jac_lin_ops[i]' * ws_lin_vec - end - - # Test batch_jac_nln_op - batch_jac_nln_ops = batch_jac_nln_op(bnlp, xs) - manual_jac_nln_ops = [jac_nln_op(models[i], xs[i]) for i = 1:n_models] - ws_nln_vec = ws[1][(meta.nlin + 1):end] - for i = 1:n_models - @test batch_jac_nln_ops[i] * vs[i] ≈ manual_jac_nln_ops[i] * vs[i] - @test batch_jac_nln_ops[i]' * ws_nln_vec ≈ manual_jac_nln_ops[i]' * ws_nln_vec - end - - # Test batch_jac_nln_op! - jvs_nln = [zeros(meta.nnln) for _ = 1:n_models] - jtvs_nln = [zeros(n) for _ = 1:n_models] - batch_jac_nln_ops = batch_jac_nln_op!(bnlp, xs, jvs_nln, jtvs_nln) - manual_jac_nln_ops = - [jac_nln_op!(models[i], xs[i], zeros(meta.nnln), zeros(n)) for i = 1:n_models] - for i = 1:n_models - @test batch_jac_nln_ops[i] * vs[i] ≈ manual_jac_nln_ops[i] * vs[i] - @test batch_jac_nln_ops[i]' * ws_nln_vec ≈ manual_jac_nln_ops[i]' * ws_nln_vec + if isa(bnlp, ForEachBatchNLPModel) + # Test batch_jac_op + batch_jac_ops = batch_jac_op(bnlp, xs) + manual_jac_ops = [jac_op(models[i], xs[i]) for i = 1:n_models] + for i = 1:n_models + @test batch_jac_ops[i] * vs[i] ≈ manual_jac_ops[i] * vs[i] + @test batch_jac_ops[i]' * ws[i] ≈ manual_jac_ops[i]' * ws[i] + end + + # Test batch_jac_op! + jvs = [zeros(m) for _ = 1:n_models] + jtvs = [zeros(n) for _ = 1:n_models] + batch_jac_ops = batch_jac_op!(bnlp, xs, jvs, jtvs) + manual_jac_ops = [jac_op!(models[i], xs[i], zeros(m), zeros(n)) for i = 1:n_models] + for i = 1:n_models + @test batch_jac_ops[i] * vs[i] ≈ manual_jac_ops[i] * vs[i] + @test batch_jac_ops[i]' * ws[i] ≈ manual_jac_ops[i]' * ws[i] + end + + # Test batch_jac_lin_op + batch_jac_lin_ops = batch_jac_lin_op(bnlp) + manual_jac_lin_ops = [jac_lin_op(models[i]) for i = 1:n_models] + ws_lin_vec = ws[1][1:(meta.nlin)] + for i = 1:n_models + @test batch_jac_lin_ops[i] * vs[i] ≈ manual_jac_lin_ops[i] * vs[i] + @test batch_jac_lin_ops[i]' * ws_lin_vec ≈ manual_jac_lin_ops[i]' * ws_lin_vec + end + + # Test batch_jac_lin_op! + jvs_lin = [zeros(meta.nlin) for _ = 1:n_models] + jtvs_lin = [zeros(n) for _ = 1:n_models] + batch_jac_lin_ops = batch_jac_lin_op!(bnlp, jvs_lin, jtvs_lin) + manual_jac_lin_ops = + [jac_lin_op!(models[i], zeros(meta.nlin), zeros(n)) for i = 1:n_models] + for i = 1:n_models + @test batch_jac_lin_ops[i] * vs[i] ≈ manual_jac_lin_ops[i] * vs[i] + @test batch_jac_lin_ops[i]' * ws_lin_vec ≈ manual_jac_lin_ops[i]' * ws_lin_vec + end + + # Test batch_jac_nln_op + batch_jac_nln_ops = batch_jac_nln_op(bnlp, xs) + manual_jac_nln_ops = [jac_nln_op(models[i], xs[i]) for i = 1:n_models] + ws_nln_vec = ws[1][(meta.nlin + 1):end] + for i = 1:n_models + @test batch_jac_nln_ops[i] * vs[i] ≈ manual_jac_nln_ops[i] * vs[i] + @test batch_jac_nln_ops[i]' * ws_nln_vec ≈ manual_jac_nln_ops[i]' * ws_nln_vec + end + + # Test batch_jac_nln_op! + jvs_nln = [zeros(meta.nnln) for _ = 1:n_models] + jtvs_nln = [zeros(n) for _ = 1:n_models] + batch_jac_nln_ops = batch_jac_nln_op!(bnlp, xs, jvs_nln, jtvs_nln) + manual_jac_nln_ops = + [jac_nln_op!(models[i], xs[i], zeros(meta.nnln), zeros(n)) for i = 1:n_models] + for i = 1:n_models + @test batch_jac_nln_ops[i] * vs[i] ≈ manual_jac_nln_ops[i] * vs[i] + @test batch_jac_nln_ops[i]' * ws_nln_vec ≈ manual_jac_nln_ops[i]' * ws_nln_vec + end + else + @test_throws ErrorException batch_jac_op(bnlp, xs) + @test_throws ErrorException batch_jac_op!(bnlp, xs, [zeros(m) for _ = 1:n_models], + [zeros(n) for _ = 1:n_models]) + @test_throws ErrorException batch_jac_lin_op(bnlp) + @test_throws ErrorException batch_jac_lin_op!(bnlp, + [zeros(meta.nlin) for _ = 1:n_models], + [zeros(n) for _ = 1:n_models]) + @test_throws ErrorException batch_jac_nln_op(bnlp, xs) + @test_throws ErrorException batch_jac_nln_op!(bnlp, + xs, + [zeros(meta.nnln) for _ = 1:n_models], + [zeros(n) for _ = 1:n_models]) end # Test batch_varscale, batch_lagscale, batch_conscale diff --git a/test/nlp/simple-model.jl b/test/nlp/simple-model.jl index 12071267..eca9a465 100644 --- a/test/nlp/simple-model.jl +++ b/test/nlp/simple-model.jl @@ -14,9 +14,10 @@ x₀ = [2.0, 2.0]. mutable struct SimpleNLPModel{T, S} <: AbstractNLPModel{T, S} meta::NLPModelMeta{T, S} counters::Counters + p::T end -function SimpleNLPModel(::Type{T}) where {T} +function SimpleNLPModel(::Type{T}; p = T(4)) where {T} meta = NLPModelMeta( 2, nnzh = 2, @@ -32,10 +33,10 @@ function SimpleNLPModel(::Type{T}) where {T} nln_nnzj = 2, ) - return SimpleNLPModel(meta, Counters()) + return SimpleNLPModel(meta, Counters(), T(p)) end -SimpleNLPModel() = SimpleNLPModel(Float64) +SimpleNLPModel(; p = 4.0) = SimpleNLPModel(Float64; p = p) function NLPModels.obj(nlp::SimpleNLPModel, x::AbstractVector) @lencheck 2 x @@ -73,7 +74,7 @@ function NLPModels.hess_coord!( @lencheck 2 x y vals increment!(nlp, :neval_hess) vals .= 2obj_weight - vals[1] -= y[2] / 2 + vals[1] -= 2y[2] / nlp.p vals[2] -= 2y[2] return vals end @@ -89,7 +90,7 @@ function NLPModels.hprod!( @lencheck 2 x y v Hv increment!(nlp, :neval_hprod) Hv .= 2obj_weight * v - Hv[1] -= y[2] * v[1] / 2 + Hv[1] -= (2y[2] / nlp.p) * v[1] Hv[2] -= 2y[2] * v[2] return Hv end @@ -98,7 +99,7 @@ function NLPModels.cons_nln!(nlp::SimpleNLPModel, x::AbstractVector, cx::Abstrac @lencheck 2 x @lencheck 1 cx increment!(nlp, :neval_cons_nln) - cx .= [-x[1]^2 / 4 - x[2]^2 + 1] + cx .= [-x[1]^2 / nlp.p - x[2]^2 + 1] return cx end @@ -135,7 +136,7 @@ end function NLPModels.jac_nln_coord!(nlp::SimpleNLPModel, x::AbstractVector, vals::AbstractVector) @lencheck 2 x vals increment!(nlp, :neval_jac_nln) - vals .= [-x[1] / 2, -2 * x[2]] + vals .= [-2 * x[1] / nlp.p, -2 * x[2]] return vals end @@ -155,7 +156,7 @@ function NLPModels.jprod_nln!( @lencheck 2 x v @lencheck 1 Jv increment!(nlp, :neval_jprod_nln) - Jv .= [-x[1] * v[1] / 2 - 2 * x[2] * v[2]] + Jv .= [-(2 * x[1] / nlp.p) * v[1] - 2 * x[2] * v[2]] return Jv end @@ -176,7 +177,7 @@ function NLPModels.jtprod_nln!( @lencheck 2 x Jtv @lencheck 1 v increment!(nlp, :neval_jtprod_nln) - Jtv .= [-x[1] * v[1] / 2; -2 * x[2] * v[1]] + Jtv .= [-(2 * x[1] / nlp.p) * v[1]; -2 * x[2] * v[1]] return Jtv end @@ -199,7 +200,7 @@ function NLPModels.jth_hess_coord!( if j == 1 vals .= 0 elseif j == 2 - vals[1] = -1 / 2 + vals[1] = -2 / nlp.p vals[2] = -2 end return vals @@ -217,7 +218,7 @@ function NLPModels.jth_hprod!( if j == 1 Hv .= 0 elseif j == 2 - Hv[1] = -v[1] / 2 + Hv[1] = -(2 / nlp.p) * v[1] Hv[2] = -2v[2] end return Hv @@ -233,7 +234,7 @@ function NLPModels.ghjvprod!( @lencheck nlp.meta.nvar x g v @lencheck nlp.meta.ncon gHv increment!(nlp, :neval_hprod) - gHv .= [T(0); -g[1] * v[1] / 2 - 2 * g[2] * v[2]] + gHv .= [T(0); -(2 * g[1] / nlp.p) * v[1] - 2 * g[2] * v[2]] return gHv end @@ -244,7 +245,7 @@ function NLPModels.jth_con(nlp::SimpleNLPModel, x::AbstractVector{T}, j::Integer if j == 1 return x[1] - 2 * x[2] + 1 elseif j == 2 - return -x[1]^2 / 4 - x[2]^2 + 1 + return -x[1]^2 / nlp.p - x[2]^2 + 1 end end @@ -260,7 +261,7 @@ function NLPModels.jth_congrad!( if j == 1 g .= [T(1); T(-2)] elseif j == 2 - g .= [-x[1] / 2; -2 * x[2]] + g .= [-2 * x[1] / nlp.p; -2 * x[2]] end return g end @@ -272,7 +273,7 @@ function NLPModels.jth_sparse_congrad(nlp::SimpleNLPModel, x::AbstractVector{T}, if j == 1 vals = [T(1); T(-2)] elseif j == 2 - vals = [-x[1] / 2; -2 * x[2]] + vals = [-2 * x[1] / nlp.p; -2 * x[2]] end return sparse([1, 1], [1, 2], vals, 1, nlp.meta.nvar) end From 88b8db6ef5728abc12be7c2578b5acf4ccd6931a Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 20 Nov 2025 17:35:06 -0500 Subject: [PATCH 11/13] Vararg{Any} -> Vararg{T} --- src/nlp/batch/foreach.jl | 4 ++-- src/nlp/batch/inplace.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/nlp/batch/foreach.jl b/src/nlp/batch/foreach.jl index 701f5a48..7bd89762 100644 --- a/src/nlp/batch/foreach.jl +++ b/src/nlp/batch/foreach.jl @@ -25,7 +25,7 @@ function _batch_map(f::F, bnlp::ForEachBatchNLPModel, xs::Vararg{T,N}) where {F, return results end -function _batch_map!(f::F, bnlp::ForEachBatchNLPModel, xs::Vararg{Any,N}) where {F,N} +function _batch_map!(f::F, bnlp::ForEachBatchNLPModel, xs::Vararg{T,N}) where {F,T,N} n = bnlp.batch_size length(xs) == 0 && error("Cannot call _batch_map! without providing arguments.") @lencheck_tup n xs @@ -52,7 +52,7 @@ function _batch_map_weight(f::F, bnlp::ForEachBatchNLPModel, obj_weights, xs::Va return results end -function _batch_map_weight!(f::F, bnlp::ForEachBatchNLPModel, obj_weights, xs::Vararg{Any,N}) where {F,N} +function _batch_map_weight!(f::F, bnlp::ForEachBatchNLPModel, obj_weights, xs::Vararg{T,N}) where {F,T,N} n = bnlp.batch_size length(xs) == 0 && error("Cannot call _batch_map_weight! without providing arguments.") @lencheck_tup n xs diff --git a/src/nlp/batch/inplace.jl b/src/nlp/batch/inplace.jl index 76cd60ea..ce9c035c 100644 --- a/src/nlp/batch/inplace.jl +++ b/src/nlp/batch/inplace.jl @@ -31,7 +31,7 @@ function _batch_map(f::F, bnlp::InplaceBatchNLPModel, xs::Vararg{T,N}) where {F, return results end -function _batch_map!(f::F, bnlp::InplaceBatchNLPModel, xs::Vararg{Any,N}) where {F,N} +function _batch_map!(f::F, bnlp::InplaceBatchNLPModel, xs::Vararg{T,N}) where {F,T,N} n = bnlp.batch_size length(xs) == 0 && error("Cannot call _batch_map! without providing arguments.") @lencheck_tup n xs @@ -60,7 +60,7 @@ function _batch_map_weight(f::F, bnlp::InplaceBatchNLPModel, obj_weights, xs::Va return results end -function _batch_map_weight!(f::F, bnlp::InplaceBatchNLPModel, obj_weights, xs::Vararg{Any,N}) where {F,N} +function _batch_map_weight!(f::F, bnlp::InplaceBatchNLPModel, obj_weights, xs::Vararg{T,N}) where {F,T,N} n = bnlp.batch_size length(xs) == 0 && error("_batch_map_weight! with zero args") @lencheck_tup n xs From fa1f3c83688915619d0e06ebab87eb591fd2af0f Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 20 Nov 2025 17:53:39 -0500 Subject: [PATCH 12/13] simplify inplace syntax --- test/nlp/batch_api.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/nlp/batch_api.jl b/test/nlp/batch_api.jl index 0d44d246..51a0a625 100644 --- a/test/nlp/batch_api.jl +++ b/test/nlp/batch_api.jl @@ -19,10 +19,7 @@ function make_inplace_batch_model() base_model = SimpleNLPModel() updates = [ - begin - param = p - nlp -> (nlp.p = param) - end for p in p_values + nlp -> (nlp.p = p) for p in p_values ] return InplaceBatchNLPModel(base_model, updates) end From d3588dd26c6d422606c650ece90295dc3a1f133a Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Mon, 24 Nov 2025 16:32:26 -0500 Subject: [PATCH 13/13] add todo --- src/nlp/batch/inplace.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nlp/batch/inplace.jl b/src/nlp/batch/inplace.jl index ce9c035c..a8ec84e3 100644 --- a/src/nlp/batch/inplace.jl +++ b/src/nlp/batch/inplace.jl @@ -200,7 +200,7 @@ batch_jtprod_lin!(bnlp::InplaceBatchNLPModel, vs, Jtvs) = batch_ghjvprod!(bnlp::InplaceBatchNLPModel, xs, gs, vs, gHvs) = _batch_map!(ghjvprod!, bnlp, xs, gs, vs, gHvs) -## jth +## jth FIXME: allow for vector of js batch_jth_con(bnlp::InplaceBatchNLPModel, xs, j::Integer) = _batch_map((m, x) -> jth_con(m, x, j), bnlp, xs) batch_jth_congrad(bnlp::InplaceBatchNLPModel, xs, j::Integer) =