1010#include " torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
1111#include " torch-mlir/Dialect/Torch/IR/TorchOps.h"
1212#include " torch-mlir/Dialect/Torch/Utils/Utils.h"
13+ #include " llvm/Support/FormatVariadic.h"
1314
1415using namespace mlir ;
1516using namespace mlir ::torch;
1617using namespace mlir ::torch::onnx_c;
1718
1819static int64_t onnxDtypeIntToTorchDtypeInt (int64_t dtypeIntOnnx) {
19- int64_t dtypeIntTorch;
2020 // TODO: Add complete mapping.
21- switch (dtypeIntOnnx) {
22- case 1 :
23- dtypeIntTorch = 6 ; // float
24- break ;
25- case 10 :
26- dtypeIntTorch = 5 ; // half
27- break ;
28- case 11 :
29- dtypeIntTorch = 7 ; // double
30- break ;
31- case 16 :
32- dtypeIntTorch = 15 ; // bfloat16
33- break ;
34- default :
35- dtypeIntTorch = -1 ; // No dtype
36- }
21+ // Where are the ONNX and PyTorch dtype enums defined?
22+ // ONNX:
23+ // https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto
24+ // PyTorch:
25+ // https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88
26+
27+ int64_t dtypeIntTorch = [dtypeIntOnnx]() {
28+ switch (dtypeIntOnnx) {
29+ case 1 :
30+ return 6 ; // float
31+ case 7 :
32+ return 5 ; // int64
33+ case 9 :
34+ return 11 ; // bool
35+ case 10 :
36+ return 5 ; // half
37+ case 11 :
38+ return 7 ; // double
39+ case 16 :
40+ return 15 ; // bfloat16
41+ default :
42+ return -1 ; // No dtype
43+ }
44+ }();
45+
3746 return dtypeIntTorch;
3847}
3948
@@ -415,30 +424,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
415424 }
416425 return success ();
417426 });
418- patterns.onOp (
419- " BitwiseAnd " , 18 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
420- Torch::ValueTensorType resultType;
421- Value lhs, rhs;
422- std::string direction;
423- if (binder.tensorOperands (lhs, rhs) ||
424- binder.tensorResultType (resultType))
425- return failure ();
426- rewriter.replaceOpWithNewOp <Torch::AtenBitwiseAndTensorOp>(
427- binder.op , resultType, lhs, rhs);
428- return success ();
429- });
430- patterns.onOp (
431- " BitwiseOr " , 18 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
432- Torch::ValueTensorType resultType;
433- Value lhs, rhs;
434- std::string direction;
435- if (binder.tensorOperands (lhs, rhs) ||
436- binder.tensorResultType (resultType))
437- return failure ();
438- rewriter.replaceOpWithNewOp <Torch::AtenBitwiseOrTensorOp>(
439- binder.op , resultType, lhs, rhs);
440- return success ();
441- });
427+ patterns.onOp (" BitwiseAnd " , 18 ,
428+ [](OpBinder binder, ConversionPatternRewriter &rewriter) {
429+ Torch::ValueTensorType resultType;
430+ Value lhs, rhs;
431+ std::string direction;
432+ if (binder.tensorOperands (lhs, rhs) ||
433+ binder.tensorResultType (resultType))
434+ return failure ();
435+ rewriter.replaceOpWithNewOp <Torch::AtenBitwiseAndTensorOp>(
436+ binder.op , resultType, lhs, rhs);
437+ return success ();
438+ });
439+ patterns.onOp (" BitwiseOr " , 18 ,
440+ [](OpBinder binder, ConversionPatternRewriter &rewriter) {
441+ Torch::ValueTensorType resultType;
442+ Value lhs, rhs;
443+ std::string direction;
444+ if (binder.tensorOperands (lhs, rhs) ||
445+ binder.tensorResultType (resultType))
446+ return failure ();
447+ rewriter.replaceOpWithNewOp <Torch::AtenBitwiseOrTensorOp>(
448+ binder.op , resultType, lhs, rhs);
449+ return success ();
450+ });
442451 patterns.onOp (" BitwiseNot" , 18 ,
443452 [](OpBinder binder, ConversionPatternRewriter &rewriter) {
444453 Torch::ValueTensorType resultType;
@@ -450,18 +459,18 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
450459 binder.op , resultType, operand);
451460 return success ();
452461 });
453- patterns.onOp (
454- " BitwiseXor " , 18 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
455- Torch::ValueTensorType resultType;
456- Value lhs, rhs;
457- std::string direction;
458- if (binder.tensorOperands (lhs, rhs) ||
459- binder.tensorResultType (resultType))
460- return failure ();
461- rewriter.replaceOpWithNewOp <Torch::AtenBitwiseXorTensorOp>(
462- binder.op , resultType, lhs, rhs);
463- return success ();
464- });
462+ patterns.onOp (" BitwiseXor " , 18 ,
463+ [](OpBinder binder, ConversionPatternRewriter &rewriter) {
464+ Torch::ValueTensorType resultType;
465+ Value lhs, rhs;
466+ std::string direction;
467+ if (binder.tensorOperands (lhs, rhs) ||
468+ binder.tensorResultType (resultType))
469+ return failure ();
470+ rewriter.replaceOpWithNewOp <Torch::AtenBitwiseXorTensorOp>(
471+ binder.op , resultType, lhs, rhs);
472+ return success ();
473+ });
465474 patterns.onOp (
466475 " Cast" , 1 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
467476 Torch::ValueTensorType resultType;
@@ -474,9 +483,13 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
474483
475484 dtypeIntTorch = onnxDtypeIntToTorchDtypeInt (dtypeIntOnnx);
476485 if (dtypeIntTorch == -1 ) {
477- return rewriter.notifyMatchFailure (
478- binder.op ,
479- " unimplemented support for the given dtype conversion" );
486+ auto message = llvm::formatv (" unimplemented support for the given "
487+ " dtype conversion (onnx 'type' = {0})" ,
488+ dtypeIntOnnx);
489+ llvm::errs () << message << " \n " ;
490+ auto y = rewriter.notifyMatchFailure (binder.op , message);
491+
492+ return y;
480493 }
481494 Value constDtype = rewriter.create <Torch::ConstantIntOp>(
482495 binder.getLoc (), rewriter.getType <Torch::IntType>(),
@@ -864,7 +877,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
864877 unsigned rank = *maybeRank;
865878
866879 SmallVector<int64_t > padding, strides, dilations, outputPadding;
867- SmallVector<int64_t > defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding;
880+ SmallVector<int64_t > defaultPadding, defaultStrides, defaultDilations,
881+ defaultOutputPadding;
868882 for (unsigned i = 0 ; i < rank - 2 ; i++) {
869883 defaultPadding.push_back (0 );
870884 defaultStrides.push_back (1 );
@@ -1018,30 +1032,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
10181032 cast<Torch::ValueTensorType>(operand.getType ()).getSizes ().size ();
10191033 Value rankVal = rewriter.create <Torch::ConstantIntOp>(
10201034 binder.getLoc (), rewriter.getType <Torch::IntType>(),
1021- rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
1022- rank));
1035+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), rank));
10231036 Value zero = rewriter.create <Torch::ConstantIntOp>(
10241037 loc, rewriter.getI64IntegerAttr (0 ));
1025-
1038+
10261039 Value axisScalar = rewriter.create <Torch::AtenItemOp>(
10271040 binder.getLoc (), rewriter.getType <Torch::IntType>(), axisTensor);
1028- Value isNegative =
1029- rewriter. create <Torch::AtenLtIntOp>( binder.getLoc (), axisScalar, zero);
1030- isNegative = rewriter. create <Torch::AtenIntBoolOp>(binder. getLoc (),
1031- isNegative);
1041+ Value isNegative = rewriter. create <Torch::AtenLtIntOp>(
1042+ binder.getLoc (), axisScalar, zero);
1043+ isNegative =
1044+ rewriter. create <Torch::AtenIntBoolOp>(binder. getLoc (), isNegative);
10321045 Value finalOffset = rewriter.create <Torch::AtenMulIntOp>(
10331046 binder.getLoc (), isNegative, rankVal);
10341047 Value dim = rewriter.create <Torch::AtenAddIntOp>(
10351048 binder.getLoc (), axisScalar, finalOffset);
10361049
1037- Torch::BaseTensorType resultTensorType = resultType.cast <Torch::BaseTensorType>();
1050+ Torch::BaseTensorType resultTensorType =
1051+ resultType.cast <Torch::BaseTensorType>();
10381052 if (!resultTensorType.hasDtype ()) {
10391053 return rewriter.notifyMatchFailure (
10401054 binder.op , " expected result type to have a dtype" );
10411055 }
10421056 // resultTensorType.print(llvm::outs());
1043- Value resultDType =
1044- Torch::getDtypeIntValueForType ( rewriter, loc, resultTensorType.getDtype ());
1057+ Value resultDType = Torch::getDtypeIntValueForType (
1058+ rewriter, loc, resultTensorType.getDtype ());
10451059
10461060 rewriter.replaceOpWithNewOp <Torch::AtenCumsumOp>(
10471061 binder.op , resultType, operand, dim, resultDType);
0 commit comments