diff --git a/perf/neural.jl b/perf/neural.jl index bce9a26..bef2298 100644 --- a/perf/neural.jl +++ b/perf/neural.jl @@ -5,5 +5,6 @@ using ArrayDiff n = 2 X = rand(n, n) model = Model() -@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) -W * X +@variable(model, W1[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) +@variable(model, W2[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) +W2 * tanh.(W1 * X) diff --git a/src/JuMP/JuMP.jl b/src/JuMP/JuMP.jl index ca21a84..c75a800 100644 --- a/src/JuMP/JuMP.jl +++ b/src/JuMP/JuMP.jl @@ -4,6 +4,7 @@ import JuMP # Equivalent of `AbstractJuMPScalar` but for arrays abstract type AbstractJuMPArray{T,N} <: AbstractArray{T,N} end +const AbstractJuMPMatrix{T} = AbstractJuMPArray{T,2} include("variables.jl") include("nlp_expr.jl") diff --git a/src/JuMP/nlp_expr.jl b/src/JuMP/nlp_expr.jl index 76ef5ad..e5d04d9 100644 --- a/src/JuMP/nlp_expr.jl +++ b/src/JuMP/nlp_expr.jl @@ -3,9 +3,13 @@ struct GenericArrayExpr{V<:JuMP.AbstractVariableRef,N} <: head::Symbol args::Vector{Any} size::NTuple{N,Int} + broadcasted::Bool end +const GenericMatrixExpr{V<:JuMP.AbstractVariableRef} = GenericArrayExpr{V,2} const ArrayExpr{N} = GenericArrayExpr{JuMP.VariableRef,N} +const MatrixExpr = ArrayExpr{2} +const VectorExpr = ArrayExpr{1} function Base.getindex(::GenericArrayExpr, args...) return error( @@ -14,3 +18,5 @@ function Base.getindex(::GenericArrayExpr, args...) end Base.size(expr::GenericArrayExpr) = expr.size + +JuMP.variable_ref_type(::Type{GenericMatrixExpr{V}}) where {V} = V diff --git a/src/JuMP/operators.jl b/src/JuMP/operators.jl index d81bd35..6907837 100644 --- a/src/JuMP/operators.jl +++ b/src/JuMP/operators.jl @@ -1,7 +1,30 @@ -function Base.:(*)(A::MatrixOfVariables, B::Matrix) - return GenericArrayExpr{JuMP.variable_ref_type(A.model),2}( - :*, - Any[A, B], - (size(A, 1), size(B, 2)), - ) +function _matmul(::Type{V}, A, B) where {V} + return GenericMatrixExpr{V}(:*, Any[A, B], (size(A, 1), size(B, 2)), false) +end + +function Base.:(*)(A::AbstractJuMPMatrix, B::Matrix) + return _matmul(JuMP.variable_ref_type(A), A, B) +end +function Base.:(*)(A::Matrix, B::AbstractJuMPMatrix) + return _matmul(JuMP.variable_ref_type(B), A, B) +end +function Base.:(*)(A::AbstractJuMPMatrix, B::AbstractJuMPMatrix) + return _matmul(JuMP.variable_ref_type(A), A, B) +end + +function __broadcast( + ::Type{V}, + axes::NTuple{N,Base.OneTo{Int}}, + op::Function, + args::Vector{Any}, +) where {V,N} + return GenericArrayExpr{V,N}(Symbol(op), args, length.(axes), true) +end + +function _broadcast(::Type{V}, op::Function, args...) where {V} + return __broadcast(V, Broadcast.combine_axes(args...), op, Any[args...]) +end + +function Base.broadcasted(op::Function, x::AbstractJuMPArray) + return _broadcast(JuMP.variable_ref_type(x), op, x) end diff --git a/src/JuMP/variables.jl b/src/JuMP/variables.jl index f70de6a..1c5a59a 100644 --- a/src/JuMP/variables.jl +++ b/src/JuMP/variables.jl @@ -15,6 +15,10 @@ function Base.getindex(A::ArrayOfVariables{T}, I...) where {T} return JuMP.GenericVariableRef{T}(A.model, MOI.VariableIndex(index)) end +function JuMP.variable_ref_type(::Type{ArrayOfVariables{T,N}}) where {T,N} + return JuMP.variable_ref_type(JuMP.GenericModel{T}) +end + function JuMP.Containers.container( f::Function, indices::JuMP.Containers.VectorizedProductIterator{ diff --git a/test/JuMP.jl b/test/JuMP.jl index e0de6c8..95a9fba 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -16,25 +16,41 @@ function runtests() return end -function test_array_product() +function test_neural() n = 2 X = rand(n, n) model = Model() - @variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) - @test W isa ArrayDiff.MatrixOfVariables{Float64} - @test JuMP.index(W[1, 1]) == MOI.VariableIndex(1) - @test JuMP.index(W[2, 1]) == MOI.VariableIndex(2) - @test JuMP.index(W[2]) == MOI.VariableIndex(2) - @test sprint(show, W) == + @variable(model, W1[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + @variable(model, W2[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + @test W1 isa ArrayDiff.MatrixOfVariables{Float64} + @test JuMP.index(W1[1, 1]) == MOI.VariableIndex(1) + @test JuMP.index(W1[2, 1]) == MOI.VariableIndex(2) + @test JuMP.index(W1[2]) == MOI.VariableIndex(2) + @test sprint(show, W1) == "2×2 ArrayDiff.ArrayOfVariables{Float64, 2} with offset 0" - prod = W * X - @test prod isa ArrayDiff.ArrayExpr{2} - @test sprint(show, prod) == - "2×2 ArrayDiff.GenericArrayExpr{JuMP.VariableRef, 2}" - err = ErrorException( - "`getindex` not implemented, build vectorized expression instead", - ) - @test_throws err prod[1, 1] + for prod in [W1 * X, X * W1] + @test prod isa ArrayDiff.MatrixExpr + @test prod.head == :* + @test !prod.broadcasted + @test sprint(show, prod) == + "2×2 ArrayDiff.GenericArrayExpr{JuMP.VariableRef, 2}" + err = ErrorException( + "`getindex` not implemented, build vectorized expression instead", + ) + @test_throws err prod[1, 1] + end + Y1 = W1 * X + X1 = tanh.(Y1) + @test X1 isa ArrayDiff.MatrixExpr + @test X1.head == :tanh + @test X1.broadcasted + @test X1.args[] === Y1 + Y2 = W2 * X1 + @test Y2.head == :* + @test !Y2.broadcasted + @test length(Y2.args) == 2 + @test Y2.args[1] === W2 + @test Y2.args[2] === X1 return end