Skip to content

Commit dc056e5

Browse files
authored
[MLIR][TORCH] Add onnx.cast cases used by OPT-1.25M (#2787)
1 parent c9d8ffb commit dc056e5

File tree

2 files changed

+91
-67
lines changed

2 files changed

+91
-67
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Lines changed: 81 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,39 @@
1010
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
1111
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1212
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
13+
#include "llvm/Support/FormatVariadic.h"
1314

1415
using namespace mlir;
1516
using namespace mlir::torch;
1617
using namespace mlir::torch::onnx_c;
1718

1819
static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
19-
int64_t dtypeIntTorch;
2020
// TODO: Add complete mapping.
21-
switch (dtypeIntOnnx) {
22-
case 1:
23-
dtypeIntTorch = 6; // float
24-
break;
25-
case 10:
26-
dtypeIntTorch = 5; // half
27-
break;
28-
case 11:
29-
dtypeIntTorch = 7; // double
30-
break;
31-
case 16:
32-
dtypeIntTorch = 15; // bfloat16
33-
break;
34-
default:
35-
dtypeIntTorch = -1; // No dtype
36-
}
21+
// Where are the ONNX and PyTorch dtype enums defined?
22+
// ONNX:
23+
// https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto
24+
// PyTorch:
25+
// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88
26+
27+
int64_t dtypeIntTorch = [dtypeIntOnnx]() {
28+
switch (dtypeIntOnnx) {
29+
case 1:
30+
return 6; // float
31+
case 7:
32+
return 5; // int64
33+
case 9:
34+
return 11; // bool
35+
case 10:
36+
return 5; // half
37+
case 11:
38+
return 7; // double
39+
case 16:
40+
return 15; // bfloat16
41+
default:
42+
return -1; // No dtype
43+
}
44+
}();
45+
3746
return dtypeIntTorch;
3847
}
3948

@@ -415,30 +424,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
415424
}
416425
return success();
417426
});
418-
patterns.onOp(
419-
"BitwiseAnd", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
420-
Torch::ValueTensorType resultType;
421-
Value lhs, rhs;
422-
std::string direction;
423-
if (binder.tensorOperands(lhs, rhs) ||
424-
binder.tensorResultType(resultType))
425-
return failure();
426-
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseAndTensorOp>(
427-
binder.op, resultType, lhs, rhs);
428-
return success();
429-
});
430-
patterns.onOp(
431-
"BitwiseOr", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
432-
Torch::ValueTensorType resultType;
433-
Value lhs, rhs;
434-
std::string direction;
435-
if (binder.tensorOperands(lhs, rhs) ||
436-
binder.tensorResultType(resultType))
437-
return failure();
438-
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseOrTensorOp>(
439-
binder.op, resultType, lhs, rhs);
440-
return success();
441-
});
427+
patterns.onOp("BitwiseAnd", 18,
428+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
429+
Torch::ValueTensorType resultType;
430+
Value lhs, rhs;
431+
std::string direction;
432+
if (binder.tensorOperands(lhs, rhs) ||
433+
binder.tensorResultType(resultType))
434+
return failure();
435+
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseAndTensorOp>(
436+
binder.op, resultType, lhs, rhs);
437+
return success();
438+
});
439+
patterns.onOp("BitwiseOr", 18,
440+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
441+
Torch::ValueTensorType resultType;
442+
Value lhs, rhs;
443+
std::string direction;
444+
if (binder.tensorOperands(lhs, rhs) ||
445+
binder.tensorResultType(resultType))
446+
return failure();
447+
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseOrTensorOp>(
448+
binder.op, resultType, lhs, rhs);
449+
return success();
450+
});
442451
patterns.onOp("BitwiseNot", 18,
443452
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
444453
Torch::ValueTensorType resultType;
@@ -450,18 +459,18 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
450459
binder.op, resultType, operand);
451460
return success();
452461
});
453-
patterns.onOp(
454-
"BitwiseXor", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
455-
Torch::ValueTensorType resultType;
456-
Value lhs, rhs;
457-
std::string direction;
458-
if (binder.tensorOperands(lhs, rhs) ||
459-
binder.tensorResultType(resultType))
460-
return failure();
461-
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseXorTensorOp>(
462-
binder.op, resultType, lhs, rhs);
463-
return success();
464-
});
462+
patterns.onOp("BitwiseXor", 18,
463+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
464+
Torch::ValueTensorType resultType;
465+
Value lhs, rhs;
466+
std::string direction;
467+
if (binder.tensorOperands(lhs, rhs) ||
468+
binder.tensorResultType(resultType))
469+
return failure();
470+
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseXorTensorOp>(
471+
binder.op, resultType, lhs, rhs);
472+
return success();
473+
});
465474
patterns.onOp(
466475
"Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
467476
Torch::ValueTensorType resultType;
@@ -474,9 +483,13 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
474483

475484
dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
476485
if (dtypeIntTorch == -1) {
477-
return rewriter.notifyMatchFailure(
478-
binder.op,
479-
"unimplemented support for the given dtype conversion");
486+
auto message = llvm::formatv("unimplemented support for the given "
487+
"dtype conversion (onnx 'type' = {0})",
488+
dtypeIntOnnx);
489+
llvm::errs() << message << "\n";
490+
auto y = rewriter.notifyMatchFailure(binder.op, message);
491+
492+
return y;
480493
}
481494
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
482495
binder.getLoc(), rewriter.getType<Torch::IntType>(),
@@ -864,7 +877,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
864877
unsigned rank = *maybeRank;
865878

866879
SmallVector<int64_t> padding, strides, dilations, outputPadding;
867-
SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding;
880+
SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations,
881+
defaultOutputPadding;
868882
for (unsigned i = 0; i < rank - 2; i++) {
869883
defaultPadding.push_back(0);
870884
defaultStrides.push_back(1);
@@ -1018,30 +1032,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
10181032
cast<Torch::ValueTensorType>(operand.getType()).getSizes().size();
10191033
Value rankVal = rewriter.create<Torch::ConstantIntOp>(
10201034
binder.getLoc(), rewriter.getType<Torch::IntType>(),
1021-
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
1022-
rank));
1035+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), rank));
10231036
Value zero = rewriter.create<Torch::ConstantIntOp>(
10241037
loc, rewriter.getI64IntegerAttr(0));
1025-
1038+
10261039
Value axisScalar = rewriter.create<Torch::AtenItemOp>(
10271040
binder.getLoc(), rewriter.getType<Torch::IntType>(), axisTensor);
1028-
Value isNegative =
1029-
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axisScalar, zero);
1030-
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
1031-
isNegative);
1041+
Value isNegative = rewriter.create<Torch::AtenLtIntOp>(
1042+
binder.getLoc(), axisScalar, zero);
1043+
isNegative =
1044+
rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isNegative);
10321045
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
10331046
binder.getLoc(), isNegative, rankVal);
10341047
Value dim = rewriter.create<Torch::AtenAddIntOp>(
10351048
binder.getLoc(), axisScalar, finalOffset);
10361049

1037-
Torch::BaseTensorType resultTensorType = resultType.cast<Torch::BaseTensorType>();
1050+
Torch::BaseTensorType resultTensorType =
1051+
resultType.cast<Torch::BaseTensorType>();
10381052
if (!resultTensorType.hasDtype()) {
10391053
return rewriter.notifyMatchFailure(
10401054
binder.op, "expected result type to have a dtype");
10411055
}
10421056
// resultTensorType.print(llvm::outs());
1043-
Value resultDType =
1044-
Torch::getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype());
1057+
Value resultDType = Torch::getDtypeIntValueForType(
1058+
rewriter, loc, resultTensorType.getDtype());
10451059

10461060
rewriter.replaceOpWithNewOp<Torch::AtenCumsumOp>(
10471061
binder.op, resultType, operand, dim, resultDType);

test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,16 @@ func.func @test_cast_FLOAT16_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f16>) -> !tor
332332
return %0 : !torch.vtensor<[3,4],f64>
333333
}
334334

335+
// CHECK-LABEL: @test_cast_FLOAT_to_BOOL
336+
func.func @test_cast_FLOAT_to_BOOL(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
337+
// CHECK: %[[INT:.*]] = torch.constant.int 11
338+
// CHECK: %[[NONE:.*]] = torch.constant.none
339+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
340+
// CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],i1>
341+
%0 = torch.operator "onnx.Cast"(%arg0) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],i1>
342+
return %0 : !torch.vtensor<[3,4],i1>
343+
}
344+
335345
// CHECK-LABEL: @test_cast_FLOAT16_to_FLOAT
336346
func.func @test_cast_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
337347
// CHECK: %[[INT:.*]] = torch.constant.int 6

0 commit comments

Comments
 (0)