diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index e08132946a1..6813416eec4 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -40,6 +40,19 @@ def _get_special_dtype(qspec: QuantArgs) -> TosaSpecialDtype | None: return None +def _merge_qparams(qspec_1: QuantArgs, qspec_2: QuantArgs) -> QuantArgs: + """Merge two QuantArgs when inputs are quantized differently. + + Requires same dtype; picks the first's parameters by default. + + """ + if qspec_1.dtype != qspec_2.dtype: + raise RuntimeError( + f"Cannot merge qparams of different dtypes: {qspec_1.dtype} vs {qspec_2.dtype}" + ) + return qspec_1 + + def get_input_qparams(node: Node) -> dict[int, QuantArgs]: """Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'. @@ -121,57 +134,72 @@ def __init__( super().__init__(*args, **kwargs) self.exported_program = exported_program - def fold_and_annotate_arg( - self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int - ) -> None: - input_qparams = None - nodes_to_remove = set() + def _extract_input_params( + self, arg_list: list[Node] + ) -> tuple[Optional[QuantArgs], set[Node]]: + input_qparams: Optional[QuantArgs] = None + nodes_to_remove: set[Node] = set() for arg in arg_list: if not isinstance(arg, Node): - return - - arg_quant_params = None + return None, set() + arg_quant: Optional[QuantArgs] = None if arg.target in DQ_OPS: args = arg.args scales = args[1] if ( - isinstance(args[1], Node) + isinstance(scales, Node) and self.exported_program is not None - and is_param_node(self.exported_program, args[1]) + and is_param_node(self.exported_program, scales) ): - scales = get_param_tensor(self.exported_program, args[1]) + scales = get_param_tensor(self.exported_program, scales) zps = args[2] if ( - isinstance(args[2], Node) + isinstance(zps, Node) and self.exported_program is not None - and is_param_node(self.exported_program, args[2]) + and is_param_node(self.exported_program, zps) ): - zps = get_param_tensor(self.exported_program, args[2]) - arg_quant_params = QuantArgs.from_operator( + zps = get_param_tensor(self.exported_program, zps) + arg_quant = QuantArgs.from_operator( arg.target, (args[0], scales, zps, *args[3:]) ) - # add arg to nodes_to_remove to fold the dq-node nodes_to_remove.add(arg) - if input_qparams is not None and input_qparams != arg_quant_params: - # Two args are quantized differently - raise RuntimeError("Input qparams do not match") - input_qparams = arg_quant_params - if input_qparams is not None: - node.meta["input_qparams"][i] = input_qparams - for n in nodes_to_remove: - if n.target not in DQ_OPS: - raise RuntimeError( - f"Expected one of {DQ_OPS} dq_op, got {n.target}" - ) + if arg_quant is not None: + if input_qparams is None: + input_qparams = arg_quant + elif input_qparams != arg_quant: + input_qparams = _merge_qparams(input_qparams, arg_quant) + return input_qparams, nodes_to_remove + + def _annotate_input_params( + self, + graph_module: GraphModule, + node: Node, + index: int, + input_qparams: QuantArgs, + nodes_to_remove: set[Node], + ) -> None: + node.meta["input_qparams"][index] = input_qparams + + for dq in nodes_to_remove: + if dq.target not in DQ_OPS: + raise RuntimeError(f"Expected one of {DQ_OPS} dq_op, got {dq.target}") + node.replace_input_with(dq, cast(Node, dq.args[0])) + if not dq.users: + graph_module.graph.erase_node(dq) + + special = _get_special_dtype(input_qparams) + if special: + node.all_input_nodes[index].meta[TosaSpecialDtype.meta_key()] = special - node.replace_input_with(n, cast(Node, n.args[0])) - if len(n.users) == 0: - graph_module.graph.erase_node(n) - special_dtype = _get_special_dtype(input_qparams) - if special_dtype: - node.all_input_nodes[i].meta[ - TosaSpecialDtype.meta_key() - ] = special_dtype + def fold_and_annotate_arg( + self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int + ) -> None: + input_qparams, nodes_to_remove = self._extract_input_params(arg_list) + if input_qparams is None: + return + self._annotate_input_params( + graph_module, node, i, input_qparams, nodes_to_remove + ) def _handle_control_flow_node(self, node: Node, graph_module: GraphModule): """Fold outmost quant nodes inside submodule. diff --git a/backends/arm/_passes/normalize_while_initial_args_pass.py b/backends/arm/_passes/normalize_while_initial_args_pass.py index 00ccf7817a1..b5d255e9520 100644 --- a/backends/arm/_passes/normalize_while_initial_args_pass.py +++ b/backends/arm/_passes/normalize_while_initial_args_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -82,6 +82,8 @@ def _normalize_node(self, graph_module: GraphModule, node: Node) -> bool: new_carried = tuple(carried_inputs + additional_inputs) node.update_arg(2, new_carried) node.update_arg(3, ()) + # annotate node so later keying of captured vs loop‐carried args is possible + node.meta["additional_inputs"] = additional_inputs body_module_name = str(cast(Node, node.args[1]).target) body_module = cast(GraphModule, graph_module.get_submodule(body_module_name)) # type: ignore diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index efc8320f0b9..7c2ce6b74d1 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -890,29 +890,33 @@ def any_or_hardtanh_min_zero(n: Node): submodule_args_pos = -1 if node.target == torch.ops.higher_order.cond else -2 submodule_args = node.args[submodule_args_pos] output_qspec = output_act_qspec - if len(submodule_args) > 0: # type: ignore[arg-type] - # The way the TOSA backend handles quantized inputs, arrays of input tensors (such as the input to a - # conditional graph) need shared quantization. - shared_qspec = SharedQuantizationSpec( - (cast(list[Node], submodule_args)[0], node) - ) - quant_properties.quant_inputs = [ - _QuantProperty( - submodule_args_pos, - [ - input_act_qspec, - *([shared_qspec] * (len(submodule_args) - 1)), # type: ignore[arg-type] - ], + # Annotate each control-flow tensor independently using the default input qspec + if submodule_args: + if node.meta.get("additional_inputs", None): + qspecs = [input_act_qspec] * len(cast(Sequence[Node], submodule_args)) # type: ignore[arg-type] + quant_properties.quant_inputs = [ + _QuantProperty(submodule_args_pos, qspecs) + ] + else: + shared_qspec = SharedQuantizationSpec( + (cast(list[Node], submodule_args)[0], node) ) - ] - if node.target == torch.ops.higher_order.while_loop: - # The output of the while loop body can either re-enter the body, or exit the while loop. - # Therefore, A and B in the diagram below need to share the same quantization parameters. - # A -> while ( RESCALE -> ... RESCALE -> ) -> B - output_qspec = shared_qspec + quant_properties.quant_inputs = [ + _QuantProperty( + submodule_args_pos, + [ + input_act_qspec, + *([shared_qspec] * (len(submodule_args) - 1)), # type: ignore[arg-type] + ], + ) + ] + if node.target == torch.ops.higher_order.while_loop: + # The output of the while loop body can either re-enter the body, or exit the while loop. + # Therefore, A and B in the diagram below need to share the same quantization parameters. + # A -> while ( RESCALE -> ... RESCALE -> ) -> B + output_qspec = shared_qspec quant_properties.quant_output = _QuantProperty(0, output_qspec) - else: return None diff --git a/backends/arm/test/ops/test_while.py b/backends/arm/test/ops/test_while.py index c150b194ead..b5cab047a50 100644 --- a/backends/arm/test/ops/test_while.py +++ b/backends/arm/test/ops/test_while.py @@ -210,9 +210,6 @@ def test_while_loop_tosa_FP(case: Callable[[], Tuple[torch.nn.Module, Tuple]]): @common.parametrize( "case", test_cases, - xfails={ - "large_threshold": "MLETORCH-1808 - Handle different scales for different parameters" - }, ) def test_while_loop_tosa_INT(case: Callable[[], Tuple[torch.nn.Module, Tuple]]): module, example_inputs = case()