diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index 133f31a6c..9be41f72e 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -149,12 +149,34 @@ def disable_compilation(model): quant_config: dict[str, Any] = { "dataset": os.environ.get("QUANT_DATASET", "cnn_dailymail"), "calib_size": int(os.environ.get("QUANT_CALIB_SIZE", 512)), - "quant_cfg": os.environ.get("QUANT_CFG", "NVFP4_DEFAULT_CFG"), + "quant_cfg": os.environ.get("QUANT_CFG", None), "kv_quant_cfg": os.environ.get("KV_QUANT_CFG", None), "amax_file_path": os.environ.get("AMAX_FILE_PATH", None), } +def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: dict[str, Any]) -> dict[str, Any]: + """Update KV cache quantization config for MLA models. + + MLA uses `kv_c_bmm_quantizer` (compressed KV) instead of separate + `k_bmm_quantizer` and `v_bmm_quantizer`. This function copies the + config from `*[kv]_bmm_quantizer` to also cover `*kv_c_bmm_quantizer`. + """ + try: + from vllm.attention.layer import MLAAttention + except ImportError: + return kv_quant_cfg + + if not any(isinstance(m, MLAAttention) for m in model.modules()): + return kv_quant_cfg + + if kv_config := kv_quant_cfg.get("*[kv]_bmm_quantizer"): + kv_quant_cfg["*kv_c_bmm_quantizer"] = kv_config + print("MLA detected: added *kv_c_bmm_quantizer config") + + return kv_quant_cfg + + def _create_new_data_cls(data_cls, **kwargs): """vLLM's low-level API changes frequently. This function creates a class with parameters compatible with the different vLLM versions.""" @@ -237,15 +259,21 @@ def calibrate_loop(model: Any = None) -> None: self.sample_tokens(None) quant_cfg = getattr(mtq, quant_config["quant_cfg"]) - if quant_config["kv_quant_cfg"] is not None: - quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( - quant_cfg, getattr(mtq, quant_config["kv_quant_cfg"])["quant_cfg"] - ) + quant_kv_cfg = getattr(mtq, quant_config["kv_quant_cfg"]) model = self.model_runner.model if hasattr(model, "unwrap"): model = model.unwrap() + # Check if model has MLA and update KV config accordingly + if quant_kv_cfg: + quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"]) + + if quant_kv_cfg: + quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( + quant_cfg, quant_kv_cfg["quant_cfg"] + ) + with disable_compilation(model): print("quantizing model...") mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) @@ -314,6 +342,6 @@ def determine_available_memory(self) -> int: return super().determine_available_memory() def compile_or_warm_up_model(self) -> None: - if quant_config["quant_cfg"]: + if quant_config["quant_cfg"] or quant_config["kv_quant_cfg"]: _fakequant_run_prolog_worker(self) super().compile_or_warm_up_model() diff --git a/examples/vllm_serve/vllm_serve_fakequant.py b/examples/vllm_serve/vllm_serve_fakequant.py index b4b230ade..25483f2be 100644 --- a/examples/vllm_serve/vllm_serve_fakequant.py +++ b/examples/vllm_serve/vllm_serve_fakequant.py @@ -70,7 +70,13 @@ # Adding the envs you want to pass to the workers -additional_env_vars = {"QUANT_DATASET", "QUANT_CALIB_SIZE", "QUANT_CFG", "AMAX_FILE_PATH"} +additional_env_vars = { + "QUANT_DATASET", + "QUANT_CALIB_SIZE", + "QUANT_CFG", + "AMAX_FILE_PATH", + "KV_QUANT_CFG", +} RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars) diff --git a/modelopt/torch/quantization/plugins/vllm.py b/modelopt/torch/quantization/plugins/vllm.py index 99bc3d9ee..bd4b998f9 100644 --- a/modelopt/torch/quantization/plugins/vllm.py +++ b/modelopt/torch/quantization/plugins/vllm.py @@ -40,6 +40,11 @@ except ImportError: continue +try: + from vllm.attention.layer import MLAAttention as VllmMLAAttention +except ImportError: + VllmMLAAttention = None + vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe") @@ -262,13 +267,17 @@ class _QuantVLLMAttention(QuantModule): def _setup(self): self.q_bmm_quantizer = TensorQuantizer() self.k_bmm_quantizer = TensorQuantizer() - self.v_bmm_quantizer = TensorQuantizer() + # required for vllm < 0.11.1 + if not self.use_mla: + self.v_bmm_quantizer = TensorQuantizer() self.parallel_state = create_parallel_state() def forward(self, query, key, value, *args, **kwargs): query = self.q_bmm_quantizer(query) key = self.k_bmm_quantizer(key) - value = self.v_bmm_quantizer(value) + # required for vllm < 0.11.1 + if not self.use_mla: + value = self.v_bmm_quantizer(value) return super().forward(query, key, value, *args, **kwargs) @@ -281,3 +290,18 @@ class _QuantVLLMCrossAttention(_QuantVLLMAttention): @QuantModuleRegistry.register({EncoderOnlyAttention: "vllm_EncoderOnlyAttention"}) class _QuantVLLMEncoderOnlyAttention(_QuantVLLMAttention): pass + + +if VllmMLAAttention is not None: + + @QuantModuleRegistry.register({VllmMLAAttention: "vllm_MLAAttention"}) + class _QuantVLLMMLAAttention(QuantModule): + def _setup(self): + self.q_bmm_quantizer = TensorQuantizer() + self.kv_c_bmm_quantizer = TensorQuantizer() + self.parallel_state = create_parallel_state() + + def forward(self, query, kv_c, *args, **kwargs): + query = self.q_bmm_quantizer(query) + kv_c = self.kv_c_bmm_quantizer(kv_c) + return super().forward(query, kv_c, *args, **kwargs)