Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions examples/vllm_serve/fakequant_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,34 @@ def disable_compilation(model):
quant_config: dict[str, Any] = {
"dataset": os.environ.get("QUANT_DATASET", "cnn_dailymail"),
"calib_size": int(os.environ.get("QUANT_CALIB_SIZE", 512)),
"quant_cfg": os.environ.get("QUANT_CFG", "NVFP4_DEFAULT_CFG"),
"quant_cfg": os.environ.get("QUANT_CFG", None),
"kv_quant_cfg": os.environ.get("KV_QUANT_CFG", None),
"amax_file_path": os.environ.get("AMAX_FILE_PATH", None),
}


def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: dict[str, Any]) -> dict[str, Any]:
"""Update KV cache quantization config for MLA models.

MLA uses `kv_c_bmm_quantizer` (compressed KV) instead of separate
`k_bmm_quantizer` and `v_bmm_quantizer`. This function copies the
config from `*[kv]_bmm_quantizer` to also cover `*kv_c_bmm_quantizer`.
"""
try:
from vllm.attention.layer import MLAAttention
except ImportError:
return kv_quant_cfg

if not any(isinstance(m, MLAAttention) for m in model.modules()):
return kv_quant_cfg

if kv_config := kv_quant_cfg.get("*[kv]_bmm_quantizer"):
kv_quant_cfg["*kv_c_bmm_quantizer"] = kv_config
print("MLA detected: added *kv_c_bmm_quantizer config")

return kv_quant_cfg


def _create_new_data_cls(data_cls, **kwargs):
"""vLLM's low-level API changes frequently. This function creates a class with parameters
compatible with the different vLLM versions."""
Expand Down Expand Up @@ -237,15 +259,21 @@ def calibrate_loop(model: Any = None) -> None:
self.sample_tokens(None)

quant_cfg = getattr(mtq, quant_config["quant_cfg"])
if quant_config["kv_quant_cfg"] is not None:
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
quant_cfg, getattr(mtq, quant_config["kv_quant_cfg"])["quant_cfg"]
)
quant_kv_cfg = getattr(mtq, quant_config["kv_quant_cfg"])

model = self.model_runner.model
if hasattr(model, "unwrap"):
model = model.unwrap()

# Check if model has MLA and update KV config accordingly
if quant_kv_cfg:
quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"])

if quant_kv_cfg:
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
quant_cfg, quant_kv_cfg["quant_cfg"]
)

with disable_compilation(model):
print("quantizing model...")
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
Expand Down Expand Up @@ -314,6 +342,6 @@ def determine_available_memory(self) -> int:
return super().determine_available_memory()

def compile_or_warm_up_model(self) -> None:
if quant_config["quant_cfg"]:
if quant_config["quant_cfg"] or quant_config["kv_quant_cfg"]:
_fakequant_run_prolog_worker(self)
super().compile_or_warm_up_model()
8 changes: 7 additions & 1 deletion examples/vllm_serve/vllm_serve_fakequant.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@


# Adding the envs you want to pass to the workers
additional_env_vars = {"QUANT_DATASET", "QUANT_CALIB_SIZE", "QUANT_CFG", "AMAX_FILE_PATH"}
additional_env_vars = {
"QUANT_DATASET",
"QUANT_CALIB_SIZE",
"QUANT_CFG",
"AMAX_FILE_PATH",
"KV_QUANT_CFG",
}

RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars)

Expand Down
28 changes: 26 additions & 2 deletions modelopt/torch/quantization/plugins/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@
except ImportError:
continue

try:
from vllm.attention.layer import MLAAttention as VllmMLAAttention
except ImportError:
VllmMLAAttention = None

vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe")


Expand Down Expand Up @@ -262,13 +267,17 @@ class _QuantVLLMAttention(QuantModule):
def _setup(self):
self.q_bmm_quantizer = TensorQuantizer()
self.k_bmm_quantizer = TensorQuantizer()
self.v_bmm_quantizer = TensorQuantizer()
# required for vllm < 0.11.1
if not self.use_mla:
self.v_bmm_quantizer = TensorQuantizer()
self.parallel_state = create_parallel_state()

def forward(self, query, key, value, *args, **kwargs):
query = self.q_bmm_quantizer(query)
key = self.k_bmm_quantizer(key)
value = self.v_bmm_quantizer(value)
# required for vllm < 0.11.1
if not self.use_mla:
value = self.v_bmm_quantizer(value)

return super().forward(query, key, value, *args, **kwargs)

Expand All @@ -281,3 +290,18 @@ class _QuantVLLMCrossAttention(_QuantVLLMAttention):
@QuantModuleRegistry.register({EncoderOnlyAttention: "vllm_EncoderOnlyAttention"})
class _QuantVLLMEncoderOnlyAttention(_QuantVLLMAttention):
pass


if VllmMLAAttention is not None:

@QuantModuleRegistry.register({VllmMLAAttention: "vllm_MLAAttention"})
class _QuantVLLMMLAAttention(QuantModule):
def _setup(self):
self.q_bmm_quantizer = TensorQuantizer()
self.kv_c_bmm_quantizer = TensorQuantizer()
self.parallel_state = create_parallel_state()

def forward(self, query, kv_c, *args, **kwargs):
query = self.q_bmm_quantizer(query)
kv_c = self.kv_c_bmm_quantizer(kv_c)
return super().forward(query, kv_c, *args, **kwargs)