diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index dabab032116..e4f38cdc4d2 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -387,6 +387,127 @@ def maybe_remove_or_replace(self, node: Node) -> bool: return False +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemovePermuteBeforeMeanPass(RemoveOrReplacePassInterface): + """Remove or sink permute ops that precede mean reductions through unary chains. + + When a permute feeds into a mean (possibly through unary ops like + dequantize/quantize), two optimizations apply: + + 1. If non-reduced dims maintain their relative order and positions, the + permute is fully removed and the mean's reduction dims are remapped. + 2. Otherwise, the permute is moved after the mean so it operates on + smaller data. + """ + + _UNARY_TARGETS: frozenset[EdgeOpOverload] = frozenset( + { + exir_ops.edge.cadence.dequantize_per_tensor.default, + exir_ops.edge.cadence.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.neg.default, + exir_ops.edge.aten.abs.default, + } + ) + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mean.dim] + + def _find_permute_through_unary_chain(self, mean_node: Node) -> Optional[Node]: + """Walk backward from mean through single-user unary ops to find a permute.""" + current = mean_node.args[0] + if not isinstance(current, Node): + return None + while True: + if current.target == exir_ops.edge.aten.permute_copy.default: + return current + if current.target not in self._UNARY_TARGETS: + return None + if len(current.users) != 1: + return None + parent = current.args[0] + if not isinstance(parent, Node): + return None + current = parent + + @staticmethod + def _get_keepdim(node: Node) -> bool: + if len(node.args) >= 3: + return bool(node.args[2]) + return bool(node.kwargs.get("keepdim", False)) + + @staticmethod + def _can_fully_remove( + perm: list[int], new_reduction_dims: list[int], ndim: int, keepdim: bool + ) -> bool: + """Check whether the post-mean permute would be a no-op.""" + canonical_reduction = {d % ndim for d in new_reduction_dims} + if keepdim: + return all( + perm[d] == d for d in range(ndim) if d not in canonical_reduction + ) + non_reduced_in_perm_order = [d for d in perm if d not in canonical_reduction] + return non_reduced_in_perm_order == sorted(non_reduced_in_perm_order) + + @staticmethod + def _compute_post_mean_perm( + perm: list[int], new_reduction_dims: list[int], ndim: int, keepdim: bool + ) -> list[int]: + """Compute the permutation to insert after the mean.""" + if keepdim: + return list(perm) + canonical_reduction = {d % ndim for d in new_reduction_dims} + non_reduced_original = sorted( + d for d in range(ndim) if d not in canonical_reduction + ) + non_reduced_permuted = [d for d in perm if d not in canonical_reduction] + return [non_reduced_original.index(d) for d in non_reduced_permuted] + + def maybe_remove_or_replace(self, node: Node) -> bool: + reduction_dims = cast(list[int], node.args[1]) + + permute_node = self._find_permute_through_unary_chain(node) + if permute_node is None: + return False + + perm = cast(list[int], permute_node.args[1]) + ndim = len(perm) + + if len(permute_node.users) != 1: + return False + + permute_input = permute_node.args[0] + assert isinstance(permute_input, Node) + + new_reduction_dims = [perm[d % ndim] for d in reduction_dims] + keepdim = self._get_keepdim(node) + can_remove = self._can_fully_remove(perm, new_reduction_dims, ndim, keepdim) + + permute_node.replace_all_uses_with(permute_input) + node.args = (node.args[0], new_reduction_dims) + node.args[2:] + + if not can_remove: + post_perm = self._compute_post_mean_perm( + perm, new_reduction_dims, ndim, keepdim + ) + graph = node.graph + with graph.inserting_after(node): + new_permute = graph.create_node( + "call_function", + exir_ops.edge.aten.permute_copy.default, + args=(node, post_perm), + ) + for user in list(node.users): + if user is not new_permute: + user.replace_input_with(node, new_permute) + + return True + + @register_cadence_pass(CadencePassAttribute(opt_level=2)) class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps): permutable_ops: set[EdgeOpOverload] = ( @@ -646,6 +767,7 @@ class CommonRemovePasses: RemoveNopSliceOrViewOpPass, RemoveToOpsPass, RemoveZeroSizedCatArgsPass, + RemovePermuteBeforeMeanPass, RemovePermutesAroundElementwiseOps, FuseTransposeOrPermuteOpPairsPass, RemoveSqueezeViewBeforeElementwiseOps, diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 11bceff0a05..14299e54a39 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -28,6 +28,7 @@ RemoveNopLinalgVectorNormOpPass, RemoveNopMulOpPass, RemoveNopSliceOrViewOpPass, + RemovePermuteBeforeMeanPass, RemovePermutesAroundElementwiseOps, RemoveSqueezeViewBeforeElementwiseOps, RemoveToOpsPass, @@ -1013,3 +1014,197 @@ def test_remove_cat_from_slice_copy_second_input(self) -> None: # Output should remain the same. self.assertTrue(torch.equal(graph_module(*inputs)[0], expected_outputs)) + + def test_remove_permute_before_mean_fully_removed(self) -> None: + """Permute → relu → mean where non-reduced dims preserve order → fully remove.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [1, 2], False) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # Permute should be fully removed. + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + + # Mean reduction dims should be remapped to original space. + mean_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mean.dim + ) + self.assertEqual(len(mean_nodes), 1) + self.assertEqual(mean_nodes[0].args[1], [2, 3]) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_sunk_after(self) -> None: + """Permute → relu → mean where non-reduced dims reorder → move permute after mean.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [2, 0, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [2, 3], False) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # One permute should remain (after the mean). + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 1 + ) + + # Mean reduction dims should be remapped. + mean_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mean.dim + ) + self.assertEqual(len(mean_nodes), 1) + self.assertEqual(mean_nodes[0].args[1], [3, 1]) + + # The permute should come after the mean, not before. + permute_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.permute_copy.default + ) + self.assertEqual(permute_nodes[0].args[0], mean_nodes[0]) + self.assertEqual(permute_nodes[0].args[1], [1, 0]) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_keepdim_true(self) -> None: + """Permute → relu → mean(keepdim=True) where only reduced dims shuffle → fully remove.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 1, 3, 2]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [2, 3], True) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # Permute fully removed (only reduced dims were shuffled). + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_keepdim_true_sunk(self) -> None: + """Permute → relu → mean(keepdim=True) where non-reduced dims move → sink permute.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [1, 2], True) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # One permute should remain (sunk after mean). + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 1 + ) + + # The post-mean permute uses the original perm since keepdim=True. + permute_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.permute_copy.default + ) + self.assertEqual(permute_nodes[0].args[1], [0, 2, 3, 1]) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_no_permute(self) -> None: + """No permute before mean → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + relu = builder.call_operator(op=exir_ops.edge.aten.relu.default, args=(x,)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [2, 3], False) + ) + builder.output([mean]) + original = builder.get_graph_module() + + result = cast(PassResult, RemovePermuteBeforeMeanPass()(original)) + self.assertFalse(result.modified) + + def test_remove_permute_before_mean_multi_user(self) -> None: + """Permute with multiple users → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [1, 2], False) + ) + # Second user of the permute prevents optimization. + neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(permuted,)) + builder.output([mean, neg]) + original = builder.get_graph_module() + + result = cast(PassResult, RemovePermuteBeforeMeanPass()(original)) + self.assertFalse(result.modified)