Skip to content

Commit e774b77

Browse files
[Executorch][QNN] Improve qualcomm examples (#15477)
1 parent 52b0b3b commit e774b77

File tree

5 files changed

+103
-65
lines changed

5 files changed

+103
-65
lines changed

examples/qualcomm/oss_scripts/albert.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030

3131

3232
def main(args):
33+
if args.compile_only and args.pre_gen_pte:
34+
raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true")
35+
3336
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
3437

3538
os.makedirs(args.artifact, exist_ok=True)
@@ -60,26 +63,32 @@ def main(args):
6063
module = AutoModelForMaskedLM.from_pretrained(model_name, config=config).eval()
6164
pte_filename = "albert_qnn_q16"
6265

63-
# lower to QNN
64-
passes_job = get_capture_program_passes()
65-
build_executorch_binary(
66-
module,
67-
inputs[0],
68-
args.model,
69-
f"{args.artifact}/{pte_filename}",
70-
dataset=inputs,
71-
skip_node_id_set=skip_node_id_set,
72-
skip_node_op_set=skip_node_op_set,
73-
quant_dtype=QuantDtype.use_16a16w,
74-
passes_job=passes_job,
75-
shared_buffer=args.shared_buffer,
76-
)
66+
# Skip lowering/compilation if using pre-generated PTE
67+
if not args.pre_gen_pte:
68+
# lower to QNN
69+
passes_job = get_capture_program_passes()
70+
build_executorch_binary(
71+
module,
72+
inputs[0],
73+
args.model,
74+
f"{args.artifact}/{pte_filename}",
75+
dataset=inputs,
76+
skip_node_id_set=skip_node_id_set,
77+
skip_node_op_set=skip_node_op_set,
78+
quant_dtype=QuantDtype.use_16a16w,
79+
passes_job=passes_job,
80+
shared_buffer=args.shared_buffer,
81+
)
7782

7883
if args.compile_only:
7984
return
8085

8186
workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}"
82-
pte_path = f"{args.artifact}/{pte_filename}.pte"
87+
pte_path = (
88+
f"{args.pre_gen_pte}/{pte_filename}.pte"
89+
if args.pre_gen_pte
90+
else f"{args.artifact}/{pte_filename}.pte"
91+
)
8392

8493
adb = SimpleADB(
8594
qnn_sdk=os.getenv("QNN_SDK_ROOT"),

examples/qualcomm/oss_scripts/bert.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030

3131

3232
def main(args):
33+
if args.compile_only and args.pre_gen_pte:
34+
raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true")
35+
3336
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
3437

3538
os.makedirs(args.artifact, exist_ok=True)
@@ -57,26 +60,32 @@ def main(args):
5760
).eval()
5861
pte_filename = "bert_qnn_q16"
5962

60-
# lower to QNN
61-
passes_job = get_capture_program_passes()
62-
build_executorch_binary(
63-
module,
64-
inputs[0],
65-
args.model,
66-
f"{args.artifact}/{pte_filename}",
67-
dataset=inputs,
68-
skip_node_id_set=skip_node_id_set,
69-
skip_node_op_set=skip_node_op_set,
70-
quant_dtype=QuantDtype.use_16a8w,
71-
passes_job=passes_job,
72-
shared_buffer=args.shared_buffer,
73-
)
63+
# Skip lowering/compilation if using pre-generated PTE
64+
if not args.pre_gen_pte:
65+
# lower to QNN
66+
passes_job = get_capture_program_passes()
67+
build_executorch_binary(
68+
module,
69+
inputs[0],
70+
args.model,
71+
f"{args.artifact}/{pte_filename}",
72+
dataset=inputs,
73+
skip_node_id_set=skip_node_id_set,
74+
skip_node_op_set=skip_node_op_set,
75+
quant_dtype=QuantDtype.use_16a8w,
76+
passes_job=passes_job,
77+
shared_buffer=args.shared_buffer,
78+
)
7479

7580
if args.compile_only:
7681
return
7782

7883
workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}"
79-
pte_path = f"{args.artifact}/{pte_filename}.pte"
84+
pte_path = (
85+
f"{args.pre_gen_pte}/{pte_filename}.pte"
86+
if args.pre_gen_pte
87+
else f"{args.artifact}/{pte_filename}.pte"
88+
)
8089

8190
adb = SimpleADB(
8291
qnn_sdk=os.getenv("QNN_SDK_ROOT"),

examples/qualcomm/oss_scripts/distilbert.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131

3232

3333
def main(args):
34+
if args.compile_only and args.pre_gen_pte:
35+
raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true")
36+
3437
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
3538

3639
os.makedirs(args.artifact, exist_ok=True)
@@ -58,26 +61,32 @@ def main(args):
5861
).eval()
5962
pte_filename = "distilbert_qnn_q16"
6063

61-
# lower to QNN
62-
passes_job = get_capture_program_passes()
63-
build_executorch_binary(
64-
module,
65-
inputs[0],
66-
args.model,
67-
f"{args.artifact}/{pte_filename}",
68-
dataset=inputs,
69-
skip_node_id_set=skip_node_id_set,
70-
skip_node_op_set=skip_node_op_set,
71-
quant_dtype=QuantDtype.use_16a8w,
72-
passes_job=passes_job,
73-
shared_buffer=args.shared_buffer,
74-
)
64+
# Skip lowering/compilation if using pre-generated PTE
65+
if not args.pre_gen_pte:
66+
# lower to QNN
67+
passes_job = get_capture_program_passes()
68+
build_executorch_binary(
69+
module,
70+
inputs[0],
71+
args.model,
72+
f"{args.artifact}/{pte_filename}",
73+
dataset=inputs,
74+
skip_node_id_set=skip_node_id_set,
75+
skip_node_op_set=skip_node_op_set,
76+
quant_dtype=QuantDtype.use_16a8w,
77+
passes_job=passes_job,
78+
shared_buffer=args.shared_buffer,
79+
)
7580

7681
if args.compile_only:
7782
return
7883

7984
workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}"
80-
pte_path = f"{args.artifact}/{pte_filename}.pte"
85+
pte_path = (
86+
f"{args.pre_gen_pte}/{pte_filename}.pte"
87+
if args.pre_gen_pte
88+
else f"{args.artifact}/{pte_filename}.pte"
89+
)
8190

8291
adb = SimpleADB(
8392
qnn_sdk=os.getenv("QNN_SDK_ROOT"),

examples/qualcomm/oss_scripts/eurobert.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535

3636

3737
def main(args):
38+
if args.compile_only and args.pre_gen_pte:
39+
raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true")
40+
3841
assert (
3942
transformers.__version__ >= TRANSFORMERS_VERSION
4043
), f"Please ensure transformers version >= {TRANSFORMERS_VERSION}, current version is {transformers.__version__}"
@@ -88,33 +91,40 @@ def replace_rms_norm_with_native_rms_norm(module: torch.nn.Module):
8891

8992
pte_filename = "eurobert_qnn_q16"
9093

91-
# lower to QNN
92-
passes_job = get_capture_program_passes()
93-
quantizer = make_quantizer(
94-
quant_dtype=QuantDtype.use_16a16w,
95-
)
96-
quantizer.add_custom_quant_annotations((annotate_eurobert,))
97-
with torch.no_grad():
98-
build_executorch_binary(
99-
model,
100-
inputs[0],
101-
args.model,
102-
f"{args.artifact}/{pte_filename}",
103-
dataset=inputs,
104-
skip_node_id_set=skip_node_id_set,
105-
skip_node_op_set=skip_node_op_set,
106-
custom_quantizer=quantizer,
107-
passes_job=passes_job,
108-
shared_buffer=args.shared_buffer,
94+
# Skip lowering/compilation if using pre-generated PTE
95+
if not args.pre_gen_pte:
96+
# lower to QNN
97+
passes_job = get_capture_program_passes()
98+
quantizer = make_quantizer(
99+
quant_dtype=QuantDtype.use_16a16w,
109100
)
101+
quantizer.add_custom_quant_annotations((annotate_eurobert,))
102+
with torch.no_grad():
103+
build_executorch_binary(
104+
model,
105+
inputs[0],
106+
args.model,
107+
f"{args.artifact}/{pte_filename}",
108+
dataset=inputs,
109+
skip_node_id_set=skip_node_id_set,
110+
skip_node_op_set=skip_node_op_set,
111+
custom_quantizer=quantizer,
112+
passes_job=passes_job,
113+
shared_buffer=args.shared_buffer,
114+
)
110115

111116
if args.compile_only:
112117
return
113118

119+
pte_path = (
120+
f"{args.pre_gen_pte}/{pte_filename}.pte"
121+
if args.pre_gen_pte
122+
else f"{args.artifact}/{pte_filename}.pte"
123+
)
114124
adb = SimpleADB(
115125
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
116126
build_path=f"{args.build_folder}",
117-
pte_path=f"{args.artifact}/{pte_filename}.pte",
127+
pte_path=pte_path,
118128
workspace=f"/data/local/tmp/executorch/{pte_filename}",
119129
device_id=args.device,
120130
host_id=args.host,

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,8 @@ def post_process():
10231023
runner=f"examples/qualcomm/oss_scripts/llama/qnn_llama_runner",
10241024
)
10251025
# No pregen inputs, input_list is not required
1026-
adb.push(inputs=[], files=[runtime_tokenizer_path])
1026+
if not args.skip_push:
1027+
adb.push(inputs=[], files=[runtime_tokenizer_path])
10271028
adb.execute(custom_runner_cmd=runner_cmd)
10281029
adb.pull(output_path=args.artifact, callback=post_process)
10291030

0 commit comments

Comments
 (0)