Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 63 additions & 35 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ def _get_special_dtype(qspec: QuantArgs) -> TosaSpecialDtype | None:
return None


def _merge_qparams(qspec_1: QuantArgs, qspec_2: QuantArgs) -> QuantArgs:
"""Merge two QuantArgs when inputs are quantized differently.

Requires same dtype; picks the first's parameters by default.

"""
if qspec_1.dtype != qspec_2.dtype:
raise RuntimeError(
f"Cannot merge qparams of different dtypes: {qspec_1.dtype} vs {qspec_2.dtype}"
)
return qspec_1


def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
"""Get the input quantization parameters from a node, set by the
'FoldAndAnnotateQParamsPass'.
Expand Down Expand Up @@ -121,57 +134,72 @@ def __init__(
super().__init__(*args, **kwargs)
self.exported_program = exported_program

def fold_and_annotate_arg(
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
) -> None:
input_qparams = None
nodes_to_remove = set()
def _extract_input_params(
self, arg_list: list[Node]
) -> tuple[Optional[QuantArgs], set[Node]]:
input_qparams: Optional[QuantArgs] = None
nodes_to_remove: set[Node] = set()
for arg in arg_list:
if not isinstance(arg, Node):
return

arg_quant_params = None
return None, set()
arg_quant: Optional[QuantArgs] = None
if arg.target in DQ_OPS:
args = arg.args
scales = args[1]
if (
isinstance(args[1], Node)
isinstance(scales, Node)
and self.exported_program is not None
and is_param_node(self.exported_program, args[1])
and is_param_node(self.exported_program, scales)
):
scales = get_param_tensor(self.exported_program, args[1])
scales = get_param_tensor(self.exported_program, scales)
zps = args[2]
if (
isinstance(args[2], Node)
isinstance(zps, Node)
and self.exported_program is not None
and is_param_node(self.exported_program, args[2])
and is_param_node(self.exported_program, zps)
):
zps = get_param_tensor(self.exported_program, args[2])
arg_quant_params = QuantArgs.from_operator(
zps = get_param_tensor(self.exported_program, zps)
arg_quant = QuantArgs.from_operator(
arg.target, (args[0], scales, zps, *args[3:])
)
# add arg to nodes_to_remove to fold the dq-node
nodes_to_remove.add(arg)
if input_qparams is not None and input_qparams != arg_quant_params:
# Two args are quantized differently
raise RuntimeError("Input qparams do not match")
input_qparams = arg_quant_params
if input_qparams is not None:
node.meta["input_qparams"][i] = input_qparams
for n in nodes_to_remove:
if n.target not in DQ_OPS:
raise RuntimeError(
f"Expected one of {DQ_OPS} dq_op, got {n.target}"
)
if arg_quant is not None:
if input_qparams is None:
input_qparams = arg_quant
elif input_qparams != arg_quant:
input_qparams = _merge_qparams(input_qparams, arg_quant)
return input_qparams, nodes_to_remove

def _annotate_input_params(
self,
graph_module: GraphModule,
node: Node,
index: int,
input_qparams: QuantArgs,
nodes_to_remove: set[Node],
) -> None:
node.meta["input_qparams"][index] = input_qparams

for dq in nodes_to_remove:
if dq.target not in DQ_OPS:
raise RuntimeError(f"Expected one of {DQ_OPS} dq_op, got {dq.target}")
node.replace_input_with(dq, cast(Node, dq.args[0]))
if not dq.users:
graph_module.graph.erase_node(dq)

special = _get_special_dtype(input_qparams)
if special:
node.all_input_nodes[index].meta[TosaSpecialDtype.meta_key()] = special

node.replace_input_with(n, cast(Node, n.args[0]))
if len(n.users) == 0:
graph_module.graph.erase_node(n)
special_dtype = _get_special_dtype(input_qparams)
if special_dtype:
node.all_input_nodes[i].meta[
TosaSpecialDtype.meta_key()
] = special_dtype
def fold_and_annotate_arg(
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
) -> None:
input_qparams, nodes_to_remove = self._extract_input_params(arg_list)
if input_qparams is None:
return
self._annotate_input_params(
graph_module, node, i, input_qparams, nodes_to_remove
)

def _handle_control_flow_node(self, node: Node, graph_module: GraphModule):
"""Fold outmost quant nodes inside submodule.
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/_passes/normalize_while_initial_args_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -82,6 +82,8 @@ def _normalize_node(self, graph_module: GraphModule, node: Node) -> bool:
new_carried = tuple(carried_inputs + additional_inputs)
node.update_arg(2, new_carried)
node.update_arg(3, ())
# annotate node so later keying of captured vs loop‐carried args is possible
node.meta["additional_inputs"] = additional_inputs

body_module_name = str(cast(Node, node.args[1]).target)
body_module = cast(GraphModule, graph_module.get_submodule(body_module_name)) # type: ignore
Expand Down
44 changes: 24 additions & 20 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,29 +890,33 @@ def any_or_hardtanh_min_zero(n: Node):
submodule_args_pos = -1 if node.target == torch.ops.higher_order.cond else -2
submodule_args = node.args[submodule_args_pos]
output_qspec = output_act_qspec
if len(submodule_args) > 0: # type: ignore[arg-type]
# The way the TOSA backend handles quantized inputs, arrays of input tensors (such as the input to a
# conditional graph) need shared quantization.
shared_qspec = SharedQuantizationSpec(
(cast(list[Node], submodule_args)[0], node)
)
quant_properties.quant_inputs = [
_QuantProperty(
submodule_args_pos,
[
input_act_qspec,
*([shared_qspec] * (len(submodule_args) - 1)), # type: ignore[arg-type]
],
# Annotate each control-flow tensor independently using the default input qspec
if submodule_args:
if node.meta.get("additional_inputs", None):
qspecs = [input_act_qspec] * len(cast(Sequence[Node], submodule_args)) # type: ignore[arg-type]
quant_properties.quant_inputs = [
_QuantProperty(submodule_args_pos, qspecs)
]
else:
shared_qspec = SharedQuantizationSpec(
(cast(list[Node], submodule_args)[0], node)
)
]
if node.target == torch.ops.higher_order.while_loop:
# The output of the while loop body can either re-enter the body, or exit the while loop.
# Therefore, A and B in the diagram below need to share the same quantization parameters.
# A -> while ( RESCALE -> ... RESCALE -> ) -> B
output_qspec = shared_qspec
quant_properties.quant_inputs = [
_QuantProperty(
submodule_args_pos,
[
input_act_qspec,
*([shared_qspec] * (len(submodule_args) - 1)), # type: ignore[arg-type]
],
)
]
if node.target == torch.ops.higher_order.while_loop:
# The output of the while loop body can either re-enter the body, or exit the while loop.
# Therefore, A and B in the diagram below need to share the same quantization parameters.
# A -> while ( RESCALE -> ... RESCALE -> ) -> B
output_qspec = shared_qspec

quant_properties.quant_output = _QuantProperty(0, output_qspec)

else:
return None

Expand Down
3 changes: 0 additions & 3 deletions backends/arm/test/ops/test_while.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,6 @@ def test_while_loop_tosa_FP(case: Callable[[], Tuple[torch.nn.Module, Tuple]]):
@common.parametrize(
"case",
test_cases,
xfails={
"large_threshold": "MLETORCH-1808 - Handle different scales for different parameters"
},
)
def test_while_loop_tosa_INT(case: Callable[[], Tuple[torch.nn.Module, Tuple]]):
module, example_inputs = case()
Expand Down
Loading