Skip to content

Commit 9c7cb61

Browse files
NXP backend: Update cat delegation. (#14580)
### Summary This PR updates the `aten.cat.default` delegation condition to more accurately match the requirements of Neutron. ### Test plan Unit tests provided. cc @robert-kalmar
1 parent 65a295e commit 9c7cb61

File tree

3 files changed

+129
-18
lines changed

3 files changed

+129
-18
lines changed

backends/nxp/backend/edge_helper.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,6 @@
2222
]
2323

2424

25-
def _is_dequantize(node_: Node) -> bool:
26-
return node_.op == "call_function" and node_.target in DEQUANTIZE_OPERATORS
27-
28-
29-
def _is_quantize(node_: Node) -> bool:
30-
return node_.op == "call_function" and node_.target in QUANTIZE_OPERATORS
31-
32-
3325
def input_tensor(node: Node, input_index: int) -> torch.Tensor:
3426
if len(node.all_input_nodes) <= input_index:
3527
raise IndexError
@@ -103,3 +95,33 @@ def try_get_tensor_constant_from_node(
10395
return None
10496
attr_itr = getattr(attr_itr, atom)
10597
return attr_itr
98+
99+
100+
def _is_dequantize(node_: Node) -> bool:
101+
return node_.op == "call_function" and node_.target in [
102+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
103+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
104+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
105+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
106+
]
107+
108+
109+
def _is_quantize(node_: Node) -> bool:
110+
return node_.op == "call_function" and node_.target in [
111+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
112+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
113+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
114+
torch.ops.quantized_decomposed.quantize_per_channel.default,
115+
]
116+
117+
118+
def previous_non_qdq_node(node: Node, input_index: int = 0) -> Node | None:
119+
"""Return the first node which is not a `quantize` or `dequantize`, found by traversing the graph backwards
120+
starting with the `node.args[input_index]`,
121+
"""
122+
current_node = node.args[input_index]
123+
while True:
124+
if _is_quantize(current_node) or _is_dequantize(current_node):
125+
current_node = current_node.args[0]
126+
else:
127+
return current_node

backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from executorch.backends.nxp.backend.custom_delegation_options import (
99
CustomDelegationOptions,
1010
)
11+
from executorch.backends.nxp.backend.edge_helper import previous_non_qdq_node
1112
from executorch.backends.nxp.backend.ir.converter.conversion import translator
1213
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
14+
apply_permutation_to,
1315
create_channels_first_to_channels_last_permutation,
1416
)
1517
from executorch.backends.nxp.backend.ir.converter.node_converter import (
@@ -23,6 +25,7 @@
2325
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
2426
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
2527
from torch.fx import Node
28+
from torch.fx.passes.infra.partitioner import Partition
2629
from torch.nn import Parameter
2730

2831

@@ -85,10 +88,6 @@ def _is_supported_on_target(
8588

8689
dim = CatConverter._get_normalized_dim(node)
8790

88-
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1491
89-
if dim == 0:
90-
return False
91-
9291
# Neutron requires the channels to be a multiple of `num_macs`. The channels could either be the second or the
9392
# last dimension, depending on the formats of the node.
9493
if node.meta[NXP_NODE_FORMAT].is_channels_first():
@@ -151,6 +150,46 @@ def _is_supported_in_IR(
151150

152151
return True
153152

153+
@classmethod
154+
def supports_partitioning_result(
155+
cls,
156+
node: Node,
157+
partition_list: list[Partition],
158+
custom_delegation_options: CustomDelegationOptions,
159+
):
160+
# There is a bug in the NeutronConverter, where if none of the input dimensions before the one referenced by
161+
# `dim` are `!= 1`, the `Concat` is not delegated.
162+
# This only happens when the inputs to the `Concat` are model inputs, and not outputs of other
163+
# operators.
164+
cat_partition = [p for p in partition_list if node in p.nodes][0]
165+
cat_inputs = map(previous_non_qdq_node, node.args[0])
166+
167+
if not all(
168+
input_.op == "call_function" and input_ in cat_partition.nodes
169+
for input_ in cat_inputs
170+
):
171+
# Some inputs of the `cat` are NOT in the same partition as `cat`.
172+
dim = CatConverter._get_normalized_dim(node)
173+
input_shapes = [list(n.meta["val"].shape) for n in node.args[0]]
174+
if node.meta[NXP_NODE_FORMAT].is_channels_first():
175+
# Transform the shapes to channels last.
176+
to_nhwc_perm = create_channels_first_to_channels_last_permutation(
177+
len(node.meta["val"].shape), True
178+
)
179+
input_shapes = [
180+
apply_permutation_to(shape, to_nhwc_perm) for shape in input_shapes
181+
]
182+
183+
# Transform the `dim` to refer to a channels last dimension.
184+
dim = to_nhwc_perm.index(dim)
185+
186+
for input_shape in input_shapes:
187+
if not any(d != 1 for d in input_shape[:dim]):
188+
# Do not delegate if there are no "non-1" dimensions in the shape before the `dim` dimension.
189+
return False
190+
191+
return True
192+
154193
def convert(self, node: Node):
155194
"""Convert the 'aten.cat' operator to TFLite 'Concatenation'."""
156195
self.assert_convertible(node)

backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ def forward(self, *inputs: torch.Tensor):
4444
return torch.cat(list(inputs), self.dim)
4545

4646

47+
class AddCatModule(torch.nn.Module):
48+
49+
def __init__(self, dim: int):
50+
super().__init__()
51+
self.dim = dim
52+
53+
def forward(self, *inputs: torch.Tensor):
54+
inputs = [input_ + input_ for input_ in inputs]
55+
56+
return torch.cat(list(inputs), self.dim)
57+
58+
4759
class CatConvModule(torch.nn.Module):
4860

4961
def __init__(self, dim: int, channels: int = 4):
@@ -73,7 +85,7 @@ def forward(self, *inputs: torch.Tensor):
7385
],
7486
)
7587
def test_cat__same_shapes(dim, num_inputs, rank, mocker):
76-
input_shape = tuple([2, 8, 8, 8, 8][-rank:])
88+
input_shape = tuple([8, 8, 8, 8][:rank])
7789

7890
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
7991

@@ -134,11 +146,23 @@ def test_cat__channels_first__same_shapes(dim, num_inputs, mocker):
134146
)
135147

136148

137-
@pytest.mark.parametrize("dim", [0, -4])
138-
@pytest.mark.parametrize("num_inputs", [2])
139-
def test_cat__unsupported_dim__imxrt700(dim, num_inputs):
140-
input_shape = (2, 8, 6, 8)
141-
149+
@pytest.mark.parametrize(
150+
"dim, input_shape",
151+
[
152+
pytest.param(0, (1, 8, 8, 8), id="axis = 0"),
153+
pytest.param(0, (8, 8, 8, 8), id="axis = 0, no `1s` in the shape."),
154+
pytest.param(-4, (1, 8, 8, 8), id="axis = -4"),
155+
pytest.param(1, (1, 1, 8, 8), id="axis = 1"),
156+
pytest.param(-3, (1, 1, 8, 8), id="axis = -3"),
157+
pytest.param(2, (1, 1, 1, 8), id="axis = 2"),
158+
pytest.param(-2, (1, 1, 1, 8), id="axis = -2"),
159+
],
160+
)
161+
def test_cat__unsupported__imxrt700(dim, input_shape):
162+
"""This test is conjoined with the one below (`test_cat__context_dependent__imxrt700`).
163+
In this case, the inputs of the `cat` are NOT compute ops, so the `cat` is NOT delegated.
164+
"""
165+
num_inputs = 2
142166
quantized_program = to_quantized_edge_program(
143167
CatModule(dim), [input_shape] * num_inputs, target="imxrt700"
144168
).exported_program()
@@ -152,6 +176,32 @@ def test_cat__unsupported_dim__imxrt700(dim, num_inputs):
152176
)
153177

154178

179+
@pytest.mark.parametrize(
180+
"dim, input_shape",
181+
[
182+
pytest.param(0, (1, 8, 8, 8), id="axis = 0"),
183+
pytest.param(0, (8, 8, 8, 8), id="axis = 0, no `1s` in the shape."),
184+
pytest.param(-4, (1, 8, 8, 8), id="axis = -4"),
185+
pytest.param(1, (1, 1, 8, 8), id="axis = 1"),
186+
pytest.param(-3, (1, 1, 8, 8), id="axis = -3"),
187+
pytest.param(2, (1, 1, 1, 8), id="axis = 2"),
188+
pytest.param(-2, (1, 1, 1, 8), id="axis = -2"),
189+
],
190+
)
191+
def test_cat__context_dependent__imxrt700(dim, input_shape):
192+
"""This test is conjoined with the one above (`test_cat__unsupported__imxrt700`).
193+
In this case, the inputs of the `cat` are compute ops, so the `cat` is delegated.
194+
"""
195+
num_inputs = 2
196+
ep = to_quantized_edge_program(
197+
AddCatModule(dim), [input_shape] * num_inputs, target="imxrt700"
198+
).exported_program()
199+
200+
# Make sure the `Cat` was delegated.
201+
assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.cat.default])
202+
assert any("lowered_module" in node.name for node in ep.graph.nodes)
203+
204+
155205
@pytest.mark.parametrize(
156206
"rank, num_inputs, dim",
157207
[

0 commit comments

Comments
 (0)