Skip to content

Commit 3e90b44

Browse files
authored
Qualcomm AI Engine Direct - VIT Optimization (#15696)
### Summary QNN doesn't have much 5D permute optimization, which causes Vit running slower than CPU. <img width="742" height="585" alt="image" src="https://github.com/user-attachments/assets/bd468a83-09c0-4dff-a75c-029c12d85ea7" /> Switched pattern from unsqueeze->permute 5d->squeeze to permute 4d. Improvements: 150ms/inference -> 4.2ms/inference. ### Test plan Pass Vit UT
1 parent ed72daf commit 3e90b44

File tree

1 file changed

+64
-10
lines changed

1 file changed

+64
-10
lines changed

examples/qualcomm/scripts/torchvision_vit.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
import json
88
import logging
99
import os
10+
from contextlib import contextmanager
1011

1112
from multiprocessing.connection import Client
1213

1314
import numpy as np
1415

1516
import torch
17+
import torch.nn.functional as F
1618
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
1719
from executorch.examples.models.torchvision_vit.model import TorchVisionViTModel
1820
from executorch.examples.qualcomm.utils import (
@@ -25,6 +27,56 @@
2527
)
2628

2729

30+
# Copied from torch/nn/functional.py
31+
# QNN does not have 5D permute optimization. Fuse to a single 4D optimization
32+
# Changed unsqueeze(0).transpose(0, -2).squeeze(-2) to permute(2, 0, 1, 3)
33+
def _in_projection_packed_custom(q, k, v, w, b=None) -> list[torch.Tensor]:
34+
from torch.nn.functional import linear
35+
36+
E = q.size(-1)
37+
if k is v:
38+
if q is k:
39+
# self-attention
40+
proj = linear(q, w, b)
41+
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
42+
proj = proj.unflatten(-1, (3, E)).permute(2, 0, 1, 3).contiguous()
43+
# pyrefly: ignore # bad-return
44+
return proj[0], proj[1], proj[2]
45+
else:
46+
# encoder-decoder attention
47+
w_q, w_kv = w.split([E, E * 2])
48+
if b is None:
49+
b_q = b_kv = None
50+
else:
51+
b_q, b_kv = b.split([E, E * 2])
52+
q_proj = linear(q, w_q, b_q)
53+
kv_proj = linear(k, w_kv, b_kv)
54+
# reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
55+
kv_proj = kv_proj.unflatten(-1, (2, E)).permute(2, 0, 1, 3).contiguous()
56+
# pyrefly: ignore # bad-return
57+
return (q_proj, kv_proj[0], kv_proj[1])
58+
else:
59+
w_q, w_k, w_v = w.chunk(3)
60+
if b is None:
61+
b_q = b_k = b_v = None
62+
else:
63+
b_q, b_k, b_v = b.chunk(3)
64+
# pyrefly: ignore # bad-return
65+
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
66+
67+
68+
# Context manager to patch temporarily, so it won't affect other users using F._in_projection_packed
69+
@contextmanager
70+
def PermuteInProjectionPacked():
71+
# Save the original function so it can be restored later
72+
_original_in_projection_packed = F._in_projection_packed
73+
F._in_projection_packed = _in_projection_packed_custom
74+
try:
75+
yield
76+
finally:
77+
F._in_projection_packed = _original_in_projection_packed
78+
79+
2880
def main(args):
2981
# ensure the working directory exist.
3082
os.makedirs(args.artifact, exist_ok=True)
@@ -44,16 +96,18 @@ def main(args):
4496
)
4597

4698
pte_filename = "vit_qnn_q8"
47-
instance = TorchVisionViTModel()
48-
build_executorch_binary(
49-
instance.get_eager_model().eval(),
50-
instance.get_example_inputs(),
51-
args.model,
52-
f"{args.artifact}/{pte_filename}",
53-
inputs,
54-
quant_dtype=QuantDtype.use_8a8w,
55-
shared_buffer=args.shared_buffer,
56-
)
99+
instance = TorchVisionViTModel().get_eager_model().eval()
100+
101+
with PermuteInProjectionPacked():
102+
build_executorch_binary(
103+
instance,
104+
inputs[0],
105+
args.model,
106+
f"{args.artifact}/{pte_filename}",
107+
inputs,
108+
quant_dtype=QuantDtype.use_8a8w,
109+
shared_buffer=args.shared_buffer,
110+
)
57111

58112
if args.compile_only:
59113
return

0 commit comments

Comments
 (0)