@@ -38,11 +38,21 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
3838 }
3939 static Node *makeSqueezeOrUnsqueeze (Graph &graph, std::vector<int64_t > &axes,
4040 Value *input, Node *target_node,
41- BuiltinSymbol k) {
41+ BuiltinSymbol k, bool is_input_qdq ) {
4242 assert (k == kSqueeze || k == kUnsqueeze );
4343 Node *squeeze = graph.create (k, 1 );
44- int opset_version = getOpsetVersion (graph);
44+ Node *dequant_node = nullptr ;
45+ Node *quant_node = nullptr ;
46+ if (is_input_qdq) {
47+ dequant_node = input->node ();
48+ quant_node = dequant_node->input (0 )->node ();
49+ target_node = quant_node;
50+ input = target_node->input (0 );
51+ dequant_node->output ()->clearMetadata ();
52+ quant_node->output ()->clearMetadata ();
53+ }
4554 squeeze->addInput (input);
55+ int opset_version = getOpsetVersion (graph);
4656 int version_threshold = 13 ;
4757 if (opset_version < version_threshold && opset_version != 0 ) {
4858 squeeze->is_ (kaxes, std::move (axes));
@@ -54,7 +64,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
5464 Value *tv = graph.addInitializerAndInput (t);
5565 squeeze->addInput (tv);
5666 }
67+ if (is_input_qdq) {
68+ quant_node->replaceInput (0 , squeeze->output ());
69+ }
5770 squeeze->insertBefore (target_node);
71+ if (is_input_qdq) {
72+ return dequant_node;
73+ }
5874 return squeeze;
5975 }
6076 bool runTransform (Node *n, Graph &graph,
@@ -115,13 +131,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
115131 if (bias_shape.size () > 1 ) {
116132 std::vector<int64_t > axes (bias_shape.size () - 1 );
117133 std::iota (axes.begin (), axes.end (), 0 );
118- Node *squeeze = makeSqueezeOrUnsqueeze (graph, axes, conv_3rd_input,
119- orig_conv->node (), kSqueeze );
134+ Node *squeeze = makeSqueezeOrUnsqueeze (
135+ graph, axes, conv_3rd_input, orig_conv->node (), kSqueeze , false );
120136 conv_3rd_input = squeeze->output ();
121137 } else if (bias_shape.size () == 0 ) {
122138 std::vector<int64_t > axes = {0 };
123- Node *unsqueeze = makeSqueezeOrUnsqueeze (graph, axes, conv_3rd_input,
124- orig_conv->node (), kUnsqueeze );
139+ Node *unsqueeze = makeSqueezeOrUnsqueeze (
140+ graph, axes, conv_3rd_input, orig_conv->node (), kUnsqueeze , false );
125141 conv_3rd_input = unsqueeze->output ();
126142 }
127143 if (M > 1 ) {
@@ -149,17 +165,25 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
149165 bias_shape[1 + bias_shape.size () - static_cast <unsigned >(rank)]
150166 .dim == M) {
151167 ONNX_ASSERT (bias_shape.size () > 1 );
168+ const bool is_input_qdq =
169+ orig_bias->node ()->kind () == Symbol (" DequantizeLinear" ) &&
170+ orig_bias->node ()->input (0 )->node ()->kind () ==
171+ Symbol (" QuantizeLinear" );
152172 if (orig_bias->node ()->kind () != kParam &&
153173 orig_conv->node ()->isBefore (orig_bias->node ())) {
174+ if (is_input_qdq) {
175+ orig_bias->node ()->input (0 )->node ()->moveBefore (orig_conv->node ());
176+ }
154177 orig_bias->node ()->moveBefore (orig_conv->node ());
155178 }
156179 std::vector<int64_t > axes (bias_shape.size ());
157180 std::iota (axes.begin (), axes.end (), static_cast <int64_t >(0 ));
158181 axes.erase (axes.begin () +
159182 (1 + bias_shape.size () - static_cast <unsigned >(rank)));
160- Node *squeeze = makeSqueezeOrUnsqueeze (graph, axes, orig_bias,
161- orig_conv->node (), kSqueeze );
162- orig_conv->node ()->addInput (squeeze->output ());
183+
184+ Node *new_bias = makeSqueezeOrUnsqueeze (
185+ graph, axes, orig_bias, orig_conv->node (), kSqueeze , is_input_qdq);
186+ orig_conv->node ()->addInput (new_bias->output ());
163187 } else {
164188 return false ;
165189 }
0 commit comments