Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/qualcomm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ backends/qualcomm
| ├── wrappers # Wrapper of QNN data structures for ease of use.
| └── python # Python interface for using QNN libraries.
├── builders # Codes for lowering each operators (AoT Part).
├── custom_op # APIs for using custom ops with QNN backend
├── partition # QNN Partitioner (AoT Part).
├── _passes # Various private passes helping lower models to QNN backend (AoT Part).
├── python # Places to put pybind artifacts for accessing QNN APIs, structures, etc (AoT Part).
Expand Down
137 changes: 137 additions & 0 deletions backends/qualcomm/custom_op/annotator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Union

import torch
from executorch.backends.qualcomm.quantizer.rules import _is_float_tensor
from torchao.quantization.pt2e.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
SharedQuantizationSpec,
)
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY

logger = logging.getLogger(__name__)


@dataclass
class IOQuantConfig:
"""
Quantization config for custom op inputs and outputs.

Attributes:
input_quant_specs: Maps input index to its QuantizationSpec.
Only indices present in the dict are annotated. If None, no inputs
are annotated.
output_quant_specs: Maps output index to its QuantizationSpec.
For single-output ops annotation is done on the op node. For multi-output ops,
each index corresponds to a downstream getitem user. If None, no
outputs are annotated.
"""

input_quant_specs: Optional[
Dict[int, Union[QuantizationSpec, SharedQuantizationSpec]]
] = None
output_quant_specs: Optional[
Dict[int, Union[QuantizationSpec, SharedQuantizationSpec]]
] = None


class CustomOpsQuantAnnotator:
"""
Holds op IOQuantConfigs and builds a single annotation function
compatible with make_quantizer(custom_annotations=...).
"""

def __init__(self):
self._registry: Dict = {} # {op_target: IOQuantConfig}

def register_annotation(
self,
op_target,
io_quant_config: IOQuantConfig,
) -> "CustomOpsQuantAnnotator":
"""
Register quantization config for custom op.

Args:
op_target: The torch op target (e.g. torch.ops.my_ops.custom_op.default).
io_quant_config: IOQuantConfig specifying how to quantize inputs and outputs.

Returns self for method chaining.
"""
self._registry[op_target] = io_quant_config
return self

def build_annotation_fn(self) -> Callable[[torch.fx.GraphModule], None]:
"""
Build and return an annotation function for all registered ops.

The returned function has signature (gm: GraphModule) -> None and
can be passed directly to make_quantizer(custom_annotations=(fn,)).
"""
registry = dict(self._registry)

def annotate_custom_ops(gm: torch.fx.GraphModule) -> None:
for node in gm.graph.nodes:
if node.target not in registry:
continue

cfg = registry[node.target]
input_qspec_map = {}
if cfg.input_quant_specs is not None:
for arg_idx, spec in cfg.input_quant_specs.items():
if arg_idx >= len(node.args):
raise ValueError(
f"IOQuantConfig error for '{node.name}' ({node.target}): "
f"input_quant_specs index {arg_idx} is out of range "
f"(op has {len(node.args)} args)"
)
if not _is_float_tensor(node.args[arg_idx]):
logger.debug(
f"Skipping quantization of input {arg_idx} for "
f"'{node.name}' ({node.target}): expected a float tensor."
)
continue
logger.debug(
f"Annotating input {arg_idx} of '{node.name}' ({node.target}) "
f"with {spec}"
)
input_qspec_map[node.args[arg_idx]] = spec

if not cfg.output_quant_specs or len(cfg.output_quant_specs) <= 1:
# Single output — annotate on the op node
output_spec = (
cfg.output_quant_specs.get(0)
if cfg.output_quant_specs
else None
)
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_spec,
_annotated=True,
)
else:
# Tuple output — push quantization down to getitem users
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=None,
_annotated=True,
)
for user in node.users:
output_idx = user.args[1]
spec = cfg.output_quant_specs.get(output_idx)

if spec is not None:
user.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
output_qspec=spec,
_annotated=True,
)

return annotate_custom_ops
112 changes: 112 additions & 0 deletions backends/qualcomm/custom_op/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional

try:
from qti.aisw.op_package_generator.generator import QnnOpPackageGenerator
except ImportError as e:
raise ImportError(
"Failed to import QnnOpPackageGenerator. "
"Please run 'source $QNN_SDK_ROOT/bin/envsetup.sh' to set up the QNN SDK environment."
) from e

from executorch.backends.qualcomm.serialization.qc_schema import (
QnnExecuTorchOpPackageInfo,
QnnExecuTorchOpPackageOptions,
QnnExecuTorchOpPackagePlatform,
QnnExecuTorchOpPackageTarget,
)


class QnnCustomOpPackageBuilder:
"""
Parses a QNN XML op package config and manages registration of
target/platform/implementation for use with ExecuTorch.

Validates that all keys in torch_op_name_map are present in the parsed
package before any implementations are registered.
"""

def __init__(
self,
xml_path: str,
torch_op_name_map,
interface_provider: Optional[str] = None,
):
"""
Args:
xml_path: Path to the QNN XML OpDef config file.
torch_op_name_map: Maps QNN op type names to their corresponding
PyTorch op targets.
e.g. {"ExampleCustomOp": torch.ops.my_ops.custom_op.default}
interface_provider: Interface provider symbol name. Defaults to
"{PackageName}InterfaceProvider" if not specified.

Raises:
ValueError: If any key in torch_op_name_map is not found in the
parsed op package.
"""
op_package_generator = QnnOpPackageGenerator()
op_package_generator.parse_config([xml_path])

pkg_info = op_package_generator.package_infos[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will there be only one op packaged allowed in the xml?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there's one OpPackage per XML - each OpPackage can have multiple ops.

self.op_package_name = pkg_info.name
self.interface_provider = (
interface_provider
if interface_provider
else pkg_info.name + "InterfaceProvider"
)
self.torch_op_name_map = torch_op_name_map
self._collection: List[QnnExecuTorchOpPackageInfo] = []
self.operator_names = {op.type_name for op in pkg_info.operators}

missing_ops = set()
for qnn_op in self.torch_op_name_map.keys():
if qnn_op not in self.operator_names:
missing_ops.add(qnn_op)

if len(missing_ops):
raise ValueError(f"Ops missing from OpPackage: {missing_ops}")

def register_implementation(
self,
target: QnnExecuTorchOpPackageTarget,
platform: QnnExecuTorchOpPackagePlatform,
op_package_path: str,
) -> "QnnCustomOpPackageBuilder":
"""
Register one (target, platform, path) combination.
Creates one QnnExecuTorchOpPackageInfo per op in torch_op_name_map.
Returns self for method chaining.

Args:
target: QnnExecuTorchOpPackageTarget
platform: QnnExecuTorchOpPackagePlatform
op_package_path: Path to the implementation for the target/platform.
"""
for qnn_op_type_name, torch_name in self.torch_op_name_map.items():
self._collection.append(
QnnExecuTorchOpPackageInfo(
op_package_name=self.op_package_name,
op_package_path=op_package_path,
interface_provider=self.interface_provider,
target=target,
custom_op_name=str(torch_name),
qnn_op_type_name=qnn_op_type_name,
platform=platform,
)
)
return self

def get_op_package_options(self) -> QnnExecuTorchOpPackageOptions:
"""
Build and return QnnExecuTorchOpPackageOptions from all registered implementations.
Call after all register_implementation() calls are complete.
"""
options = QnnExecuTorchOpPackageOptions()
options.op_package_infos = list(self._collection)
return options
38 changes: 37 additions & 1 deletion backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9204,7 +9204,7 @@ def test_cli_with_input_list_assignment(self):
golden_output = ep.module()(sample_input, sample_input2)
self._assert_outputs_equal(golden_output, device_output)

def test_custom_op(self):
def test_custom_op_1(self):
if not self.required_envs([self.op_package_dir]):
self.skipTest("missing required envs")
cmds = [
Expand Down Expand Up @@ -9240,6 +9240,42 @@ def test_custom_op(self):
msg = json.loads(conn.recv())
self.assertTrue(msg["is_close"])

def test_custom_op_2(self):
if not self.required_envs([self.op_package_dir]):
self.skipTest("missing required envs")
cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/custom_op/custom_ops_2.py",
"--artifact",
self.artifact_dir,
"--build_folder",
self.build_folder,
"--device",
self.device,
"--model",
self.model,
"--target",
self.target,
"--ip",
self.ip,
"--port",
str(self.port),
"--op_package_dir",
self.op_package_dir,
"--build_op_package",
]
if self.host:
cmds.extend(["--host", self.host])
if self.enable_x86_64:
cmds.extend(["--enable_x86_64"])

p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
conn = listener.accept()
p.communicate()
msg = json.loads(conn.recv())
self.assertTrue(msg["is_close"])

def test_debugger_generate_optrace(self):
cmds = [
"python",
Expand Down
27 changes: 15 additions & 12 deletions docs/source/backends-qualcomm.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,18 @@ i.e., the directory containing `QNN_README.txt`.

### Setup environment variables

We set `LD_LIBRARY_PATH` to make sure the dynamic linker can find QNN libraries.
Source the QNN SDK environment setup script to configure paths and environment variables:

Further, we set `PYTHONPATH` because it's easier to develop and import ExecuTorch
Python APIs.
```bash
source $QNN_SDK_ROOT/bin/envsetup.sh
```

This sets up `LD_LIBRARY_PATH` and other required variables for the QNN SDK tools and libraries.

Additionally, set `PYTHONPATH` for ExecuTorch Python APIs:

```bash
export LD_LIBRARY_PATH=$QNN_SDK_ROOT/lib/x86_64-linux-clang/:$LD_LIBRARY_PATH
export PYTHONPATH=$EXECUTORCH_ROOT/..
export PYTHONPATH=$EXECUTORCH_ROOT/..:$PYTHONPATH
```

## Build
Expand Down Expand Up @@ -615,14 +619,13 @@ This matrix directly corresponds to the implementations in: [executorch/backends

### Custom Ops Support

You can extend QNN backend support for your own operators.
Follow the [tutorial](https://github.com/pytorch/executorch/tree/f32cdc3de6f7176d70a80228f1a60bcd45d93437/examples/qualcomm/custom_op#custom-operator-support):
The QNN backend supports custom PyTorch operators with the op package mechanism.
See the [custom op tutorial](https://github.com/pytorch/executorch/tree/main/examples/qualcomm/custom_op) for the full end-to-end flow. It covers:

It covers:
- Writing new NodeVisitor for your op
- Registering via @register_node_visitor
- Creating and linking libQnnOp*.so for the delegate
- Testing and verifying custom kernels on HTP
- Defining a custom PyTorch op (single-output and multi-output)
- Writing and building a QNN op package (XML and Op Implementation)
- Registering the op package with ExecuTorch via `QnnCustomOpPackageBuilder`
- Annotating custom ops for quantization via `CustomOpsQuantAnnotator` / `IOQuantConfig`

## FAQ

Expand Down
Loading
Loading