3232 deep_compare ,
3333 should_ignore_layer ,
3434)
35+ from vllm .model_executor .models .utils import WeightsMapper
3536from vllm .platforms import current_platform
3637
3738if TYPE_CHECKING :
@@ -57,7 +58,6 @@ def __init__(
5758 self .kv_cache_group = kv_cache_group
5859 self .kv_cache_config = kv_cache_config
5960 self .pack_method = pack_method
60- self .ignore : list [str ] = cast (list [str ], self .quant_config .get ("exclude" , []))
6161
6262 def get_linear_method (self ) -> "QuarkLinearMethod" :
6363 return QuarkLinearMethod (self )
@@ -72,14 +72,42 @@ def get_min_capability(cls) -> int:
7272 def get_name (self ) -> QuantizationMethods :
7373 return "quark"
7474
75+ def apply_vllm_mapper ( # noqa: B027
76+ self , hf_to_vllm_mapper : "WeightsMapper"
77+ ):
78+ """
79+ Interface for models to update module names referenced in
80+ quantization configs in order to reflect the vllm model structure
81+
82+ :param hf_to_vllm_mapper: maps from hf model structure (the assumed
83+ structure of the qconfig) to vllm model structure
84+ """
85+ quant_config_with_hf_to_vllm_mapper = {}
86+
87+ for k , v in self .quant_config .items ():
88+ if isinstance (v , list ):
89+ quant_config_with_hf_to_vllm_mapper [k ] = hf_to_vllm_mapper .apply_list (v )
90+ elif isinstance (v , dict ):
91+ quant_config_with_hf_to_vllm_mapper [k ] = hf_to_vllm_mapper .apply_dict (v )
92+ else :
93+ if isinstance (v , str ):
94+ mapped_v_list = hf_to_vllm_mapper .apply_list ([v ])
95+ if mapped_v_list :
96+ quant_config_with_hf_to_vllm_mapper [k ] = mapped_v_list [0 ]
97+ else :
98+ quant_config_with_hf_to_vllm_mapper [k ] = v
99+
100+ self .quant_config = quant_config_with_hf_to_vllm_mapper
101+
75102 def get_quant_method (
76103 self , layer : torch .nn .Module , prefix : str
77104 ) -> Optional ["QuantizeMethodBase" ]:
78105 from vllm .attention .layer import Attention # Avoid circular import
79106
80107 # Check if the layer is skipped for quantization.
108+ exclude_layers = cast (list [str ], self .quant_config .get ("exclude" ))
81109 if should_ignore_layer (
82- prefix , ignore = self . ignore , fused_mapping = self .packed_modules_mapping
110+ prefix , ignore = exclude_layers , fused_mapping = self .packed_modules_mapping
83111 ):
84112 return UnquantizedLinearMethod ()
85113 if isinstance (layer , LinearBase ):
@@ -93,9 +121,6 @@ def get_quant_method(
93121 return QuarkMoEMethod .get_moe_method (self , module = layer , layer_name = prefix )
94122 return None
95123
96- def apply_vllm_mapper (self , hf_to_vllm_mapper : "WeightsMapper" ):
97- self .ignore = hf_to_vllm_mapper .apply_list (self .ignore )
98-
99124 @classmethod
100125 def from_config (cls , config : dict [str , Any ]) -> "QuarkConfig" :
101126 export_config = config .get ("export" )
0 commit comments