From 2d228fcceebd7f01a86b7a4cc464dbcca0e08835 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Fri, 24 Apr 2026 17:21:46 -0700 Subject: [PATCH] Reorder slice before permute Summary: Move slice_copy ops before permute_copy to reduce permute data volume when profitable. Only transforms single-user permutes. Cost model: dim-0 slices are nop-eligible after MakeSliceAndCatDimOutermostPass. Moving such a slice loses the nop, so we only move it when the slice removes more than half the data (permute savings outweigh the nop loss). Non-dim-0 slices have no nop opportunity, so any permute savings is pure win. Not added to CadenceReorderOpsInGraph to avoid interaction with MakeSliceAndCatDimOutermostPass in the default pipeline. Differential Revision: D102426699 --- backends/cadence/aot/reorder_ops.py | 84 +++++++++- .../aot/tests/test_reorder_ops_passes.py | 146 ++++++++++++++++++ 2 files changed, 229 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/reorder_ops.py b/backends/cadence/aot/reorder_ops.py index e14471bc7ed..1b6ec19be06 100644 --- a/backends/cadence/aot/reorder_ops.py +++ b/backends/cadence/aot/reorder_ops.py @@ -11,13 +11,14 @@ from collections import defaultdict from math import prod -from typing import DefaultDict, List, Tuple +from typing import cast, DefaultDict, List, Tuple import torch import torch.fx from executorch.backends.cadence.aot.compiler_utils import get_placeholders, get_shape from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, + get_arg, get_overload_packet, register_cadence_pass, RemoveOrReplacePassInterface, @@ -641,6 +642,87 @@ class PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView( pass +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class MoveSliceBeforePermutePass(RemoveOrReplacePassInterface): + """Move slice_copy ops before permute_copy to reduce permute data volume. + + Rewrites permute(input, perm) -> slice(dim=D) into + slice(input, dim=perm[D]) -> permute(sliced, perm), so the permute + operates on a smaller tensor. + + Cost model: dim-0 slices are nop-eligible (zero-copy pointer offset + after MakeSliceAndCatDimOutermostPass). Moving such a slice loses the + nop, so we only move it when the permute savings outweigh the nop loss, + i.e. when the slice removes more than half the data (full > 2 * sliced). + Non-dim-0 slices have no nop opportunity, so any permute savings is + pure win. + """ + + STRIDED_SLICE_COST_FACTOR: int = 2 + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.permute_copy.default] + + @staticmethod + def _is_profitable( + slice_dim: int, full_shape: torch.Size, sliced_shape: torch.Size + ) -> bool: + full_size = prod(full_shape) + sliced_size = prod(sliced_shape) + if slice_dim == 0: + return ( + full_size + > MoveSliceBeforePermutePass.STRIDED_SLICE_COST_FACTOR * sliced_size + ) + return True + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + perm = cast(list[int], node.args[1]) + permute_input = node.args[0] + assert isinstance(permute_input, torch.fx.Node) + + if len(node.users) != 1: + return False + + slice_node = next(iter(node.users)) + if slice_node.target != exir_ops.edge.aten.slice_copy.Tensor: + return False + + slice_dim = get_arg(slice_node, "dim", int) + + if not self._is_profitable( + slice_dim, + node.meta["val"].shape, + slice_node.meta["val"].shape, + ): + return False + + new_dim = perm[slice_dim] + graph = node.graph + + with graph.inserting_before(node): + new_slice = graph.create_node( + "call_function", + exir_ops.edge.aten.slice_copy.Tensor, + args=( + permute_input, + new_dim, + get_arg(slice_node, "start"), + get_arg(slice_node, "end"), + get_arg(slice_node, "step", int), + ), + ) + new_permute = graph.create_node( + "call_function", + exir_ops.edge.aten.permute_copy.default, + args=(new_slice, perm), + ) + + slice_node.replace_all_uses_with(new_permute) + return True + + # The following class consolidates functions to reoder ops (i.e., either hoist # or sink some ops in the graph). class CadenceReorderOpsInGraph: diff --git a/backends/cadence/aot/tests/test_reorder_ops_passes.py b/backends/cadence/aot/tests/test_reorder_ops_passes.py index 4aa7f46c8a1..46b81f98115 100644 --- a/backends/cadence/aot/tests/test_reorder_ops_passes.py +++ b/backends/cadence/aot/tests/test_reorder_ops_passes.py @@ -23,6 +23,7 @@ from executorch.backends.cadence.aot.reorder_ops import ( AdvanceQuantizeOpAboveDefChainPass, AdvanceQuantizeOpAboveDefInBranchPass, + MoveSliceBeforePermutePass, PostponeDequantizeOpBelowUseChainPass, PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, SinkOpsCloserToUsePass, @@ -633,3 +634,148 @@ def test_permute_view_chains_neg(self) -> None: self.assertTrue(nodes[1] == exir_ops.edge.aten.permute_copy) self.assertTrue(nodes[2] == exir_ops.edge.aten.view_copy) self.assertTrue(nodes[3] == exir_ops.edge.aten.permute_copy) + + +class TestMoveSliceBeforePermutePass(unittest.TestCase): + def test_basic_move(self) -> None: + """permute → slice becomes slice → permute.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + sliced = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(permuted, 1, 0, 2, 1), + ) + builder.output([sliced]) + original = builder.get_graph_module() + + result = transform_and_check_numerics( + original, + (torch.randn(2, 3, 4, 5),), + MoveSliceBeforePermutePass(), + ) + self.assertTrue(result.modified) + + nodes = get_compute_nodes_in_gm(result.graph_module) + self.assertEqual(len(nodes), 2) + self.assertEqual(nodes[0], exir_ops.edge.aten.slice_copy) + self.assertEqual(nodes[1], exir_ops.edge.aten.permute_copy) + + def test_multi_user_permute_no_change(self) -> None: + """Permute with multiple users → no change (only single-user supported).""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + slice1 = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(permuted, 1, 0, 2, 1), + ) + slice2 = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(permuted, 2, 1, 3, 1), + ) + builder.output([slice1, slice2]) + original = builder.get_graph_module() + + result = cast(PassResult, MoveSliceBeforePermutePass()(original)) + self.assertFalse(result.modified) + + def test_mixed_users_no_change(self) -> None: + """Permute with one slice user and one non-slice user → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + sliced = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(permuted, 1, 0, 2, 1), + ) + neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(permuted,)) + builder.output([sliced, neg]) + original = builder.get_graph_module() + + result = cast(PassResult, MoveSliceBeforePermutePass()(original)) + self.assertFalse(result.modified) + + def test_no_slice_users_no_change(self) -> None: + """Permute with no slice users → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(permuted,)) + builder.output([neg]) + original = builder.get_graph_module() + + result = cast(PassResult, MoveSliceBeforePermutePass()(original)) + self.assertFalse(result.modified) + + def test_dim0_slice_large_reduction_moved(self) -> None: + """Dim-0 slice removing >50% of data → profitable, moved.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(10, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + sliced = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(permuted, 0, 0, 2, 1), + ) + builder.output([sliced]) + original = builder.get_graph_module() + + result = transform_and_check_numerics( + original, + (torch.randn(10, 3, 4, 5),), + MoveSliceBeforePermutePass(), + ) + self.assertTrue(result.modified) + + nodes = get_compute_nodes_in_gm(result.graph_module) + self.assertEqual(len(nodes), 2) + self.assertEqual(nodes[0], exir_ops.edge.aten.slice_copy) + self.assertEqual(nodes[1], exir_ops.edge.aten.permute_copy) + + def test_dim0_slice_small_reduction_not_moved(self) -> None: + """Dim-0 slice removing <50% of data → not profitable, kept.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(10, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + sliced = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(permuted, 0, 0, 8, 1), + ) + builder.output([sliced]) + original = builder.get_graph_module() + + result = cast(PassResult, MoveSliceBeforePermutePass()(original)) + self.assertFalse(result.modified) + + def test_non_dim0_slice_always_moved(self) -> None: + """Non-dim-0 slice → always profitable, moved regardless of reduction.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(10, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + sliced = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(permuted, 2, 0, 3, 1), + ) + builder.output([sliced]) + original = builder.get_graph_module() + + result = transform_and_check_numerics( + original, + (torch.randn(10, 3, 4, 5),), + MoveSliceBeforePermutePass(), + ) + self.assertTrue(result.modified)