@@ -152,12 +152,17 @@ func.func @transposedGroupedConvolution2D(%arg0: !torch.vtensor<[1,2,5,7],f32>)
152152}
153153
154154// CHECK-LABEL: func.func @tranConv2dNegativePadding(
155- // CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>) -> !torch.vtensor<[1,2,6,3],f32>
156- // CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
157- // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[IN_TENSOR]][0, 0, 0, 1] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x7xf32> to tensor<1x1x4x5xf32>
158- // CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]] into %[[INIT_TENSOR:.*]][0, 0, 2, 0] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x5xf32> into tensor<1x1x8x5xf32>
159- // CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[INSERTED_SLICE]], %[[WEIGHTS:.*]] : tensor<1x1x8x5xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x6x3xf32>) -> tensor<1x2x6x3xf32>
160- // CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x6x3xf32> -> !torch.vtensor<[1,2,6,3],f32>
155+ // CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>) -> !torch.vtensor<[1,2,6,3],f32> attributes {torch.assume_strict_symbolic_shapes} {
156+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
157+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
158+ // CHECK-DAG: %[[C0F:.*]] = arith.constant 0.000000e+00 : f32
159+ // CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
160+ // CHECK: %[[EMPTY_UNSTRIDED_TENSOR:.*]] = tensor.empty() : tensor<1x1x8x7xf32>
161+ // CHECK: %[[ZEROS_UNSTRIDED_TENSOR:.*]] = linalg.fill ins(%[[C0F]] : f32) outs(%[[EMPTY_UNSTRIDED_TENSOR]] : tensor<1x1x8x7xf32>) -> tensor<1x1x8x7xf32>
162+ // CHECK: %[[INPUT_UNSTRIDED_TENSOR:.*]] = tensor.insert_slice %[[INPUT_TENSOR]] into %[[ZEROS_UNSTRIDED_TENSOR]][0, 0, 2, 0] [1, 1, 4, 7] [1, 1, 1, 1] : tensor<1x1x4x7xf32> into tensor<1x1x8x7xf32>
163+ // CHECK: %[[CROPPED_UNSTRIDED_TENSOR:.*]] = tensor.extract_slice %[[INPUT_UNSTRIDED_TENSOR]][0, 0, 0, 1] [1, 1, 8, 5] [1, 1, 1, 1] : tensor<1x1x8x7xf32> to tensor<1x1x8x5xf32>
164+ // CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[CROPPED_UNSTRIDED_TENSOR]], %[[WEIGHTS:.*]] : tensor<1x1x8x5xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x6x3xf32>) -> tensor<1x2x6x3xf32>
165+ // CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x6x3xf32> -> !torch.vtensor<[1,2,6,3],f32>
161166func.func @tranConv2dNegativePadding (%arg0: !torch.vtensor <[1 , 1 , 4 , 7 ],f32 >) -> !torch.vtensor <[1 , 2 , 6 , 3 ],f32 > attributes {torch.assume_strict_symbolic_shapes } {
162167 %int0 = torch.constant.int 0
163168 %true = torch.constant.bool true
@@ -174,3 +179,31 @@ func.func @tranConv2dNegativePadding(%arg0: !torch.vtensor<[1, 1, 4, 7],f32>) ->
174179 %6 = torch.aten.convolution %arg0 , %0 , %1 , %2 , %3 , %4 , %true , %5 , %int1 : !torch.vtensor <[1 , 1 , 4 , 7 ],f32 >, !torch.vtensor <[1 ,2 ,3 ,3 ],f32 >, !torch.vtensor <[2 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[1 , 2 , 6 , 3 ],f32 >
175180 return %6 : !torch.vtensor <[1 , 2 , 6 , 3 ],f32 >
176181}
182+
183+ // CHECK-LABEL: func.func @tranConv2dNegativeAndPositivePadding(
184+ // CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>,
185+ // CHECK-SAME: %[[WEIGHTS_VTENSOR:.*]]: !torch.vtensor<[1,2,3,3],f32>,
186+ // CHECK-SAME: %[[BIAS_VTENSOR:.*]]: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,15,21],f32> {
187+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
188+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
189+ // CHECK-DAG: %[[C0F:.*]] = arith.constant 0.000000e+00 : f32
190+ // CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
191+ // CHECK: %[[EMPTY_UNSTRIDED_TENSOR:.*]] = tensor.empty() : tensor<1x1x17x25xf32>
192+ // CHECK: %[[ZEROS_UNSTRIDED_TENSOR:.*]] = linalg.fill ins(%[[C0F]] : f32) outs(%[[EMPTY_UNSTRIDED_TENSOR]] : tensor<1x1x17x25xf32>) -> tensor<1x1x17x25xf32>
193+ // CHECK: %[[INPUT_UNSTRIDED_TENSOR:.*]] = tensor.insert_slice %[[INPUT_TENSOR]] into %[[ZEROS_UNSTRIDED_TENSOR]][0, 0, 2, 0] [1, 1, 4, 7] [1, 1, 4, 4] : tensor<1x1x4x7xf32> into tensor<1x1x17x25xf32>
194+ // CHECK: %[[CROPPED_UNSTRIDED_TENSOR:.*]] = tensor.extract_slice %[[INPUT_UNSTRIDED_TENSOR]][0, 0, 0, 1] [1, 1, 17, 23] [1, 1, 1, 1] : tensor<1x1x17x25xf32> to tensor<1x1x17x23xf32>
195+ // CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[CROPPED_UNSTRIDED_TENSOR]], %[[WEIGHTS:.*]] : tensor<1x1x17x23xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x15x21xf32>) -> tensor<1x2x15x21xf32>
196+ // CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x15x21xf32> -> !torch.vtensor<[1,2,15,21],f32>
197+ func.func @tranConv2dNegativeAndPositivePadding (%arg0: !torch.vtensor <[1 ,1 ,4 ,7 ],f32 >, %arg1: !torch.vtensor <[1 ,2 ,3 ,3 ],f32 >, %arg2: !torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[1 ,2 ,15 ,21 ],f32 > {
198+ %int1 = torch.constant.int 1
199+ %int3 = torch.constant.int 3
200+ %int0 = torch.constant.int 0
201+ %int4 = torch.constant.int 4
202+ %true = torch.constant.bool true
203+ %0 = torch.prim.ListConstruct %int4 , %int4 : (!torch.int , !torch.int ) -> !torch.list <int >
204+ %1 = torch.prim.ListConstruct %int0 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
205+ %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
206+ %3 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
207+ %4 = torch.aten.convolution %arg0 , %arg1 , %arg2 , %0 , %1 , %2 , %true , %3 , %int1 : !torch.vtensor <[1 ,1 ,4 ,7 ],f32 >, !torch.vtensor <[1 ,2 ,3 ,3 ],f32 >, !torch.vtensor <[2 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[1 ,2 ,15 ,21 ],f32 >
208+ return %4 : !torch.vtensor <[1 ,2 ,15 ,21 ],f32 >
209+ }
0 commit comments