From fe0d927a1a8cb4126b475e2d2c2a8a8b61d4dbac Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Thu, 23 Apr 2026 20:05:58 -0700 Subject: [PATCH] Remove or move permute after mean Summary: If we have a permute -> unary chain -> mean, based on the reduction dims of the mean, we can either fully remove the preceding permute or move the permute after the mean. Case 1: Dims after permute are still in same order with respect to each other, we can fully get rid of the permute and just update the reduction dims of the mean. Case 2: Not case 1. In this case, it's better to move the permute after the mean, since the permute will operate on less data. Differential Revision: D102268214 --- backends/cadence/aot/remove_ops.py | 122 +++++++++++ .../aot/tests/test_remove_ops_passes.py | 195 ++++++++++++++++++ 2 files changed, 317 insertions(+) diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index dabab032116..e4f38cdc4d2 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -387,6 +387,127 @@ def maybe_remove_or_replace(self, node: Node) -> bool: return False +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemovePermuteBeforeMeanPass(RemoveOrReplacePassInterface): + """Remove or sink permute ops that precede mean reductions through unary chains. + + When a permute feeds into a mean (possibly through unary ops like + dequantize/quantize), two optimizations apply: + + 1. If non-reduced dims maintain their relative order and positions, the + permute is fully removed and the mean's reduction dims are remapped. + 2. Otherwise, the permute is moved after the mean so it operates on + smaller data. + """ + + _UNARY_TARGETS: frozenset[EdgeOpOverload] = frozenset( + { + exir_ops.edge.cadence.dequantize_per_tensor.default, + exir_ops.edge.cadence.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.neg.default, + exir_ops.edge.aten.abs.default, + } + ) + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mean.dim] + + def _find_permute_through_unary_chain(self, mean_node: Node) -> Optional[Node]: + """Walk backward from mean through single-user unary ops to find a permute.""" + current = mean_node.args[0] + if not isinstance(current, Node): + return None + while True: + if current.target == exir_ops.edge.aten.permute_copy.default: + return current + if current.target not in self._UNARY_TARGETS: + return None + if len(current.users) != 1: + return None + parent = current.args[0] + if not isinstance(parent, Node): + return None + current = parent + + @staticmethod + def _get_keepdim(node: Node) -> bool: + if len(node.args) >= 3: + return bool(node.args[2]) + return bool(node.kwargs.get("keepdim", False)) + + @staticmethod + def _can_fully_remove( + perm: list[int], new_reduction_dims: list[int], ndim: int, keepdim: bool + ) -> bool: + """Check whether the post-mean permute would be a no-op.""" + canonical_reduction = {d % ndim for d in new_reduction_dims} + if keepdim: + return all( + perm[d] == d for d in range(ndim) if d not in canonical_reduction + ) + non_reduced_in_perm_order = [d for d in perm if d not in canonical_reduction] + return non_reduced_in_perm_order == sorted(non_reduced_in_perm_order) + + @staticmethod + def _compute_post_mean_perm( + perm: list[int], new_reduction_dims: list[int], ndim: int, keepdim: bool + ) -> list[int]: + """Compute the permutation to insert after the mean.""" + if keepdim: + return list(perm) + canonical_reduction = {d % ndim for d in new_reduction_dims} + non_reduced_original = sorted( + d for d in range(ndim) if d not in canonical_reduction + ) + non_reduced_permuted = [d for d in perm if d not in canonical_reduction] + return [non_reduced_original.index(d) for d in non_reduced_permuted] + + def maybe_remove_or_replace(self, node: Node) -> bool: + reduction_dims = cast(list[int], node.args[1]) + + permute_node = self._find_permute_through_unary_chain(node) + if permute_node is None: + return False + + perm = cast(list[int], permute_node.args[1]) + ndim = len(perm) + + if len(permute_node.users) != 1: + return False + + permute_input = permute_node.args[0] + assert isinstance(permute_input, Node) + + new_reduction_dims = [perm[d % ndim] for d in reduction_dims] + keepdim = self._get_keepdim(node) + can_remove = self._can_fully_remove(perm, new_reduction_dims, ndim, keepdim) + + permute_node.replace_all_uses_with(permute_input) + node.args = (node.args[0], new_reduction_dims) + node.args[2:] + + if not can_remove: + post_perm = self._compute_post_mean_perm( + perm, new_reduction_dims, ndim, keepdim + ) + graph = node.graph + with graph.inserting_after(node): + new_permute = graph.create_node( + "call_function", + exir_ops.edge.aten.permute_copy.default, + args=(node, post_perm), + ) + for user in list(node.users): + if user is not new_permute: + user.replace_input_with(node, new_permute) + + return True + + @register_cadence_pass(CadencePassAttribute(opt_level=2)) class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps): permutable_ops: set[EdgeOpOverload] = ( @@ -646,6 +767,7 @@ class CommonRemovePasses: RemoveNopSliceOrViewOpPass, RemoveToOpsPass, RemoveZeroSizedCatArgsPass, + RemovePermuteBeforeMeanPass, RemovePermutesAroundElementwiseOps, FuseTransposeOrPermuteOpPairsPass, RemoveSqueezeViewBeforeElementwiseOps, diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 11bceff0a05..14299e54a39 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -28,6 +28,7 @@ RemoveNopLinalgVectorNormOpPass, RemoveNopMulOpPass, RemoveNopSliceOrViewOpPass, + RemovePermuteBeforeMeanPass, RemovePermutesAroundElementwiseOps, RemoveSqueezeViewBeforeElementwiseOps, RemoveToOpsPass, @@ -1013,3 +1014,197 @@ def test_remove_cat_from_slice_copy_second_input(self) -> None: # Output should remain the same. self.assertTrue(torch.equal(graph_module(*inputs)[0], expected_outputs)) + + def test_remove_permute_before_mean_fully_removed(self) -> None: + """Permute → relu → mean where non-reduced dims preserve order → fully remove.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [1, 2], False) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # Permute should be fully removed. + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + + # Mean reduction dims should be remapped to original space. + mean_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mean.dim + ) + self.assertEqual(len(mean_nodes), 1) + self.assertEqual(mean_nodes[0].args[1], [2, 3]) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_sunk_after(self) -> None: + """Permute → relu → mean where non-reduced dims reorder → move permute after mean.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [2, 0, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [2, 3], False) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # One permute should remain (after the mean). + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 1 + ) + + # Mean reduction dims should be remapped. + mean_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mean.dim + ) + self.assertEqual(len(mean_nodes), 1) + self.assertEqual(mean_nodes[0].args[1], [3, 1]) + + # The permute should come after the mean, not before. + permute_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.permute_copy.default + ) + self.assertEqual(permute_nodes[0].args[0], mean_nodes[0]) + self.assertEqual(permute_nodes[0].args[1], [1, 0]) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_keepdim_true(self) -> None: + """Permute → relu → mean(keepdim=True) where only reduced dims shuffle → fully remove.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 1, 3, 2]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [2, 3], True) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # Permute fully removed (only reduced dims were shuffled). + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_keepdim_true_sunk(self) -> None: + """Permute → relu → mean(keepdim=True) where non-reduced dims move → sink permute.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [1, 2], True) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # One permute should remain (sunk after mean). + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 1 + ) + + # The post-mean permute uses the original perm since keepdim=True. + permute_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.permute_copy.default + ) + self.assertEqual(permute_nodes[0].args[1], [0, 2, 3, 1]) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_no_permute(self) -> None: + """No permute before mean → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + relu = builder.call_operator(op=exir_ops.edge.aten.relu.default, args=(x,)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [2, 3], False) + ) + builder.output([mean]) + original = builder.get_graph_module() + + result = cast(PassResult, RemovePermuteBeforeMeanPass()(original)) + self.assertFalse(result.modified) + + def test_remove_permute_before_mean_multi_user(self) -> None: + """Permute with multiple users → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [1, 2], False) + ) + # Second user of the permute prevents optimization. + neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(permuted,)) + builder.output([mean, neg]) + original = builder.get_graph_module() + + result = cast(PassResult, RemovePermuteBeforeMeanPass()(original)) + self.assertFalse(result.modified)