Skip to content

Commit 927de72

Browse files
Auto module tree (#2204)
* add _auto_detect_module_tree * cleanup * add test_auto_detect_module_tree * cleanup * fix from_quantized * use warn level * mod log * cleanup * pass quant_method * cleanup
1 parent ec70ee6 commit 927de72

File tree

3 files changed

+150
-9
lines changed

3 files changed

+150
-9
lines changed

gptqmodel/models/auto.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,15 @@ def _is_supported_quantization_config(config: AutoConfig) -> bool:
271271
return False
272272

273273

274-
def check_and_get_model_type(model_dir, trust_remote_code=False):
274+
def check_and_get_model_definition(model_dir, trust_remote_code=False):
275275
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
276-
if config.model_type.lower() not in SUPPORTED_MODELS:
277-
raise TypeError(f"{config.model_type} isn't supported yet.")
278-
model_type = config.model_type
279-
return model_type.lower()
276+
model_type = config.model_type.lower()
277+
278+
# if model_type is not supported, use BaseQModel, will use auto_detect_module_tree to generate module tree
279+
if model_type not in SUPPORTED_MODELS:
280+
return BaseQModel
281+
282+
return MODEL_MAP[model_type]
280283

281284
class GPTQModel:
282285
def __init__(self):
@@ -372,8 +375,9 @@ def from_pretrained(
372375
log.warn(
373376
"GPTQModel's per-module `dynamic` quantization feature is fully supported in latest vLLM and SGLang but not yet available in hf transformers.")
374377

375-
model_type = check_and_get_model_type(model_id_or_path, trust_remote_code)
376-
return MODEL_MAP[model_type].from_pretrained(
378+
model_definition = check_and_get_model_definition(model_id_or_path, trust_remote_code)
379+
380+
return model_definition.from_pretrained(
377381
pretrained_model_id_or_path=model_id_or_path,
378382
quantize_config=quantize_config,
379383
trust_remote_code=trust_remote_code,
@@ -395,12 +399,12 @@ def from_quantized(
395399
adapter = normalize_adapter(adapter)
396400

397401
print(f"from_quantized: adapter: {adapter}")
398-
model_type = check_and_get_model_type(model_id_or_path, trust_remote_code)
402+
model_definition = check_and_get_model_definition(model_id_or_path, trust_remote_code)
399403

400404
if isinstance(backend, str):
401405
backend = BACKEND(backend)
402406

403-
return MODEL_MAP[model_type].from_quantized(
407+
return model_definition.from_quantized(
404408
model_id_or_path=model_id_or_path,
405409
device_map=device_map,
406410
device=device,

gptqmodel/models/base.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,14 @@ def __init__(
236236
# setting cls.module_tree
237237
type(self).module_tree = apply_module_tree_override(self.module_tree, self.module_tree_overrides[quant_method])
238238

239+
if type(self).module_tree is None:
240+
type(self).module_tree = self._auto_detect_module_tree(model, quant_method)
241+
242+
# If module_tree is still None after auto-detection, raise an error indicating unsupported model type
243+
if type(self).module_tree is None:
244+
raise ValueError(f"Unsupport model_type {model.config.model_type}, and failed to auto-detect module tree for model {model}")
245+
246+
239247
# record configuration early so model lifecycle hooks can rely on them
240248
self.compiled = False # set to True while compile() is triggered successfully
241249
self.quantized = quantized
@@ -1657,6 +1665,91 @@ def __getattr__(self, item):
16571665
return getattr(model, item)
16581666
raise exc
16591667

1668+
def _auto_detect_module_tree(self, model: PreTrainedModel, quant_method: METHOD):
1669+
log.warn("Model not yet support, attempting Module Tree AutoCompat...")
1670+
1671+
if quant_method != METHOD.GPTQ:
1672+
log.warn(f"Module Tree AutoCompat: Failed, quant_method={quant_method}, only support GPTQ")
1673+
return None
1674+
1675+
def _get(path):
1676+
base = model
1677+
for p in path.split("."):
1678+
base = getattr(base, p, None)
1679+
if base is None:
1680+
return None
1681+
return base
1682+
1683+
candidates = [
1684+
"model.layers",
1685+
"language_model.layers",
1686+
"model.decoder.layers",
1687+
"transformer.h",
1688+
"transformer.blocks",
1689+
"layers",
1690+
"blocks",
1691+
"model.blocks",
1692+
]
1693+
1694+
chosen = None
1695+
for c in candidates:
1696+
m = _get(c)
1697+
if isinstance(m, (nn.ModuleList, list, tuple)) and len(m) > 0 and isinstance(m[0], nn.Module):
1698+
chosen = c
1699+
log.warn(f"Module Tree AutoCompat: Matched candidate path '{c}', type={type(m).__name__}")
1700+
break
1701+
1702+
if chosen is None:
1703+
log.warn("Module Tree AutoCompat: All candidate paths invalid, return None")
1704+
return None
1705+
1706+
layer0 = _get(chosen)[0]
1707+
log.warn(f"Module Tree AutoCompat: Using layer0: {type(layer0).__name__}")
1708+
1709+
def _linear_names(module):
1710+
mods = find_modules(module, layers=[nn.Linear, nn.Conv1d, nn.Conv2d])
1711+
log.warn(f"Module Tree AutoCompat: _linear_names: found {len(mods)} Linear/Conv modules in {type(module).__name__}")
1712+
return list(mods.keys())
1713+
1714+
all_linear = _linear_names(layer0)
1715+
if len(all_linear)>0:
1716+
log.warn(f"Module Tree AutoCompat: found {len(all_linear)} Linear/Conv modules in {type(layer0).__name__}: {all_linear}")
1717+
else:
1718+
log.warn(f"Module Tree AutoCompat: No Linear/Conv names in layer0, return None")
1719+
return None
1720+
1721+
mapping = {}
1722+
1723+
def _find_parents(module, possible_names):
1724+
found = set()
1725+
for n, _ in module.named_children():
1726+
l = n.lower()
1727+
if any(k in l for k in possible_names):
1728+
found.add(n)
1729+
return found
1730+
1731+
def _leaf_tokens(prefix):
1732+
return tuple(x.split(".")[-1] for x in all_linear if x.startswith(f"{prefix}."))
1733+
1734+
possible_parent = ["attn", "attention", "self_attn", "mlp", "ffn", "feed", "dense"]
1735+
1736+
found_parents = _find_parents(layer0, possible_parent)
1737+
1738+
for p in found_parents:
1739+
t = _leaf_tokens(p)
1740+
if t:
1741+
mapping[p] = t
1742+
1743+
if not mapping:
1744+
blocks = tuple(n.split(".")[-1] for n in all_linear)
1745+
mapping[""] = blocks
1746+
log.warn(f"Module Tree AutoCompat: Mapping empty, using all Linear as fallback: {blocks}")
1747+
1748+
parts = chosen.split(".")
1749+
tree = parts + ["#", mapping]
1750+
log.warn(f"Module Tree AutoCompat: Final module_tree: {tree}")
1751+
return tree
1752+
16601753
__all__ = ["BaseQModel"]
16611754

16621755
BaseQModel = ModelLoader(ModelWriter(BaseQModel))
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import unittest
2+
import torch.nn as nn
3+
4+
from gptqmodel.models.base import BaseQModel
5+
6+
7+
class DummyAttention(nn.Module):
8+
def __init__(self):
9+
super().__init__()
10+
self.q_proj = nn.Linear(4, 4)
11+
self.k_proj = nn.Linear(4, 4)
12+
13+
14+
class DummyMLP(nn.Module):
15+
def __init__(self):
16+
super().__init__()
17+
self.fc1 = nn.Linear(4, 4)
18+
self.fc2 = nn.Linear(4, 4)
19+
20+
21+
class DummyBlock(nn.Module):
22+
def __init__(self):
23+
super().__init__()
24+
self.self_attn = DummyAttention()
25+
self.mlp = DummyMLP()
26+
27+
28+
class DummyModel:
29+
def __init__(self):
30+
self.layers = nn.ModuleList([DummyBlock()])
31+
32+
33+
class TestAutoDetectModuleTree(unittest.TestCase):
34+
def test_layers_with_parents(self):
35+
model = DummyModel()
36+
base = BaseQModel.__new__(BaseQModel)
37+
tree = base._auto_detect_module_tree(model, quant_method="gptq")
38+
self.assertEqual(tree[0], "layers")
39+
self.assertEqual(tree[1], "#")
40+
mapping = tree[2]
41+
self.assertIn("self_attn", mapping)
42+
self.assertIn("mlp", mapping)
43+
self.assertSetEqual(set(mapping["self_attn"]), {"q_proj", "k_proj"})
44+
self.assertSetEqual(set(mapping["mlp"]), {"fc1", "fc2"})

0 commit comments

Comments
 (0)