From 7ea862c96dad277774bb90054204f6e51fcf8f02 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 30 Apr 2026 23:38:05 +0200 Subject: [PATCH 1/5] Drop `fgraph` parameter from `Op.infer_shape` signature Breaking API change: the `fgraph` argument was unused by every in-tree `infer_shape` implementation. Removing it makes `infer_shape` a pure function of `(node, input_shapes)`, simpler to call from outside an fgraph context (e.g. ShapeFeature's lazy kernel build) and tighter as a contract. External Ops with custom `infer_shape(self, fgraph, node, input_shapes)` must drop the `fgraph` parameter. --- pytensor/breakpoint.py | 2 +- pytensor/compile/builders.py | 2 +- pytensor/compile/ops.py | 10 +++--- pytensor/ifelse.py | 2 +- pytensor/raise_op.py | 2 +- pytensor/scan/op.py | 2 +- pytensor/sparse/basic.py | 36 +++++++++---------- pytensor/sparse/math.py | 32 ++++++++--------- pytensor/sparse/rewriting.py | 2 +- pytensor/tensor/basic.py | 28 +++++++-------- pytensor/tensor/blas.py | 12 +++---- pytensor/tensor/blockwise.py | 10 ++---- pytensor/tensor/elemwise.py | 6 ++-- pytensor/tensor/extra_ops.py | 18 +++++----- pytensor/tensor/fourier.py | 2 +- pytensor/tensor/linalg/constructors.py | 2 +- .../tensor/linalg/decomposition/cholesky.py | 2 +- pytensor/tensor/linalg/decomposition/eigen.py | 6 ++-- pytensor/tensor/linalg/decomposition/lu.py | 4 +-- pytensor/tensor/linalg/decomposition/qr.py | 2 +- pytensor/tensor/linalg/decomposition/schur.py | 4 +-- pytensor/tensor/linalg/decomposition/svd.py | 2 +- pytensor/tensor/linalg/inverse.py | 6 ++-- pytensor/tensor/linalg/products.py | 2 +- pytensor/tensor/linalg/solvers/core.py | 2 +- .../tensor/linalg/solvers/linear_control.py | 2 +- pytensor/tensor/linalg/summary.py | 4 +-- pytensor/tensor/math.py | 4 +-- pytensor/tensor/random/op.py | 2 +- pytensor/tensor/reshape.py | 4 +-- pytensor/tensor/rewriting/numba.py | 3 +- pytensor/tensor/rewriting/shape.py | 10 +++--- pytensor/tensor/shape.py | 8 ++--- pytensor/tensor/signal/conv.py | 2 +- pytensor/tensor/sort.py | 4 +-- pytensor/tensor/special.py | 6 ++-- pytensor/tensor/subtensor.py | 12 +++---- pytensor/xtensor/basic.py | 2 +- tests/compile/test_ops.py | 2 +- tests/sparse/test_basic.py | 2 +- tests/tensor/rewriting/test_shape.py | 4 +-- tests/tensor/test_blockwise.py | 4 +-- tests/tensor/test_elemwise.py | 8 ++--- 43 files changed, 135 insertions(+), 146 deletions(-) diff --git a/pytensor/breakpoint.py b/pytensor/breakpoint.py index f9c74950dc..c2913cfea1 100644 --- a/pytensor/breakpoint.py +++ b/pytensor/breakpoint.py @@ -144,7 +144,7 @@ def perform(self, node, inputs, output_storage): def pullback(self, inputs, outputs, output_gradients): return [disconnected_type(), *output_gradients] - def infer_shape(self, fgraph, inputs, input_shapes): + def infer_shape(self, inputs, input_shapes): # Return the shape of every input but the condition (first input) return input_shapes[1:] diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index c55af519a4..29189b7fe3 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -884,7 +884,7 @@ def connection_pattern(self, node): self._connection_pattern = ret return ret - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): # TODO: Use `fgraph.shape_feature` to do this instead. out_shapes = infer_shape(self.inner_outputs, self.inner_inputs, shapes) diff --git a/pytensor/compile/ops.py b/pytensor/compile/ops.py index 57d34bbf25..028c854854 100644 --- a/pytensor/compile/ops.py +++ b/pytensor/compile/ops.py @@ -90,7 +90,7 @@ class ViewOp(TypeCastingOp): def make_node(self, x): return Apply(self, [x], [x.type()]) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return input_shapes def pullback(self, args, outputs, g_outs): @@ -179,7 +179,7 @@ def c_code(self, node, name, inames, onames, sub): # Else, no C code raise NotImplementedError() - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return input_shapes @@ -251,8 +251,8 @@ def __reduce__(self): ) return load_back, (mod, name) - def _infer_shape(self, fgraph, node, input_shapes): - return self.__infer_shape(fgraph, node, input_shapes) + def _infer_shape(self, node, input_shapes): + return self.__infer_shape(node, input_shapes) def as_op(itypes, otypes, infer_shape=None): @@ -275,7 +275,7 @@ def wrap_py(itypes, otypes, infer_shape=None): It takes an optional infer_shape parameter that should be a callable with this signature: - def infer_shape(fgraph, node, input_shapes): + def infer_shape(node, input_shapes): ... return output_shapes diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index 5f7cb58470..8ad47ac1c8 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -103,7 +103,7 @@ def __str__(self): args.append("inplace") return f"if{{{','.join(args)}}}" - def infer_shape(self, fgraph, node, inputs_shapes): + def infer_shape(self, node, inputs_shapes): # By construction, corresponding then/else pairs have the same number # of dimensions diff --git a/pytensor/raise_op.py b/pytensor/raise_op.py index c2962e4d7d..87f52c3c4b 100644 --- a/pytensor/raise_op.py +++ b/pytensor/raise_op.py @@ -137,7 +137,7 @@ def c_code(self, node, name, inames, onames, props): def c_code_cache_version(self): return (2,) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] def do_constant_folding(self, fgraph, node): diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index c2396d7948..6cfbbdbea9 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -2251,7 +2251,7 @@ def perform(self, node, inputs, output_storage): self.t_call = t_call self.t_fn = t_fn - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): # input_shapes correspond to the shapes of node.inputs for inp, inp_shp in zip(node.inputs, input_shapes, strict=True): assert inp_shp is None or len(inp_shp) == inp.type.ndim diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index bdd4f77777..1fa1a8e181 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -494,7 +494,7 @@ def pullback(self, inputs, outputs, gout): disconnected_type(), ] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): # node.inputs[3] is of length as we only support sparse matrix. return [(node.inputs[3][0], node.inputs[3][1])] @@ -584,7 +584,7 @@ def perform(self, node, inputs, outputs): g_out[0] = gout_data - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[1]] @@ -629,7 +629,7 @@ def pullback(self, inputs, outputs, outputs_gradients): else: return [Cast(inputs[0].dtype)(gz)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return ins_shapes def __str__(self): @@ -742,7 +742,7 @@ def pullback(self, inputs, outputs, gout): else: return [SparseFromDense(x.type.format)(gz)] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -806,7 +806,7 @@ def pullback(self, inputs, outputs, gout): ) return (gx,) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -820,7 +820,7 @@ class GetItemList(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[1][0], shapes[0][1])] def make_node(self, x, index): @@ -865,7 +865,7 @@ def pullback(self, inputs, outputs, g_outputs): class GetItemListGrad(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0])] def make_node(self, x, index, gz): @@ -958,7 +958,7 @@ def pullback(self, inputs, outputs, g_outputs): class GetItem2ListsGrad(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0])] def make_node(self, x, ind1, ind2, gz): @@ -1139,7 +1139,7 @@ class GetItemScalar(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [()] def make_node(self, x, index): @@ -1239,7 +1239,7 @@ def pullback(self, inputs, outputs, gout): assert _is_sparse_variable(x) and _is_sparse_variable(gz) return (transpose(gz),) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0][::-1]] @@ -1288,7 +1288,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [col_scale(gz, s), sp_sum(x * gz, axis=0)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1339,7 +1339,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [row_scale(gz, s), sp_sum(x * gz, axis=1)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1438,7 +1438,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [square_diagonal(gz)] - def infer_shape(self, fgraph, nodes, shapes): + def infer_shape(self, nodes, shapes): return [(minimum(*shapes[0]),)] @@ -1498,7 +1498,7 @@ def perform(self, node, inputs, outputs): def pullback(self, inputs, outputs, output_grad): return [output_grad[0]] - def infer_shape(self, fgraph, node, i0_shapes): + def infer_shape(self, node, i0_shapes): return i0_shapes def __str__(self): @@ -1614,7 +1614,7 @@ def choose(continuous, derivative): return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): d = sum(shape[1] for shape in ins_shapes) return [(ins_shapes[0][0], d)] @@ -1711,7 +1711,7 @@ def choose(continuous, derivative): return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): d = sum(shape[0] for shape in ins_shapes) return [(d, ins_shapes[0][1])] @@ -1800,7 +1800,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [gz] - def infer_shape(self, fgraph, node, i0_shapes): + def infer_shape(self, node, i0_shapes): return i0_shapes @@ -1880,7 +1880,7 @@ def perform(self, node, inp, out_): (data, indices, indptr), shape=out_shape, dtype=values.dtype ) - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): x = node.inputs[0] return [[x[0], x[1]]] diff --git a/pytensor/sparse/math.py b/pytensor/sparse/math.py index 57fe280f88..d59915ed45 100644 --- a/pytensor/sparse/math.py +++ b/pytensor/sparse/math.py @@ -325,7 +325,7 @@ def pullback(self, inputs, outputs, gout): r = psb.SparseFromDense(o_format)(r) return [r] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): r = None if self.axis is None: r = [()] @@ -404,7 +404,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_sparse_variable(gz) return gz, gz - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -465,7 +465,7 @@ def pullback(self, inputs, outputs, gout): derivative = {True: gz, False: None} return [derivative[b] for b in is_continuous] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -507,7 +507,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_dense_variable(gz) return psb.sp_ones_like(x) * gz, gz - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[1]] @@ -567,7 +567,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_sparse_variable(gz) return gz, sp_sum(gz, axis=0, sparse_grad=True) - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -697,7 +697,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return y * gz, x * gz - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -786,7 +786,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_sparse_variable(gz) return y * gz, psb.dense_from_sparse(x * gz) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -869,7 +869,7 @@ def pullback(self, inputs, outputs, gout): return mul_s_v(gz, y), sp_sum(x * gz, axis=0, sparse_grad=True) - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -987,7 +987,7 @@ def perform(self, node, inputs, outputs): self.comparison(x, y).astype("uint8").asformat(node.outputs[0].type.format) ) - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1032,7 +1032,7 @@ def perform(self, node, inputs, outputs): o = np.asarray(o) out[0] = o - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1282,7 +1282,7 @@ def pullback(self, inputs, outputs, gout): rval[1] = psb.dense_from_sparse(rval[1]) return rval - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0][0], shapes[1][1])] @@ -1410,7 +1410,7 @@ def pullback(self, inputs, outputs, gout): (g_out,) = gout return [structured_dot_grad(a, b, g_out), structured_dot(a.T, g_out)] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0][0], shapes[1][1])] @@ -1594,7 +1594,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -1729,7 +1729,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -1827,7 +1827,7 @@ def pullback(self, inputs, outputs, gout): return rval - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[2]] @@ -1840,7 +1840,7 @@ class Dot(Op): def __str__(self): return "Sparse" + self.__class__.__name__ - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshp, yshp = shapes x, y = node.inputs if x.ndim == 2 and y.ndim == 2: diff --git a/pytensor/sparse/rewriting.py b/pytensor/sparse/rewriting.py index 29b9ff8bae..8d6138bd2c 100644 --- a/pytensor/sparse/rewriting.py +++ b/pytensor/sparse/rewriting.py @@ -176,7 +176,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ return code - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[3]] def c_code_cache_version(self): diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index d1d19f41df..9bc6232a14 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -628,7 +628,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = np.asarray(s) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [()] def pullback(self, inp, outputs, grads): @@ -689,7 +689,7 @@ def perform(self, node, inputs, output_storage): # not using .item() because that returns a Python scalar, not a numpy scalar output_storage[0][0] = inputs[0][()] - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [()] def pullback(self, inp, outputs, grads): @@ -1370,7 +1370,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = np.eye(n, m, k, dtype=self.dtype) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): out_shape = [node.inputs[0], node.inputs[1]] return [out_shape] @@ -1699,7 +1699,7 @@ def c_code(self, node, name, inp, out, sub): def c_code_cache_version(self): return (5,) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [node.inputs[1:]] def connection_pattern(self, node): @@ -1957,7 +1957,7 @@ def c_code(self, node, name, inp, out_, props): """ return ret - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): return [(len(ishapes),)] def pullback(self, inputs, outputs, output_gradients): @@ -2245,7 +2245,7 @@ def perform(self, node, inputs, outputs_storage): for out_storage, out in zip(outputs_storage, split_outs, strict=False): out_storage[0] = out - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): axis = node.inputs[1] splits = node.inputs[2] shp_x, _shp_axis, _shp_splits = in_shapes @@ -2701,7 +2701,7 @@ def pullback(self, inputs, outputs, grads): return rval - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): from pytensor.tensor.math import eq, ge # ishapes[0] contains the size of the axis on which we join @@ -3255,7 +3255,7 @@ def make_node(self, start, stop, step): return Apply(self, inputs, outputs) @config.change_flags(warn_float64="ignore") - def infer_shape(self, fgraph, node, i_shapes): + def infer_shape(self, node, i_shapes): from pytensor.tensor.math import ceil, maximum # Note start, stop and step can be float numbers. @@ -3632,7 +3632,7 @@ def perform(self, node, inp, out): self._rec_perform(node, x, y, self.inverse, outs[0], curdim=0) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): from pytensor.tensor.math import maximum shp_x = in_shapes[0] @@ -3884,7 +3884,7 @@ def pullback(self, inputs, outputs, gout): x_grad = moveaxis(x_grad, (0, 1), (axis1, axis2)) return [x_grad] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): from pytensor.tensor.math import clip, minimum (in_shape,) = shapes @@ -4233,7 +4233,7 @@ def __init__(self, mode): assert mode in ("raise", "wrap", "clip") self.mode = mode - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): a_shape, choices_shape = shapes if choices_shape is None: # choices is a TypedList, not a tensor; no shape to broadcast @@ -4264,9 +4264,7 @@ def make_node(self, a, choices): choice = as_tensor_variable(choices) choice_dtype = choice.dtype - (out_shape,) = self.infer_shape( - None, None, [shape_tuple(a), shape_tuple(choice)] - ) + (out_shape,) = self.infer_shape(None, [shape_tuple(a), shape_tuple(choice)]) static_out_shape = () for s in out_shape: @@ -4369,7 +4367,7 @@ def c_code(self, node, name, inputs, out_, sub): """ return str - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [node.inputs] def c_code_cache_version(self): diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 247034451b..bf5ff6e777 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -247,7 +247,7 @@ def perform(self, node, inputs, out_storage): out += y out_storage[0][0] = np.asarray(out, dtype=y.dtype) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] @@ -320,7 +320,7 @@ def perform(self, node, inputs, output_storage): A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive) output_storage[0][0] = A - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] @@ -945,7 +945,7 @@ def perform(self, node, inp, out): z += a * np.dot(x, y) zout[0] = z - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): z_shape, _, x_shape, y_shape, _ = input_shapes return [ ( @@ -1150,7 +1150,7 @@ def make_node(self, x, y): def perform(self, node, inputs, output_storage): output_storage[0][0] = np.dot(*inputs) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [[input_shapes[0][0], input_shapes[1][1]]] setup_z_Nz_Sz = """ @@ -1253,7 +1253,7 @@ def perform(self, node, inp, out): e.args = (*e.args, x.shape, y.shape) raise - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [[input_shapes[0][0], input_shapes[1][1]]] setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz @@ -1642,7 +1642,7 @@ def pushforward(self, inputs, outputs, eval_points): else: return [t2] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshp, yshp = shapes return [xshp[:-1] + yshp[2:]] diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index ecc4ad92d1..2c8a2c99bc 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -6,7 +6,6 @@ from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType -from pytensor.graph import FunctionGraph from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.null_type import NullType from pytensor.graph.op import Op @@ -321,9 +320,7 @@ def make_node(self, *inputs): def batch_ndim(self, node: Apply) -> int: return cast(int, node.outputs[0].type.ndim - len(self.outputs_sig[0])) - def infer_shape( - self, fgraph, node, input_shapes - ) -> list[tuple[TensorVariable, ...]]: + def infer_shape(self, node, input_shapes) -> list[tuple[TensorVariable, ...]]: from pytensor.tensor import broadcast_shape from pytensor.tensor.shape import Shape_i @@ -354,13 +351,10 @@ def extract_core_shape_from_infer_shape(): return_dummy_inputs=True, propagate_unbatched_core_inputs=True, ) - dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False) core_input_shapes = [ input_shape[batch_ndims:] for input_shape in input_shapes ] - core_output_shapes = core_op_infer_shape( - dummy_fgraph, dummy_core_node, core_input_shapes - ) + core_output_shapes = core_op_infer_shape(dummy_core_node, core_input_shapes) if not dummy_core_inputs: # All inputs are unbatched, so the core_shape can be used as is diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 1196a5ca77..f9c6c2b2db 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -257,7 +257,7 @@ def perform(self, node, inp, out): new_shape.insert(augm, 1) out[0][0] = res.reshape(new_shape) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (ishp,) = shapes # transpose rval = [ishp[i] for i in self.shuffle] @@ -755,7 +755,7 @@ def _check_runtime_broadcast(node, inputs): "If broadcasting was intended, use `specify_broadcastable` on the relevant input." ) - def infer_shape(self, fgraph, node, i_shapes) -> list[tuple[TensorVariable, ...]]: + def infer_shape(self, node, i_shapes) -> list[tuple[TensorVariable, ...]]: from pytensor.tensor.extra_ops import broadcast_shape out_shape = broadcast_shape(*i_shapes, arrays_are_shapes=True) @@ -1426,7 +1426,7 @@ def perform(self, node, inp, out): output[0] = np.asarray(out, dtype=out_dtype) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (ishape,) = shapes axis = self.axis if axis is None: diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 7091f2fd36..96e6eca2b5 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -150,7 +150,7 @@ def make_node(self, x, v, sorter=None): raise TypeError("sorter must be an integer vector", sorter.type) return Apply(self, [x, v, sorter], [out_type()]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[1]] def perform(self, node, inputs, output_storage): @@ -340,7 +340,7 @@ def pullback(self, inputs, outputs, output_gradients): f'{type(self).__name__}: unknown gradient for mode "{self.mode}"' ) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return shapes def c_code(self, node, name, inames, onames, sub): @@ -717,7 +717,7 @@ def pullback(self, inputs, outputs, gout): return [gx, disconnected_type()] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): i0_shapes = ins_shapes[0] repeats = node.inputs[1] out_shape = list(i0_shapes) @@ -849,7 +849,7 @@ def perform(self, node, inputs, out_): (out,) = out_ out[0] = np.bartlett(M) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): temp = node.inputs[0] M = ptb.switch(lt(temp, 0), ptb.cast(0, temp.dtype), temp) return [[M]] @@ -892,7 +892,7 @@ class FillDiagonal(Op): # See function fill_diagonal for docstring __props__ = () - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [in_shapes[0]] def make_node(self, a, val): @@ -993,7 +993,7 @@ class FillDiagonalOffset(Op): # See function fill_diagonal_offset for docstring __props__ = () - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [in_shapes[0]] def make_node(self, a, val, offset): @@ -1240,7 +1240,7 @@ def perform(self, node, inputs, output_storage): else: output_storage[0][0] = outs - def infer_shape(self, fgraph, node, i0_shapes): + def infer_shape(self, node, i0_shapes): [x_shape] = i0_shapes shape0_op = Shape_i(0) out_shapes = [(shape0_op(out),) for out in node.outputs] @@ -1310,7 +1310,7 @@ def make_node(self, indices, dims): [out_type() for _i in range(ptb.get_vector_length(dims))], ) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] * len(node.outputs) def perform(self, node, inp, out): @@ -1387,7 +1387,7 @@ def make_node(self, *inp): [out_type()], ) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] def perform(self, node, inp, out): diff --git a/pytensor/tensor/fourier.py b/pytensor/tensor/fourier.py index 9d35955c6f..260fc2dc09 100644 --- a/pytensor/tensor/fourier.py +++ b/pytensor/tensor/fourier.py @@ -100,7 +100,7 @@ def make_node(self, a, n, axis): ], ) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): shape_a = in_shapes[0] n = node.inputs[1] axis = node.inputs[2] diff --git a/pytensor/tensor/linalg/constructors.py b/pytensor/tensor/linalg/constructors.py index 47a5ec2922..449c4b8009 100644 --- a/pytensor/tensor/linalg/constructors.py +++ b/pytensor/tensor/linalg/constructors.py @@ -35,7 +35,7 @@ def pullback(self, inputs, outputs, gout): ] return [gout[0][slc] for slc in slices] - def infer_shape(self, fgraph, nodes, shapes): + def infer_shape(self, nodes, shapes): first, second = unzip(shapes, n=2, strict=True) return [(pt.add(*first), pt.add(*second))] diff --git a/pytensor/tensor/linalg/decomposition/cholesky.py b/pytensor/tensor/linalg/decomposition/cholesky.py index 9dd69d58cb..fca8a0dd74 100644 --- a/pytensor/tensor/linalg/decomposition/cholesky.py +++ b/pytensor/tensor/linalg/decomposition/cholesky.py @@ -33,7 +33,7 @@ def __init__( if self.overwrite_a: self.destroy_map = {0: [0]} - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] def make_node(self, x): diff --git a/pytensor/tensor/linalg/decomposition/eigen.py b/pytensor/tensor/linalg/decomposition/eigen.py index 4cdc1cd0ac..0db83e1ce9 100644 --- a/pytensor/tensor/linalg/decomposition/eigen.py +++ b/pytensor/tensor/linalg/decomposition/eigen.py @@ -65,7 +65,7 @@ def perform(self, node, inputs, outputs): outputs[0][0] = w.astype(dtype, copy=False) outputs[1][0] = v.astype(dtype, copy=False) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (x_shapes,) = shapes n, _ = x_shapes @@ -206,7 +206,7 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": return self return type(self)(**new_props) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] return [(n,), (n, n)] @@ -416,7 +416,7 @@ def make_node(self, a, b=None): w = vector(dtype=out_dtype, shape=(N,)) return Apply(self, inputs, [w]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] return [ (n,), diff --git a/pytensor/tensor/linalg/decomposition/lu.py b/pytensor/tensor/linalg/decomposition/lu.py index b64b1a37e4..5373b825f3 100644 --- a/pytensor/tensor/linalg/decomposition/lu.py +++ b/pytensor/tensor/linalg/decomposition/lu.py @@ -43,7 +43,7 @@ def __init__(self, *, permute_l=False, overwrite_a=False, p_indices=False): if self.overwrite_a: self.destroy_map = {0: [0]} if self.permute_l else {1: [0]} - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] if self.permute_l: return [(n, n), (n, n)] @@ -259,7 +259,7 @@ def make_node(self, A): return Apply(self, [A], [LU, pivots]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] return [(n, n), (n,)] diff --git a/pytensor/tensor/linalg/decomposition/qr.py b/pytensor/tensor/linalg/decomposition/qr.py index 0de3d57327..f499f8d918 100644 --- a/pytensor/tensor/linalg/decomposition/qr.py +++ b/pytensor/tensor/linalg/decomposition/qr.py @@ -103,7 +103,7 @@ def make_node(self, x): return Apply(self, [x], outputs) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (x_shape,) = shapes M, N = x_shape diff --git a/pytensor/tensor/linalg/decomposition/schur.py b/pytensor/tensor/linalg/decomposition/schur.py index 7b34a64efd..d2260b4d40 100644 --- a/pytensor/tensor/linalg/decomposition/schur.py +++ b/pytensor/tensor/linalg/decomposition/schur.py @@ -147,7 +147,7 @@ def perform(self, node, inputs, outputs): T_out[0] = T Z_out[0] = Z - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0], shapes[0]] def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": @@ -490,7 +490,7 @@ def perform(self, node, inputs, outputs): alpha_out[0] = alpha beta_out[0] = beta - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): A_shape, B_shape = shapes if self.return_eigenvalues: return [A_shape, B_shape, (A_shape[0],), (A_shape[0],), A_shape, B_shape] diff --git a/pytensor/tensor/linalg/decomposition/svd.py b/pytensor/tensor/linalg/decomposition/svd.py index 86291b5104..6a55a4e7b8 100644 --- a/pytensor/tensor/linalg/decomposition/svd.py +++ b/pytensor/tensor/linalg/decomposition/svd.py @@ -74,7 +74,7 @@ def perform(self, node, inputs, outputs): (s,) = outputs s[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (x_shape,) = shapes M, N = x_shape K = ptm.minimum(M, N) diff --git a/pytensor/tensor/linalg/inverse.py b/pytensor/tensor/linalg/inverse.py index 82878c8777..6c205d6ae7 100644 --- a/pytensor/tensor/linalg/inverse.py +++ b/pytensor/tensor/linalg/inverse.py @@ -61,7 +61,7 @@ def pullback(self, inputs, outputs, g_outputs): ).T return [grad] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [list(reversed(shapes[0]))] @@ -159,7 +159,7 @@ def pushforward(self, inputs, outputs, eval_points): return [-matrix_dot(xi, ev, xi)] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return shapes @@ -187,7 +187,7 @@ def perform(self, node, inputs, outputs): (x,) = outputs x[0] = np.linalg.tensorinv(a, self.ind) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): sp = shapes[0][self.ind :] + shapes[0][: self.ind] return [sp] diff --git a/pytensor/tensor/linalg/products.py b/pytensor/tensor/linalg/products.py index 1dd41fe51b..7a92792df2 100644 --- a/pytensor/tensor/linalg/products.py +++ b/pytensor/tensor/linalg/products.py @@ -65,7 +65,7 @@ def pullback(self, inputs, outputs, output_grads): return [result] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] diff --git a/pytensor/tensor/linalg/solvers/core.py b/pytensor/tensor/linalg/solvers/core.py index 9805d86c8f..39a46e84a8 100644 --- a/pytensor/tensor/linalg/solvers/core.py +++ b/pytensor/tensor/linalg/solvers/core.py @@ -71,7 +71,7 @@ def make_node(self, A, b): x = tensor(dtype=o_dtype, shape=b.type.shape) return Apply(self, [A, b], [x]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): Ashape, Bshape = shapes rows = Ashape[1] if len(Bshape) == 1: diff --git a/pytensor/tensor/linalg/solvers/linear_control.py b/pytensor/tensor/linalg/solvers/linear_control.py index 0ce4fbc3c5..930e5f0db3 100644 --- a/pytensor/tensor/linalg/solvers/linear_control.py +++ b/pytensor/tensor/linalg/solvers/linear_control.py @@ -82,7 +82,7 @@ def perform(self, node, inputs, outputs_storage): Y *= scale X[0] = Y - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[2]] def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": diff --git a/pytensor/tensor/linalg/summary.py b/pytensor/tensor/linalg/summary.py index bba599e17f..c76753885f 100644 --- a/pytensor/tensor/linalg/summary.py +++ b/pytensor/tensor/linalg/summary.py @@ -71,7 +71,7 @@ def pullback(self, inputs, outputs, g_outputs): (x,) = inputs return [gz * self(x) * matrix_inverse(x).T] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [()] def __str__(self): @@ -106,7 +106,7 @@ def perform(self, node, inputs, outputs): except Exception as e: raise ValueError("Failed to compute determinant", x) from e - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(), ()] def __str__(self): diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 139d22529d..7e93db8c01 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -254,7 +254,7 @@ def c_code(self, node, name, inp, out, sub): def c_code_cache_version(self): return (3,) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (ishape,) = shapes if self.axis is None: return [()] @@ -3106,7 +3106,7 @@ def pushforward(self, inputs, outputs, eval_points): else: return [t2] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshp, yshp = shapes return [[xshp[0], yshp[1]]] diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 34479e142c..7201d631e8 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -306,7 +306,7 @@ def extract_batch_shape(p, ps, n): return shape - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): _, size, *dist_params = node.inputs _, _, *param_shapes = input_shapes diff --git a/pytensor/tensor/reshape.py b/pytensor/tensor/reshape.py index b25308fccc..b43245d31e 100644 --- a/pytensor/tensor/reshape.py +++ b/pytensor/tensor/reshape.py @@ -67,7 +67,7 @@ def make_node(self, x: Variable) -> Apply: # type: ignore[override] return Apply(self, [x], [output_type]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): [input_shape] = shapes joined_shape = prod([input_shape[i] for i in self.axis_range], dtype=int) return [self.output_shapes(input_shape, joined_shape)] @@ -188,7 +188,7 @@ def make_node(self, x, shape): ) return Apply(self, [x, shape], [output]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): [input_shape, _] = shapes _, shape = node.inputs output_shapes = list(input_shape) diff --git a/pytensor/tensor/rewriting/numba.py b/pytensor/tensor/rewriting/numba.py index 6bb9ed5bd9..047893de49 100644 --- a/pytensor/tensor/rewriting/numba.py +++ b/pytensor/tensor/rewriting/numba.py @@ -77,8 +77,7 @@ def introduce_explicit_core_shape_blockwise(fgraph, node): else: input_shapes = [tuple(inp.shape) for inp in node.inputs] core_shapes = [ - out_shape[batch_ndim:] - for out_shape in op.infer_shape(None, node, input_shapes) + out_shape[batch_ndim:] for out_shape in op.infer_shape(node, input_shapes) ] core_shapes = [ diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index d033609b2e..c9f3012a31 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -135,12 +135,10 @@ def get_node_infer_shape(self, node): shape_infer = self.default_infer_shape try: - o_shapes = shape_infer( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] - ) + o_shapes = shape_infer(node, [self.shape_of[r] for r in node.inputs]) except ShapeError: o_shapes = self.default_infer_shape( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] + node, [self.shape_of[r] for r in node.inputs] ) except NotImplementedError as e: raise NotImplementedError( @@ -161,7 +159,7 @@ def get_node_infer_shape(self, node): else: warn(msg) o_shapes = self.default_infer_shape( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] + node, [self.shape_of[r] for r in node.inputs] ) return o_shapes @@ -230,7 +228,7 @@ def shape_tuple(self, r): return None return tuple(self.shape_ir(i, r) for i in range(r.type.ndim)) - def default_infer_shape(self, fgraph, node, i_shapes): + def default_infer_shape(self, node, i_shapes): """Return a list of shape tuple or None for the outputs of node. This function is used for Ops that don't implement infer_shape. diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 3a7202acfc..74c445f2f5 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -81,7 +81,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = np.asarray(np.shape(x), dtype="int64") - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [[len(in_shapes[0])]] def connection_pattern(self, node): @@ -297,7 +297,7 @@ def c_code(self, node, name, inames, onames, sub): # Else, no C code raise NotImplementedError() - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [()] def connection_pattern(self, node): @@ -452,7 +452,7 @@ def perform(self, node, inp, out_): ) out[0] = x - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshape, *_ = shapes shape = node.inputs[1:] # Use x shape if specified dim is None, otherwise the specified shape @@ -727,7 +727,7 @@ def pushforward(self, inputs, outputs, eval_points): return [disconnected_type()] return self(eval_points[0], *inputs[1:], return_list=True) - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): from pytensor.tensor.math import eq, maximum, mul # inputs[1] can contain at most one value of '-1', meaning the actual diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py index d470f25975..2c4f374756 100644 --- a/pytensor/tensor/signal/conv.py +++ b/pytensor/tensor/signal/conv.py @@ -78,7 +78,7 @@ def make_node(self, in1, in2, full_mode): out = tensor(dtype=dtype, shape=out_shape) return Apply(self, [in1, in2, full_mode], [out]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): _, _, full_mode = node.inputs in1_shape, in2_shape, _ = shapes out_shape = [ diff --git a/pytensor/tensor/sort.py b/pytensor/tensor/sort.py index af695d9e42..c911be988d 100644 --- a/pytensor/tensor/sort.py +++ b/pytensor/tensor/sort.py @@ -54,7 +54,7 @@ def perform(self, node, inputs, output_storage): z = output_storage[0] z[0] = np.sort(a, axis, self.kind) - def infer_shape(self, fgraph, node, inputs_shapes): + def infer_shape(self, node, inputs_shapes): assert node.inputs[0].ndim == node.outputs[0].ndim assert inputs_shapes[1] == () return [inputs_shapes[0]] @@ -185,7 +185,7 @@ def perform(self, node, inputs, output_storage): dtype=node.outputs[0].dtype, ) - def infer_shape(self, fgraph, node, inputs_shapes): + def infer_shape(self, node, inputs_shapes): assert node.inputs[0].ndim == node.outputs[0].ndim assert inputs_shapes[1] == () return [inputs_shapes[0]] diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index cf9be0ac9f..ca4906bf4c 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -57,7 +57,7 @@ def pullback(self, inp, outputs, grads): return g_dy, g_sm - def infer_shape(self, fgraph, node, shape): + def infer_shape(self, node, shape): return [shape[1]] def c_code_cache_version(self): @@ -284,7 +284,7 @@ def pushforward(self, inputs, outputs, eval_points): return [disconnected_type()] return self.pullback(inputs, outputs, eval_points) - def infer_shape(self, fgraph, node, shape): + def infer_shape(self, node, shape): return shape def c_headers(self, **kwargs): @@ -537,7 +537,7 @@ def pushforward(self, inputs, outputs, eval_points): return [disconnected_type()] return self.pullback(inputs, outputs, eval_points) - def infer_shape(self, fgraph, node, shape): + def infer_shape(self, node, shape): return shape def c_headers(self, **kwargs): diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index e21bf866bb..40f54ed887 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -843,7 +843,7 @@ def perform(self, node, inputs, out_): cdata = unflatten_index_variables(index_variables, self.idx_list) out[0] = np.asarray(x.__getitem__(tuple(cdata))) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): def _is_constant(const, x): return isinstance(const, Constant) and const.data.item() == x @@ -1689,7 +1689,7 @@ def add_to_zview(self, name, x, fail): {fail}; }}""" - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] def pushforward(self, inputs, outputs, eval_points): @@ -1867,7 +1867,7 @@ def pushforward(self, inputs, outputs, eval_points): _x, *index_variables = inputs return self.make_node(eval_points[0], *index_variables).outputs - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): x, ilist = ishapes return [ilist + x[1:]] @@ -2217,7 +2217,7 @@ def perform(self, node, inputs, output_storage): output_storage[0][0] = x - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): x, _y, _ilist = ishapes return [x] @@ -2392,7 +2392,7 @@ def pushforward(self, inputs, outputs, eval_points): _x, *index_variables = inputs return self.make_node(eval_points[0], *index_variables).outputs - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): def is_bool_index(idx): return ( isinstance(idx, np.bool_ | bool) @@ -2619,7 +2619,7 @@ def perform(self, node, inputs, out_): else: np.add.at(out[0], tuple(full_indices), y) - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): return [ishapes[0]] def connection_pattern(self, node): diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py index 79210f6958..b9f27e521c 100644 --- a/pytensor/xtensor/basic.py +++ b/pytensor/xtensor/basic.py @@ -30,7 +30,7 @@ class XTypeCastOp(TypeCastingOp): This is like a `ViewOp` but without the expectation the input and output have identical types. """ - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return input_shapes def vectorize_node( diff --git a/tests/compile/test_ops.py b/tests/compile/test_ops.py index a30ed6475d..1954c6cbb6 100644 --- a/tests/compile/test_ops.py +++ b/tests/compile/test_ops.py @@ -65,7 +65,7 @@ def test_infer_shape(self): x = dmatrix("x") y = dvector("y") - def infer_shape(fgraph, node, shapes): + def infer_shape(node, shapes): _x, y = shapes return [y] diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 5b9d677bd9..8881a20b0b 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -310,7 +310,7 @@ def grad(self, inputs, gout): else: return (gz,) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] def test_grad_fail(self): diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index ba02feabb6..f24860d804 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -162,7 +162,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = x.copy() - # def infer_shape(self, fgraph, node, (xshp,)): + # def infer_shape(self, node, (xshp,)): # return [tuple([self.shape_i(i)(r) for i in range(r.ndim)])] identity_noshape = IdentityNoShape() @@ -179,7 +179,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = x.copy() - def infer_shape(self, fgraph, node, xshp_): + def infer_shape(self, node, xshp_): # Could also just return. (xshp,) = xshp_ return (xshp,) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index dfded3fbc3..c4023b3ea5 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -310,7 +310,7 @@ def perform(self, node, inputs, outputs): c[0] = np.arange(a.size + b.size, dtype=config.floatX) d[0] = np.arange(a.sum() + b.sum(), dtype=config.floatX) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): # First output shape depends only on input_shapes # Second output shape depends on input values a_identity, b_identity = node.inputs @@ -362,7 +362,7 @@ def make_node(self, x): def perform(self, node, inputs, outputs): raise NotImplementedError() - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): y = node.outputs[0] # Apparently it's valid to return integers in infer_shape. # DimShuffle does this. Modify test if that is no longer allowed. diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 7e217ab3ea..4819281451 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -873,7 +873,7 @@ def test_partial_static_shape_info(self): x_inferred_shape = (ps.constant(1), ps.constant(1)) res_shape = z.owner.op.infer_shape( - None, z.owner, [x_inferred_shape, x_inferred_shape] + z.owner, [x_inferred_shape, x_inferred_shape] ) assert len(res_shape) == 1 @@ -902,7 +902,7 @@ def make_node(self, *args): as_tensor_variable(np.eye(1)), ) in_1_shape = (ps.constant(1), ps.constant(1)) - outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape]) + outs = z_1.owner.op.infer_shape(z_1.owner, [in_1_shape, in_1_shape]) for out in outs: assert out[0].eval() == 1 assert out[1].eval() == 1 @@ -911,7 +911,7 @@ def make_node(self, *args): as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(3)) ) in_2_shape = (ps.constant(3), ps.constant(3)) - outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_2_shape]) + outs = z_1.owner.op.infer_shape(z_1.owner, [in_1_shape, in_2_shape]) for out in outs: assert out[0].eval() == 3 assert out[1].eval() == 3 @@ -924,7 +924,7 @@ def test_shape_types(self): assert isinstance(z.owner.op, Elemwise) - (out_shape,) = z.owner.op.infer_shape(None, z.owner, [(lscalar(), 1), (50, 10)]) + (out_shape,) = z.owner.op.infer_shape(z.owner, [(lscalar(), 1), (50, 10)]) assert all(isinstance(v.type, TensorType) for v in out_shape) From cdd49e1473632e4b39a6369e538355e71fd82517 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 29 Apr 2026 19:32:53 +0200 Subject: [PATCH 2/5] Don't constant-fold `Alloc` consumed by `Subtensor` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `Alloc.do_constant_folding` listed `Elemwise | DimShuffle | Alloc | Join` and batched-`Blockwise` as protected client ops, but not `Subtensor`. `local_subtensor_of_alloc` rewrites `alloc(val, *shape)[idx]` into `alloc(val[...], *new_shape)` — preserving the Alloc structure that downstream rewrites like `local_blockwise_alloc_inputs` depend on. Folding the Alloc here short-circuited that lift and produced broadcast-equivalent `Constant` matrices whose batch dim was no longer type-broadcastable, so `local_blockwise_reshape` couldn't unwrap the surrounding `Blockwise(Reshape)`. Surfaced by the lazy-kernel `ShapeFeature` (which resolves `Subtensor(Shape(out), const)` to a scalar `Constant` earlier and makes more upstream Allocs constant-foldable), but the fix belongs here — the protection was too narrow. --- pytensor/tensor/basic.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 9bc6232a14..7cfc68cedf 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1753,14 +1753,32 @@ def do_constant_folding(self, fgraph, node): if not clients: return False + from pytensor.tensor.blas import Gemv, Ger + from pytensor.tensor.blas_c import CGemv, CGer + from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + IncSubtensor, + Subtensor, + ) + for client, idx in clients: client_op = client.op if isinstance(client_op, Output): # If the output is a constant, it will have to be deepcopied # each time the function is called. So we do not fold. return False - # Op's through which Alloc can be lifted - elif isinstance(client_op, Elemwise | DimShuffle | Alloc | Join): + # Op's through which Alloc can be lifted. ``Subtensor`` is + # included because ``local_subtensor_of_alloc`` rewrites + # ``alloc(val, *shape)[idx]`` into ``alloc(val[...], *new_shape)``, + # preserving the Alloc structure that downstream rewrites + # (e.g. ``local_blockwise_alloc_inputs``) rely on. Folding the + # Alloc here would short-circuit that lift and produce a + # broadcast-equivalent constant whose batch dim is no longer + # type-broadcastable. + elif isinstance( + client_op, Elemwise | DimShuffle | Alloc | Join | Subtensor + ): return False # Same for Blockwise, unless it has no batch_dims elif isinstance(client_op, Blockwise) and client.op.batch_ndim(client): @@ -1770,13 +1788,13 @@ def do_constant_folding(self, fgraph, node): idx == 0 and isinstance( client_op, - pytensor.tensor.subtensor.IncSubtensor - | pytensor.tensor.subtensor.AdvancedIncSubtensor1 - | pytensor.tensor.subtensor.AdvancedIncSubtensor - | pytensor.tensor.blas.Gemv - | pytensor.tensor.blas_c.CGemv - | pytensor.tensor.blas.Ger - | pytensor.tensor.blas_c.CGer, + IncSubtensor + | AdvancedIncSubtensor1 + | AdvancedIncSubtensor + | Gemv + | CGemv + | Ger + | CGer, ) ): # Ops that will work inplace on the Alloc. So if they From 3da36ed4fc6453af3ba1b7a580a0af022d0eb4a3 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 30 Apr 2026 23:39:06 +0200 Subject: [PATCH 3/5] Rewrite ShapeFeature as lazy kernel-based feature Build one FrozenFunctionGraph "kernel" per Apply, rooted in NominalVariable clones of node.inputs. Each kernel is cached in self._cache and materialized on demand against today's live inputs via a custom walker (graph_replace would mutate the globally-interned FrozenApply's inputs). The kernel never holds live variables, so stale references can't leak into shape expressions across rewrites. local_track_shape_i rewrites Shape_i(v, i) with the kernel-inferred expression directly. on_change_input installs r's inferred shape as an override on new_r when new_r's Op has no infer_shape. Also includes: - break_aliasing_cycles (graph/replace.py) for sub-graphs where a single Apply reads an inplace-destroyed input and has another input that depends on the destroyer's output - Hash-cons materialized get_shape(v, i) results - Canonicalize Subtensor(Shape(x), const) / Shape(x) patterns into Shape_i / MakeVector post-materialization - Drop set_shape; route overrides through borrowed kernels - Drop fallback_out role (redundant with layout-None branch) - Updated builders.infer_shape: leaf-rebinding approach --- pytensor/compile/builders.py | 83 +- pytensor/graph/replace.py | 108 +- pytensor/scan/rewriting.py | 20 +- pytensor/tensor/random/rewriting/basic.py | 2 +- pytensor/tensor/random/rewriting/numba.py | 7 +- pytensor/tensor/rewriting/basic.py | 2 +- pytensor/tensor/rewriting/numba.py | 4 +- pytensor/tensor/rewriting/shape.py | 1467 +++++++++++---------- pytensor/tensor/rewriting/subtensor.py | 26 +- pytensor/tensor/shape.py | 16 +- pytensor/tensor/utils.py | 73 +- tests/compile/test_builders.py | 4 +- tests/graph/test_replace.py | 121 +- tests/tensor/random/test_basic.py | 8 +- tests/tensor/rewriting/test_shape.py | 306 ++++- tests/tensor/test_utils.py | 54 +- tests/xtensor/test_rewriting.py | 39 +- 17 files changed, 1483 insertions(+), 857 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 29189b7fe3..481cf1a80c 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -7,7 +7,6 @@ from collections.abc import Callable, Sequence from copy import copy from functools import partial -from itertools import chain from typing import cast from pytensor.compile.maker import function @@ -23,70 +22,54 @@ from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph from pytensor.graph.null_type import NullType from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern -from pytensor.graph.replace import clone_replace +from pytensor.graph.replace import clone_replace, graph_replace from pytensor.graph.traversal import graph_inputs from pytensor.graph.utils import MissingInputError def infer_shape(outs, inputs, input_shapes): - """ - Compute the shape of the outputs given the shape of the inputs of an PyTensor - graph. - - We do it this way to avoid compiling the inner function just to get - the shape. Changes to ShapeFeature could require changes in this function. + """Compute the shape of ``outs`` given the shape of ``inputs``. + Builds per-Apply shape kernels via ``ShapeFeature`` and then + rebinds each inner-input leaf — surfaced as ``Shape_i(j)(inp)`` in + the materialized exprs — to the caller-supplied outer dim. No + compile of the inner function required. """ - # We use a ShapeFeature because it has all the necessary logic - # inside. We don't use the full ShapeFeature interface, but we - # let it initialize itself with an empty fgraph, otherwise we will - # need to do it manually - # TODO: ShapeFeature should live elsewhere from pytensor.tensor.rewriting.shape import ShapeFeature for inp, inp_shp in zip(inputs, input_shapes, strict=True): if inp_shp is not None and len(inp_shp) != inp.type.ndim: - assert len(inp_shp) == inp.type.ndim + raise ValueError( + f"input {inp} has {inp.type.ndim} dims, got shape with {len(inp_shp)}" + ) - shape_feature = ShapeFeature() - fgraph = FunctionGraph([], [], features=[shape_feature]) - for v in chain.from_iterable(s for s in input_shapes if s is not None): - # Import input_shape nodes, as for some graphs ShapeFeature assumes these were seen before - if (node := v.owner) is not None: - fgraph.import_node(node, import_missing=True) + feature = ShapeFeature() + out_shapes = [feature.shape_tuple(o) for o in outs] - # Initialize shape_of with the input shapes + # ``feature.get_shape(inp, j)`` is the same memoized instance that + # appears at the leaves of ``out_shapes`` — ``Shape_i(j)(inp)`` for + # unknown dims, ``Constant`` for static dims. Rebind the Shape_i + # leaves to the caller-supplied scalars; static-dim Constants are + # skipped (no owner) so the static type wins, matching prior behavior. + replacements = {} for inp, inp_shp in zip(inputs, input_shapes, strict=True): - shape_feature.set_shape(inp, inp_shp, override=True) - - def local_traverse(out): - """ - Go back in the graph, from out, adding computable shapes to shape_of. - - """ - if out in shape_feature.shape_of: - # Its shape is already known - return - elif out.owner is None: - # This is an input of the graph - shape_feature.init_r(out) - else: - # Recurse over inputs - for inp in out.owner.inputs: - if inp not in shape_feature.shape_of: - local_traverse(inp) - - # shape_feature.on_import does not actually use an fgraph - # It will call infer_shape and set_shape appropriately - dummy_fgraph = None - shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy") - - ret = [] - for o in outs: - local_traverse(o) - ret.append(shape_feature.shape_of[o]) - return ret + if inp_shp is None or not hasattr(inp.type, "ndim"): + continue + for j in range(inp.type.ndim): + leaf = feature.get_shape(inp, j) + if leaf.owner is not None: + replacements[leaf] = inp_shp[j] + + if not replacements: + return out_shapes + + # ``strict=False``: an inner input may not be reachable from every + # output, so its leaf won't appear in every shape expression. + return [ + None if s is None else tuple(graph_replace(list(s), replacements, strict=False)) + for s in out_shapes + ] def construct_nominal_fgraph( diff --git a/pytensor/graph/replace.py b/pytensor/graph/replace.py index ad309161aa..e7357d38ef 100644 --- a/pytensor/graph/replace.py +++ b/pytensor/graph/replace.py @@ -1,5 +1,5 @@ import warnings -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Callable, Collection, Iterable, Mapping, Sequence from functools import singledispatch from typing import cast, overload @@ -11,6 +11,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.graph.traversal import ( + general_toposort, toposort, truncated_graph_inputs, ) @@ -212,6 +213,111 @@ def toposort_key( return fg.outputs[0] +def break_aliasing_cycles( + outputs: Sequence[Variable], + destroyers_of: Callable[[Variable], Collection[Apply]], +) -> list[Variable]: + """Break aliasing-induced ordering cycles in ``outputs``. + + An inplace Op ``D`` overwrites one of its inputs ``x`` in place, so + ``D``'s output ``y`` aliases ``x``'s storage. Any client that reads + the pre-overwrite ``x`` must therefore run *before* ``D``, and any + client that reads ``y`` must run *after*. A cycle arises when a + single Apply ``A`` does both — reads ``x`` directly *and* has another + input that (directly or transitively) depends on ``y``. ``A`` would + have to run before ``D`` and after it. No valid schedule exists. + + This function finds every such ``A`` in ``outputs``' ancestry and + re-routes ``x`` *on that one Apply only* through ``deep_copy_op``. + ``A`` then reads the copy instead of the aliased original, lifting + the "before" constraint. ``D`` keeps reading ``x`` directly; the + rest of the graph is untouched. + + Multiple outputs share one topological pass; an Apply reachable from + more than one output is analyzed once, and an aliased value patched + across outputs gets a single shared ``deep_copy_op`` wrapper. Returns + ``outputs`` unchanged when no Apply exhibits the pattern. + + Parameters + ---------- + outputs + Roots of the sub-graph to scan. + destroyers_of + Callable returning the Apply nodes that overwrite a given + Variable in place (empty when none). Typically + ``fgraph.destroyers`` from a ``FunctionGraph`` with an attached + ``DestroyHandler``, but this function makes no assumption about + provenance — the caller is responsible for the lookup's + meaningfulness, and for skipping the call when there are no + inplace ops in the first place (the ancestry is walked + unconditionally). + """ + from pytensor.compile.ops import deep_copy_op + + deps: dict[Variable, frozenset[Variable]] = {} + substitutes: dict[Variable, Variable] = {} + replacements: dict[Variable, Variable] = {} + # ``general_toposort`` guarantees inputs are visited before consumers, + # so ``deps`` for every input is final by the time we look at an Apply. + for v in general_toposort( + outputs, lambda v: v.owner.inputs if v.owner is not None else [] + ): + if v.owner is None: + deps[v] = frozenset() + continue + node = v.owner + + # Accumulate this Variable's destroyer-output reach: union of the + # parents' reaches, plus any parent that is itself an output of an + # inplace Apply. + d: set[Variable] = set() + for inp in node.inputs: + d |= deps[inp] + if inp.owner is not None and inp.owner.op.destroy_map: + d.add(inp) + deps[v] = frozenset(d) + + if node.op.destroy_map: + # Inplace Apply — preserve as-is; never enters ``replacements`` + # so ``graph_replace`` leaves it alone. + continue + + # Cycle-pattern check per destroyed input on ``node``: a destroyed + # input ``i`` triggers the pattern iff some *other* input has the + # destroyer's output in its reach. + new_inputs = list(node.inputs) + changed = False + for i, inp in enumerate(node.inputs): + inp_destroyers = destroyers_of(inp) + if not inp_destroyers: + continue + other_deps: set[Variable] = set() + for j, other_inp in enumerate(node.inputs): + if j == i: + continue + other_deps |= deps[other_inp] + if other_inp.owner is not None and other_inp.owner.op.destroy_map: + other_deps.add(other_inp) + if any( + out in other_deps for c_app in inp_destroyers for out in c_app.outputs + ): + if inp not in substitutes: + substitutes[inp] = cast(Variable, deep_copy_op(inp)) + new_inputs[i] = substitutes[inp] + changed = True + if changed: + new_node = node.op.make_node(*new_inputs) + replacements.update(zip(node.outputs, new_node.outputs, strict=True)) + + if not replacements: + return list(outputs) + + # ``graph_replace`` walks each output, substitutes any matched Apply + # outputs with the patched version, and rebuilds whatever's downstream + # — composing stacked patches automatically. + return graph_replace(list(outputs), replace=replacements) + + @singledispatch def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply | Sequence[Variable]: # Default implementation is provided in pytensor.tensor.blockwise diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index b50f3642fc..9addb3929e 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -1405,13 +1405,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: position in the outer circular buffer. This would invalidate results, if the input is still needed for some other output computation. """ - if hasattr(fgraph, "shape_feature"): - shape_of = fgraph.shape_feature.shape_of - else: - # Each access to shape_of is in a try..except block in order to - # use a default version when the variable is not in the shape_of - # dictionary. - shape_of = {} + shape_feature = getattr(fgraph, "shape_feature", None) # 1. Initialization of variables # Note 1) We do not actually care about outputs representing shared # variables (those have no intermediate values) so it is safer to @@ -1503,14 +1497,14 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: # 2.3.2 extract the begin/end of the first dimension if i >= op_info.n_mit_mot: - try: - length = shape_of[out][0] - except KeyError: + if shape_feature is not None and shape_feature.tracks_shape(out): + length = shape_feature.get_shape(out, 0) + else: length = node.inputs[0] + init_l[i] else: - try: - length = shape_of[out][0] - except KeyError: + if shape_feature is not None and shape_feature.tracks_shape(out): + length = shape_feature.get_shape(out, 0) + else: length = out.shape[0] cf_slice = get_canonical_form_slice(this_slice[0], length) slices[i] += [(cf_slice, this_slice)] # type: ignore diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index 28b462f7e1..43390ec907 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -275,7 +275,7 @@ def is_nd_advanced_idx(idx, dtype) -> bool: # Use shape_feature to facilitate inferring final shape. # Check that neither the RV nor the old Subtensor are in the shape graph. - output_shape = fgraph.shape_feature.shape_of.get(indexed_rv, None) + output_shape = shape_feature.shape_tuple(indexed_rv) if output_shape is None or {indexed_rv, rv} & set(ancestors(output_shape)): return None diff --git a/pytensor/tensor/random/rewriting/numba.py b/pytensor/tensor/random/rewriting/numba.py index e171e03b45..41ed103d1a 100644 --- a/pytensor/tensor/random/rewriting/numba.py +++ b/pytensor/tensor/random/rewriting/numba.py @@ -58,9 +58,10 @@ def introduce_explicit_core_shape_rv(fgraph, node): _next_rng, rv = node.outputs shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) if shape_feature: - core_shape = [ - shape_feature.get_shape(rv, -i - 1) for i in reversed(range(op.ndim_supp)) - ] + # Trailing ``op.ndim_supp`` dims are the core shape, cycle-broken. + core_shape = list( + shape_feature.unaliased_shape_tuple(rv, range(-op.ndim_supp, 0)) + ) else: core_shape = op._supp_shape_from_params(op.dist_params(node)) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 280e20f41e..64c5c459cf 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -144,7 +144,7 @@ def alloc_like( if value.type.is_super(template.type): return value if hasattr(fgraph, "shape_feature"): - new_shape = fgraph.shape_feature.shape_of[template] + new_shape = fgraph.shape_feature.shape_tuple(template) else: new_shape = template.shape rval = alloc(value, *new_shape) diff --git a/pytensor/tensor/rewriting/numba.py b/pytensor/tensor/rewriting/numba.py index 047893de49..476f629937 100644 --- a/pytensor/tensor/rewriting/numba.py +++ b/pytensor/tensor/rewriting/numba.py @@ -70,8 +70,10 @@ def introduce_explicit_core_shape_blockwise(fgraph, node): shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) if shape_feature: + # Core dims only, cycle-broken — import won't trip the destroy + # handler. core_shapes = [ - [shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)] + shape_feature.unaliased_shape_tuple(out, range(batch_ndim, out.type.ndim)) for out in node.outputs ] else: diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index c9f3012a31..4ea5817c4f 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -1,26 +1,22 @@ -import traceback -from io import StringIO -from typing import cast as type_cast from warnings import warn import numpy as np import pytensor from pytensor.configdefaults import config -from pytensor.graph.basic import Constant, Variable, equal_computations +from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.features import AlreadyThere, Feature -from pytensor.graph.fg import FunctionGraph +from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph from pytensor.graph.rewriting.basic import ( GraphRewriter, copy_stack_trace, node_rewriter, ) from pytensor.graph.traversal import ancestors -from pytensor.graph.utils import InconsistencyError, get_variable_trace_string +from pytensor.graph.utils import get_variable_trace_string from pytensor.tensor.basic import ( MakeVector, as_tensor_variable, - cast, constant, expand_dims, get_scalar_constant_value, @@ -34,7 +30,6 @@ register_specialize, register_stabilize, register_useless, - topo_constant_folding, ) from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift from pytensor.tensor.shape import ( @@ -51,612 +46,787 @@ Subtensor, get_idx_list, ) -from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes +from pytensor.tensor.type import TensorType, integer_dtypes, lscalar from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.variable import TensorVariable class ShapeFeature(Feature): - r"""A `Feature` that tracks shape information in a graph. - - This `Feature` aids in the replacement of all `Shape`\s and `Subtensor`\s of `Shape`\s with - `Shape_i` and `MakeVector` `Op`\s. - - This `Feature` and its associated rewrites have several goals: - - 1. to "lift" `Shape`\s to as close to the inputs as possible, - 2. to infer the shape of every node in the graph in terms of the - input shapes, and - 3. remove fill `Op`\s (e.g. `Second`) from the graph. - - Lifting shapes as close to the inputs as possible is important for - canonicalization because it is very bad form to have to compute - something just to know how big it will be. Firstly, it is a waste - of time to compute such outputs. But it is important to get rid - of these outputs as early as possible in the compilation process - because the extra computations make it appear as if many internal - graph nodes have multiple clients. Many rewrites refuse to - work on nodes with multiple clients. - - Lifting is done by using an `.infer_shape` function if one is - present, or else using a conservative default. An Op that - supports shape-lifting should define a infer_shape(self, fgraph, node, - input_shapes) function. The argument input_shapes is a tuple of - tuples... there is an interior tuple for each input to the node. - The tuple has as many elements as dimensions. The element in - position i of tuple j represents the i'th shape component of the - j'th input. The function should return a tuple of tuples. One - output tuple for each node.output. Again, the i'th element of the - j'th output tuple represents the output[j].shape[i] of the - function. If an output is not a TensorType, then None should be - returned instead of a tuple for that output. - - For example the infer_shape for a matrix-matrix product would accept - input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),). - - Inferring the shape of internal nodes in the graph is important - for doing size-driven rewrites. If we know how big various - intermediate results will be, we can estimate the cost of many Ops - accurately, and generate c-code that is specific [e.g. unrolled] - to particular sizes. - - In cases where you cannot figure out the shape, raise a ShapeError. - - Notes - ----- - To use this shape information in rewrites, use the - ``shape_of`` dictionary. - - For example: - - .. code-block:: python - - try: - shape_of = fgraph.shape_feature.shape_of - except AttributeError: - # This can happen when the mode doesn't include the ShapeFeature. - return - - shape_of_output_zero = shape_of[node.output[0]] - - The ``shape_of_output_zero`` symbol will contain a tuple, whose - elements are either integers or symbolic integers. - - TODO: check to see if the symbols are necessarily - non-constant... or are integer literals sometimes PyTensor - constants?? That would be confusing. - + r"""Kernel-based `Feature` that tracks shape information in a graph. + + For each `Apply`, a `FrozenFunctionGraph` "kernel" is built once and + stored in ``self._cache[node]``. The kernel is rooted in *dummy* + variables — never the live outer variables — so it can't go stale + as the fgraph mutates. Shape requests materialize the kernel + against today's ``node.inputs`` (and recursive shape lookups), so + returned expressions are always rooted in live variables. + + Public API: + + - ``get_shape(v, i)`` — materialize ``v.shape[i]``. + - ``shape_tuple(v)`` — materialize ``tuple(v.shape)``. + - ``unaliased_shape_tuple(v, dims=None)`` — like ``shape_tuple`` but + breaks aliasing-induced cycles so the result is safe to import + into the attached fgraph alongside its inplace destroyers. + - ``tracks_shape(v)`` — does the feature know a shape for ``v``? + - ``same_shape(x, y, dim_x=None, dim_y=None)`` — via content-addressed ``shape_key``. """ - def get_node_infer_shape(self, node): - try: - shape_infer = node.op.infer_shape - except AttributeError: - shape_infer = self.default_infer_shape - - try: - o_shapes = shape_infer(node, [self.shape_of[r] for r in node.inputs]) - except ShapeError: - o_shapes = self.default_infer_shape( - node, [self.shape_of[r] for r in node.inputs] - ) - except NotImplementedError as e: - raise NotImplementedError( - "Code called by infer_shape failed raising a " - "NotImplementedError. Raising NotImplementedError to " - "indicate that a shape cannot be computed is no longer " - "supported, and one should now use ShapeError " - f"instead. The original exception message is: {e}" - ).with_traceback(e.__traceback__) - except Exception as e: - msg = ( - f"Failed to infer_shape from Op {node.op}.\nInput shapes: " - f"{[self.shape_of[r] for r in node.inputs]}\nException encountered during infer_shape: " - f"{type(e)}\nException message: {e!s}\nTraceback: {traceback.format_exc()}" - ) - if config.on_shape_error == "raise": - raise Exception(msg).with_traceback(e.__traceback__) - else: - warn(msg) - o_shapes = self.default_infer_shape( - node, [self.shape_of[r] for r in node.inputs] - ) - - return o_shapes - - def get_shape(self, var, idx): - """Rewrites can call this to get a `Shape_i`. - - It is better to call this then use directly ``shape_of[var][idx]`` - as this method should update `shape_of` if needed. - - TODO: Up to now, we don't update it in all cases. Update in all cases. + def __init__(self): + # Per-Apply kernel cache: ``node -> (kernel, meta)`` from + # ``_build_kernel``. The kernel is a ``FrozenFunctionGraph`` rooted + # in dummy inputs; ``get_shape`` materializes it against today's + # live ``node.inputs``. Populated lazily on ``get_shape`` / + # ``shape_key`` / ``_reroute_dim``; dropped in ``on_prune``. + self._cache: dict = {} + # Kernel-borrow overrides keyed by ``Variable``: an ndim-tuple + # whose entries are either ``None`` (Shape_i fallback) or + # ``(dim_kernel, role_bindings)``. ``dim_kernel`` is the per-dim + # ``FrozenFunctionGraph`` borrowed from the *replaced* var's + # kernel; ``role_bindings`` is a tuple of ``(input_idx, dim)`` + # aligned with ``dim_kernel.inputs``, indexing into the keying + # var's ``.owner.inputs``. No live Variables are pinned: the + # live shape is rebuilt at access time by walking the dim_kernel + # against ``v.owner.inputs[input_idx].shape[dim]``. Installed by + # ``on_change_input`` when ``new_r`` replaces ``r`` and + # ``new_r``'s Op has no ``infer_shape``. + self._overrides: dict = {} + # Memoizes ``Shape_i(i)(v)`` for leaves/fallbacks so callers that + # cross-reference shape entries with ``Shape_i`` nodes in the graph + # observe Apply identity (the graph's MergeFeature would otherwise + # merge structurally equal copies, but by then compare-by-identity + # rewrites may have already bailed out). + # Keyed by ``(id(v), i)``; safe because the fgraph holds a strong + # ref to ``v`` for the feature's lifetime. ``on_prune`` drops the + # entries for removed Apply outputs; graph-input removal would + # leak entries but is not a path we currently exercise. + self._shape_i_cache: dict = {} + # Memoize the canonicalized result of ``get_shape(v, i)`` so a + # second caller observes identity, not a fresh equivalent tree. + # Safe to hold strong refs because the cached expression is + # canonical: ``Shape_i{j}(graph_input_leaf)``, lscalars, constants, + # and arithmetic — none of those participate in the rewrite cycles + # that would otherwise replace nodes out from under us. Dropped + # in ``on_prune`` when the keying ``v`` is removed. + self._materialized: dict = {} + # Per-dim sub-views of the per-node kernel, used by + # ``same_shape``/``shape_key``. Keyed ``node -> {slot: (dim_kernel, + # used_roles) | None}``. ``dim_kernel`` is a single-output + # ``FrozenFunctionGraph`` over only the kernel inputs reachable + # from ``kernel.outputs[slot]``. Because ``FrozenApply`` and + # ``NominalVariable`` are globally interned, structurally + # identical shape expressions yield ``__eq__`` dim kernels — so + # ``same_shape`` reduces to a content-addressed kernel match plus + # a roles/binding compare, instead of a recursive op-tree walk. + self._dim_kernel_cache: dict = {} + self.fgraph: FunctionGraph | None = None + + def tracks_shape(self, v) -> bool: + """``True`` iff this feature has shape information for ``v``. + + A var is tracked when its owner has a kernel cached (it was + hit by a ``get_shape`` / ``shape_key`` call), or it carries an + explicit override, or it's a graph input of the attached fgraph. """ - r = self.shape_of[var][idx] - if ( - r.owner - and isinstance(r.owner.op, Shape_i) - and r.owner.inputs[0] not in self.fgraph.variables - ): - assert var.owner - node = var.owner - # recur on inputs - for i in node.inputs: - if getattr(i.type, "ndim", None) > 0: - self.get_shape(i, 0) - o_shapes = self.get_node_infer_shape(node) - assert len(o_shapes) == len(node.outputs) - - # Only change the variables and dimensions that would introduce - # extra computation - for new_shps, out in zip(o_shapes, node.outputs, strict=True): - if not hasattr(out.type, "ndim"): - continue - - merged_shps = list(self.shape_of[out]) - changed = False - for i in range(out.type.ndim): - n_r = merged_shps[i] - if ( - n_r.owner - and isinstance(n_r.owner.op, Shape_i) - and n_r.owner.inputs[0] not in self.fgraph.variables - ): - changed = True - merged_shps[i] = new_shps[i] - if changed: - self.set_shape(out, merged_shps, override=True) - r = self.shape_of[var][idx] - return r - - def shape_ir(self, i, r): - """Return symbolic r.shape[i] for tensor variable r, int i.""" - if hasattr(r.type, "shape") and r.type.shape[i] is not None: - return constant(r.type.shape[i], dtype="int64") - else: - s = Shape_i(i)(r) - try: - s = get_scalar_constant_value(s) - except NotScalarConstantError: - pass - return s - - def shape_tuple(self, r): - """Return a tuple of symbolic shape vars for tensor variable r.""" - if not hasattr(r.type, "ndim"): - # This happen for NoneConst. - return None - return tuple(self.shape_ir(i, r) for i in range(r.type.ndim)) - - def default_infer_shape(self, node, i_shapes): - """Return a list of shape tuple or None for the outputs of node. - - This function is used for Ops that don't implement infer_shape. - Ops that do implement infer_shape should use the i_shapes parameter, - but this default implementation ignores it. - - """ - rval = [] - for r in node.outputs: - try: - rval.append(self.shape_tuple(r)) - except AttributeError: - rval.append(None) - return rval - - def unpack(self, s_i, var): - """Return a symbolic integer scalar for the shape element s_i. - - The s_i argument was produced by the infer_shape() of an Op subclass. - - var: the variable that correspond to s_i. This is just for - error reporting. - - """ - assert s_i is not None - - if s_i == 1: - return self.lscalar_one - if isinstance(s_i, float) and int(s_i) == s_i: - s_i = int(s_i) - if isinstance(s_i, np.integer | int) or ( - isinstance(s_i, np.ndarray) and s_i.ndim == 0 - ): - # this shape is a constant - if s_i < 0: - msg = "There is a negative shape in the graph!" - msg += get_variable_trace_string(var) - # The rest of the pipeline don't handle correctly this - # case. So we have 2 choices, stop compilation or - # consider the shape as unknown. As we have more - # chance to give the stack trace here then later, I - # choose that options as it would give better error - # message. - raise AssertionError(msg) - return constant(s_i, dtype="int64") - if isinstance(s_i, tuple | list): - # this dimension is the same as many of the inputs - # which tells us that if one of the inputs is known, - # the others all become known. - # TODO: should be implemented in Elemwise, and Dot - # - # worst case, we loop over shape_of and replace things - raise NotImplementedError(s_i) - - # s_i is x.shape[i] for some x, we change it to shape_of[x][i] - if ( - s_i.owner - and isinstance(s_i.owner.op, Subtensor) - and s_i.owner.inputs[0].owner - and isinstance(s_i.owner.inputs[0].owner.op, Shape) - ): - assert s_i.type.ndim == 0 - assert len(s_i.owner.op.idx_list) == 1 - - # The current Subtensor always put constant index in the graph. - # This was not True in the past. So call the Subtensor function - # that will return the right index. - idx = get_idx_list(s_i.owner.inputs, s_i.owner.op.idx_list) - assert len(idx) == 1 - idx = idx[0] - try: - i = get_scalar_constant_value(idx) - except NotScalarConstantError: - pass - else: - # Executed only if no exception was raised - x = s_i.owner.inputs[0].owner.inputs[0] - # x should already have been imported, and should be in shape_of. - s_i = self.shape_of[x][i] - - if s_i.type.dtype in integer_dtypes: - if getattr(s_i.type, "ndim", 0): - raise TypeError("Shape element must be scalar", s_i) - return s_i - else: - raise TypeError( - "Unsupported shape element", s_i, type(s_i), getattr(s_i, "type", None) - ) - - def set_shape(self, r, s, override=False): - """Assign the shape `s` to previously un-shaped variable `r`. - - Parameters - ---------- - r : a variable - s : None or a tuple of symbolic integers - override : If False, it mean r is a new object in the fgraph. - If True, it mean r is already in the fgraph and we want to - override its shape. - - """ - if not override: - assert r not in self.shape_of, "r already in shape_of" - if s is None: - self.shape_of[r] = s - else: - if not isinstance(s, tuple | list): - raise TypeError("shapes must be tuple/list", (r, s)) - - if r.type.ndim != len(s): - sio = StringIO() - pytensor.printing.debugprint(r, file=sio, print_type=True) - raise AssertionError( - f"Something inferred a shape with {len(s)} dimensions " - f"for a variable with {int(r.type.ndim)} dimensions" - f" for the variable:\n{sio.getvalue()}" - ) - - shape_vars = [] - for i in range(r.type.ndim): - if hasattr(r.type, "shape") and r.type.shape[i] is not None: - shape_vars.append(constant(r.type.shape[i], dtype="int64")) - else: - shape_vars.append(self.unpack(s[i], r)) - assert all( - not hasattr(r.type, "shape") - or r.type.shape[i] != 1 - or self.lscalar_one.equals(shape_vars[i]) - or self.lscalar_one.equals( - get_scalar_constant_value(shape_vars[i], raise_not_constant=False) - ) - for i in range(r.type.ndim) - ) - self.shape_of[r] = tuple(shape_vars) - for sv in shape_vars: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) - - def update_shape(self, r, other_r): - """Replace shape of r by shape of other_r. - - If, on some dimensions, the shape of other_r is not informative, - keep the shape of r on those dimensions. - + if v is None or not hasattr(v.type, "ndim"): + return False + if v in self._overrides: + return True + if v.owner is not None: + return v.owner in self._cache + fg = self.fgraph + return fg is not None and v in fg.inputs + + def _shape_i_var(self, v, i): + key = (id(v), i) + cached = self._shape_i_cache.get(key) + if cached is not None: + return cached + res = Shape_i(i)(v) + self._shape_i_cache[key] = res + return res + + def _canonicalize_live_shape(self, s, memo=None): + """Rewrite ``Shape(x)`` / ``Subtensor(Shape(x), const_i)`` patterns + into ``MakeVector(Shape_i_0, …)`` / ``Shape_i(const_i)(x)``. + + Why: some ``infer_shape`` impls (e.g. ``Alloc``: ``return [node.inputs[1:]]``) + pipe live shape inputs through verbatim. Those live inputs were + often built by user code as ``v.shape[axis]`` — i.e. + ``Subtensor(Shape(v), axis)`` Applies. If those reach the + materialized output of ``get_shape`` unchanged, EquilibriumGraphRewriter + keeps re-firing ``local_shape_to_shape_i`` on each fresh ``Shape(v)`` + we re-emit, never reaching a fixed point. Pre-canonicalizing here + means the materialized shape never contains ``Shape(...)`` Apply + nodes — only ``Shape_i`` leaves the optimizer leaves alone. """ - # other_r should already have a shape - assert other_r in self.shape_of, ("other_r not in shape_of", other_r) - other_shape = self.shape_of[other_r] - - # If other_shape has no information, call is pointless. - if other_shape is None: - return - - if r in self.shape_of: - r_shape = self.shape_of[r] - else: - # If no info is known on r's shape, use other_shape - self.set_shape(r, other_shape) - return - if ( - other_r.owner - and r.owner - and other_r.owner.inputs == r.owner.inputs - and other_r.owner.op == r.owner.op - ): - # We are doing a merge, so the two shape graphs will be the - # same. This is only done so that we call `ancestors` less - # frequently. - return - - # Merge other_shape with r_shape, giving the priority to other_shape - merged_shape = [] - for i, ps in enumerate(other_shape): - if r_shape is None and other_shape: - merged_shape.append(other_shape[i]) - elif ( - ps.owner - and isinstance(ps.owner.op, Shape_i) - and ps.owner.op.i == i - and ps.owner.inputs[0] in (r, other_r) - ): - # If other_shape[i] is uninformative, use r_shape[i]. - # For now, we consider 2 cases of uninformative other_shape[i]: - # - Shape_i(i)(other_r); - # - Shape_i(i)(r). - merged_shape.append(r_shape[i]) - elif isinstance(r_shape[i], Constant | int): - # We do this to call less often ancestors and make - # sure we have the simplest shape possible. - merged_shape.append(r_shape[i]) - elif isinstance(other_shape[i], Constant | int): - # We do this to call less often ancestors and make - # sure we have the simplest shape possible. - merged_shape.append(other_shape[i]) - elif other_shape[i] == r_shape[i]: - # This mean the shape is equivalent - # We do not want to do the ancestor check in those cases - merged_shape.append(r_shape[i]) - elif any( - ( - r_shape[i] == anc - or ( - anc.owner - and isinstance(anc.owner.op, Shape) - and anc.owner.inputs[0] == r - ) - ) - for anc in ancestors([other_shape[i]]) - ): - # Another case where we want to use r_shape[i] is when - # other_shape[i] actually depends on r_shape[i]. In that case, - # we do not want to substitute an expression with another that - # is strictly more complex. Such a substitution could also lead - # to cycles: if (in the future) r_shape[i] gets replaced by an - # expression of other_shape[i], other_shape[i] may end up - # depending on itself. - merged_shape.append(r_shape[i]) - else: - merged_shape.append(other_shape[i]) - assert all( - ( - not hasattr(r.type, "shape") - or (r.type.shape[i] != 1 and other_r.type.shape[i] != 1) - ) - or self.lscalar_one.equals(merged_shape[i]) - or self.lscalar_one.equals( - get_scalar_constant_value( - merged_shape[i], - only_process_constants=True, - raise_not_constant=False, - ) - ) - for i in range(r.type.ndim) - ) - self.shape_of[r] = tuple(merged_shape) - for sv in self.shape_of[r]: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) - - def set_shape_i(self, r, i, s_i): - """Replace element i of shape_of[r] by s_i""" - assert r in self.shape_of - prev_shape = self.shape_of[r] - # prev_shape is a tuple, so we cannot change it inplace, - # so we build another one. - new_shape = [] - for j, s_j in enumerate(prev_shape): - if j == i: - new_shape.append(self.unpack(s_i, r)) - else: - new_shape.append(s_j) - assert all( - not hasattr(r.type, "shape") - or r.type.shape[idx] != 1 - or self.lscalar_one.equals(new_shape[idx]) - or self.lscalar_one.equals( - get_scalar_constant_value(new_shape[idx], raise_not_constant=False) - ) - for idx in range(r.type.ndim) - ) - self.shape_of[r] = tuple(new_shape) - for sv in self.shape_of[r]: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) - - def init_r(self, r): - """Register r's shape in the shape_of dictionary.""" - if r not in self.shape_of: - self.set_shape(r, self.shape_tuple(r)) + if memo is None: + memo = {} + cached = memo.get(s) + if cached is not None: + return cached + if not isinstance(s, Variable) or s.owner is None: + memo[s] = s + return s - def make_vector_shape(self, r): - return as_tensor_variable(self.shape_of[r], ndim=1, dtype="int64") + node = s.owner + op = node.op + + if isinstance(op, Subtensor): + base = node.inputs[0] + if base.owner is not None and isinstance(base.owner.op, Shape): + x = base.owner.inputs[0] + if hasattr(x.type, "ndim"): + try: + idx_list = get_idx_list(node.inputs, op.idx_list) + if len(idx_list) == 1: + i = int(get_scalar_constant_value(idx_list[0])) + if 0 <= i < x.type.ndim: + result = self.get_shape(x, i) + memo[s] = result + return result + except (NotScalarConstantError, IndexError, TypeError): + pass + + if isinstance(op, Shape): + x = node.inputs[0] + if hasattr(x.type, "ndim") and x.type.ndim > 0: + result = stack([self.get_shape(x, j) for j in range(x.type.ndim)]) + memo[s] = result + return result + + new_inputs = [self._canonicalize_live_shape(inp, memo) for inp in node.inputs] + if all(ni is oi for ni, oi in zip(new_inputs, node.inputs, strict=True)): + memo[s] = s + return s + new_node = op.make_node(*new_inputs) + new_out = new_node.outputs[node.outputs.index(s)] + memo[s] = new_out + return new_out def on_attach(self, fgraph): if hasattr(fgraph, "shape_feature"): raise AlreadyThere("This FunctionGraph already has a ShapeFeature") - - if hasattr(self, "fgraph") and self.fgraph != fgraph: + if self.fgraph is not None and self.fgraph is not fgraph: raise Exception("This ShapeFeature is already attached to a graph") - self.fgraph = fgraph - fgraph.shape_feature = self - # Must be local to the object as otherwise we reuse the same - # variable for multiple fgraph! - self.lscalar_one = constant(1, dtype="int64") - assert self.lscalar_one.type.dtype == "int64" - - self.fgraph = fgraph - # Variable -> tuple(scalars) or None (All tensor vars map to tuple) - self.shape_of = {} - # Variable -> - self.scheduled = {} - # shape var -> graph v - self.shape_of_reverse_index = {} - - for node in fgraph.toposort(): - self.on_import(fgraph, node, reason="on_attach") def on_detach(self, fgraph): - self.shape_of = {} - self.scheduled = {} - self.shape_of_reverse_index = {} + self._cache.clear() + self._overrides.clear() + self._shape_i_cache.clear() + self._materialized.clear() + self._dim_kernel_cache.clear() self.fgraph = None - del fgraph.shape_feature + if hasattr(fgraph, "shape_feature"): + del fgraph.shape_feature + + def on_prune(self, fgraph, node, reason): + self._cache.pop(node, None) + self._dim_kernel_cache.pop(node, None) + # Drop cached Shape_i variables whose owner is being pruned — without + # this the memo grows monotonically over a long canonicalize pass. + for out in node.outputs: + oid = id(out) + for j in range(getattr(out.type, "ndim", 0) or 0): + self._shape_i_cache.pop((oid, j), None) + self._materialized.pop((oid, j), None) + self._overrides.pop(out, None) - def on_import(self, fgraph, node, reason): - if node.outputs[0] in self.shape_of: - # this is a revert, not really an import - for r in node.outputs + node.inputs: - assert r in self.shape_of + def on_change_input(self, fgraph, node, i, r, new_r, reason): + # Carry r's shape forward as a *kernel-borrow* override when + # ``new_r``'s Op has no ``infer_shape``. Per-dim, we rederive r's + # shape kernel against ``new_r.owner.inputs`` by matching + # kernel-input bindings via ``shape_key``; if every binding finds + # a structurally-equal counterpart we store the dim_kernel plus + # the ``(input_idx, dim)`` positions to look up in + # ``new_r.owner.inputs`` at access time. No live Variables are + # pinned. Per-dim ``None`` means "couldn't reroute; fall back to + # ``Shape_i``". + if r is new_r or not hasattr(new_r.type, "ndim"): return + if new_r in self._overrides: + return + if new_r.owner is None: + return # graph inputs have their own Shape_i fallback + if getattr(new_r.owner.op, "infer_shape", None) is not None: + return # new_r's own kernel will produce a real shape + if not hasattr(r.type, "ndim") or r.type.ndim != new_r.type.ndim: + return + new_owner_inputs = new_r.owner.inputs + entries = [] + any_set = False + for k in range(r.type.ndim): + e = self._reroute_dim(r, k, new_owner_inputs) + if e is not None: + any_set = True + entries.append(e) + if any_set: + self._overrides[new_r] = tuple(entries) + + def _reroute_dim(self, r, k, new_r_owner_inputs): + """Try to rederive ``r.shape[k]`` against ``new_r_owner_inputs``. + + Returns ``(dim_kernel, role_bindings)`` on success, where + ``role_bindings`` is a tuple of ``(input_idx, dim)`` aligned with + ``dim_kernel.inputs`` — indices into ``new_r_owner_inputs`` whose + ``shape_key`` matches the corresponding live binding under r's + owner. + + Returns ``None`` when (a) the kernel uses any role other than + ``input_shape_slot`` (``input_slot`` would need value-equality; + ``self_out`` references r's own outputs and can't reroute), + or (b) some role's binding has no structurally-equal + counterpart in ``new_r_owner_inputs``. + """ + if r.owner is None: + return None + if (entry := self._cache.get(r.owner)) is None: + entry = self._build_kernel(r.owner) + self._cache[r.owner] = entry + kernel, meta = entry + if kernel is None: + return None + out_idx = r.owner.outputs.index(r) + layout = meta["output_layout"] + if layout[out_idx] is None: + return None + slot = sum((layout[k_] or 0) for k_ in range(out_idx)) + k + dk = self._dim_kernel(r.owner, slot) + if dk is None: + return None + dim_kernel, used_roles = dk - for i, r in enumerate(node.inputs): - # make sure we have shapes for the inputs - self.init_r(r) - - o_shapes = self.get_node_infer_shape(node) + if any(role[0] != "input_shape_slot" for role in used_roles): + return None - # this is packed information - # an element of o_shapes is either None or a tuple - # elements of the tuple can be either strings, or ints - if len(o_shapes) != len(node.outputs): - raise Exception( - f'The infer_shape method for the Op "{node.op}" returned a list ' - f"with the wrong number of element: len(o_shapes) = {len(o_shapes)} " - f" != len(node.outputs) = {len(node.outputs)}" + slot_to_input_idx = meta["slot_to_input_idx"] + role_bindings = [] + for role in used_roles: + s, j = role[1], role[2] + r_inp = r.owner.inputs[slot_to_input_idx[s]] + target_key = self.shape_key(r_inp, j) + match = None + for idx, inp in enumerate(new_r_owner_inputs): + if not hasattr(inp.type, "ndim"): + continue + for d in range(inp.type.ndim): + if self.shape_key(inp, d) == target_key: + match = (idx, d) + break + if match is not None: + break + if match is None: + return None + role_bindings.append(match) + return (dim_kernel, tuple(role_bindings)) + + def _materialize_override(self, v, i, entry): + """Walk a borrowed dim_kernel against ``v.owner.inputs``.""" + if entry is None: + return self._shape_i_var(v, i) + dim_kernel, role_bindings = entry + new_owner_inputs = v.owner.inputs + memo: dict = { + k_input: self.get_shape(new_owner_inputs[idx], dim) + for k_input, (idx, dim) in zip( + dim_kernel.inputs, role_bindings, strict=True ) + } + for fa in dim_kernel.toposort(): + new_inputs = [memo.get(inp, inp) for inp in fa.inputs] + new_node = fa.op.make_node(*new_inputs) + memo.update(zip(fa.outputs, new_node.outputs, strict=True)) + raw = memo.get(dim_kernel.outputs[0], dim_kernel.outputs[0]) + return self._canonicalize_live_shape(raw) + + def _override_shape_key(self, v, i, entry): + """Content-addressed key for an override entry; see ``shape_key``.""" + if entry is None: + return ("leaf", id(v), i) + dim_kernel, role_bindings = entry + new_owner_inputs = v.owner.inputs + sv = dim_kernel.outputs[0] + if sv.owner is None: + # Passthrough: the dim_kernel output is a kernel input directly. + if isinstance(sv, Constant): + try: + return ("const", int(sv.data)) + except Exception: + return ("const", id(sv)) + try: + k_idx = dim_kernel.inputs.index(sv) + except ValueError: + return ("opaque", id(sv)) + idx, dim = role_bindings[k_idx] + return self.shape_key(new_owner_inputs[idx], dim) + bindings = tuple( + self.shape_key(new_owner_inputs[idx], dim) for idx, dim in role_bindings + ) + return (dim_kernel, bindings) + + def _build_kernel(self, node): + # When the same live input appears at multiple positions (e.g. + # ``Elemwise.add(x, x)``), share the dummy clone AND the dummy + # input-shape lscalars between those positions. Ops like Elemwise + # call ``broadcast_shape(*i_shapes)``, which only drops the runtime + # ``Assert`` guard when the incoming shape expressions are + # identical — so identity here is what lets ``x + x`` infer a + # clean shape instead of ``Assert(x.shape[0], ...)``. + input_slot: dict[int, int] = {} + unique_dummies: list[Variable] = [] + unique_shape_tuples: list[tuple | None] = [] + + dummy_inputs: list[Variable] = [] + dummy_input_shapes: list[tuple | None] = [] + for inp in node.inputs: + key = id(inp) + slot = input_slot.get(key) + if slot is None: + slot = len(unique_dummies) + input_slot[key] = slot + d = inp.clone() + unique_dummies.append(d) + if hasattr(inp.type, "ndim"): + static_shape = getattr(inp.type, "shape", (None,) * inp.type.ndim) + shp_tuple = tuple( + constant(s, dtype="int64") if s is not None else lscalar() + for s in static_shape + ) + else: + shp_tuple = None + unique_shape_tuples.append(shp_tuple) + dummy_inputs.append(unique_dummies[slot]) + dummy_input_shapes.append(unique_shape_tuples[slot]) - # Ensure shapes are in 'int64'. This is to make sure the assert - # found in the `local_useless_subtensor` rewrite does not fail. - for sh_idx, sh in enumerate(o_shapes): - if sh is None: - continue - if not isinstance(sh, list | tuple): - raise ValueError( - f"infer_shape of {node} didn't return a list of" - f" list. It returned '{o_shapes}'" + dummy_outputs = [out.clone() for out in node.outputs] + dummy_node = Apply(node.op, dummy_inputs, dummy_outputs) + + output_shapes = None + shape_infer = getattr(node.op, "infer_shape", None) + if shape_infer is not None: + try: + output_shapes = shape_infer(dummy_node, dummy_input_shapes) + except ShapeError: + output_shapes = None + except NotImplementedError: + output_shapes = None + except Exception as exc: + if config.on_shape_error == "raise": + raise + warn( + f"Failed to infer_shape from Op {node.op}: " + f"{type(exc).__name__}: {exc}" ) - new_shape = [] - for i, d in enumerate(sh): - # Note: we ignore any shape element that is not typed (i.e., - # does not have a 'dtype' attribute). This means there may - # still remain int elements that are int32 on 32-bit platforms, - # but this works with `local_useless_subtensor`, so for now we - # keep it this way. See #266 for a better long-term fix. - if getattr(d, "dtype", "int64") != "int64": - assert d.dtype in discrete_dtypes, (node, d.dtype) - assert str(d.dtype) != "uint64", node - new_shape += sh[len(new_shape) : i + 1] - if isinstance(d, Constant): - casted_d = constant(d.data, dtype="int64") - else: - casted_d = cast(d, "int64") - new_shape[i] = casted_d - if new_shape: - # We replace the shape with wrong dtype by the one with - # 'int64'. - new_shape += sh[len(new_shape) :] - o_shapes[sh_idx] = tuple(new_shape) - - for r, s in zip(node.outputs, o_shapes, strict=True): - self.set_shape(r, s) + output_shapes = None + + if output_shapes is None: + output_shapes = [None] * len(dummy_outputs) + + # Fallback: Shape_i(i)(dummy_output) where the op couldn't provide + # an infer_shape for a given output. Reuse dummy_outputs — no extra + # placeholders. + def coerce_shape_el(s, dummy_out): + # Accept any integer scalar Variable verbatim, and any Python / + # NumPy integer scalar as an int64 constant. Floats and + # non-scalar arrays are buggy returns and raise. + if isinstance(s, np.ndarray): + if s.ndim != 0: + raise TypeError( + f"infer_shape for {node.op} returned a non-scalar " + f"ndarray for shape element: {s!r}" + ) + s = s.item() + if isinstance(s, Variable): + if s.type.dtype not in integer_dtypes: + raise TypeError( + f"infer_shape for {node.op} returned a non-integer " + f"Variable for shape element: {s!r}" + ) + if getattr(s.type, "ndim", 0): + raise TypeError( + f"infer_shape for {node.op} returned a non-scalar " + f"Variable for shape element: {s!r}" + ) + return s + if isinstance(s, int | np.integer): + if int(s) < 0: + raise ValueError( + f"infer_shape for {node.op} returned a negative " + f"shape: {int(s)}" + get_variable_trace_string(dummy_out) + ) + return constant(int(s), dtype="int64") + raise TypeError( + f"infer_shape for {node.op} returned an unsupported " + f"shape element of type {type(s).__name__}: {s!r}" + ) - def on_change_input(self, fgraph, node, i, r, new_r, reason): - if new_r not in self.shape_of: - # It happen that the fgraph didn't called on_import for some - # new_r. This happen when new_r don't have an - # owner(i.e. it is a constant or an input of the graph) - # update_shape suppose that r and new_r are in shape_of. - self.init_r(new_r) - - # This tells us that r and new_r must have the same shape if - # we didn't know that the shapes are related, now we do. - self.update_shape(new_r, r) - - # change_input happens in two cases: - # 1) we are trying to get rid of r, or - # 2) we are putting things back after a failed transaction. - - # In case 1, if r has a shape_i client, we will want to - # replace the shape_i of r with the shape of new_r. Say that r is *scheduled*. - # At that point, node is no longer a client of r, but of new_r - # This schedule is processed by `local_track_shape_i`. - for shpnode, idx in fgraph.clients[r] + [(node, i)]: - if isinstance(shpnode.op, Shape_i): - idx = shpnode.op.i - repl = self.shape_of[new_r][idx] - if repl.owner is shpnode: - # This mean the replacement shape object is - # exactly the same as the current shape object. So - # no need for replacement. - continue + # An output with missing/malformed ``infer_shape`` gets ``None`` + # here, which propagates to ``output_layout[k] = None``. ``get_shape`` + # / ``shape_key`` short-circuit that case to ``_shape_i_var(v, i)`` + # — no kernel slot, no ``fallback_out`` role. + coerced_output_shapes = [] + for k, dummy_out in enumerate(dummy_outputs): + sh = output_shapes[k] if k < len(output_shapes) else None + if not hasattr(dummy_out.type, "ndim"): + coerced_output_shapes.append(None) + continue + if sh is None or not isinstance(sh, list | tuple): + coerced_output_shapes.append(None) + continue + coerced = [] + for i, s in enumerate(sh): if ( - repl.owner - and repl.owner.inputs[0] is shpnode.inputs[0] - and isinstance(repl.owner.op, Shape_i) - and repl.owner.op.i == shpnode.op.i + hasattr(dummy_out.type, "shape") + and dummy_out.type.shape[i] is not None ): - # The replacement is a shape_i of the same - # input. So no need to do this equivalent - # replacement. + coerced.append(constant(dummy_out.type.shape[i], dtype="int64")) continue + coerced.append(coerce_shape_el(s, dummy_out)) + coerced_output_shapes.append(tuple(coerced)) - if shpnode.outputs[0] in ancestors([repl]): - raise InconsistencyError( - "This substitution would insert a cycle in the graph:" - f"node: {node}, i: {i}, r: {r}, new_r: {new_r}" + flat_out = [] + layout = [] + for sh in coerced_output_shapes: + if sh is None: + layout.append(None) + continue + layout.append(len(sh)) + flat_out.extend(sh) + + # ``meta`` carries only what ``get_shape`` / ``shape_key`` need to + # re-wire the frozen kernel against live ``node.inputs``. + meta = {"output_layout": tuple(layout)} + if not flat_out: + return (None, meta) + + # Build kernel_inputs with unique dummies only. Shape slots are + # attached by unique-slot index so duplicate live inputs share the + # same set of kernel-input positions. Each kernel_input needs a + # role that maps back to the live graph at materialization time. + kernel_inputs: list[Variable] = [] + roles: list[tuple] = [] + for slot, dummy in enumerate(unique_dummies): + kernel_inputs.append(dummy) + roles.append(("input_slot", slot)) + for slot, shape_tuple in enumerate(unique_shape_tuples): + if shape_tuple is None: + continue + for j, s in enumerate(shape_tuple): + kernel_inputs.append(s) + roles.append(("input_shape_slot", slot, j)) + + # Some ``infer_shape`` impls (e.g. Scan) reference ``dummy_node.outputs`` + # directly inside the returned shape expression. Without an explicit + # substitution, ``_materialize_frozen`` would walk into ``dummy_node`` + # and rebuild it via ``make_node`` against live inputs, producing + # fresh-but-equivalent Apply nodes on every call and stalling + # EquilibriumGraphRewriter (``local_track_shape_i``). + anc_set = set(ancestors(flat_out)) + for k, dummy_out in enumerate(dummy_outputs): + if dummy_out in anc_set and dummy_out not in kernel_inputs: + kernel_inputs.append(dummy_out) + roles.append(("self_out", k)) + + # Sanity: every free Variable in flat_out should be in kernel_inputs. + # An orphan indicates a buggy ``infer_shape`` that leaked a variable + # outside of ``node.inputs`` / their shape scalars. In development + # mode (config.on_shape_error == "raise") we surface this eagerly + # instead of silently falling back to ``Shape_i``. + kernel_input_set = set(kernel_inputs) + for anc in ancestors(flat_out): + if anc.owner is None: + if isinstance(anc, Constant): + continue + if anc not in kernel_input_set: + msg = ( + f"Op {node.op}.infer_shape leaked an orphan variable " + f"{anc!r} that is not one of node.inputs or their " + f"shape scalars; falling back to Shape_i." ) + if config.on_shape_error == "raise": + raise ShapeError(msg) + return (None, dict(meta, kernel_build_error=msg)) + + # Find any live input index that maps to this slot, so materialization + # can look up ``node.inputs[]``. + slot_to_input_idx: list[int] = [-1] * len(unique_dummies) + for inp_idx, inp in enumerate(node.inputs): + s = input_slot[id(inp)] + if slot_to_input_idx[s] == -1: + slot_to_input_idx[s] = inp_idx - self.scheduled[shpnode] = new_r - # In case 2, if r is a variable that we've scheduled for shape update, - # then we should cancel it. - unscheduled = [k for k, v in self.scheduled.items() if v == r] - for k in unscheduled: - del self.scheduled[k] - - # In either case, r could be in shape_of.values(), that is, r itself - # is the shape of something. In that case, we want to update - # the value in shape_of, to keep it up-to-date. - for v in self.shape_of_reverse_index.get(r, []): - # The reverse index is only approximate. It is not updated on - # deletion of variables, or on change_input so it might be the - # case that there are a few extra `v`'s in it that no longer have - # a shape of r or possibly have been deleted from shape_of - # entirely. The important thing is that it permits to recall - # all variables with r in their shape. - for ii, svi in enumerate(self.shape_of.get(v, [])): - if svi == r: - self.set_shape_i(v, ii, new_r) - self.shape_of_reverse_index[r] = set() + try: + kernel = FrozenFunctionGraph(kernel_inputs, flat_out) + except Exception as exc: + return (None, dict(meta, kernel_build_error=str(exc))) + + meta["roles"] = tuple(roles) + meta["slot_to_input_idx"] = tuple(slot_to_input_idx) + return (kernel, meta) + + def get_shape(self, v, i): + if hasattr(v.type, "shape") and v.type.shape[i] is not None: + return constant(v.type.shape[i], dtype="int64") + cache_key = (id(v), i) + if (ov := self._overrides.get(v)) is not None: + cached = self._materialized.get(cache_key) + if cached is not None: + return cached + result = self._materialize_override(v, i, ov[i]) + self._materialized[cache_key] = result + return result + if v.owner is None: + return self._shape_i_var(v, i) + + cached = self._materialized.get(cache_key) + if cached is not None: + return cached + + node = v.owner + if (entry := self._cache.get(node)) is None: + entry = self._build_kernel(node) + self._cache[node] = entry + kernel, meta = entry + if kernel is None: + result = self._shape_i_var(v, i) + self._materialized[cache_key] = result + return result + + out_idx = node.outputs.index(v) + layout = meta["output_layout"] + if layout[out_idx] is None: + result = self._shape_i_var(v, i) + self._materialized[cache_key] = result + return result + slot = sum((layout[k] or 0) for k in range(out_idx)) + i + dk = self._dim_kernel(node, slot) + if dk is None: + result = self._shape_i_var(v, i) + self._materialized[cache_key] = result + return result + dim_kernel, used_roles = dk + + # Seed memo with the live binding for each used kernel input, + # then walk the kernel's cached topological order rebuilding + # each ``FrozenApply`` against live ``make_node`` calls. Fresh + # ``make_node`` (rather than ``graph_replace``) is required — + # the latter would mutate the globally-interned ``FrozenApply`` + # nodes via ``Apply.clone_with_new_inputs``. + slot_to_input_idx = meta["slot_to_input_idx"] + memo: dict = {} + for k_input, role in zip(dim_kernel.inputs, used_roles, strict=True): + tag = role[0] + if tag == "input_slot": + memo[k_input] = node.inputs[slot_to_input_idx[role[1]]] + elif tag == "input_shape_slot": + memo[k_input] = self.get_shape( + node.inputs[slot_to_input_idx[role[1]]], role[2] + ) + else: + # self_out + memo[k_input] = node.outputs[role[1]] + for fa in dim_kernel.toposort(): + new_inputs = [memo.get(inp, inp) for inp in fa.inputs] + new_node = fa.op.make_node(*new_inputs) + memo.update(zip(fa.outputs, new_node.outputs, strict=True)) + raw = memo.get(dim_kernel.outputs[0], dim_kernel.outputs[0]) + result = self._canonicalize_live_shape(raw) + self._materialized[cache_key] = result + return result + + def unaliased_shape_tuple(self, v, dims=None): + """Like :meth:`shape_tuple`, but free of aliasing-induced cycles + so the result can be imported into ``self.fgraph`` alongside its + inplace destroyers. + + ``shape_tuple`` returns dim expressions that may share live Apply + nodes with the rest of the fgraph. If one of those Applies reads + a destroyed scalar ``x`` directly *and* (via another input) + depends on its destroyer's output, importing the shape into the + fgraph would trip the destroy handler — that Apply would have to + run both before and after the destroyer. This wrapper materializes + the requested dims and runs the per-Apply cycle break in one pass + via :func:`pytensor.graph.replace.break_aliasing_cycles`. No-op + when the fgraph has no ``DestroyHandler`` or no destroyers; in + that case it's equivalent to ``shape_tuple``. + + Parameters + ---------- + v + Variable whose shape we want. + dims + Optional iterable of dim indices to materialize (defaults to + all dims of ``v``). Negative indices follow Python convention. + Pass an explicit subset to avoid materializing dims you don't + need — both the kernel call and the cycle-break walk are + scoped to the dims actually requested. + + Returns ``None`` if ``v`` has no ``ndim``. + """ + if not hasattr(v.type, "ndim"): + return None + if dims is None: + dims = range(v.type.ndim) + shape = [self.get_shape(v, i) for i in dims] + fgraph = self.fgraph + dh = getattr(fgraph, "destroy_handler", None) if fgraph is not None else None + if dh is None or not dh.destroyers: + return tuple(shape) + + from pytensor.graph.replace import break_aliasing_cycles + + return tuple(break_aliasing_cycles(shape, fgraph.destroyers)) + + def shape_tuple(self, r): + """Return a tuple of symbolic shape vars for tensor variable r.""" + if not hasattr(r.type, "ndim"): + return None + return tuple(self.get_shape(r, i) for i in range(r.type.ndim)) + + def _dim_kernel(self, node, slot): + """Lazily-built per-dim ``FrozenFunctionGraph`` view of the + per-node kernel for ``kernel.outputs[slot]``. + + Returns ``(dim_kernel, used_roles)`` or ``None`` if no kernel. + ``dim_kernel`` is a single-output ``FrozenFunctionGraph`` whose + inputs are the subset of ``kernel.inputs`` reachable from the + slot, in their original kernel order. Two structurally identical + slot DAGs produce ``__eq__`` ``FrozenFunctionGraph`` objects (via + global ``FrozenApply``/``NominalVariable`` interning), letting + ``shape_key`` collapse the structural comparison to one hash and + only descend into inputs that are themselves shape lookups. + """ + per_node = self._dim_kernel_cache.get(node) + if per_node is None: + per_node = {} + self._dim_kernel_cache[node] = per_node + if slot in per_node: + return per_node[slot] + if (entry := self._cache.get(node)) is None: + entry = self._build_kernel(node) + self._cache[node] = entry + kernel, meta = entry + if kernel is None: + per_node[slot] = None + return None + sv = kernel.outputs[slot] + kernel_input_set = set(kernel.inputs) + used = {anc for anc in ancestors([sv]) if anc in kernel_input_set} + used_inputs = tuple(inp for inp in kernel.inputs if inp in used) + roles = meta["roles"] + used_roles = tuple( + roles[i] for i, inp in enumerate(kernel.inputs) if inp in used + ) + try: + dim_kernel = FrozenFunctionGraph(used_inputs, [sv]) + except Exception: + per_node[slot] = None + return None + result = (dim_kernel, used_roles) + per_node[slot] = result + return result + + def shape_key(self, v, i): + """Hashable key for ``v.shape[i]`` such that two keys compare equal + iff this feature can prove the two shapes are the same. + + The key is shaped ``(dim_kernel, bindings)``: + + - ``dim_kernel`` is the per-dim ``FrozenFunctionGraph`` view from + ``_dim_kernel``. ``FrozenApply`` and ``NominalVariable`` are + globally interned, so two structurally identical shape + expressions produce ``__eq__`` kernels — content-addressed + structural equality with no manual op-tree walk on this side. + - ``bindings`` records what's bound at each kernel-input + position. An ``id`` for the live var at ``input_slot`` / + ``self_out`` leaves, and a recursive ``shape_key`` call for + ``input_shape_slot`` leaves — whose binding is itself a + sub-shape (``node.inputs[k]``'s dim j), which can in turn + hit any of these branches again. The recursion is bounded + by graph depth. + + Special cases handled before the kernel path: + + - **static dim** → ``("const", value)``. + - **override** → routed through ``_override_shape_key`` against + the borrowed ``(dim_kernel, role_bindings)`` tuple. Same + structure as the kernel path below: passthrough slots collapse + to the underlying live var's key, otherwise + ``(dim_kernel, recursive_shape_keys)``. No identity-only + fallback — a rerouted override compares equal to any + structurally-equal kernel shape. + - **untracked leaf** (no owner, kernel build failed, or this + output isn't laid out) → ``("leaf", id(v), i)``. + - **passthrough slot** (kernel output is a kernel input + directly) → return the underlying live var's binding so the + key matches that var's own ``shape_key``. + + Known limitation: shape sub-expressions baked into a kernel via + ``Op(input).shape`` (e.g. an ``infer_shape`` impl that takes + ``foo(node.inputs[0]).shape[0]``) are compared *structurally* as + part of the parent kernel — ``same_shape`` will not equate two + such kernels even when the inner ops have equivalent shape + kernels. Cross-kernel shape-equivalence is only detected through + ``input_shape_slot`` bindings, which are the explicit seams + ``_build_kernel`` creates. A follow-up could inline sub-kernels + at build time (analogous to how ``_canonicalize_live_shape`` + resolves ``Subtensor(Shape(...))`` at materialization) to close + this gap. + """ + if hasattr(v.type, "shape") and v.type.shape[i] is not None: + return ("const", int(v.type.shape[i])) + if (ov := self._overrides.get(v)) is not None: + return self._override_shape_key(v, i, ov[i]) + node = v.owner + if node is None: + return ("leaf", id(v), i) + if (entry := self._cache.get(node)) is None: + entry = self._build_kernel(node) + self._cache[node] = entry + kernel, meta = entry + if kernel is None: + return ("leaf", id(v), i) + out_idx = node.outputs.index(v) + layout = meta["output_layout"] + if layout[out_idx] is None: + return ("leaf", id(v), i) + slot = sum((layout[k] or 0) for k in range(out_idx)) + i + sv = kernel.outputs[slot] + slot_to_input_idx = meta["slot_to_input_idx"] + + # Bind one kernel-input role to a live key. Only ``input_shape_slot`` + # needs to recurse (its leaf is a sub-shape, not a live var); every + # other role bottoms out at a live ``node.inputs``/``node.outputs``, + # whose ``id`` already discriminates by identity. Heterogeneous + # element types (int id vs recursive tuple) don't collide. + def bind(role): + if role[0] == "input_shape_slot": + return self.shape_key(node.inputs[slot_to_input_idx[role[1]]], role[2]) + if role[0] == "input_slot": + return id(node.inputs[slot_to_input_idx[role[1]]]) + return id(node.outputs[role[1]]) # self_out + + # Passthrough slot: sv is a kernel input (or Constant) directly, + # no shape function around it. Skip the dim-kernel wrapper so the + # key matches the underlying live var's own shape_key. + if sv.owner is None: + if isinstance(sv, Constant): + try: + return ("const", int(sv.data)) + except Exception: + return ("const", id(sv)) + try: + k_idx = kernel.inputs.index(sv) + except ValueError: + return ("opaque", id(sv)) + return bind(meta["roles"][k_idx]) + dk = self._dim_kernel(node, slot) + if dk is None: + return ("leaf", id(v), i) + dim_kernel, used_roles = dk + return (dim_kernel, tuple(bind(role) for role in used_roles)) def same_shape( self, @@ -665,64 +835,24 @@ def same_shape( dim_x: int | None = None, dim_y: int | None = None, ) -> bool: - """Return ``True`` if `x` and `y` have the same shape. - - Parameters - ========== - x - The `Variable` for which its shape is to be compared with `y`'s shape. - y - The `Variable` for which its shape is to be compared with `x`'s shape. - dim_x - If non ``None``, compare only the dimension of `x` equal to - `dim_x`. - dim_y - If non ``None``, compare only the dimension of `y` equal to - `dim_y`. - + """Return ``True`` if ``x`` and ``y`` have the same shape (along + ``dim_x`` / ``dim_y`` if given, else all dims). """ - sx = self.shape_of[x] - sy = self.shape_of[y] - - if sx is None or sy is None: - return False - - if dim_x is not None: - sx = [sx[dim_x]] - - if dim_y is not None: - sy = [sy[dim_y]] - - if len(sx) != len(sy): - return False - - # Canonicalize the graphs so that comparisons are reasonable - # TODO FIXME: This should *not* need to be performed manually here. - # Instead, the shape information in `self.shape_of` should be operated - # upon alongside all the other elements in a `FunctionGraph` (e.g. as - # if `self.shape_of.values()` were additional outputs). - shapes_fg = FunctionGraph( - outputs=sx + sy, - # features=[self], - clone=True, - # copy_inputs=False, - ) - from pytensor.graph.rewriting.utils import rewrite_graph - - canon_shapes_fg = type_cast( - FunctionGraph, - rewrite_graph(shapes_fg, custom_rewrite=topo_constant_folding), - ) - canon_shapes = canon_shapes_fg.outputs - - sx = canon_shapes[: len(sx)] - sy = canon_shapes[len(sx) :] - - for dx, dy in zip(sx, sy, strict=True): - if not equal_computations([dx], [dy]): + if dim_x is None and dim_y is None: + if x.type.ndim != y.type.ndim: return False - - return True + for i in range(x.type.ndim): + if not self.same_shape(x, y, i, i): + return False + return True + if dim_x is None: + dim_x = dim_y + if dim_y is None: + dim_y = dim_x + # Force IndexError semantics matching the legacy impl. + x.type.shape[dim_x] + y.type.shape[dim_y] + return bool(self.shape_key(x, dim_x) == self.shape_key(y, dim_y)) def clone(self): return type(self)() @@ -1289,7 +1419,9 @@ def local_shape_to_shape_i(fgraph, node): if not hasattr(fgraph, "shape_feature"): return shape_feature = fgraph.shape_feature - ret = shape_feature.make_vector_shape(node.inputs[0]) + r = node.inputs[0] + elems = [shape_feature.get_shape(r, i) for i in range(r.type.ndim)] + ret = as_tensor_variable(elems, ndim=1, dtype="int64") # We need to copy over stack trace from input to output copy_stack_trace(node.outputs[0], ret) @@ -1301,44 +1433,37 @@ def local_shape_to_shape_i(fgraph, node): @register_canonicalize @node_rewriter([Shape_i]) def local_track_shape_i(fgraph, node): + """Rewrite ``Shape_i(v, i)`` to the kernel-inferred shape expression. + + With the kernel-based `ShapeFeature`, per-node shape kernels are + always rooted in live inputs. Whenever ``v`` has an ``infer_shape`` + available, the kernel yields a non-``Shape_i`` expression for + ``v.shape[i]``. Rewriting the literal ``Shape_i(v, i)`` with the + kernel expression lets rewrites downstream see the inferred form and + typically lets the original producer node of ``v`` be pruned when + only its shape is consumed. """ - Update `Shape_i` nodes to match `ShapeFeature`'s internal state. - - This rewrite is essential for propagating shape information during graph - transformations (like lowering). When a node is replaced or updated, - `ShapeFeature` calculates the shape of the new node and "schedules" - dependent `Shape_i` nodes for update, so they use the latest inferred graph. - - If we start with an fgraph containing the two nodes below: - >> out = OpWithoutInferShape(a, b) - >> out_shape_i = Shape_i(out) - - And then rewrite - >> new_out = OpWithInferShape(a, b) - >> fgraph.replace(out, new_out) - - We end up with - >> out_shape_i == Shape_i(new_out) - - If installed, ShapeFeature will do this work in the background - >> new_out_shape = infer_shape(new_out) # Usually some f(a, b) - >> fgraph.shape_feature.scheduled[out_shape_i.owner] = new_out_shape + shape_feature = getattr(fgraph, "shape_feature", None) + if shape_feature is None: + return False - And this rewrite will ultimately propagate the inference back to the fgraph - >> new_out_shape_i = fgraph.shape_feature.scheduled[out_shape_i.owner][i] - >> fgraph.replace(out_shape_i, new_out_shape_i) + [v] = node.inputs + if v.owner is None: + return False - """ - try: - shape_feature = fgraph.shape_feature - except AttributeError: + i = node.op.i + new_shape = shape_feature.get_shape(v, i) + if new_shape is None: return False - if node not in shape_feature.scheduled: + # Avoid rewriting Shape_i(v, i) to itself. + if new_shape.owner is node or ( + isinstance(new_shape, Variable) + and new_shape.owner is not None + and isinstance(new_shape.owner.op, Shape_i) + and new_shape.owner.op.i == i + and new_shape.owner.inputs[0] is v + ): return False - # Don't unschedule node as it could be reinserted in the - # fgraph as we don't change it in the shapefeature internal - # structure. - replacement = shape_feature.scheduled[node] - return [shape_feature.shape_of[replacement][node.op.i]] + return [new_shape] diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 8500c957a0..92a69dacc7 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -415,8 +415,9 @@ def local_subtensor_merge(fgraph, node): # Get the shapes of the vectors ! try: # try not to introduce new shape into the graph - xshape = fgraph.shape_feature.shape_of[x] - ushape = fgraph.shape_feature.shape_of[u] + sf = fgraph.shape_feature + xshape = sf.shape_tuple(x) + ushape = sf.shape_tuple(u) except AttributeError: # Following the suggested use of shape_feature which should # consider the case when the compilation mode doesn't @@ -683,7 +684,7 @@ def local_useless_subtensor(fgraph, node): if not hasattr(fgraph, "shape_feature"): return - shape_of = fgraph.shape_feature.shape_of + shape_feature = fgraph.shape_feature cdata = get_constant_idx( node.op.idx_list, @@ -705,7 +706,7 @@ def local_useless_subtensor(fgraph, node): # is not a useless subtensor return False - length_pos = shape_of[node.inputs[0]][pos] + length_pos = shape_feature.get_shape(node.inputs[0], pos) if isinstance(idx.stop, int | np.integer): length_pos_data = sys.maxsize @@ -809,12 +810,12 @@ def local_useless_AdvancedSubtensor1(fgraph, node): if not hasattr(fgraph, "shape_feature"): return - shape_of = fgraph.shape_feature.shape_of + shape_feature = fgraph.shape_feature # get length of the indexed tensor along the first axis try: length = get_scalar_constant_value( - shape_of[node.inputs[0]][0], only_process_constants=True + shape_feature.get_shape(node.inputs[0], 0), only_process_constants=True ) except NotScalarConstantError: return False @@ -1640,7 +1641,6 @@ def local_useless_inc_subtensor_alloc(fgraph, node): # need it for this optimization, so don't continue. return False - shape_of = shape_feature.shape_of same_shape = shape_feature.same_shape # Get the subtensor of `x` indexed by `i` in order to compare @@ -1654,22 +1654,12 @@ def local_useless_inc_subtensor_alloc(fgraph, node): else: raise Exception("Should never happen!") - reason = "local_useless_incsubtensor_alloc" - - # Add `xi` to the shape feature `fgraph`. This is important for - # shape inference later because the variable must be part of the - # function graph in order to call `same_shape` on it. - if xi not in shape_of: - shape_feature.on_import(fgraph, xi.owner, f"{reason}: add `xi`") - # `xi` may have more dimensions than `y` since the subtensor ops # do automatic broadcasting of the increment internally. Thus, we # need to make the leading implicitly broadcasted dimensions # explicit for shape comparison later. if xi.ndim > y.ndim: y = shape_padleft(y, xi.ndim - y.ndim) - if y not in shape_of: - shape_feature.on_import(fgraph, y.owner, f"{reason}: add `y`") # Build `z_broad` explicitly to include extra implicit dimensions. z_broad = (True,) * (xi.ndim - z.ndim) + z.broadcastable @@ -1702,7 +1692,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node): if ( z_broad[k] and not same_shape(xi, y, dim_x=k, dim_y=k) - and shape_of[y][k] != 1 + and shape_feature.get_shape(y, k) != 1 ) ] diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 74c445f2f5..1b43eaec6e 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -339,21 +339,7 @@ def shape_i(var, i, fgraph=None): """ if fgraph and hasattr(fgraph, "shape_feature"): - shape_feature = fgraph.shape_feature - shape_of = shape_feature.shape_of - - def recur(node): - if node.outputs[0] not in shape_of: - for inp in node.inputs: - if inp.owner: - recur(inp.owner) - # If the output var isn't marked as being in the graph, - # we need to add it in the ShapeFeature. - shape_feature.on_import(fgraph, node, "graph.ops.shape_i") - - if var not in shape_of: - recur(var.owner) - return shape_of[var][i] + return fgraph.shape_feature.get_shape(var, i) # If we are not able to use the shape feature, we should not put # Shape_i in the graph. Otherwise, the shape feature optimization diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 1f12e47cc1..1789678e5e 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -7,8 +7,7 @@ from numpy import nditer from numpy.lib.array_utils import normalize_axis_tuple -import pytensor -from pytensor.graph import FunctionGraph, Variable +from pytensor.graph import Variable from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.utils import hash_from_code @@ -40,76 +39,6 @@ def hash_from_ndarray(data) -> str: ) -def shape_of_variables( - fgraph: FunctionGraph, input_shapes -) -> dict[Variable, tuple[int, ...]]: - """ - Compute the numeric shape of all intermediate variables given input shapes. - - Parameters - ---------- - fgraph - The FunctionGraph in question. - input_shapes : dict - A dict mapping input to shape. - - Returns - ------- - shapes : dict - A dict mapping variable to shape - - .. warning:: This modifies the fgraph. Not pure. - - Examples - -------- - >>> import pytensor.tensor as pt - >>> from pytensor.graph.fg import FunctionGraph - >>> x = pt.matrix("x") - >>> y = x[512:] - >>> y.name = "y" - >>> fgraph = FunctionGraph([x], [y], clone=False) - >>> d = shape_of_variables(fgraph, {x: (1024, 1024)}) - >>> d[y] - (array(512), array(1024)) - >>> d[x] - (array(1024), array(1024)) - """ - - if not hasattr(fgraph, "shape_feature"): - from pytensor.tensor.rewriting.shape import ShapeFeature - - fgraph.attach_feature(ShapeFeature()) - - shape_feature = fgraph.shape_feature # type: ignore[attr-defined] - - input_dims = [ - dimension for inp in fgraph.inputs for dimension in shape_feature.shape_of[inp] - ] - - output_dims = [ - dimension for shape in shape_feature.shape_of.values() for dimension in shape - ] - - compute_shapes = pytensor.function(input_dims, output_dims) - - if any(i not in fgraph.inputs for i in input_shapes): - raise ValueError( - "input_shapes keys aren't in the fgraph.inputs. FunctionGraph()" - " interface changed. Now by default, it clones the graph it receives." - " To have the old behavior, give it this new parameter `clone=False`." - ) - - numeric_input_dims = [dim for inp in fgraph.inputs for dim in input_shapes[inp]] - numeric_output_dims = compute_shapes(*numeric_input_dims) - - sym_to_num_dict = dict(zip(output_dims, numeric_output_dims, strict=True)) - - l = {} - for var in shape_feature.shape_of: - l[var] = tuple(sym_to_num_dict[sym] for sym in shape_feature.shape_of[var]) - return l - - def import_func_from_string(func_string: str): # -> Optional[Callable]: func = getattr(np, func_string, None) if func is not None: diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index f8180773cc..c6f1f1e74c 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -460,8 +460,8 @@ def test_infer_shape(self): fg = FunctionGraph(outputs=[op_var[1]], clone=False) opt_res = rewrite_graph(fg, custom_rewrite=ShapeOptimizer()) - assert opt_res.shape_feature.shape_of[x] is None - assert opt_res.shape_feature.shape_of[z][0].data == 2 + assert opt_res.shape_feature.shape_tuple(x) is None + assert opt_res.shape_feature.shape_tuple(z)[0].data == 2 def test_make_node_shared(self): """Make sure we can provide `OpFromGraph.make_node` new shared inputs and get a valid `OpFromGraph`.""" diff --git a/tests/graph/test_replace.py b/tests/graph/test_replace.py index 2605c15b8f..9e9dcfe189 100644 --- a/tests/graph/test_replace.py +++ b/tests/graph/test_replace.py @@ -4,15 +4,22 @@ import pytensor.tensor as pt from pytensor import config, function, shared +from pytensor.compile.ops import DeepCopyOp, deep_copy_op from pytensor.graph.basic import equal_computations +from pytensor.graph.destroyhandler import DestroyHandler, _contains_cycle +from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import ( _vectorize_node, + break_aliasing_cycles, clone_replace, graph_replace, vectorize_graph, ) -from pytensor.graph.traversal import graph_inputs +from pytensor.graph.traversal import ancestors, graph_inputs from pytensor.tensor import dvector, fvector, vector +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.math import maximum +from pytensor.tensor.type import scalar from tests import unittest_tools as utt from tests.graph.utils import MyOp, MyVariable, op_multiple_outputs from tests.unittest_tools import assert_equal_computations @@ -373,3 +380,115 @@ def test_non_variable_raises(self): batch_out.eval({x: 3, y: 4}), np.zeros((2, 3, 4)), ) + + +class TestBreakAliasingCycles: + """``break_aliasing_cycles`` re-routes aliased-scalar reads on Apply + nodes whose dual reference to ``x`` and the inplace op's output ``y`` + would otherwise be unschedulable. + + The bad pattern (the one we patch): a single Apply ``A`` that reads + an aliased scalar ``x`` directly *and* whose other inputs reach the + inplace op's output ``y``. The fix: replace ``x`` on that one Apply + with ``deep_copy_op(x)``. + """ + + @staticmethod + def _inplace_add(): + import pytensor.scalar as ps + + return Elemwise(ps.add, inplace_pattern={0: 0}) + + def _setup_with_inplace(self): + """Build an fgraph containing an inplace destroyer ``y = x + x'`` + (destroys ``x``). + """ + x = scalar("x") + x2 = scalar("x2") + y = self._inplace_add()(x, x2) + fg = FunctionGraph([x, x2], [y], clone=False) + fg.attach_feature(DestroyHandler()) + # Sanity: dh sees the destroyer. + assert fg.destroy_handler.destroyers + return fg, x, x2, y + + @staticmethod + def _has_cycle(fgraph, patched): + """Add ``patched`` as a fresh output of ``fgraph`` and ask the + destroy handler whether the resulting graph contains a cycle. + """ + for var in patched: + fgraph.add_output(var, reason="test_break_aliasing_cycles") + dh = fgraph.destroy_handler + return _contains_cycle(fgraph, dh.orderings(fgraph, ordered=False)) + + def test_clean_subgraph_unchanged(self): + """Subgraph that only reads ``x`` (no ``y``) passes through.""" + fg, x, _x2, _y = self._setup_with_inplace() + v = pt.add(x, x) + (result,) = break_aliasing_cycles([v], fg.destroyers) + assert result is v + assert not self._has_cycle(fg, [result]) + + def test_subgraph_reads_only_y_unchanged(self): + """Subgraph that only reads ``y`` (not ``x`` directly) passes.""" + fg, _x, x2, y = self._setup_with_inplace() + v = pt.add(y, x2) + (result,) = break_aliasing_cycles([v], fg.destroyers) + assert result is v + assert not self._has_cycle(fg, [result]) + + def test_split_pattern_unchanged(self): + """``zar(y, bar(x))`` schedules cleanly: no Apply has both x and + y as direct inputs.""" + fg, x, x2, y = self._setup_with_inplace() + bar = pt.add(x, x2) + zar = pt.add(y, bar) + (result,) = break_aliasing_cycles([zar], fg.destroyers) + assert result is zar + # Original x edge in bar still present (no surgery). + assert x in set(ancestors([result])) + assert not self._has_cycle(fg, [result]) + + def test_dual_reference_pattern_is_patched(self): + """A single Apply that reads x AND y is the bad pattern — patch it.""" + fg, x, _x2, y = self._setup_with_inplace() + bad = maximum(y, x) + (result,) = break_aliasing_cycles([bad], fg.destroyers) + expected = maximum(y, deep_copy_op(x)) + utt.assert_equal_computations([result], [expected], original=bad) + # The destroyer (inplace add) still reads x directly — left alone. + for a in fg.apply_nodes: + if a.op.destroy_map: + assert x in a.inputs + assert not self._has_cycle(fg, [result]) + + def test_transitive_y_via_sibling(self): + """A's other input transitively reaches y via an intermediate node. + + The intermediate ``add(y, x2)`` doesn't equal ``y`` by identity, + but its ancestors include ``y``. Walking ``deps`` upward must + catch that, so ``maximum(intermediate, x)`` still gets the surgery. + """ + fg, x, x2, y = self._setup_with_inplace() + intermediate = pt.add(y, x2) + bad = maximum(intermediate, x) + (result,) = break_aliasing_cycles([bad], fg.destroyers) + expected = maximum(intermediate, deep_copy_op(x)) + utt.assert_equal_computations([result], [expected], original=bad) + assert not self._has_cycle(fg, [result]) + + def test_multiple_outputs_share_substitute(self): + """When two outputs both need to patch the same destroyed scalar, + a single DeepCopyOp(x) Apply is shared between them.""" + fg, x, _x2, y = self._setup_with_inplace() + bad1 = maximum(y, x) + bad2 = pt.add(y, x) + results = break_aliasing_cycles([bad1, bad2], fg.destroyers) + deep_copies = set() + for r in results: + for a in {anc.owner for anc in ancestors([r]) if anc.owner}: + if isinstance(a.op, DeepCopyOp): + deep_copies.add(a) + assert len(deep_copies) == 1 + assert not self._has_cycle(fg, results) diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 358c95fc66..96cecdc333 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -284,7 +284,7 @@ def test_normal_ShapeFeature(): clone=False, features=[ShapeFeature()], ) - s1, s2 = fg.shape_feature.shape_of[d_rv] + s1, s2 = fg.shape_feature.shape_tuple(d_rv) f = function([M_pt, sd_pt], [s1, s2, d_rv], mode=py_mode, on_unused_input="ignore") s1_val, s2_val, d_rv_val = f(3, np.array(1.0, dtype=config.floatX)) @@ -657,7 +657,7 @@ def test_mvnormal_ShapeFeature(): features=[ShapeFeature()], ) - s1, s2 = fg.shape_feature.shape_of[d_rv] + s1, s2 = fg.shape_feature.shape_tuple(d_rv) f = function([M_pt], [s1, s2], mode=py_mode) s1_val, s2_val = f(2) @@ -679,7 +679,7 @@ def test_mvnormal_ShapeFeature(): features=[ShapeFeature()], ) - s1, s2, s3, s4 = fg.shape_feature.shape_of[d_rv] + s1, s2, s3, s4 = fg.shape_feature.shape_tuple(d_rv) mean_val = np.array([[0, 1, 2]], dtype=config.floatX) f = function([mean, cov], [s1, s2, s3, s4], mode=py_mode, on_unused_input="ignore") @@ -810,7 +810,7 @@ def test_dirichlet_ShapeFeature(): features=[ShapeFeature()], ) - s1, s2 = fg.shape_feature.shape_of[d_rv] + s1, s2 = fg.shape_feature.shape_tuple(d_rv) assert M_pt in graph_inputs([s1]) assert N_pt in graph_inputs([s2]) diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index f24860d804..124852c723 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -3,6 +3,7 @@ import numpy as np import pytest +import pytensor.scalar as ps import pytensor.tensor as pt from pytensor import shared from pytensor.compile.maker import function @@ -10,7 +11,8 @@ from pytensor.compile.ops import deep_copy_op from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable, equal_computations -from pytensor.graph.fg import FunctionGraph +from pytensor.graph.destroyhandler import DestroyHandler, _contains_cycle +from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph from pytensor.graph.op import Op from pytensor.graph.rewriting.basic import check_stack_trace, node_rewriter, out2in from pytensor.graph.rewriting.utils import rewrite_graph @@ -33,6 +35,7 @@ shape, specify_shape, ) +from pytensor.tensor.signal import convolve1d from pytensor.tensor.subtensor import set_subtensor from pytensor.tensor.type import ( fmatrix, @@ -592,6 +595,307 @@ def test_vector_dim_err(self): with pytest.raises(IndexError): shape_feature.same_shape(x, o, 0, 1) + def test_distinct_passthrough_ops(self): + # Different unary Elemwises (exp vs cos) over the same input have + # passthrough kernels that bottom out at the same input shape. + x = vector() + a = pt.exp(x) + b = pt.cos(x) + fgraph = FunctionGraph([x], [a, b], clone=False) + shape_feature = ShapeFeature() + fgraph.attach_feature(shape_feature) + assert shape_feature.same_shape(a, b) + + def test_chained_passthrough(self): + # ``exp(x)`` and ``exp(x + 1)`` should be same_shape: the inner Add + # passthrough cascades through the outer Elemwise's passthrough + # back to ``shape_key(x, 0)``. + x = vector() + a = pt.exp(x) + b = pt.exp(x + 1) + fgraph = FunctionGraph([x], [a, b], clone=False) + shape_feature = ShapeFeature() + fgraph.attach_feature(shape_feature) + assert shape_feature.same_shape(a, b) + + def test_distinct_sources_shared_shape_arg(self): + # ``alloc(0., n)`` and ``alloc(1., n)`` have different sources but + # share the same shape input ``n``. The dim_kernel for Alloc has + # ``input_slot`` bindings to the shape args; both Allocs bind to + # the same live ``n``, so same_shape must hold. + n = iscalar("n") + a = alloc(np.float64(0.0), n) + b = alloc(np.float64(1.0), n) + fgraph = FunctionGraph([n], [a, b], clone=False) + shape_feature = ShapeFeature() + fgraph.attach_feature(shape_feature) + assert shape_feature.same_shape(a, b) + + # Same shape vars at swapped positions: ``alloc(0., n, n+1)`` + # vs ``alloc(0., n+1, n)``. Per-dim queries should detect the + # cross-dim equivalences; the overall ``same_shape`` (no dims) + # compares dim-by-dim positionally and should fail. + n_plus_1 = n + 1 + c = alloc(np.float64(0.0), n, n_plus_1) + d = alloc(np.float64(0.0), n_plus_1, n) + fgraph = FunctionGraph([n], [c, d], clone=False) + shape_feature = ShapeFeature() + fgraph.attach_feature(shape_feature) + # Overall fails (dim 0 of c is n, dim 0 of d is n+1, etc.). + assert not shape_feature.same_shape(c, d) + # Cross-dim works: same shape var on each side. + assert shape_feature.same_shape(c, d, 0, 1) # n == n + assert shape_feature.same_shape(c, d, 1, 0) # n+1 == n+1 + # Same-dim comparisons should not match. + assert not shape_feature.same_shape(c, d, 0, 0) + assert not shape_feature.same_shape(c, d, 1, 1) + + def test_baked_in_shape_subexpr_limitation(self): + # KNOWN LIMITATION (documented on ``shape_key``): kernel input + # bindings compare the live ``node.inputs[k]`` by ``id``, not by + # structural shape-equivalence. Two structurally-equivalent + # live shape inputs that happen to be distinct ``Variable`` + # objects yield different ``same_shape`` results. + # + # ``reshape(x, exp(s).shape)`` and ``reshape(x, cos(s).shape)`` + # both have output shape equal to ``s.shape`` at runtime, but + # the live shape inputs (``exp(s).shape`` vs ``cos(s).shape``) + # are distinct Variables. Reshape's kernel binds the shape + # input via ``input_slot``, which compares by ``id`` only, so + # ``same_shape`` returns ``False``. + # + # Closing this gap requires inlining sub-kernels at build time + # (so the parent kernel resolves through ``exp/cos`` into ``s`` + # directly, content-addressing the whole chain). If that lands, + # flip the assert. + x = vector() + s = vector() + a = reshape(x, pt.exp(s).shape) + b = reshape(x, pt.cos(s).shape) + fgraph = FunctionGraph([x, s], [a, b], clone=False) + shape_feature = ShapeFeature() + fgraph.attach_feature(shape_feature) + assert not shape_feature.same_shape(a, b) + + +def test_unaliased_shape_tuple_blockwise_convolve(): + """Recreate the ``Blockwise(Convolve1d)`` situation from the convolve1d + gradient that originally triggered the destroy-handler cycle. + + The setup mirrors what late inplace fusion produces: an inplace + ``Composite{(i + j) - 1}`` over two ``Shape_i`` scalars feeding an + ``Alloc`` that's the first input of a ``Blockwise(Convolve1d)``. Lazy + shape materialization traces through the Alloc and pulls the live + destroyer output into the shape arithmetic. With the second convolve + input also derived from ``larger`` (same source as the destroyed + scalar), the shape ends up with an Apply reading *both* the destroyed + ``Shape_i`` and the destroyer's output — the dual-reference pattern + that breaks scheduling. + + Asserts: + + 1. the naive ``shape_tuple``, once imported into the fgraph alongside + the inplace destroyer, is flagged as cyclic by the destroy handler + (this is the bug); + 2. ``unaliased_shape_tuple`` produces the same shape with the cycle- + pattern Applys rerouted through ``deep_copy_op``, so importing it + is cycle-free. + """ + larger = pt.matrix("larger", shape=(8, None)) + smaller = pt.matrix("smaller", shape=(8, None)) + + # Pre-warm the ShapeFeature ``Shape_i`` cache so the destroyer's + # destroyed inputs are the *same* Apply nodes the lazy shape + # materialization will return later. + warm_fg = FunctionGraph([larger, smaller], [larger], clone=False) + warm_sf = ShapeFeature() + warm_fg.attach_feature(warm_sf) + larger_s1 = warm_sf.get_shape(larger, 1) + smaller_s1 = warm_sf.get_shape(smaller, 1) + + # Inplace Composite{(i + j) - 1}: destroys input 0 (= ``larger.shape[1]``). + sx, sy = ps.int64(), ps.int64() + inplace_comp = Elemwise( + ps.Composite([sx, sy], [ps.sub(ps.add(sx, sy), ps.constant(1, dtype="int64"))]), + inplace_pattern={0: 0}, + ) + new_dim = inplace_comp(larger_s1, smaller_s1) + a = alloc(pt.zeros((1, 1)), 1, new_dim) + # Slice of ``larger`` as the second convolve input — its shape depends + # on ``larger.shape[1]`` (= the destroyed scalar) too. That's what + # makes the convolve shape arithmetic combine the destroyer's output + # with the destroyed scalar in a single Apply. + out = convolve1d(a, larger[:, ::-1], mode="full") + + fg = FunctionGraph([larger, smaller], [out], clone=False) + sf = ShapeFeature() + fg.attach_feature(sf) + fg.attach_feature(DestroyHandler()) + sf._shape_i_cache[(id(larger), 1)] = larger_s1 + sf._shape_i_cache[(id(smaller), 1)] = smaller_s1 + + naive_shape = sf.shape_tuple(out) + safe_shape = sf.unaliased_shape_tuple(out) + + # Importing each shape into a fresh fgraph (with the destroyer present) + # tells us whether the destroy handler accepts it. A new fgraph per + # check keeps the cycle from the naive case from poisoning the safe one. + def imports_with_cycle(shape_vars): + check_fg = FunctionGraph([larger, smaller], [out, *shape_vars], clone=False) + check_fg.attach_feature(DestroyHandler()) + dh = check_fg.destroy_handler + return _contains_cycle(check_fg, dh.orderings(check_fg, ordered=False)) + + # Naive lazy shape: destroy handler rejects it. + assert imports_with_cycle(naive_shape) + # Cycle-broken version: imports cleanly. + assert not imports_with_cycle(safe_shape) + + +class _NoShapeOp(Op): + """Op without ``infer_shape``, used to drive the kernel-borrow + override path in ``ShapeFeature.on_change_input``.""" + + __props__ = () + + def make_node(self, x): + return Apply(self, [x], [x.type()]) + + def perform(self, node, inputs, outputs): + outputs[0][0] = inputs[0] + + +_no_shape = _NoShapeOp() + + +class TestKernelReroute: + """When ``r`` is replaced by ``new_r`` whose Op has no + ``infer_shape``, ``ShapeFeature.on_change_input`` rederives r's + shape kernel against ``new_r.owner.inputs`` by matching kernel- + input bindings via ``shape_key``. The override stores + ``(dim_kernel, role_bindings)`` per dim — no live ``Variable`` + is pinned. None of these paths are exercised by in-tree rewriters + today (every common Op has an ``infer_shape``), so each test + wires up the replacement explicitly. + """ + + def test_passthrough_reroute(self): + """Passthrough kernel (``exp(x)``: shape = ``[x.shape]``) + reroutes against ``new_r.owner.inputs = [x]``. The override + stores only ``(dim_kernel, ((0, 0),))``; ``shape_key`` matches + ``x.shape[0]`` exactly. + """ + x = vector("x") + r = exp(x) + new_r = _no_shape(x) + + fg = FunctionGraph([x], [r], clone=False) + sf = ShapeFeature() + fg.attach_feature(sf) + fg.replace(r, new_r, reason="reroute_test") + + assert new_r in sf._overrides + ov = sf._overrides[new_r] + assert len(ov) == 1 + dim_kernel, role_bindings = ov[0] + + # Pure-kernel structure: no live Variables. + assert isinstance(dim_kernel, FrozenFunctionGraph) + assert role_bindings == ((0, 0),) + for binding in role_bindings: + assert all(isinstance(b, int) for b in binding) + + live = sf.get_shape(new_r, 0) + assert live.owner is not None and isinstance(live.owner.op, Shape_i) + assert live.owner.op.i == 0 + assert live.owner.inputs[0] is x + + assert sf.shape_key(new_r, 0) == sf.shape_key(x, 0) + assert sf.same_shape(new_r, x, 0, 0) + + def test_no_shape_key_match(self): + """When the only role's binding has no ``shape_key`` match in + ``new_r.owner.inputs``, reroute gives up for that dim. Here r + depends on ``a.shape[0]`` but ``new_r.owner.inputs = [b]`` — + ``id``-keyed leaves of two distinct vectors don't match, so + no override is installed. + """ + a = vector("a") + b = vector("b") + r = exp(a) + new_r = _no_shape(b) + + fg = FunctionGraph([a, b], [r], clone=False) + sf = ShapeFeature() + fg.attach_feature(sf) + fg.replace(r, new_r, reason="reroute_test") + + # Reroute returned None for every dim (here, just dim 0), so + # no override entry is recorded. + assert new_r not in sf._overrides + # Shape lookups fall back to Shape_i on new_r itself. + live = sf.get_shape(new_r, 0) + assert isinstance(live.owner.op, Shape_i) and live.owner.inputs[0] is new_r + + def test_input_slot_kernel_skipped(self): + """A kernel with an ``input_slot`` role (Alloc — its shape + elements are input *values*, not input *shapes*) can't be + rerouted by shape-key matching. Reroute bails before searching + new_r's inputs. + """ + n = iscalar("n") + r = alloc(np.float64(0.0), n) + other = pt.tensor("other", dtype="float64", shape=(None,)) + new_r = _no_shape(other) + + fg = FunctionGraph([n, other], [r], clone=False) + sf = ShapeFeature() + fg.attach_feature(sf) + fg.replace(r, new_r, reason="reroute_test") + + # input_slot role → reroute returns None for every dim. + assert new_r not in sf._overrides + + def test_partial_reroute(self): + """Per-dim independence: with ``r = dot(A, B)`` and + ``new_r = no_shape(A)``, dim 0 (``A.shape[0]``) reroutes + cleanly while dim 1 (``B.shape[1]``) has no counterpart in + ``[A]``. The override is installed but with ``None`` at dim 1 + — falling back to ``Shape_i`` for that dim only. + """ + A = matrix("A") + B = matrix("B") + r = pt.dot(A, B) + new_r = _no_shape(A) + + fg = FunctionGraph([A, B], [r], clone=False) + sf = ShapeFeature() + fg.attach_feature(sf) + fg.replace(r, new_r, reason="reroute_test") + + assert new_r in sf._overrides + ov = sf._overrides[new_r] + assert len(ov) == 2 + + # Dim 0 rerouted to A.shape[0]. + assert ov[0] is not None + dim_kernel0, role_bindings0 = ov[0] + assert isinstance(dim_kernel0, FrozenFunctionGraph) + assert role_bindings0 == ((0, 0),) + + # Dim 1's binding (B.shape[1]) has no counterpart in [A]. + assert ov[1] is None + + s0 = sf.get_shape(new_r, 0) + assert isinstance(s0.owner.op, Shape_i) and s0.owner.op.i == 0 + assert s0.owner.inputs[0] is A + s1 = sf.get_shape(new_r, 1) + assert isinstance(s1.owner.op, Shape_i) and s1.owner.op.i == 1 + assert s1.owner.inputs[0] is new_r + + assert sf.shape_key(new_r, 0) == sf.shape_key(A, 0) + assert sf.shape_key(new_r, 1) == ("leaf", id(new_r), 1) + def test_useless_specify_shape(): x = tensor("x", shape=(None, 5, 3)) diff --git a/tests/tensor/test_utils.py b/tests/tensor/test_utils.py index e2fd3d2958..fd12a5871b 100644 --- a/tests/tensor/test_utils.py +++ b/tests/tensor/test_utils.py @@ -1,10 +1,60 @@ import numpy as np import pytest +import pytensor import pytensor.tensor as pt +from pytensor.graph.basic import Variable from pytensor.graph.fg import FunctionGraph -from pytensor.tensor.type import matrix -from pytensor.tensor.utils import hash_from_ndarray, shape_of_variables +from pytensor.tensor.type import lscalar, matrix +from pytensor.tensor.utils import hash_from_ndarray + + +def shape_of_variables( + fgraph: FunctionGraph, input_shapes +) -> dict[Variable, tuple[int, ...]]: + """Compute the numeric shape of every variable in ``fgraph`` given + the numeric shapes of its inputs (test helper). + + Builds scalar placeholders for each input dim, walks + ``builders.infer_shape`` over the fgraph variables, then compiles a + scalar-in / scalar-out function. Used only by the tests in this + module. + """ + if any(i not in fgraph.inputs for i in input_shapes): + raise ValueError( + "input_shapes keys aren't in the fgraph.inputs. FunctionGraph()" + " interface changed. Now by default, it clones the graph it receives." + " To have the old behavior, give it this new parameter `clone=False`." + ) + from pytensor.compile.builders import infer_shape + + input_shape_scalars: dict[Variable, tuple[Variable, ...]] = { + inp: tuple(lscalar() for _ in range(inp.type.ndim)) for inp in fgraph.inputs + } + input_dims = [s for inp in fgraph.inputs for s in input_shape_scalars[inp]] + + all_vars: list[Variable] = [v for v in fgraph.variables if hasattr(v.type, "ndim")] + inferred_shapes = infer_shape( + outs=all_vars, + inputs=list(fgraph.inputs), + input_shapes=[input_shape_scalars[inp] for inp in fgraph.inputs], + ) + per_var_shape: dict = dict(zip(all_vars, inferred_shapes, strict=True)) + output_dims = [dim for shape in per_var_shape.values() for dim in shape] + + compute_shapes = pytensor.function( + input_dims, output_dims, on_unused_input="ignore" + ) + + numeric_input_dims = [dim for inp in fgraph.inputs for dim in input_shapes[inp]] + numeric_output_dims = compute_shapes(*numeric_input_dims) + + sym_to_num_dict = dict(zip(output_dims, numeric_output_dims, strict=True)) + + return { + var: tuple(sym_to_num_dict[sym] for sym in shape) + for var, shape in per_var_shape.items() + } def test_hash_from_ndarray(): diff --git a/tests/xtensor/test_rewriting.py b/tests/xtensor/test_rewriting.py index da076b1824..bb46310c07 100644 --- a/tests/xtensor/test_rewriting.py +++ b/tests/xtensor/test_rewriting.py @@ -1,8 +1,15 @@ +import pytest + +from pytensor.compile import optdb from pytensor.graph import FunctionGraph +from pytensor.tensor import tensor from pytensor.tensor.basic import infer_shape_db +from pytensor.tensor.random.type import random_generator_type from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.shape import Shape_i -from pytensor.xtensor import xtensor +from pytensor.xtensor import as_xtensor, xtensor +from pytensor.xtensor.random import normal +from pytensor.xtensor.vectorization import XRV from tests.unittest_tools import assert_equal_computations @@ -22,3 +29,33 @@ def test_infer_shape_db_handles_xtensor_lowering(): infer_shape_db.default_query.rewrite(fgraph) [rewritten_shape_y] = fgraph.outputs assert_equal_computations([rewritten_shape_y], [Shape_i(1)(x)]) + + +@pytest.mark.parametrize("with_shape_feature", [False, True]) +def test_nested_xrv_lowering_does_not_leak_stale_xrv(with_shape_feature): + # Nested XRV where the outer's extra_dims aren't in the inner's dims. + # Lowering needs the inner's shape for the outer's size, which drags the + # pre-lowering XRV back into the graph via a dormant Shape_i cached in + # ShapeFeature.shape_of. + a_size = tensor("a_size", shape=(), dtype="int64") + b_size = tensor("b_size", shape=(), dtype="int64") + rng1 = random_generator_type("rng1") + rng2 = random_generator_type("rng2") + mu = normal(0.0, 0.1, extra_dims={"a": as_xtensor(a_size)}, rng=rng1) + out = normal(mu, 1.0, extra_dims={"b": as_xtensor(b_size)}, rng=rng2) + + features = [ShapeFeature()] if with_shape_feature else [] + fgraph = FunctionGraph( + [a_size, b_size, rng1, rng2], + [out.values], + features=features, + copy_inputs=False, + ) + optdb.query( + "+lower_xtensor", + "+canonicalize", + "-local_eager_useless_unbatched_blockwise", + ).rewrite(fgraph) + + stale = [n for n in fgraph.apply_nodes if isinstance(n.op, XRV)] + assert not stale, f"XRV remained after lowering: {stale}" From 07778b478b6b6968a4fa675a3cef38f7d07c7b8f Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 1 May 2026 00:15:00 +0200 Subject: [PATCH 4/5] WIP --- pytensor/tensor/rewriting/shape.py | 159 +++++++++++------------------ 1 file changed, 57 insertions(+), 102 deletions(-) diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 4ea5817c4f..570f23844b 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -44,7 +44,6 @@ AdvancedIncSubtensor1, IncSubtensor, Subtensor, - get_idx_list, ) from pytensor.tensor.type import TensorType, integer_dtypes, lscalar from pytensor.tensor.type_other import NoneTypeT @@ -72,53 +71,20 @@ class ShapeFeature(Feature): - ``same_shape(x, y, dim_x=None, dim_y=None)`` — via content-addressed ``shape_key``. """ + _scalar_shape = constant(np.array([], dtype="int64")) + def __init__(self): - # Per-Apply kernel cache: ``node -> (kernel, meta)`` from - # ``_build_kernel``. The kernel is a ``FrozenFunctionGraph`` rooted - # in dummy inputs; ``get_shape`` materializes it against today's - # live ``node.inputs``. Populated lazily on ``get_shape`` / - # ``shape_key`` / ``_reroute_dim``; dropped in ``on_prune``. + # node -> (kernel, meta) from _build_kernel, lazily populated self._cache: dict = {} - # Kernel-borrow overrides keyed by ``Variable``: an ndim-tuple - # whose entries are either ``None`` (Shape_i fallback) or - # ``(dim_kernel, role_bindings)``. ``dim_kernel`` is the per-dim - # ``FrozenFunctionGraph`` borrowed from the *replaced* var's - # kernel; ``role_bindings`` is a tuple of ``(input_idx, dim)`` - # aligned with ``dim_kernel.inputs``, indexing into the keying - # var's ``.owner.inputs``. No live Variables are pinned: the - # live shape is rebuilt at access time by walking the dim_kernel - # against ``v.owner.inputs[input_idx].shape[dim]``. Installed by - # ``on_change_input`` when ``new_r`` replaces ``r`` and - # ``new_r``'s Op has no ``infer_shape``. + # node -> {slot: (dim_kernel, used_roles) | None}, per-dim views of _cache + self._dim_kernel_cache: dict = {} + # var -> ndim-tuple of (dim_kernel, role_bindings) | None, + # installed by on_change_input when new_r's Op has no infer_shape self._overrides: dict = {} - # Memoizes ``Shape_i(i)(v)`` for leaves/fallbacks so callers that - # cross-reference shape entries with ``Shape_i`` nodes in the graph - # observe Apply identity (the graph's MergeFeature would otherwise - # merge structurally equal copies, but by then compare-by-identity - # rewrites may have already bailed out). - # Keyed by ``(id(v), i)``; safe because the fgraph holds a strong - # ref to ``v`` for the feature's lifetime. ``on_prune`` drops the - # entries for removed Apply outputs; graph-input removal would - # leak entries but is not a path we currently exercise. + # (id(v), i) -> Shape_i(i)(v), ensures Apply identity for leaves self._shape_i_cache: dict = {} - # Memoize the canonicalized result of ``get_shape(v, i)`` so a - # second caller observes identity, not a fresh equivalent tree. - # Safe to hold strong refs because the cached expression is - # canonical: ``Shape_i{j}(graph_input_leaf)``, lscalars, constants, - # and arithmetic — none of those participate in the rewrite cycles - # that would otherwise replace nodes out from under us. Dropped - # in ``on_prune`` when the keying ``v`` is removed. + # (id(v), i) -> canonicalized get_shape result, avoids re-materialization self._materialized: dict = {} - # Per-dim sub-views of the per-node kernel, used by - # ``same_shape``/``shape_key``. Keyed ``node -> {slot: (dim_kernel, - # used_roles) | None}``. ``dim_kernel`` is a single-output - # ``FrozenFunctionGraph`` over only the kernel inputs reachable - # from ``kernel.outputs[slot]``. Because ``FrozenApply`` and - # ``NominalVariable`` are globally interned, structurally - # identical shape expressions yield ``__eq__`` dim kernels — so - # ``same_shape`` reduces to a content-addressed kernel match plus - # a roles/binding compare, instead of a recursive op-tree walk. - self._dim_kernel_cache: dict = {} self.fgraph: FunctionGraph | None = None def tracks_shape(self, v) -> bool: @@ -172,31 +138,29 @@ def _canonicalize_live_shape(self, s, memo=None): node = s.owner op = node.op - if isinstance(op, Subtensor): - base = node.inputs[0] - if base.owner is not None and isinstance(base.owner.op, Shape): + if isinstance(op, Subtensor) and op.idx_list == (0,): + base, idx = node.inputs + if isinstance(base.owner_op, Shape): x = base.owner.inputs[0] - if hasattr(x.type, "ndim"): - try: - idx_list = get_idx_list(node.inputs, op.idx_list) - if len(idx_list) == 1: - i = int(get_scalar_constant_value(idx_list[0])) - if 0 <= i < x.type.ndim: - result = self.get_shape(x, i) - memo[s] = result - return result - except (NotScalarConstantError, IndexError, TypeError): - pass + try: + idx_const = int(get_scalar_constant_value(idx)) + memo[s] = result = self.get_shape(x, idx_const) + return result + except (NotScalarConstantError, IndexError): + pass if isinstance(op, Shape): x = node.inputs[0] - if hasattr(x.type, "ndim") and x.type.ndim > 0: - result = stack([self.get_shape(x, j) for j in range(x.type.ndim)]) - memo[s] = result - return result + dims = [self.get_shape(x, j) for j in range(x.type.ndim)] + if dims: + result = stack(dims) + else: + result = self._scalar_shape + memo[s] = result + return result new_inputs = [self._canonicalize_live_shape(inp, memo) for inp in node.inputs] - if all(ni is oi for ni, oi in zip(new_inputs, node.inputs, strict=True)): + if all(ni is oi for ni, oi in zip(new_inputs, node.inputs)): memo[s] = s return s new_node = op.make_node(*new_inputs) @@ -322,25 +286,6 @@ def _reroute_dim(self, r, k, new_r_owner_inputs): role_bindings.append(match) return (dim_kernel, tuple(role_bindings)) - def _materialize_override(self, v, i, entry): - """Walk a borrowed dim_kernel against ``v.owner.inputs``.""" - if entry is None: - return self._shape_i_var(v, i) - dim_kernel, role_bindings = entry - new_owner_inputs = v.owner.inputs - memo: dict = { - k_input: self.get_shape(new_owner_inputs[idx], dim) - for k_input, (idx, dim) in zip( - dim_kernel.inputs, role_bindings, strict=True - ) - } - for fa in dim_kernel.toposort(): - new_inputs = [memo.get(inp, inp) for inp in fa.inputs] - new_node = fa.op.make_node(*new_inputs) - memo.update(zip(fa.outputs, new_node.outputs, strict=True)) - raw = memo.get(dim_kernel.outputs[0], dim_kernel.outputs[0]) - return self._canonicalize_live_shape(raw) - def _override_shape_key(self, v, i, entry): """Content-addressed key for an override entry; see ``shape_key``.""" if entry is None: @@ -570,41 +515,52 @@ def get_shape(self, v, i): if hasattr(v.type, "shape") and v.type.shape[i] is not None: return constant(v.type.shape[i], dtype="int64") cache_key = (id(v), i) - if (ov := self._overrides.get(v)) is not None: - cached = self._materialized.get(cache_key) - if cached is not None: - return cached - result = self._materialize_override(v, i, ov[i]) - self._materialized[cache_key] = result - return result - if v.owner is None: - return self._shape_i_var(v, i) - cached = self._materialized.get(cache_key) - if cached is not None: + if (cached := self._materialized.get(cache_key)) is not None: return cached + if (ov := self._overrides.get(v)) is not None: + entry = ov[i] + if entry is None: + return self._shape_i_var(v, i) + dim_kernel, role_bindings = entry + new_owner_inputs = v.owner.inputs + memo: dict = { + k_input: self.get_shape(new_owner_inputs[idx], dim) + for k_input, (idx, dim) in zip( + dim_kernel.inputs, role_bindings, strict=True + ) + } + for fa in dim_kernel.toposort(): + new_inputs = [memo.get(inp, inp) for inp in fa.inputs] + new_node = fa.op.make_node(*new_inputs) + memo.update(zip(fa.outputs, new_node.outputs)) + raw = memo.get(dim_kernel.outputs[0], dim_kernel.outputs[0]) + result = self._materialized[cache_key] = self._canonicalize_live_shape(raw) + return result + node = v.owner + + if node is None: + return self._shape_i_var(v, i) + if (entry := self._cache.get(node)) is None: - entry = self._build_kernel(node) - self._cache[node] = entry + self._cache[node] = entry = self._build_kernel(node) + kernel, meta = entry if kernel is None: - result = self._shape_i_var(v, i) - self._materialized[cache_key] = result + self._materialized[cache_key] = result = self._shape_i_var(v, i) return result out_idx = node.outputs.index(v) layout = meta["output_layout"] if layout[out_idx] is None: - result = self._shape_i_var(v, i) - self._materialized[cache_key] = result + self._materialized[cache_key] = result = self._shape_i_var(v, i) return result slot = sum((layout[k] or 0) for k in range(out_idx)) + i dk = self._dim_kernel(node, slot) if dk is None: - result = self._shape_i_var(v, i) - self._materialized[cache_key] = result + self._materialized[cache_key] = result = self._shape_i_var(v, i) return result dim_kernel, used_roles = dk @@ -624,8 +580,7 @@ def get_shape(self, v, i): memo[k_input] = self.get_shape( node.inputs[slot_to_input_idx[role[1]]], role[2] ) - else: - # self_out + else: # self_out memo[k_input] = node.outputs[role[1]] for fa in dim_kernel.toposort(): new_inputs = [memo.get(inp, inp) for inp in fa.inputs] From 33fc896e64654575b336b13a347dbbd17e2d0c6d Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 1 May 2026 01:28:54 +0200 Subject: [PATCH 5/5] Clean up ShapeFeature caches and remove tracks_shape - Rename _cache -> _shape_kernel_cache, _materialized -> _materialized_dim_cache - Key _materialized_dim_cache by node instead of (id(v), i) for cheap invalidation - Remove tracks_shape (broken with lazy design), replace with _inferred_shape_or_fallback helper in scan rewriting - Simplify on_prune/on_change_input cache invalidation - Clean up getattr(out.type, "ndim", 0) or 0 pattern --- pytensor/scan/rewriting.py | 25 ++-- pytensor/tensor/rewriting/shape.py | 232 +++++++++++------------------ 2 files changed, 104 insertions(+), 153 deletions(-) diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index 9addb3929e..df65d48566 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -65,7 +65,7 @@ ) from pytensor.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs from pytensor.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink -from pytensor.tensor.shape import shape +from pytensor.tensor.shape import Shape_i, shape from pytensor.tensor.subtensor import ( IncSubtensor, Subtensor, @@ -1358,6 +1358,15 @@ def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool: return not broadcasted_by(init_value_.squeeze(0), init_buffer[0]) +def _inferred_shape_or_fallback(shape_feature, v, i, fallback): + """Return ``shape_feature.get_shape(v, i)`` if it's better than ``Shape_i(v)``, else *fallback*.""" + if shape_feature is not None: + s = shape_feature.get_shape(v, i) + if not (s.owner and isinstance(s.owner.op, Shape_i) and s.owner.inputs[0] is v): + return s + return fallback + + def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool): r"""Graph optimizer that reduces scan memory consumption. @@ -1497,15 +1506,13 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: # 2.3.2 extract the begin/end of the first dimension if i >= op_info.n_mit_mot: - if shape_feature is not None and shape_feature.tracks_shape(out): - length = shape_feature.get_shape(out, 0) - else: - length = node.inputs[0] + init_l[i] + length = _inferred_shape_or_fallback( + shape_feature, out, 0, node.inputs[0] + init_l[i] + ) else: - if shape_feature is not None and shape_feature.tracks_shape(out): - length = shape_feature.get_shape(out, 0) - else: - length = out.shape[0] + length = _inferred_shape_or_fallback( + shape_feature, out, 0, out.shape[0] + ) cf_slice = get_canonical_form_slice(this_slice[0], length) slices[i] += [(cf_slice, this_slice)] # type: ignore diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 570f23844b..805ad61e78 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -54,7 +54,7 @@ class ShapeFeature(Feature): r"""Kernel-based `Feature` that tracks shape information in a graph. For each `Apply`, a `FrozenFunctionGraph` "kernel" is built once and - stored in ``self._cache[node]``. The kernel is rooted in *dummy* + stored in ``_shape_kernel_cache[node]``. The kernel is rooted in *dummy* variables — never the live outer variables — so it can't go stale as the fgraph mutates. Shape requests materialize the kernel against today's ``node.inputs`` (and recursive shape lookups), so @@ -67,7 +67,6 @@ class ShapeFeature(Feature): - ``unaliased_shape_tuple(v, dims=None)`` — like ``shape_tuple`` but breaks aliasing-induced cycles so the result is safe to import into the attached fgraph alongside its inplace destroyers. - - ``tracks_shape(v)`` — does the feature know a shape for ``v``? - ``same_shape(x, y, dim_x=None, dim_y=None)`` — via content-addressed ``shape_key``. """ @@ -75,34 +74,18 @@ class ShapeFeature(Feature): def __init__(self): # node -> (kernel, meta) from _build_kernel, lazily populated - self._cache: dict = {} - # node -> {slot: (dim_kernel, used_roles) | None}, per-dim views of _cache + self._shape_kernel_cache: dict = {} + # node -> {slot: (dim_kernel, used_roles) | None}, per-dim views of _shape_kernel_cache self._dim_kernel_cache: dict = {} # var -> ndim-tuple of (dim_kernel, role_bindings) | None, # installed by on_change_input when new_r's Op has no infer_shape self._overrides: dict = {} # (id(v), i) -> Shape_i(i)(v), ensures Apply identity for leaves self._shape_i_cache: dict = {} - # (id(v), i) -> canonicalized get_shape result, avoids re-materialization - self._materialized: dict = {} + # node -> {(out_idx, i): result}, canonicalized get_shape results + self._materialized_dim_cache: dict = {} self.fgraph: FunctionGraph | None = None - def tracks_shape(self, v) -> bool: - """``True`` iff this feature has shape information for ``v``. - - A var is tracked when its owner has a kernel cached (it was - hit by a ``get_shape`` / ``shape_key`` call), or it carries an - explicit override, or it's a graph input of the attached fgraph. - """ - if v is None or not hasattr(v.type, "ndim"): - return False - if v in self._overrides: - return True - if v.owner is not None: - return v.owner in self._cache - fg = self.fgraph - return fg is not None and v in fg.inputs - def _shape_i_var(self, v, i): key = (id(v), i) cached = self._shape_i_cache.get(key) @@ -177,25 +160,23 @@ def on_attach(self, fgraph): fgraph.shape_feature = self def on_detach(self, fgraph): - self._cache.clear() + self._shape_kernel_cache.clear() self._overrides.clear() self._shape_i_cache.clear() - self._materialized.clear() + self._materialized_dim_cache.clear() self._dim_kernel_cache.clear() self.fgraph = None if hasattr(fgraph, "shape_feature"): del fgraph.shape_feature def on_prune(self, fgraph, node, reason): - self._cache.pop(node, None) + self._shape_kernel_cache.pop(node, None) self._dim_kernel_cache.pop(node, None) - # Drop cached Shape_i variables whose owner is being pruned — without - # this the memo grows monotonically over a long canonicalize pass. + self._materialized_dim_cache.pop(node, None) for out in node.outputs: oid = id(out) - for j in range(getattr(out.type, "ndim", 0) or 0): + for j in range(getattr(out.type, "ndim", 0)): self._shape_i_cache.pop((oid, j), None) - self._materialized.pop((oid, j), None) self._overrides.pop(out, None) def on_change_input(self, fgraph, node, i, r, new_r, reason): @@ -210,6 +191,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason): # ``Shape_i``". if r is new_r or not hasattr(new_r.type, "ndim"): return + self._materialized_dim_cache.pop(node, None) if new_r in self._overrides: return if new_r.owner is None: @@ -246,9 +228,9 @@ def _reroute_dim(self, r, k, new_r_owner_inputs): """ if r.owner is None: return None - if (entry := self._cache.get(r.owner)) is None: + if (entry := self._shape_kernel_cache.get(r.owner)) is None: entry = self._build_kernel(r.owner) - self._cache[r.owner] = entry + self._shape_kernel_cache[r.owner] = entry kernel, meta = entry if kernel is None: return None @@ -294,12 +276,11 @@ def _override_shape_key(self, v, i, entry): new_owner_inputs = v.owner.inputs sv = dim_kernel.outputs[0] if sv.owner is None: - # Passthrough: the dim_kernel output is a kernel input directly. + # Passthrough: kernel output is one of its inputs (no computation). + # Collapse to the source input's shape_key directly so equality + # is transitive across chains of passthrough overrides. if isinstance(sv, Constant): - try: - return ("const", int(sv.data)) - except Exception: - return ("const", id(sv)) + return ("const", int(sv.data)) try: k_idx = dim_kernel.inputs.index(sv) except ValueError: @@ -312,13 +293,11 @@ def _override_shape_key(self, v, i, entry): return (dim_kernel, bindings) def _build_kernel(self, node): + # Phase 1: Deduplicate inputs. # When the same live input appears at multiple positions (e.g. - # ``Elemwise.add(x, x)``), share the dummy clone AND the dummy - # input-shape lscalars between those positions. Ops like Elemwise - # call ``broadcast_shape(*i_shapes)``, which only drops the runtime - # ``Assert`` guard when the incoming shape expressions are - # identical — so identity here is what lets ``x + x`` infer a - # clean shape instead of ``Assert(x.shape[0], ...)``. + # ``add(x, x)``), share the dummy clone and shape scalars so + # ``broadcast_shape`` sees identity-equal shapes and elides + # the runtime Assert. input_slot: dict[int, int] = {} unique_dummies: list[Variable] = [] unique_shape_tuples: list[tuple | None] = [] @@ -345,6 +324,7 @@ def _build_kernel(self, node): dummy_inputs.append(unique_dummies[slot]) dummy_input_shapes.append(unique_shape_tuples[slot]) + # Phase 2: Call infer_shape with dummy node. dummy_outputs = [out.clone() for out in node.outputs] dummy_node = Apply(node.op, dummy_inputs, dummy_outputs) @@ -369,13 +349,10 @@ def _build_kernel(self, node): if output_shapes is None: output_shapes = [None] * len(dummy_outputs) - # Fallback: Shape_i(i)(dummy_output) where the op couldn't provide - # an infer_shape for a given output. Reuse dummy_outputs — no extra - # placeholders. + # Phase 3: Coerce and validate each shape element returned by + # infer_shape. Static type shape overrides infer_shape when known, + # ensuring the canonical constant form. def coerce_shape_el(s, dummy_out): - # Accept any integer scalar Variable verbatim, and any Python / - # NumPy integer scalar as an int64 constant. Floats and - # non-scalar arrays are buggy returns and raise. if isinstance(s, np.ndarray): if s.ndim != 0: raise TypeError( @@ -407,10 +384,9 @@ def coerce_shape_el(s, dummy_out): f"shape element of type {type(s).__name__}: {s!r}" ) - # An output with missing/malformed ``infer_shape`` gets ``None`` - # here, which propagates to ``output_layout[k] = None``. ``get_shape`` - # / ``shape_key`` short-circuit that case to ``_shape_i_var(v, i)`` - # — no kernel slot, no ``fallback_out`` role. + # Outputs with missing/malformed infer_shape get None, which + # propagates to output_layout[k] = None. get_shape / shape_key + # short-circuit to _shape_i_var(v, i) for those. coerced_output_shapes = [] for k, dummy_out in enumerate(dummy_outputs): sh = output_shapes[k] if k < len(output_shapes) else None @@ -431,6 +407,8 @@ def coerce_shape_el(s, dummy_out): coerced.append(coerce_shape_el(s, dummy_out)) coerced_output_shapes.append(tuple(coerced)) + # Phase 4: Flatten per-output shape tuples into a single list. + # layout[k] records how many dims output k contributed (or None). flat_out = [] layout = [] for sh in coerced_output_shapes: @@ -440,16 +418,14 @@ def coerce_shape_el(s, dummy_out): layout.append(len(sh)) flat_out.extend(sh) - # ``meta`` carries only what ``get_shape`` / ``shape_key`` need to - # re-wire the frozen kernel against live ``node.inputs``. meta = {"output_layout": tuple(layout)} if not flat_out: return (None, meta) - # Build kernel_inputs with unique dummies only. Shape slots are - # attached by unique-slot index so duplicate live inputs share the - # same set of kernel-input positions. Each kernel_input needs a - # role that maps back to the live graph at materialization time. + # Phase 5: Build kernel inputs with roles. + # Three role types: input_slot (the dummy tensor itself), + # input_shape_slot (a shape scalar of a dummy), self_out (a dummy + # output referenced by infer_shape, e.g. Scan). kernel_inputs: list[Variable] = [] roles: list[tuple] = [] for slot, dummy in enumerate(unique_dummies): @@ -462,40 +438,16 @@ def coerce_shape_el(s, dummy_out): kernel_inputs.append(s) roles.append(("input_shape_slot", slot, j)) - # Some ``infer_shape`` impls (e.g. Scan) reference ``dummy_node.outputs`` - # directly inside the returned shape expression. Without an explicit - # substitution, ``_materialize_frozen`` would walk into ``dummy_node`` - # and rebuild it via ``make_node`` against live inputs, producing - # fresh-but-equivalent Apply nodes on every call and stalling - # EquilibriumGraphRewriter (``local_track_shape_i``). + # Some infer_shape impls (e.g. Scan) reference dummy_node.outputs + # in the returned expression. Register those as self_out inputs so + # materialization can substitute live outputs. anc_set = set(ancestors(flat_out)) for k, dummy_out in enumerate(dummy_outputs): if dummy_out in anc_set and dummy_out not in kernel_inputs: kernel_inputs.append(dummy_out) roles.append(("self_out", k)) - # Sanity: every free Variable in flat_out should be in kernel_inputs. - # An orphan indicates a buggy ``infer_shape`` that leaked a variable - # outside of ``node.inputs`` / their shape scalars. In development - # mode (config.on_shape_error == "raise") we surface this eagerly - # instead of silently falling back to ``Shape_i``. - kernel_input_set = set(kernel_inputs) - for anc in ancestors(flat_out): - if anc.owner is None: - if isinstance(anc, Constant): - continue - if anc not in kernel_input_set: - msg = ( - f"Op {node.op}.infer_shape leaked an orphan variable " - f"{anc!r} that is not one of node.inputs or their " - f"shape scalars; falling back to Shape_i." - ) - if config.on_shape_error == "raise": - raise ShapeError(msg) - return (None, dict(meta, kernel_build_error=msg)) - - # Find any live input index that maps to this slot, so materialization - # can look up ``node.inputs[]``. + # Phase 6: Map unique slots back to live input indices. slot_to_input_idx: list[int] = [-1] * len(unique_dummies) for inp_idx, inp in enumerate(node.inputs): s = input_slot[id(inp)] @@ -514,64 +466,65 @@ def coerce_shape_el(s, dummy_out): def get_shape(self, v, i): if hasattr(v.type, "shape") and v.type.shape[i] is not None: return constant(v.type.shape[i], dtype="int64") - cache_key = (id(v), i) - if (cached := self._materialized.get(cache_key)) is not None: - return cached + node = v.owner + if node is None: + return self._shape_i_var(v, i) + + node_cache = self._materialized_dim_cache.get(node) + if node_cache is not None: + cached = node_cache.get((v, i)) + if cached is not None: + return cached + else: + node_cache = {} + self._materialized_dim_cache[node] = node_cache + + def _walk(dim_kernel, memo): + for fa in dim_kernel.toposort(): + new_inputs = [memo.get(inp, inp) for inp in fa.inputs] + new_node = fa.op.make_node(*new_inputs) + memo.update(zip(fa.outputs, new_node.outputs, strict=True)) + raw = memo.get(dim_kernel.outputs[0], dim_kernel.outputs[0]) + return self._canonicalize_live_shape(raw) + + def _fallback(): + node_cache[(v, i)] = result = self._shape_i_var(v, i) + return result if (ov := self._overrides.get(v)) is not None: entry = ov[i] if entry is None: - return self._shape_i_var(v, i) + return _fallback() dim_kernel, role_bindings = entry - new_owner_inputs = v.owner.inputs - memo: dict = { - k_input: self.get_shape(new_owner_inputs[idx], dim) + memo = { + k_input: self.get_shape(node.inputs[idx], dim) for k_input, (idx, dim) in zip( dim_kernel.inputs, role_bindings, strict=True ) } - for fa in dim_kernel.toposort(): - new_inputs = [memo.get(inp, inp) for inp in fa.inputs] - new_node = fa.op.make_node(*new_inputs) - memo.update(zip(fa.outputs, new_node.outputs)) - raw = memo.get(dim_kernel.outputs[0], dim_kernel.outputs[0]) - result = self._materialized[cache_key] = self._canonicalize_live_shape(raw) + node_cache[(v, i)] = result = _walk(dim_kernel, memo) return result - node = v.owner - - if node is None: - return self._shape_i_var(v, i) - - if (entry := self._cache.get(node)) is None: - self._cache[node] = entry = self._build_kernel(node) + if (entry := self._shape_kernel_cache.get(node)) is None: + self._shape_kernel_cache[node] = entry = self._build_kernel(node) kernel, meta = entry if kernel is None: - self._materialized[cache_key] = result = self._shape_i_var(v, i) - return result + return _fallback() out_idx = node.outputs.index(v) layout = meta["output_layout"] if layout[out_idx] is None: - self._materialized[cache_key] = result = self._shape_i_var(v, i) - return result + return _fallback() slot = sum((layout[k] or 0) for k in range(out_idx)) + i dk = self._dim_kernel(node, slot) if dk is None: - self._materialized[cache_key] = result = self._shape_i_var(v, i) - return result + return _fallback() dim_kernel, used_roles = dk - # Seed memo with the live binding for each used kernel input, - # then walk the kernel's cached topological order rebuilding - # each ``FrozenApply`` against live ``make_node`` calls. Fresh - # ``make_node`` (rather than ``graph_replace``) is required — - # the latter would mutate the globally-interned ``FrozenApply`` - # nodes via ``Apply.clone_with_new_inputs``. slot_to_input_idx = meta["slot_to_input_idx"] - memo: dict = {} + memo = {} for k_input, role in zip(dim_kernel.inputs, used_roles, strict=True): tag = role[0] if tag == "input_slot": @@ -580,15 +533,9 @@ def get_shape(self, v, i): memo[k_input] = self.get_shape( node.inputs[slot_to_input_idx[role[1]]], role[2] ) - else: # self_out + else: memo[k_input] = node.outputs[role[1]] - for fa in dim_kernel.toposort(): - new_inputs = [memo.get(inp, inp) for inp in fa.inputs] - new_node = fa.op.make_node(*new_inputs) - memo.update(zip(fa.outputs, new_node.outputs, strict=True)) - raw = memo.get(dim_kernel.outputs[0], dim_kernel.outputs[0]) - result = self._canonicalize_live_shape(raw) - self._materialized[cache_key] = result + node_cache[(v, i)] = result = _walk(dim_kernel, memo) return result def unaliased_shape_tuple(self, v, dims=None): @@ -653,18 +600,18 @@ def _dim_kernel(self, node, slot): ``shape_key`` collapse the structural comparison to one hash and only descend into inputs that are themselves shape lookups. """ - per_node = self._dim_kernel_cache.get(node) - if per_node is None: - per_node = {} - self._dim_kernel_cache[node] = per_node - if slot in per_node: - return per_node[slot] - if (entry := self._cache.get(node)) is None: + node_cache = self._dim_kernel_cache.get(node) + if node_cache is None: + node_cache = {} + self._dim_kernel_cache[node] = node_cache + if slot in node_cache: + return node_cache[slot] + if (entry := self._shape_kernel_cache.get(node)) is None: entry = self._build_kernel(node) - self._cache[node] = entry + self._shape_kernel_cache[node] = entry kernel, meta = entry if kernel is None: - per_node[slot] = None + node_cache[slot] = None return None sv = kernel.outputs[slot] kernel_input_set = set(kernel.inputs) @@ -677,10 +624,10 @@ def _dim_kernel(self, node, slot): try: dim_kernel = FrozenFunctionGraph(used_inputs, [sv]) except Exception: - per_node[slot] = None + node_cache[slot] = None return None result = (dim_kernel, used_roles) - per_node[slot] = result + node_cache[slot] = result return result def shape_key(self, v, i): @@ -737,9 +684,9 @@ def shape_key(self, v, i): node = v.owner if node is None: return ("leaf", id(v), i) - if (entry := self._cache.get(node)) is None: + if (entry := self._shape_kernel_cache.get(node)) is None: entry = self._build_kernel(node) - self._cache[node] = entry + self._shape_kernel_cache[node] = entry kernel, meta = entry if kernel is None: return ("leaf", id(v), i) @@ -768,10 +715,7 @@ def bind(role): # key matches the underlying live var's own shape_key. if sv.owner is None: if isinstance(sv, Constant): - try: - return ("const", int(sv.data)) - except Exception: - return ("const", id(sv)) + return ("const", int(sv.data)) try: k_idx = kernel.inputs.index(sv) except ValueError: