Skip to content

Commit 6550a37

Browse files
Arm backend: Propagate node info from quantizer to backend (#15300)
Use the Node meta 'custom' field to propagate information from quantizer to partitioner using a new ArmAnnotationInfo data class. This allows us to track quantized node reliably which is useful in order to track which nodes should 'fold' it's quantization parameter and which should be kept in fp when mixing integer and float in a sub-graph. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com> Co-authored-by: Per Åstrand <per.astrand@arm.com>
1 parent 208695c commit 6550a37

File tree

11 files changed

+251
-19
lines changed

11 files changed

+251
-19
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from dataclasses import dataclass
7+
8+
9+
@dataclass(frozen=True)
10+
class ArmAnnotationInfo:
11+
"""
12+
Data class to carry Arm-specific annotation information through the pipeline.
13+
This is intended to be attached to node.meta['custom'] and propagated
14+
through partitioning and backend stages. As it's propagated through the pipeline,
15+
it's intentionally minimal and only carries whether the node is quantized or not.
16+
"""
17+
18+
quantized: bool
19+
CUSTOM_META_KEY: str = "_arm_annotation_info"

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
FuseQuantizedActivationPass,
2222
)
2323
from executorch.backends.arm._passes.insert_table_ops import TableOps
24+
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
2425
from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS
2526
from executorch.backends.arm.operator_support.ethos_u55_support import (
2627
EthosU55CastCheck,
@@ -140,6 +141,7 @@ def tosa_support_factory(
140141
]
141142

142143
if not tosa_spec.support_float():
144+
negative_checks.append(CheckArmQuantized(reporter))
143145
negative_checks.append(CheckProperQuantization(reporter))
144146
if tosa_spec.is_U55_subset:
145147
negative_checks.append(EthosU55NotSupported(reporter))
@@ -167,7 +169,6 @@ class TOSAProINTSupportList(OperatorSupportBase):
167169
def is_node_supported(
168170
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
169171
) -> bool:
170-
171172
return node.op == "call_function" and node.target in TOSA_PRO_INT_SupportList
172173

173174

@@ -180,10 +181,80 @@ class TOSAProFPSupportList(OperatorSupportBase):
180181
def is_node_supported(
181182
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
182183
) -> bool:
183-
184184
return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList
185185

186186

187+
class CheckArmQuantized(OperatorSupportBase):
188+
"""
189+
Check if the node was marked as quantized in the Arm backend.
190+
This is used to ensure that nodes that were quantized in the Arm backend
191+
are only partitioned if they are supported by the TOSA backend.
192+
"""
193+
194+
def __init__(self, reporter: WhyNoPartitionReporter):
195+
self.reporter = reporter
196+
197+
def _is_quantized(self, node: torch.fx.Node) -> bool:
198+
"""Checks if the node is quantized.
199+
200+
A node is considered quantized if at least one criteria is met:
201+
- Its dtype is not floating point or complex => integer
202+
- It is one of the special cases where the node has been created in to_edge, e.g.
203+
.Scalar operations that have been promoted .Tensor operations
204+
where the scalar is replaced by a full op.
205+
- It has been marked as quantized in the ArmAnnotationInfo custom meta.
206+
207+
Args:
208+
node (torch.fx.Node): The FX node to check.
209+
210+
Returns:
211+
bool: True if the node is quantized, False otherwise.
212+
"""
213+
node_dtype = get_first_fake_tensor(node).dtype
214+
if not node_dtype.is_complex and not node_dtype.is_floating_point:
215+
return True
216+
if node.target in (
217+
exir_ops.edge.aten.full_like.default,
218+
*ComputeConstantOpsAOT.targeted_ops,
219+
):
220+
# Special cases where nodes have been created in to_edge, e.g.
221+
# .Scalar operations that have been promoted .Tensor operations
222+
# where the scalar is replaced by a full op.
223+
if all(user.target in Q_OPS for user in node.users):
224+
return True
225+
for user in node.users:
226+
if (
227+
user.target
228+
== exir_ops.edge.dim_order_ops._to_dim_order_copy.default
229+
):
230+
dim_order_dtype = get_first_fake_tensor(user).dtype
231+
if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point:
232+
return False
233+
else:
234+
return False
235+
return True
236+
return (
237+
ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {})
238+
and node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY].quantized
239+
)
240+
241+
def is_node_supported(
242+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
243+
) -> bool:
244+
if node.op != "call_function":
245+
return False
246+
247+
if node.target in (*DQ_OPS, *Q_OPS):
248+
return True
249+
250+
if not self._is_quantized(node):
251+
self.reporter.report_reject(
252+
node, "Node was not marked as quantized in the Arm backend."
253+
)
254+
return False
255+
return True
256+
257+
187258
class CheckProperQuantization(OperatorSupportBase):
188259
"""
189260
For targeted nodes, check that it has been quantized as expected. In most cases this means that a pair of quantize
@@ -427,7 +498,6 @@ def is_node_supported(
427498

428499

429500
class CheckFloat64Inputs(OperatorSupportBase):
430-
431501
def __init__(
432502
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
433503
):

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -14,6 +14,8 @@
1414

1515
from typing import cast
1616

17+
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
18+
1719
from torch.fx import Node
1820

1921
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
@@ -65,4 +67,10 @@ def mark_node_as_annotated(node: Node) -> None:
6567
"""
6668
if Q_ANNOTATION_KEY not in node.meta:
6769
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation()
70+
annotation_info = ArmAnnotationInfo(
71+
quantized=True,
72+
)
6873
node.meta[Q_ANNOTATION_KEY]._annotated = True
74+
meta_custom = node.meta.get("custom", {})
75+
meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = annotation_info
76+
node.meta["custom"] = meta_custom

backends/arm/quantizer/quantization_annotator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ def _match_pattern(
394394
torch.ops.aten.view.default,
395395
torch.ops.aten.view_as.default,
396396
torch.ops.aten.view_copy.default,
397+
torch.ops.aten._unsafe_view.default,
397398
torch.ops.aten.select.int,
398399
torch.ops.aten.select_copy.int,
399400
torch.ops.aten.slice.Tensor,
@@ -426,6 +427,7 @@ def _match_pattern(
426427
]
427428

428429
_one_to_one_shared_input_or_input_act_qspec = [
430+
torch.ops.aten.alias.default,
429431
torch.ops.aten.clone.default,
430432
torch.ops.aten.hardtanh.default,
431433
torch.ops.aten.hardtanh_.default,
@@ -693,10 +695,10 @@ def any_or_hardtanh_min_zero(n: Node):
693695
]
694696
quant_properties.quant_output = None
695697
elif node.target in [
696-
torch.ops.aten.scalar_tensor.default,
697698
torch.ops.aten.full.default,
698699
torch.ops.aten.full,
699700
torch.ops.aten.fill_.Scalar,
701+
torch.ops.aten.scalar_tensor.default,
700702
]:
701703
quant_properties.quant_inputs = []
702704
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)

backends/arm/test/misc/test_int64.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,6 @@ def forward(self, x: torch.Tensor):
6868
ConstAdd(torch.int64, 2**40),
6969
(torch.rand(10) - 0.5,),
7070
),
71-
"int64_in+float_const": (
72-
ConstAdd(torch.float32),
73-
(torch.randint(0, 10, (10,)),),
74-
),
7571
"fp32_in+int64_buffer_chain": (
7672
BufferChainAdd(torch.int64),
7773
(torch.rand(2, 5, 3) - 0.5,),
@@ -94,7 +90,7 @@ def test_int64_tosa_FP(test_data: Tuple):
9490
ArmTester(
9591
model,
9692
inputs,
97-
common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"),
93+
common.get_tosa_compile_spec("TOSA-1.0+FP"),
9894
)
9995
.export()
10096
.to_edge_transform_and_lower()
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm.quantizer import (
8+
get_symmetric_quantization_config,
9+
TOSAQuantizer,
10+
)
11+
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT
12+
from executorch.backends.arm.tosa import TosaSpecification
13+
from executorch.backends.xnnpack.test.tester import Quantize
14+
15+
16+
class AddSigmoidMul(torch.nn.Module):
17+
def __init__(self, *args, **kwargs):
18+
super().__init__(*args, **kwargs)
19+
self.sigmoid = torch.nn.Sigmoid()
20+
21+
def forward(self, x, y):
22+
return self.sigmoid(x + y) * x
23+
24+
25+
def get_selective_quantizer(modules):
26+
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
27+
quantizer.set_global(get_symmetric_quantization_config())
28+
for module in modules:
29+
quantizer.set_module_type(module, None)
30+
31+
return Quantize(quantizer, get_symmetric_quantization_config())
32+
33+
34+
def test_qdq_squeezed_fp_op():
35+
"""Test that a float operation surrounded by quantize-dequantize pairs
36+
is correctly handled by the partitioner and the TOSA backend.
37+
Pattern:
38+
q -> dq -> add -> q -> dq -> sigmoid -> q -> dq -> mul -> dq -> q
39+
|_____Non-delegated____|
40+
"""
41+
aten_op = "torch.ops.aten.add.Tensor"
42+
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
43+
module = AddSigmoidMul()
44+
x = torch.randn(2, 3, 4)
45+
y = torch.randn(2, 3, 4)
46+
pipeline = TosaPipelineINT(
47+
module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op
48+
)
49+
pipeline.change_args("quantize", get_selective_quantizer([torch.nn.Sigmoid]))
50+
pipeline.change_args(
51+
"check_count.exir",
52+
{
53+
"torch.ops.higher_order.executorch_call_delegate": 2,
54+
"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1,
55+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2,
56+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3,
57+
},
58+
)
59+
pipeline.run()
60+
61+
62+
class MulAddSigmoidConv(torch.nn.Module):
63+
def __init__(self, *args, **kwargs):
64+
super().__init__(*args, **kwargs)
65+
self.sigmoid = torch.nn.Sigmoid()
66+
self.conv = torch.nn.Conv1d(3, 3, 1)
67+
68+
def forward(self, x, y):
69+
return self.conv(self.sigmoid(x + y * x))
70+
71+
72+
def test_quantized_to_float_transition():
73+
"""Test that a model executing quantized ops followed by float ops
74+
is correctly handled by the partitioner and the TOSA backend.
75+
Pattern:
76+
q -> dq -> mul -> q -> dq -> add -> q -> dq -> sigmoid -> conv
77+
|____Non-delegated___|
78+
"""
79+
aten_op = "torch.ops.aten.add.Tensor"
80+
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
81+
module = MulAddSigmoidConv()
82+
x = torch.randn(2, 3, 4)
83+
y = torch.randn(2, 3, 4)
84+
pipeline = TosaPipelineINT(
85+
module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op
86+
)
87+
pipeline.change_args(
88+
"quantize", get_selective_quantizer([torch.nn.Sigmoid, torch.nn.Conv1d])
89+
)
90+
pipeline.change_args(
91+
"check_count.exir",
92+
{
93+
"torch.ops.higher_order.executorch_call_delegate": 1,
94+
"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1,
95+
"executorch_exir_dialects_edge__ops_aten_convolution_default": 1,
96+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1,
97+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
98+
},
99+
)
100+
pipeline.run()

backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ class TestSD3Transformer2DModel:
3939

4040
ops_after_partitioner_INT = {
4141
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
42-
"torch.ops.higher_order.executorch_call_delegate": 2,
42+
"torch.ops.higher_order.executorch_call_delegate": 3,
43+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
4344
}
4445

4546
def _prepare_inputs(

backends/arm/test/models/test_nn_functional.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def test_nn_functional_FP(test_data):
102102
@parametrize(
103103
"test_data",
104104
module_tests,
105-
{"normalize": "MLETORCH-1255: Unsupported dtype in InsertTableOpsPass"},
106105
)
107106
def test_nn_functional_INT(test_data):
108107
module, inputs = test_data

backends/arm/test/models/test_torch_functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,12 @@ def test_torch_fns_FP(test_data):
126126
xfails={
127127
"nonzero": "torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u4, 0). "
128128
"Requires dynamic output shape.",
129+
"eye": "ValueError: Failed processing buffer placeholder: aten_arange_start_step_1_pre_computed_common. "
130+
"Is the original torch function supported?",
129131
"topk": "NotImplementedError: No registered serialization name for <class 'torch.return_types.topk'> found",
130132
"sort": "NotImplementedError: No registered serialization name for <class 'torch.return_types.sort'> found",
131133
},
132-
strict=False,
134+
strict=True,
133135
)
134136
def test_torch_fns_INT(test_data):
135137
module, inputs = test_data

backends/arm/test/ops/test_eye.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_eye_u85_INT(test_data: test_data_t):
9595
input_data(),
9696
EyeAdd.aten_op,
9797
use_to_edge_transform_and_lower=True,
98-
).dump_artifact("to_edge_transform_and_lower")
98+
)
9999
pipeline.pop_stage("check.quant_nodes")
100100
pipeline.run()
101101

0 commit comments

Comments
 (0)