Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_config(group_size):

@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
@unittest.skipIf(not torch.xpu.is_available(), "XPU not available")
class Int4PlainInt32Tensor(TestCase):
class Int4PlainInt32TensorXPU(TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: revert change

@parametrize(
"sizes",
[
Expand Down Expand Up @@ -98,8 +98,75 @@ def test_activation_prescaling(self):
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20)


instantiate_parametrized_tests(Int4PlainInt32Tensor)
@unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+")
@unittest.skipIf(
torch.accelerator.current_accelerator().type != "npu"
or not torch.accelerator.is_available(),
"NPU not available",
)
class Int4PlainInt32TensorNPU(TestCase):
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be merged with the xpu case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I’ll combine them into a single test class.


@parametrize("device", ["npu"])
@parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 512, 128),
((2, 32, 128), 256, 128),
],
)
@parametrize("dtype", [torch.float16, torch.bfloat16])
@parametrize("group_size", [32, 64])
def test_linear(self, device, sizes, dtype, group_size):
M, N, K = sizes
input = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
orig_output = linear(input)
quantize_(linear, get_config(group_size))
quantized_output = linear(input)
self.assertTrue(compute_error(orig_output, quantized_output) > 10)

@parametrize("device", ["npu"])
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_module_path(self, device, dtype):
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
quantize_(linear, get_config(group_size=64))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(linear.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
)

@parametrize("device", ["npu"])
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_activation_prescaling(self, device, dtype):
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(64))
qw = linear.weight
assert isinstance(
qw, SupportsActivationPreScaling
), "Expected int4 tensor supports activation prescaling"
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
_ACT_PRE_SCALE = 2
qw.act_pre_scale = _ACT_PRE_SCALE
quantized = linear(input)

# making sure activation pre scaling is successfully applied to the activation
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10)


instantiate_parametrized_tests(Int4PlainInt32TensorXPU)
instantiate_parametrized_tests(Int4PlainInt32TensorNPU)

if __name__ == "__main__":
run_tests()
275 changes: 223 additions & 52 deletions torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,60 +91,152 @@ def from_hp(
w: torch.Tensor,
block_size: List[int],
):
assert w.ndim == 2 and w.device.type == "xpu", (
f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}"
)
assert len(block_size) == w.ndim
assert w.dtype in [torch.float16, torch.bfloat16], (
f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}"
)
original_shape = w.shape
mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
scale_dtype = None
zero_point_dtype = torch.int32
scale, zero_point = choose_qparams_affine(
w,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
)
int_data = quantize_affine(
w,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
)
assert int_data.dtype == torch.int32, (
"torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype"
)
packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
packed_weight.contiguous(), 8
)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
return Int4PlainInt32Tensor(
packed_weight,
scale.transpose(0, 1).contiguous(),
zero_point.transpose(0, 1).contiguous().to(torch.int8),
block_size,
original_shape,
act_pre_scale=None,
)
if w.device.type == "xpu":
return _from_hp_xpu(cls, w, block_size)
elif w.device.type == "npu":
return _from_hp_npu(cls, w, block_size)
else:
raise AssertionError(f"Int4PlainInt32Tensor does not support device '{w.device.type}' yet.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ValueError or NotImplementedError here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


def _from_hp_xpu(
cls,
w: torch.Tensor,
block_size: List[int],
):
assert w.ndim == 2 and w.device.type == "xpu", (
f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}"
)
assert len(block_size) == w.ndim
assert w.dtype in [torch.float16, torch.bfloat16], (
f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}"
)
original_shape = w.shape
mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
scale_dtype = None
zero_point_dtype = torch.int32
scale, zero_point = choose_qparams_affine(
w,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
)
int_data = quantize_affine(
w,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
)
assert int_data.dtype == torch.int32, (
"torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype"
)
packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
packed_weight.contiguous(), 8
)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
return Int4PlainInt32Tensor(
packed_weight,
scale.transpose(0, 1).contiguous(),
zero_point.transpose(0, 1).contiguous().to(torch.int8),
block_size,
original_shape,
act_pre_scale=None,
)

def _from_hp_npu(
cls,
w: torch.Tensor,
block_size: List[int],
):
assert w.ndim == 2 and w.device.type == "npu", (
f"Expecting 2D tensor on NPU, but got: {w.shape} on {w.device.type}"
)
assert len(block_size) == w.ndim
assert w.dtype in [torch.float16, torch.bfloat16], (
f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}"
)

group_size = block_size[1]
k_dim = w.shape[-1]
assert (
group_size >= 32
and group_size % 32 == 0
and group_size < k_dim
), (
f"Invalid group_size={group_size}: "
f"expected to be a multiple of 32, "
f"in range [32, {k_dim - 1}] for per-group quantization, "
f"but got group_size={group_size} (k_dim={k_dim})."
)

original_shape = w.shape
mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int32
quant_min = -8
quant_max = 7
eps = 1e-6
scale_dtype = w.dtype
zero_point_dtype = w.dtype

scale, zero_point = choose_qparams_affine(
w,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
)

int_data = quantize_affine(
w,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
)

assert int_data.dtype == torch.int32, (
"torch.ops.npu.npu_convert_weight_to_int4pack expects `int32` dtype"
)
assert int_data.shape[-1] % 8 == 0, (
f"torch.ops.npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}"
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we ever run into a case where we have NPU support but this op is missing? Maybe in an earlier version of torch_npu? Should we throw a cleaner error message in that case?

It'd be good to add a comment here on where this op is defined and what version of torch npu is needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder — since torch and torch_npu versions are tightly coupled, I added

# Require PyTorch 2.7.1+ for NPU backend ops and backward compatibility.
assert torch_version_at_least("2.7.1"), (
    "Need PyTorch 2.7.1+ for NPU backend op support."
)

at the beginning. Does this make the version requirement clear enough?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be a tad more explicit, I want to make it clear it's PyTorch NPU >= 2.7.1 and not regular torch

assert (torch.accelerator.is_available() and torch.accelerator.current_accelerator().type == "npu" and torch_version_at_least("2.7.1"), (
    f"PyTorch NPU 2.7.1+ needed for int4 packing and matmul ops, {torch.__version__} found"
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified

packed_weight = torch.ops.npu.npu_convert_weight_to_int4pack(
int_data.contiguous(), 0
)

scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)

return Int4PlainInt32Tensor(
packed_weight,
scale.transpose(0, 1).contiguous(),
zero_point.transpose(0, 1).contiguous(),
block_size,
original_shape,
act_pre_scale=None,
)


implements = Int4PlainInt32Tensor.implements
implements_torch_function = Int4PlainInt32Tensor.implements_torch_function

Expand All @@ -157,6 +249,20 @@ def _(func, types, args, kwargs):
args[1],
args[2] if len(args) > 2 else None,
)

if input_tensor.device.type == "xpu":
return _linear_xpu(input_tensor, weight_tensor, bias)
elif input_tensor.device.type == "npu":
return _linear_npu(input_tensor, weight_tensor, bias)
else:
raise AssertionError(f"Int4PlainInt32Tensor does not support device '{input_tensor.device.type}' yet.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: NotImplementedError or ValueError is better here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed



def _linear_xpu(
input_tensor,
weight_tensor,
bias,
):
assert input_tensor.device.type == "xpu", (
f"For XPU device only but got: {input_tensor.device}"
)
Expand Down Expand Up @@ -200,8 +306,73 @@ def _(func, types, args, kwargs):
y += bias
return y.to(orig_dtype)

def _linear_npu(
input_tensor,
weight_tensor,
bias,
):
assert input_tensor.device.type == "npu", (
f"For NPU device only but got: {input_tensor.device.type}"
)
assert isinstance(weight_tensor, Int4PlainInt32Tensor), (
f"Expected weight_tensor to be Int4PlainInt32NPUTensor, got: {type(weight_tensor)}"
)
assert weight_tensor.block_size[0] == 1, (
f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
)
assert input_tensor.shape[-1] == weight_tensor.shape[1], (
f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}"
)

if weight_tensor.act_pre_scale is not None:
input_tensor = input_tensor * weight_tensor.act_pre_scale

act_mat = input_tensor
packed_weight = weight_tensor.qdata
scale = weight_tensor.scale
zero_point = weight_tensor.zero_point

orig_act_size = act_mat.shape
orig_dtype = act_mat.dtype

# dtype alignment
if act_mat.dtype == torch.float16:
scale = scale.to(torch.float16)
zero_point = zero_point.to(torch.float16)
if bias is not None:
bias = bias.to(torch.float16)
elif act_mat.dtype == torch.bfloat16:
scale = scale.to(torch.bfloat16)
zero_point = zero_point.to(torch.bfloat16)
if bias is not None:
bias = bias.to(torch.float32)

# reshape to 2D
act_mat = act_mat.reshape(-1, act_mat.shape[-1])

# groupwise int4 quantization
groupsize = weight_tensor.block_size[1]

y = torch.ops.npu.npu_weight_quant_batchmatmul(
x=act_mat,
weight=packed_weight.contiguous().transpose(-1, -2),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to call contiguous() every time we do matmul? should we save the packed_weight in contiguous format instead to only do this once?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Addressed — packed_weight are now made contiguous once when constructing the Int4PlainInt32Tensor.

antiquant_scale=scale,
antiquant_offset=zero_point,
antiquant_group_size=groupsize,
bias=bias,
)

# remove out_feature padding
assert weight_tensor.ndim == 2
orig_out_features = weight_tensor.shape[-2]
y = y[:, :orig_out_features]
y = y.reshape(*orig_act_size[:-1], orig_out_features)

return y.to(orig_dtype)


Int4PlainInt32Tensor.__module__ = "torchao.quantization"

# Allow a model with Int4PlainInt32Tensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([Int4PlainInt32Tensor])

Loading