From 1006b01d30a19248f0d05b27dd76360a2d6b4141 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 4 Mar 2026 17:16:47 +0100 Subject: [PATCH 1/4] Implement broadcasting --- Project.toml | 2 ++ perf/neural.jl | 5 +++-- src/JuMP/JuMP.jl | 2 ++ src/JuMP/nlp_expr.jl | 6 ++++++ src/JuMP/operators.jl | 40 ++++++++++++++++++++++++++++++++++++++-- src/JuMP/variables.jl | 4 ++++ 6 files changed, 55 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 5670237..f4aa1ad 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JuMP = "4076af6c-e467-56ae-b986-b466b2749572" MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" +MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -20,6 +21,7 @@ DataStructures = "0.18, 0.19" ForwardDiff = "1" JuMP = "1.29.4" MathOptInterface = "1.40" +MutableArithmetics = "1.6.7" NaNMath = "1" SparseArrays = "1.10" SpecialFunctions = "2.6.1" diff --git a/perf/neural.jl b/perf/neural.jl index bce9a26..bef2298 100644 --- a/perf/neural.jl +++ b/perf/neural.jl @@ -5,5 +5,6 @@ using ArrayDiff n = 2 X = rand(n, n) model = Model() -@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) -W * X +@variable(model, W1[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) +@variable(model, W2[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) +W2 * tanh.(W1 * X) diff --git a/src/JuMP/JuMP.jl b/src/JuMP/JuMP.jl index ca21a84..f11f294 100644 --- a/src/JuMP/JuMP.jl +++ b/src/JuMP/JuMP.jl @@ -1,9 +1,11 @@ # JuMP extension +import MutableArithmetics as MA import JuMP # Equivalent of `AbstractJuMPScalar` but for arrays abstract type AbstractJuMPArray{T,N} <: AbstractArray{T,N} end +const AbstractJuMPMatrix{T} = AbstractJuMPArray{T,2} include("variables.jl") include("nlp_expr.jl") diff --git a/src/JuMP/nlp_expr.jl b/src/JuMP/nlp_expr.jl index 76ef5ad..e5d04d9 100644 --- a/src/JuMP/nlp_expr.jl +++ b/src/JuMP/nlp_expr.jl @@ -3,9 +3,13 @@ struct GenericArrayExpr{V<:JuMP.AbstractVariableRef,N} <: head::Symbol args::Vector{Any} size::NTuple{N,Int} + broadcasted::Bool end +const GenericMatrixExpr{V<:JuMP.AbstractVariableRef} = GenericArrayExpr{V,2} const ArrayExpr{N} = GenericArrayExpr{JuMP.VariableRef,N} +const MatrixExpr = ArrayExpr{2} +const VectorExpr = ArrayExpr{1} function Base.getindex(::GenericArrayExpr, args...) return error( @@ -14,3 +18,5 @@ function Base.getindex(::GenericArrayExpr, args...) end Base.size(expr::GenericArrayExpr) = expr.size + +JuMP.variable_ref_type(::Type{GenericMatrixExpr{V}}) where {V} = V diff --git a/src/JuMP/operators.jl b/src/JuMP/operators.jl index d81bd35..6b314a9 100644 --- a/src/JuMP/operators.jl +++ b/src/JuMP/operators.jl @@ -1,7 +1,43 @@ -function Base.:(*)(A::MatrixOfVariables, B::Matrix) - return GenericArrayExpr{JuMP.variable_ref_type(A.model),2}( +function _matmul(::Type{V}, A, B) where {V} + return GenericMatrixExpr{V}( :*, Any[A, B], (size(A, 1), size(B, 2)), + false, ) end + +Base.:(*)(A::AbstractJuMPMatrix, B::Matrix) = _matmul(JuMP.variable_ref_type(A), A, B) +Base.:(*)(A::Matrix, B::AbstractJuMPMatrix) = _matmul(JuMP.variable_ref_type(B), A, B) +Base.:(*)(A::AbstractJuMPMatrix, B::AbstractJuMPMatrix) = _matmul(JuMP.variable_ref_type(A), A, B) + +function __broadcast( + ::Type{V}, + axes::NTuple{N,Base.OneTo{Int}}, + op::Function, + args::Vector{Any}, +) where {V,N} + return GenericArrayExpr{V,N}( + Symbol(op), + args, + length.(axes), + true, + ) +end + +function _broadcast( + ::Type{V}, + op::Function, + args..., +) where {V} + return __broadcast( + V, + Broadcast.combine_axes(args...), + op, + Any[args...], + ) +end + +function Base.broadcasted(op::Function, x::AbstractJuMPArray) + return _broadcast(JuMP.variable_ref_type(x), op, x) +end diff --git a/src/JuMP/variables.jl b/src/JuMP/variables.jl index f70de6a..1c5a59a 100644 --- a/src/JuMP/variables.jl +++ b/src/JuMP/variables.jl @@ -15,6 +15,10 @@ function Base.getindex(A::ArrayOfVariables{T}, I...) where {T} return JuMP.GenericVariableRef{T}(A.model, MOI.VariableIndex(index)) end +function JuMP.variable_ref_type(::Type{ArrayOfVariables{T,N}}) where {T,N} + return JuMP.variable_ref_type(JuMP.GenericModel{T}) +end + function JuMP.Containers.container( f::Function, indices::JuMP.Containers.VectorizedProductIterator{ From fb40ff7c9e0ad91a1959978d888380444d88ddc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 4 Mar 2026 17:17:10 +0100 Subject: [PATCH 2/4] Remove MutableArithmetics --- Project.toml | 2 -- src/JuMP/JuMP.jl | 1 - 2 files changed, 3 deletions(-) diff --git a/Project.toml b/Project.toml index f4aa1ad..5670237 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JuMP = "4076af6c-e467-56ae-b986-b466b2749572" MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" -MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -21,7 +20,6 @@ DataStructures = "0.18, 0.19" ForwardDiff = "1" JuMP = "1.29.4" MathOptInterface = "1.40" -MutableArithmetics = "1.6.7" NaNMath = "1" SparseArrays = "1.10" SpecialFunctions = "2.6.1" diff --git a/src/JuMP/JuMP.jl b/src/JuMP/JuMP.jl index f11f294..c75a800 100644 --- a/src/JuMP/JuMP.jl +++ b/src/JuMP/JuMP.jl @@ -1,6 +1,5 @@ # JuMP extension -import MutableArithmetics as MA import JuMP # Equivalent of `AbstractJuMPScalar` but for arrays From 95d99a29ede65d8151fada0ed2f2208aa5dfeb1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Thu, 5 Mar 2026 12:06:33 +0100 Subject: [PATCH 3/4] Add tests --- test/JuMP.jl | 46 +++++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/test/JuMP.jl b/test/JuMP.jl index e0de6c8..95ca758 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -16,25 +16,41 @@ function runtests() return end -function test_array_product() +function test_neural() n = 2 X = rand(n, n) model = Model() - @variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) - @test W isa ArrayDiff.MatrixOfVariables{Float64} - @test JuMP.index(W[1, 1]) == MOI.VariableIndex(1) - @test JuMP.index(W[2, 1]) == MOI.VariableIndex(2) - @test JuMP.index(W[2]) == MOI.VariableIndex(2) - @test sprint(show, W) == + @variable(model, W1[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + @variable(model, W2[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + @test W1 isa ArrayDiff.MatrixOfVariables{Float64} + @test JuMP.index(W1[1, 1]) == MOI.VariableIndex(1) + @test JuMP.index(W1[2, 1]) == MOI.VariableIndex(2) + @test JuMP.index(W1[2]) == MOI.VariableIndex(2) + @test sprint(show, W1) == "2×2 ArrayDiff.ArrayOfVariables{Float64, 2} with offset 0" - prod = W * X - @test prod isa ArrayDiff.ArrayExpr{2} - @test sprint(show, prod) == - "2×2 ArrayDiff.GenericArrayExpr{JuMP.VariableRef, 2}" - err = ErrorException( - "`getindex` not implemented, build vectorized expression instead", - ) - @test_throws err prod[1, 1] + for prod in [W1 * X, X * W1] + @test prod isa ArrayDiff.MatrixExpr + @test prod.head == :* + @test !prod.broadcasted + @test sprint(show, prod) == + "2×2 ArrayDiff.GenericArrayExpr{JuMP.VariableRef, 2}" + err = ErrorException( + "`getindex` not implemented, build vectorized expression instead", + ) + @test_throws err prod[1, 1] + end + Y1 = W1 * X + X1 = tanh.(Y1) + @test X1 isa ArrayDiff.MatrixExpr + @test X1.head == :tanh + @test X1.broadcasted + @test X1.args[] === Y1 + Y2 = W2 * X1 + @test Y2.head == :* + @test !Y2.broadcasted + @test length(Y2.args) == 2 + @test Y2.args[1] === W2 + @test Y2.args[2] === X1 return end From 087585c4e8b6933efab9c3f7547431c675f43b31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Thu, 5 Mar 2026 12:06:40 +0100 Subject: [PATCH 4/4] Fix format --- src/JuMP/operators.jl | 39 +++++++++++++-------------------------- test/JuMP.jl | 2 +- 2 files changed, 14 insertions(+), 27 deletions(-) diff --git a/src/JuMP/operators.jl b/src/JuMP/operators.jl index 6b314a9..6907837 100644 --- a/src/JuMP/operators.jl +++ b/src/JuMP/operators.jl @@ -1,15 +1,16 @@ function _matmul(::Type{V}, A, B) where {V} - return GenericMatrixExpr{V}( - :*, - Any[A, B], - (size(A, 1), size(B, 2)), - false, - ) + return GenericMatrixExpr{V}(:*, Any[A, B], (size(A, 1), size(B, 2)), false) end -Base.:(*)(A::AbstractJuMPMatrix, B::Matrix) = _matmul(JuMP.variable_ref_type(A), A, B) -Base.:(*)(A::Matrix, B::AbstractJuMPMatrix) = _matmul(JuMP.variable_ref_type(B), A, B) -Base.:(*)(A::AbstractJuMPMatrix, B::AbstractJuMPMatrix) = _matmul(JuMP.variable_ref_type(A), A, B) +function Base.:(*)(A::AbstractJuMPMatrix, B::Matrix) + return _matmul(JuMP.variable_ref_type(A), A, B) +end +function Base.:(*)(A::Matrix, B::AbstractJuMPMatrix) + return _matmul(JuMP.variable_ref_type(B), A, B) +end +function Base.:(*)(A::AbstractJuMPMatrix, B::AbstractJuMPMatrix) + return _matmul(JuMP.variable_ref_type(A), A, B) +end function __broadcast( ::Type{V}, @@ -17,25 +18,11 @@ function __broadcast( op::Function, args::Vector{Any}, ) where {V,N} - return GenericArrayExpr{V,N}( - Symbol(op), - args, - length.(axes), - true, - ) + return GenericArrayExpr{V,N}(Symbol(op), args, length.(axes), true) end -function _broadcast( - ::Type{V}, - op::Function, - args..., -) where {V} - return __broadcast( - V, - Broadcast.combine_axes(args...), - op, - Any[args...], - ) +function _broadcast(::Type{V}, op::Function, args...) where {V} + return __broadcast(V, Broadcast.combine_axes(args...), op, Any[args...]) end function Base.broadcasted(op::Function, x::AbstractJuMPArray) diff --git a/test/JuMP.jl b/test/JuMP.jl index 95ca758..95a9fba 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -33,7 +33,7 @@ function test_neural() @test prod.head == :* @test !prod.broadcasted @test sprint(show, prod) == - "2×2 ArrayDiff.GenericArrayExpr{JuMP.VariableRef, 2}" + "2×2 ArrayDiff.GenericArrayExpr{JuMP.VariableRef, 2}" err = ErrorException( "`getindex` not implemented, build vectorized expression instead", )