From df5a0e543427ad34d76b42f432840d0f0fcbf36e Mon Sep 17 00:00:00 2001 From: binghanc <176802681+binghanc@users.noreply.github.com> Date: Tue, 18 Nov 2025 02:01:11 -0800 Subject: [PATCH 1/2] support for dpskr1_nvfp4_v3 --- examples/deepseek/ptq.py | 79 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 2 deletions(-) diff --git a/examples/deepseek/ptq.py b/examples/deepseek/ptq.py index b091ddc0e..ffeed4cce 100644 --- a/examples/deepseek/ptq.py +++ b/examples/deepseek/ptq.py @@ -300,7 +300,7 @@ def calibrate_loop(model): # disable head that corresponds to lm_head (for the huggingface checkpoint) mtq_cfg["quant_cfg"]["*head*"] = {"enable": False} - allowed_mla_quant = [None, "per_tensor_fp8"] + allowed_mla_quant = [None, "per_tensor_fp8", "nvfp4_wq_a_wkv_a_wq_b_wo", "nvfp4_wq_a_wkv_a_wq_b_wo_fp8_wkv_b"] assert mla_quant in allowed_mla_quant, f"mla_quant must be {allowed_mla_quant}" if not mla_quant: @@ -308,12 +308,81 @@ def calibrate_loop(model): elif mla_quant == "per_tensor_fp8": mtq_cfg["quant_cfg"]["*attn*weight_quantizer"] = {"num_bits": (4, 3), "axis": None} mtq_cfg["quant_cfg"]["*attn*input_quantizer"] = {"num_bits": (4, 3), "axis": None} + elif mla_quant == "nvfp4_wq_a_wkv_a_wq_b_wo": # for DeepSeek-R1-0528-v3_1 + # Only quantize linear layers(wq_a, wq_b, wkv_a, wo) in MLA, not BMM operations + mla_linear_layers = ["*wq_a*", "*wq_b*", "*wkv_a*", "*wkv_b*", "*wo*"] # "*wq*" + mla_nvfp4_linear_layers = ["*wq_a*", "*wkv_a*", "*wq_b*", "*wo*"] + for layer in mla_linear_layers: + if layer in mla_nvfp4_linear_layers: + mtq_cfg["quant_cfg"][layer+"_quantizer"] = { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + } + else: + mtq_cfg["quant_cfg"][layer+"_quantizer"] = {"enable": False} + + # Disable BMM quantizers + mtq_cfg["quant_cfg"]["*attn.kv_bmm_quantizer*"] = {"enable": False} + mtq_cfg["quant_cfg"]["*attn.pe_bmm_quantizer*"] = {"enable": False} + + elif mla_quant == "nvfp4_wq_a_wkv_a_wq_b_wo_fp8_wkv_b": # for DeepSeek-R1-0528-v3_2 + # wq_a, wkv_a, wq_b, wo use NVFP4 + # wkv_b uses FP8 per-tensor quantization (weight: normal scale, activation: scale=1) + mla_linear_layers = ["*wq_a*", "*wq_b*", "*wkv_a*", "*wkv_b*", "*wo*"] + mla_nvfp4_linear_layers = ["*wq_a*", "*wkv_a*", "*wq_b*", "*wo*"] + + for layer in mla_linear_layers: + if layer in mla_nvfp4_linear_layers: + # NVFP4 quantization + mtq_cfg["quant_cfg"][layer+"_quantizer"] = { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + } + elif layer == "*wkv_b*": + # wkv_b uses FP8 per-tensor quantization + mtq_cfg["quant_cfg"][layer+"weight_quantizer"] = { + "num_bits": (4, 3), # FP8 + "axis": None, + "enable": True, + } + mtq_cfg["quant_cfg"][layer+"input_quantizer"] = { + "num_bits": (4, 3), # FP8 + "axis": None, + "enable": True, + } + else: + mtq_cfg["quant_cfg"][layer+"_quantizer"] = {"enable": False} + + # Disable BMM quantizers + mtq_cfg["quant_cfg"]["*attn.kv_bmm_quantizer*"] = {"enable": False} + mtq_cfg["quant_cfg"]["*attn.pe_bmm_quantizer*"] = {"enable": False} if not args.disable_wo_quant and "FP4" in quant_cfg: mtq_cfg["quant_cfg"]["*wo*weight_quantizer"] = mtq_cfg["quant_cfg"]["*input_quantizer"] mtq_cfg["quant_cfg"]["*wo*input_quantizer"] = mtq_cfg["quant_cfg"]["*weight_quantizer"] ## ptq transformer = mtq.quantize(transformer, mtq_cfg, calibrate_loop) + + # Force wkv_b activation scale=1 for nvfp4_wq_a_wkv_a_wq_b_wo_fp8_wkv_b + if mla_quant == "nvfp4_wq_a_wkv_a_wq_b_wo_fp8_wkv_b": + fp8_max_value = 448.0 # FP8 E4M3 max value + + for name, module in transformer.named_modules(): + # Match wkv_b layers + if "wkv_b" in name: + if hasattr(module, 'input_quantizer') and module.input_quantizer.is_enabled: + # Force activation amax = 448.0, so scale = amax/448.0 = 1.0 + if int(os.environ.get("LOCAL_RANK", "0")) == 0: + old_amax = module.input_quantizer._amax.data.clone() + module.input_quantizer._amax.data.fill_(fp8_max_value) + print(f"[INFO] Forced {name}.input_quantizer amax from {old_amax.item()} to {fp8_max_value}") + else: + module.input_quantizer._amax.data.fill_(fp8_max_value) + if int(os.environ["LOCAL_RANK"]) == 0: mtq.print_quant_summary(transformer) @@ -396,11 +465,17 @@ def state_dict_filter(state_dict): parser.add_argument("--disable_fp8_kvcache", action="store_true", help="disable fp8 kvcache.") parser.add_argument("--disable_wo_quant", action="store_true", help="disable MLA wo quant.") parser.add_argument("--trust_remote_code", action="store_true", help="trust remote code.") + parser.add_argument( + "--mla_quant", + type=str, + default=None, + help="MLA quantization type: None (disable), per_tensor_fp8, nvfp4 (all), or nvfp4_linear_only (linear layers only)" + ) args = parser.parse_args() model = load_deepseek_model(args.config, args.model_path, args.batch_size) tokenizer = AutoTokenizer.from_pretrained( args.model_path, trust_remote_code=args.trust_remote_code ) - model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size) + model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size, args.mla_quant) save_amax_and_quant_config(model, args.output_path, not args.disable_fp8_kvcache) From d2e618cefbc11d69e51910758e72e4be2adad14e Mon Sep 17 00:00:00 2001 From: binghanc <176802681+binghanc@users.noreply.github.com> Date: Thu, 20 Nov 2025 03:15:47 +0000 Subject: [PATCH 2/2] modify argparse help info for mla_quant --- examples/deepseek/ptq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/deepseek/ptq.py b/examples/deepseek/ptq.py index ffeed4cce..4f236d6eb 100644 --- a/examples/deepseek/ptq.py +++ b/examples/deepseek/ptq.py @@ -469,7 +469,7 @@ def state_dict_filter(state_dict): "--mla_quant", type=str, default=None, - help="MLA quantization type: None (disable), per_tensor_fp8, nvfp4 (all), or nvfp4_linear_only (linear layers only)" + help="MLA quantization type: None (disable), per_tensor_fp8, nvfp4_wq_a_wkv_a_wq_b_wo, or nvfp4_wq_a_wkv_a_wq_b_wo_fp8_wkv_b" ) args = parser.parse_args()