Skip to content

Commit b7a0329

Browse files
[ONNX][MLIR] Fix padding size constraint for onnx.maxpool op (#2782)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
1 parent d452c4f commit b7a0329

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
201201
binder.op, "kernel list size does not match the number of axes");
202202
if (binder.s64IntegerArrayAttr(padding, "pads", {0}))
203203
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
204-
if (padding.size() != 1 && padding.size() != rank - 2)
204+
if (padding.size() != 1 && padding.size() != 2 * (rank - 2))
205205
return rewriter.notifyMatchFailure(
206-
binder.op, "padding list size does not match the number of axes");
206+
binder.op, "padding list must contain (begin,end) pair for each "
207+
"spatial axis");
207208
if (binder.s64IntegerArrayAttr(strides, "strides", {1}))
208209
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
209210
if (strides.size() != 1 && strides.size() != rank - 2)

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,29 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) ->
274274

275275
// -----
276276

277+
// CHECK-LABEL: func.func @test_maxpool_pad
278+
func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
279+
// CHECK: %[[INT3:.*]] = torch.constant.int 3
280+
// CHECK: %[[INT3_0:.*]] = torch.constant.int 3
281+
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
282+
// CHECK: %[[INT1:.*]] = torch.constant.int 1
283+
// CHECK: %[[INT1_1:.*]] = torch.constant.int 1
284+
// CHECK: %[[INT1_2:.*]] = torch.constant.int 1
285+
// CHECK: %[[INT1_3:.*]] = torch.constant.int 1
286+
// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_1]], %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
287+
// CHECK: %[[INT2:.*]] = torch.constant.int 2
288+
// CHECK: %[[INT2_4:.*]] = torch.constant.int 2
289+
// CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_4]] : (!torch.int, !torch.int) -> !torch.list<int>
290+
// CHECK: %[[EMPTY_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
291+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
292+
// CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %arg0, %[[LIST]], %[[LIST3]], %[[LIST2]], %[[EMPTY_LIST]], %[[FALSE]] : !torch.vtensor<[1,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,64,56,56],f32>
293+
// CHECK: return %[[OUT]] : !torch.vtensor<[1,64,56,56],f32>
294+
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32>
295+
return %0 : !torch.vtensor<[1,64,56,56],f32>
296+
}
297+
298+
// -----
299+
277300
// CHECK-LABEL: @test_gelu_default_1
278301
func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
279302
// CHECK: %[[STR1:.*]] = torch.constant.str "none"

0 commit comments

Comments
 (0)