Skip to content

Commit c4c943f

Browse files
martinlsmMartin Lindström
andauthored
Arm backend: Make remaining passes inherit from ArmPass (#15764)
cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Co-authored-by: Martin Lindström <Martin.Lindstroem@arm.com>
1 parent 7600df8 commit c4c943f

10 files changed

+83
-17
lines changed

backends/arm/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
6565
from .decompose_remainder_pass import DecomposeRemainderPass # noqa
6666
from .decompose_round_pass import DecomposeRoundPass # noqa
67+
from .decompose_sdpa_pass import DecomposeScaledDotProductAttention # noqa
6768
from .decompose_select import DecomposeSelectPass # noqa
6869
from .decompose_sign_pass import DecomposeSignPass # noqa
6970
from .decompose_silu_pass import DecomposeSiluPass # noqa
@@ -83,6 +84,7 @@
8384
from .fuse_duplicate_users_pass import FuseDuplicateUsersPass # noqa
8485
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
8586
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
87+
from .fuse_view_copy_transform_pass import FuseViewCopyTransformPass # noqa
8688
from .insert_int32_casts_after_int64_placeholders import ( # noqa
8789
InsertInt32CastsAfterInt64PlaceholdersPass,
8890
)
@@ -91,6 +93,8 @@
9193
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
9294
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
9395
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
96+
from .remove_getitem_pass import RemoveGetItemPass # noqa
97+
from .remove_graph_asserts_pass import RemoveGraphAssertsPass # noqa
9498
from .remove_noop_pass import RemoveNoopPass # noqa
9599
from .replace_scalar_with_tensor_pass import ( # noqa
96100
ReplaceScalarWithTensorByProfilePass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
DecomposeNotEqualPass,
6767
DecomposeRemainderPass,
6868
DecomposeRoundPass,
69+
DecomposeScaledDotProductAttention,
6970
DecomposeSelectPass,
7071
DecomposeSignPass,
7172
DecomposeSiluPass,
@@ -82,13 +83,16 @@
8283
FuseDuplicateUsersPass,
8384
FuseEqualPlaceholdersPass,
8485
FuseQuantizedActivationPass,
86+
FuseViewCopyTransformPass,
8587
InsertInt32CastsAfterInt64PlaceholdersPass,
8688
InsertRescaleInt32Pass,
8789
InsertRescalePass,
8890
InsertTableOpsPass,
8991
MatchArgDtypePass,
9092
MatchArgRanksPass,
9193
QuantizeOperatorArguments,
94+
RemoveGetItemPass,
95+
RemoveGraphAssertsPass,
9296
RemoveNoopPass,
9397
ReplaceInfValues,
9498
ReplaceScalarWithTensorByProfilePass,
@@ -107,14 +111,8 @@
107111
TosaLoweringContext,
108112
TosaSpecification,
109113
)
110-
from executorch.backends.transforms.decompose_sdpa import (
111-
DecomposeScaledDotProductAttention,
112-
)
113-
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
114-
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
115114
from executorch.exir import ExportedProgram
116115
from executorch.exir.pass_manager import PassManager
117-
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
118116
from torch.fx import GraphModule
119117
from torch.fx.passes.infra.pass_base import PassResult
120118
from torch.nn.modules import Module
@@ -258,7 +256,7 @@ def _tosa_pipeline(
258256
self.add_pass(CastToInt32Pass())
259257
self.add_pass(BroadcastArgsPass())
260258
self.add_pass(ConvertPermuteSingletonToViewPass())
261-
self.add_pass(FuseViewCopyTransform())
259+
self.add_pass(FuseViewCopyTransformPass())
262260
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
263261
self.add_pass(DecomposeSumPass())
264262
self.add_pass(InsertTableOpsPass(exported_program))

backends/arm/_passes/convert_permute_singleton_to_view_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from typing import Sequence, Set, Tuple, Type
88

9+
from executorch.backends.arm._passes.arm_pass import ArmPass
10+
911
from executorch.exir.dialects._ops import ops as exir_ops
1012
from executorch.exir.pass_base import ExportPass
1113

@@ -18,7 +20,7 @@
1820
)
1921

2022

21-
class ConvertPermuteSingletonToViewPass(ExportPass):
23+
class ConvertPermuteSingletonToViewPass(ArmPass):
2224
"""Replace permutations that only move singleton axes with a reshape.
2325
2426
Examples:

backends/arm/_passes/convert_squeezes_to_view.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -8,9 +7,9 @@
87
from typing import Set, Type
98

109
from executorch.backends.arm._passes import ArmPass
11-
12-
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
13-
10+
from executorch.backends.arm._passes.fuse_view_copy_transform_pass import (
11+
FuseViewCopyTransformPass,
12+
)
1413
from executorch.exir.dialects._ops import ops as exir_ops
1514
from executorch.exir.pass_base import ExportPass
1615

@@ -20,7 +19,7 @@ class ConvertSqueezesToViewPass(ArmPass):
2019
Replaces squeeze/unsqueeze operators with view. These are simply special cases of the view op, so removing them gives us less cases to handle in the node visitiors.
2120
"""
2221

23-
_passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransform}
22+
_passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransformPass}
2423

2524
def call_operator(self, op, args, kwargs, meta):
2625
if op not in [

backends/arm/_passes/decompose_embedding_pass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010

1111
import torch
1212
from executorch.backends.arm._passes import ArmPass
13-
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
13+
from executorch.backends.arm._passes.fuse_view_copy_transform_pass import (
14+
FuseViewCopyTransformPass,
15+
)
1416
from executorch.exir.dialects._ops import ops as exir_ops
1517
from executorch.exir.pass_base import ExportPass, PassResult
1618

@@ -33,7 +35,7 @@ class DecomposeEmbeddingPass(ArmPass):
3335
i = indices is expected to be int32 before this pass
3436
"""
3537

36-
_passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransform}
38+
_passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransformPass}
3739

3840
aten_ops = (torch.ops.aten.embedding.default,)
3941
edge_ops = (exir_ops.edge.aten.embedding.default,)

backends/arm/_passes/decompose_int16_activation_conv2d_pass.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,27 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
from typing import cast
7+
from typing import cast, Set, Type
88

99
import torch
10+
from executorch.backends.arm._passes.arm_pass import ArmPass
1011
from executorch.backends.arm._passes.quant_args import QuantArgs
1112

1213
from executorch.backends.arm.tosa.specification import get_context_spec, Tosa_1_00
1314
from executorch.exir.dialects._ops import ops as exir_ops
1415
from executorch.exir.pass_base import ExportPass
1516

1617

17-
class DecomposeConv2dWithInt16ActivationPass(ExportPass):
18+
class DecomposeConv2dWithInt16ActivationPass(ArmPass):
1819
"""
1920
This pass decomposes a convolution with input dtype int16 and bias
2021
into a convolution without bias followed by an addition of the bias
2122
since the TOSA op requires the bias to be int48 which is hard to represent
2223
in torch. Instead rescale the int48 output to int16 and add the bias in int16.
2324
"""
2425

26+
_passes_required_after: Set[Type[ExportPass]] = set()
27+
2528
def call_operator(self, op, args, kwargs, meta):
2629
if op != exir_ops.edge.aten.convolution.default:
2730
return super().call_operator(op, args, kwargs, meta)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
from executorch.backends.arm._passes.arm_pass import ArmPass
9+
from executorch.backends.transforms import decompose_sdpa
10+
from executorch.exir.pass_base import ExportPass
11+
12+
13+
class DecomposeScaledDotProductAttention(
14+
ArmPass, decompose_sdpa.DecomposeScaledDotProductAttention
15+
):
16+
_passes_required_after: Set[Type[ExportPass]] = set()
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
from executorch.backends.arm._passes.arm_pass import ArmPass
9+
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
10+
from executorch.exir.pass_base import ExportPass
11+
12+
13+
class FuseViewCopyTransformPass(ArmPass, FuseViewCopyTransform):
14+
_passes_required_after: Set[Type[ExportPass]] = set()
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
from executorch.backends.arm._passes.arm_pass import ArmPass
9+
from executorch.backends.transforms import remove_getitem_op
10+
from executorch.exir.pass_base import ExportPass
11+
12+
13+
class RemoveGetItemPass(ArmPass, remove_getitem_op.RemoveGetItemPass):
14+
_passes_required_after: Set[Type[ExportPass]] = set()
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
from executorch.backends.arm._passes.arm_pass import ArmPass
9+
from executorch.exir.pass_base import ExportPass
10+
from executorch.exir.passes import remove_graph_asserts_pass
11+
12+
13+
class RemoveGraphAssertsPass(remove_graph_asserts_pass.RemoveGraphAssertsPass, ArmPass):
14+
_passes_required_after: Set[Type[ExportPass]] = set()

0 commit comments

Comments
 (0)