Skip to content

Commit 65a295e

Browse files
authored
Update modified bit for a variety of remove ops
Differential Revision: D86355740 Pull Request resolved: #15722
1 parent 5167f37 commit 65a295e

File tree

2 files changed

+153
-128
lines changed

2 files changed

+153
-128
lines changed

backends/cadence/aot/remove_ops.py

Lines changed: 135 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66

77
# pyre-strict
88

9-
import logging
109
from dataclasses import dataclass, field
11-
from typing import cast, List, Optional, Sequence, Set, Type
10+
from typing import cast, List, Optional, Set, Type
1211

1312
# Import these for the cadence function signatures.
1413
import executorch.backends.cadence.aot.ops_registrations # noqa: F401
@@ -69,45 +68,57 @@ class RemoveRedundantOps:
6968

7069

7170
@register_cadence_pass(CadencePassAttribute(opt_level=0))
72-
class RemoveZeroSizedCatArgsPass(ExportPass):
73-
def call_operator(
74-
self,
75-
op, # pyre-ignore
76-
args: tuple[Argument, ...],
77-
kwargs: dict[str, Argument],
78-
meta: NodeMetadata,
79-
) -> ProxyValue:
80-
if op != exir_ops.edge.aten.cat.default:
81-
return super().call_operator(op, args, kwargs, meta)
82-
83-
# Remove any zero-sized tensor arg to form a new args list.
84-
cat_inputs: list[ProxyValue] = []
85-
for arg in cast(Sequence[ProxyValue], args[0]):
86-
if arg.to_tensor().numel() > 0:
87-
cat_inputs.append(arg)
71+
class RemoveZeroSizedCatArgsPass(RemoveOrReplacePassInterface):
72+
@property
73+
def targets(self) -> list[EdgeOpOverload]:
74+
return [exir_ops.edge.aten.cat.default]
8875

89-
# If all the tensors were empty, we just return an empty tensor with
90-
# the right shape.
76+
def maybe_remove_or_replace(self, node: Node) -> bool:
77+
# Get the cat inputs (first argument is a list of tensors)
78+
cat_inputs_arg = node.args[0]
79+
80+
# Assert that cat_inputs_arg is iterable
81+
assert isinstance(
82+
cat_inputs_arg, (list, tuple)
83+
), "cat_inputs_arg must be a sequence type"
84+
85+
# Filter out zero-sized tensors
86+
cat_inputs: list[Node] = []
87+
for arg in cat_inputs_arg:
88+
if isinstance(arg, Node) and arg.meta.get("val") is not None:
89+
if arg.meta["val"].numel() > 0:
90+
cat_inputs.append(arg)
91+
92+
# If all tensors were empty, create a full op with the right shape
9193
if not cat_inputs:
92-
empty_shape = meta["val"].shape
93-
dtype = meta["val"].dtype
94-
return super().call_operator(
95-
exir_ops.edge.aten.full.default,
96-
(tuple(empty_shape), 0),
97-
{"dtype": dtype},
98-
meta,
99-
)
94+
empty_shape = node.meta["val"].shape
95+
dtype = node.meta["val"].dtype
96+
# Create a new full node
97+
with node.graph.inserting_before(node):
98+
full_node = node.graph.call_function(
99+
exir_ops.edge.aten.full.default,
100+
args=(tuple(empty_shape), 0),
101+
kwargs={"dtype": dtype},
102+
)
103+
full_node.meta = node.meta.copy()
104+
node.replace_all_uses_with(full_node)
105+
return True
100106

101-
# If there was only one tensor in the cat_inputs list,
102-
# we can safely erase this cat op.
107+
# If only one tensor remains, replace with it
103108
if len(cat_inputs) == 1:
104-
return cat_inputs[0]
109+
node.replace_all_uses_with(cat_inputs[0])
110+
return True
111+
112+
# If the number of inputs changed, update the cat args
113+
if len(cat_inputs) < len(cat_inputs_arg):
114+
# Update the first argument with filtered inputs
115+
new_args = list(node.args)
116+
new_args[0] = cat_inputs
117+
node.args = tuple(new_args)
118+
return True
105119

106-
# Otherwise, we replace args[0] with cat_inputs.
107-
new_args = list(args)
108-
# pyre error introduced after D66937105
109-
new_args[0] = cat_inputs # pyre-ignore[6]
110-
return super().call_operator(op, tuple(new_args), kwargs, meta)
120+
# No changes needed
121+
return False
111122

112123

113124
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -151,25 +162,29 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
151162

152163

153164
@register_cadence_pass(CadencePassAttribute(opt_level=1))
154-
class RemoveZeroSizedConstantPadNd(ExportPass):
155-
def call_operator(
156-
self,
157-
op, # pyre-ignore
158-
args: tuple[ProxyValue, tuple[int, ...], Argument],
159-
kwargs: dict[str, Argument],
160-
meta: NodeMetadata,
161-
) -> ProxyValue:
162-
if op != exir_ops.edge.aten.constant_pad_nd.default:
163-
return super().call_operator(op, args, kwargs, meta)
165+
class RemoveZeroSizedConstantPadNd(RemoveOrReplacePassInterface):
166+
@property
167+
def targets(self) -> list[EdgeOpOverload]:
168+
return [exir_ops.edge.aten.constant_pad_nd.default]
164169

165-
input_tensor = args[0]
166-
padding = args[1]
170+
def maybe_remove_or_replace(self, node: Node) -> bool:
171+
# Get padding argument (second argument)
172+
if len(node.args) < 2:
173+
return False
174+
175+
padding = node.args[1]
176+
if not isinstance(padding, (list, tuple)):
177+
return False
167178

179+
# If any padding value is non-zero, keep the node
168180
if any(x != 0 for x in padding):
169-
return super().call_operator(op, args, kwargs, meta)
181+
return False
170182

171-
logging.debug(f"Erasing 0 sized constant pad nd node with {input_tensor}")
172-
return input_tensor
183+
# All padding is zero, replace with input
184+
input_node = node.args[0]
185+
assert isinstance(input_node, Node)
186+
node.replace_all_uses_with(input_node)
187+
return True
173188

174189

175190
@register_cadence_pass(CadencePassAttribute(opt_level=1))
@@ -721,27 +736,27 @@ def get_squeeze_indices(self, view_node: Node) -> List[int]:
721736

722737
return squeeze_indices
723738

724-
def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> None:
739+
def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> bool:
725740
if view_node in visited_view_nodes:
726-
return
741+
return False
727742

728743
squeeze_indices = self.get_squeeze_indices(view_node)
729744
if not squeeze_indices:
730-
return
745+
return False
731746

732747
# Only handle simple chains for now.
733748
if len(view_node.users) != 1:
734-
return
749+
return False
735750
node = next(iter(view_node.users))
736751

737752
# Traverse down from the node until finding another view op.
738753
intermediate_slices = []
739754
while node.target != exir_ops.edge.aten.view_copy.default:
740755
# Only handle simple chains for now
741756
if len(node.users) != 1:
742-
return
757+
return False
743758
if node.target not in self.intermediate_ops:
744-
return
759+
return False
745760
if node.target == exir_ops.edge.aten.slice_copy.Tensor:
746761
intermediate_slices.append(node)
747762
node = next(iter(node.users))
@@ -764,18 +779,22 @@ def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> None
764779
# Skip the initial view node.
765780
input_node = cast(Node, get_arg(view_node, "input"))
766781
view_node.replace_all_uses_with(input_node)
782+
return True
767783

768784
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
769785
visited_view_nodes = set()
786+
modified = False
770787
for view_node in graph_module.graph.find_nodes(
771788
op="call_function", target=exir_ops.edge.aten.view_copy.default, sort=True
772789
):
773-
self.handle_squeeze(view_node, visited_view_nodes)
790+
modified |= self.handle_squeeze(view_node, visited_view_nodes)
774791

775-
graph_module.graph.eliminate_dead_code()
776-
graph_module.recompile()
792+
if modified:
793+
graph_module.graph.eliminate_dead_code()
794+
graph_module.recompile()
795+
return super().call(graph_module)
777796

778-
return super().call(graph_module)
797+
return PassResult(graph_module, False)
779798

780799

781800
@register_cadence_pass(CadencePassAttribute(opt_level=1))
@@ -798,23 +817,27 @@ class RemoveBranchedQuantDequant(ExportPass):
798817
}
799818

800819
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
801-
self.remove_branched(
820+
modified = self.remove_branched(
802821
graph_module, self.quantize_op_packets, self.dequantize_op_packets
803822
)
804-
self.remove_branched(
823+
modified |= self.remove_branched(
805824
graph_module, self.dequantize_op_packets, self.quantize_op_packets
806825
)
807826

808-
graph_module.graph.eliminate_dead_code()
809-
result = super().call(graph_module)
810-
return result
827+
if modified:
828+
graph_module.graph.eliminate_dead_code()
829+
result = super().call(graph_module)
830+
return result
831+
832+
return PassResult(graph_module, False)
811833

812834
def remove_branched(
813835
self,
814836
graph_module: torch.fx.GraphModule,
815837
producer_pkts: set[EdgeOpOverloadPacket],
816838
consumer_pkts: set[EdgeOpOverloadPacket],
817-
) -> None:
839+
) -> bool:
840+
modified = False
818841
for node in graph_module.graph.nodes:
819842
if (
820843
node.op != "call_function"
@@ -838,61 +861,62 @@ def remove_branched(
838861
continue
839862

840863
user.replace_all_uses_with(node.args[0])
864+
modified = True
841865

866+
return modified
842867

843-
class RemoveCatFromSliceCopyPass(ExportPass):
868+
869+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
870+
class RemoveCatFromSliceCopyPass(RemoveOrReplacePassInterface):
844871
"""
845872
Simplifies cat->slice_copy chains where one of the cat inputs can be directly passed
846873
to the slice_copy.
847874
"""
848875

849-
def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
850-
for slice_copy_node in graph_module.graph.find_nodes(
851-
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
852-
):
853-
cat_node = cast(Node, get_arg(slice_copy_node, "input"))
854-
slice_dim = cast(int, get_arg(slice_copy_node, "dim"))
855-
start_idx = cast(int, get_arg(slice_copy_node, "start"))
856-
end_idx = cast(int, get_arg(slice_copy_node, "end"))
857-
step = cast(int, get_arg(slice_copy_node, "step"))
876+
@property
877+
def targets(self) -> list[EdgeOpOverload]:
878+
return [exir_ops.edge.aten.slice_copy.Tensor]
858879

859-
if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
860-
continue
880+
def maybe_remove_or_replace(self, node: Node) -> bool:
881+
cat_node = cast(Node, get_arg(node, "input"))
882+
slice_dim = cast(int, get_arg(node, "dim"))
883+
start_idx = cast(int, get_arg(node, "start"))
884+
end_idx = cast(int, get_arg(node, "end"))
885+
step = cast(int, get_arg(node, "step"))
861886

862-
# Make sure cat and slice happens on the same dimension.
863-
cat_dim = cast(Node, get_arg(cat_node, "dim"))
864-
if cat_dim != slice_dim:
865-
continue
887+
if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
888+
return False
889+
890+
# Make sure cat and slice happens on the same dimension.
891+
cat_dim = cast(int, get_arg(cat_node, "dim"))
892+
if cat_dim != slice_dim:
893+
return False
866894

867-
# Canonicalize slice indices.
868-
cat_output_shape = cat_node.meta["val"].shape
869-
if start_idx is None:
870-
start_idx = 0
871-
elif start_idx < 0:
872-
start_idx += cat_output_shape[cat_dim]
873-
if end_idx is None or end_idx > cat_output_shape[cat_dim]:
874-
end_idx = cat_output_shape[cat_dim]
875-
elif end_idx < 0:
876-
end_idx += cat_output_shape[cat_dim]
877-
878-
offset = 0
879-
for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")):
880-
cat_input_shape = cat_input_node.meta["val"].shape
881-
882-
# Check if the slice range overlaps with the cat input range.
883-
if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
884-
slice_copy_node.replace_input_with(cat_node, cat_input_node)
885-
set_arg(slice_copy_node, "start", start_idx - offset)
886-
set_arg(slice_copy_node, "end", end_idx - offset)
887-
break
888-
889-
offset += cat_input_shape[cat_dim]
895+
# Canonicalize slice indices.
896+
cat_output_shape = cat_node.meta["val"].shape
897+
if start_idx is None:
898+
start_idx = 0
899+
elif start_idx < 0:
900+
start_idx += cat_output_shape[cat_dim]
901+
if end_idx is None or end_idx > cat_output_shape[cat_dim]:
902+
end_idx = cat_output_shape[cat_dim]
903+
elif end_idx < 0:
904+
end_idx += cat_output_shape[cat_dim]
905+
906+
offset = 0
907+
for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")):
908+
cat_input_shape = cat_input_node.meta["val"].shape
909+
910+
# Check if the slice range overlaps with the cat input range.
911+
if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
912+
node.replace_input_with(cat_node, cat_input_node)
913+
set_arg(node, "start", start_idx - offset)
914+
set_arg(node, "end", end_idx - offset)
915+
return True
890916

891-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
892-
self._remove_unused_cat(graph_module)
893-
graph_module.recompile()
894-
graph_module.graph.eliminate_dead_code()
895-
return super().call(graph_module)
917+
offset += cat_input_shape[cat_dim]
918+
919+
return False
896920

897921

898922
class CommonRemovePasses:
@@ -901,7 +925,6 @@ class CommonRemovePasses:
901925
RemoveAliasCopyOpPass,
902926
RemoveNopExpandOpPass,
903927
RemoveNopSliceOrViewOpPass,
904-
RemoveNopSelectOpPass,
905928
RemoveToOpsPass,
906929
RemoveZeroSizedCatArgsPass,
907930
RemovePermutesAroundElementwiseOps,

0 commit comments

Comments
 (0)