Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,108 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
return True


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class FuseMeanKeepDimWithViewPass(RemoveOrReplacePassInterface):
"""Fuse mean + view_copy when the view toggles keepdim behavior.

Case 1: mean(keepdim=True) + view that squeezes reduction dims
→ mean(keepdim=False), view removed.
Case 2: mean(keepdim=False) + view that unsqueezes at reduction dims
→ mean(keepdim=True), view removed.
"""

@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.mean.dim]

@staticmethod
def _get_keepdim(node: torch.fx.Node) -> bool:
if len(node.args) >= 3:
return bool(node.args[2])
return bool(node.kwargs.get("keepdim", False))

@staticmethod
def _set_keepdim(node: torch.fx.Node, keepdim: bool) -> None:
if "keepdim" in node.kwargs:
new_kwargs = dict(node.kwargs)
new_kwargs["keepdim"] = keepdim
node.kwargs = new_kwargs
elif len(node.args) > 2:
new_args = list(node.args)
new_args[2] = keepdim
node.args = tuple(new_args)
else:
node.args = tuple(node.args) + (keepdim,)

@staticmethod
def _is_squeeze_of_reduction_dims(
mean_shape: list[int],
view_shape: list[int],
reduction_dims: list[int],
ndim: int,
) -> bool:
canonical_reduction = {d % ndim for d in reduction_dims}
expected = [s for i, s in enumerate(mean_shape) if i not in canonical_reduction]
return list(view_shape) == expected

@staticmethod
def _is_unsqueeze_at_reduction_dims(
mean_output_shape: list[int],
view_shape: list[int],
reduction_dims: list[int],
input_ndim: int,
) -> bool:
canonical_reduction = {d % input_ndim for d in reduction_dims}
if len(view_shape) != input_ndim:
return False
non_reduced_idx = 0
for i in range(input_ndim):
if i in canonical_reduction:
if view_shape[i] != 1:
return False
else:
if (
non_reduced_idx >= len(mean_output_shape)
or view_shape[i] != mean_output_shape[non_reduced_idx]
):
return False
non_reduced_idx += 1
return non_reduced_idx == len(mean_output_shape)

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
if len(node.users) != 1:
return False

view_node = next(iter(node.users))
if view_node.target != exir_ops.edge.aten.view_copy.default:
return False

reduction_dims = cast(list[int], node.args[1])
keepdim = self._get_keepdim(node)
mean_output_shape = list(node.meta["val"].shape)
view_output_shape = list(view_node.meta["val"].shape)

if keepdim:
ndim = len(mean_output_shape)
if not self._is_squeeze_of_reduction_dims(
mean_output_shape, view_output_shape, reduction_dims, ndim
):
return False
self._set_keepdim(node, False)
else:
input_node = node.args[0]
assert isinstance(input_node, torch.fx.Node)
input_ndim = len(input_node.meta["val"].shape)
if not self._is_unsqueeze_at_reduction_dims(
mean_output_shape, view_output_shape, reduction_dims, input_ndim
):
return False
self._set_keepdim(node, True)

view_node.replace_all_uses_with(node)
return True


class HierarchicalCSEPass(HierarchicalInplacePassInterface):
"""
A hierarchical Common Subexpression Elimination (CSE) pass that recursively
Expand Down Expand Up @@ -1035,4 +1137,5 @@ class CadenceFuseOpsInGraph:
FuseMulScalarIntoDequantPass,
FuseFullThenReshapePass,
FuseTransposeOrPermuteOpPairsPass,
FuseMeanKeepDimWithViewPass,
]
122 changes: 122 additions & 0 deletions backends/cadence/aot/remove_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,127 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
return False


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemovePermuteBeforeMeanPass(RemoveOrReplacePassInterface):
"""Remove or sink permute ops that precede mean reductions through unary chains.

When a permute feeds into a mean (possibly through unary ops like
dequantize/quantize), two optimizations apply:

1. If non-reduced dims maintain their relative order and positions, the
permute is fully removed and the mean's reduction dims are remapped.
2. Otherwise, the permute is moved after the mean so it operates on
smaller data.
"""

_UNARY_TARGETS: frozenset[EdgeOpOverload] = frozenset(
{
exir_ops.edge.cadence.dequantize_per_tensor.default,
exir_ops.edge.cadence.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.abs.default,
}
)

@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.mean.dim]

def _find_permute_through_unary_chain(self, mean_node: Node) -> Optional[Node]:
"""Walk backward from mean through single-user unary ops to find a permute."""
current = mean_node.args[0]
if not isinstance(current, Node):
return None
while True:
if current.target == exir_ops.edge.aten.permute_copy.default:
return current
if current.target not in self._UNARY_TARGETS:
return None
if len(current.users) != 1:
return None
parent = current.args[0]
if not isinstance(parent, Node):
return None
current = parent

@staticmethod
def _get_keepdim(node: Node) -> bool:
if len(node.args) >= 3:
return bool(node.args[2])
return bool(node.kwargs.get("keepdim", False))

@staticmethod
def _can_fully_remove(
perm: list[int], new_reduction_dims: list[int], ndim: int, keepdim: bool
) -> bool:
"""Check whether the post-mean permute would be a no-op."""
canonical_reduction = {d % ndim for d in new_reduction_dims}
if keepdim:
return all(
perm[d] == d for d in range(ndim) if d not in canonical_reduction
)
non_reduced_in_perm_order = [d for d in perm if d not in canonical_reduction]
return non_reduced_in_perm_order == sorted(non_reduced_in_perm_order)

@staticmethod
def _compute_post_mean_perm(
perm: list[int], new_reduction_dims: list[int], ndim: int, keepdim: bool
) -> list[int]:
"""Compute the permutation to insert after the mean."""
if keepdim:
return list(perm)
canonical_reduction = {d % ndim for d in new_reduction_dims}
non_reduced_original = sorted(
d for d in range(ndim) if d not in canonical_reduction
)
non_reduced_permuted = [d for d in perm if d not in canonical_reduction]
return [non_reduced_original.index(d) for d in non_reduced_permuted]

def maybe_remove_or_replace(self, node: Node) -> bool:
reduction_dims = cast(list[int], node.args[1])

permute_node = self._find_permute_through_unary_chain(node)
if permute_node is None:
return False

perm = cast(list[int], permute_node.args[1])
ndim = len(perm)

if len(permute_node.users) != 1:
return False

permute_input = permute_node.args[0]
assert isinstance(permute_input, Node)

new_reduction_dims = [perm[d % ndim] for d in reduction_dims]
keepdim = self._get_keepdim(node)
can_remove = self._can_fully_remove(perm, new_reduction_dims, ndim, keepdim)

permute_node.replace_all_uses_with(permute_input)
node.args = (node.args[0], new_reduction_dims) + node.args[2:]

if not can_remove:
post_perm = self._compute_post_mean_perm(
perm, new_reduction_dims, ndim, keepdim
)
graph = node.graph
with graph.inserting_after(node):
new_permute = graph.create_node(
"call_function",
exir_ops.edge.aten.permute_copy.default,
args=(node, post_perm),
)
for user in list(node.users):
if user is not new_permute:
user.replace_input_with(node, new_permute)

return True


@register_cadence_pass(CadencePassAttribute(opt_level=2))
class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps):
permutable_ops: set[EdgeOpOverload] = (
Expand Down Expand Up @@ -646,6 +767,7 @@ class CommonRemovePasses:
RemoveNopSliceOrViewOpPass,
RemoveToOpsPass,
RemoveZeroSizedCatArgsPass,
RemovePermuteBeforeMeanPass,
RemovePermutesAroundElementwiseOps,
FuseTransposeOrPermuteOpPairsPass,
RemoveSqueezeViewBeforeElementwiseOps,
Expand Down
Loading
Loading