Skip to content

Commit b84f372

Browse files
authored
Arm backend: Improve support for multiple output ops (#15695)
The pattern of using .name from the TosaArg output is used in essentially all node visitors. Refactoring output to possibly be a list of TosaArgs would be very cumbersome. Instead, use a new multiple_output_names field to provide the names of all getitem tensors. This also means that we don't have to insert identity ops for getitem operators, they just need to insert a tensor. Signed-off-by: Erik Lundell <erik.lundell@arm.com>
1 parent dc03dc9 commit b84f372

File tree

5 files changed

+28
-10
lines changed

5 files changed

+28
-10
lines changed

backends/arm/operators/op_cond_if.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,6 @@ def define_node(
5656
inputs[0].name,
5757
*(subgraph_input.name for subgraph_input in inputs[-1].special),
5858
],
59-
[output.name],
59+
output.multiple_output_names,
6060
attr,
6161
)

backends/arm/operators/ops_identity.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def define_node(
4040
inputs: List[TosaArg],
4141
output: TosaArg,
4242
) -> None:
43-
validate_num_inputs(self.target, inputs, [1, 2])
43+
validate_num_inputs(self.target, inputs, 1)
4444
validate_same_dtype(self.target, [inputs[0], output], ts)
4545

4646
# Simply add an identityOp
@@ -58,5 +58,4 @@ def define_node(
5858
register_node_visitor(IdentityOperatorVisitor)
5959

6060

61-
identity_operator_factory("getitem")
6261
identity_operator_factory("aten.alias_copy.default")

backends/arm/process_node.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
#
66

7+
import operator
78
from typing import Any, cast, Dict
89

910
import numpy as np
@@ -45,9 +46,15 @@ def process_call_function(
4546
f"Failed processing call_function: {node.name}. "
4647
"Is the original torch function supported?"
4748
) from e
48-
tosa_graph.currRegion.currBasicBlock.addTensor(
49-
output.name, tosa_shape(output.shape, output.dim_order), output.dtype
50-
)
49+
50+
if not output.multiple_output_names:
51+
tosa_graph.currRegion.currBasicBlock.addTensor(
52+
output.name, tosa_shape(output.shape, output.dim_order), output.dtype
53+
)
54+
55+
# Get item nodes just add tensors, no node visitor is needed.
56+
if node.target == operator.getitem:
57+
return
5158

5259
# Visiting each Node
5360
# pyre-ignore[16]: Undefined attribute.

backends/arm/test/ops/test_cond.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,7 @@ def _create() -> tuple[torch.nn.Module, input_t2]:
174174
"case",
175175
test_cases,
176176
xfails={
177-
"one_arg_two_outputs": "Multiple outputs is not supported.",
178177
"one_arg_and_scalar_one_output": "Scalars become get_attr nodes that are not supported.",
179-
"two_args_two_outputs": "Nodes with multiple outputs are not properly supported.",
180178
"multiple_one_arg_one_output": "Scalars become get_attr nodes that are not supported.",
181179
},
182180
)

backends/arm/tosa/mapping.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
1111
"""
1212

13+
import operator
1314
from enum import Enum
1415
from typing import Any, Optional, Sequence
1516

@@ -136,7 +137,7 @@ class TosaArg:
136137
special (list | None): Captured list when the argument is a sequence.
137138
number (float | int | None): Captured numeric value when given.
138139
tosa_spec (TosaSpecification): Active specification used for mapping.
139-
140+
multiple_output_name (list[str]): Output node names when node has multiple outputs; empty otherwise.
140141
"""
141142

142143
def __process_node(self, argument: torch.fx.Node):
@@ -146,7 +147,8 @@ def __process_node(self, argument: torch.fx.Node):
146147
argument (torch.fx.Node): FX node to inspect.
147148
148149
"""
149-
self.name = argument.name + argument.meta.get(TOSA_TENSOR_NAME_META, "")
150+
suffix = argument.meta.get(TOSA_TENSOR_NAME_META, "")
151+
self.name = argument.name + suffix
150152

151153
if "val" in argument.meta:
152154
output_dtype, self.shape, self.dim_order = extract_tensor_meta(
@@ -158,6 +160,16 @@ def __process_node(self, argument: torch.fx.Node):
158160

159161
self.dtype = output_dtype
160162

163+
# If all users of the node are getitems, node visitors should connect the output of this node directly to the getitem tensors.
164+
# Add a new attribute 'multiple_output_names' instead of making 'name' a list to avoid ambiguity regarding the type of 'name'.
165+
# Make name of the output is the first getitem since we in most cases only handle that output.
166+
users = list(argument.users)
167+
if len(users) > 0 and all(user.target == operator.getitem for user in users):
168+
self.multiple_output_names: list = [user.name + suffix for user in users]
169+
self.name = self.multiple_output_names[0]
170+
else:
171+
self.multiple_output_names = []
172+
161173
def __process_list(self, argument):
162174
"""Capture a sequence argument as ``special``.
163175
@@ -244,4 +256,6 @@ def __repr__(self):
244256
attrs.append(f"number={self.number!r}")
245257
if hasattr(self, "tosa_spec") and self.tosa_spec is not None:
246258
attrs.append(f"tosa_spec={self.tosa_spec!r}")
259+
if hasattr(self, "names"):
260+
attrs.append(f"names={self.multiple_output_names!r}")
247261
return f"{self.__class__.__name__}({', '.join(attrs)})"

0 commit comments

Comments
 (0)