-
Notifications
You must be signed in to change notification settings - Fork 963
Qualcomm AI Engine Direct - Addition of new APIs for QNN custom op package and quantization annotation #19094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
abhinaykukkadapu
merged 1 commit into
pytorch:main
from
CodeLinaro:dev1/mmadhava/custom_op_refactor
Apr 28, 2026
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| # Copyright (c) Qualcomm Innovation Center, Inc. | ||
| # All rights reserved | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import logging | ||
| from dataclasses import dataclass | ||
| from typing import Callable, Dict, Optional, Union | ||
|
|
||
| import torch | ||
| from executorch.backends.qualcomm.quantizer.rules import _is_float_tensor | ||
| from torchao.quantization.pt2e.quantizer import ( | ||
| QuantizationAnnotation, | ||
| QuantizationSpec, | ||
| SharedQuantizationSpec, | ||
| ) | ||
| from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class IOQuantConfig: | ||
| """ | ||
| Quantization config for custom op inputs and outputs. | ||
|
|
||
| Attributes: | ||
| input_quant_specs: Maps input index to its QuantizationSpec. | ||
| Only indices present in the dict are annotated. If None, no inputs | ||
| are annotated. | ||
| output_quant_specs: Maps output index to its QuantizationSpec. | ||
| For single-output ops annotation is done on the op node. For multi-output ops, | ||
| each index corresponds to a downstream getitem user. If None, no | ||
| outputs are annotated. | ||
| """ | ||
|
|
||
| input_quant_specs: Optional[ | ||
| Dict[int, Union[QuantizationSpec, SharedQuantizationSpec]] | ||
| ] = None | ||
| output_quant_specs: Optional[ | ||
| Dict[int, Union[QuantizationSpec, SharedQuantizationSpec]] | ||
| ] = None | ||
|
|
||
|
|
||
| class CustomOpsQuantAnnotator: | ||
| """ | ||
| Holds op IOQuantConfigs and builds a single annotation function | ||
| compatible with make_quantizer(custom_annotations=...). | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| self._registry: Dict = {} # {op_target: IOQuantConfig} | ||
|
|
||
| def register_annotation( | ||
| self, | ||
| op_target, | ||
| io_quant_config: IOQuantConfig, | ||
| ) -> "CustomOpsQuantAnnotator": | ||
| """ | ||
| Register quantization config for custom op. | ||
|
|
||
| Args: | ||
| op_target: The torch op target (e.g. torch.ops.my_ops.custom_op.default). | ||
| io_quant_config: IOQuantConfig specifying how to quantize inputs and outputs. | ||
|
|
||
| Returns self for method chaining. | ||
| """ | ||
| self._registry[op_target] = io_quant_config | ||
| return self | ||
|
|
||
| def build_annotation_fn(self) -> Callable[[torch.fx.GraphModule], None]: | ||
| """ | ||
| Build and return an annotation function for all registered ops. | ||
|
|
||
| The returned function has signature (gm: GraphModule) -> None and | ||
| can be passed directly to make_quantizer(custom_annotations=(fn,)). | ||
| """ | ||
| registry = dict(self._registry) | ||
|
|
||
| def annotate_custom_ops(gm: torch.fx.GraphModule) -> None: | ||
| for node in gm.graph.nodes: | ||
| if node.target not in registry: | ||
| continue | ||
|
|
||
| cfg = registry[node.target] | ||
| input_qspec_map = {} | ||
| if cfg.input_quant_specs is not None: | ||
| for arg_idx, spec in cfg.input_quant_specs.items(): | ||
| if arg_idx >= len(node.args): | ||
| raise ValueError( | ||
| f"IOQuantConfig error for '{node.name}' ({node.target}): " | ||
| f"input_quant_specs index {arg_idx} is out of range " | ||
| f"(op has {len(node.args)} args)" | ||
| ) | ||
| if not _is_float_tensor(node.args[arg_idx]): | ||
| logger.debug( | ||
| f"Skipping quantization of input {arg_idx} for " | ||
| f"'{node.name}' ({node.target}): expected a float tensor." | ||
| ) | ||
| continue | ||
| logger.debug( | ||
| f"Annotating input {arg_idx} of '{node.name}' ({node.target}) " | ||
| f"with {spec}" | ||
| ) | ||
| input_qspec_map[node.args[arg_idx]] = spec | ||
|
|
||
| if not cfg.output_quant_specs or len(cfg.output_quant_specs) <= 1: | ||
| # Single output — annotate on the op node | ||
| output_spec = ( | ||
| cfg.output_quant_specs.get(0) | ||
| if cfg.output_quant_specs | ||
| else None | ||
| ) | ||
| node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( | ||
| input_qspec_map=input_qspec_map, | ||
| output_qspec=output_spec, | ||
| _annotated=True, | ||
| ) | ||
| else: | ||
| # Tuple output — push quantization down to getitem users | ||
| node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( | ||
| input_qspec_map=input_qspec_map, | ||
| output_qspec=None, | ||
| _annotated=True, | ||
| ) | ||
| for user in node.users: | ||
| output_idx = user.args[1] | ||
| spec = cfg.output_quant_specs.get(output_idx) | ||
|
|
||
| if spec is not None: | ||
| user.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( | ||
| output_qspec=spec, | ||
| _annotated=True, | ||
| ) | ||
|
|
||
| return annotate_custom_ops |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| # Copyright (c) Qualcomm Innovation Center, Inc. | ||
| # All rights reserved | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from typing import List, Optional | ||
|
|
||
| try: | ||
| from qti.aisw.op_package_generator.generator import QnnOpPackageGenerator | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "Failed to import QnnOpPackageGenerator. " | ||
| "Please run 'source $QNN_SDK_ROOT/bin/envsetup.sh' to set up the QNN SDK environment." | ||
| ) from e | ||
|
|
||
| from executorch.backends.qualcomm.serialization.qc_schema import ( | ||
| QnnExecuTorchOpPackageInfo, | ||
| QnnExecuTorchOpPackageOptions, | ||
| QnnExecuTorchOpPackagePlatform, | ||
| QnnExecuTorchOpPackageTarget, | ||
| ) | ||
|
|
||
|
|
||
| class QnnCustomOpPackageBuilder: | ||
| """ | ||
| Parses a QNN XML op package config and manages registration of | ||
| target/platform/implementation for use with ExecuTorch. | ||
|
|
||
| Validates that all keys in torch_op_name_map are present in the parsed | ||
| package before any implementations are registered. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| xml_path: str, | ||
| torch_op_name_map, | ||
| interface_provider: Optional[str] = None, | ||
| ): | ||
| """ | ||
| Args: | ||
| xml_path: Path to the QNN XML OpDef config file. | ||
| torch_op_name_map: Maps QNN op type names to their corresponding | ||
| PyTorch op targets. | ||
| e.g. {"ExampleCustomOp": torch.ops.my_ops.custom_op.default} | ||
| interface_provider: Interface provider symbol name. Defaults to | ||
| "{PackageName}InterfaceProvider" if not specified. | ||
|
|
||
| Raises: | ||
| ValueError: If any key in torch_op_name_map is not found in the | ||
| parsed op package. | ||
| """ | ||
| op_package_generator = QnnOpPackageGenerator() | ||
| op_package_generator.parse_config([xml_path]) | ||
|
|
||
| pkg_info = op_package_generator.package_infos[0] | ||
| self.op_package_name = pkg_info.name | ||
| self.interface_provider = ( | ||
| interface_provider | ||
| if interface_provider | ||
| else pkg_info.name + "InterfaceProvider" | ||
| ) | ||
| self.torch_op_name_map = torch_op_name_map | ||
| self._collection: List[QnnExecuTorchOpPackageInfo] = [] | ||
| self.operator_names = {op.type_name for op in pkg_info.operators} | ||
|
|
||
| missing_ops = set() | ||
| for qnn_op in self.torch_op_name_map.keys(): | ||
| if qnn_op not in self.operator_names: | ||
| missing_ops.add(qnn_op) | ||
|
|
||
| if len(missing_ops): | ||
| raise ValueError(f"Ops missing from OpPackage: {missing_ops}") | ||
|
|
||
| def register_implementation( | ||
| self, | ||
| target: QnnExecuTorchOpPackageTarget, | ||
| platform: QnnExecuTorchOpPackagePlatform, | ||
| op_package_path: str, | ||
| ) -> "QnnCustomOpPackageBuilder": | ||
| """ | ||
| Register one (target, platform, path) combination. | ||
| Creates one QnnExecuTorchOpPackageInfo per op in torch_op_name_map. | ||
| Returns self for method chaining. | ||
|
|
||
| Args: | ||
| target: QnnExecuTorchOpPackageTarget | ||
| platform: QnnExecuTorchOpPackagePlatform | ||
| op_package_path: Path to the implementation for the target/platform. | ||
| """ | ||
| for qnn_op_type_name, torch_name in self.torch_op_name_map.items(): | ||
| self._collection.append( | ||
| QnnExecuTorchOpPackageInfo( | ||
| op_package_name=self.op_package_name, | ||
| op_package_path=op_package_path, | ||
| interface_provider=self.interface_provider, | ||
| target=target, | ||
| custom_op_name=str(torch_name), | ||
| qnn_op_type_name=qnn_op_type_name, | ||
| platform=platform, | ||
| ) | ||
| ) | ||
| return self | ||
|
|
||
| def get_op_package_options(self) -> QnnExecuTorchOpPackageOptions: | ||
| """ | ||
| Build and return QnnExecuTorchOpPackageOptions from all registered implementations. | ||
| Call after all register_implementation() calls are complete. | ||
| """ | ||
| options = QnnExecuTorchOpPackageOptions() | ||
| options.op_package_infos = list(self._collection) | ||
| return options | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will there be only one op packaged allowed in the xml?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there's one OpPackage per XML - each OpPackage can have multiple ops.