Skip to content

Commit da7c6d2

Browse files
[MLIR][TORCH] Add support for dynamic shape for Onnx.Transpose op (#2803)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
1 parent 4964977 commit da7c6d2

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,6 +1244,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
12441244
current[i] = i;
12451245
}
12461246

1247+
// Convert dynamic shape dimension.
1248+
for (unsigned i = 0; i < shape.size(); i++){
1249+
if (shape[i] == ShapedType::kDynamic)
1250+
shape[i] = Torch::kUnknownSize;
1251+
}
1252+
12471253
for (int64_t i = 0; i < rank; ++i) {
12481254
if (current[i] == permutations[i])
12491255
continue;

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,18 @@ func.func @test_transpose_all_permutations_4(%arg0: !torch.vtensor<[2,3,4],f32>)
968968
return %0 : !torch.vtensor<[4,2,3],f32>
969969
}
970970

971+
// -----
972+
973+
// CHECK-LABEL: func.func @test_transpose_dynamic
974+
func.func @test_transpose_dynamic(%arg0: !torch.vtensor<[?,32,5,128],f32>) -> !torch.vtensor<[?,5,32,128],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
975+
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
976+
// CHECK-DAG: %[[I2:.+]] = torch.constant.int 2
977+
// CHECK: %[[TRANSPOSE:.+]] = torch.aten.transpose.int %arg0, %[[I1]], %[[I2]] : !torch.vtensor<[?,32,5,128],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,5,32,128],f32>
978+
%0 = torch.operator "onnx.Transpose"(%arg0) {torch.onnx.perm = [0 : si64, 2 : si64, 1 : si64, 3 : si64]} : (!torch.vtensor<[?,32,5,128],f32>) -> !torch.vtensor<[?,5,32,128],f32>
979+
return %0 : !torch.vtensor<[?,5,32,128],f32>
980+
}
981+
982+
971983
// -----
972984

973985
// CHECK-LABEL: func.func @test_slice

0 commit comments

Comments
 (0)