Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/NLPModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,8 @@ for f in ["utils", "api", "counters", "meta", "show", "tools"]
include("nls/$f.jl")
end

include("nlp/batch/api.jl")
include("nlp/batch/foreach.jl")
include("nlp/batch/inplace.jl")

end # module
91 changes: 91 additions & 0 deletions src/nlp/batch/api.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
export AbstractBatchNLPModel
export batch_obj, batch_grad, batch_grad!, batch_objgrad, batch_objgrad!, batch_objcons, batch_objcons!
export batch_cons, batch_cons!, batch_cons_lin, batch_cons_lin!, batch_cons_nln, batch_cons_nln!
export batch_jth_con, batch_jth_congrad, batch_jth_congrad!, batch_jth_sparse_congrad
export batch_jac_structure!, batch_jac_structure, batch_jac_coord!, batch_jac_coord
export batch_jac, batch_jprod, batch_jprod!, batch_jtprod, batch_jtprod!, batch_jac_op, batch_jac_op!
export batch_jac_lin_structure!, batch_jac_lin_structure, batch_jac_lin_coord!, batch_jac_lin_coord
export batch_jac_lin, batch_jprod_lin, batch_jprod_lin!, batch_jtprod_lin, batch_jtprod_lin!, batch_jac_lin_op, batch_jac_lin_op!
export batch_jac_nln_structure!, batch_jac_nln_structure, batch_jac_nln_coord!, batch_jac_nln_coord
export batch_jac_nln, batch_jprod_nln, batch_jprod_nln!, batch_jtprod_nln, batch_jtprod_nln!, batch_jac_nln_op, batch_jac_nln_op!
export batch_jth_hess_coord, batch_jth_hess_coord!, batch_jth_hess
export batch_jth_hprod, batch_jth_hprod!, batch_ghjvprod, batch_ghjvprod!
export batch_hess_structure!, batch_hess_structure, batch_hess_coord!, batch_hess_coord
export batch_hess, batch_hprod, batch_hprod!, batch_hess_op, batch_hess_op!
export batch_varscale, batch_lagscale, batch_conscale

abstract type AbstractBatchNLPModel end

function NLPModels.increment!(bnlp::AbstractBatchNLPModel, fun::Symbol)
NLPModels.increment!(bnlp, Val(fun))
end

function batch_obj end
function batch_grad end
function batch_grad! end
function batch_objgrad end
function batch_objgrad! end
function batch_objcons end
function batch_objcons! end
function batch_cons end
function batch_cons! end
function batch_cons_lin end
function batch_cons_lin! end
function batch_cons_nln end
function batch_cons_nln! end
function batch_jth_con end
function batch_jth_congrad end
function batch_jth_congrad! end
function batch_jth_sparse_congrad end
function batch_jac_structure! end
function batch_jac_structure end
function batch_jac_coord! end
function batch_jac_coord end
function batch_jac end
function batch_jprod end
function batch_jprod! end
function batch_jtprod end
function batch_jtprod! end
function batch_jac_op end
function batch_jac_op! end
function batch_jac_lin_structure! end
function batch_jac_lin_structure end
function batch_jac_lin_coord! end
function batch_jac_lin_coord end
function batch_jac_lin end
function batch_jprod_lin end
function batch_jprod_lin! end
function batch_jtprod_lin end
function batch_jtprod_lin! end
function batch_jac_lin_op end
function batch_jac_lin_op! end
function batch_jac_nln_structure! end
function batch_jac_nln_structure end
function batch_jac_nln_coord! end
function batch_jac_nln_coord end
function batch_jac_nln end
function batch_jprod_nln end
function batch_jprod_nln! end
function batch_jtprod_nln end
function batch_jtprod_nln! end
function batch_jac_nln_op end
function batch_jac_nln_op! end
function batch_jth_hess_coord end
function batch_jth_hess_coord! end
function batch_jth_hess end
function batch_jth_hprod end
function batch_jth_hprod! end
function batch_ghjvprod end
function batch_ghjvprod! end
function batch_hess_structure! end
function batch_hess_structure end
function batch_hess_coord! end
function batch_hess_coord end
function batch_hess end
function batch_hprod end
function batch_hprod! end
function batch_hess_op end
function batch_hess_op! end
function batch_varscale end
function batch_lagscale end
function batch_conscale end
270 changes: 270 additions & 0 deletions src/nlp/batch/foreach.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
export ForEachBatchNLPModel
struct ForEachBatchNLPModel{M} <: AbstractBatchNLPModel
models::M
counters::Counters
batch_size::Int
end
function ForEachBatchNLPModel(models::M) where {M}
isempty(models) && error("Cannot create ForEachBatchNLPModel from empty collection.")
ForEachBatchNLPModel{M}(models, Counters(), length(models))
end
Base.length(vnlp::ForEachBatchNLPModel) = vnlp.batch_size
Base.getindex(vnlp::ForEachBatchNLPModel, i::Integer) = vnlp.models[i]
Base.iterate(vnlp::ForEachBatchNLPModel, state::Integer = 1) = iterate(vnlp.models, state)


function _batch_map(f::F, bnlp::ForEachBatchNLPModel, xs::Vararg{T,N}) where {F,T,N}
n = bnlp.batch_size
@lencheck_tup n xs
results = []
resize!(results, n)
for i = 1:n
args_i = (x[i] for x in xs)
results[i] = f(bnlp[i], args_i...)
end
return results
end

function _batch_map!(f::F, bnlp::ForEachBatchNLPModel, xs::Vararg{T,N}) where {F,T,N}
n = bnlp.batch_size
length(xs) == 0 && error("Cannot call _batch_map! without providing arguments.")
@lencheck_tup n xs
outputs = xs[end]
inputs = length(xs) == 1 ? () : Base.ntuple(i -> xs[i], length(xs) - 1)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am hoping that making Julia specialize on the N in Vararg{Any,N} will make this compile down to something reasonable

@lencheck n outputs
for i = 1:n
args_i = (x[i] for x in inputs)
f(bnlp[i], args_i..., outputs[i])
end
return outputs
end

function _batch_map_weight(f::F, bnlp::ForEachBatchNLPModel, obj_weights, xs::Vararg{T,N}) where {F,T,N}
n = bnlp.batch_size
@lencheck_tup n xs
@lencheck n obj_weights
results = []
resize!(results, n)
for i = 1:n
args_i = (x[i] for x in xs)
results[i] = f(bnlp[i], args_i...; obj_weight = obj_weights[i])
end
return results
end

function _batch_map_weight!(f::F, bnlp::ForEachBatchNLPModel, obj_weights, xs::Vararg{T,N}) where {F,T,N}
n = bnlp.batch_size
length(xs) == 0 && error("Cannot call _batch_map_weight! without providing arguments.")
@lencheck_tup n xs
@lencheck n obj_weights
outputs = xs[end]
inputs = length(xs) == 1 ? () : Base.ntuple(i -> xs[i], length(xs) - 1)
@lencheck n outputs
for i = 1:n
args_i = (x[i] for x in inputs)
f(bnlp[i], args_i..., outputs[i]; obj_weight = obj_weights[i])
end
return outputs
end

function _batch_map_tuple(f::F, bnlp::ForEachBatchNLPModel, xs::Vararg{T,N}) where {F,T,N}
n = bnlp.batch_size
@lencheck_tup n xs
results = _batch_map(f, bnlp, xs...)

first_result = first(results)
T1, T2 = typeof(first_result[1]), typeof(first_result[2])
vec1, vec2 = Vector{T1}(undef, n), Vector{T2}(undef, n)
for i = 1:n
vec1[i], vec2[i] = results[i]
end
return vec1, vec2
end

function _batch_map_tuple!(f::F, bnlp::ForEachBatchNLPModel, outputs, xs::Vararg{T,N}) where {F,T,N}
n = bnlp.batch_size
@lencheck_tup n xs
@lencheck n outputs
firsts = []
resize!(firsts, n)
for i = 1:n
args_i = (x[i] for x in xs)
firsts[i], _ = f(bnlp[i], args_i..., outputs[i])
end
return firsts, outputs
end

for fun in fieldnames(Counters)
@eval function NLPModels.increment!(bnlp::ForEachBatchNLPModel, ::Val{$(Meta.quot(fun))})
# sub-model counters are already incremented since we call their methods
bnlp.counters.$fun += 1
end
end


batch_jac_structure(bnlp::ForEachBatchNLPModel) =
_batch_map(jac_structure, bnlp)
batch_jac_lin_structure(bnlp::ForEachBatchNLPModel) =
_batch_map(jac_lin_structure, bnlp)
batch_jac_nln_structure(bnlp::ForEachBatchNLPModel) =
_batch_map(jac_nln_structure, bnlp)
batch_hess_structure(bnlp::ForEachBatchNLPModel) =
_batch_map(hess_structure, bnlp)
batch_obj(bnlp::ForEachBatchNLPModel, xs) =
_batch_map(obj, bnlp, xs)
batch_grad(bnlp::ForEachBatchNLPModel, xs) =
_batch_map(grad, bnlp, xs)
batch_cons(bnlp::ForEachBatchNLPModel, xs) =
_batch_map(cons, bnlp, xs)
batch_cons_lin(bnlp::ForEachBatchNLPModel, xs) =
_batch_map(cons_lin, bnlp, xs)
batch_cons_nln(bnlp::ForEachBatchNLPModel, xs) =
_batch_map(cons_nln, bnlp, xs)
batch_jac(bnlp::ForEachBatchNLPModel, xs) =
_batch_map(jac, bnlp, xs)
batch_jac_lin(bnlp::ForEachBatchNLPModel) =
_batch_map(jac_lin, bnlp)
batch_jac_nln(bnlp::ForEachBatchNLPModel, xs) =
_batch_map(jac_nln, bnlp, xs)
batch_jac_lin_coord(bnlp::ForEachBatchNLPModel) =
_batch_map(jac_lin_coord, bnlp)
batch_jac_coord(bnlp::ForEachBatchNLPModel, xs) =
_batch_map(jac_coord, bnlp, xs)
batch_jac_nln_coord(bnlp::ForEachBatchNLPModel, xs) =
_batch_map(jac_nln_coord, bnlp, xs)
batch_varscale(bnlp::ForEachBatchNLPModel) =
_batch_map(varscale, bnlp)
batch_lagscale(bnlp::ForEachBatchNLPModel) =
_batch_map(lagscale, bnlp)
batch_conscale(bnlp::ForEachBatchNLPModel) =
_batch_map(conscale, bnlp)
batch_jprod(bnlp::ForEachBatchNLPModel, xs, vs) =
_batch_map(jprod, bnlp, xs, vs)
batch_jtprod(bnlp::ForEachBatchNLPModel, xs, vs) =
_batch_map(jtprod, bnlp, xs, vs)
batch_jprod_nln(bnlp::ForEachBatchNLPModel, xs, vs) =
_batch_map(jprod_nln, bnlp, xs, vs)
batch_jtprod_nln(bnlp::ForEachBatchNLPModel, xs, vs) =
_batch_map(jtprod_nln, bnlp, xs, vs)
batch_jprod_lin(bnlp::ForEachBatchNLPModel, vs) =
_batch_map(jprod_lin, bnlp, vs)
batch_jtprod_lin(bnlp::ForEachBatchNLPModel, vs) =
_batch_map(jtprod_lin, bnlp, vs)
batch_ghjvprod(bnlp::ForEachBatchNLPModel, xs, gs, vs) =
_batch_map(ghjvprod, bnlp, xs, gs, vs)

batch_jac_structure!(bnlp::ForEachBatchNLPModel, rowss, colss) =
_batch_map!(jac_structure!, bnlp, rowss, colss)
batch_jac_lin_structure!(bnlp::ForEachBatchNLPModel, rowss, colss) =
_batch_map!(jac_lin_structure!, bnlp, rowss, colss)
batch_jac_nln_structure!(bnlp::ForEachBatchNLPModel, rowss, colss) =
_batch_map!(jac_nln_structure!, bnlp, rowss, colss)
batch_hess_structure!(bnlp::ForEachBatchNLPModel, rowss, colss) =
_batch_map!(hess_structure!, bnlp, rowss, colss)
batch_jac_lin_coord!(bnlp::ForEachBatchNLPModel, valss) =
_batch_map!(jac_lin_coord!, bnlp, valss)
batch_grad!(bnlp::ForEachBatchNLPModel, xs, gs) =
_batch_map!(grad!, bnlp, xs, gs)
batch_cons!(bnlp::ForEachBatchNLPModel, xs, cs) =
_batch_map!(cons!, bnlp, xs, cs)
batch_cons_lin!(bnlp::ForEachBatchNLPModel, xs, cs) =
_batch_map!(cons_lin!, bnlp, xs, cs)
batch_cons_nln!(bnlp::ForEachBatchNLPModel, xs, cs) =
_batch_map!(cons_nln!, bnlp, xs, cs)
batch_jac_coord!(bnlp::ForEachBatchNLPModel, xs, valss) =
_batch_map!(jac_coord!, bnlp, xs, valss)
batch_jac_nln_coord!(bnlp::ForEachBatchNLPModel, xs, valss) =
_batch_map!(jac_nln_coord!, bnlp, xs, valss)
batch_jprod!(bnlp::ForEachBatchNLPModel, xs, vs, Jvs) =
_batch_map!(jprod!, bnlp, xs, vs, Jvs)
batch_jtprod!(bnlp::ForEachBatchNLPModel, xs, vs, Jtvs) =
_batch_map!(jtprod!, bnlp, xs, vs, Jtvs)
batch_jprod_nln!(bnlp::ForEachBatchNLPModel, xs, vs, Jvs) =
_batch_map!(jprod_nln!, bnlp, xs, vs, Jvs)
batch_jtprod_nln!(bnlp::ForEachBatchNLPModel, xs, vs, Jtvs) =
_batch_map!(jtprod_nln!, bnlp, xs, vs, Jtvs)
batch_jprod_lin!(bnlp::ForEachBatchNLPModel, vs, Jvs) =
_batch_map!(jprod_lin!, bnlp, vs, Jvs)
batch_jtprod_lin!(bnlp::ForEachBatchNLPModel, vs, Jtvs) =
_batch_map!(jtprod_lin!, bnlp, vs, Jtvs)
batch_ghjvprod!(bnlp::ForEachBatchNLPModel, xs, gs, vs, gHvs) =
_batch_map!(ghjvprod!, bnlp, xs, gs, vs, gHvs)

## jth
batch_jth_con(bnlp::ForEachBatchNLPModel, xs, j::Integer) =
_batch_map((m, x) -> jth_con(m, x, j), bnlp, xs)
batch_jth_congrad(bnlp::ForEachBatchNLPModel, xs, j::Integer) =
_batch_map((m, x) -> jth_congrad(m, x, j), bnlp, xs)
batch_jth_sparse_congrad(bnlp::ForEachBatchNLPModel, xs, j::Integer) =
_batch_map((m, x) -> jth_sparse_congrad(m, x, j), bnlp, xs)
batch_jth_hess_coord(bnlp::ForEachBatchNLPModel, xs, j::Integer) =
_batch_map((m, x) -> jth_hess_coord(m, x, j), bnlp, xs)
batch_jth_hess(bnlp::ForEachBatchNLPModel, xs, j::Integer) =
_batch_map((m, x) -> jth_hess(m, x, j), bnlp, xs)
batch_jth_hprod(bnlp::ForEachBatchNLPModel, xs, vs, j::Integer) =
_batch_map((m, x, v) -> jth_hprod(m, x, v, j), bnlp, xs, vs)

batch_jth_congrad!(bnlp::ForEachBatchNLPModel, xs, j::Integer, outputs) =
_batch_map!((m, x, out) -> jth_congrad!(m, x, j, out), bnlp, xs, outputs)
batch_jth_hess_coord!(bnlp::ForEachBatchNLPModel, xs, j::Integer, outputs) =
_batch_map!((m, x, out) -> jth_hess_coord!(m, x, j, out), bnlp, xs, outputs)
batch_jth_hprod!(bnlp::ForEachBatchNLPModel, xs, vs, j::Integer, outputs) =
_batch_map!((m, x, v, out) -> jth_hprod!(m, x, v, j, out), bnlp, xs, vs, outputs)

# hess (need to treat obj_weight) FIXME: obj_weights is required in batch API
batch_hprod(bnlp::ForEachBatchNLPModel, xs, vs; obj_weights) =
_batch_map_weight(hprod, bnlp, obj_weights, xs, vs)
batch_hprod(bnlp::ForEachBatchNLPModel, xs, ys, vs; obj_weights) =
_batch_map_weight(hprod, bnlp, obj_weights, xs, ys, vs)
batch_hess_coord(bnlp::ForEachBatchNLPModel, xs; obj_weights) =
_batch_map_weight(hess_coord, bnlp, obj_weights, xs)
batch_hess_coord(bnlp::ForEachBatchNLPModel, xs, ys; obj_weights) =
_batch_map_weight(hess_coord, bnlp, obj_weights, xs, ys)
batch_hess_op(bnlp::ForEachBatchNLPModel, xs; obj_weights) =
_batch_map_weight(hess_op, bnlp, obj_weights, xs)
batch_hess_op(bnlp::ForEachBatchNLPModel, xs, ys; obj_weights) =
_batch_map_weight(hess_op, bnlp, obj_weights, xs, ys)

batch_hprod!(bnlp::ForEachBatchNLPModel, xs, vs, outputs; obj_weights) =
_batch_map_weight!(hprod!, bnlp, obj_weights, xs, vs, outputs)
batch_hprod!(bnlp::ForEachBatchNLPModel, xs, ys, vs, outputs; obj_weights) =
_batch_map_weight!(hprod!, bnlp, obj_weights, xs, ys, vs, outputs)
batch_hess_coord!(bnlp::ForEachBatchNLPModel, xs, outputs; obj_weights) =
_batch_map_weight!(hess_coord!, bnlp, obj_weights, xs, outputs)
batch_hess_coord!(bnlp::ForEachBatchNLPModel, xs, ys, outputs; obj_weights) =
_batch_map_weight!(hess_coord!, bnlp, obj_weights, xs, ys, outputs)
batch_hess_op!(bnlp::ForEachBatchNLPModel, xs, Hvs; obj_weights) =
_batch_map_weight(hess_op!, bnlp, obj_weights, xs, Hvs)
batch_hess_op!(bnlp::ForEachBatchNLPModel, xs, ys, Hvs; obj_weights) =
_batch_map_weight(hess_op!, bnlp, obj_weights, xs, ys, Hvs)

batch_hess(bnlp::ForEachBatchNLPModel, xs; obj_weights) =
_batch_map_weight(hess, bnlp, obj_weights, xs)
batch_hess(bnlp::ForEachBatchNLPModel, xs, ys; obj_weights) =
_batch_map_weight(hess, bnlp, obj_weights, xs, ys)

## operators
batch_jac_op(bnlp::ForEachBatchNLPModel, xs) =
_batch_map(jac_op, bnlp, xs)
batch_jac_lin_op(bnlp::ForEachBatchNLPModel) =
_batch_map(jac_lin_op, bnlp)
batch_jac_nln_op(bnlp::ForEachBatchNLPModel, xs) =
_batch_map(jac_nln_op, bnlp, xs)

batch_jac_op!(bnlp::ForEachBatchNLPModel, xs, Jvs, Jtvs) =
_batch_map(jac_op!, bnlp, xs, Jvs, Jtvs)
batch_jac_lin_op!(bnlp::ForEachBatchNLPModel, Jvs, Jtvs) =
_batch_map(jac_lin_op!, bnlp, Jvs, Jtvs)
batch_jac_nln_op!(bnlp::ForEachBatchNLPModel, xs, Jvs, Jtvs) =
_batch_map(jac_nln_op!, bnlp, xs, Jvs, Jtvs)

## tuple functions
batch_objgrad(bnlp::ForEachBatchNLPModel, xs) =
_batch_map_tuple(objgrad, bnlp, xs)
batch_objcons(bnlp::ForEachBatchNLPModel, xs) =
_batch_map_tuple(objcons, bnlp, xs)

batch_objgrad!(bnlp::ForEachBatchNLPModel, xs, gs) =
_batch_map_tuple!(objgrad!, bnlp, gs, xs)
batch_objcons!(bnlp::ForEachBatchNLPModel, xs, cs) =
_batch_map_tuple!(objcons!, bnlp, cs, xs)
Loading