@@ -1125,6 +1125,31 @@ def test_fuse_add_bias_into_conv_squeeze_4d_bias_no_fuse(self):
11251125 assert optimized_model .graph .node [0 ].op_type == 'Conv'
11261126 assert optimized_model .graph .node [1 ].op_type == 'Add'
11271127
1128+ # type: () -> None
1129+ def test_fuse_add_bias_into_conv_with_non_constant_bias (self ):
1130+ nodes = [helper .make_node ("Conv" , ["X" , "Y" ], ["Z" ]),
1131+ helper .make_node ("Sin" , ["A" ], ["B" ]),
1132+ helper .make_node ("Add" , ["Z" , "B" ], ["C" ])]
1133+ graph = helper .make_graph (
1134+ nodes ,
1135+ "test" ,
1136+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (1 , 5 , 3 , 3 )),
1137+ helper .make_tensor_value_info (
1138+ "Y" , TensorProto .FLOAT , (16 , 5 , 3 , 3 )),
1139+ helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (16 , 1 , 1 ))],
1140+ [helper .make_tensor_value_info (
1141+ "C" , TensorProto .FLOAT , (1 , 16 , 1 , 1 ))],
1142+ value_info = [helper .make_tensor_value_info (
1143+ "B" , TensorProto .FLOAT , (16 , 1 , 1 ))]
1144+ )
1145+ optimized_model = self ._optimized (graph , ["fuse_add_bias_into_conv" ])
1146+
1147+ assert len (list (optimized_model .graph .node )) == 3
1148+ assert optimized_model .graph .node [0 ].op_type == 'Sin'
1149+ assert optimized_model .graph .node [1 ].op_type == 'Squeeze'
1150+ assert optimized_model .graph .node [2 ].op_type == 'Conv'
1151+ assert optimized_model .graph .output [0 ].name == 'C'
1152+
11281153 def test_fuse_matmul_add_bias_into_gemm (self ): # type: () -> None
11291154 matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
11301155 add = helper .make_node ("Add" , ["Z" , "B" ], ["A" ])
0 commit comments