77import json
88import logging
99import os
10+ from contextlib import contextmanager
1011
1112from multiprocessing .connection import Client
1213
1314import numpy as np
1415
1516import torch
17+ import torch .nn .functional as F
1618from executorch .backends .qualcomm .quantizer .quantizer import QuantDtype
1719from executorch .examples .models .torchvision_vit .model import TorchVisionViTModel
1820from executorch .examples .qualcomm .utils import (
2527)
2628
2729
30+ # Copied from torch/nn/functional.py
31+ # QNN does not have 5D permute optimization. Fuse to a single 4D optimization
32+ # Changed unsqueeze(0).transpose(0, -2).squeeze(-2) to permute(2, 0, 1, 3)
33+ def _in_projection_packed_custom (q , k , v , w , b = None ) -> list [torch .Tensor ]:
34+ from torch .nn .functional import linear
35+
36+ E = q .size (- 1 )
37+ if k is v :
38+ if q is k :
39+ # self-attention
40+ proj = linear (q , w , b )
41+ # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
42+ proj = proj .unflatten (- 1 , (3 , E )).permute (2 , 0 , 1 , 3 ).contiguous ()
43+ # pyrefly: ignore # bad-return
44+ return proj [0 ], proj [1 ], proj [2 ]
45+ else :
46+ # encoder-decoder attention
47+ w_q , w_kv = w .split ([E , E * 2 ])
48+ if b is None :
49+ b_q = b_kv = None
50+ else :
51+ b_q , b_kv = b .split ([E , E * 2 ])
52+ q_proj = linear (q , w_q , b_q )
53+ kv_proj = linear (k , w_kv , b_kv )
54+ # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
55+ kv_proj = kv_proj .unflatten (- 1 , (2 , E )).permute (2 , 0 , 1 , 3 ).contiguous ()
56+ # pyrefly: ignore # bad-return
57+ return (q_proj , kv_proj [0 ], kv_proj [1 ])
58+ else :
59+ w_q , w_k , w_v = w .chunk (3 )
60+ if b is None :
61+ b_q = b_k = b_v = None
62+ else :
63+ b_q , b_k , b_v = b .chunk (3 )
64+ # pyrefly: ignore # bad-return
65+ return linear (q , w_q , b_q ), linear (k , w_k , b_k ), linear (v , w_v , b_v )
66+
67+
68+ # Context manager to patch temporarily, so it won't affect other users using F._in_projection_packed
69+ @contextmanager
70+ def PermuteInProjectionPacked ():
71+ # Save the original function so it can be restored later
72+ _original_in_projection_packed = F ._in_projection_packed
73+ F ._in_projection_packed = _in_projection_packed_custom
74+ try :
75+ yield
76+ finally :
77+ F ._in_projection_packed = _original_in_projection_packed
78+
79+
2880def main (args ):
2981 # ensure the working directory exist.
3082 os .makedirs (args .artifact , exist_ok = True )
@@ -44,16 +96,18 @@ def main(args):
4496 )
4597
4698 pte_filename = "vit_qnn_q8"
47- instance = TorchVisionViTModel ()
48- build_executorch_binary (
49- instance .get_eager_model ().eval (),
50- instance .get_example_inputs (),
51- args .model ,
52- f"{ args .artifact } /{ pte_filename } " ,
53- inputs ,
54- quant_dtype = QuantDtype .use_8a8w ,
55- shared_buffer = args .shared_buffer ,
56- )
99+ instance = TorchVisionViTModel ().get_eager_model ().eval ()
100+
101+ with PermuteInProjectionPacked ():
102+ build_executorch_binary (
103+ instance ,
104+ inputs [0 ],
105+ args .model ,
106+ f"{ args .artifact } /{ pte_filename } " ,
107+ inputs ,
108+ quant_dtype = QuantDtype .use_8a8w ,
109+ shared_buffer = args .shared_buffer ,
110+ )
57111
58112 if args .compile_only :
59113 return
0 commit comments