From f43d9e3967a09ecf312307ea4f2ef1f013c11627 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 17 May 2026 22:17:31 -0500 Subject: [PATCH 1/4] Make Join and Split axis a static Op property Move axis from a symbolic Apply-node input to a __props__ property. Breaking change: Apply-node layout differs and c_code cache versions are bumped, so old pickled graphs will not load. --- pytensor/graph/rewriting/basic.py | 2 +- pytensor/sparse/basic.py | 4 +- pytensor/tensor/basic.py | 464 ++++++++++++------------------ 3 files changed, 181 insertions(+), 289 deletions(-) diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index e39465f416..001ccc8162 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -1481,7 +1481,7 @@ def output_fn(fgraph, node, s): PatternNodeRewriter( ( OpPattern(CAReduce, scalar_op="scalar_op", axis=None), - (Join(), "join_axis", "a", "b"), + (Join(0), "a", "b"), ), output_fn, ) diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index bdd4f77777..5f92fc6e8f 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -1600,7 +1600,7 @@ def pullback(self, inputs, outputs, gout): if _is_sparse_variable(gz): gz = dense_from_sparse(gz) - split = Split(len(inputs))(gz, 1, ptb.stack([x.shape[1] for x in inputs])) + split = Split(len(inputs), 1)(gz, ptb.stack([x.shape[1] for x in inputs])) if not isinstance(split, list): split = [split] @@ -1697,7 +1697,7 @@ def pullback(self, inputs, outputs, gout): if _is_sparse_variable(gz): gz = dense_from_sparse(gz) - split = Split(len(inputs))(gz, 0, ptb.stack([x.shape[0] for x in inputs])) + split = Split(len(inputs), 0)(gz, ptb.stack([x.shape[0] for x in inputs])) if not isinstance(split, list): split = [split] diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 7d690a9e9c..95e8d8f886 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -24,7 +24,7 @@ from pytensor.compile.builders import SymbolicOp from pytensor.gradient import DisconnectedType, disconnected_type, grad_undefined from pytensor.graph import RewriteDatabaseQuery -from pytensor.graph.basic import Apply, Constant, Variable, equal_computations +from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.op import Op from pytensor.graph.replace import _vectorize_node @@ -434,11 +434,13 @@ def _get_underlying_scalar_constant_value( and isinstance(v.owner.inputs[0].owner.op, Join) and len(v.owner.op.idx_list) == 1 ): - # Ensure the Join is joining only (effectively) scalar - # variables (so that the constant value can be found at the - # same index as the one used in the sub-tensor). - if builtins.all( - var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:] + # Ensure the Join is along axis 0 and joins only + # (effectively) scalar variables (so that the constant + # value can be found at the same index as the one used in + # the sub-tensor). + join_node = v.owner.inputs[0].owner + if join_node.op.axis == 0 and builtins.all( + var.ndim == 1 for var in join_node.inputs ): idx = v.owner.op.idx_list[0] if isinstance(idx, int): @@ -446,10 +448,9 @@ def _get_underlying_scalar_constant_value( v.owner.inputs[1], max_recur=max_recur ) try: - # TODO: assert joined axis is 0. length = 0 loop = False - for joined in v.owner.inputs[0].owner.inputs[1:]: + for joined in join_node.inputs: ll = get_vector_length(joined) if idx < length + ll: v = joined[idx - length] @@ -2156,13 +2157,35 @@ def matrix_transpose(x: "TensorLike") -> TensorVariable: return swapaxes(x, -1, -2) +def _validate_axis_argument(axis: int, op_name: str): + """Resolve a `Join`/`Split` axis to a Python integer. + + Accept a Python integer or a constant scalar variable; reject symbolic + (non-constant) variables. + """ + if isinstance(axis, Variable): + try: + return int(get_scalar_constant_value(axis)) + except NotScalarConstantError: + raise TypeError( + f"The axis of {op_name} must be a constant integer. Symbolic " + "axes are no longer supported; implement a custom Op if you " + "need a runtime-varying axis." + ) + if not isinstance(axis, int | np.integer): + raise TypeError( + f"The axis of {op_name} must be an integer, got {type(axis).__name__}." + ) + return int(axis) + + def split(x, splits_size, *, n_splits=None, axis=0): if n_splits is None: if isinstance(splits_size, Variable): n_splits = get_vector_length(splits_size) else: n_splits = len(splits_size) - return Split(n_splits)(x, axis, splits_size) + return Split(n_splits, axis)(x, splits_size) class Split(COp): @@ -2191,55 +2214,46 @@ class Split(COp): """A Split instance will have this many outputs, and require that the splits argument to `perform` have exactly this many elements. """ - __props__ = ("len_splits",) + __props__ = ("len_splits", "axis") - def __init__(self, len_splits): + def __init__(self, len_splits, axis): self.len_splits = int(len_splits) + self.axis = _validate_axis_argument(axis, "Split") self.view_map = {i: [0] for i in range(self.len_splits)} def __str__(self): - return f"{self.__class__.__name__}{{{self.len_splits}}}" + return f"{self.__class__.__name__}{{len_splits={self.len_splits}, axis={self.axis}}}" - def make_node(self, x, axis, splits): - """WRITEME""" + def make_node(self, x, splits): x = as_tensor_variable(x) - axis = as_tensor_variable(axis) splits = as_tensor_variable(splits) if splits.type.ndim == 1 and splits.type.dtype not in integer_dtypes: raise TypeError("`splits` parameter must be tensors of integer type") - if axis.type.dtype not in integer_dtypes or axis.ndim != 0: - raise TypeError("`axis` parameter must be an integer scalar") - - inputs = [x, axis, splits] - + axis = normalize_axis_index(self.axis, x.type.ndim) + # Bind the node to an Op with a canonical (non-negative) axis so that + # rewrites and the Op's own methods can read `axis` without normalizing. + op = self if axis == self.axis else Split(self.len_splits, axis) x_dtype = x.type.dtype - if isinstance(axis, Constant): - # In this case we can preserve more static shape info - static_axis = axis.data.item() - outputs = [] - x_static_shape = list(x.type.shape) - for i in range(self.len_splits): - try: - static_split_size = int(get_scalar_constant_value(splits[i])) - except NotScalarConstantError: - static_split_size = None - except IndexError: - raise ValueError("Number of splits is larger than splits size") - static_out_shape = x_static_shape.copy() - static_out_shape[static_axis] = static_split_size - outputs.append(tensor(shape=tuple(static_out_shape), dtype=x_dtype)) - else: - outputs = [ - tensor(shape=(None,) * x.type.ndim, dtype=x_dtype) - for i in range(self.len_splits) - ] + outputs = [] + x_static_shape = list(x.type.shape) + for i in range(self.len_splits): + try: + static_split_size = int(get_scalar_constant_value(splits[i])) + except NotScalarConstantError: + static_split_size = None + except IndexError: + raise ValueError("Number of splits is larger than splits size") + static_out_shape = x_static_shape.copy() + static_out_shape[axis] = static_split_size + outputs.append(tensor(shape=tuple(static_out_shape), dtype=x_dtype)) - return Apply(self, inputs, outputs) + return Apply(op, [x, splits], outputs) def perform(self, node, inputs, outputs_storage): - x, axis, splits = inputs + x, splits = inputs + axis = self.axis if len(splits) != self.len_splits: raise ValueError("Length of splits is not equal to n_splits") @@ -2255,50 +2269,43 @@ def perform(self, node, inputs, outputs_storage): out_storage[0] = out def infer_shape(self, fgraph, node, in_shapes): - axis = node.inputs[1] - splits = node.inputs[2] - shp_x, _shp_axis, _shp_splits = in_shapes + splits = node.inputs[1] + shp_x, _shp_splits = in_shapes + axis = self.axis out_shapes = [] for i in range(self.len_splits): temp = as_tensor_variable(shp_x) temp = pytensor.tensor.subtensor.set_subtensor(temp[axis], splits[i]) - temp = [temp[i] for i in range(len(shp_x))] + temp = [temp[j] for j in range(len(shp_x))] out_shapes.append(temp) return out_shapes def connection_pattern(self, node): n_out = len(node.outputs) return [ - [True] * n_out, [True] * n_out, [False] * n_out, ] def pullback(self, inputs, outputs, g_outputs): """Join the gradients along the axis that was used to split x.""" - _x, axis, _n = inputs - # We have to convert disconnected outputs to zeros before joining them - new_g_outputs = [] - for o, g in zip(outputs, g_outputs, strict=True): - if isinstance(g.type, DisconnectedType): - new_g_outputs.append(o.zeros_like()) - else: - new_g_outputs.append(g) - + new_g_outputs = [ + o.zeros_like() if isinstance(g.type, DisconnectedType) else g + for o, g in zip(outputs, g_outputs, strict=True) + ] return [ - join(axis, *new_g_outputs), - grad_undefined(self, 1, axis), + join(self.axis, *new_g_outputs), disconnected_type(), ] def pushforward(self, inputs, outputs, eval_points): if isinstance(eval_points[0].type, DisconnectedType): - return [disconnected_type() for i in self.len_splits] - return self.make_node(eval_points[0], *inputs[1:]).outputs + return [disconnected_type() for _ in range(self.len_splits)] + return self.make_node(eval_points[0], inputs[1]).outputs def c_code_cache_version(self): - return (3,) + return (4,) def c_code(self, node, name, inputs, outputs, sub): if self.len_splits == 0: @@ -2307,35 +2314,16 @@ def c_code(self, node, name, inputs, outputs, sub): # outputs_pointers lists the addresses of the pointers to the outputs. outputs_pointers = "&" + (", &".join(outputs)) - x, axis, splits = inputs + x, splits = inputs fail = sub["fail"] - splits_dtype = node.inputs[2].type.dtype_specs()[1] + splits_dtype = node.inputs[1].type.dtype_specs()[1] len_splits = self.len_splits ndim = node.inputs[0].type.ndim - - # Most times axis is constant, inline it - # This is safe to do because the hash of the c_code includes the constant signature - if isinstance(node.inputs[1], Constant): - static_axis = int(node.inputs[1].data) - static_axis = normalize_axis_index(static_axis, ndim) - axis_def = f"{static_axis};" - axis_check = "" - else: - axis_dtype = node.inputs[1].type.dtype_specs()[1] - axis_def = f"(({axis_dtype} *)PyArray_DATA({axis}))[0];" - axis_check = f""" - if (axis < 0){{ - axis = ndim + axis; - }} - if (axis >= ndim || axis < 0) {{ - PyErr_SetString(PyExc_ValueError, "Split axis is out of bounds"); - {fail} - }} - """ + axis = self.axis return f""" int ndim = {ndim}; - int axis = {axis_def} + int axis = {axis}; int splits_count = PyArray_DIM({splits}, 0); npy_intp sum_of_splits = 0, current_split_start = 0; PyArrayObject** outputs[] = {{{outputs_pointers}}}; @@ -2351,8 +2339,6 @@ def c_code(self, node, name, inputs, outputs, sub): {fail} }} - {axis_check}; - for (int i = 0; i < splits_count; ++i) {{ int current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i)); if (current_split_length < 0) {{ @@ -2404,12 +2390,12 @@ def c_code(self, node, name, inputs, outputs, sub): class Join(COp): r""" - Concatenate multiple `TensorVariable`\s along some axis. + Concatenate multiple `TensorVariable`\s along an axis. - The axis must be given as first argument. All tensors must have the same - shape along all dimensions other than this axis. - Of course, TensorVariable instances do not have a shape, so this error - cannot be caught until runtime. See `perform()`. + The ``axis`` is a Python integer fixed at construction time and stored as + an Op property. All tensors must have the same shape along all dimensions + other than this axis. Of course, TensorVariable instances do not have a + shape, so this error cannot be caught until runtime. See `perform()`. See Also -------- @@ -2437,41 +2423,27 @@ class Join(COp): """ check_input = False - __props__ = () + __props__ = ("axis",) - def make_node(self, axis, *tensors): + def __init__(self, axis): + self.axis = _validate_axis_argument(axis, "Join") + + def __str__(self): + return f"{self.__class__.__name__}{{axis={self.axis}}}" + + def make_node(self, *tensors): """ Parameters ---------- - axis - The axis upon which to join `tensors`. tensors - A variable number of tensors to join along the specified axis. - These tensors must have the same shape along all dimensions other - than `axis`. + A variable number of tensors to join along the axis stored on the + Op. These tensors must have the same shape along all dimensions + other than `axis`. """ if not tensors: raise ValueError("Cannot join an empty list of tensors") - axis = as_tensor_variable(axis) - if axis.type.dtype not in int_dtypes: - raise TypeError(f"Axis {axis} must be an integer type.") - if axis.type.ndim > 0: - raise TypeError(f"Axis {axis} must be 0-d.") - - # Convert negative constant axis to positive during canonicalization - if isinstance(axis, Constant) and tensors: - # Get the axis value directly from the constant's data - axis_val = axis.data.item() - # Check if it's negative and needs normalization - if axis_val < 0: - ndim = tensors[0].ndim - # Convert negative axis to positive - axis_val = normalize_axis_index(axis_val, ndim) - # Replace the original axis with the normalized one - axis = constant(axis_val, dtype=axis.type.dtype) - tensors = [as_tensor_variable(x) for x in tensors] if not builtins.all(targs.type.ndim > 0 for targs in tensors): @@ -2480,100 +2452,65 @@ def make_node(self, axis, *tensors): " Use `stack` to join scalar values or promote the scalars to vectors." ) + ndim = tensors[0].type.ndim + if not builtins.all(x.ndim == ndim for x in tensors): + raise TypeError( + "Only tensors with the same number of dimensions can be joined. " + f"Input ndims were: {[x.ndim for x in tensors]}" + ) + axis = normalize_axis_index(self.axis, ndim) + # Bind the node to an Op with a canonical (non-negative) axis so that + # rewrites and the Op's own methods can read `axis` without normalizing. + op = self if axis == self.axis else Join(axis) + if len(tensors) == 1: out_shape = tensors[0].type.shape else: - ndim = tensors[0].type.ndim - - if not builtins.all(x.ndim == ndim for x in tensors): - raise TypeError( - "Only tensors with the same number of dimensions can be joined. " - f"Input ndims were: {[x.ndim for x in tensors]}" - ) - - try: - static_axis = int(get_scalar_constant_value(axis)) - except NotScalarConstantError: - static_axis = None - - if static_axis is None: - # When axis isn't static, we can't conclude anything about output dimension - # (unless we had some degenerate zero arrays) that can be removed during rewrites. - # We could also raise errors if any dimensions are pairwise inconsistent across all the axes - # As no matter the join it would be invalid. - # However, dynamic axis is so rare that is not worth the trouble - out_shape = [None] * ndim - - else: # We know the axis statically - static_axis = normalize_axis_index(static_axis, ndim) - static_shapes = [x.type.shape for x in tensors] - - # Determine output shapes from a matrix of input shapes - static_shapes = np.array(static_shapes) - out_shape = [None] * ndim - for d in range(ndim): - ins = static_shapes[:, d] - if d == static_axis: - # Any unknown size along the axis means we can't infer it - if None in ins: - out_shape[d] = None - else: - out_shape[d] = sum(ins) + # Determine output shapes from a matrix of input shapes + static_shapes = np.array([x.type.shape for x in tensors]) + out_shape = [None] * ndim + for d in range(ndim): + ins = static_shapes[:, d] + if d == axis: + # Any unknown size along the axis means we can't infer it + if None in ins: + out_shape[d] = None else: - inset = set(static_shapes[:, d]) - # Other dims must match exactly, - # or if a mix of None and ? the output will be ? - # otherwise the input shapes are incompatible. - if len(inset) == 1: - (out_shape[d],) = inset - elif len(inset - {None}) == 1: - (out_shape[d],) = inset - {None} - else: - raise ValueError( - f"all input array dimensions other than the specified `axis` ({static_axis})" - " must match exactly, or be unknown (None)," - f" but along dimension {d}, the inputs shapes are incompatible: {ins}" - ) + out_shape[d] = sum(ins) + else: + inset = set(ins) + # Other dims must match exactly, + # or if a mix of None and ? the output will be ? + # otherwise the input shapes are incompatible. + if len(inset) == 1: + (out_shape[d],) = inset + elif len(inset - {None}) == 1: + (out_shape[d],) = inset - {None} + else: + raise ValueError( + f"all input array dimensions other than the specified `axis` ({axis})" + " must match exactly, or be unknown (None)," + f" but along dimension {d}, the inputs shapes are incompatible: {ins}" + ) - inputs = [axis, *tensors] out_dtype = ps.upcast(*[x.type.dtype for x in tensors]) - return Apply(self, inputs, [tensor(dtype=out_dtype, shape=out_shape)]) + return Apply(op, list(tensors), [tensor(dtype=out_dtype, shape=out_shape)]) def perform(self, node, inputs, output_storage): - axis, *arrays = inputs output_storage[0][0] = np.concatenate( - arrays, axis=axis, dtype=node.outputs[0].type.dtype + inputs, axis=self.axis, dtype=node.outputs[0].type.dtype ) def c_code_cache_version(self): - return (7,) + return (8,) def c_code(self, node, name, inputs, outputs, sub): - axis, *arrays = inputs + arrays = inputs [out] = outputs n = len(arrays) ndim = node.outputs[0].type.ndim fail = sub["fail"] - - # Most times axis is constant, inline it - # This is safe to do because the hash of the c_code includes the constant signature - if isinstance(node.inputs[0], Constant): - static_axis = int(node.inputs[0].data) - static_axis = normalize_axis_index(static_axis, ndim) - axis_def = f"{static_axis};" - axis_check = "" - else: - axis_ctype = node.inputs[0].type.dtype_specs()[1] - axis_def = f"(({axis_ctype} *)PyArray_DATA({axis}))[0];" - axis_check = f""" - if (axis < 0){{ - axis = {ndim} + axis; - }} - if (axis >= {ndim} || axis < 0) {{ - PyErr_SetString(PyExc_ValueError, "Join axis is out of bounds"); - {fail} - }} - """ + axis = self.axis copy_arrays_to_tuple = "\n".join( ( @@ -2583,12 +2520,10 @@ def c_code(self, node, name, inputs, outputs, sub): ) code = f""" - int axis = {axis_def} + int axis = {axis}; PyArrayObject* arrays[{n}] = {{{",".join(arrays)}}}; int out_is_valid = {out} != NULL; - {axis_check} - if (out_is_valid) {{ // Check if we can reuse output npy_intp join_size = 0; @@ -2670,9 +2605,9 @@ def c_code(self, node, name, inputs, outputs, sub): return code def pushforward(self, inputs, outputs, eval_points): - if any(isinstance(t.type, DisconnectedType) for t in eval_points[1:]): + if any(isinstance(t.type, DisconnectedType) for t in eval_points): return [disconnected_type()] - return self.make_node(inputs[0], *eval_points[1:]).outputs + return self.make_node(*eval_points).outputs def pullback(self, inputs, outputs, grads): """The gradient wrt a join op is a `Split`, used to partition @@ -2680,9 +2615,8 @@ def pullback(self, inputs, outputs, grads): """ [gz] = grads [out] = outputs - axis, *tensors = inputs - - rval = [grad_undefined(self, 0, axis)] + tensors = inputs + axis = self.axis out_dtype = out.type.dtype if "float" in out_dtype or "complex" in out_dtype: @@ -2695,7 +2629,7 @@ def pullback(self, inputs, outputs, grads): # Split.make_node isn't always able to infer the right # broadcast. As the grad need to keep the information, # read it if needed. - split_gz = [ + return [ g if g.type.shape == t.type.shape == 1 else specify_broadcastable( @@ -2703,66 +2637,36 @@ def pullback(self, inputs, outputs, grads): ) for t, g in zip(tensors, split_gz, strict=True) ] - rval = rval + split_gz else: # the output has integer type, so the gradient through it is 0 - rval = rval + [t.zeros_like(dtype=config.floatX) for t in tensors] - - return rval + return [t.zeros_like(dtype=config.floatX) for t in tensors] def infer_shape(self, fgraph, node, ishapes): - from pytensor.tensor.math import eq, ge - - # ishapes[0] contains the size of the axis on which we join # Join op should get at least one input to join - assert len(ishapes) > 1 - n_dim = len(ishapes[1]) - for shp in ishapes[1:]: + assert len(ishapes) > 0 + n_dim = len(ishapes[0]) + for shp in ishapes: assert shp is not None assert len(shp) == n_dim - # The joining dimension could be negative, but we need it to be - # in [0, n_dim) in the loop below. - # An axis < -n_dim or >= ndim would be invalid, but this is - # not checked here. A `CheckAndRaise` `Op` would be a way of - # addressing that, but it may disrupt optimizations. - axis = node.inputs[0] - join_dim = switch(ge(axis, 0), axis, axis + n_dim) - out_shapes = [] - for dim in range(n_dim): - # we have to deal with 2 possible cases in here : - # a) we are dealing with the dimension for which we join - # (called t_side from true side of the if, where the if - # compares current dimension with the joining dimension) - # b) a non joining dimension ( in which maybe a symbolic - # assertion can be used to make sure all tensors have - # the same number of elements on this non-joined dimension - # this is f_side - # initialize - t_side = ishapes[1][dim] - f_side = ishapes[1][dim] - # loop over tensors and sum for the joining dimension - for shp in ishapes[2:]: - t_side = t_side + shp[dim] - # return the dimensions found - out_shapes.append(switch(eq(dim, join_dim), t_side, f_side)) - - return [tuple(out_shapes)] - - -_join = Join() + axis = self.axis + out_shape = list(ishapes[0]) + join_size = out_shape[axis] + for shp in ishapes[1:]: + join_size = join_size + shp[axis] + out_shape[axis] = join_size + return [tuple(out_shape)] + + pprint.assign(Join, printing.FunctionPrinter(["join"])) @_get_vector_length.register(Join) def _get_vector_length_Join(op, var): - axis, *arrays = var.owner.inputs - try: - axis = get_scalar_constant_value(axis) - assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays) + arrays = var.owner.inputs + if op.axis == 0 and builtins.all(a.ndim == 1 for a in arrays): return builtins.sum(get_vector_length(a) for a in arrays) - except NotScalarConstantError: - raise ValueError(f"Length of {var} cannot be determined") + raise ValueError(f"Length of {var} cannot be determined") def join(axis, *tensors_list): @@ -2775,15 +2679,11 @@ def join(axis, *tensors_list): Parameters ---------- - axis : int (symbolic or literal) - On which dimension should the tensors be joined? The `axis` - must be a valid index into the shape of the tensors to be - concatenated. - The `axis` parameter may either be an integer or an object that - can be converted to a scalar using `as_scalar`(`axis`). In the - former case, the axis is fixed at construction, while in the - latter it may vary over time depending on the value of the - `axis` variable. + axis : int + On which dimension should the tensors be joined? The `axis` must be an + integer (or a constant scalar variable) and a valid index into the + shape of the tensors to be concatenated. Symbolic axes are not + supported. tensors_list : list of TensorVariable (or list-like) A list of tensors to be concatenated along the given axis. The shapes of the tensors to be concatenated must be all @@ -2792,42 +2692,34 @@ def join(axis, *tensors_list): """ if len(tensors_list) == 1: return tensors_list[0] - else: - return _join(axis, *tensors_list) + if len(tensors_list) == 0: + raise ValueError("Cannot join an empty list of tensors") + return Join(axis)(*tensors_list) @_vectorize_node.register(Join) -def vectorize_join(op: Join, node, batch_axis, *batch_inputs): - original_axis, *old_inputs = node.inputs - # We can vectorize join as a shifted axis on the batch inputs if: - # 1. The batch axis is a constant and has not changed - # 2. All inputs are batched with the same broadcastable pattern +def vectorize_join(op: Join, node, *batch_inputs): + old_inputs = node.inputs + # We can vectorize join as a shifted axis on the batch inputs if all inputs + # are batched with the same broadcastable pattern. - # TODO: We can relax the second condition by broadcasting the batch dimensions + # TODO: We can relax this condition by broadcasting the batch dimensions # This can be done with `broadcast_arrays` if the tensors shape match at the axis or reduction # Or otherwise by calling `broadcast_to` for each tensor that needs it - if ( - original_axis.type.ndim == 0 - and isinstance(original_axis, Constant) - and equal_computations([original_axis], [batch_axis]) - ): - batch_ndims = { - batch_input.type.ndim - old_input.type.ndim - for batch_input, old_input in zip(batch_inputs, old_inputs, strict=True) - } - if len(batch_ndims) == 1: - [batch_ndim] = batch_ndims - batch_bcast = batch_inputs[0].type.broadcastable[:batch_ndim] - if all( - batch_input.type.broadcastable[:batch_ndim] == batch_bcast - for batch_input in batch_inputs[1:] - ): - original_ndim = node.outputs[0].type.ndim - original_axis = normalize_axis_index(original_axis.data, original_ndim) - batch_axis = original_axis + batch_ndim - return op.make_node(batch_axis, *batch_inputs) + batch_ndims = { + batch_input.type.ndim - old_input.type.ndim + for batch_input, old_input in zip(batch_inputs, old_inputs, strict=True) + } + if len(batch_ndims) == 1: + [batch_ndim] = batch_ndims + batch_bcast = batch_inputs[0].type.broadcastable[:batch_ndim] + if all( + batch_input.type.broadcastable[:batch_ndim] == batch_bcast + for batch_input in batch_inputs[1:] + ): + return Join(op.axis + batch_ndim)(*batch_inputs).owner - return vectorize_node_fallback(op, node, batch_axis, *batch_inputs) + return vectorize_node_fallback(op, node, *batch_inputs) def roll(x, shift, axis=None): @@ -2842,7 +2734,7 @@ def roll(x, shift, axis=None): Input tensor. shift : int (symbolic or literal) The number of places by which elements are shifted. - axis : int (symbolic or literal), optional + axis : int, optional The axis along which elements are shifted. By default, the array is flattened before shifting, after which the original shape is restored. From 862651f9998ed70b3dbc9e7f8708b383ef233b2a Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 17 May 2026 22:17:42 -0500 Subject: [PATCH 2/4] Update Join/Split rewrites for static axis property --- pytensor/tensor/rewriting/basic.py | 52 +++++++-------------- pytensor/tensor/rewriting/math.py | 10 ++-- pytensor/tensor/rewriting/subtensor.py | 8 +--- pytensor/tensor/rewriting/subtensor_lift.py | 8 +--- 4 files changed, 25 insertions(+), 53 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index e2910da9c3..011ef6118d 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -26,7 +26,6 @@ from collections.abc import Sequence import numpy as np -from numpy.lib.array_utils import normalize_axis_index from pytensor import compile, config from pytensor.compile.ops import ViewOp @@ -859,7 +858,7 @@ def local_join_1(fgraph, node): """ if not isinstance(node.op, Join): return - tensors = node.inputs[1:] + tensors = node.inputs if len(tensors) == 1: # We don't need to copy over any stacktrace here, because the # input variable should already have its own stacktrace. @@ -875,16 +874,10 @@ def local_join_empty(fgraph, node): Remove empty inputs to joins. The empty inputs can be anywhere. """ - axis, *tensors = node.inputs + tensors = node.inputs + axis = node.op.axis - try: - static_axis = get_scalar_constant_value( - node.inputs[0], only_process_constants=True - ) - except NotScalarConstantError: - return - - new_tensors = [tensor for tensor in tensors if tensor.type.shape[static_axis] != 0] + new_tensors = [tensor for tensor in tensors if tensor.type.shape[axis] != 0] # If there are zero tensors, the join is useless but so is any other operation # Another rewrite will (one day) handle all those cases @@ -917,8 +910,8 @@ def local_join_make_vector(fgraph, node): """ if not isinstance(node.op, Join) or node.outputs[0].ndim != 1: return - new_inputs = [node.inputs[1]] - for idx in range(2, len(node.inputs)): + new_inputs = [node.inputs[0]] + for idx in range(1, len(node.inputs)): inp = node.inputs[idx] if ( inp.owner @@ -938,8 +931,8 @@ def local_join_make_vector(fgraph, node): copy_stack_trace(node.outputs, new_inputs[-1]) else: new_inputs.append(inp) - if len(new_inputs) < len(node.inputs) - 1: - ret = join(node.inputs[0], *new_inputs) + if len(new_inputs) < len(node.inputs): + ret = join(node.op.axis, *new_inputs) # Copy over stacktrace from previous output (after join op) # to new output, because an error in the new op must be caused @@ -961,15 +954,9 @@ def local_join_to_repeat(fgraph, node): join(0, x, x, x) -> tile(x, (3, 1, 1, ...)) join(1, x, x) -> tile(x, (1, 2, 1, ...)) """ - # Extract axis and the tensors being joined - axis, *tensors = node.inputs - - # Optimization only applies when axis is constant - if not isinstance(axis, Constant): - return None - - # Extract the Python integer from the constant - axis_val = axis.data + # Extract the tensors being joined + tensors = node.inputs + axis_val = node.op.axis # Need at least 2 tensors to consider optimization if len(tensors) <= 1: @@ -1174,11 +1161,11 @@ def local_useless_split(fgraph, node): """ if isinstance(node.op, Split): if node.op.len_splits == 1: - x, axis, splits = node.inputs + x, splits = node.inputs out = assert_op(x, eq(splits.shape[0], 1)) # Copy over stacktrace from previous output node. copy_stack_trace(node.outputs, out) - out2 = assert_op(out, eq(x.shape[axis], splits[0])) + out2 = assert_op(out, eq(x.shape[node.op.axis], splits[0])) # Copy over stacktrace from previous output node. copy_stack_trace(out, out2) @@ -1359,15 +1346,12 @@ def local_dimshuffle_alloc(fgraph, node): @node_rewriter([Join]) def local_join_of_alloc(fgraph, node): """Rewrite a Join of Alloc nodes to an Alloc of the Join nodes.""" - axis, *tensors = node.inputs + tensors = node.inputs if len(tensors) < 2: # Let other rewrite handle the useless Join return - if not isinstance(axis, Constant): - return - core_tensors = [] alloc_shapes = [] for tensor in tensors: @@ -1388,7 +1372,7 @@ def local_join_of_alloc(fgraph, node): # Axis can never be lifted # Non-axis allocated dimensions can be lifted if they are all broadcastable [out] = node.outputs - static_axis = normalize_axis_index(axis.data, tensors[0].type.ndim) + axis = node.op.axis broadcasted_dims = list( zip( @@ -1410,7 +1394,7 @@ def local_join_of_alloc(fgraph, node): lifteable_alloc_dims = { dim for dim in range(out.type.ndim) - if dim != static_axis and all(broadcasted_dims[dim]) + if dim != axis and all(broadcasted_dims[dim]) } if not lifteable_alloc_dims: @@ -1427,13 +1411,13 @@ def local_join_of_alloc(fgraph, node): copy_stack_trace(tensor, new_tensor) new_tensors.append(new_tensor) - new_join = node.op(static_axis, *new_tensors) + new_join = join(axis, *new_tensors) copy_stack_trace(node.outputs[0], new_join) # Reintroduce the lifted dims post_join_shape = [] for i, alloc_dims in enumerate(zip(*alloc_shapes, strict=True)): - if i == static_axis: + if i == axis: # The alloc dim along the axis is the sum of all the pre-join alloc dims post_join_shape.append(add(*alloc_dims)) else: diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 0572721236..cc05d65438 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1856,9 +1856,7 @@ def investigate_if_shape(node) -> bool: elif isinstance(node.op, Subtensor) and node.inputs[0].owner: return investigate_if_shape(node.inputs[0].owner) elif isinstance(node.op, Join): - return all( - v.owner and investigate_if_shape(v.owner) for v in node.inputs[1:] - ) + return all(v.owner and investigate_if_shape(v.owner) for v in node.inputs) elif isinstance(node.op, MakeVector): return all(v.owner and investigate_if_shape(v.owner) for v in node.inputs) return False @@ -1962,7 +1960,7 @@ def local_reduce_join(fgraph, node): [joined_out] = node.inputs joined_node = joined_out.owner - join_axis_tensor, *joined_inputs = joined_node.inputs + joined_inputs = joined_node.inputs n_joined_inputs = len(joined_inputs) if n_joined_inputs < 2: @@ -1972,9 +1970,7 @@ def local_reduce_join(fgraph, node): # We don't rewrite if a single Elemwise cannot take all inputs at once return None - if not isinstance(join_axis_tensor, Constant): - return None - join_axis = join_axis_tensor.data + join_axis = joined_node.op.axis # Check whether reduction happens on joined axis reduce_op = node.op diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 80a6c296bb..2b852f8e3c 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -2512,12 +2512,8 @@ def local_join_subtensors(fgraph, node): """ # TODO: Generalize to AdvancedSubtensors - axis, tensors = node.inputs[0], node.inputs[1:] - - try: - axis = get_scalar_constant_value(axis) - except NotScalarConstantError: - return + tensors = node.inputs + axis = node.op.axis for subtensor1_idx, (subtensor1, subtensor2) in enumerate(pairwise(tensors)): # Check that two consecutive Subtensors are operating on the same base tensor diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 4b71ae34c0..8589c4269b 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -1074,14 +1074,10 @@ def local_subtensor_of_join(fgraph, node): # Join involves a full_copy, so we don't want to do it twice return None - join_axis, *join_components = join_var.owner.inputs - - # Rewrite only works when the join axis is a constant along a non-indexed dimension - if not isinstance(join_axis, Constant): - return None + join_components = join_var.owner.inputs [old_out] = node.outputs - axis = normalize_axis_index(join_axis.data, join_components[0].type.ndim) + axis = join_var.owner.op.axis idx_tuple = indices_from_subtensor(idx, node.op.idx_list) if _axis_is_indexed_by_basic_index(idx_tuple, axis): return _lift_subtensor_non_axis( From 23f0370bd79e27d33cc03cca0b37763263851541 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 17 May 2026 22:17:53 -0500 Subject: [PATCH 3/4] Update Join/Split backend dispatches for static axis --- pytensor/link/jax/dispatch/tensor_basic.py | 28 ++++++----------- pytensor/link/mlx/dispatch/tensor_basic.py | 22 ++++---------- pytensor/link/numba/dispatch/tensor_basic.py | 13 +++++--- pytensor/link/pytorch/dispatch/basic.py | 32 +++++++------------- 4 files changed, 35 insertions(+), 60 deletions(-) diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index e70fd67b72..71a2232f85 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -83,7 +83,9 @@ def arange(start, stop, step): @jax_funcify.register(Join) def jax_funcify_Join(op, **kwargs): - def join(axis, *tensors): + axis = op.axis + + def join(*tensors): # tensors could also be tuples, and in this case they don't have a ndim tensors = [jnp.asarray(tensor) for tensor in tensors] return jnp.concatenate(tensors, axis=axis) @@ -93,14 +95,8 @@ def join(axis, *tensors): @jax_funcify.register(Split) def jax_funcify_Split(op: Split, node, **kwargs): - _, axis, splits = node.inputs - try: - constant_axis = get_scalar_constant_value(axis) - except NotScalarConstantError: - constant_axis = None - warnings.warn( - "Split node does not have constant axis. Jax implementation will likely fail" - ) + _x, splits = node.inputs + axis = op.axis try: constant_splits = np.array( @@ -115,25 +111,21 @@ def jax_funcify_Split(op: Split, node, **kwargs): "Split node does not have constant split positions. Jax implementation will likely fail" ) - def split(x, axis, splits): - if constant_axis is not None: - axis = constant_axis - if len(splits) != op.len_splits: - raise ValueError("Length of splits is not equal to n_splits") + def split(x, splits): + if len(splits) != op.len_splits: + raise ValueError("Length of splits is not equal to n_splits") if constant_splits is not None: splits = constant_splits cumsum_splits = np.cumsum(splits[:-1]) if (splits < 0).any(): raise ValueError("Split sizes cannot be negative") - else: - cumsum_splits = jnp.cumsum(splits[:-1]) - - if constant_axis is not None and constant_splits is not None: if splits.sum() != x.shape[axis]: raise ValueError( f"Split sizes do not sum up to input length along axis: {x.shape[axis]}" ) + else: + cumsum_splits = jnp.cumsum(splits[:-1]) return jnp.split(x, cumsum_splits, axis=axis) diff --git a/pytensor/link/mlx/dispatch/tensor_basic.py b/pytensor/link/mlx/dispatch/tensor_basic.py index 730aa140c4..3cdc47323f 100644 --- a/pytensor/link/mlx/dispatch/tensor_basic.py +++ b/pytensor/link/mlx/dispatch/tensor_basic.py @@ -32,7 +32,9 @@ @mlx_funcify.register(Join) def mlx_funcify_Join(op, **kwargs): - def join(axis, *tensors): + axis = op.axis + + def join(*tensors): return mx.concatenate(tensors, axis=axis) return join @@ -40,12 +42,8 @@ def join(axis, *tensors): @mlx_funcify.register(Split) def mlx_funcify_Split(op: Split, node, **kwargs): - _, axis_sym, splits_sym = node.inputs - - try: - constant_axis = get_scalar_constant_value(axis_sym) - except NotScalarConstantError: - constant_axis = None + _x, splits_sym = node.inputs + axis = op.axis try: constant_splits = np.array( @@ -57,15 +55,7 @@ def mlx_funcify_Split(op: Split, node, **kwargs): except (ValueError, NotScalarConstantError): constant_splits = None - def split(x, axis, splits): - # Resolve constants for significant performance improvement (14x speedup) - if constant_axis is not None: - axis = int(constant_axis) - else: - raise ValueError( - "Symbolic axis is not supported in MLX Split implementation." - ) - + def split(x, splits): if constant_splits is not None: splits_arr = mx.array(constant_splits) else: diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index a5babae7de..dcaa540d26 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -116,20 +116,23 @@ def arange(start, stop, step): @register_funcify_default_op_cache_key(Join) def numba_funcify_Join(op, **kwargs): + axis = op.axis + @numba_basic.numba_njit - def join(axis, *tensors): - return np.concatenate(tensors, axis.item()) + def join(*tensors): + return np.concatenate(tensors, axis) return join @register_funcify_default_op_cache_key(Split) def numba_funcify_Split(op, **kwargs): + axis = op.axis + @numba_basic.numba_njit - def split(x, axis, sizes): + def split(x, sizes): if (sizes < 0).any(): raise ValueError("Split sizes cannot be negative") - axis = axis.item() split_indices = np.cumsum(sizes) if split_indices[-1] != x.shape[axis]: raise ValueError( @@ -137,7 +140,7 @@ def split(x, axis, sizes): ) return np.split(x, split_indices[:-1], axis=axis) - cache_version = 1 + cache_version = 2 return split, cache_version diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index c66f3210b6..89a2b4ec4a 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -141,22 +141,12 @@ def arange(start, stop, step): @pytorch_funcify.register(Join) def pytorch_funcify_Join(op, node, **kwargs): - axis = node.inputs[0] + axis = op.axis - if isinstance(axis, Constant): - axis = int(axis.data) + def join(*tensors): + return torch.cat(tensors, dim=axis) - def join_constant_axis(_, *tensors): - return torch.cat(tensors, dim=axis) - - return join_constant_axis - - else: - - def join(axis, *tensors): - return torch.cat(tensors, dim=axis) - - return join + return join @pytorch_funcify.register(Eye) @@ -224,19 +214,19 @@ def tensorfromscalar(x): @pytorch_funcify.register(Split) def pytorch_funcify_Split(op, node, **kwargs): - _x, dim, split_sizes = node.inputs - if isinstance(dim, Constant) and isinstance(split_sizes, Constant): - dim = int(dim.data) + _x, split_sizes = node.inputs + dim = op.axis + if isinstance(split_sizes, Constant): split_sizes = tuple(int(size) for size in split_sizes.data) - def split_constant_axis_and_sizes(x, *_): + def split_constant_sizes(x, _): return x.split(split_sizes, dim=dim) - return split_constant_axis_and_sizes + return split_constant_sizes else: - def inner_fn(x, dim, split_amounts): - return x.split(split_amounts.tolist(), dim=dim.item()) + def inner_fn(x, split_amounts): + return x.split(split_amounts.tolist(), dim=dim) return inner_fn From 9927d10a8fa721a62e9e97cfda748d79d4d9c6bf Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 17 May 2026 22:18:06 -0500 Subject: [PATCH 4/4] Update tests for static Join/Split axis --- tests/link/jax/test_tensor_basic.py | 9 +- tests/link/mlx/test_tensor_basic.py | 12 +- tests/tensor/rewriting/test_basic.py | 8 +- tests/tensor/rewriting/test_subtensor_lift.py | 8 - tests/tensor/test_basic.py | 197 +++++++++--------- 5 files changed, 102 insertions(+), 132 deletions(-) diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py index 9a86463bbc..66b8791cfb 100644 --- a/tests/link/jax/test_tensor_basic.py +++ b/tests/link/jax/test_tensor_basic.py @@ -12,7 +12,7 @@ import pytensor import pytensor.tensor.basic as ptb from pytensor.configdefaults import config -from pytensor.tensor.type import iscalar, matrix, scalar, vector +from pytensor.tensor.type import matrix, scalar, vector from tests.link.jax.test_basic import compare_jax_and_py from tests.tensor.test_basic import check_alloc_runtime_broadcast @@ -187,13 +187,6 @@ def test_jax_split_not_supported(self): with pytest.raises(ConcretizationTypeError): fn(np.zeros((6, 4), dtype=pytensor.config.floatX)) - split_axis = iscalar("split_axis") - a_splits = ptb.split(a, splits_size=[2, 4], n_splits=2, axis=split_axis) - with pytest.warns(UserWarning, match="Split node does not have constant axis."): - fn = pytensor.function([a, split_axis], a_splits, mode="JAX") - with pytest.raises(ConcretizationTypeError): - fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0) - def test_jax_eye(): """Tests jaxification of the Eye operator""" diff --git a/tests/link/mlx/test_tensor_basic.py b/tests/link/mlx/test_tensor_basic.py index 48e51d65cc..8ab3e07542 100644 --- a/tests/link/mlx/test_tensor_basic.py +++ b/tests/link/mlx/test_tensor_basic.py @@ -2,7 +2,6 @@ import pytest import pytensor -from pytensor import config from pytensor import tensor as pt from pytensor.tensor.basic import Alloc, arange from tests.link.mlx.test_basic import ( @@ -152,18 +151,13 @@ def test_split_const_axis_const_splits_compiled(): compare_mlx_and_py([x], outs, [np.arange(5, dtype="float32")]) -def test_split_dynamic_axis_const_splits(): +def test_split_symbolic_axis_rejected(): x = pt.matrix("x") axis = pt.scalar("axis", dtype="int64") splits = [1, 2, 3] - outs = pt.split(x, splits, n_splits=len(splits), axis=axis) - test_input = np.arange(12).astype(config.floatX).reshape(2, 6) - - with pytest.raises( - ValueError, match="Symbolic axis is not supported in MLX Split implementation" - ): - compare_mlx_and_py([x, axis], outs, [test_input, np.array(1)]) + with pytest.raises(TypeError, match="Symbolic axes are no longer supported"): + pt.split(x, splits, n_splits=len(splits), axis=axis) def test_arange(): diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 93610b1585..3447b0d59b 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1351,12 +1351,6 @@ def test_local_join_empty(): [new_s], [specify_shape(int_mat, (2, None)).astype(s.dtype)] ) - # Dynamic axis, can't apply rewrite - axis = scalar("axis", dtype=int) - s = join(axis, empty_mat, int_mat, empty_sym_mat) - new_s = rewrite_graph(s) - assert equal_computations([new_s], [s]) - # Stack introduces an expand_dims in the join, that's a nonzero dim! s = pt.stack([vec, vec, empty_vec]) new_s = rewrite_graph(s) @@ -1374,7 +1368,7 @@ def test_local_join_make_vector(): e = f.maker.fgraph.toposort() assert len([n for n in e if isinstance(n.op, Join)]) == 1 assert all( - not isinstance(n.op, Join) or len(n.inputs) == 4 + not isinstance(n.op, Join) or len(n.inputs) == 3 for n in e if isinstance(n.op, Join) ) diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 16fa9e9738..3cbe6876e2 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -874,9 +874,6 @@ def test_empty_subtensor(self): assert local_subtensor_make_vector.transform(fgraph, node) == [v] -shared_axis = shared(1, "axis") - - @pytest.mark.parametrize( "original_fn, expected_fn", [ @@ -902,11 +899,6 @@ def test_empty_subtensor(self): lambda x, y: concatenate([x, y], axis=1)[:, 1:], lambda x, y: concatenate([x, y], axis=1)[:, 1:], ), - # Not supported, axis of concatenation is dynamically determined - ( - lambda x, y: concatenate([x, y], axis=shared_axis)[1], - lambda x, y: concatenate([x, y], axis=shared_axis)[1], - ), ], ) def test_local_subtensor_of_join(original_fn, expected_fn): diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 2ab01c047d..1f3fa01e03 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -1,4 +1,5 @@ import itertools +import pickle from functools import partial from tempfile import mkstemp @@ -10,7 +11,6 @@ import pytensor.tensor.basic as ptb import pytensor.tensor.math as ptm from pytensor import compile, config, function, shared -from pytensor.compile import SharedVariable from pytensor.compile.io import In, Out from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp @@ -1296,11 +1296,8 @@ def test_get_vector_length(): assert isinstance(z.owner.op, Join) assert get_vector_length(z) == 3 - z = join( - lscalar(), - as_tensor_variable([1, 2], ndim=1), - as_tensor_variable([3, 4], ndim=1), - ) + # A `Join` whose components have undeterminable length cannot be measured + z = Join(0)(vector("v"), vector("w")) with pytest.raises(ValueError, match=r"^Length of .*"): get_vector_length(z) @@ -1334,7 +1331,7 @@ def setup_method(self): Join.debug = False self.mode = pytensor.compile.get_default_mode().excluding("constant_folding") - self.join_op = Join() + self.join_op = Join(0) self.split_op_class = Split self.make_vector_op = MakeVector() self.floatX = config.floatX @@ -1361,14 +1358,19 @@ def eval_outputs_and_check_vector(self, outputs, make_vector_op=None): return variables def test_input_validation(self): + # `splits` must be of integer type with pytest.raises(TypeError, match=r".*integer.*"): - Split(2)(matrix(), dscalar(), [1, 1]) + Split(2, 0)(matrix(), dvector()) - with pytest.raises(TypeError, match=r".*integer.*"): - Split(2)(matrix(), ivector(), [1, 1]) + # A symbolic (non-constant) axis is no longer supported + with pytest.raises(TypeError, match=r".*[Ss]ymbolic.*"): + Split(2, lscalar()) - with pytest.raises(TypeError, match=r".*integer.*"): - join(dscalar(), matrix(), matrix()) + with pytest.raises(TypeError, match=r".*[Ss]ymbolic.*"): + Join(lscalar()) + + with pytest.raises(TypeError, match=r".*constant integer.*"): + join(lscalar(), matrix(), matrix()) def test_join_scalar(self): a = as_tensor_variable(1) @@ -1724,60 +1726,16 @@ def test_join_matrix1_using_horizontal_stack(self): utt.verify_grad(lambda a, b: join(1, a, b), [av, bv], mode=self.mode) - def test_join_matrixV(self): - # variable join axis + def test_symbolic_axis_rejected(self): + # A symbolic (non-constant) axis is no longer supported by Join/Split. v = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=self.floatX) a = self.shared(v) b = as_tensor_variable(v) - ax = lscalar() - s = join(ax, a, b) - - f = inplace_func([ax], [s], mode=self.mode) - topo = f.maker.fgraph.toposort() - assert [True for node in topo if isinstance(node.op, type(self.join_op))] - - want = np.array( - [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - ) - got = f(0) - assert np.allclose(got, want) - - want = np.array( - [[0.1, 0.2, 0.3, 0.1, 0.2, 0.3], [0.4, 0.5, 0.6, 0.4, 0.5, 0.6]] - ) - got = f(1) - assert np.allclose(got, want) - - utt.verify_grad(lambda a, b: join(0, a, b), [v, 2 * v], mode=self.mode) - utt.verify_grad(lambda a, b: join(1, a, b), [v, 2 * v], mode=self.mode) - - def test_join_matrixV_negative_axis(self): - # variable join negative axis - v = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=self.floatX) - a = self.shared(v) - b = as_tensor_variable(v) - ax = lscalar() - s = join(ax, a, b) - - f = inplace_func([ax], [s], mode=self.mode) - topo = f.maker.fgraph.toposort() - assert [True for node in topo if isinstance(node.op, type(self.join_op))] - - want = np.array( - [[0.1, 0.2, 0.3, 0.1, 0.2, 0.3], [0.4, 0.5, 0.6, 0.4, 0.5, 0.6]] - ) - - got = f(-1) - assert np.allclose(got, want) - - want = np.array( - [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - ) - got = f(-2) - assert np.allclose(got, want) - - with pytest.raises((ValueError, IndexError)): - f(-3) + with pytest.raises(TypeError, match=r".*constant integer.*"): + join(lscalar(), a, b) + # A constant scalar variable is still accepted and normalized. + out = join(constant(-1), a, b) + assert out.owner.op == Join(1) @pytest.mark.parametrize("py_impl", (False, True)) def test_join_matrixC_negative_axis(self, py_impl): @@ -1835,15 +1793,15 @@ def test_broadcastable_flag_assignment_mixed_otheraxes(self): a = self.shared(a_val, shape=(None, None, 1)) b = self.shared(b_val, shape=(1, None, 1)) - c = self.join_op(1, a, b) + c = join(1, a, b) assert c.type.shape == (1, None, 1) # Opt can remplace the int by an PyTensor constant - c = self.join_op(constant(1), a, b) + c = join(constant(1), a, b) assert c.type.shape == (1, None, 1) # In case futur opt insert other useless stuff - c = self.join_op(cast(constant(1), dtype="int32"), a, b) + c = join(cast(constant(1), dtype="int32"), a, b) assert c.type.shape == (1, None, 1) f = function([], c, mode=self.mode) @@ -1870,7 +1828,7 @@ def test_broadcastable_flag_assignment_mixed_thisaxes(self): a = self.shared(a_val, shape=(None, None, 1)) b = self.shared(b_val, shape=(1, None, 1)) - c = self.join_op(0, a, b) + c = join(0, a, b) assert c.type.shape[0] != 1 f = function([], c, mode=self.mode) @@ -1887,7 +1845,7 @@ def test_broadcastable_flag_assignment_mixed_thisaxes(self): b.set_value(rng.random((3, 4, 1)).astype(self.floatX)) a = TensorType(dtype=self.floatX, shape=(None, None, 1))() b = TensorType(dtype=self.floatX, shape=(1, None, 1))() - c = self.join_op(0, a, b) + c = join(0, a, b) f = function([a, b], c, mode=self.mode) bad_b_val = rng.random((3, 4, 1)).astype(self.floatX) with pytest.raises(TypeError): @@ -1903,7 +1861,7 @@ def test_broadcastable_flags_all_broadcastable_on_joinaxis(self): a = self.shared(a_val, shape=(1, None, 1)) b = self.shared(b_val, shape=(1, None, 1)) - c = self.join_op(0, a, b) + c = join(0, a, b) assert c.type.shape[0] != 1 f = function([], c, mode=self.mode) @@ -1921,7 +1879,7 @@ def test_broadcastable_single_input_broadcastable_dimension(self): rng = np.random.default_rng(seed=utt.fetch_seed()) a_val = rng.random((1, 4, 1)).astype(self.floatX) a = self.shared(a_val, shape=(1, None, 1)) - b = self.join_op(0, a) + b = join(0, a) assert b.type.shape[0] == 1 assert b.type.shape[2] == 1 assert b.type.shape[1] != 1 @@ -1950,17 +1908,17 @@ def test_broadcastable_flags_many_dims_and_inputs(self): d = TensorType(dtype=self.floatX, shape=(1, None, 1, 1, None, 1))() e = TensorType(dtype=self.floatX, shape=(1, None, 1, None, None, 1))() - f = self.join_op(0, a, b, c, d, e) + f = join(0, a, b, c, d, e) fb = tuple(s == 1 for s in f.type.shape) assert f.type.shape == (5, 1, 1, 1, None, 1) assert fb == (False, True, True, True, False, True) - g = self.join_op(1, a, b, c, d, e) + g = join(1, a, b, c, d, e) gb = tuple(s == 1 for s in g.type.shape) assert g.type.shape == (1, None, 1, 1, None, 1) assert gb == (True, False, True, True, False, True) - h = self.join_op(4, a, b, c, d, e) + h = join(4, a, b, c, d, e) hb = tuple(s == 1 for s in h.type.shape) assert h.type.shape == (1, 1, 1, 1, None, 1) assert hb == (True, True, True, True, False, True) @@ -2020,7 +1978,7 @@ def get_mat(s1, s2): x3 = self.shared(get_mat(1, 4)) # Test dim 0 - z = self.join_op(0, x1, x2, x3) + z = join(0, x1, x2, x3) f = pytensor.function([], z.shape, mode=self.mode) topo = f.maker.fgraph.toposort() @@ -2035,7 +1993,7 @@ def get_mat(s1, s2): x1.set_value(get_mat(3, 4)) x2.set_value(get_mat(3, 4)) x3.set_value(get_mat(3, 5)) - z = self.join_op(1, x1, x2, x3) + z = join(1, x1, x2, x3) f = pytensor.function([], z.shape, mode=self.mode) topo = f.maker.fgraph.toposort() out = f() @@ -2057,7 +2015,7 @@ def test_mixed_ndim_error(self): v = self.shared(rng.random(4).astype(self.floatX)) m = self.shared(rng.random((4, 4)).astype(self.floatX)) with pytest.raises(TypeError, match="same number of dimensions"): - self.join_op(0, v, m) + join(0, v, m) def test_static_shape_inference(self): a = ptb.tensor(dtype="int8", shape=(2, 3)) @@ -2084,7 +2042,7 @@ def test_static_shape_inference(self): def test_split_0elem(self): rng = np.random.default_rng(seed=utt.fetch_seed()) m = self.shared(rng.random((4, 6)).astype(self.floatX)) - o = self.split_op_class(2)(m, 0, [4, 0]) + o = self.split_op_class(2, 0)(m, [4, 0]) f = function([], o, mode=self.mode) assert any( isinstance(node.op, self.split_op_class) @@ -2097,7 +2055,7 @@ def test_split_0elem(self): def test_split_neg(self): rng = np.random.default_rng(seed=utt.fetch_seed()) m = self.shared(rng.random((4, 6)).astype(self.floatX)) - o = self.split_op_class(2)(m, 0, [5, -1]) + o = self.split_op_class(2, 0)(m, [5, -1]) f = function([], o, mode=self.mode) assert any( isinstance(node.op, self.split_op_class) @@ -2109,7 +2067,7 @@ def test_split_neg(self): def test_split_static_shape(self): x = TensorType("floatX", shape=(5,))("x") s = iscalar("s") - y = Split(2)(x, 0, [s, 5 - s])[0] + y = Split(2, 0)(x, [s, 5 - s])[0] assert y.type.shape == (None,) def test_join_oneInput(self): @@ -2131,10 +2089,9 @@ def test_join_oneInput(self): @pytest.mark.parametrize("linker", ("py", "c")) def test_split_view(self, linker): x = vector("x") - axis = 0 - op = Split(len_splits=3) + op = Split(len_splits=3, axis=0) assert op.view_map == {0: [0], 1: [0], 2: [0]} - splits = op(x, axis, [0, 3, 2]) + splits = op(x, [0, 3, 2]) mode = Mode(linker) f = pytensor.function( @@ -2147,7 +2104,7 @@ def test_split_view(self, linker): assert r.base is x_test def test_join_negative_axis_rewrite(self): - """Test that constant negative axis is rewritten to positive axis in make_node.""" + """Test that a constant negative axis is normalized to a positive axis.""" v = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=self.floatX) a = self.shared(v) b = as_tensor_variable(v) @@ -2155,6 +2112,48 @@ def test_join_negative_axis_rewrite(self): assert equal_computations([join(-1, a, b)], [join(1, a, b)]) assert equal_computations([join(-2, a, b)], [join(0, a, b)]) + def test_axis_is_op_property(self): + # The axis is an Op property, not an Apply-node input. + a = matrix("a") + b = matrix("b") + + join_node = join(1, a, b).owner + assert join_node.op.axis == 1 + assert join_node.inputs == [a, b] + assert Join(0) == Join(0) + assert Join(0) != Join(1) + assert hash(Join(0)) == hash(Join(0)) + + split_node = Split(2, 1)(a, [1, 2])[0].owner + assert split_node.op.axis == 1 + assert split_node.op.len_splits == 2 + assert split_node.inputs[0] is a + assert Split(2, 0) != Split(2, 1) + assert Split(2, 0) != Split(3, 0) + + def test_negative_axis_normalized_in_make_node(self): + # `make_node` binds the node to an Op with a canonical non-negative + # axis, so a directly-constructed negative-axis Op still yields a node + # whose `op.axis` is normalized. + a = matrix("a") + b = matrix("b") + assert Join(-1)(a, b).owner.op == Join(1) + assert Split(2, -1)(a, [1, 2])[0].owner.op == Split(2, 1) + + def test_pickle_roundtrip(self): + # Compiled functions with Join/Split round-trip through pickle. + a = matrix("a") + b = matrix("b") + joined = join(1, a, b) + split_out = Split(2, 0)(a, [1, 1]) + f = function([a, b], [joined, *split_out], mode=self.mode) + + reloaded = pickle.loads(pickle.dumps(f)) + a_val = np.ones((2, 3), dtype=config.floatX) + b_val = np.zeros((2, 4), dtype=config.floatX) + for orig, new in zip(f(a_val, b_val), reloaded(a_val, b_val), strict=True): + assert np.array_equal(orig, new) + def test_TensorFromScalar(): s = ps.constant(56) @@ -3869,43 +3868,41 @@ def test_ExtractDiag(self): self._compile_and_check([atens3], [atens3_diag], [atens3_val], ExtractDiag) def test_Split(self): - aiscal = iscalar() aivec = ivector() adtens = tensor3() adtens_val = random(4, 10, 3) aivec_val = [2, 5, 3] - for aiscal_val in [1, -2]: + for axis in [1, -2]: self._compile_and_check( - [adtens, aiscal, aivec], - [Split(3)(adtens, aiscal, aivec)[0]], - [adtens_val, aiscal_val, aivec_val], + [adtens, aivec], + [Split(3, axis)(adtens, aivec)[0]], + [adtens_val, aivec_val], (Split), ) def test_Join(self): - aiscal = iscalar() cdmat = dmatrix() admat_val = random(1, 3) bdmat_val = random(2, 3) cdmat_val = random(4, 3) admat = dmatrix() bdmat = dmatrix() - for aiscal_val in [0, -2]: + for axis in [0, -2]: self._compile_and_check( - [aiscal, admat, bdmat, cdmat], - [Join()(aiscal, admat, bdmat, cdmat)], - [aiscal_val, admat_val, bdmat_val, cdmat_val], + [admat, bdmat, cdmat], + [Join(axis)(admat, bdmat, cdmat)], + [admat_val, bdmat_val, cdmat_val], Join, ) admat_val = random(4, 1) bdmat_val = random(4, 3) cdmat_val = random(4, 2) - for aiscal_val in [-1, 1]: + for axis in [-1, 1]: self._compile_and_check( - [aiscal, admat, bdmat, cdmat], - [Join()(aiscal, admat, bdmat, cdmat)], - [aiscal_val, admat_val, bdmat_val, cdmat_val], + [admat, bdmat, cdmat], + [Join(axis)(admat, bdmat, cdmat)], + [admat_val, bdmat_val, cdmat_val], Join, ) @@ -4505,7 +4502,7 @@ def core_np(*scalars): ) -@pytest.mark.parametrize("axis", [constant(1), constant(-2), shared(1)]) +@pytest.mark.parametrize("axis", [1, -2, constant(1)]) @pytest.mark.parametrize("broadcasting_y", ["none", "implicit", "explicit"]) @config.change_flags(cxx="") # C code not needed def test_vectorize_join(axis, broadcasting_y): @@ -4516,7 +4513,7 @@ def core_pt(x, y): return join(axis, x, y) def core_np(x, y): - return np.concatenate([x, y], axis=axis.eval()) + return np.concatenate([x, y], axis=int(getattr(axis, "data", axis))) x = tensor(shape=(4, 2, 3, 5)) y_shape = {"none": (4, 2, 3, 5), "implicit": (2, 3, 5), "explicit": (1, 2, 3, 5)} @@ -4524,7 +4521,7 @@ def core_np(x, y): vectorize_pt = function([x, y], vectorize(core_pt, signature=signature)(x, y)) - blockwise_needed = isinstance(axis, SharedVariable) or broadcasting_y != "none" + blockwise_needed = broadcasting_y != "none" has_blockwise = any( isinstance(node.op, Blockwise | BlockwiseWithCoreShape) for node in vectorize_pt.maker.fgraph.apply_nodes