Skip to content

Commit d0a7362

Browse files
[ROCm][Quantization] add apply_vllm_mapper in quark config for models like gpt-oss (#28638)
Signed-off-by: xuebwang-amd <xuebwang@amd.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 88ab591 commit d0a7362

File tree

1 file changed

+30
-5
lines changed
  • vllm/model_executor/layers/quantization/quark

1 file changed

+30
-5
lines changed

vllm/model_executor/layers/quantization/quark/quark.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
deep_compare,
3333
should_ignore_layer,
3434
)
35+
from vllm.model_executor.models.utils import WeightsMapper
3536
from vllm.platforms import current_platform
3637

3738
if TYPE_CHECKING:
@@ -57,7 +58,6 @@ def __init__(
5758
self.kv_cache_group = kv_cache_group
5859
self.kv_cache_config = kv_cache_config
5960
self.pack_method = pack_method
60-
self.ignore: list[str] = cast(list[str], self.quant_config.get("exclude", []))
6161

6262
def get_linear_method(self) -> "QuarkLinearMethod":
6363
return QuarkLinearMethod(self)
@@ -72,14 +72,42 @@ def get_min_capability(cls) -> int:
7272
def get_name(self) -> QuantizationMethods:
7373
return "quark"
7474

75+
def apply_vllm_mapper( # noqa: B027
76+
self, hf_to_vllm_mapper: "WeightsMapper"
77+
):
78+
"""
79+
Interface for models to update module names referenced in
80+
quantization configs in order to reflect the vllm model structure
81+
82+
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
83+
structure of the qconfig) to vllm model structure
84+
"""
85+
quant_config_with_hf_to_vllm_mapper = {}
86+
87+
for k, v in self.quant_config.items():
88+
if isinstance(v, list):
89+
quant_config_with_hf_to_vllm_mapper[k] = hf_to_vllm_mapper.apply_list(v)
90+
elif isinstance(v, dict):
91+
quant_config_with_hf_to_vllm_mapper[k] = hf_to_vllm_mapper.apply_dict(v)
92+
else:
93+
if isinstance(v, str):
94+
mapped_v_list = hf_to_vllm_mapper.apply_list([v])
95+
if mapped_v_list:
96+
quant_config_with_hf_to_vllm_mapper[k] = mapped_v_list[0]
97+
else:
98+
quant_config_with_hf_to_vllm_mapper[k] = v
99+
100+
self.quant_config = quant_config_with_hf_to_vllm_mapper
101+
75102
def get_quant_method(
76103
self, layer: torch.nn.Module, prefix: str
77104
) -> Optional["QuantizeMethodBase"]:
78105
from vllm.attention.layer import Attention # Avoid circular import
79106

80107
# Check if the layer is skipped for quantization.
108+
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
81109
if should_ignore_layer(
82-
prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
110+
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
83111
):
84112
return UnquantizedLinearMethod()
85113
if isinstance(layer, LinearBase):
@@ -93,9 +121,6 @@ def get_quant_method(
93121
return QuarkMoEMethod.get_moe_method(self, module=layer, layer_name=prefix)
94122
return None
95123

96-
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
97-
self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)
98-
99124
@classmethod
100125
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
101126
export_config = config.get("export")

0 commit comments

Comments
 (0)