diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index d6ee88e94c6..09deabd86fe 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -1003,6 +1003,108 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class FuseMeanKeepDimWithViewPass(RemoveOrReplacePassInterface): + """Fuse mean + view_copy when the view toggles keepdim behavior. + + Case 1: mean(keepdim=True) + view that squeezes reduction dims + → mean(keepdim=False), view removed. + Case 2: mean(keepdim=False) + view that unsqueezes at reduction dims + → mean(keepdim=True), view removed. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mean.dim] + + @staticmethod + def _get_keepdim(node: torch.fx.Node) -> bool: + if len(node.args) >= 3: + return bool(node.args[2]) + return bool(node.kwargs.get("keepdim", False)) + + @staticmethod + def _set_keepdim(node: torch.fx.Node, keepdim: bool) -> None: + if "keepdim" in node.kwargs: + new_kwargs = dict(node.kwargs) + new_kwargs["keepdim"] = keepdim + node.kwargs = new_kwargs + elif len(node.args) > 2: + new_args = list(node.args) + new_args[2] = keepdim + node.args = tuple(new_args) + else: + node.args = tuple(node.args) + (keepdim,) + + @staticmethod + def _is_squeeze_of_reduction_dims( + mean_shape: list[int], + view_shape: list[int], + reduction_dims: list[int], + ndim: int, + ) -> bool: + canonical_reduction = {d % ndim for d in reduction_dims} + expected = [s for i, s in enumerate(mean_shape) if i not in canonical_reduction] + return list(view_shape) == expected + + @staticmethod + def _is_unsqueeze_at_reduction_dims( + mean_output_shape: list[int], + view_shape: list[int], + reduction_dims: list[int], + input_ndim: int, + ) -> bool: + canonical_reduction = {d % input_ndim for d in reduction_dims} + if len(view_shape) != input_ndim: + return False + non_reduced_idx = 0 + for i in range(input_ndim): + if i in canonical_reduction: + if view_shape[i] != 1: + return False + else: + if ( + non_reduced_idx >= len(mean_output_shape) + or view_shape[i] != mean_output_shape[non_reduced_idx] + ): + return False + non_reduced_idx += 1 + return non_reduced_idx == len(mean_output_shape) + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + if len(node.users) != 1: + return False + + view_node = next(iter(node.users)) + if view_node.target != exir_ops.edge.aten.view_copy.default: + return False + + reduction_dims = cast(list[int], node.args[1]) + keepdim = self._get_keepdim(node) + mean_output_shape = list(node.meta["val"].shape) + view_output_shape = list(view_node.meta["val"].shape) + + if keepdim: + ndim = len(mean_output_shape) + if not self._is_squeeze_of_reduction_dims( + mean_output_shape, view_output_shape, reduction_dims, ndim + ): + return False + self._set_keepdim(node, False) + else: + input_node = node.args[0] + assert isinstance(input_node, torch.fx.Node) + input_ndim = len(input_node.meta["val"].shape) + if not self._is_unsqueeze_at_reduction_dims( + mean_output_shape, view_output_shape, reduction_dims, input_ndim + ): + return False + self._set_keepdim(node, True) + + view_node.replace_all_uses_with(node) + return True + + class HierarchicalCSEPass(HierarchicalInplacePassInterface): """ A hierarchical Common Subexpression Elimination (CSE) pass that recursively @@ -1035,4 +1137,5 @@ class CadenceFuseOpsInGraph: FuseMulScalarIntoDequantPass, FuseFullThenReshapePass, FuseTransposeOrPermuteOpPairsPass, + FuseMeanKeepDimWithViewPass, ] diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index dabab032116..e4f38cdc4d2 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -387,6 +387,127 @@ def maybe_remove_or_replace(self, node: Node) -> bool: return False +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemovePermuteBeforeMeanPass(RemoveOrReplacePassInterface): + """Remove or sink permute ops that precede mean reductions through unary chains. + + When a permute feeds into a mean (possibly through unary ops like + dequantize/quantize), two optimizations apply: + + 1. If non-reduced dims maintain their relative order and positions, the + permute is fully removed and the mean's reduction dims are remapped. + 2. Otherwise, the permute is moved after the mean so it operates on + smaller data. + """ + + _UNARY_TARGETS: frozenset[EdgeOpOverload] = frozenset( + { + exir_ops.edge.cadence.dequantize_per_tensor.default, + exir_ops.edge.cadence.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.neg.default, + exir_ops.edge.aten.abs.default, + } + ) + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mean.dim] + + def _find_permute_through_unary_chain(self, mean_node: Node) -> Optional[Node]: + """Walk backward from mean through single-user unary ops to find a permute.""" + current = mean_node.args[0] + if not isinstance(current, Node): + return None + while True: + if current.target == exir_ops.edge.aten.permute_copy.default: + return current + if current.target not in self._UNARY_TARGETS: + return None + if len(current.users) != 1: + return None + parent = current.args[0] + if not isinstance(parent, Node): + return None + current = parent + + @staticmethod + def _get_keepdim(node: Node) -> bool: + if len(node.args) >= 3: + return bool(node.args[2]) + return bool(node.kwargs.get("keepdim", False)) + + @staticmethod + def _can_fully_remove( + perm: list[int], new_reduction_dims: list[int], ndim: int, keepdim: bool + ) -> bool: + """Check whether the post-mean permute would be a no-op.""" + canonical_reduction = {d % ndim for d in new_reduction_dims} + if keepdim: + return all( + perm[d] == d for d in range(ndim) if d not in canonical_reduction + ) + non_reduced_in_perm_order = [d for d in perm if d not in canonical_reduction] + return non_reduced_in_perm_order == sorted(non_reduced_in_perm_order) + + @staticmethod + def _compute_post_mean_perm( + perm: list[int], new_reduction_dims: list[int], ndim: int, keepdim: bool + ) -> list[int]: + """Compute the permutation to insert after the mean.""" + if keepdim: + return list(perm) + canonical_reduction = {d % ndim for d in new_reduction_dims} + non_reduced_original = sorted( + d for d in range(ndim) if d not in canonical_reduction + ) + non_reduced_permuted = [d for d in perm if d not in canonical_reduction] + return [non_reduced_original.index(d) for d in non_reduced_permuted] + + def maybe_remove_or_replace(self, node: Node) -> bool: + reduction_dims = cast(list[int], node.args[1]) + + permute_node = self._find_permute_through_unary_chain(node) + if permute_node is None: + return False + + perm = cast(list[int], permute_node.args[1]) + ndim = len(perm) + + if len(permute_node.users) != 1: + return False + + permute_input = permute_node.args[0] + assert isinstance(permute_input, Node) + + new_reduction_dims = [perm[d % ndim] for d in reduction_dims] + keepdim = self._get_keepdim(node) + can_remove = self._can_fully_remove(perm, new_reduction_dims, ndim, keepdim) + + permute_node.replace_all_uses_with(permute_input) + node.args = (node.args[0], new_reduction_dims) + node.args[2:] + + if not can_remove: + post_perm = self._compute_post_mean_perm( + perm, new_reduction_dims, ndim, keepdim + ) + graph = node.graph + with graph.inserting_after(node): + new_permute = graph.create_node( + "call_function", + exir_ops.edge.aten.permute_copy.default, + args=(node, post_perm), + ) + for user in list(node.users): + if user is not new_permute: + user.replace_input_with(node, new_permute) + + return True + + @register_cadence_pass(CadencePassAttribute(opt_level=2)) class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps): permutable_ops: set[EdgeOpOverload] = ( @@ -646,6 +767,7 @@ class CommonRemovePasses: RemoveNopSliceOrViewOpPass, RemoveToOpsPass, RemoveZeroSizedCatArgsPass, + RemovePermuteBeforeMeanPass, RemovePermutesAroundElementwiseOps, FuseTransposeOrPermuteOpPairsPass, RemoveSqueezeViewBeforeElementwiseOps, diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index f5afbe243f8..f519febba37 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -19,6 +19,7 @@ FuseCascadedTransposeOrPermuteOps, FuseCascadedViewOps, FuseFullThenReshapePass, + FuseMeanKeepDimWithViewPass, FuseMMWithAdd, FuseMulScalarIntoDequantPass, FuseMulTensorIntoDequantPass, @@ -1696,3 +1697,168 @@ def __init__(self) -> None: # Verify fusion occurred: bn should be removed, conv remains self.assertEqual(count_node(gm, conv_op), 1) self.assertEqual(count_node(gm, bn_op), 0) + + +class TestFuseMeanKeepDimWithViewPass(TestFusionPassesBase): + def test_keepdim_true_to_false(self) -> None: + """mean(keepdim=True) + view that squeezes reduction dims → mean(keepdim=False).""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 62, 4, 4)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(x, [-1, -2], True) + ) + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mean, [2, 62]) + ) + builder.output([view]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + result = cast(PassResult, FuseMeanKeepDimWithViewPass()(original)) + gm = result.graph_module + self.assertTrue(result.modified) + + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + mean_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mean.dim + ) + self.assertEqual(len(mean_nodes), 1) + self.assertFalse(mean_nodes[0].args[2]) + + validate_numerics( + gm_before, gm, (torch.randn(2, 62, 4, 4),), "FuseMeanKeepDimWithViewPass" + ) + + def test_keepdim_false_to_true(self) -> None: + """mean(keepdim=False) + view that unsqueezes at reduction dims → mean(keepdim=True).""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 62, 4, 4)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(x, [-1, -2], False) + ) + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mean, [2, 62, 1, 1]) + ) + builder.output([view]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + result = cast(PassResult, FuseMeanKeepDimWithViewPass()(original)) + gm = result.graph_module + self.assertTrue(result.modified) + + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + mean_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mean.dim + ) + self.assertEqual(len(mean_nodes), 1) + self.assertTrue(mean_nodes[0].args[2]) + + validate_numerics( + gm_before, gm, (torch.randn(2, 62, 4, 4),), "FuseMeanKeepDimWithViewPass" + ) + + def test_keepdim_true_view_does_not_match(self) -> None: + """View reshapes to something other than squeezing reduction dims → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 62, 4, 4)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(x, [-1, -2], True) + ) + # Reshape to a different layout, not a simple squeeze of reduction dims. + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mean, [1, 2, 62]) + ) + builder.output([view]) + original = builder.get_graph_module() + + result = cast(PassResult, FuseMeanKeepDimWithViewPass()(original)) + self.assertFalse(result.modified) + + def test_keepdim_false_view_wrong_unsqueeze(self) -> None: + """View inserts 1s at wrong positions → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 62, 4, 4)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(x, [-1, -2], False) + ) + # 1s at positions 0 and 1 instead of 2 and 3. + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mean, [1, 1, 2, 62]) + ) + builder.output([view]) + original = builder.get_graph_module() + + result = cast(PassResult, FuseMeanKeepDimWithViewPass()(original)) + self.assertFalse(result.modified) + + def test_mean_multiple_users_no_change(self) -> None: + """Mean has multiple users → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 62, 4, 4)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(x, [-1, -2], True) + ) + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mean, [2, 62]) + ) + neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(mean,)) + builder.output([view, neg]) + original = builder.get_graph_module() + + result = cast(PassResult, FuseMeanKeepDimWithViewPass()(original)) + self.assertFalse(result.modified) + + def test_reduce_single_dim(self) -> None: + """Reduction over a single dim, both directions.""" + # keepdim=True → False + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 4, 5)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(x, [1], True) + ) + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mean, [3, 5]) + ) + builder.output([view]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + result = cast(PassResult, FuseMeanKeepDimWithViewPass()(original)) + self.assertTrue(result.modified) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.view_copy.default), 0 + ) + + validate_numerics( + gm_before, + result.graph_module, + (torch.randn(3, 4, 5),), + "FuseMeanKeepDimWithViewPass", + ) + + # keepdim=False → True + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 4, 5)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(x, [1], False) + ) + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mean, [3, 1, 5]) + ) + builder.output([view]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + result = cast(PassResult, FuseMeanKeepDimWithViewPass()(original)) + self.assertTrue(result.modified) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.view_copy.default), 0 + ) + + validate_numerics( + gm_before, + result.graph_module, + (torch.randn(3, 4, 5),), + "FuseMeanKeepDimWithViewPass", + ) diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 11bceff0a05..14299e54a39 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -28,6 +28,7 @@ RemoveNopLinalgVectorNormOpPass, RemoveNopMulOpPass, RemoveNopSliceOrViewOpPass, + RemovePermuteBeforeMeanPass, RemovePermutesAroundElementwiseOps, RemoveSqueezeViewBeforeElementwiseOps, RemoveToOpsPass, @@ -1013,3 +1014,197 @@ def test_remove_cat_from_slice_copy_second_input(self) -> None: # Output should remain the same. self.assertTrue(torch.equal(graph_module(*inputs)[0], expected_outputs)) + + def test_remove_permute_before_mean_fully_removed(self) -> None: + """Permute → relu → mean where non-reduced dims preserve order → fully remove.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [1, 2], False) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # Permute should be fully removed. + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + + # Mean reduction dims should be remapped to original space. + mean_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mean.dim + ) + self.assertEqual(len(mean_nodes), 1) + self.assertEqual(mean_nodes[0].args[1], [2, 3]) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_sunk_after(self) -> None: + """Permute → relu → mean where non-reduced dims reorder → move permute after mean.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [2, 0, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [2, 3], False) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # One permute should remain (after the mean). + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 1 + ) + + # Mean reduction dims should be remapped. + mean_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mean.dim + ) + self.assertEqual(len(mean_nodes), 1) + self.assertEqual(mean_nodes[0].args[1], [3, 1]) + + # The permute should come after the mean, not before. + permute_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.permute_copy.default + ) + self.assertEqual(permute_nodes[0].args[0], mean_nodes[0]) + self.assertEqual(permute_nodes[0].args[1], [1, 0]) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_keepdim_true(self) -> None: + """Permute → relu → mean(keepdim=True) where only reduced dims shuffle → fully remove.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 1, 3, 2]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [2, 3], True) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # Permute fully removed (only reduced dims were shuffled). + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_keepdim_true_sunk(self) -> None: + """Permute → relu → mean(keepdim=True) where non-reduced dims move → sink permute.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [1, 2], True) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # One permute should remain (sunk after mean). + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 1 + ) + + # The post-mean permute uses the original perm since keepdim=True. + permute_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.permute_copy.default + ) + self.assertEqual(permute_nodes[0].args[1], [0, 2, 3, 1]) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_no_permute(self) -> None: + """No permute before mean → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + relu = builder.call_operator(op=exir_ops.edge.aten.relu.default, args=(x,)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [2, 3], False) + ) + builder.output([mean]) + original = builder.get_graph_module() + + result = cast(PassResult, RemovePermuteBeforeMeanPass()(original)) + self.assertFalse(result.modified) + + def test_remove_permute_before_mean_multi_user(self) -> None: + """Permute with multiple users → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [1, 2], False) + ) + # Second user of the permute prevents optimization. + neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(permuted,)) + builder.output([mean, neg]) + original = builder.get_graph_module() + + result = cast(PassResult, RemovePermuteBeforeMeanPass()(original)) + self.assertFalse(result.modified)