@@ -1344,7 +1344,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
13441344 auto makeQTensor =
13451345 qtensor.getDefiningOp <Aten_MakePerTensorQuantizedTensorOp>();
13461346 if (!makeQTensor) {
1347- op->emitError (
1347+ op->emitWarning (
13481348 " unimplemented: dequantizing tensor of unknown scale / zero-point" );
13491349 return nullptr ;
13501350 }
@@ -2221,16 +2221,109 @@ class ConvertAtenIntReprOp : public OpConversionPattern<AtenIntReprOp> {
22212221} // namespace
22222222
22232223namespace {
2224- class ConvertMakePerTensorQuantizedTensorOp
2225- : public OpConversionPattern<Aten_MakePerTensorQuantizedTensorOp > {
2224+ class ConvertDequantizePerChannel
2225+ : public OpConversionPattern<AtenDequantizeSelfOp > {
22262226public:
22272227 using OpConversionPattern::OpConversionPattern;
22282228 LogicalResult
2229- matchAndRewrite (Aten_MakePerTensorQuantizedTensorOp op, OpAdaptor adaptor,
2229+ matchAndRewrite (AtenDequantizeSelfOp op, OpAdaptor adaptor,
22302230 ConversionPatternRewriter &rewriter) const override {
2231- RankedTensorType resultType = getTypeConverter ()
2232- ->convertType (op->getResult (0 ).getType ())
2233- .cast <RankedTensorType>();
2231+ auto loc = op.getLoc ();
2232+ auto qoperand = op.getOperand ();
2233+ auto make = qoperand.getDefiningOp <Aten_MakePerChannelQuantizedTensorOp>();
2234+ if (!make) {
2235+ llvm::errs () << " Did not find make per channel\n " ;
2236+ return rewriter.notifyMatchFailure (op, " did not find per channel qint" );
2237+ }
2238+
2239+ auto converter = getTypeConverter ();
2240+ auto operand = make.getOperand (0 );
2241+ auto scale = make.getScale ();
2242+ auto zeropoint = make.getZeroPoint ();
2243+ auto axis = make.getAxis ();
2244+
2245+ IntegerAttr axisAttr;
2246+ if (!matchPattern (axis, m_Constant (&axisAttr))) {
2247+ return failure ();
2248+ }
2249+
2250+ auto operandDTy = operand.getType ().cast <ValueTensorType>().getDtype ();
2251+ auto zeropointDTy = zeropoint.getType ().cast <ValueTensorType>().getDtype ();
2252+ operand = converter->materializeTargetConversion (
2253+ rewriter, loc, converter->convertType (operand.getType ()), operand);
2254+ scale = converter->materializeTargetConversion (
2255+ rewriter, loc, converter->convertType (scale.getType ()), scale);
2256+ zeropoint = converter->materializeTargetConversion (
2257+ rewriter, loc, converter->convertType (zeropoint.getType ()), zeropoint);
2258+
2259+ auto resultType = converter->convertType (op->getResult (0 ).getType ())
2260+ .cast <RankedTensorType>();
2261+
2262+ llvm::SmallVector<Value> dynSizes;
2263+ for (auto [index, dim] : llvm::enumerate (resultType.getShape ())) {
2264+ if (ShapedType::isDynamic (dim)) {
2265+ dynSizes.push_back (rewriter.create <tensor::DimOp>(loc, operand, index));
2266+ }
2267+ }
2268+
2269+ llvm::SmallVector<utils::IteratorType> iterators (
2270+ resultType.getRank (), utils::IteratorType::parallel);
2271+ llvm::SmallVector<AffineMap> maps (
2272+ 4 , {rewriter.getMultiDimIdentityMap (resultType.getRank ())});
2273+ auto broadcastMap = AffineMap::get (
2274+ resultType.getRank (), /* symbolCount=*/ 0 ,
2275+ {rewriter.getAffineDimExpr (axisAttr.getInt ())}, rewriter.getContext ());
2276+ maps[1 ] = broadcastMap;
2277+ maps[2 ] = broadcastMap;
2278+
2279+ auto empty =
2280+ rewriter.create <tensor::EmptyOp>(op.getLoc (), resultType, dynSizes);
2281+ auto linalgOp = rewriter.create <linalg::GenericOp>(
2282+ loc, resultType, ValueRange{operand, scale, zeropoint},
2283+ ValueRange{empty}, maps, iterators,
2284+ [&](OpBuilder &b, Location loc, ValueRange args) {
2285+ Value operand = args[0 ];
2286+ Value scale = args[1 ];
2287+ Value zeropoint = args[2 ];
2288+ if (operandDTy.isUnsignedInteger (8 )) {
2289+ operand = b.create <arith::ExtUIOp>(loc, b.getI32Type (), operand);
2290+ } else if (operandDTy.isSignedInteger (8 )) {
2291+ operand = b.create <arith::ExtSIOp>(loc, b.getI32Type (), operand);
2292+ }
2293+
2294+ if (zeropointDTy.isUnsignedInteger (8 )) {
2295+ zeropoint =
2296+ b.create <arith::ExtUIOp>(loc, b.getI32Type (), zeropoint);
2297+ } else if (zeropointDTy.isSignedInteger (8 )) {
2298+ zeropoint =
2299+ b.create <arith::ExtSIOp>(loc, b.getI32Type (), zeropoint);
2300+ }
2301+
2302+ Value sub = rewriter.create <arith::SubIOp>(loc, operand, zeropoint);
2303+ Value fp =
2304+ rewriter.create <arith::SIToFPOp>(loc, args[3 ].getType (), sub);
2305+ Value mul = rewriter.create <arith::MulFOp>(loc, fp, scale);
2306+ b.create <linalg::YieldOp>(loc, mul);
2307+ });
2308+ rewriter.replaceOp (op, linalgOp.getResults ());
2309+ return success ();
2310+ }
2311+ };
2312+ } // namespace
2313+
2314+ namespace {
2315+
2316+ template <typename OpTy>
2317+ class ConvertCastEquivalentOp : public OpConversionPattern <OpTy> {
2318+ using OpConversionPattern<OpTy>::OpConversionPattern;
2319+ using OpAdaptor = typename OpTy::Adaptor;
2320+
2321+ LogicalResult
2322+ matchAndRewrite (OpTy op, OpAdaptor adaptor,
2323+ ConversionPatternRewriter &rewriter) const override {
2324+ auto converter = this ->getTypeConverter ();
2325+ RankedTensorType resultType = cast<RankedTensorType>(
2326+ converter->convertType (op->getResult (0 ).getType ()));
22342327 rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType,
22352328 adaptor.getSelf ());
22362329 return success ();
@@ -2283,6 +2376,11 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
22832376 target.addIllegalOp <TensorStaticInfoCastOp>();
22842377 patterns.add <ConvertAtenIntReprOp>(typeConverter, context);
22852378 target.addIllegalOp <AtenIntReprOp>();
2286- patterns.add <ConvertMakePerTensorQuantizedTensorOp>(typeConverter, context);
2379+ patterns.add <ConvertCastEquivalentOp<Aten_MakePerChannelQuantizedTensorOp>>(
2380+ typeConverter, context);
2381+ target.addIllegalOp <Aten_MakePerChannelQuantizedTensorOp>();
2382+ patterns.add <ConvertCastEquivalentOp<Aten_MakePerTensorQuantizedTensorOp>>(
2383+ typeConverter, context);
22872384 target.addIllegalOp <Aten_MakePerTensorQuantizedTensorOp>();
2385+ patterns.add <ConvertDequantizePerChannel>(typeConverter, context);
22882386}
0 commit comments