diff --git a/backends/cadence/aot/reorder_ops.py b/backends/cadence/aot/reorder_ops.py index 2774b3d7477..4ad21ca87bc 100644 --- a/backends/cadence/aot/reorder_ops.py +++ b/backends/cadence/aot/reorder_ops.py @@ -11,7 +11,7 @@ from collections import defaultdict from math import prod -from typing import Callable, cast, DefaultDict, List, Tuple +from typing import Callable, cast, DefaultDict, List, Optional, Tuple import torch import torch.fx @@ -781,6 +781,239 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class MoveSliceBeforeViewPass(RemoveOrReplacePassInterface): + """Move a slice_copy above a view_copy when the slice is re-expressible as a + single slice on one dim of the pre-view tensor. + + Rewrites view(x) -> slice(dim=d, start, end, step) into + slice(x, dim=d', start', end', step') -> view(sliced, slice_out_shape), so the + slice lands directly on x. This may be useful in attention patterns, where + we view outputs of a large linear into a new shape where the number of + attention heads are the last dim, and we need to run independent computation + per head. Moving the slice before the view can allow us to then directly slice + the constant linear weights. + + A view is a contiguous reshape: it never moves or reorders elements, it only + re-groups the shared row-major index space into different dims. A slice keeps + an arithmetic progression of indices (start, start+step, ...) along one viewed + dim, and that progression collapses back to a *single* slice on one pre-view + dim exactly when the row-major strides line up. ``_derive_pre_view_slice`` + handles the three cases that qualify: + + * untouched dim: the viewed dim is left unchanged by the view -- same size + and same inner stride as some pre-view dim -- so the slice copies over + verbatim (any step). + * contiguous: the viewed dim and a pre-view dim span the same flat extent + (a split's outermost factor, or a merge that aligns), so a contiguous + (step==1) slice maps to a contiguous pre-view slice. + * strided: the viewed dim is an innermost factor of a pre-view dim + (identical inner stride) selected width-1, so it maps to a strided + pre-view slice with step == the viewed dim's size. + + Everything else -- middle factors, wider strided selections -- is block-strided + (runs separated by gaps), which no single slice can express, so it is left + unchanged. + + Each slice is handled independently, so a view that fans out to several slices + is rewritten one slice at a time and the now-dead view is removed by dead-code + elimination -- there is no single-user requirement on the view. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.slice_copy.Tensor] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + view_node = get_arg(node, "input", torch.fx.Node) + if view_node.target != exir_ops.edge.aten.view_copy.default: + return False + + x_node = get_arg(view_node, "input", torch.fx.Node) + pre_view_shape = tuple(x_node.meta["val"].shape) + post_view_shape = tuple(view_node.meta["val"].shape) + if 0 in pre_view_shape or 0 in post_view_shape: + return False + + dim = get_arg(node, "dim", int) + if dim < 0: + dim += len(post_view_shape) + post_view_size = post_view_shape[dim] + + bounds = self._normalize_slice(node, post_view_size) + if bounds is None: + return False + start, stop, step = bounds + + # The slice's own output shape gives the selected-element count along the + # sliced dim directly -- it is exactly output_shape[dim]. + slice_out_shape = tuple(node.meta["val"].shape) + post_view_count = slice_out_shape[dim] + if post_view_count == 0: + return False + + # Row-major stride of the sliced viewed dim, and of every pre-view dim. + post_view_stride = prod(post_view_shape[dim + 1 :]) + pre_view_strides = self._row_major_strides(pre_view_shape) + + derived = self._derive_pre_view_slice( + pre_view_shape, + pre_view_strides, + post_view_stride, + post_view_size, + start, + stop, + step, + post_view_count, + ) + if derived is None: + return False + pre_view_dim, pre_view_start, pre_view_stop, pre_view_step = derived + + graph = node.graph + with graph.inserting_before(node): + new_slice_args = ( + x_node, + pre_view_dim, + pre_view_start, + pre_view_stop, + pre_view_step, + ) + new_slice = graph.create_node( + "call_function", + exir_ops.edge.aten.slice_copy.Tensor, + args=new_slice_args, + ) + new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor( + x_node.meta["val"], *new_slice_args[1:] + ) + new_view = graph.create_node( + "call_function", + exir_ops.edge.aten.view_copy.default, + args=(new_slice, list(slice_out_shape)), + ) + new_view.meta["val"] = exir_ops.edge.aten.view_copy.default( + new_slice.meta["val"], list(slice_out_shape) + ) + + node.replace_all_uses_with(new_view) + return True + + @staticmethod + def _row_major_strides(shape: tuple[int, ...]) -> list[int]: + """Row-major (contiguous) strides for ``shape``.""" + strides = [1] * len(shape) + acc = 1 + for i in range(len(shape) - 1, -1, -1): + strides[i] = acc + acc *= shape[i] + return strides + + def _normalize_slice( + self, node: torch.fx.Node, post_view_size: int + ) -> Optional[tuple[int, int, int]]: + """Resolve the slice to concrete, clamped ``(start, stop, step)`` ints, or + None if the bounds are dynamic or the step is non-positive (neither of + which this pass handles).""" + step = get_arg(node, "step") + + if not isinstance(step, int): + return None + + if step <= 0: + return None + + raw_start = get_arg(node, "start") + raw_stop = get_arg(node, "end") + + # Make sure raw_start/raw_stop are not symbolic. + if (raw_start is not None and not isinstance(raw_start, int)) or ( + raw_stop is not None and not isinstance(raw_stop, int) + ): + return None + + start = 0 if raw_start is None else raw_start + stop = post_view_size if raw_stop is None else raw_stop + if start < 0: + start += post_view_size + if stop < 0: + stop += post_view_size + start = max(0, min(start, post_view_size)) + stop = max(0, min(stop, post_view_size)) + return start, stop, step + + def _derive_pre_view_slice( + self, + pre_view_shape: tuple[int, ...], + pre_view_strides: list[int], + post_view_stride: int, + post_view_size: int, + start: int, + stop: int, + step: int, + post_view_count: int, + ) -> tuple[int, int, int, int] | None: + """Return ``(dim, start, stop, step)`` for the single pre-view-tensor slice + equivalent to slicing the viewed dim, or None if no single pre-view slice + reproduces it. + + Both shapes index the same row-major flat space, so the sliced viewed dim + (size ``post_view_size``, inner stride ``post_view_stride``) lines up with + one pre-view dim (size ``pre_view_size``, inner stride ``pre_view_stride``) + in one of three ways. + """ + for pre_view_dim, (pre_view_stride, pre_view_size) in enumerate( + zip(pre_view_strides, pre_view_shape) + ): + # Untouched: the viewed dim is identical to this pre-view dim (same + # size and same inner stride), so the slice applies verbatim, any step. + if pre_view_stride == post_view_stride and pre_view_size == post_view_size: + return pre_view_dim, start, stop, step + + # Contiguous: the viewed dim and this pre-view dim span the same flat + # extent (same period), and the selected band aligns to this dim's + # boundaries. A contiguous (step==1) viewed slice + # [start, start+post_view_count) is the flat band [start* + # post_view_stride, (start+post_view_count)*post_view_stride), a + # contiguous slice on this pre-view dim iff both ends are multiples of + # its stride. + if ( + step == 1 + and post_view_size * post_view_stride == pre_view_size * pre_view_stride + ): + flat_start = start * post_view_stride + flat_stop = (start + post_view_count) * post_view_stride + if ( + flat_start % pre_view_stride == 0 + and flat_stop % pre_view_stride == 0 + ): + return ( + pre_view_dim, + flat_start // pre_view_stride, + flat_stop // pre_view_stride, + 1, + ) + + # Strided is the ONLY way the reshape itself introduces a stride, and + # it requires a width-1 selection (post_view_count == 1): the viewed + # dim is an innermost factor of this pre-view dim (identical inner + # stride), so fixing that single factor index and letting the rest of + # the pre-view dim run yields a uniform stride equal to the viewed dim's + # size. Any wider selection (post_view_count > 1) of an inner factor + # leaves runs separated by gaps -- block-strided, not a single slice -- + # so width-1 is required. + if ( + post_view_count == 1 + and post_view_size > 1 + and pre_view_stride == post_view_stride + and pre_view_size % post_view_size == 0 + ): + pre_view_count = pre_view_size // post_view_size + pre_view_stop = start + (pre_view_count - 1) * post_view_size + 1 + return pre_view_dim, start, pre_view_stop, post_view_size + return None + + @register_cadence_pass(CadencePassAttribute(opt_level=1)) class PropagateSlice(RemoveOrReplacePassInterface): """Propagate slice_copy before element-wise ops when the cost model diff --git a/backends/cadence/aot/tests/test_reorder_ops_passes.py b/backends/cadence/aot/tests/test_reorder_ops_passes.py index 0253772a7b9..b0528f98f58 100644 --- a/backends/cadence/aot/tests/test_reorder_ops_passes.py +++ b/backends/cadence/aot/tests/test_reorder_ops_passes.py @@ -24,6 +24,7 @@ AdvanceQuantizeOpAboveDefChainPass, AdvanceQuantizeOpAboveDefInBranchPass, MoveSliceBeforePermutePass, + MoveSliceBeforeViewPass, PostponeDequantizeOpBelowUseChainPass, PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, PropagateSlice, @@ -765,6 +766,218 @@ def test_non_dim0_slice_always_moved(self) -> None: self.assertTrue(result.modified) +class TestMoveSliceBeforeViewPass(unittest.TestCase): + @staticmethod + def _shapes_by_target( + gm: torch.fx.GraphModule, target: object + ) -> list[tuple[int, ...]]: + """Output shapes of every node with the given target, in graph order.""" + return [ + tuple(node.meta["val"].shape) + for node in gm.graph.nodes + if node.target == target + ] + + def _assert_slice_and_view_shapes( + self, + gm: torch.fx.GraphModule, + slice_shapes: list[tuple[int, ...]], + view_shapes: list[tuple[int, ...]], + ) -> None: + self.assertEqual( + self._shapes_by_target(gm, exir_ops.edge.aten.slice_copy.Tensor), + slice_shapes, + ) + self.assertEqual( + self._shapes_by_target(gm, exir_ops.edge.aten.view_copy.default), + view_shapes, + ) + + def test_strided_innermost_move(self) -> None: + """Splitting the last dim into [...,2] then slicing that 2 (innermost, + width 1) becomes a strided slice on the pre-view tensor, then view.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 16, 256)) + viewed = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [1, 16, 4, 32, 2]), + ) + sliced = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(viewed, 4, 0, 1, 1), + ) + builder.output([sliced]) + original = builder.get_graph_module() + + result = transform_and_check_numerics( + original, + (torch.randn(1, 16, 256),), + MoveSliceBeforeViewPass(), + ) + self.assertTrue(result.modified) + + nodes = get_compute_nodes_in_gm(result.graph_module) + self.assertEqual(len(nodes), 2) + self.assertEqual(nodes[0], exir_ops.edge.aten.slice_copy) + self.assertEqual(nodes[1], exir_ops.edge.aten.view_copy) + # The strided pre-view slice [0:255:2] keeps 128 of the 256 elements, + # then the view restores the (sliced) viewed shape. + self._assert_slice_and_view_shapes( + result.graph_module, [(1, 16, 128)], [(1, 16, 4, 32, 1)] + ) + + def test_fanout_both_slices_move(self) -> None: + """A view that fans out to even/odd slices: each is pushed before the + view independently and the now-dead shared view is removed.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 16, 256)) + viewed = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [1, 16, 4, 32, 2]), + ) + even = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(viewed, 4, 0, 1, 1), + ) + odd = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(viewed, 4, 1, 2, 1), + ) + builder.output([even, odd]) + original = builder.get_graph_module() + + result = transform_and_check_numerics( + original, + (torch.randn(1, 16, 256),), + MoveSliceBeforeViewPass(), + ) + self.assertTrue(result.modified) + + nodes = get_compute_nodes_in_gm(result.graph_module) + self.assertEqual(sum(n == exir_ops.edge.aten.slice_copy for n in nodes), 2) + self.assertEqual(sum(n == exir_ops.edge.aten.view_copy for n in nodes), 2) + # Each fanned-out slice keeps half the elements and gets its own view. + self._assert_slice_and_view_shapes( + result.graph_module, + [(1, 16, 128), (1, 16, 128)], + [(1, 16, 4, 32, 1), (1, 16, 4, 32, 1)], + ) + + def test_contiguous_outermost_move(self) -> None: + """Slicing the outermost factor of a split dim → contiguous pre-view + slice, then view.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 16, 256)) + viewed = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [1, 16, 2, 128]), + ) + sliced = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(viewed, 2, 0, 1, 1), + ) + builder.output([sliced]) + original = builder.get_graph_module() + + result = transform_and_check_numerics( + original, + (torch.randn(1, 16, 256),), + MoveSliceBeforeViewPass(), + ) + self.assertTrue(result.modified) + + nodes = get_compute_nodes_in_gm(result.graph_module) + self.assertEqual(len(nodes), 2) + self.assertEqual(nodes[0], exir_ops.edge.aten.slice_copy) + self.assertEqual(nodes[1], exir_ops.edge.aten.view_copy) + self._assert_slice_and_view_shapes( + result.graph_module, [(1, 16, 128)], [(1, 16, 1, 128)] + ) + + def test_contiguous_outer_factor_width_two_move(self) -> None: + """Slicing the first two of the outermost factor (size 4) is still a + contiguous pre-view slice [0:128], then view → (1,16,2,32,2).""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 16, 256)) + viewed = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [1, 16, 4, 32, 2]), + ) + sliced = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(viewed, 2, 0, 2, 1), + ) + builder.output([sliced]) + original = builder.get_graph_module() + + result = transform_and_check_numerics( + original, + (torch.randn(1, 16, 256),), + MoveSliceBeforeViewPass(), + ) + self.assertTrue(result.modified) + + nodes = get_compute_nodes_in_gm(result.graph_module) + self.assertEqual(len(nodes), 2) + self.assertEqual(nodes[0], exir_ops.edge.aten.slice_copy) + self.assertEqual(nodes[1], exir_ops.edge.aten.view_copy) + self._assert_slice_and_view_shapes( + result.graph_module, [(1, 16, 128)], [(1, 16, 2, 32, 2)] + ) + + def test_strided_outer_factor_not_moved(self) -> None: + """A strided (step>1, width>1) selection of the outermost factor is a + block-strided pattern, not a single pre-view slice → left unchanged.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 16, 256)) + viewed = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [1, 16, 4, 32, 2]), + ) + sliced = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(viewed, 2, 0, 4, 2), + ) + builder.output([sliced]) + original = builder.get_graph_module() + + result = cast(PassResult, MoveSliceBeforeViewPass()(original)) + self.assertFalse(result.modified) + + def test_block_strided_not_moved(self) -> None: + """Slicing a middle factor yields a block-strided selection that is not a + single pre-view slice → left unchanged.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 16, 256)) + viewed = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [1, 16, 4, 2, 32]), + ) + sliced = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(viewed, 3, 0, 1, 1), + ) + builder.output([sliced]) + original = builder.get_graph_module() + + result = cast(PassResult, MoveSliceBeforeViewPass()(original)) + self.assertFalse(result.modified) + + def test_non_view_input_no_change(self) -> None: + """A slice whose input is not a view is left unchanged.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 16, 256)) + sliced = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(x, 2, 0, 128, 1), + ) + builder.output([sliced]) + original = builder.get_graph_module() + + result = cast(PassResult, MoveSliceBeforeViewPass()(original)) + self.assertFalse(result.modified) + + class TestPropagateSlice(unittest.TestCase): def test_swap_quantize_slice(self) -> None: builder = GraphBuilder()