|
35 | 35 |
|
36 | 36 |
|
37 | 37 | def main(args): |
| 38 | + if args.compile_only and args.pre_gen_pte: |
| 39 | + raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true") |
| 40 | + |
38 | 41 | assert ( |
39 | 42 | transformers.__version__ >= TRANSFORMERS_VERSION |
40 | 43 | ), f"Please ensure transformers version >= {TRANSFORMERS_VERSION}, current version is {transformers.__version__}" |
@@ -88,33 +91,40 @@ def replace_rms_norm_with_native_rms_norm(module: torch.nn.Module): |
88 | 91 |
|
89 | 92 | pte_filename = "eurobert_qnn_q16" |
90 | 93 |
|
91 | | - # lower to QNN |
92 | | - passes_job = get_capture_program_passes() |
93 | | - quantizer = make_quantizer( |
94 | | - quant_dtype=QuantDtype.use_16a16w, |
95 | | - ) |
96 | | - quantizer.add_custom_quant_annotations((annotate_eurobert,)) |
97 | | - with torch.no_grad(): |
98 | | - build_executorch_binary( |
99 | | - model, |
100 | | - inputs[0], |
101 | | - args.model, |
102 | | - f"{args.artifact}/{pte_filename}", |
103 | | - dataset=inputs, |
104 | | - skip_node_id_set=skip_node_id_set, |
105 | | - skip_node_op_set=skip_node_op_set, |
106 | | - custom_quantizer=quantizer, |
107 | | - passes_job=passes_job, |
108 | | - shared_buffer=args.shared_buffer, |
| 94 | + # Skip lowering/compilation if using pre-generated PTE |
| 95 | + if not args.pre_gen_pte: |
| 96 | + # lower to QNN |
| 97 | + passes_job = get_capture_program_passes() |
| 98 | + quantizer = make_quantizer( |
| 99 | + quant_dtype=QuantDtype.use_16a16w, |
109 | 100 | ) |
| 101 | + quantizer.add_custom_quant_annotations((annotate_eurobert,)) |
| 102 | + with torch.no_grad(): |
| 103 | + build_executorch_binary( |
| 104 | + model, |
| 105 | + inputs[0], |
| 106 | + args.model, |
| 107 | + f"{args.artifact}/{pte_filename}", |
| 108 | + dataset=inputs, |
| 109 | + skip_node_id_set=skip_node_id_set, |
| 110 | + skip_node_op_set=skip_node_op_set, |
| 111 | + custom_quantizer=quantizer, |
| 112 | + passes_job=passes_job, |
| 113 | + shared_buffer=args.shared_buffer, |
| 114 | + ) |
110 | 115 |
|
111 | 116 | if args.compile_only: |
112 | 117 | return |
113 | 118 |
|
| 119 | + pte_path = ( |
| 120 | + f"{args.pre_gen_pte}/{pte_filename}.pte" |
| 121 | + if args.pre_gen_pte |
| 122 | + else f"{args.artifact}/{pte_filename}.pte" |
| 123 | + ) |
114 | 124 | adb = SimpleADB( |
115 | 125 | qnn_sdk=os.getenv("QNN_SDK_ROOT"), |
116 | 126 | build_path=f"{args.build_folder}", |
117 | | - pte_path=f"{args.artifact}/{pte_filename}.pte", |
| 127 | + pte_path=pte_path, |
118 | 128 | workspace=f"/data/local/tmp/executorch/{pte_filename}", |
119 | 129 | device_id=args.device, |
120 | 130 | host_id=args.host, |
|
0 commit comments