Skip to content
Open
48 changes: 48 additions & 0 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
register_infer_shape,
switch,
tensor_copy,
tile,
zeros,
zeros_like,
)
Expand Down Expand Up @@ -910,6 +911,53 @@ def local_join_make_vector(fgraph, node):
return [ret]


@register_canonicalize
@node_rewriter([Join])
def local_join_to_repeat(fgraph, node):
"""Join(axis, x, x, x, ...) -> tile(x, reps)

When the same tensor is concatenated multiple times along an axis,
replace with a single tile operation which is more efficient.

Examples
--------
join(0, x, x, x) -> tile(x, (3, 1, 1, ...))
join(1, x, x) -> tile(x, (1, 2, 1, ...))
"""
# Extract axis and the tensors being joined
axis, *tensors = node.inputs

# Optimization only applies when axis is constant
if not isinstance(axis, Constant):
return None

# Extract the Python integer from the constant
axis_val = axis.data

# Need at least 2 tensors to consider optimization
if len(tensors) <= 1:
return

# Check if all tensors are identical
if not all(t == tensors[0] for t in tensors[1:]):
return

n_reps = len(tensors)
first_tensor = tensors[0]
ndim = first_tensor.ndim

# Build reps tuple to repeat only along the join axis
# For shape (a, b, c) joining at axis 1: reps = (1, n_reps, 1)
# This directly concatenates n_reps copies along axis_val
reps = tuple(n_reps if i == axis_val else 1 for i in range(ndim))

result = tile(first_tensor, reps)

# Preserve debugging information
copy_stack_trace(node.outputs[0], result)
return [result]


@register_specialize
@register_canonicalize
@register_useless
Expand Down
87 changes: 76 additions & 11 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,33 +1237,98 @@ def test_local_join_1():
assert len([n for n in e if isinstance(n.op, Join)]) == 0
assert f.maker.fgraph.outputs[0].dtype == config.floatX

# test we don't apply when their is 2 inputs
s = join(1, a, a)
# Test that join with 2 different inputs remains (not optimized away)
s = join(1, a, a[:, ::-1])
f = function([a], s, mode=rewrite_mode)
val = f([[1]])
assert np.all(val == [[1]])
val = f([[1, 2]])
assert np.all(val == [[1, 2, 2, 1]]) # joined along axis 1
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert len([n for n in e if isinstance(n.op, Join)]) == 1 # join remains
assert f.maker.fgraph.outputs[0].dtype == config.floatX


def test_local_join_to_tile():
"""Join(axis, x, x, ...) is rewritten to tile(x, reps) with reps[axis] = k.

This optimization applies whenever we concatenate the *same* tensor multiple
times along a given axis. It replaces the Join/concatenate with a Tile op.
"""

# ---- Case 1: joining same vector along axis 0 ----
x = vector("x")
s = join(0, x, x, x) # (3n,)
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX)
assert np.allclose(result, expected)

# Join should be optimized away
ops = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, Join) for n in ops)

# ---- Case 2: joining same matrix along axis 0 ----
a = matrix("a")
s = join(0, a, a) # (2m, n)
f = function([a], s, mode=rewrite_mode)

test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
result = f(test_mat)
expected = np.vstack([test_mat, test_mat])
assert np.allclose(result, expected)

ops = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, Join) for n in ops)

# ---- Case 3: joining same matrix along axis 1 ----
s = join(1, a, a, a) # (m, 3n)
f = function([a], s, mode=rewrite_mode)

result = f(test_mat)
expected = np.hstack([test_mat, test_mat, test_mat])
assert np.allclose(result, expected)

ops = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, Join) for n in ops)

# ---- Case 4: different tensors -> should NOT optimize ----
y = vector("y")
s = join(0, x, y) # inputs differ
f = function([x, y], s, mode=rewrite_mode)

test_vec1 = np.array([1.0, 2.0], dtype=config.floatX)
test_vec2 = np.array([3.0, 4.0], dtype=config.floatX)
result = f(test_vec1, test_vec2)
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=config.floatX)
assert np.allclose(result, expected)

# Join should still be present since inputs aren't identical
ops = f.maker.fgraph.toposort()
assert any(isinstance(n.op, Join) for n in ops)


def test_local_join_empty():
# Vector case
# Vector case - empty tensors should be removed
empty_vec = np.asarray([], dtype=config.floatX)
vec = vector("vec")
s = pt.join(0, vec, vec, empty_vec)
s = pt.join(0, vec, vec[::-1], empty_vec)
new_s = rewrite_graph(s)
assert equal_computations([new_s], [join(0, vec, vec)])
assert new_s.dtype == s.dtype
# Verify that empty tensors are removed from the join
expected = pt.join(0, vec, vec[::-1])
assert equal_computations([new_s], [expected])

# Matrix case
# Matrix case - empty tensors should be removed
empty_mat = np.zeros((2, 0), dtype=config.floatX)
empty_sym_mat = matrix("m", shape=(2, 0))
mat = matrix("mat", shape=(2, 10))
s = join(1, empty_mat, mat, empty_sym_mat, mat, mat)
s = join(1, empty_mat, mat, empty_sym_mat, mat[:, ::-1])
new_s = rewrite_graph(s)
assert equal_computations([new_s], [join(1, mat, mat, mat)])
assert new_s.dtype == s.dtype
# Verify that empty tensors are removed from the join
expected = join(1, mat, mat[:, ::-1])
assert equal_computations([new_s], [expected])

# Join can be completely removed, but casting and specify_shape are propagated
int_mat = matrix("int_mat", dtype=int)
Expand Down
9 changes: 3 additions & 6 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2029,12 +2029,9 @@ def test_concatenate_same(self):
Tout = ptb.concatenate([T_shared, T_shared])
f = function([], Tout, mode=self.mode)
out = f()
if config.mode != "FAST_COMPILE":
assert [
True
for node in f.maker.fgraph.toposort()
if isinstance(node.op, type(self.join_op))
]
# Note: Join operations are optimized away when concatenating identical
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this isn't testing join and we don't have gpu backend anymore and we're testing this case in the new tests, let's remove this old test altogether?

# tensors (converted to tile operations). The important check is numerical
# correctness, not the presence of Join operations in the graph.
assert np.allclose(
out, np.concatenate([T_shared.get_value(), T_shared.get_value()])
)
Expand Down