Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion backends/cadence/aot/reorder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@

from collections import defaultdict
from math import prod
from typing import DefaultDict, List, Tuple
from typing import cast, DefaultDict, List, Tuple

import torch
import torch.fx
from executorch.backends.cadence.aot.compiler_utils import get_placeholders, get_shape
from executorch.backends.cadence.aot.pass_utils import (
CadencePassAttribute,
get_arg,
get_overload_packet,
register_cadence_pass,
RemoveOrReplacePassInterface,
Expand Down Expand Up @@ -641,6 +642,87 @@ class PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(
pass


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class MoveSliceBeforePermutePass(RemoveOrReplacePassInterface):
"""Move slice_copy ops before permute_copy to reduce permute data volume.

Rewrites permute(input, perm) -> slice(dim=D) into
slice(input, dim=perm[D]) -> permute(sliced, perm), so the permute
operates on a smaller tensor.

Cost model: dim-0 slices are nop-eligible (zero-copy pointer offset
after MakeSliceAndCatDimOutermostPass). Moving such a slice loses the
nop, so we only move it when the permute savings outweigh the nop loss,
i.e. when the slice removes more than half the data (full > 2 * sliced).
Non-dim-0 slices have no nop opportunity, so any permute savings is
pure win.
"""

STRIDED_SLICE_COST_FACTOR: int = 2

@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.permute_copy.default]

@staticmethod
def _is_profitable(
slice_dim: int, full_shape: torch.Size, sliced_shape: torch.Size
) -> bool:
full_size = prod(full_shape)
sliced_size = prod(sliced_shape)
if slice_dim == 0:
return (
full_size
> MoveSliceBeforePermutePass.STRIDED_SLICE_COST_FACTOR * sliced_size
)
return True

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
perm = cast(list[int], node.args[1])
permute_input = node.args[0]
assert isinstance(permute_input, torch.fx.Node)

if len(node.users) != 1:
return False

slice_node = next(iter(node.users))
if slice_node.target != exir_ops.edge.aten.slice_copy.Tensor:
return False

slice_dim = get_arg(slice_node, "dim", int)

if not self._is_profitable(
slice_dim,
node.meta["val"].shape,
slice_node.meta["val"].shape,
):
return False

new_dim = perm[slice_dim]
graph = node.graph

with graph.inserting_before(node):
new_slice = graph.create_node(
"call_function",
exir_ops.edge.aten.slice_copy.Tensor,
args=(
permute_input,
new_dim,
get_arg(slice_node, "start"),
get_arg(slice_node, "end"),
get_arg(slice_node, "step", int),
),
)
new_permute = graph.create_node(
"call_function",
exir_ops.edge.aten.permute_copy.default,
args=(new_slice, perm),
)

slice_node.replace_all_uses_with(new_permute)
return True


# The following class consolidates functions to reoder ops (i.e., either hoist
# or sink some ops in the graph).
class CadenceReorderOpsInGraph:
Expand Down
146 changes: 146 additions & 0 deletions backends/cadence/aot/tests/test_reorder_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from executorch.backends.cadence.aot.reorder_ops import (
AdvanceQuantizeOpAboveDefChainPass,
AdvanceQuantizeOpAboveDefInBranchPass,
MoveSliceBeforePermutePass,
PostponeDequantizeOpBelowUseChainPass,
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
SinkOpsCloserToUsePass,
Expand Down Expand Up @@ -633,3 +634,148 @@ def test_permute_view_chains_neg(self) -> None:
self.assertTrue(nodes[1] == exir_ops.edge.aten.permute_copy)
self.assertTrue(nodes[2] == exir_ops.edge.aten.view_copy)
self.assertTrue(nodes[3] == exir_ops.edge.aten.permute_copy)


class TestMoveSliceBeforePermutePass(unittest.TestCase):
def test_basic_move(self) -> None:
"""permute → slice becomes slice → permute."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
permuted = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
)
sliced = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(permuted, 1, 0, 2, 1),
)
builder.output([sliced])
original = builder.get_graph_module()

result = transform_and_check_numerics(
original,
(torch.randn(2, 3, 4, 5),),
MoveSliceBeforePermutePass(),
)
self.assertTrue(result.modified)

nodes = get_compute_nodes_in_gm(result.graph_module)
self.assertEqual(len(nodes), 2)
self.assertEqual(nodes[0], exir_ops.edge.aten.slice_copy)
self.assertEqual(nodes[1], exir_ops.edge.aten.permute_copy)

def test_multi_user_permute_no_change(self) -> None:
"""Permute with multiple users → no change (only single-user supported)."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
permuted = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
)
slice1 = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(permuted, 1, 0, 2, 1),
)
slice2 = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(permuted, 2, 1, 3, 1),
)
builder.output([slice1, slice2])
original = builder.get_graph_module()

result = cast(PassResult, MoveSliceBeforePermutePass()(original))
self.assertFalse(result.modified)

def test_mixed_users_no_change(self) -> None:
"""Permute with one slice user and one non-slice user → no change."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
permuted = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
)
sliced = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(permuted, 1, 0, 2, 1),
)
neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(permuted,))
builder.output([sliced, neg])
original = builder.get_graph_module()

result = cast(PassResult, MoveSliceBeforePermutePass()(original))
self.assertFalse(result.modified)

def test_no_slice_users_no_change(self) -> None:
"""Permute with no slice users → no change."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
permuted = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
)
neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(permuted,))
builder.output([neg])
original = builder.get_graph_module()

result = cast(PassResult, MoveSliceBeforePermutePass()(original))
self.assertFalse(result.modified)

def test_dim0_slice_large_reduction_moved(self) -> None:
"""Dim-0 slice removing >50% of data → profitable, moved."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(10, 3, 4, 5))
permuted = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
)
sliced = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(permuted, 0, 0, 2, 1),
)
builder.output([sliced])
original = builder.get_graph_module()

result = transform_and_check_numerics(
original,
(torch.randn(10, 3, 4, 5),),
MoveSliceBeforePermutePass(),
)
self.assertTrue(result.modified)

nodes = get_compute_nodes_in_gm(result.graph_module)
self.assertEqual(len(nodes), 2)
self.assertEqual(nodes[0], exir_ops.edge.aten.slice_copy)
self.assertEqual(nodes[1], exir_ops.edge.aten.permute_copy)

def test_dim0_slice_small_reduction_not_moved(self) -> None:
"""Dim-0 slice removing <50% of data → not profitable, kept."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(10, 3, 4, 5))
permuted = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
)
sliced = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(permuted, 0, 0, 8, 1),
)
builder.output([sliced])
original = builder.get_graph_module()

result = cast(PassResult, MoveSliceBeforePermutePass()(original))
self.assertFalse(result.modified)

def test_non_dim0_slice_always_moved(self) -> None:
"""Non-dim-0 slice → always profitable, moved regardless of reduction."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(10, 3, 4, 5))
permuted = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
)
sliced = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(permuted, 2, 0, 3, 1),
)
builder.output([sliced])
original = builder.get_graph_module()

result = transform_and_check_numerics(
original,
(torch.randn(10, 3, 4, 5),),
MoveSliceBeforePermutePass(),
)
self.assertTrue(result.modified)
Loading