Skip to content

Commit 2ef2283

Browse files
authored
[torch] torch.dequantize for per channel tensors to linalg (#2769)
Support a lowering for dequantization for per channel tensors from `torch` dialect to a linalg decomposition. Tested via a numerical `torch` test.
1 parent 0aed231 commit 2ef2283

File tree

8 files changed

+258
-8
lines changed

8 files changed

+258
-8
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14465,6 +14465,33 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
1446514465
}];
1446614466
}
1446714467

14468+
def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
14469+
AllowsTypeRefinement,
14470+
HasValueSemantics,
14471+
ReadOnly
14472+
]> {
14473+
let summary = "Generated op for `aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)`";
14474+
let arguments = (ins
14475+
AnyTorchTensorType:$self,
14476+
AnyTorchTensorType:$scales,
14477+
AnyTorchTensorType:$zero_points,
14478+
Torch_IntType:$axis,
14479+
Torch_IntType:$dtype
14480+
);
14481+
let results = (outs
14482+
AnyTorchTensorType:$result
14483+
);
14484+
let hasCustomAssemblyFormat = 1;
14485+
let extraClassDefinition = [{
14486+
ParseResult AtenQuantizePerChannelOp::parse(OpAsmParser &parser, OperationState &result) {
14487+
return parseDefaultTorchOp(parser, result, 5, 1);
14488+
}
14489+
void AtenQuantizePerChannelOp::print(OpAsmPrinter &printer) {
14490+
printDefaultTorchOp(printer, *this, 5, 1);
14491+
}
14492+
}];
14493+
}
14494+
1446814495
def Torch_AtenQuantizePerTensorOp : Torch_Op<"aten.quantize_per_tensor", [
1446914496
AllowsTypeRefinement,
1447014497
HasValueSemantics,
@@ -14560,6 +14587,32 @@ def Torch_AtenIntReprOp : Torch_Op<"aten.int_repr", [
1456014587
}];
1456114588
}
1456214589

14590+
def Torch_Aten_MakePerChannelQuantizedTensorOp : Torch_Op<"aten._make_per_channel_quantized_tensor", [
14591+
AllowsTypeRefinement,
14592+
HasValueSemantics,
14593+
ReadOnly
14594+
]> {
14595+
let summary = "Generated op for `aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)`";
14596+
let arguments = (ins
14597+
AnyTorchTensorType:$self,
14598+
AnyTorchTensorType:$scale,
14599+
AnyTorchTensorType:$zero_point,
14600+
Torch_IntType:$axis
14601+
);
14602+
let results = (outs
14603+
AnyTorchTensorType:$result
14604+
);
14605+
let hasCustomAssemblyFormat = 1;
14606+
let extraClassDefinition = [{
14607+
ParseResult Aten_MakePerChannelQuantizedTensorOp::parse(OpAsmParser &parser, OperationState &result) {
14608+
return parseDefaultTorchOp(parser, result, 4, 1);
14609+
}
14610+
void Aten_MakePerChannelQuantizedTensorOp::print(OpAsmPrinter &printer) {
14611+
printDefaultTorchOp(printer, *this, 4, 1);
14612+
}
14613+
}];
14614+
}
14615+
1456314616
def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_quantized_tensor", [
1456414617
AllowsTypeRefinement,
1456514618
HasValueSemantics,

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 106 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,7 +1344,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
13441344
auto makeQTensor =
13451345
qtensor.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
13461346
if (!makeQTensor) {
1347-
op->emitError(
1347+
op->emitWarning(
13481348
"unimplemented: dequantizing tensor of unknown scale / zero-point");
13491349
return nullptr;
13501350
}
@@ -2221,16 +2221,109 @@ class ConvertAtenIntReprOp : public OpConversionPattern<AtenIntReprOp> {
22212221
} // namespace
22222222

22232223
namespace {
2224-
class ConvertMakePerTensorQuantizedTensorOp
2225-
: public OpConversionPattern<Aten_MakePerTensorQuantizedTensorOp> {
2224+
class ConvertDequantizePerChannel
2225+
: public OpConversionPattern<AtenDequantizeSelfOp> {
22262226
public:
22272227
using OpConversionPattern::OpConversionPattern;
22282228
LogicalResult
2229-
matchAndRewrite(Aten_MakePerTensorQuantizedTensorOp op, OpAdaptor adaptor,
2229+
matchAndRewrite(AtenDequantizeSelfOp op, OpAdaptor adaptor,
22302230
ConversionPatternRewriter &rewriter) const override {
2231-
RankedTensorType resultType = getTypeConverter()
2232-
->convertType(op->getResult(0).getType())
2233-
.cast<RankedTensorType>();
2231+
auto loc = op.getLoc();
2232+
auto qoperand = op.getOperand();
2233+
auto make = qoperand.getDefiningOp<Aten_MakePerChannelQuantizedTensorOp>();
2234+
if (!make) {
2235+
llvm::errs() << "Did not find make per channel\n";
2236+
return rewriter.notifyMatchFailure(op, "did not find per channel qint");
2237+
}
2238+
2239+
auto converter = getTypeConverter();
2240+
auto operand = make.getOperand(0);
2241+
auto scale = make.getScale();
2242+
auto zeropoint = make.getZeroPoint();
2243+
auto axis = make.getAxis();
2244+
2245+
IntegerAttr axisAttr;
2246+
if (!matchPattern(axis, m_Constant(&axisAttr))) {
2247+
return failure();
2248+
}
2249+
2250+
auto operandDTy = operand.getType().cast<ValueTensorType>().getDtype();
2251+
auto zeropointDTy = zeropoint.getType().cast<ValueTensorType>().getDtype();
2252+
operand = converter->materializeTargetConversion(
2253+
rewriter, loc, converter->convertType(operand.getType()), operand);
2254+
scale = converter->materializeTargetConversion(
2255+
rewriter, loc, converter->convertType(scale.getType()), scale);
2256+
zeropoint = converter->materializeTargetConversion(
2257+
rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint);
2258+
2259+
auto resultType = converter->convertType(op->getResult(0).getType())
2260+
.cast<RankedTensorType>();
2261+
2262+
llvm::SmallVector<Value> dynSizes;
2263+
for (auto [index, dim] : llvm::enumerate(resultType.getShape())) {
2264+
if (ShapedType::isDynamic(dim)) {
2265+
dynSizes.push_back(rewriter.create<tensor::DimOp>(loc, operand, index));
2266+
}
2267+
}
2268+
2269+
llvm::SmallVector<utils::IteratorType> iterators(
2270+
resultType.getRank(), utils::IteratorType::parallel);
2271+
llvm::SmallVector<AffineMap> maps(
2272+
4, {rewriter.getMultiDimIdentityMap(resultType.getRank())});
2273+
auto broadcastMap = AffineMap::get(
2274+
resultType.getRank(), /*symbolCount=*/0,
2275+
{rewriter.getAffineDimExpr(axisAttr.getInt())}, rewriter.getContext());
2276+
maps[1] = broadcastMap;
2277+
maps[2] = broadcastMap;
2278+
2279+
auto empty =
2280+
rewriter.create<tensor::EmptyOp>(op.getLoc(), resultType, dynSizes);
2281+
auto linalgOp = rewriter.create<linalg::GenericOp>(
2282+
loc, resultType, ValueRange{operand, scale, zeropoint},
2283+
ValueRange{empty}, maps, iterators,
2284+
[&](OpBuilder &b, Location loc, ValueRange args) {
2285+
Value operand = args[0];
2286+
Value scale = args[1];
2287+
Value zeropoint = args[2];
2288+
if (operandDTy.isUnsignedInteger(8)) {
2289+
operand = b.create<arith::ExtUIOp>(loc, b.getI32Type(), operand);
2290+
} else if (operandDTy.isSignedInteger(8)) {
2291+
operand = b.create<arith::ExtSIOp>(loc, b.getI32Type(), operand);
2292+
}
2293+
2294+
if (zeropointDTy.isUnsignedInteger(8)) {
2295+
zeropoint =
2296+
b.create<arith::ExtUIOp>(loc, b.getI32Type(), zeropoint);
2297+
} else if (zeropointDTy.isSignedInteger(8)) {
2298+
zeropoint =
2299+
b.create<arith::ExtSIOp>(loc, b.getI32Type(), zeropoint);
2300+
}
2301+
2302+
Value sub = rewriter.create<arith::SubIOp>(loc, operand, zeropoint);
2303+
Value fp =
2304+
rewriter.create<arith::SIToFPOp>(loc, args[3].getType(), sub);
2305+
Value mul = rewriter.create<arith::MulFOp>(loc, fp, scale);
2306+
b.create<linalg::YieldOp>(loc, mul);
2307+
});
2308+
rewriter.replaceOp(op, linalgOp.getResults());
2309+
return success();
2310+
}
2311+
};
2312+
} // namespace
2313+
2314+
namespace {
2315+
2316+
template <typename OpTy>
2317+
class ConvertCastEquivalentOp : public OpConversionPattern<OpTy> {
2318+
using OpConversionPattern<OpTy>::OpConversionPattern;
2319+
using OpAdaptor = typename OpTy::Adaptor;
2320+
2321+
LogicalResult
2322+
matchAndRewrite(OpTy op, OpAdaptor adaptor,
2323+
ConversionPatternRewriter &rewriter) const override {
2324+
auto converter = this->getTypeConverter();
2325+
RankedTensorType resultType = cast<RankedTensorType>(
2326+
converter->convertType(op->getResult(0).getType()));
22342327
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
22352328
adaptor.getSelf());
22362329
return success();
@@ -2283,6 +2376,11 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
22832376
target.addIllegalOp<TensorStaticInfoCastOp>();
22842377
patterns.add<ConvertAtenIntReprOp>(typeConverter, context);
22852378
target.addIllegalOp<AtenIntReprOp>();
2286-
patterns.add<ConvertMakePerTensorQuantizedTensorOp>(typeConverter, context);
2379+
patterns.add<ConvertCastEquivalentOp<Aten_MakePerChannelQuantizedTensorOp>>(
2380+
typeConverter, context);
2381+
target.addIllegalOp<Aten_MakePerChannelQuantizedTensorOp>();
2382+
patterns.add<ConvertCastEquivalentOp<Aten_MakePerTensorQuantizedTensorOp>>(
2383+
typeConverter, context);
22872384
target.addIllegalOp<Aten_MakePerTensorQuantizedTensorOp>();
2385+
patterns.add<ConvertDequantizePerChannel>(typeConverter, context);
22882386
}

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6549,6 +6549,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
65496549
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
65506550
" return %0 : !torch.list<int>\n"
65516551
" }\n"
6552+
" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_channel\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list<int> {\n"
6553+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6554+
" return %0 : !torch.list<int>\n"
6555+
" }\n"
65526556
" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_tensor\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
65536557
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
65546558
" return %0 : !torch.list<int>\n"
@@ -6565,6 +6569,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
65656569
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
65666570
" return %0 : !torch.list<int>\n"
65676571
" }\n"
6572+
" func.func @\"__torch_mlir_shape_fn.aten._make_per_channel_quantized_tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int) -> !torch.list<int> {\n"
6573+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6574+
" return %0 : !torch.list<int>\n"
6575+
" }\n"
65686576
" func.func @\"__torch_mlir_shape_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int) -> !torch.list<int> {\n"
65696577
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
65706578
" return %0 : !torch.list<int>\n"
@@ -12632,6 +12640,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1263212640
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1263312641
" return %0#1 : !torch.int\n"
1263412642
" }\n"
12643+
" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_channel\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n"
12644+
" return %arg4 : !torch.int\n"
12645+
" }\n"
1263512646
" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
1263612647
" return %arg3 : !torch.int\n"
1263712648
" }\n"
@@ -12664,6 +12675,27 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1266412675
" }\n"
1266512676
" return %2 : !torch.int\n"
1266612677
" }\n"
12678+
" func.func @\"__torch_mlir_dtype_fn.aten._make_per_channel_quantized_tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int) -> !torch.int {\n"
12679+
" %int14 = torch.constant.int 14\n"
12680+
" %int12 = torch.constant.int 12\n"
12681+
" %int1 = torch.constant.int 1\n"
12682+
" %int13 = torch.constant.int 13\n"
12683+
" %int0 = torch.constant.int 0\n"
12684+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12685+
" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
12686+
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
12687+
" torch.prim.If.yield %int13 : !torch.int\n"
12688+
" } else {\n"
12689+
" %3 = torch.aten.eq.int %0#1, %int1 : !torch.int, !torch.int -> !torch.bool\n"
12690+
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
12691+
" torch.prim.If.yield %int12 : !torch.int\n"
12692+
" } else {\n"
12693+
" torch.prim.If.yield %int14 : !torch.int\n"
12694+
" }\n"
12695+
" torch.prim.If.yield %4 : !torch.int\n"
12696+
" }\n"
12697+
" return %2 : !torch.int\n"
12698+
" }\n"
1266712699
" func.func @\"__torch_mlir_dtype_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.int) -> !torch.int {\n"
1266812700
" %int14 = torch.constant.int 14\n"
1266912701
" %int12 = torch.constant.int 12\n"

projects/ltc/csrc/base_lazy_backend/shape_inference.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,20 @@ std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor& self,
3939
return {Shape(self.scalar_type(), self.sizes().vec())};
4040
}
4141

42+
std::vector<torch::lazy::Shape>
43+
compute_shape__make_per_channel_quantized_tensor(const at::Tensor &self,
44+
const at::Tensor &scale,
45+
const at::Tensor &zero_point,
46+
int64_t axis) {
47+
if (self.scalar_type() == at::kChar)
48+
return {Shape(at::kQInt8, self.sizes().vec())};
49+
if (self.scalar_type() == at::kByte)
50+
return {Shape(at::kQUInt8, self.sizes().vec())};
51+
if (self.scalar_type() == at::kInt)
52+
return {Shape(at::kQInt32, self.sizes().vec())};
53+
assert(false);
54+
}
55+
4256
std::vector<torch::lazy::Shape> compute_shape__make_per_tensor_quantized_tensor(
4357
const at::Tensor &self, double scale, int64_t zero_point) {
4458
if (self.scalar_type() == at::kChar)
@@ -75,6 +89,12 @@ std::vector<torch::lazy::Shape> compute_shape_isinf(const at::Tensor& self) {
7589
return {Shape(at::kBool, self.sizes().vec())};
7690
}
7791

92+
std::vector<torch::lazy::Shape> compute_shape_quantize_per_channel(
93+
const at::Tensor &self, const at::Tensor &scales,
94+
const at::Tensor &zero_points, int64_t axis, at::ScalarType dtype) {
95+
return {Shape(dtype, self.sizes().vec())};
96+
}
97+
7898
std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
7999
const at::Tensor& self, at::IntArrayRef kernel_size,
80100
at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation,

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@
313313
"GroupNormNoWeightAndBiasModule_basic",
314314

315315
# Dynamo does not support tracing quantized tensors
316+
"ElementwiseDequantizePerChannelModule_basic",
316317
"ElementwiseDequantizePerTensorModule_basic",
317318
"ElementwiseQuantizePerTensorModule_basic",
318319
"AtenMmQuint8_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,9 @@ def aten〇clamp_max〡shape(self: List[int], max: float) -> List[int]:
251251
def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1) -> List[int]:
252252
return upstream_shape_functions.unary(self)
253253

254+
def aten〇quantize_per_channel〡shape(self: List[int], scales: List[int], zero_points: List[int], axis: int, dtype: int) -> List[int]:
255+
return upstream_shape_functions.unary(self)
256+
254257
def aten〇quantize_per_tensor〡shape(self: List[int], scale: float, zero_point: int, dtype: int) -> List[int]:
255258
return upstream_shape_functions.unary(self)
256259

@@ -263,6 +266,9 @@ def aten〇dequantize〇tensor〡shape(qtensor: List[int]) -> List[int]:
263266
def aten〇int_repr〡shape(self: List[int]) -> List[int]:
264267
return upstream_shape_functions.unary(self)
265268

269+
def aten〇_make_per_channel_quantized_tensor〡shape(self: List[int], scale: List[int], zero_point: List[int], axis: int) -> List[int]:
270+
return upstream_shape_functions.unary(self)
271+
266272
def aten〇_make_per_tensor_quantized_tensor〡shape(self: List[int], scale: float, zero_point: int) -> List[int]:
267273
return upstream_shape_functions.unary(self)
268274

@@ -4280,6 +4286,9 @@ def prims〇collapse〡dtype(a_rank_dtype: Tuple[int, int], start: int, end: int
42804286
return a_dtype
42814287

42824288

4289+
def aten〇quantize_per_channel〡dtype(self_rank_dtype: Tuple[int, int], scales_rank_dtype: Tuple[int, int], zero_points_rank_dtype: Tuple[int, int], axis: int, dtype: int) -> int:
4290+
return dtype
4291+
42834292
def aten〇quantize_per_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int, dtype: int) -> int:
42844293
return dtype
42854294

@@ -4297,6 +4306,14 @@ def aten〇int_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
42974306
return torch.int8
42984307
return torch.int32
42994308

4309+
def aten〇_make_per_channel_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], axis: int) -> int:
4310+
self_rank, self_dtype = self_rank_dtype
4311+
if (self_dtype == torch.uint8):
4312+
return torch.quint8
4313+
if (self_dtype == torch.int8):
4314+
return torch.qint8
4315+
return torch.qint32
4316+
43004317
def aten〇_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int) -> int:
43014318
self_rank, self_dtype = self_rank_dtype
43024319
if (self_dtype == torch.uint8):

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,10 +820,12 @@ def emit_with_mutating_variants(key, **kwargs):
820820
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
821821

822822
# quantized ops
823+
emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)")
823824
emit("aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)")
824825
emit("aten::dequantize.self : (Tensor) -> (Tensor)")
825826
emit("aten::dequantize.tensor : (Tensor) -> (Tensor)")
826827
emit("aten::int_repr : (Tensor) -> (Tensor)")
828+
emit("aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)")
827829
emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)")
828830

829831
# ==========================================================================

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4328,6 +4328,33 @@ def ElementwiseDequantizePerTensorModule_basic(module, tu: TestUtils):
43284328

43294329
# ==============================================================================
43304330

4331+
class ElementwiseDequantizePerChannelModule(torch.nn.Module):
4332+
4333+
def __init__(self):
4334+
super().__init__()
4335+
4336+
@export
4337+
@annotate_args([
4338+
None,
4339+
([3, 4], torch.int8, True),
4340+
([4], torch.int8, True),
4341+
([4], torch.float, True),
4342+
])
4343+
def forward(self, x, zeropoint, scale):
4344+
qx = torch._make_per_channel_quantized_tensor(x, scale, zeropoint, axis=1)
4345+
qx = torch.dequantize(qx)
4346+
return qx
4347+
4348+
@register_test_case(module_factory=lambda: ElementwiseDequantizePerChannelModule())
4349+
def ElementwiseDequantizePerChannelModule_basic(module, tu: TestUtils):
4350+
module.forward(
4351+
tu.randint(3, 4, low=-128, high=127).to(torch.int8),
4352+
tu.randint(4, low=-128, high=127).to(torch.int8),
4353+
tu.rand(4)
4354+
)
4355+
4356+
# ==============================================================================
4357+
43314358
class GluStaticModule(torch.nn.Module):
43324359
def __init__(self):
43334360
super().__init__()

0 commit comments

Comments
 (0)