-
Notifications
You must be signed in to change notification settings - Fork 376
Add NPU (Ascend) backend support for INT4 weight-only quantization workflow #3172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
f3aefca
68eea61
164435e
06c77d1
498f052
ea2aa7a
ca8f056
05af947
25360da
fa3220f
623c589
89ad729
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,7 +35,7 @@ def get_config(group_size): | |
|
|
||
| @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") | ||
| @unittest.skipIf(not torch.xpu.is_available(), "XPU not available") | ||
| class Int4PlainInt32Tensor(TestCase): | ||
| class Int4PlainInt32TensorXPU(TestCase): | ||
| @parametrize( | ||
| "sizes", | ||
| [ | ||
|
|
@@ -98,8 +98,75 @@ def test_activation_prescaling(self): | |
| self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20) | ||
|
|
||
|
|
||
| instantiate_parametrized_tests(Int4PlainInt32Tensor) | ||
| @unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+") | ||
| @unittest.skipIf( | ||
| torch.accelerator.current_accelerator().type != "npu" | ||
| or not torch.accelerator.is_available(), | ||
| "NPU not available", | ||
| ) | ||
| class Int4PlainInt32TensorNPU(TestCase): | ||
|
||
|
|
||
| @parametrize("device", ["npu"]) | ||
| @parametrize( | ||
| "sizes", | ||
| [ | ||
| ((128,), 256, 128), | ||
| ((32, 128), 512, 128), | ||
| ((2, 32, 128), 256, 128), | ||
| ], | ||
| ) | ||
| @parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| @parametrize("group_size", [32, 64]) | ||
| def test_linear(self, device, sizes, dtype, group_size): | ||
| M, N, K = sizes | ||
| input = torch.randn(*M, K, dtype=dtype, device=device) | ||
| linear = torch.nn.Linear(K, N, dtype=dtype, device=device) | ||
| orig_output = linear(input) | ||
| quantize_(linear, get_config(group_size)) | ||
| quantized_output = linear(input) | ||
| self.assertTrue(compute_error(orig_output, quantized_output) > 10) | ||
|
|
||
| @parametrize("device", ["npu"]) | ||
| @parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| def test_module_path(self, device, dtype): | ||
| linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) | ||
| quantize_(linear, get_config(group_size=64)) | ||
| self.assertEqual( | ||
| str(type(linear.weight)), | ||
| "<class 'torchao.quantization.Int4PlainInt32Tensor'>", | ||
| ) | ||
|
|
||
| with tempfile.NamedTemporaryFile() as f: | ||
| torch.save(linear.state_dict(), f) | ||
| f.seek(0) | ||
| state_dict = torch.load(f) | ||
| self.assertEqual( | ||
| str(type(state_dict["weight"])), | ||
| "<class 'torchao.quantization.Int4PlainInt32Tensor'>", | ||
| ) | ||
|
|
||
| @parametrize("device", ["npu"]) | ||
| @parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| def test_activation_prescaling(self, device, dtype): | ||
| input = torch.randn(1, 128, dtype=dtype, device=device) | ||
| linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) | ||
| original = linear(input) | ||
| quantize_(linear, get_config(64)) | ||
| qw = linear.weight | ||
| assert isinstance( | ||
| qw, SupportsActivationPreScaling | ||
| ), "Expected int4 tensor supports activation prescaling" | ||
| assert qw.act_pre_scale is None, "Default `act_pre_scale` is None" | ||
| _ACT_PRE_SCALE = 2 | ||
| qw.act_pre_scale = _ACT_PRE_SCALE | ||
| quantized = linear(input) | ||
|
|
||
| # making sure activation pre scaling is successfully applied to the activation | ||
| self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10) | ||
|
|
||
|
|
||
| instantiate_parametrized_tests(Int4PlainInt32TensorXPU) | ||
| instantiate_parametrized_tests(Int4PlainInt32TensorNPU) | ||
|
|
||
| if __name__ == "__main__": | ||
| run_tests() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -91,60 +91,152 @@ def from_hp( | |
| w: torch.Tensor, | ||
| block_size: List[int], | ||
| ): | ||
| assert w.ndim == 2 and w.device.type == "xpu", ( | ||
| f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}" | ||
| ) | ||
| assert len(block_size) == w.ndim | ||
| assert w.dtype in [torch.float16, torch.bfloat16], ( | ||
| f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" | ||
| ) | ||
| original_shape = w.shape | ||
| mapping_type = MappingType.ASYMMETRIC | ||
| target_dtype = torch.int32 | ||
| quant_min = 0 | ||
| quant_max = 15 | ||
| eps = 1e-6 | ||
| scale_dtype = None | ||
| zero_point_dtype = torch.int32 | ||
| scale, zero_point = choose_qparams_affine( | ||
| w, | ||
| mapping_type, | ||
| block_size, | ||
| target_dtype, | ||
| quant_min, | ||
| quant_max, | ||
| eps, | ||
| scale_dtype, | ||
| zero_point_dtype, | ||
| ) | ||
| int_data = quantize_affine( | ||
| w, | ||
| block_size, | ||
| scale, | ||
| zero_point, | ||
| target_dtype, | ||
| quant_min, | ||
| quant_max, | ||
| ) | ||
| assert int_data.dtype == torch.int32, ( | ||
| "torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype" | ||
| ) | ||
| packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) | ||
| packed_weight = torch.ops.aten._convert_weight_to_int4pack( | ||
| packed_weight.contiguous(), 8 | ||
| ) | ||
| scale = scale.reshape(int_data.shape[0], -1) | ||
| zero_point = zero_point.reshape(int_data.shape[0], -1) | ||
| return Int4PlainInt32Tensor( | ||
| packed_weight, | ||
| scale.transpose(0, 1).contiguous(), | ||
| zero_point.transpose(0, 1).contiguous().to(torch.int8), | ||
| block_size, | ||
| original_shape, | ||
| act_pre_scale=None, | ||
| ) | ||
| if w.device.type == "xpu": | ||
| return _from_hp_xpu(cls, w, block_size) | ||
| elif w.device.type == "npu": | ||
| return _from_hp_npu(cls, w, block_size) | ||
| else: | ||
| raise AssertionError(f"Int4PlainInt32Tensor does not support device '{w.device.type}' yet.") | ||
|
||
|
|
||
| def _from_hp_xpu( | ||
| cls, | ||
| w: torch.Tensor, | ||
| block_size: List[int], | ||
| ): | ||
| assert w.ndim == 2 and w.device.type == "xpu", ( | ||
| f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}" | ||
| ) | ||
| assert len(block_size) == w.ndim | ||
| assert w.dtype in [torch.float16, torch.bfloat16], ( | ||
| f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" | ||
| ) | ||
| original_shape = w.shape | ||
| mapping_type = MappingType.ASYMMETRIC | ||
| target_dtype = torch.int32 | ||
| quant_min = 0 | ||
| quant_max = 15 | ||
| eps = 1e-6 | ||
| scale_dtype = None | ||
| zero_point_dtype = torch.int32 | ||
| scale, zero_point = choose_qparams_affine( | ||
| w, | ||
| mapping_type, | ||
| block_size, | ||
| target_dtype, | ||
| quant_min, | ||
| quant_max, | ||
| eps, | ||
| scale_dtype, | ||
| zero_point_dtype, | ||
| ) | ||
| int_data = quantize_affine( | ||
| w, | ||
| block_size, | ||
| scale, | ||
| zero_point, | ||
| target_dtype, | ||
| quant_min, | ||
| quant_max, | ||
| ) | ||
| assert int_data.dtype == torch.int32, ( | ||
| "torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype" | ||
| ) | ||
| packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) | ||
| packed_weight = torch.ops.aten._convert_weight_to_int4pack( | ||
| packed_weight.contiguous(), 8 | ||
| ) | ||
| scale = scale.reshape(int_data.shape[0], -1) | ||
| zero_point = zero_point.reshape(int_data.shape[0], -1) | ||
| return Int4PlainInt32Tensor( | ||
| packed_weight, | ||
| scale.transpose(0, 1).contiguous(), | ||
| zero_point.transpose(0, 1).contiguous().to(torch.int8), | ||
| block_size, | ||
| original_shape, | ||
| act_pre_scale=None, | ||
| ) | ||
|
|
||
| def _from_hp_npu( | ||
| cls, | ||
| w: torch.Tensor, | ||
| block_size: List[int], | ||
| ): | ||
| assert w.ndim == 2 and w.device.type == "npu", ( | ||
| f"Expecting 2D tensor on NPU, but got: {w.shape} on {w.device.type}" | ||
| ) | ||
| assert len(block_size) == w.ndim | ||
| assert w.dtype in [torch.float16, torch.bfloat16], ( | ||
| f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" | ||
| ) | ||
|
|
||
| group_size = block_size[1] | ||
| k_dim = w.shape[-1] | ||
| assert ( | ||
| group_size >= 32 | ||
| and group_size % 32 == 0 | ||
| and group_size < k_dim | ||
| ), ( | ||
| f"Invalid group_size={group_size}: " | ||
| f"expected to be a multiple of 32, " | ||
| f"in range [32, {k_dim - 1}] for per-group quantization, " | ||
| f"but got group_size={group_size} (k_dim={k_dim})." | ||
| ) | ||
|
|
||
| original_shape = w.shape | ||
| mapping_type = MappingType.ASYMMETRIC | ||
| target_dtype = torch.int32 | ||
| quant_min = -8 | ||
| quant_max = 7 | ||
| eps = 1e-6 | ||
| scale_dtype = w.dtype | ||
| zero_point_dtype = w.dtype | ||
|
|
||
| scale, zero_point = choose_qparams_affine( | ||
| w, | ||
| mapping_type, | ||
| block_size, | ||
| target_dtype, | ||
| quant_min, | ||
| quant_max, | ||
| eps, | ||
| scale_dtype, | ||
| zero_point_dtype, | ||
| ) | ||
|
|
||
| int_data = quantize_affine( | ||
| w, | ||
| block_size, | ||
| scale, | ||
| zero_point, | ||
| target_dtype, | ||
| quant_min, | ||
| quant_max, | ||
| ) | ||
|
|
||
| assert int_data.dtype == torch.int32, ( | ||
| "torch.ops.npu.npu_convert_weight_to_int4pack expects `int32` dtype" | ||
| ) | ||
| assert int_data.shape[-1] % 8 == 0, ( | ||
| f"torch.ops.npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}" | ||
| ) | ||
|
|
||
|
||
| packed_weight = torch.ops.npu.npu_convert_weight_to_int4pack( | ||
| int_data.contiguous(), 0 | ||
| ) | ||
|
|
||
| scale = scale.reshape(int_data.shape[0], -1) | ||
| zero_point = zero_point.reshape(int_data.shape[0], -1) | ||
|
|
||
| return Int4PlainInt32Tensor( | ||
| packed_weight, | ||
| scale.transpose(0, 1).contiguous(), | ||
| zero_point.transpose(0, 1).contiguous(), | ||
| block_size, | ||
| original_shape, | ||
| act_pre_scale=None, | ||
| ) | ||
|
|
||
|
|
||
| implements = Int4PlainInt32Tensor.implements | ||
| implements_torch_function = Int4PlainInt32Tensor.implements_torch_function | ||
|
|
||
|
|
@@ -157,6 +249,20 @@ def _(func, types, args, kwargs): | |
| args[1], | ||
| args[2] if len(args) > 2 else None, | ||
| ) | ||
|
|
||
| if input_tensor.device.type == "xpu": | ||
| return _linear_xpu(input_tensor, weight_tensor, bias) | ||
| elif input_tensor.device.type == "npu": | ||
| return _linear_npu(input_tensor, weight_tensor, bias) | ||
| else: | ||
| raise AssertionError(f"Int4PlainInt32Tensor does not support device '{input_tensor.device.type}' yet.") | ||
|
||
|
|
||
|
|
||
| def _linear_xpu( | ||
| input_tensor, | ||
| weight_tensor, | ||
| bias, | ||
| ): | ||
| assert input_tensor.device.type == "xpu", ( | ||
| f"For XPU device only but got: {input_tensor.device}" | ||
| ) | ||
|
|
@@ -200,8 +306,73 @@ def _(func, types, args, kwargs): | |
| y += bias | ||
| return y.to(orig_dtype) | ||
|
|
||
| def _linear_npu( | ||
| input_tensor, | ||
| weight_tensor, | ||
| bias, | ||
| ): | ||
| assert input_tensor.device.type == "npu", ( | ||
| f"For NPU device only but got: {input_tensor.device.type}" | ||
| ) | ||
| assert isinstance(weight_tensor, Int4PlainInt32Tensor), ( | ||
| f"Expected weight_tensor to be Int4PlainInt32NPUTensor, got: {type(weight_tensor)}" | ||
| ) | ||
| assert weight_tensor.block_size[0] == 1, ( | ||
| f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" | ||
| ) | ||
| assert input_tensor.shape[-1] == weight_tensor.shape[1], ( | ||
| f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" | ||
| ) | ||
|
|
||
| if weight_tensor.act_pre_scale is not None: | ||
| input_tensor = input_tensor * weight_tensor.act_pre_scale | ||
|
|
||
| act_mat = input_tensor | ||
| packed_weight = weight_tensor.qdata | ||
| scale = weight_tensor.scale | ||
| zero_point = weight_tensor.zero_point | ||
|
|
||
| orig_act_size = act_mat.shape | ||
| orig_dtype = act_mat.dtype | ||
|
|
||
| # dtype alignment | ||
| if act_mat.dtype == torch.float16: | ||
| scale = scale.to(torch.float16) | ||
| zero_point = zero_point.to(torch.float16) | ||
| if bias is not None: | ||
| bias = bias.to(torch.float16) | ||
| elif act_mat.dtype == torch.bfloat16: | ||
| scale = scale.to(torch.bfloat16) | ||
| zero_point = zero_point.to(torch.bfloat16) | ||
| if bias is not None: | ||
| bias = bias.to(torch.float32) | ||
|
|
||
| # reshape to 2D | ||
| act_mat = act_mat.reshape(-1, act_mat.shape[-1]) | ||
|
|
||
| # groupwise int4 quantization | ||
| groupsize = weight_tensor.block_size[1] | ||
|
|
||
| y = torch.ops.npu.npu_weight_quant_batchmatmul( | ||
| x=act_mat, | ||
| weight=packed_weight.contiguous().transpose(-1, -2), | ||
|
||
| antiquant_scale=scale, | ||
| antiquant_offset=zero_point, | ||
| antiquant_group_size=groupsize, | ||
| bias=bias, | ||
| ) | ||
|
|
||
| # remove out_feature padding | ||
| assert weight_tensor.ndim == 2 | ||
| orig_out_features = weight_tensor.shape[-2] | ||
| y = y[:, :orig_out_features] | ||
| y = y.reshape(*orig_act_size[:-1], orig_out_features) | ||
|
|
||
| return y.to(orig_dtype) | ||
|
|
||
|
|
||
| Int4PlainInt32Tensor.__module__ = "torchao.quantization" | ||
|
|
||
| # Allow a model with Int4PlainInt32Tensor weights to be loaded with `weights_only=True` | ||
| torch.serialization.add_safe_globals([Int4PlainInt32Tensor]) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: revert change