From 5665a2626dfe569da3f6a77e643cef460429e64a Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Thu, 23 Apr 2026 19:42:19 -0700 Subject: [PATCH 1/2] Remove or move permute after mean Summary: If we have a permute -> unary chain -> mean, based on the reduction dims of the mean, we can either fully remove the preceding permute or move the permute after the mean. Case 1: Dims after permute are still in same order with respect to each other, we can fully get rid of the permute and just update the reduction dims of the mean. Case 2: Not case 1. In this case, it's better to move the permute after the mean, since the permute will operate on less data. Differential Revision: D102268214 --- backends/cadence/aot/remove_ops.py | 122 +++++++++++ .../aot/tests/test_remove_ops_passes.py | 195 ++++++++++++++++++ 2 files changed, 317 insertions(+) diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index dabab032116..e4f38cdc4d2 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -387,6 +387,127 @@ def maybe_remove_or_replace(self, node: Node) -> bool: return False +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemovePermuteBeforeMeanPass(RemoveOrReplacePassInterface): + """Remove or sink permute ops that precede mean reductions through unary chains. + + When a permute feeds into a mean (possibly through unary ops like + dequantize/quantize), two optimizations apply: + + 1. If non-reduced dims maintain their relative order and positions, the + permute is fully removed and the mean's reduction dims are remapped. + 2. Otherwise, the permute is moved after the mean so it operates on + smaller data. + """ + + _UNARY_TARGETS: frozenset[EdgeOpOverload] = frozenset( + { + exir_ops.edge.cadence.dequantize_per_tensor.default, + exir_ops.edge.cadence.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.neg.default, + exir_ops.edge.aten.abs.default, + } + ) + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mean.dim] + + def _find_permute_through_unary_chain(self, mean_node: Node) -> Optional[Node]: + """Walk backward from mean through single-user unary ops to find a permute.""" + current = mean_node.args[0] + if not isinstance(current, Node): + return None + while True: + if current.target == exir_ops.edge.aten.permute_copy.default: + return current + if current.target not in self._UNARY_TARGETS: + return None + if len(current.users) != 1: + return None + parent = current.args[0] + if not isinstance(parent, Node): + return None + current = parent + + @staticmethod + def _get_keepdim(node: Node) -> bool: + if len(node.args) >= 3: + return bool(node.args[2]) + return bool(node.kwargs.get("keepdim", False)) + + @staticmethod + def _can_fully_remove( + perm: list[int], new_reduction_dims: list[int], ndim: int, keepdim: bool + ) -> bool: + """Check whether the post-mean permute would be a no-op.""" + canonical_reduction = {d % ndim for d in new_reduction_dims} + if keepdim: + return all( + perm[d] == d for d in range(ndim) if d not in canonical_reduction + ) + non_reduced_in_perm_order = [d for d in perm if d not in canonical_reduction] + return non_reduced_in_perm_order == sorted(non_reduced_in_perm_order) + + @staticmethod + def _compute_post_mean_perm( + perm: list[int], new_reduction_dims: list[int], ndim: int, keepdim: bool + ) -> list[int]: + """Compute the permutation to insert after the mean.""" + if keepdim: + return list(perm) + canonical_reduction = {d % ndim for d in new_reduction_dims} + non_reduced_original = sorted( + d for d in range(ndim) if d not in canonical_reduction + ) + non_reduced_permuted = [d for d in perm if d not in canonical_reduction] + return [non_reduced_original.index(d) for d in non_reduced_permuted] + + def maybe_remove_or_replace(self, node: Node) -> bool: + reduction_dims = cast(list[int], node.args[1]) + + permute_node = self._find_permute_through_unary_chain(node) + if permute_node is None: + return False + + perm = cast(list[int], permute_node.args[1]) + ndim = len(perm) + + if len(permute_node.users) != 1: + return False + + permute_input = permute_node.args[0] + assert isinstance(permute_input, Node) + + new_reduction_dims = [perm[d % ndim] for d in reduction_dims] + keepdim = self._get_keepdim(node) + can_remove = self._can_fully_remove(perm, new_reduction_dims, ndim, keepdim) + + permute_node.replace_all_uses_with(permute_input) + node.args = (node.args[0], new_reduction_dims) + node.args[2:] + + if not can_remove: + post_perm = self._compute_post_mean_perm( + perm, new_reduction_dims, ndim, keepdim + ) + graph = node.graph + with graph.inserting_after(node): + new_permute = graph.create_node( + "call_function", + exir_ops.edge.aten.permute_copy.default, + args=(node, post_perm), + ) + for user in list(node.users): + if user is not new_permute: + user.replace_input_with(node, new_permute) + + return True + + @register_cadence_pass(CadencePassAttribute(opt_level=2)) class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps): permutable_ops: set[EdgeOpOverload] = ( @@ -646,6 +767,7 @@ class CommonRemovePasses: RemoveNopSliceOrViewOpPass, RemoveToOpsPass, RemoveZeroSizedCatArgsPass, + RemovePermuteBeforeMeanPass, RemovePermutesAroundElementwiseOps, FuseTransposeOrPermuteOpPairsPass, RemoveSqueezeViewBeforeElementwiseOps, diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 11bceff0a05..14299e54a39 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -28,6 +28,7 @@ RemoveNopLinalgVectorNormOpPass, RemoveNopMulOpPass, RemoveNopSliceOrViewOpPass, + RemovePermuteBeforeMeanPass, RemovePermutesAroundElementwiseOps, RemoveSqueezeViewBeforeElementwiseOps, RemoveToOpsPass, @@ -1013,3 +1014,197 @@ def test_remove_cat_from_slice_copy_second_input(self) -> None: # Output should remain the same. self.assertTrue(torch.equal(graph_module(*inputs)[0], expected_outputs)) + + def test_remove_permute_before_mean_fully_removed(self) -> None: + """Permute → relu → mean where non-reduced dims preserve order → fully remove.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [1, 2], False) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # Permute should be fully removed. + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + + # Mean reduction dims should be remapped to original space. + mean_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mean.dim + ) + self.assertEqual(len(mean_nodes), 1) + self.assertEqual(mean_nodes[0].args[1], [2, 3]) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_sunk_after(self) -> None: + """Permute → relu → mean where non-reduced dims reorder → move permute after mean.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [2, 0, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [2, 3], False) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # One permute should remain (after the mean). + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 1 + ) + + # Mean reduction dims should be remapped. + mean_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mean.dim + ) + self.assertEqual(len(mean_nodes), 1) + self.assertEqual(mean_nodes[0].args[1], [3, 1]) + + # The permute should come after the mean, not before. + permute_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.permute_copy.default + ) + self.assertEqual(permute_nodes[0].args[0], mean_nodes[0]) + self.assertEqual(permute_nodes[0].args[1], [1, 0]) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_keepdim_true(self) -> None: + """Permute → relu → mean(keepdim=True) where only reduced dims shuffle → fully remove.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 1, 3, 2]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [2, 3], True) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # Permute fully removed (only reduced dims were shuffled). + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_keepdim_true_sunk(self) -> None: + """Permute → relu → mean(keepdim=True) where non-reduced dims move → sink permute.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [1, 2], True) + ) + builder.output([mean]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + graph_after = cast( + PassResult, RemovePermuteBeforeMeanPass()(original) + ).graph_module + + # One permute should remain (sunk after mean). + self.assertEqual( + count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 1 + ) + + # The post-mean permute uses the original perm since keepdim=True. + permute_nodes = graph_after.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.permute_copy.default + ) + self.assertEqual(permute_nodes[0].args[1], [0, 2, 3, 1]) + + validate( + gm_before, + graph_after, + (torch.randn(2, 3, 4, 5),), + "RemovePermuteBeforeMeanPass", + ) + + def test_remove_permute_before_mean_no_permute(self) -> None: + """No permute before mean → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + relu = builder.call_operator(op=exir_ops.edge.aten.relu.default, args=(x,)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [2, 3], False) + ) + builder.output([mean]) + original = builder.get_graph_module() + + result = cast(PassResult, RemovePermuteBeforeMeanPass()(original)) + self.assertFalse(result.modified) + + def test_remove_permute_before_mean_multi_user(self) -> None: + """Permute with multiple users → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permuted = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + relu = builder.call_operator( + op=exir_ops.edge.aten.relu.default, args=(permuted,) + ) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(relu, [1, 2], False) + ) + # Second user of the permute prevents optimization. + neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(permuted,)) + builder.output([mean, neg]) + original = builder.get_graph_module() + + result = cast(PassResult, RemovePermuteBeforeMeanPass()(original)) + self.assertFalse(result.modified) From c6baa03f566bcf3f3daccafd1e1a4116489bcd2e Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Thu, 23 Apr 2026 21:11:34 -0700 Subject: [PATCH 2/2] Fuse mean and view if possible Summary: I've seen lots of cases where we have a mean followed by a view which can either be fused into mean with keepdim True or False. These views typically become no-ops, however it's nice to remove them if we don't need them. Differential Revision: D102276633 --- backends/cadence/aot/fuse_ops.py | 103 +++++++++++ .../aot/tests/test_fusion_ops_passes.py | 166 ++++++++++++++++++ 2 files changed, 269 insertions(+) diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index d6ee88e94c6..09deabd86fe 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -1003,6 +1003,108 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class FuseMeanKeepDimWithViewPass(RemoveOrReplacePassInterface): + """Fuse mean + view_copy when the view toggles keepdim behavior. + + Case 1: mean(keepdim=True) + view that squeezes reduction dims + → mean(keepdim=False), view removed. + Case 2: mean(keepdim=False) + view that unsqueezes at reduction dims + → mean(keepdim=True), view removed. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mean.dim] + + @staticmethod + def _get_keepdim(node: torch.fx.Node) -> bool: + if len(node.args) >= 3: + return bool(node.args[2]) + return bool(node.kwargs.get("keepdim", False)) + + @staticmethod + def _set_keepdim(node: torch.fx.Node, keepdim: bool) -> None: + if "keepdim" in node.kwargs: + new_kwargs = dict(node.kwargs) + new_kwargs["keepdim"] = keepdim + node.kwargs = new_kwargs + elif len(node.args) > 2: + new_args = list(node.args) + new_args[2] = keepdim + node.args = tuple(new_args) + else: + node.args = tuple(node.args) + (keepdim,) + + @staticmethod + def _is_squeeze_of_reduction_dims( + mean_shape: list[int], + view_shape: list[int], + reduction_dims: list[int], + ndim: int, + ) -> bool: + canonical_reduction = {d % ndim for d in reduction_dims} + expected = [s for i, s in enumerate(mean_shape) if i not in canonical_reduction] + return list(view_shape) == expected + + @staticmethod + def _is_unsqueeze_at_reduction_dims( + mean_output_shape: list[int], + view_shape: list[int], + reduction_dims: list[int], + input_ndim: int, + ) -> bool: + canonical_reduction = {d % input_ndim for d in reduction_dims} + if len(view_shape) != input_ndim: + return False + non_reduced_idx = 0 + for i in range(input_ndim): + if i in canonical_reduction: + if view_shape[i] != 1: + return False + else: + if ( + non_reduced_idx >= len(mean_output_shape) + or view_shape[i] != mean_output_shape[non_reduced_idx] + ): + return False + non_reduced_idx += 1 + return non_reduced_idx == len(mean_output_shape) + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + if len(node.users) != 1: + return False + + view_node = next(iter(node.users)) + if view_node.target != exir_ops.edge.aten.view_copy.default: + return False + + reduction_dims = cast(list[int], node.args[1]) + keepdim = self._get_keepdim(node) + mean_output_shape = list(node.meta["val"].shape) + view_output_shape = list(view_node.meta["val"].shape) + + if keepdim: + ndim = len(mean_output_shape) + if not self._is_squeeze_of_reduction_dims( + mean_output_shape, view_output_shape, reduction_dims, ndim + ): + return False + self._set_keepdim(node, False) + else: + input_node = node.args[0] + assert isinstance(input_node, torch.fx.Node) + input_ndim = len(input_node.meta["val"].shape) + if not self._is_unsqueeze_at_reduction_dims( + mean_output_shape, view_output_shape, reduction_dims, input_ndim + ): + return False + self._set_keepdim(node, True) + + view_node.replace_all_uses_with(node) + return True + + class HierarchicalCSEPass(HierarchicalInplacePassInterface): """ A hierarchical Common Subexpression Elimination (CSE) pass that recursively @@ -1035,4 +1137,5 @@ class CadenceFuseOpsInGraph: FuseMulScalarIntoDequantPass, FuseFullThenReshapePass, FuseTransposeOrPermuteOpPairsPass, + FuseMeanKeepDimWithViewPass, ] diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index f5afbe243f8..f519febba37 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -19,6 +19,7 @@ FuseCascadedTransposeOrPermuteOps, FuseCascadedViewOps, FuseFullThenReshapePass, + FuseMeanKeepDimWithViewPass, FuseMMWithAdd, FuseMulScalarIntoDequantPass, FuseMulTensorIntoDequantPass, @@ -1696,3 +1697,168 @@ def __init__(self) -> None: # Verify fusion occurred: bn should be removed, conv remains self.assertEqual(count_node(gm, conv_op), 1) self.assertEqual(count_node(gm, bn_op), 0) + + +class TestFuseMeanKeepDimWithViewPass(TestFusionPassesBase): + def test_keepdim_true_to_false(self) -> None: + """mean(keepdim=True) + view that squeezes reduction dims → mean(keepdim=False).""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 62, 4, 4)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(x, [-1, -2], True) + ) + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mean, [2, 62]) + ) + builder.output([view]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + result = cast(PassResult, FuseMeanKeepDimWithViewPass()(original)) + gm = result.graph_module + self.assertTrue(result.modified) + + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + mean_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mean.dim + ) + self.assertEqual(len(mean_nodes), 1) + self.assertFalse(mean_nodes[0].args[2]) + + validate_numerics( + gm_before, gm, (torch.randn(2, 62, 4, 4),), "FuseMeanKeepDimWithViewPass" + ) + + def test_keepdim_false_to_true(self) -> None: + """mean(keepdim=False) + view that unsqueezes at reduction dims → mean(keepdim=True).""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 62, 4, 4)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(x, [-1, -2], False) + ) + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mean, [2, 62, 1, 1]) + ) + builder.output([view]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + result = cast(PassResult, FuseMeanKeepDimWithViewPass()(original)) + gm = result.graph_module + self.assertTrue(result.modified) + + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + mean_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mean.dim + ) + self.assertEqual(len(mean_nodes), 1) + self.assertTrue(mean_nodes[0].args[2]) + + validate_numerics( + gm_before, gm, (torch.randn(2, 62, 4, 4),), "FuseMeanKeepDimWithViewPass" + ) + + def test_keepdim_true_view_does_not_match(self) -> None: + """View reshapes to something other than squeezing reduction dims → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 62, 4, 4)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(x, [-1, -2], True) + ) + # Reshape to a different layout, not a simple squeeze of reduction dims. + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mean, [1, 2, 62]) + ) + builder.output([view]) + original = builder.get_graph_module() + + result = cast(PassResult, FuseMeanKeepDimWithViewPass()(original)) + self.assertFalse(result.modified) + + def test_keepdim_false_view_wrong_unsqueeze(self) -> None: + """View inserts 1s at wrong positions → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 62, 4, 4)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(x, [-1, -2], False) + ) + # 1s at positions 0 and 1 instead of 2 and 3. + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mean, [1, 1, 2, 62]) + ) + builder.output([view]) + original = builder.get_graph_module() + + result = cast(PassResult, FuseMeanKeepDimWithViewPass()(original)) + self.assertFalse(result.modified) + + def test_mean_multiple_users_no_change(self) -> None: + """Mean has multiple users → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 62, 4, 4)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(x, [-1, -2], True) + ) + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mean, [2, 62]) + ) + neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(mean,)) + builder.output([view, neg]) + original = builder.get_graph_module() + + result = cast(PassResult, FuseMeanKeepDimWithViewPass()(original)) + self.assertFalse(result.modified) + + def test_reduce_single_dim(self) -> None: + """Reduction over a single dim, both directions.""" + # keepdim=True → False + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 4, 5)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(x, [1], True) + ) + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mean, [3, 5]) + ) + builder.output([view]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + result = cast(PassResult, FuseMeanKeepDimWithViewPass()(original)) + self.assertTrue(result.modified) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.view_copy.default), 0 + ) + + validate_numerics( + gm_before, + result.graph_module, + (torch.randn(3, 4, 5),), + "FuseMeanKeepDimWithViewPass", + ) + + # keepdim=False → True + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 4, 5)) + mean = builder.call_operator( + op=exir_ops.edge.aten.mean.dim, args=(x, [1], False) + ) + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mean, [3, 1, 5]) + ) + builder.output([view]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + result = cast(PassResult, FuseMeanKeepDimWithViewPass()(original)) + self.assertTrue(result.modified) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.view_copy.default), 0 + ) + + validate_numerics( + gm_before, + result.graph_module, + (torch.randn(3, 4, 5),), + "FuseMeanKeepDimWithViewPass", + )