@@ -38,11 +38,22 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
3838 }
3939 static Node *makeSqueezeOrUnsqueeze (Graph &graph, std::vector<int64_t > &axes,
4040 Value *input, Node *target_node,
41- BuiltinSymbol k) {
41+ BuiltinSymbol k, bool is_input_qdq ) {
4242 assert (k == kSqueeze || k == kUnsqueeze );
4343 Node *squeeze = graph.create (k, 1 );
44- int opset_version = getOpsetVersion (graph);
44+ Node *dequant_node = nullptr ;
45+ Node *quant_node = nullptr ;
46+ // insert squeeze op before qdq
47+ if (is_input_qdq) {
48+ dequant_node = input->node ();
49+ quant_node = dequant_node->input (0 )->node ();
50+ target_node = quant_node;
51+ input = target_node->input (0 );
52+ dequant_node->output ()->clearMetadata ();
53+ quant_node->output ()->clearMetadata ();
54+ }
4555 squeeze->addInput (input);
56+ int opset_version = getOpsetVersion (graph);
4657 int version_threshold = 13 ;
4758 if (opset_version < version_threshold && opset_version != 0 ) {
4859 squeeze->is_ (kaxes, std::move (axes));
@@ -54,7 +65,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
5465 Value *tv = graph.addInitializerAndInput (t);
5566 squeeze->addInput (tv);
5667 }
68+ if (is_input_qdq) {
69+ quant_node->replaceInput (0 , squeeze->output ());
70+ }
5771 squeeze->insertBefore (target_node);
72+ if (is_input_qdq) {
73+ return dequant_node;
74+ }
5875 return squeeze;
5976 }
6077 bool runTransform (Node *n, Graph &graph,
@@ -115,13 +132,13 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
115132 if (bias_shape.size () > 1 ) {
116133 std::vector<int64_t > axes (bias_shape.size () - 1 );
117134 std::iota (axes.begin (), axes.end (), 0 );
118- Node *squeeze = makeSqueezeOrUnsqueeze (graph, axes, conv_3rd_input,
119- orig_conv->node (), kSqueeze );
135+ Node *squeeze = makeSqueezeOrUnsqueeze (
136+ graph, axes, conv_3rd_input, orig_conv->node (), kSqueeze , false );
120137 conv_3rd_input = squeeze->output ();
121138 } else if (bias_shape.size () == 0 ) {
122139 std::vector<int64_t > axes = {0 };
123- Node *unsqueeze = makeSqueezeOrUnsqueeze (graph, axes, conv_3rd_input,
124- orig_conv->node (), kUnsqueeze );
140+ Node *unsqueeze = makeSqueezeOrUnsqueeze (
141+ graph, axes, conv_3rd_input, orig_conv->node (), kUnsqueeze , false );
125142 conv_3rd_input = unsqueeze->output ();
126143 }
127144 if (M > 1 ) {
@@ -149,17 +166,25 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
149166 bias_shape[1 + bias_shape.size () - static_cast <unsigned >(rank)]
150167 .dim == M) {
151168 ONNX_ASSERT (bias_shape.size () > 1 );
169+ const bool is_input_qdq =
170+ orig_bias->node ()->kind () == Symbol (" DequantizeLinear" ) &&
171+ orig_bias->node ()->input (0 )->node ()->kind () ==
172+ Symbol (" QuantizeLinear" );
152173 if (orig_bias->node ()->kind () != kParam &&
153174 orig_conv->node ()->isBefore (orig_bias->node ())) {
175+ if (is_input_qdq) {
176+ orig_bias->node ()->input (0 )->node ()->moveBefore (orig_conv->node ());
177+ }
154178 orig_bias->node ()->moveBefore (orig_conv->node ());
155179 }
156180 std::vector<int64_t > axes (bias_shape.size ());
157181 std::iota (axes.begin (), axes.end (), static_cast <int64_t >(0 ));
158182 axes.erase (axes.begin () +
159183 (1 + bias_shape.size () - static_cast <unsigned >(rank)));
160- Node *squeeze = makeSqueezeOrUnsqueeze (graph, axes, orig_bias,
161- orig_conv->node (), kSqueeze );
162- orig_conv->node ()->addInput (squeeze->output ());
184+
185+ Node *new_bias = makeSqueezeOrUnsqueeze (
186+ graph, axes, orig_bias, orig_conv->node (), kSqueeze , is_input_qdq);
187+ orig_conv->node ()->addInput (new_bias->output ());
163188 } else {
164189 return false ;
165190 }
0 commit comments