@@ -274,6 +274,29 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) ->
274274
275275// -----
276276
277+ // CHECK-LABEL: func.func @test_maxpool_pad
278+ func.func @test_maxpool_pad (%arg0: !torch.vtensor <[1 ,64 ,112 ,112 ],f32 >) -> !torch.vtensor <[1 ,64 ,56 ,56 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 } {
279+ // CHECK: %[[INT3:.*]] = torch.constant.int 3
280+ // CHECK: %[[INT3_0:.*]] = torch.constant.int 3
281+ // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
282+ // CHECK: %[[INT1:.*]] = torch.constant.int 1
283+ // CHECK: %[[INT1_1:.*]] = torch.constant.int 1
284+ // CHECK: %[[INT1_2:.*]] = torch.constant.int 1
285+ // CHECK: %[[INT1_3:.*]] = torch.constant.int 1
286+ // CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_1]], %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
287+ // CHECK: %[[INT2:.*]] = torch.constant.int 2
288+ // CHECK: %[[INT2_4:.*]] = torch.constant.int 2
289+ // CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_4]] : (!torch.int, !torch.int) -> !torch.list<int>
290+ // CHECK: %[[EMPTY_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
291+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
292+ // CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %arg0, %[[LIST]], %[[LIST3]], %[[LIST2]], %[[EMPTY_LIST]], %[[FALSE]] : !torch.vtensor<[1,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,64,56,56],f32>
293+ // CHECK: return %[[OUT]] : !torch.vtensor<[1,64,56,56],f32>
294+ %0 = torch.operator " onnx.MaxPool" (%arg0 ) {torch.onnx.ceil_mode = 0 : si64 , torch.onnx.kernel_shape = [3 : si64 , 3 : si64 ], torch.onnx.pads = [1 : si64 , 1 : si64 , 1 : si64 , 1 : si64 ], torch.onnx.strides = [2 : si64 , 2 : si64 ]} : (!torch.vtensor <[1 ,64 ,112 ,112 ],f32 >) -> !torch.vtensor <[1 ,64 ,56 ,56 ],f32 >
295+ return %0 : !torch.vtensor <[1 ,64 ,56 ,56 ],f32 >
296+ }
297+
298+ // -----
299+
277300// CHECK-LABEL: @test_gelu_default_1
278301func.func @test_gelu_default_1 (%arg0: !torch.vtensor <[3 ],f32 >) -> !torch.vtensor <[3 ],f32 > attributes {torch.onnx_meta.ir_version = 9 : si64 , torch.onnx_meta.opset_version = 20 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
279302 // CHECK: %[[STR1:.*]] = torch.constant.str "none"
0 commit comments