Skip to content

Commit 52d2c42

Browse files
fix 3 bit packing regression, fixed #1278 (#1280)
* fix torch's pack() * allow setting bits & quant backend * add 3 bit test * fix 3 bit packing in base * revert data clone changes * remove 3bits test in q4 cuda * fix error was printed but ignored * add delta windows size * update scores * fix score
1 parent 3ead8c1 commit 52d2c42

File tree

4 files changed

+47
-39
lines changed

4 files changed

+47
-39
lines changed

gptqmodel/nn_modules/qlinear/__init__.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -344,23 +344,24 @@ def pack(self, linear, scales, zeros, g_idx=None):
344344
elif self.bits == 3:
345345
i = 0
346346
col = 0
347-
for j in range(i, i + 10):
348-
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
349-
i += 10
350-
qzeros[:, col] |= zeros[:, i] << 30
351-
col += 1
352-
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
353-
i += 1
354-
for j in range(i, i + 10):
355-
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
356-
i += 10
357-
qzeros[:, col] |= zeros[:, i] << 31
358-
col += 1
359-
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
360-
i += 1
361-
for j in range(i, i + 10):
362-
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
363-
i += 10
364-
col += 1
347+
while col < qzeros.shape[1]:
348+
for j in range(i, i + 10):
349+
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
350+
i += 10
351+
qzeros[:, col] |= zeros[:, i] << 30
352+
col += 1
353+
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
354+
i += 1
355+
for j in range(i, i + 10):
356+
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
357+
i += 10
358+
qzeros[:, col] |= zeros[:, i] << 31
359+
col += 1
360+
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
361+
i += 1
362+
for j in range(i, i + 10):
363+
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
364+
i += 10
365+
col += 1
365366

366367
self.qzeros = t.from_numpy(qzeros.astype(self.pack_np_dtype))

tests/models/model_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class ModelTest(unittest.TestCase):
5858
TORCH_DTYPE = "auto"
5959
BATCH_SIZE = "auto"
6060
LOAD_BACKEND = BACKEND.AUTO
61+
QUANT_BACKEND = BACKEND.AUTO
6162
USE_VLLM = False
6263
INPUTS_MAX_LENGTH = 2048
6364
MODEL_MAX_LEN = 4096
@@ -83,6 +84,8 @@ class ModelTest(unittest.TestCase):
8384
LM_HEAD_LOSS_MAX_DELTA_PERCENT = 0.1 # ±10%
8485
EXPECT_LM_HEAD_LOSS = None
8586

87+
QUANTIZE_CONFIG_BITS = 4
88+
8689
def assertInference(self, model, tokenizer=None, keywords=None, prompt=INFERENCE_PROMPT):
8790
# gptqmodel can auto init tokenizer internally
8891
if keywords is None:
@@ -148,7 +151,7 @@ def check_kernel(self, model, expected_kernels):
148151

149152
def quantModel(self, model_id_or_path, trust_remote_code=False, torch_dtype="auto", need_eval=True, batch_size: int = 4, **kwargs):
150153
quantize_config = QuantizeConfig(
151-
bits=4,
154+
bits=self.QUANTIZE_CONFIG_BITS,
152155
group_size=128,
153156
format=self.QUANT_FORMAT,
154157
desc_act=self.DESC_ACT,
@@ -189,7 +192,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, torch_dtype="aut
189192
is_ovis_model = model.__class__.__name__ == "OvisGPTQ"
190193
need_create_processor = is_image_to_text_model and not is_ovis_model
191194
if not is_quantized:
192-
model.quantize(calibration_dataset, batch_size=batch_size)
195+
model.quantize(calibration_dataset, backend=self.QUANT_BACKEND, batch_size=batch_size)
193196

194197
self.check_kernel(model, self.KERNEL_QUANT)
195198

tests/test_bits.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@ class TestBits(unittest.TestCase):
5454
BACKEND.MARLIN: MarlinQuantLinear,
5555
}
5656

57-
QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.025 # -2.5%
58-
QUANT_ARC_MAX_POSITIVE_DELTA_CEIL_PERCENT = 0.025 # +2.5%
57+
QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.1
58+
QUANT_ARC_MAX_POSITIVE_DELTA_CEIL_PERCENT = 0.1
5959

6060
CUDA_QLINEAR_QUANTIZED_MODEL_ARC_CHALLENGE_EXPECTS = {
61-
2: {'acc,none': 0.22610921501706485, 'acc_norm,none': 0.2909556313993174},
62-
3: {'acc,none': 0.21245733788395904, 'acc_norm,none': 0.24744027303754265},
63-
4: {'acc,none': 0.2738907849829352, 'acc_norm,none': 0.3122866894197952},
64-
8: {'acc,none': 0.2841296928327645, 'acc_norm,none': 0.302901023890785},
61+
2: {'acc,none': 0.2175767918088737, 'acc_norm,none': 0.26535836177474403},
62+
3: {'acc,none': 0.22696245733788395, 'acc_norm,none': 0.2627986348122867},
63+
4: {'acc,none': 0.26621160409556316, 'acc_norm,none': 0.3148464163822526},
64+
8: {'acc,none': 0.29948805460750855, 'acc_norm,none': 0.3293515358361775},
6565
}
6666

6767
def calculatorPer(self, filter, value, base_value):
@@ -92,22 +92,29 @@ def test_bits(self):
9292
# quantize
9393
model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0"
9494
tokenizer = AutoTokenizer.from_pretrained(model_id)
95-
dataset = [
96-
"gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."]
95+
dataset = ["gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."]
9796
calibration_dataset = [tokenizer(example) for example in dataset]
97+
98+
errors = []
9899
for quant_backend in self.pack_backends:
99100
supports_bits = self.QLINEAR_DICT[quant_backend].SUPPORTS_BITS
100101
for bits in supports_bits:
101-
print("-----------------------quant-----------------------")
102+
print(f"-----------------------quant backend: {quant_backend}-- bits: {bits} ---------------------")
102103
quantize_config = QuantizeConfig(bits=bits, group_size=128, sym=True, desc_act=False)
103104
print(f"bits: {quantize_config.bits}, quant_backend: {quant_backend} start quant")
104105
try:
105106
self.quant_and_eval(calibration_dataset, model_id, quant_backend, quantize_config, tokenizer)
106107
except Exception:
107-
print(f"bits: {quantize_config.bits}, quant_backend: {quant_backend} An error occurred")
108+
error_log=f"bits: {quantize_config.bits}, quant_backend: {quant_backend} An error occurred"
109+
print(error_log)
110+
errors.append(error_log)
111+
108112
traceback.print_exc()
113+
109114
continue
110115

116+
self.assertTrue(len(errors) == 0, '\n'.join(errors))
117+
111118
def quant_and_eval(self, calibration_dataset, model_id, quant_backend, quantize_config, tokenizer):
112119
model = GPTQModel.load(
113120
model_id,
@@ -127,11 +134,7 @@ def quant_and_eval(self, calibration_dataset, model_id, quant_backend, quantize_
127134
# Skip inference_backend that does not support the current bits
128135
continue
129136

130-
try:
131-
self.eval(inference_backend, quant_backend, quantize_config, tmp_dir)
132-
except Exception:
133-
traceback.print_exc()
134-
continue
137+
self.eval(inference_backend, quant_backend, quantize_config, tmp_dir)
135138

136139
def eval(self, inference_backend, quant_backend, quantize_config, tmp_dir):
137140
print("-----------------------eval-----------------------")
@@ -165,8 +168,7 @@ def eval(self, inference_backend, quant_backend, quantize_config, tmp_dir):
165168
metric: value for metric, value in results['results'].get(TASK_NAME, {}).items()
166169
if metric != 'alias' and 'stderr' not in metric
167170
}
168-
print(
169-
f"bits is: {quantize_config.bits}, quant_backend: {quant_backend}, inference_backend: {inference_backend} -> task_results: {task_results}")
171+
print(f"bits is: {quantize_config.bits}, quant_backend: {quant_backend}, inference_backend: {inference_backend} -> task_results: {task_results}")
170172
del model
171173

172174
self.check_results(quantize_config.bits, task_results)

tests/test_q4_cuda.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616

1717
# -- do not touch
1818
import os
19+
import tempfile
20+
21+
from gptqmodel.utils import Perplexity
1922

2023
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
2124
# -- end do not touch
2225

2326

2427
import torch # noqa: E402
25-
from gptqmodel import BACKEND, GPTQModel # noqa: E402
28+
from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402
2629
from models.model_test import ModelTest # noqa: E402
2730
from parameterized import parameterized # noqa: E402
2831
from transformers import AutoTokenizer # noqa: E402
@@ -74,4 +77,3 @@ def test_generation_desc_act_false(self, torch_dtype, device):
7477
self.assertInference(model=model_q,tokenizer=self.tokenizer)
7578
# This one does not.
7679
self.assertInference(model=model_q.model,tokenizer=self.tokenizer)
77-

0 commit comments

Comments
 (0)