From f3aefca51d9bf3d136a1e2afbad1593334b178a6 Mon Sep 17 00:00:00 2001 From: orangeH25 <18085625039@163.com> Date: Mon, 13 Oct 2025 11:07:20 +0000 Subject: [PATCH 1/9] Add NPU (Ascend) backend support for INT4 weight-only quantization workflow --- .../int4/test_int4_plain_int32_tensor_npu.py | 113 ++++++++ torchao/quantization/__init__.py | 1 + torchao/quantization/quant_api.py | 15 +- .../quantize_/workflows/__init__.py | 4 + .../int4/int4_plain_int32_tensor_npu.py | 243 ++++++++++++++++++ 5 files changed, 372 insertions(+), 4 deletions(-) create mode 100644 test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py create mode 100644 torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py new file mode 100644 index 0000000000..a5e91952c2 --- /dev/null +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +import tempfile +from packaging import version + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.quantization import ( + Int4WeightOnlyConfig, + quantize_, +) +from torchao.quantization.quantize_.common import SupportsActivationPreScaling +from torchao.quantization.utils import compute_error +from torchao.utils import ( + torch_version_at_least, +) + +try: + import torch_npu +except ImportError: + torch_npu = None + + +def get_config(group_size): + return Int4WeightOnlyConfig( + group_size=group_size, + int4_packing_format="plain_int32", + ) + + +@unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+") +@unittest.skipIf(torch_npu is None, "torch_npu is not available") +@unittest.skipIf(not torch_npu.npu.is_available(), "NPU not available") +@unittest.skipIf( + version.parse(torch_npu.__version__) < version.parse("2.7.1rc1"), + "Need torch_npu 2.7.1rc1+", +) +class Int4PlainInt32TensorNPU(TestCase): + + @parametrize("device", ["npu"]) + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 512, 128), + ((2, 32, 128), 256, 128), + ], + ) + @parametrize("dtype", [torch.float16, torch.bfloat16]) + @parametrize("group_size", [32, 64]) + def test_linear(self, device, sizes, dtype, group_size): + M, N, K = sizes + input = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) + orig_output = linear(input) + quantize_(linear, get_config(group_size)) + quantized_output = linear(input) + self.assertTrue(compute_error(orig_output, quantized_output) > 10) + + @parametrize("device", ["npu"]) + @parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_module_path(self, device, dtype): + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + quantize_(linear, get_config(group_size=64)) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + @parametrize("device", ["npu"]) + @parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_activation_prescaling(self, device, dtype): + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, get_config(64)) + qw = linear.weight + assert isinstance( + qw, SupportsActivationPreScaling + ), "Expected int4 tensor supports activation prescaling" + assert qw.act_pre_scale is None, "Default `act_pre_scale` is None" + _ACT_PRE_SCALE = 2 + qw.act_pre_scale = _ACT_PRE_SCALE + quantized = linear(input) + + # making sure activation pre scaling is successfully applied to the activation + self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10) + + +instantiate_parametrized_tests(Int4PlainInt32TensorNPU) + +if __name__ == "__main__": + run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index c8774e9426..090298c10a 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -94,6 +94,7 @@ Int4MarlinSparseTensor, Int4OpaqueTensor, Int4PlainInt32Tensor, + Int4PlainInt32TensorNPU, Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3bda8f91ab..6abc7819a7 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -78,6 +78,7 @@ Int4OpaqueTensor, Int4PackingFormat, Int4PlainInt32Tensor, + Int4PlainInt32TensorNPU, Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, @@ -1210,10 +1211,16 @@ def _int4_weight_only_quantize_tensor(weight, config): ) return new_weight elif int4_packing_format == Int4PackingFormat.PLAIN_INT32: - new_weight = Int4PlainInt32Tensor.from_hp( - weight, - block_size, - ) + if weight.device.type == "npu": + new_weight = Int4PlainInt32TensorNPU.from_hp( + weight, + block_size, + ) + else: + new_weight = Int4PlainInt32Tensor.from_hp( + weight, + block_size, + ) return new_weight elif int4_packing_format == Int4PackingFormat.MARLIN_SPARSE: new_weight = Int4MarlinSparseTensor.from_hp( diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 4307637f8e..ea05afd733 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -13,6 +13,9 @@ from .int4.int4_plain_int32_tensor import ( Int4PlainInt32Tensor, ) +from .int4.int4_plain_int32_tensor_npu import ( + Int4PlainInt32TensorNPU, +) from .int4.int4_preshuffled_tensor import ( Int4PreshuffledTensor, ) @@ -36,6 +39,7 @@ "Int4PreshuffledTensor", "Int4MarlinSparseTensor", "Int4PlainInt32Tensor", + "Int4PlainInt32TensorNPU", "Int4TilePackedTo4dTensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", diff --git a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py new file mode 100644 index 0000000000..6bca2cd345 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch + +from torchao.quantization.quant_primitives import ( + MappingType, + choose_qparams_affine, + quantize_affine, +) +from torchao.utils import ( + TorchAOBaseTensor, +) + +__all__ = ["Int4PlainInt32TensorNPU"] + +aten = torch.ops.aten + +try: + import torch_npu +except ImportError: + torch_npu = None + + +class Int4PlainInt32TensorNPU(TorchAOBaseTensor): + """ + int4 weight-only quantization on Ascend NPU backend (groupwise quantization only) + + Tensor Attributes: + qdata: (N, K/8), packed int4 weight, the data type is int32 here with 8*int4, the original dtype can be float16 or bfloat16 + scale: (K/group_size, N), dtype is the same as the original Tensor type (float16 or bfloat16) + zero_point: (K/group_size, N), dtype is the same as the original Tensor type (float16 or bfloat16) + + Non-Tensor Attributes: + block_size: the block size for quantization, representing the granularity + shape: shape of the original Tensor + + Optional Tensor Data Attributes: + act_pre_scale (Optional[Tensor]): Optional scale for activation Tensor, if present, + we'll multiply activation Tensor with act_pre_scale before applying dynamic + quantization to activation or running quantized mm op + + """ + + tensor_data_names = ["qdata", "scale", "zero_point"] + tensor_attribute_names = ["block_size", "shape"] + optional_tensor_data_names = ["act_pre_scale"] + + def __new__( + cls, + qdata, + scale, + zero_point, + block_size, + shape, + act_pre_scale: Optional[torch.Tensor] = None, + ): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata, + scale, + zero_point, + block_size, + shape, + act_pre_scale: Optional[torch.Tensor] = None, + ): + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.block_size = block_size + self.act_pre_scale = act_pre_scale + + def _quantization_type(self): + s = f"shape={self.shape}, block_size={self.block_size}, device={self.device}" + if self.act_pre_scale is not None: + s += f", act_pre_scale.shape={self.act_pre_scale.shape}" + return s + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: List[int], + ): + if torch_npu is None: + raise ImportError("Requires torch_npu but it is not installed") + + assert w.ndim == 2 and w.device.type == "npu", ( + f"Expecting 2D tensor on NPU, but got: {w.shape} on {w.device.type}" + ) + assert len(block_size) == w.ndim + assert w.dtype in [torch.float16, torch.bfloat16], ( + f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" + ) + + original_shape = w.shape + mapping_type = MappingType.ASYMMETRIC + target_dtype = torch.int32 + quant_min = -8 + quant_max = 7 + eps = 1e-6 + scale_dtype = w.dtype + zero_point_dtype = w.dtype + + scale, zero_point = choose_qparams_affine( + w, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + ) + + int_data = quantize_affine( + w, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + ) + + assert int_data.dtype == torch.int32, ( + f"torch_npu.npu_convert_weight_to_int4pack expects `int32` dtype" + ) + + assert int_data.shape[-1] % 8 == 0, ( + f"torch_npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}" + ) + + packed_weight = torch_npu.npu_convert_weight_to_int4pack( + int_data.contiguous(), 0 + ) + + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + + return Int4PlainInt32TensorNPU( + packed_weight, + scale.transpose(0, 1).contiguous(), + zero_point.transpose(0, 1).contiguous(), + block_size, + original_shape, + act_pre_scale=None, + ) + + +implements = Int4PlainInt32TensorNPU.implements +implements_torch_function = Int4PlainInt32TensorNPU.implements_torch_function + + +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) +def _(func, types, args, kwargs): + + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + if torch_npu is None: + raise ImportError("Requires torch_npu but it is not installed") + + assert input_tensor.device.type == "npu", ( + f"For NPU device only but got: {input_tensor.device.type}" + ) + assert isinstance(weight_tensor, Int4PlainInt32TensorNPU), ( + f"Expected weight_tensor to be Int4PlainInt32NPUTensor, got: {type(weight_tensor)}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" + ) + + if weight_tensor.act_pre_scale is not None: + input_tensor = input_tensor * weight_tensor.act_pre_scale + + act_mat = input_tensor + packed_weight = weight_tensor.qdata + scale = weight_tensor.scale + zero_point = weight_tensor.zero_point + + orig_act_size = act_mat.shape + orig_dtype = act_mat.dtype + + # dtype alignment + if act_mat.dtype == torch.float16: + scale = scale.to(torch.float16) + zero_point = zero_point.to(torch.float16) + if bias is not None: + bias = bias.to(torch.float16) + elif act_mat.dtype == torch.bfloat16: + scale = scale.to(torch.bfloat16) + zero_point = zero_point.to(torch.bfloat16) + if bias is not None: + bias = bias.to(torch.float32) + + # reshape to 2D + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[1] + + y = torch_npu.npu_weight_quant_batchmatmul( + x=act_mat, + weight=packed_weight.contiguous().transpose(-1, -2), + antiquant_scale=scale, + antiquant_offset=zero_point, + antiquant_group_size=groupsize, + bias=bias, + ) + + # remove out_feature padding + assert weight_tensor.ndim == 2 + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + return y.to(orig_dtype) + + +Int4PlainInt32TensorNPU.__module__ = "torchao.quantization" + +# Allow a model with Int4PlainInt32TensorNPU weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4PlainInt32TensorNPU]) From 68eea614160a0adae4d797282f13fa2e014a72bb Mon Sep 17 00:00:00 2001 From: orangeH25 <18085625039@163.com> Date: Tue, 14 Oct 2025 09:51:48 +0000 Subject: [PATCH 2/9] use torch.ops.npu prefix and drop redundant torch_npu import --- .../workflows/int4/int4_plain_int32_tensor_npu.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py index 6bca2cd345..3c86f0e805 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py @@ -21,11 +21,6 @@ aten = torch.ops.aten -try: - import torch_npu -except ImportError: - torch_npu = None - class Int4PlainInt32TensorNPU(TorchAOBaseTensor): """ @@ -93,9 +88,6 @@ def from_hp( w: torch.Tensor, block_size: List[int], ): - if torch_npu is None: - raise ImportError("Requires torch_npu but it is not installed") - assert w.ndim == 2 and w.device.type == "npu", ( f"Expecting 2D tensor on NPU, but got: {w.shape} on {w.device.type}" ) @@ -143,7 +135,7 @@ def from_hp( f"torch_npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}" ) - packed_weight = torch_npu.npu_convert_weight_to_int4pack( + packed_weight = torch.ops.npu.npu_convert_weight_to_int4pack( int_data.contiguous(), 0 ) @@ -174,9 +166,6 @@ def _(func, types, args, kwargs): args[2] if len(args) > 2 else None, ) - if torch_npu is None: - raise ImportError("Requires torch_npu but it is not installed") - assert input_tensor.device.type == "npu", ( f"For NPU device only but got: {input_tensor.device.type}" ) @@ -219,7 +208,7 @@ def _(func, types, args, kwargs): # groupwise int4 quantization groupsize = weight_tensor.block_size[1] - y = torch_npu.npu_weight_quant_batchmatmul( + y = torch.ops.npu.npu_weight_quant_batchmatmul( x=act_mat, weight=packed_weight.contiguous().transpose(-1, -2), antiquant_scale=scale, From 06c77d1fc616f13ee7195c714f2aab848e7677cf Mon Sep 17 00:00:00 2001 From: orangeH25 <18085625039@163.com> Date: Wed, 15 Oct 2025 02:29:44 +0000 Subject: [PATCH 3/9] Modify test file and update comments --- .../int4/test_int4_plain_int32_tensor_npu.py | 12 +++--------- .../workflows/int4/int4_plain_int32_tensor_npu.py | 4 ++-- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py index a5e91952c2..f21977310d 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py @@ -26,11 +26,6 @@ torch_version_at_least, ) -try: - import torch_npu -except ImportError: - torch_npu = None - def get_config(group_size): return Int4WeightOnlyConfig( @@ -40,11 +35,10 @@ def get_config(group_size): @unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+") -@unittest.skipIf(torch_npu is None, "torch_npu is not available") -@unittest.skipIf(not torch_npu.npu.is_available(), "NPU not available") @unittest.skipIf( - version.parse(torch_npu.__version__) < version.parse("2.7.1rc1"), - "Need torch_npu 2.7.1rc1+", + torch.accelerator.current_accelerator(True).type == "npu" + and torch.accelerator.is_available(), + "NPU not available", ) class Int4PlainInt32TensorNPU(TestCase): diff --git a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py index 3c86f0e805..80ddcd9619 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py @@ -128,11 +128,11 @@ def from_hp( ) assert int_data.dtype == torch.int32, ( - f"torch_npu.npu_convert_weight_to_int4pack expects `int32` dtype" + f"torch.ops.npu.npu_convert_weight_to_int4pack expects `int32` dtype" ) assert int_data.shape[-1] % 8 == 0, ( - f"torch_npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}" + f"torch.ops.npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}" ) packed_weight = torch.ops.npu.npu_convert_weight_to_int4pack( From ea2aa7a861e97aee82d2cfef3a8da2211c1ce72c Mon Sep 17 00:00:00 2001 From: orangeH25 <18085625039@163.com> Date: Tue, 21 Oct 2025 08:08:46 +0000 Subject: [PATCH 4/9] add: merge NPU(Ascend) backend logic in Int4PlainInt32Tensor subclass --- .../int4/test_int4_plain_int32_tensor.py | 71 ++++- .../int4/test_int4_plain_int32_tensor_npu.py | 107 ------- torchao/quantization/__init__.py | 1 - torchao/quantization/quant_api.py | 15 +- .../quantize_/workflows/__init__.py | 4 - .../workflows/int4/int4_plain_int32_tensor.py | 275 ++++++++++++++---- .../int4/int4_plain_int32_tensor_npu.py | 232 --------------- 7 files changed, 296 insertions(+), 409 deletions(-) delete mode 100644 test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py delete mode 100644 torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py index becb44a5e0..d8f6640c8d 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py @@ -35,7 +35,7 @@ def get_config(group_size): @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") @unittest.skipIf(not torch.xpu.is_available(), "XPU not available") -class Int4PlainInt32Tensor(TestCase): +class Int4PlainInt32TensorXPU(TestCase): @parametrize( "sizes", [ @@ -98,8 +98,75 @@ def test_activation_prescaling(self): self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20) -instantiate_parametrized_tests(Int4PlainInt32Tensor) +@unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+") +@unittest.skipIf( + torch.accelerator.current_accelerator().type != "npu" + or not torch.accelerator.is_available(), + "NPU not available", +) +class Int4PlainInt32TensorNPU(TestCase): + + @parametrize("device", ["npu"]) + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 512, 128), + ((2, 32, 128), 256, 128), + ], + ) + @parametrize("dtype", [torch.float16, torch.bfloat16]) + @parametrize("group_size", [32, 64]) + def test_linear(self, device, sizes, dtype, group_size): + M, N, K = sizes + input = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) + orig_output = linear(input) + quantize_(linear, get_config(group_size)) + quantized_output = linear(input) + self.assertTrue(compute_error(orig_output, quantized_output) > 10) + + @parametrize("device", ["npu"]) + @parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_module_path(self, device, dtype): + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + quantize_(linear, get_config(group_size=64)) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + @parametrize("device", ["npu"]) + @parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_activation_prescaling(self, device, dtype): + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, get_config(64)) + qw = linear.weight + assert isinstance( + qw, SupportsActivationPreScaling + ), "Expected int4 tensor supports activation prescaling" + assert qw.act_pre_scale is None, "Default `act_pre_scale` is None" + _ACT_PRE_SCALE = 2 + qw.act_pre_scale = _ACT_PRE_SCALE + quantized = linear(input) + + # making sure activation pre scaling is successfully applied to the activation + self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10) + +instantiate_parametrized_tests(Int4PlainInt32TensorXPU) +instantiate_parametrized_tests(Int4PlainInt32TensorNPU) if __name__ == "__main__": run_tests() diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py deleted file mode 100644 index f21977310d..0000000000 --- a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import unittest -import tempfile -from packaging import version - -import torch -from torch.testing._internal.common_utils import ( - TestCase, - instantiate_parametrized_tests, - parametrize, - run_tests, -) - -from torchao.quantization import ( - Int4WeightOnlyConfig, - quantize_, -) -from torchao.quantization.quantize_.common import SupportsActivationPreScaling -from torchao.quantization.utils import compute_error -from torchao.utils import ( - torch_version_at_least, -) - - -def get_config(group_size): - return Int4WeightOnlyConfig( - group_size=group_size, - int4_packing_format="plain_int32", - ) - - -@unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+") -@unittest.skipIf( - torch.accelerator.current_accelerator(True).type == "npu" - and torch.accelerator.is_available(), - "NPU not available", -) -class Int4PlainInt32TensorNPU(TestCase): - - @parametrize("device", ["npu"]) - @parametrize( - "sizes", - [ - ((128,), 256, 128), - ((32, 128), 512, 128), - ((2, 32, 128), 256, 128), - ], - ) - @parametrize("dtype", [torch.float16, torch.bfloat16]) - @parametrize("group_size", [32, 64]) - def test_linear(self, device, sizes, dtype, group_size): - M, N, K = sizes - input = torch.randn(*M, K, dtype=dtype, device=device) - linear = torch.nn.Linear(K, N, dtype=dtype, device=device) - orig_output = linear(input) - quantize_(linear, get_config(group_size)) - quantized_output = linear(input) - self.assertTrue(compute_error(orig_output, quantized_output) > 10) - - @parametrize("device", ["npu"]) - @parametrize("dtype", [torch.float16, torch.bfloat16]) - def test_module_path(self, device, dtype): - linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) - quantize_(linear, get_config(group_size=64)) - self.assertEqual( - str(type(linear.weight)), - "", - ) - - with tempfile.NamedTemporaryFile() as f: - torch.save(linear.state_dict(), f) - f.seek(0) - state_dict = torch.load(f) - self.assertEqual( - str(type(state_dict["weight"])), - "", - ) - - @parametrize("device", ["npu"]) - @parametrize("dtype", [torch.float16, torch.bfloat16]) - def test_activation_prescaling(self, device, dtype): - input = torch.randn(1, 128, dtype=dtype, device=device) - linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) - original = linear(input) - quantize_(linear, get_config(64)) - qw = linear.weight - assert isinstance( - qw, SupportsActivationPreScaling - ), "Expected int4 tensor supports activation prescaling" - assert qw.act_pre_scale is None, "Default `act_pre_scale` is None" - _ACT_PRE_SCALE = 2 - qw.act_pre_scale = _ACT_PRE_SCALE - quantized = linear(input) - - # making sure activation pre scaling is successfully applied to the activation - self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10) - - -instantiate_parametrized_tests(Int4PlainInt32TensorNPU) - -if __name__ == "__main__": - run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 33aeb7512c..87e011c57b 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -94,7 +94,6 @@ Int4MarlinSparseTensor, Int4OpaqueTensor, Int4PlainInt32Tensor, - Int4PlainInt32TensorNPU, Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 8fc9c9de75..ae8210a41a 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -78,7 +78,6 @@ Int4OpaqueTensor, Int4PackingFormat, Int4PlainInt32Tensor, - Int4PlainInt32TensorNPU, Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, @@ -1195,16 +1194,10 @@ def _int4_weight_only_quantize_tensor(weight, config): ) return new_weight elif int4_packing_format == Int4PackingFormat.PLAIN_INT32: - if weight.device.type == "npu": - new_weight = Int4PlainInt32TensorNPU.from_hp( - weight, - block_size, - ) - else: - new_weight = Int4PlainInt32Tensor.from_hp( - weight, - block_size, - ) + new_weight = Int4PlainInt32Tensor.from_hp( + weight, + block_size, + ) return new_weight elif int4_packing_format == Int4PackingFormat.MARLIN_SPARSE: new_weight = Int4MarlinSparseTensor.from_hp( diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index ea05afd733..4307637f8e 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -13,9 +13,6 @@ from .int4.int4_plain_int32_tensor import ( Int4PlainInt32Tensor, ) -from .int4.int4_plain_int32_tensor_npu import ( - Int4PlainInt32TensorNPU, -) from .int4.int4_preshuffled_tensor import ( Int4PreshuffledTensor, ) @@ -39,7 +36,6 @@ "Int4PreshuffledTensor", "Int4MarlinSparseTensor", "Int4PlainInt32Tensor", - "Int4PlainInt32TensorNPU", "Int4TilePackedTo4dTensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", diff --git a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py index 2c8de1a2d0..0c8d7d65d4 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py @@ -91,60 +91,152 @@ def from_hp( w: torch.Tensor, block_size: List[int], ): - assert w.ndim == 2 and w.device.type == "xpu", ( - f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}" - ) - assert len(block_size) == w.ndim - assert w.dtype in [torch.float16, torch.bfloat16], ( - f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" - ) - original_shape = w.shape - mapping_type = MappingType.ASYMMETRIC - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - eps = 1e-6 - scale_dtype = None - zero_point_dtype = torch.int32 - scale, zero_point = choose_qparams_affine( - w, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - scale_dtype, - zero_point_dtype, - ) - int_data = quantize_affine( - w, - block_size, - scale, - zero_point, - target_dtype, - quant_min, - quant_max, - ) - assert int_data.dtype == torch.int32, ( - "torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype" - ) - packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - packed_weight.contiguous(), 8 - ) - scale = scale.reshape(int_data.shape[0], -1) - zero_point = zero_point.reshape(int_data.shape[0], -1) - return Int4PlainInt32Tensor( - packed_weight, - scale.transpose(0, 1).contiguous(), - zero_point.transpose(0, 1).contiguous().to(torch.int8), - block_size, - original_shape, - act_pre_scale=None, - ) + if w.device.type == "xpu": + return _from_hp_xpu(cls, w, block_size) + elif w.device.type == "npu": + return _from_hp_npu(cls, w, block_size) + else: + raise AssertionError(f"Int4PlainInt32Tensor does not support device '{w.device.type}' yet.") +def _from_hp_xpu( + cls, + w: torch.Tensor, + block_size: List[int], +): + assert w.ndim == 2 and w.device.type == "xpu", ( + f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}" + ) + assert len(block_size) == w.ndim + assert w.dtype in [torch.float16, torch.bfloat16], ( + f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" + ) + original_shape = w.shape + mapping_type = MappingType.ASYMMETRIC + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + scale_dtype = None + zero_point_dtype = torch.int32 + scale, zero_point = choose_qparams_affine( + w, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + ) + int_data = quantize_affine( + w, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + ) + assert int_data.dtype == torch.int32, ( + "torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype" + ) + packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + packed_weight.contiguous(), 8 + ) + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + return Int4PlainInt32Tensor( + packed_weight, + scale.transpose(0, 1).contiguous(), + zero_point.transpose(0, 1).contiguous().to(torch.int8), + block_size, + original_shape, + act_pre_scale=None, + ) +def _from_hp_npu( + cls, + w: torch.Tensor, + block_size: List[int], +): + assert w.ndim == 2 and w.device.type == "npu", ( + f"Expecting 2D tensor on NPU, but got: {w.shape} on {w.device.type}" + ) + assert len(block_size) == w.ndim + assert w.dtype in [torch.float16, torch.bfloat16], ( + f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" + ) + + group_size = block_size[1] + k_dim = w.shape[-1] + assert ( + group_size >= 32 + and group_size % 32 == 0 + and group_size < k_dim + ), ( + f"Invalid group_size={group_size}: " + f"expected to be a multiple of 32, " + f"in range [32, {k_dim - 1}] for per-group quantization, " + f"but got group_size={group_size} (k_dim={k_dim})." + ) + + original_shape = w.shape + mapping_type = MappingType.ASYMMETRIC + target_dtype = torch.int32 + quant_min = -8 + quant_max = 7 + eps = 1e-6 + scale_dtype = w.dtype + zero_point_dtype = w.dtype + + scale, zero_point = choose_qparams_affine( + w, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + ) + + int_data = quantize_affine( + w, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + ) + + assert int_data.dtype == torch.int32, ( + "torch.ops.npu.npu_convert_weight_to_int4pack expects `int32` dtype" + ) + assert int_data.shape[-1] % 8 == 0, ( + f"torch.ops.npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}" + ) + + packed_weight = torch.ops.npu.npu_convert_weight_to_int4pack( + int_data.contiguous(), 0 + ) + + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + + return Int4PlainInt32Tensor( + packed_weight, + scale.transpose(0, 1).contiguous(), + zero_point.transpose(0, 1).contiguous(), + block_size, + original_shape, + act_pre_scale=None, + ) + + implements = Int4PlainInt32Tensor.implements implements_torch_function = Int4PlainInt32Tensor.implements_torch_function @@ -157,6 +249,20 @@ def _(func, types, args, kwargs): args[1], args[2] if len(args) > 2 else None, ) + + if input_tensor.device.type == "xpu": + return _linear_xpu(input_tensor, weight_tensor, bias) + elif input_tensor.device.type == "npu": + return _linear_npu(input_tensor, weight_tensor, bias) + else: + raise AssertionError(f"Int4PlainInt32Tensor does not support device '{input_tensor.device.type}' yet.") + + +def _linear_xpu( + input_tensor, + weight_tensor, + bias, +): assert input_tensor.device.type == "xpu", ( f"For XPU device only but got: {input_tensor.device}" ) @@ -200,8 +306,73 @@ def _(func, types, args, kwargs): y += bias return y.to(orig_dtype) +def _linear_npu( + input_tensor, + weight_tensor, + bias, +): + assert input_tensor.device.type == "npu", ( + f"For NPU device only but got: {input_tensor.device.type}" + ) + assert isinstance(weight_tensor, Int4PlainInt32Tensor), ( + f"Expected weight_tensor to be Int4PlainInt32NPUTensor, got: {type(weight_tensor)}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" + ) + + if weight_tensor.act_pre_scale is not None: + input_tensor = input_tensor * weight_tensor.act_pre_scale + + act_mat = input_tensor + packed_weight = weight_tensor.qdata + scale = weight_tensor.scale + zero_point = weight_tensor.zero_point + + orig_act_size = act_mat.shape + orig_dtype = act_mat.dtype + + # dtype alignment + if act_mat.dtype == torch.float16: + scale = scale.to(torch.float16) + zero_point = zero_point.to(torch.float16) + if bias is not None: + bias = bias.to(torch.float16) + elif act_mat.dtype == torch.bfloat16: + scale = scale.to(torch.bfloat16) + zero_point = zero_point.to(torch.bfloat16) + if bias is not None: + bias = bias.to(torch.float32) + + # reshape to 2D + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[1] + + y = torch.ops.npu.npu_weight_quant_batchmatmul( + x=act_mat, + weight=packed_weight.contiguous().transpose(-1, -2), + antiquant_scale=scale, + antiquant_offset=zero_point, + antiquant_group_size=groupsize, + bias=bias, + ) + + # remove out_feature padding + assert weight_tensor.ndim == 2 + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + return y.to(orig_dtype) + Int4PlainInt32Tensor.__module__ = "torchao.quantization" # Allow a model with Int4PlainInt32Tensor weights to be loaded with `weights_only=True` torch.serialization.add_safe_globals([Int4PlainInt32Tensor]) + diff --git a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py deleted file mode 100644 index 80ddcd9619..0000000000 --- a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -from typing import List, Optional - -import torch - -from torchao.quantization.quant_primitives import ( - MappingType, - choose_qparams_affine, - quantize_affine, -) -from torchao.utils import ( - TorchAOBaseTensor, -) - -__all__ = ["Int4PlainInt32TensorNPU"] - -aten = torch.ops.aten - - -class Int4PlainInt32TensorNPU(TorchAOBaseTensor): - """ - int4 weight-only quantization on Ascend NPU backend (groupwise quantization only) - - Tensor Attributes: - qdata: (N, K/8), packed int4 weight, the data type is int32 here with 8*int4, the original dtype can be float16 or bfloat16 - scale: (K/group_size, N), dtype is the same as the original Tensor type (float16 or bfloat16) - zero_point: (K/group_size, N), dtype is the same as the original Tensor type (float16 or bfloat16) - - Non-Tensor Attributes: - block_size: the block size for quantization, representing the granularity - shape: shape of the original Tensor - - Optional Tensor Data Attributes: - act_pre_scale (Optional[Tensor]): Optional scale for activation Tensor, if present, - we'll multiply activation Tensor with act_pre_scale before applying dynamic - quantization to activation or running quantized mm op - - """ - - tensor_data_names = ["qdata", "scale", "zero_point"] - tensor_attribute_names = ["block_size", "shape"] - optional_tensor_data_names = ["act_pre_scale"] - - def __new__( - cls, - qdata, - scale, - zero_point, - block_size, - shape, - act_pre_scale: Optional[torch.Tensor] = None, - ): - kwargs = {} - kwargs["device"] = qdata.device - kwargs["dtype"] = scale.dtype - kwargs["requires_grad"] = False - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - qdata, - scale, - zero_point, - block_size, - shape, - act_pre_scale: Optional[torch.Tensor] = None, - ): - self.qdata = qdata - self.scale = scale - self.zero_point = zero_point - self.block_size = block_size - self.act_pre_scale = act_pre_scale - - def _quantization_type(self): - s = f"shape={self.shape}, block_size={self.block_size}, device={self.device}" - if self.act_pre_scale is not None: - s += f", act_pre_scale.shape={self.act_pre_scale.shape}" - return s - - @classmethod - def from_hp( - cls, - w: torch.Tensor, - block_size: List[int], - ): - assert w.ndim == 2 and w.device.type == "npu", ( - f"Expecting 2D tensor on NPU, but got: {w.shape} on {w.device.type}" - ) - assert len(block_size) == w.ndim - assert w.dtype in [torch.float16, torch.bfloat16], ( - f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" - ) - - original_shape = w.shape - mapping_type = MappingType.ASYMMETRIC - target_dtype = torch.int32 - quant_min = -8 - quant_max = 7 - eps = 1e-6 - scale_dtype = w.dtype - zero_point_dtype = w.dtype - - scale, zero_point = choose_qparams_affine( - w, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - scale_dtype, - zero_point_dtype, - ) - - int_data = quantize_affine( - w, - block_size, - scale, - zero_point, - target_dtype, - quant_min, - quant_max, - ) - - assert int_data.dtype == torch.int32, ( - f"torch.ops.npu.npu_convert_weight_to_int4pack expects `int32` dtype" - ) - - assert int_data.shape[-1] % 8 == 0, ( - f"torch.ops.npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}" - ) - - packed_weight = torch.ops.npu.npu_convert_weight_to_int4pack( - int_data.contiguous(), 0 - ) - - scale = scale.reshape(int_data.shape[0], -1) - zero_point = zero_point.reshape(int_data.shape[0], -1) - - return Int4PlainInt32TensorNPU( - packed_weight, - scale.transpose(0, 1).contiguous(), - zero_point.transpose(0, 1).contiguous(), - block_size, - original_shape, - act_pre_scale=None, - ) - - -implements = Int4PlainInt32TensorNPU.implements -implements_torch_function = Int4PlainInt32TensorNPU.implements_torch_function - - -@implements(aten.linear.default) -@implements_torch_function(torch.nn.functional.linear) -def _(func, types, args, kwargs): - - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - - assert input_tensor.device.type == "npu", ( - f"For NPU device only but got: {input_tensor.device.type}" - ) - assert isinstance(weight_tensor, Int4PlainInt32TensorNPU), ( - f"Expected weight_tensor to be Int4PlainInt32NPUTensor, got: {type(weight_tensor)}" - ) - assert weight_tensor.block_size[0] == 1, ( - f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" - ) - assert input_tensor.shape[-1] == weight_tensor.shape[1], ( - f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" - ) - - if weight_tensor.act_pre_scale is not None: - input_tensor = input_tensor * weight_tensor.act_pre_scale - - act_mat = input_tensor - packed_weight = weight_tensor.qdata - scale = weight_tensor.scale - zero_point = weight_tensor.zero_point - - orig_act_size = act_mat.shape - orig_dtype = act_mat.dtype - - # dtype alignment - if act_mat.dtype == torch.float16: - scale = scale.to(torch.float16) - zero_point = zero_point.to(torch.float16) - if bias is not None: - bias = bias.to(torch.float16) - elif act_mat.dtype == torch.bfloat16: - scale = scale.to(torch.bfloat16) - zero_point = zero_point.to(torch.bfloat16) - if bias is not None: - bias = bias.to(torch.float32) - - # reshape to 2D - act_mat = act_mat.reshape(-1, act_mat.shape[-1]) - - # groupwise int4 quantization - groupsize = weight_tensor.block_size[1] - - y = torch.ops.npu.npu_weight_quant_batchmatmul( - x=act_mat, - weight=packed_weight.contiguous().transpose(-1, -2), - antiquant_scale=scale, - antiquant_offset=zero_point, - antiquant_group_size=groupsize, - bias=bias, - ) - - # remove out_feature padding - assert weight_tensor.ndim == 2 - orig_out_features = weight_tensor.shape[-2] - y = y[:, :orig_out_features] - y = y.reshape(*orig_act_size[:-1], orig_out_features) - - return y.to(orig_dtype) - - -Int4PlainInt32TensorNPU.__module__ = "torchao.quantization" - -# Allow a model with Int4PlainInt32TensorNPU weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([Int4PlainInt32TensorNPU]) From ca8f0566d257060e38da99eef2baf15176d8277c Mon Sep 17 00:00:00 2001 From: orangeH25 <18085625039@163.com> Date: Wed, 29 Oct 2025 08:51:17 +0000 Subject: [PATCH 5/9] ruff format cleanup, replace error types, add torch version check --- .../int4/test_int4_plain_int32_tensor.py | 7 +-- .../workflows/int4/int4_plain_int32_tensor.py | 61 ++++++++++--------- 2 files changed, 36 insertions(+), 32 deletions(-) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py index d8f6640c8d..5586469041 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py @@ -105,7 +105,6 @@ def test_activation_prescaling(self): "NPU not available", ) class Int4PlainInt32TensorNPU(TestCase): - @parametrize("device", ["npu"]) @parametrize( "sizes", @@ -153,9 +152,9 @@ def test_activation_prescaling(self, device, dtype): original = linear(input) quantize_(linear, get_config(64)) qw = linear.weight - assert isinstance( - qw, SupportsActivationPreScaling - ), "Expected int4 tensor supports activation prescaling" + assert isinstance(qw, SupportsActivationPreScaling), ( + "Expected int4 tensor supports activation prescaling" + ) assert qw.act_pre_scale is None, "Default `act_pre_scale` is None" _ACT_PRE_SCALE = 2 qw.act_pre_scale = _ACT_PRE_SCALE diff --git a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py index 0c8d7d65d4..3f7d1a27e9 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py @@ -14,9 +14,7 @@ choose_qparams_affine, quantize_affine, ) -from torchao.utils import ( - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor, torch_version_at_least __all__ = [ "Int4PlainInt32Tensor", @@ -96,7 +94,10 @@ def from_hp( elif w.device.type == "npu": return _from_hp_npu(cls, w, block_size) else: - raise AssertionError(f"Int4PlainInt32Tensor does not support device '{w.device.type}' yet.") + raise NotImplementedError( + f"Int4PlainInt32Tensor does not support device '{w.device.type}' yet." + ) + def _from_hp_xpu( cls, @@ -156,11 +157,17 @@ def _from_hp_xpu( act_pre_scale=None, ) + def _from_hp_npu( cls, w: torch.Tensor, block_size: List[int], ): + # Require PyTorch 2.7.1+ for NPU backend ops and backward compatibility. + assert torch_version_at_least("2.7.1"), ( + "Need pytorch 2.7.1+ for NPU backend op support." + ) + assert w.ndim == 2 and w.device.type == "npu", ( f"Expecting 2D tensor on NPU, but got: {w.shape} on {w.device.type}" ) @@ -168,20 +175,16 @@ def _from_hp_npu( assert w.dtype in [torch.float16, torch.bfloat16], ( f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" ) - + group_size = block_size[1] k_dim = w.shape[-1] - assert ( - group_size >= 32 - and group_size % 32 == 0 - and group_size < k_dim - ), ( + assert group_size >= 32 and group_size % 32 == 0 and group_size < k_dim, ( f"Invalid group_size={group_size}: " f"expected to be a multiple of 32, " f"in range [32, {k_dim - 1}] for per-group quantization, " f"but got group_size={group_size} (k_dim={k_dim})." ) - + original_shape = w.shape mapping_type = MappingType.ASYMMETRIC target_dtype = torch.int32 @@ -190,7 +193,7 @@ def _from_hp_npu( eps = 1e-6 scale_dtype = w.dtype zero_point_dtype = w.dtype - + scale, zero_point = choose_qparams_affine( w, mapping_type, @@ -202,7 +205,7 @@ def _from_hp_npu( scale_dtype, zero_point_dtype, ) - + int_data = quantize_affine( w, block_size, @@ -212,31 +215,31 @@ def _from_hp_npu( quant_min, quant_max, ) - + assert int_data.dtype == torch.int32, ( "torch.ops.npu.npu_convert_weight_to_int4pack expects `int32` dtype" ) assert int_data.shape[-1] % 8 == 0, ( f"torch.ops.npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}" ) - + packed_weight = torch.ops.npu.npu_convert_weight_to_int4pack( int_data.contiguous(), 0 ) - + scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) - + return Int4PlainInt32Tensor( - packed_weight, + packed_weight.contiguous(), scale.transpose(0, 1).contiguous(), zero_point.transpose(0, 1).contiguous(), block_size, original_shape, act_pre_scale=None, ) - - + + implements = Int4PlainInt32Tensor.implements implements_torch_function = Int4PlainInt32Tensor.implements_torch_function @@ -249,20 +252,22 @@ def _(func, types, args, kwargs): args[1], args[2] if len(args) > 2 else None, ) - + if input_tensor.device.type == "xpu": return _linear_xpu(input_tensor, weight_tensor, bias) elif input_tensor.device.type == "npu": return _linear_npu(input_tensor, weight_tensor, bias) else: - raise AssertionError(f"Int4PlainInt32Tensor does not support device '{input_tensor.device.type}' yet.") + raise NotImplementedError( + f"Int4PlainInt32Tensor does not support device '{input_tensor.device.type}' yet." + ) def _linear_xpu( input_tensor, weight_tensor, bias, -): +): assert input_tensor.device.type == "xpu", ( f"For XPU device only but got: {input_tensor.device}" ) @@ -306,11 +311,12 @@ def _linear_xpu( y += bias return y.to(orig_dtype) + def _linear_npu( input_tensor, weight_tensor, bias, -): +): assert input_tensor.device.type == "npu", ( f"For NPU device only but got: {input_tensor.device.type}" ) @@ -355,19 +361,19 @@ def _linear_npu( y = torch.ops.npu.npu_weight_quant_batchmatmul( x=act_mat, - weight=packed_weight.contiguous().transpose(-1, -2), + weight=packed_weight.transpose(-1, -2), antiquant_scale=scale, antiquant_offset=zero_point, antiquant_group_size=groupsize, bias=bias, ) - + # remove out_feature padding assert weight_tensor.ndim == 2 orig_out_features = weight_tensor.shape[-2] y = y[:, :orig_out_features] y = y.reshape(*orig_act_size[:-1], orig_out_features) - + return y.to(orig_dtype) @@ -375,4 +381,3 @@ def _linear_npu( # Allow a model with Int4PlainInt32Tensor weights to be loaded with `weights_only=True` torch.serialization.add_safe_globals([Int4PlainInt32Tensor]) - From 05af947347c89d4537ed65f91ca4863c7de3cf0c Mon Sep 17 00:00:00 2001 From: orangeH25 <18085625039@163.com> Date: Fri, 31 Oct 2025 09:52:47 +0000 Subject: [PATCH 6/9] add torch_npu version assertion and show downstream testing result --- README.md | 6 ++++++ .../quantize_/workflows/int4/int4_plain_int32_tensor.py | 9 ++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index ff24a989e5..c9c9689c66 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,12 @@ TorchAO is an easy to use quantization library for native PyTorch. TorchAO works Check out our [docs](https://docs.pytorch.org/ao/main/) for more details! +## Third-party Pipeline Status + +| Backend | Inference | +| ---------- | ------------------------------------------------------------ | +| Ascend NPU | [![Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml/badge.svg)]( | + ## 🚀 Quick Start First, install TorchAO. We recommend installing the latest stable version: diff --git a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py index 3f7d1a27e9..51e09dbe9c 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py @@ -163,9 +163,12 @@ def _from_hp_npu( w: torch.Tensor, block_size: List[int], ): - # Require PyTorch 2.7.1+ for NPU backend ops and backward compatibility. - assert torch_version_at_least("2.7.1"), ( - "Need pytorch 2.7.1+ for NPU backend op support." + assert ( + torch.accelerator.is_available() + and torch.accelerator.current_accelerator().type == "npu" + and torch_version_at_least("2.7.1") + ), ( + f"PyTorch NPU 2.7.1+ needed for int4 packing and matmul ops, {torch.__version__} found" ) assert w.ndim == 2 and w.device.type == "npu", ( From 25360da10406a4ee93afed2deabd475a8a534295 Mon Sep 17 00:00:00 2001 From: orangeH25 <18085625039@163.com> Date: Fri, 31 Oct 2025 09:59:00 +0000 Subject: [PATCH 7/9] add downstream testing result --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c9c9689c66..81d8a4a65a 100644 --- a/README.md +++ b/README.md @@ -62,9 +62,9 @@ Check out our [docs](https://docs.pytorch.org/ao/main/) for more details! ## Third-party Pipeline Status -| Backend | Inference | -| ---------- | ------------------------------------------------------------ | -| Ascend NPU | [![Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml/badge.svg)]( | +| Backend | Inference | +| ----------- | -------------------------------------------------------------------------------------------------------------------- | +| Ascend NPU | [![Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml/badge.svg)](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml) | ## 🚀 Quick Start From fa3220ff5aa8201554f262608847f3b2ac20fec1 Mon Sep 17 00:00:00 2001 From: orangeH25 <18085625039@163.com> Date: Mon, 3 Nov 2025 07:59:27 +0000 Subject: [PATCH 8/9] unify NPU and XPU test cases into a single class --- .../int4/test_int4_plain_int32_tensor.py | 137 +++++++----------- 1 file changed, 52 insertions(+), 85 deletions(-) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py index 5586469041..dd44734f2e 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py @@ -5,12 +5,12 @@ # LICENSE file in the root directory of this source tree. import tempfile -import unittest +import pytest import torch +from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import ( TestCase, - instantiate_parametrized_tests, parametrize, run_tests, ) @@ -33,9 +33,19 @@ def get_config(group_size): ) -@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") -@unittest.skipIf(not torch.xpu.is_available(), "XPU not available") -class Int4PlainInt32TensorXPU(TestCase): +class Int4PlainInt32Tensor(TestCase): + _MIN_VER = { + "xpu": "2.8.0", + "npu": "2.7.1", + } + + def setUp(self): + min_req = type(self)._MIN_VER.get(self.device_type) + if not torch_version_at_least(min_req): + self.skipTest( + f"{self.device_type} requires torch >= {min_req}, current {torch.__version__}" + ) + @parametrize( "sizes", [ @@ -46,90 +56,36 @@ class Int4PlainInt32TensorXPU(TestCase): ) @parametrize("dtype", [torch.bfloat16, torch.half]) @parametrize("group_size", [32, 64, 128]) - def test_linear(self, sizes, dtype, group_size): - device = "xpu" + @parametrize("thresholds", [{"xpu": 20, "npu": 10}]) + def test_linear(self, device, sizes, dtype, group_size, thresholds): M, N, K = sizes + if "npu" in device and group_size == K: + pytest.skip( + f"{device} does not support group_size equal to K dimension ({group_size} == {K})" + ) + threshold = thresholds.get(device.split(":")[0]) + input = torch.randn(*M, K, dtype=dtype, device=device) linear = torch.nn.Linear(K, N, dtype=dtype, device=device) original = linear(input) quantize_(linear, get_config(group_size)) quantized = linear(input) - self.assertTrue(compute_error(original, quantized) > 20) + self.assertTrue(compute_error(original, quantized) > threshold) - compiled_linear = torch.compile(linear) - quantized_and_compiled = compiled_linear(input) - self.assertTrue(compute_error(original, quantized_and_compiled) > 20) + if "xpu" in device: + compiled_linear = torch.compile(linear) + quantized_and_compiled = compiled_linear(input) + self.assertTrue(compute_error(original, quantized_and_compiled) > threshold) @parametrize("dtype", [torch.bfloat16, torch.half]) - def test_module_path(self, dtype): - linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu") - quantize_(linear, get_config(group_size=128)) - self.assertEqual( - str(type(linear.weight)), - "", - ) - - with tempfile.NamedTemporaryFile() as f: - torch.save(linear.state_dict(), f) - f.seek(0) - state_dict = torch.load(f) - self.assertEqual( - str(type(state_dict["weight"])), - "", - ) - - def test_activation_prescaling(self): - dtype = torch.bfloat16 - device = "xpu" - input = torch.randn(1, 128, dtype=dtype, device=device) - linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) - original = linear(input) - quantize_(linear, get_config(128)) - qw = linear.weight - assert isinstance(qw, SupportsActivationPreScaling), ( - "Expected int4 tensor supports activation prescaling" - ) - assert qw.act_pre_scale is None, "Default `act_pre_scale` is None" - _ACT_PRE_SCALE = 2 - qw.act_pre_scale = _ACT_PRE_SCALE - quantized = linear(input) - - # making sure activation pre scaling is successfully applied to the activation - self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20) - + def test_module_path(self, device, dtype): + device = self.device_type + K, N, group_size = 128, 256, 128 + if "npu" in device: + group_size = 64 -@unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+") -@unittest.skipIf( - torch.accelerator.current_accelerator().type != "npu" - or not torch.accelerator.is_available(), - "NPU not available", -) -class Int4PlainInt32TensorNPU(TestCase): - @parametrize("device", ["npu"]) - @parametrize( - "sizes", - [ - ((128,), 256, 128), - ((32, 128), 512, 128), - ((2, 32, 128), 256, 128), - ], - ) - @parametrize("dtype", [torch.float16, torch.bfloat16]) - @parametrize("group_size", [32, 64]) - def test_linear(self, device, sizes, dtype, group_size): - M, N, K = sizes - input = torch.randn(*M, K, dtype=dtype, device=device) linear = torch.nn.Linear(K, N, dtype=dtype, device=device) - orig_output = linear(input) quantize_(linear, get_config(group_size)) - quantized_output = linear(input) - self.assertTrue(compute_error(orig_output, quantized_output) > 10) - - @parametrize("device", ["npu"]) - @parametrize("dtype", [torch.float16, torch.bfloat16]) - def test_module_path(self, device, dtype): - linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) - quantize_(linear, get_config(group_size=64)) self.assertEqual( str(type(linear.weight)), "", @@ -144,13 +100,22 @@ def test_module_path(self, device, dtype): "", ) - @parametrize("device", ["npu"]) @parametrize("dtype", [torch.float16, torch.bfloat16]) - def test_activation_prescaling(self, device, dtype): - input = torch.randn(1, 128, dtype=dtype, device=device) - linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) + @parametrize("thresholds", [{"xpu": 20, "npu": 10}]) + def test_activation_prescaling(self, device, dtype, thresholds): + device = self.device_type + if "xpu" in device and dtype == torch.float16: + pytest.skip(f"{device} test_activation_prescaling don't test {dtype}") + + threshold = thresholds.get(device.split(":")[0]) + K, N, group_size = 128, 256, 128 + if "npu" in device: + group_size = 64 + + input = torch.randn(1, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device) original = linear(input) - quantize_(linear, get_config(64)) + quantize_(linear, get_config(group_size)) qw = linear.weight assert isinstance(qw, SupportsActivationPreScaling), ( "Expected int4 tensor supports activation prescaling" @@ -161,11 +126,13 @@ def test_activation_prescaling(self, device, dtype): quantized = linear(input) # making sure activation pre scaling is successfully applied to the activation - self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10) + self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > threshold) + +instantiate_device_type_tests( + Int4PlainInt32Tensor, globals(), only_for=("xpu", "npu"), allow_xpu=True +) -instantiate_parametrized_tests(Int4PlainInt32TensorXPU) -instantiate_parametrized_tests(Int4PlainInt32TensorNPU) if __name__ == "__main__": run_tests() From 623c589ff911d434d7004549475975503c33f70b Mon Sep 17 00:00:00 2001 From: orangeH25 <18085625039@163.com> Date: Wed, 12 Nov 2025 03:48:21 +0000 Subject: [PATCH 9/9] move CI display to quantization README and update test file --- README.md | 6 ------ .../workflows/int4/test_int4_plain_int32_tensor.py | 2 -- torchao/quantization/README.md | 8 ++++++-- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 81d8a4a65a..ff24a989e5 100644 --- a/README.md +++ b/README.md @@ -60,12 +60,6 @@ TorchAO is an easy to use quantization library for native PyTorch. TorchAO works Check out our [docs](https://docs.pytorch.org/ao/main/) for more details! -## Third-party Pipeline Status - -| Backend | Inference | -| ----------- | -------------------------------------------------------------------------------------------------------------------- | -| Ascend NPU | [![Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml/badge.svg)](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml) | - ## 🚀 Quick Start First, install TorchAO. We recommend installing the latest stable version: diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py index dd44734f2e..1b17c40fb0 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py @@ -79,7 +79,6 @@ def test_linear(self, device, sizes, dtype, group_size, thresholds): @parametrize("dtype", [torch.bfloat16, torch.half]) def test_module_path(self, device, dtype): - device = self.device_type K, N, group_size = 128, 256, 128 if "npu" in device: group_size = 64 @@ -103,7 +102,6 @@ def test_module_path(self, device, dtype): @parametrize("dtype", [torch.float16, torch.bfloat16]) @parametrize("thresholds", [{"xpu": 20, "npu": 10}]) def test_activation_prescaling(self, device, dtype, thresholds): - device = self.device_type if "xpu" in device and dtype == torch.float16: pytest.skip(f"{device} test_activation_prescaling don't test {dtype}") diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index dbd7983b8e..a4b5d2801e 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -71,8 +71,12 @@ use_hqq = False quantize_(model, Int4WeightOnlyConfig(group_size=group_size, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq")) ``` -Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model. - +Note: +- The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model. +- Third-party backend CI status: + - Ascend NPU(requires torch_npu ≥ 2.7.1) + [![Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml/badge.svg)](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml) + #### A16W8 Int8 WeightOnly Quantization ```python