Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def perform(self, node, inputs, output_storage):
def pullback(self, inputs, outputs, output_gradients):
return [disconnected_type(), *output_gradients]

def infer_shape(self, fgraph, inputs, input_shapes):
def infer_shape(self, inputs, input_shapes):
# Return the shape of every input but the condition (first input)
return input_shapes[1:]

Expand Down
81 changes: 32 additions & 49 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from collections.abc import Callable, Sequence
from copy import copy
from functools import partial
from itertools import chain
from typing import cast

from pytensor.compile.maker import function
Expand All @@ -23,70 +22,54 @@
from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph
from pytensor.graph.null_type import NullType
from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern
from pytensor.graph.replace import clone_replace
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.graph.traversal import graph_inputs
from pytensor.graph.utils import MissingInputError


def infer_shape(outs, inputs, input_shapes):
"""
Compute the shape of the outputs given the shape of the inputs of an PyTensor
graph.

We do it this way to avoid compiling the inner function just to get
the shape. Changes to ShapeFeature could require changes in this function.
"""Compute the shape of ``outs`` given the shape of ``inputs``.

Builds per-Apply shape kernels via ``ShapeFeature`` and then
rebinds each inner-input leaf — surfaced as ``Shape_i(j)(inp)`` in
the materialized exprs — to the caller-supplied outer dim. No
compile of the inner function required.
"""
# We use a ShapeFeature because it has all the necessary logic
# inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will
# need to do it manually

# TODO: ShapeFeature should live elsewhere
from pytensor.tensor.rewriting.shape import ShapeFeature

for inp, inp_shp in zip(inputs, input_shapes, strict=True):
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
assert len(inp_shp) == inp.type.ndim
raise ValueError(
f"input {inp} has {inp.type.ndim} dims, got shape with {len(inp_shp)}"
)

shape_feature = ShapeFeature()
fgraph = FunctionGraph([], [], features=[shape_feature])
for v in chain.from_iterable(s for s in input_shapes if s is not None):
# Import input_shape nodes, as for some graphs ShapeFeature assumes these were seen before
if (node := v.owner) is not None:
fgraph.import_node(node, import_missing=True)
feature = ShapeFeature()
out_shapes = [feature.shape_tuple(o) for o in outs]

# Initialize shape_of with the input shapes
# ``feature.get_shape(inp, j)`` is the same memoized instance that
# appears at the leaves of ``out_shapes`` — ``Shape_i(j)(inp)`` for
# unknown dims, ``Constant`` for static dims. Rebind the Shape_i
# leaves to the caller-supplied scalars; static-dim Constants are
# skipped (no owner) so the static type wins, matching prior behavior.
replacements = {}
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
shape_feature.set_shape(inp, inp_shp, override=True)

def local_traverse(out):
"""
Go back in the graph, from out, adding computable shapes to shape_of.

"""
if out in shape_feature.shape_of:
# Its shape is already known
return
elif out.owner is None:
# This is an input of the graph
shape_feature.init_r(out)
else:
# Recurse over inputs
for inp in out.owner.inputs:
if inp not in shape_feature.shape_of:
local_traverse(inp)
if inp_shp is None or not hasattr(inp.type, "ndim"):
continue
for j in range(inp.type.ndim):
leaf = feature.get_shape(inp, j)
if leaf.owner is not None:
replacements[leaf] = inp_shp[j]

# shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately
dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
if not replacements:
return out_shapes

ret = []
for o in outs:
local_traverse(o)
ret.append(shape_feature.shape_of[o])
return ret
# ``strict=False``: an inner input may not be reachable from every
# output, so its leaf won't appear in every shape expression.
return [
None if s is None else tuple(graph_replace(list(s), replacements, strict=False))
for s in out_shapes
]


def construct_nominal_fgraph(
Expand Down Expand Up @@ -884,7 +867,7 @@ def connection_pattern(self, node):
self._connection_pattern = ret
return ret

def infer_shape(self, fgraph, node, shapes):
def infer_shape(self, node, shapes):
# TODO: Use `fgraph.shape_feature` to do this instead.
out_shapes = infer_shape(self.inner_outputs, self.inner_inputs, shapes)

Expand Down
10 changes: 5 additions & 5 deletions pytensor/compile/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class ViewOp(TypeCastingOp):
def make_node(self, x):
return Apply(self, [x], [x.type()])

def infer_shape(self, fgraph, node, input_shapes):
def infer_shape(self, node, input_shapes):
return input_shapes

def pullback(self, args, outputs, g_outs):
Expand Down Expand Up @@ -179,7 +179,7 @@ def c_code(self, node, name, inames, onames, sub):
# Else, no C code
raise NotImplementedError()

def infer_shape(self, fgraph, node, input_shapes):
def infer_shape(self, node, input_shapes):
return input_shapes


Expand Down Expand Up @@ -251,8 +251,8 @@ def __reduce__(self):
)
return load_back, (mod, name)

def _infer_shape(self, fgraph, node, input_shapes):
return self.__infer_shape(fgraph, node, input_shapes)
def _infer_shape(self, node, input_shapes):
return self.__infer_shape(node, input_shapes)


def as_op(itypes, otypes, infer_shape=None):
Expand All @@ -275,7 +275,7 @@ def wrap_py(itypes, otypes, infer_shape=None):
It takes an optional infer_shape parameter that should be a callable with
this signature:

def infer_shape(fgraph, node, input_shapes):
def infer_shape(node, input_shapes):
...
return output_shapes

Expand Down
108 changes: 107 additions & 1 deletion pytensor/graph/replace.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Callable, Collection, Iterable, Mapping, Sequence
from functools import singledispatch
from typing import cast, overload

Expand All @@ -11,6 +11,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.traversal import (
general_toposort,
toposort,
truncated_graph_inputs,
)
Expand Down Expand Up @@ -212,6 +213,111 @@ def toposort_key(
return fg.outputs[0]


def break_aliasing_cycles(
outputs: Sequence[Variable],
destroyers_of: Callable[[Variable], Collection[Apply]],
) -> list[Variable]:
"""Break aliasing-induced ordering cycles in ``outputs``.

An inplace Op ``D`` overwrites one of its inputs ``x`` in place, so
``D``'s output ``y`` aliases ``x``'s storage. Any client that reads
the pre-overwrite ``x`` must therefore run *before* ``D``, and any
client that reads ``y`` must run *after*. A cycle arises when a
single Apply ``A`` does both — reads ``x`` directly *and* has another
input that (directly or transitively) depends on ``y``. ``A`` would
have to run before ``D`` and after it. No valid schedule exists.

This function finds every such ``A`` in ``outputs``' ancestry and
re-routes ``x`` *on that one Apply only* through ``deep_copy_op``.
``A`` then reads the copy instead of the aliased original, lifting
the "before" constraint. ``D`` keeps reading ``x`` directly; the
rest of the graph is untouched.

Multiple outputs share one topological pass; an Apply reachable from
more than one output is analyzed once, and an aliased value patched
across outputs gets a single shared ``deep_copy_op`` wrapper. Returns
``outputs`` unchanged when no Apply exhibits the pattern.

Parameters
----------
outputs
Roots of the sub-graph to scan.
destroyers_of
Callable returning the Apply nodes that overwrite a given
Variable in place (empty when none). Typically
``fgraph.destroyers`` from a ``FunctionGraph`` with an attached
``DestroyHandler``, but this function makes no assumption about
provenance — the caller is responsible for the lookup's
meaningfulness, and for skipping the call when there are no
inplace ops in the first place (the ancestry is walked
unconditionally).
"""
from pytensor.compile.ops import deep_copy_op

deps: dict[Variable, frozenset[Variable]] = {}
substitutes: dict[Variable, Variable] = {}
replacements: dict[Variable, Variable] = {}
# ``general_toposort`` guarantees inputs are visited before consumers,
# so ``deps`` for every input is final by the time we look at an Apply.
for v in general_toposort(
outputs, lambda v: v.owner.inputs if v.owner is not None else []
):
if v.owner is None:
deps[v] = frozenset()
continue
node = v.owner

# Accumulate this Variable's destroyer-output reach: union of the
# parents' reaches, plus any parent that is itself an output of an
# inplace Apply.
d: set[Variable] = set()
for inp in node.inputs:
d |= deps[inp]
if inp.owner is not None and inp.owner.op.destroy_map:
d.add(inp)
deps[v] = frozenset(d)

if node.op.destroy_map:
# Inplace Apply — preserve as-is; never enters ``replacements``
# so ``graph_replace`` leaves it alone.
continue

# Cycle-pattern check per destroyed input on ``node``: a destroyed
# input ``i`` triggers the pattern iff some *other* input has the
# destroyer's output in its reach.
new_inputs = list(node.inputs)
changed = False
for i, inp in enumerate(node.inputs):
inp_destroyers = destroyers_of(inp)
if not inp_destroyers:
continue
other_deps: set[Variable] = set()
for j, other_inp in enumerate(node.inputs):
if j == i:
continue
other_deps |= deps[other_inp]
if other_inp.owner is not None and other_inp.owner.op.destroy_map:
other_deps.add(other_inp)
if any(
out in other_deps for c_app in inp_destroyers for out in c_app.outputs
):
if inp not in substitutes:
substitutes[inp] = cast(Variable, deep_copy_op(inp))
new_inputs[i] = substitutes[inp]
changed = True
if changed:
new_node = node.op.make_node(*new_inputs)
replacements.update(zip(node.outputs, new_node.outputs, strict=True))

if not replacements:
return list(outputs)

# ``graph_replace`` walks each output, substitutes any matched Apply
# outputs with the patched version, and rebuilds whatever's downstream
# — composing stacked patches automatically.
return graph_replace(list(outputs), replace=replacements)


@singledispatch
def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply | Sequence[Variable]:
# Default implementation is provided in pytensor.tensor.blockwise
Expand Down
2 changes: 1 addition & 1 deletion pytensor/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __str__(self):
args.append("inplace")
return f"if{{{','.join(args)}}}"

def infer_shape(self, fgraph, node, inputs_shapes):
def infer_shape(self, node, inputs_shapes):
# By construction, corresponding then/else pairs have the same number
# of dimensions

Expand Down
2 changes: 1 addition & 1 deletion pytensor/raise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def c_code(self, node, name, inames, onames, props):
def c_code_cache_version(self):
return (2,)

def infer_shape(self, fgraph, node, input_shapes):
def infer_shape(self, node, input_shapes):
return [input_shapes[0]]

def do_constant_folding(self, fgraph, node):
Expand Down
2 changes: 1 addition & 1 deletion pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,7 +2251,7 @@ def perform(self, node, inputs, output_storage):
self.t_call = t_call
self.t_fn = t_fn

def infer_shape(self, fgraph, node, input_shapes):
def infer_shape(self, node, input_shapes):
# input_shapes correspond to the shapes of node.inputs
for inp, inp_shp in zip(node.inputs, input_shapes, strict=True):
assert inp_shp is None or len(inp_shp) == inp.type.ndim
Expand Down
33 changes: 17 additions & 16 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
)
from pytensor.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs
from pytensor.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink
from pytensor.tensor.shape import shape
from pytensor.tensor.shape import Shape_i, shape
from pytensor.tensor.subtensor import (
IncSubtensor,
Subtensor,
Expand Down Expand Up @@ -1358,6 +1358,15 @@ def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool:
return not broadcasted_by(init_value_.squeeze(0), init_buffer[0])


def _inferred_shape_or_fallback(shape_feature, v, i, fallback):
"""Return ``shape_feature.get_shape(v, i)`` if it's better than ``Shape_i(v)``, else *fallback*."""
if shape_feature is not None:
s = shape_feature.get_shape(v, i)
if not (s.owner and isinstance(s.owner.op, Shape_i) and s.owner.inputs[0] is v):
return s
return fallback


def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
r"""Graph optimizer that reduces scan memory consumption.

Expand Down Expand Up @@ -1405,13 +1414,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
position in the outer circular buffer. This would invalidate results,
if the input is still needed for some other output computation.
"""
if hasattr(fgraph, "shape_feature"):
shape_of = fgraph.shape_feature.shape_of
else:
# Each access to shape_of is in a try..except block in order to
# use a default version when the variable is not in the shape_of
# dictionary.
shape_of = {}
shape_feature = getattr(fgraph, "shape_feature", None)
# 1. Initialization of variables
# Note 1) We do not actually care about outputs representing shared
# variables (those have no intermediate values) so it is safer to
Expand Down Expand Up @@ -1503,15 +1506,13 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:

# 2.3.2 extract the begin/end of the first dimension
if i >= op_info.n_mit_mot:
try:
length = shape_of[out][0]
except KeyError:
length = node.inputs[0] + init_l[i]
length = _inferred_shape_or_fallback(
shape_feature, out, 0, node.inputs[0] + init_l[i]
)
else:
try:
length = shape_of[out][0]
except KeyError:
length = out.shape[0]
length = _inferred_shape_or_fallback(
shape_feature, out, 0, out.shape[0]
)
cf_slice = get_canonical_form_slice(this_slice[0], length)
slices[i] += [(cf_slice, this_slice)] # type: ignore

Expand Down
Loading
Loading