66
77# pyre-strict
88
9- import logging
109from dataclasses import dataclass , field
11- from typing import cast , List , Optional , Sequence , Set , Type
10+ from typing import cast , List , Optional , Set , Type
1211
1312# Import these for the cadence function signatures.
1413import executorch .backends .cadence .aot .ops_registrations # noqa: F401
@@ -69,45 +68,57 @@ class RemoveRedundantOps:
6968
7069
7170@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
72- class RemoveZeroSizedCatArgsPass (ExportPass ):
73- def call_operator (
74- self ,
75- op , # pyre-ignore
76- args : tuple [Argument , ...],
77- kwargs : dict [str , Argument ],
78- meta : NodeMetadata ,
79- ) -> ProxyValue :
80- if op != exir_ops .edge .aten .cat .default :
81- return super ().call_operator (op , args , kwargs , meta )
82-
83- # Remove any zero-sized tensor arg to form a new args list.
84- cat_inputs : list [ProxyValue ] = []
85- for arg in cast (Sequence [ProxyValue ], args [0 ]):
86- if arg .to_tensor ().numel () > 0 :
87- cat_inputs .append (arg )
71+ class RemoveZeroSizedCatArgsPass (RemoveOrReplacePassInterface ):
72+ @property
73+ def targets (self ) -> list [EdgeOpOverload ]:
74+ return [exir_ops .edge .aten .cat .default ]
8875
89- # If all the tensors were empty, we just return an empty tensor with
90- # the right shape.
76+ def maybe_remove_or_replace (self , node : Node ) -> bool :
77+ # Get the cat inputs (first argument is a list of tensors)
78+ cat_inputs_arg = node .args [0 ]
79+
80+ # Assert that cat_inputs_arg is iterable
81+ assert isinstance (
82+ cat_inputs_arg , (list , tuple )
83+ ), "cat_inputs_arg must be a sequence type"
84+
85+ # Filter out zero-sized tensors
86+ cat_inputs : list [Node ] = []
87+ for arg in cat_inputs_arg :
88+ if isinstance (arg , Node ) and arg .meta .get ("val" ) is not None :
89+ if arg .meta ["val" ].numel () > 0 :
90+ cat_inputs .append (arg )
91+
92+ # If all tensors were empty, create a full op with the right shape
9193 if not cat_inputs :
92- empty_shape = meta ["val" ].shape
93- dtype = meta ["val" ].dtype
94- return super ().call_operator (
95- exir_ops .edge .aten .full .default ,
96- (tuple (empty_shape ), 0 ),
97- {"dtype" : dtype },
98- meta ,
99- )
94+ empty_shape = node .meta ["val" ].shape
95+ dtype = node .meta ["val" ].dtype
96+ # Create a new full node
97+ with node .graph .inserting_before (node ):
98+ full_node = node .graph .call_function (
99+ exir_ops .edge .aten .full .default ,
100+ args = (tuple (empty_shape ), 0 ),
101+ kwargs = {"dtype" : dtype },
102+ )
103+ full_node .meta = node .meta .copy ()
104+ node .replace_all_uses_with (full_node )
105+ return True
100106
101- # If there was only one tensor in the cat_inputs list,
102- # we can safely erase this cat op.
107+ # If only one tensor remains, replace with it
103108 if len (cat_inputs ) == 1 :
104- return cat_inputs [0 ]
109+ node .replace_all_uses_with (cat_inputs [0 ])
110+ return True
111+
112+ # If the number of inputs changed, update the cat args
113+ if len (cat_inputs ) < len (cat_inputs_arg ):
114+ # Update the first argument with filtered inputs
115+ new_args = list (node .args )
116+ new_args [0 ] = cat_inputs
117+ node .args = tuple (new_args )
118+ return True
105119
106- # Otherwise, we replace args[0] with cat_inputs.
107- new_args = list (args )
108- # pyre error introduced after D66937105
109- new_args [0 ] = cat_inputs # pyre-ignore[6]
110- return super ().call_operator (op , tuple (new_args ), kwargs , meta )
120+ # No changes needed
121+ return False
111122
112123
113124@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
@@ -151,25 +162,29 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
151162
152163
153164@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
154- class RemoveZeroSizedConstantPadNd (ExportPass ):
155- def call_operator (
156- self ,
157- op , # pyre-ignore
158- args : tuple [ProxyValue , tuple [int , ...], Argument ],
159- kwargs : dict [str , Argument ],
160- meta : NodeMetadata ,
161- ) -> ProxyValue :
162- if op != exir_ops .edge .aten .constant_pad_nd .default :
163- return super ().call_operator (op , args , kwargs , meta )
165+ class RemoveZeroSizedConstantPadNd (RemoveOrReplacePassInterface ):
166+ @property
167+ def targets (self ) -> list [EdgeOpOverload ]:
168+ return [exir_ops .edge .aten .constant_pad_nd .default ]
164169
165- input_tensor = args [0 ]
166- padding = args [1 ]
170+ def maybe_remove_or_replace (self , node : Node ) -> bool :
171+ # Get padding argument (second argument)
172+ if len (node .args ) < 2 :
173+ return False
174+
175+ padding = node .args [1 ]
176+ if not isinstance (padding , (list , tuple )):
177+ return False
167178
179+ # If any padding value is non-zero, keep the node
168180 if any (x != 0 for x in padding ):
169- return super (). call_operator ( op , args , kwargs , meta )
181+ return False
170182
171- logging .debug (f"Erasing 0 sized constant pad nd node with { input_tensor } " )
172- return input_tensor
183+ # All padding is zero, replace with input
184+ input_node = node .args [0 ]
185+ assert isinstance (input_node , Node )
186+ node .replace_all_uses_with (input_node )
187+ return True
173188
174189
175190@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
@@ -721,27 +736,27 @@ def get_squeeze_indices(self, view_node: Node) -> List[int]:
721736
722737 return squeeze_indices
723738
724- def handle_squeeze (self , view_node : Node , visited_view_nodes : Set [Node ]) -> None :
739+ def handle_squeeze (self , view_node : Node , visited_view_nodes : Set [Node ]) -> bool :
725740 if view_node in visited_view_nodes :
726- return
741+ return False
727742
728743 squeeze_indices = self .get_squeeze_indices (view_node )
729744 if not squeeze_indices :
730- return
745+ return False
731746
732747 # Only handle simple chains for now.
733748 if len (view_node .users ) != 1 :
734- return
749+ return False
735750 node = next (iter (view_node .users ))
736751
737752 # Traverse down from the node until finding another view op.
738753 intermediate_slices = []
739754 while node .target != exir_ops .edge .aten .view_copy .default :
740755 # Only handle simple chains for now
741756 if len (node .users ) != 1 :
742- return
757+ return False
743758 if node .target not in self .intermediate_ops :
744- return
759+ return False
745760 if node .target == exir_ops .edge .aten .slice_copy .Tensor :
746761 intermediate_slices .append (node )
747762 node = next (iter (node .users ))
@@ -764,18 +779,22 @@ def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> None
764779 # Skip the initial view node.
765780 input_node = cast (Node , get_arg (view_node , "input" ))
766781 view_node .replace_all_uses_with (input_node )
782+ return True
767783
768784 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
769785 visited_view_nodes = set ()
786+ modified = False
770787 for view_node in graph_module .graph .find_nodes (
771788 op = "call_function" , target = exir_ops .edge .aten .view_copy .default , sort = True
772789 ):
773- self .handle_squeeze (view_node , visited_view_nodes )
790+ modified |= self .handle_squeeze (view_node , visited_view_nodes )
774791
775- graph_module .graph .eliminate_dead_code ()
776- graph_module .recompile ()
792+ if modified :
793+ graph_module .graph .eliminate_dead_code ()
794+ graph_module .recompile ()
795+ return super ().call (graph_module )
777796
778- return super (). call ( graph_module )
797+ return PassResult ( graph_module , False )
779798
780799
781800@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
@@ -798,23 +817,27 @@ class RemoveBranchedQuantDequant(ExportPass):
798817 }
799818
800819 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
801- self .remove_branched (
820+ modified = self .remove_branched (
802821 graph_module , self .quantize_op_packets , self .dequantize_op_packets
803822 )
804- self .remove_branched (
823+ modified |= self .remove_branched (
805824 graph_module , self .dequantize_op_packets , self .quantize_op_packets
806825 )
807826
808- graph_module .graph .eliminate_dead_code ()
809- result = super ().call (graph_module )
810- return result
827+ if modified :
828+ graph_module .graph .eliminate_dead_code ()
829+ result = super ().call (graph_module )
830+ return result
831+
832+ return PassResult (graph_module , False )
811833
812834 def remove_branched (
813835 self ,
814836 graph_module : torch .fx .GraphModule ,
815837 producer_pkts : set [EdgeOpOverloadPacket ],
816838 consumer_pkts : set [EdgeOpOverloadPacket ],
817- ) -> None :
839+ ) -> bool :
840+ modified = False
818841 for node in graph_module .graph .nodes :
819842 if (
820843 node .op != "call_function"
@@ -838,61 +861,62 @@ def remove_branched(
838861 continue
839862
840863 user .replace_all_uses_with (node .args [0 ])
864+ modified = True
841865
866+ return modified
842867
843- class RemoveCatFromSliceCopyPass (ExportPass ):
868+
869+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
870+ class RemoveCatFromSliceCopyPass (RemoveOrReplacePassInterface ):
844871 """
845872 Simplifies cat->slice_copy chains where one of the cat inputs can be directly passed
846873 to the slice_copy.
847874 """
848875
849- def _remove_unused_cat (self , graph_module : torch .fx .GraphModule ) -> None :
850- for slice_copy_node in graph_module .graph .find_nodes (
851- op = "call_function" , target = exir_ops .edge .aten .slice_copy .Tensor
852- ):
853- cat_node = cast (Node , get_arg (slice_copy_node , "input" ))
854- slice_dim = cast (int , get_arg (slice_copy_node , "dim" ))
855- start_idx = cast (int , get_arg (slice_copy_node , "start" ))
856- end_idx = cast (int , get_arg (slice_copy_node , "end" ))
857- step = cast (int , get_arg (slice_copy_node , "step" ))
876+ @property
877+ def targets (self ) -> list [EdgeOpOverload ]:
878+ return [exir_ops .edge .aten .slice_copy .Tensor ]
858879
859- if cat_node .target != exir_ops .edge .aten .cat .default or step != 1 :
860- continue
880+ def maybe_remove_or_replace (self , node : Node ) -> bool :
881+ cat_node = cast (Node , get_arg (node , "input" ))
882+ slice_dim = cast (int , get_arg (node , "dim" ))
883+ start_idx = cast (int , get_arg (node , "start" ))
884+ end_idx = cast (int , get_arg (node , "end" ))
885+ step = cast (int , get_arg (node , "step" ))
861886
862- # Make sure cat and slice happens on the same dimension.
863- cat_dim = cast (Node , get_arg (cat_node , "dim" ))
864- if cat_dim != slice_dim :
865- continue
887+ if cat_node .target != exir_ops .edge .aten .cat .default or step != 1 :
888+ return False
889+
890+ # Make sure cat and slice happens on the same dimension.
891+ cat_dim = cast (int , get_arg (cat_node , "dim" ))
892+ if cat_dim != slice_dim :
893+ return False
866894
867- # Canonicalize slice indices.
868- cat_output_shape = cat_node .meta ["val" ].shape
869- if start_idx is None :
870- start_idx = 0
871- elif start_idx < 0 :
872- start_idx += cat_output_shape [cat_dim ]
873- if end_idx is None or end_idx > cat_output_shape [cat_dim ]:
874- end_idx = cat_output_shape [cat_dim ]
875- elif end_idx < 0 :
876- end_idx += cat_output_shape [cat_dim ]
877-
878- offset = 0
879- for cat_input_node in cast (List [Node ], get_arg (cat_node , "tensors" )):
880- cat_input_shape = cat_input_node .meta ["val" ].shape
881-
882- # Check if the slice range overlaps with the cat input range.
883- if offset <= start_idx and end_idx <= offset + cat_input_shape [cat_dim ]:
884- slice_copy_node .replace_input_with (cat_node , cat_input_node )
885- set_arg (slice_copy_node , "start" , start_idx - offset )
886- set_arg (slice_copy_node , "end" , end_idx - offset )
887- break
888-
889- offset += cat_input_shape [cat_dim ]
895+ # Canonicalize slice indices.
896+ cat_output_shape = cat_node .meta ["val" ].shape
897+ if start_idx is None :
898+ start_idx = 0
899+ elif start_idx < 0 :
900+ start_idx += cat_output_shape [cat_dim ]
901+ if end_idx is None or end_idx > cat_output_shape [cat_dim ]:
902+ end_idx = cat_output_shape [cat_dim ]
903+ elif end_idx < 0 :
904+ end_idx += cat_output_shape [cat_dim ]
905+
906+ offset = 0
907+ for cat_input_node in cast (List [Node ], get_arg (cat_node , "tensors" )):
908+ cat_input_shape = cat_input_node .meta ["val" ].shape
909+
910+ # Check if the slice range overlaps with the cat input range.
911+ if offset <= start_idx and end_idx <= offset + cat_input_shape [cat_dim ]:
912+ node .replace_input_with (cat_node , cat_input_node )
913+ set_arg (node , "start" , start_idx - offset )
914+ set_arg (node , "end" , end_idx - offset )
915+ return True
890916
891- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
892- self ._remove_unused_cat (graph_module )
893- graph_module .recompile ()
894- graph_module .graph .eliminate_dead_code ()
895- return super ().call (graph_module )
917+ offset += cat_input_shape [cat_dim ]
918+
919+ return False
896920
897921
898922class CommonRemovePasses :
@@ -901,7 +925,6 @@ class CommonRemovePasses:
901925 RemoveAliasCopyOpPass ,
902926 RemoveNopExpandOpPass ,
903927 RemoveNopSliceOrViewOpPass ,
904- RemoveNopSelectOpPass ,
905928 RemoveToOpsPass ,
906929 RemoveZeroSizedCatArgsPass ,
907930 RemovePermutesAroundElementwiseOps ,
0 commit comments