@@ -236,6 +236,14 @@ def __init__(
236236 # setting cls.module_tree
237237 type(self ).module_tree = apply_module_tree_override (self .module_tree , self .module_tree_overrides [quant_method ])
238238
239+ if type (self ).module_tree is None :
240+ type(self ).module_tree = self ._auto_detect_module_tree (model , quant_method )
241+
242+ # If module_tree is still None after auto-detection, raise an error indicating unsupported model type
243+ if type (self ).module_tree is None :
244+ raise ValueError (f"Unsupport model_type { model .config .model_type } , and failed to auto-detect module tree for model { model } " )
245+
246+
239247 # record configuration early so model lifecycle hooks can rely on them
240248 self .compiled = False # set to True while compile() is triggered successfully
241249 self .quantized = quantized
@@ -1657,6 +1665,91 @@ def __getattr__(self, item):
16571665 return getattr (model , item )
16581666 raise exc
16591667
1668+ def _auto_detect_module_tree (self , model : PreTrainedModel , quant_method : METHOD ):
1669+ log .warn ("Model not yet support, attempting Module Tree AutoCompat..." )
1670+
1671+ if quant_method != METHOD .GPTQ :
1672+ log .warn (f"Module Tree AutoCompat: Failed, quant_method={ quant_method } , only support GPTQ" )
1673+ return None
1674+
1675+ def _get (path ):
1676+ base = model
1677+ for p in path .split ("." ):
1678+ base = getattr (base , p , None )
1679+ if base is None :
1680+ return None
1681+ return base
1682+
1683+ candidates = [
1684+ "model.layers" ,
1685+ "language_model.layers" ,
1686+ "model.decoder.layers" ,
1687+ "transformer.h" ,
1688+ "transformer.blocks" ,
1689+ "layers" ,
1690+ "blocks" ,
1691+ "model.blocks" ,
1692+ ]
1693+
1694+ chosen = None
1695+ for c in candidates :
1696+ m = _get (c )
1697+ if isinstance (m , (nn .ModuleList , list , tuple )) and len (m ) > 0 and isinstance (m [0 ], nn .Module ):
1698+ chosen = c
1699+ log .warn (f"Module Tree AutoCompat: Matched candidate path '{ c } ', type={ type (m ).__name__ } " )
1700+ break
1701+
1702+ if chosen is None :
1703+ log .warn ("Module Tree AutoCompat: All candidate paths invalid, return None" )
1704+ return None
1705+
1706+ layer0 = _get (chosen )[0 ]
1707+ log .warn (f"Module Tree AutoCompat: Using layer0: { type (layer0 ).__name__ } " )
1708+
1709+ def _linear_names (module ):
1710+ mods = find_modules (module , layers = [nn .Linear , nn .Conv1d , nn .Conv2d ])
1711+ log .warn (f"Module Tree AutoCompat: _linear_names: found { len (mods )} Linear/Conv modules in { type (module ).__name__ } " )
1712+ return list (mods .keys ())
1713+
1714+ all_linear = _linear_names (layer0 )
1715+ if len (all_linear )> 0 :
1716+ log .warn (f"Module Tree AutoCompat: found { len (all_linear )} Linear/Conv modules in { type (layer0 ).__name__ } : { all_linear } " )
1717+ else :
1718+ log .warn (f"Module Tree AutoCompat: No Linear/Conv names in layer0, return None" )
1719+ return None
1720+
1721+ mapping = {}
1722+
1723+ def _find_parents (module , possible_names ):
1724+ found = set ()
1725+ for n , _ in module .named_children ():
1726+ l = n .lower ()
1727+ if any (k in l for k in possible_names ):
1728+ found .add (n )
1729+ return found
1730+
1731+ def _leaf_tokens (prefix ):
1732+ return tuple (x .split ("." )[- 1 ] for x in all_linear if x .startswith (f"{ prefix } ." ))
1733+
1734+ possible_parent = ["attn" , "attention" , "self_attn" , "mlp" , "ffn" , "feed" , "dense" ]
1735+
1736+ found_parents = _find_parents (layer0 , possible_parent )
1737+
1738+ for p in found_parents :
1739+ t = _leaf_tokens (p )
1740+ if t :
1741+ mapping [p ] = t
1742+
1743+ if not mapping :
1744+ blocks = tuple (n .split ("." )[- 1 ] for n in all_linear )
1745+ mapping ["" ] = blocks
1746+ log .warn (f"Module Tree AutoCompat: Mapping empty, using all Linear as fallback: { blocks } " )
1747+
1748+ parts = chosen .split ("." )
1749+ tree = parts + ["#" , mapping ]
1750+ log .warn (f"Module Tree AutoCompat: Final module_tree: { tree } " )
1751+ return tree
1752+
16601753__all__ = ["BaseQModel" ]
16611754
16621755BaseQModel = ModelLoader (ModelWriter (BaseQModel ))
0 commit comments