From 7430041bd082698723c0fcceb1bee9f67fac33e9 Mon Sep 17 00:00:00 2001 From: William Arnold Date: Wed, 11 Feb 2026 15:22:49 -0800 Subject: [PATCH 1/3] Add GLM-5 (HF bf16) support to DeepSeek quantization pipeline - Make DeepSeek-specific imports lazy (only loaded for --model_type deepseek) - Add load_hf_model() using device_map="auto" for single-node multi-GPU - Add dynamic MoE class discovery and registration for calibration - Add HF layer name patterns for quant config (q_a_proj, kv_a_proj, etc.) - Disable GLM-5 indexer/MTP layers from quantization - Make dist calls conditional for non-distributed HF path - Add --model_type flag to ptq.py, quantize_to_nvfp4.py, and shell script - Skip key remapping in quantize_to_nvfp4.py for HF models - Guard quantization_config removal for bf16 checkpoints - Add run_glm5_ptq.sh launch script Co-Authored-By: Claude Opus 4.6 --- examples/deepseek/ptq.py | 229 +++++++++++++++++---- examples/deepseek/quantize_fp8_to_nvfp4.sh | 10 +- examples/deepseek/quantize_to_nvfp4.py | 36 +++- examples/deepseek/run_glm5_ptq.sh | 27 +++ 4 files changed, 256 insertions(+), 46 deletions(-) create mode 100755 examples/deepseek/run_glm5_ptq.sh diff --git a/examples/deepseek/ptq.py b/examples/deepseek/ptq.py index 6b1086a30..c5386baf9 100644 --- a/examples/deepseek/ptq.py +++ b/examples/deepseek/ptq.py @@ -64,24 +64,43 @@ from modelopt.torch.utils.dataset_utils import get_dataset_dataloader from modelopt.torch.utils.distributed import ParallelState -DS_V3_PATH = Path(__file__).resolve().parent / "DeepSeek-V3/inference" -DS_V3_2_PATH = Path(__file__).resolve().parent / "DeepSeek-V3.2-Exp/inference" - -if DS_V3_2_PATH.exists(): - sys.path.append(str(DS_V3_2_PATH)) -elif DS_V3_PATH.exists(): - sys.path.append(str(DS_V3_PATH)) -else: - raise ValueError( - f"DeepSeek-V3 or DeepSeek-V3.2-Exp not found in {Path(__file__).resolve().parent}" - ) +deekseep_model = None +weight_dequant = None +act_quant = None +fp8_gemm = None + + +def _import_deepseek_deps(): + """Lazily import DeepSeek-specific dependencies (only needed for --model_type deepseek).""" + global deekseep_model, weight_dequant, act_quant, fp8_gemm + if deekseep_model is not None: + return + + DS_V3_PATH = Path(__file__).resolve().parent / "DeepSeek-V3/inference" + DS_V3_2_PATH = Path(__file__).resolve().parent / "DeepSeek-V3.2-Exp/inference" + + if DS_V3_2_PATH.exists(): + sys.path.append(str(DS_V3_2_PATH)) + elif DS_V3_PATH.exists(): + sys.path.append(str(DS_V3_PATH)) + else: + raise ValueError( + f"DeepSeek-V3 or DeepSeek-V3.2-Exp not found in {Path(__file__).resolve().parent}" + ) -import model as deekseep_model # noqa: E402 -from ds_kernel import weight_dequant # noqa: E402 -from kernel import act_quant, fp8_gemm # noqa: E402 + import model as _model # noqa: E402 + from ds_kernel import weight_dequant as _weight_dequant # noqa: E402 + from kernel import act_quant as _act_quant, fp8_gemm as _fp8_gemm # noqa: E402 + + deekseep_model = _model + weight_dequant = _weight_dequant + act_quant = _act_quant + fp8_gemm = _fp8_gemm def monkey_patch_deepseek_model(): + _import_deepseek_deps() + gemm_impl: Literal["bf16", "fp8"] = "bf16" block_size = 128 @@ -231,8 +250,106 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: mtq.register(original_cls=deekseep_model.MoE, quantized_cls=CalibMoe) +def _register_hf_moe_for_calibration(model): + """Discover and register the MoE module class from an HF model for calibration. + + This dynamically finds the MoE class (has `experts` + `gate` attributes) and registers it + with modelopt's _QuantSparseMoe so all experts see tokens during calibration. + """ + from modelopt.torch.quantization.nn import QuantModuleRegistry + from modelopt.torch.quantization.plugins.huggingface import _QuantSparseMoe + + moe_cls = None + for _name, module in model.named_modules(): + cls = type(module) + if hasattr(module, "experts") and hasattr(module, "gate") and cls not in QuantModuleRegistry: + moe_cls = cls + break + + if moe_cls is None: + print("Warning: No unregistered MoE module class found in model.") + return + + # Check if the MoE class uses standard attribute names (top_k, num_experts) + sample = next(m for m in model.modules() if isinstance(m, moe_cls)) + has_top_k = hasattr(sample, "top_k") + has_num_experts = hasattr(sample, "num_experts") + + if has_top_k and has_num_experts: + # Standard attribute names - register directly + QuantModuleRegistry.register({moe_cls: f"hf.{moe_cls.__name__}"})(_QuantSparseMoe) + else: + # Need property adapters for non-standard attribute names + # Discover the actual attribute names + top_k_attr = "top_k" if has_top_k else None + num_experts_attr = "num_experts" if has_num_experts else None + + if top_k_attr is None: + for attr in ["num_experts_per_tok", "top_k", "topk"]: + if hasattr(sample, attr): + top_k_attr = attr + break + if num_experts_attr is None: + for attr in ["num_experts", "n_routed_experts", "num_local_experts"]: + if hasattr(sample, attr): + num_experts_attr = attr + break + + if top_k_attr is None or num_experts_attr is None: + print( + f"Warning: Could not find top_k/num_experts attributes on {moe_cls.__name__}. " + f"Skipping MoE calibration registration." + ) + return + + # Create adapter subclass + _top_k_attr = top_k_attr + _num_experts_attr = num_experts_attr + + class _QuantAdaptedSparseMoe(_QuantSparseMoe): + @property + def top_k(self): + return getattr(self, _top_k_attr) + + @top_k.setter + def top_k(self, value): + setattr(self, _top_k_attr, value) + + @property + def num_experts(self): + return getattr(self, _num_experts_attr) + + QuantModuleRegistry.register({moe_cls: f"hf.{moe_cls.__name__}"})(_QuantAdaptedSparseMoe) + + print(f"Registered MoE class {moe_cls.__name__} for calibration.") + + +def load_hf_model(model_path: str): + """Load a HuggingFace model (e.g. GLM-5 bf16 checkpoint). + + Uses device_map="auto" to shard the model across all visible GPUs (single process). + """ + from transformers import AutoModelForCausalLM + + torch.set_default_dtype(torch.bfloat16) + + model = AutoModelForCausalLM.from_pretrained( + model_path, + device_map="auto", + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + model.eval() + + _register_hf_moe_for_calibration(model) + + return model + + def load_deepseek_model(model_config: str, model_path: str, batch_size: int): """Loads the deepseek model to memory.""" + _import_deepseek_deps() + # get distributed info world_size = int(os.getenv("WORLD_SIZE", "1")) rank = int(os.getenv("RANK", "0")) @@ -280,8 +397,11 @@ def ptq( batch_size: int, calib_size: int, mla_quant: str | None = None, + model_type: str = "deepseek", + disable_wo_quant: bool = False, ): - """Runs Deepseek model PTQ and returns the quantized model.""" + """Runs Deepseek/HF model PTQ and returns the quantized model.""" + is_hf = model_type != "deepseek" # quantize the model ## create dataset @@ -303,7 +423,8 @@ def calibrate_loop(model): transformer = model.model if hasattr(model, "model") else model # make sure all processes are ready before starting the calibration - dist.barrier() + if dist.is_initialized(): + dist.barrier() ## quant config mtq_cfg = getattr(mtq, quant_cfg) @@ -311,6 +432,14 @@ def calibrate_loop(model): # disable head that corresponds to lm_head (for the huggingface checkpoint) mtq_cfg["quant_cfg"]["*head*"] = {"enable": False} + if is_hf: + # Disable GLM-5 / HF-specific layers that should not be quantized + mtq_cfg["quant_cfg"]["*indexer*"] = {"enable": False} + mtq_cfg["quant_cfg"]["*eh_proj*"] = {"enable": False} + mtq_cfg["quant_cfg"]["*enorm*"] = {"enable": False} + mtq_cfg["quant_cfg"]["*hnorm*"] = {"enable": False} + mtq_cfg["quant_cfg"]["*shared_head*"] = {"enable": False} + allowed_mla_quant = [None, "per_tensor_fp8", "nvfp4"] assert mla_quant in allowed_mla_quant, f"mla_quant must be {allowed_mla_quant}" @@ -320,11 +449,14 @@ def calibrate_loop(model): mtq_cfg["quant_cfg"]["*attn*weight_quantizer"] = {"num_bits": (4, 3), "axis": None} mtq_cfg["quant_cfg"]["*attn*input_quantizer"] = {"num_bits": (4, 3), "axis": None} elif mla_quant == "nvfp4": # for DeepSeek-R1-0528-NVFP4-Turbo - mla_linear_layers = ["*wq_a*", "*wq_b*", "*wkv_a*", "*wkv_b*", "*wo*"] - mla_nvfp4_linear_layers = ["*wq_a*", "*wkv_a*", "*wq_b*", "*wo*"] + if is_hf: + mla_linear_layers = ["*q_a_proj*", "*q_b_proj*", "*kv_a_proj*", "*kv_b_proj*", "*o_proj*"] + mla_nvfp4_linear_layers = ["*q_a_proj*", "*kv_a_proj*", "*q_b_proj*", "*o_proj*"] + else: + mla_linear_layers = ["*wq_a*", "*wq_b*", "*wkv_a*", "*wkv_b*", "*wo*"] + mla_nvfp4_linear_layers = ["*wq_a*", "*wkv_a*", "*wq_b*", "*wo*"] for layer in mla_linear_layers: if layer in mla_nvfp4_linear_layers: - # wq_a, wkv_a, wq_b, wo use NVFP4 quantization mtq_cfg["quant_cfg"][layer + "_quantizer"] = { "num_bits": (2, 1), "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, @@ -338,14 +470,15 @@ def calibrate_loop(model): mtq_cfg["quant_cfg"]["*attn.kv_bmm_quantizer*"] = {"enable": False} mtq_cfg["quant_cfg"]["*attn.pe_bmm_quantizer*"] = {"enable": False} - if not args.disable_wo_quant and "FP4" in quant_cfg: - mtq_cfg["quant_cfg"]["*wo*weight_quantizer"] = mtq_cfg["quant_cfg"]["*input_quantizer"] - mtq_cfg["quant_cfg"]["*wo*input_quantizer"] = mtq_cfg["quant_cfg"]["*weight_quantizer"] + if not disable_wo_quant and "FP4" in quant_cfg: + wo_pattern = "*o_proj*" if is_hf else "*wo*" + mtq_cfg["quant_cfg"][wo_pattern + "weight_quantizer"] = mtq_cfg["quant_cfg"]["*input_quantizer"] + mtq_cfg["quant_cfg"][wo_pattern + "input_quantizer"] = mtq_cfg["quant_cfg"]["*weight_quantizer"] ## ptq transformer = mtq.quantize(transformer, mtq_cfg, calibrate_loop) - if int(os.environ["LOCAL_RANK"]) == 0: + if int(os.environ.get("LOCAL_RANK", "0")) == 0: mtq.print_quant_summary(transformer) return model @@ -353,13 +486,15 @@ def calibrate_loop(model): def save_amax_and_quant_config(model, output_path: str, enable_fp8_kvcache: bool = True): """Saves the amax values of the model to the output path.""" - world_size = int(os.getenv("WORLD_SIZE", "1")) - rank = int(os.getenv("RANK", "0")) + is_distributed = dist.is_initialized() + world_size = int(os.getenv("WORLD_SIZE", "1")) if is_distributed else 1 + rank = int(os.getenv("RANK", "0")) if is_distributed else 0 if rank == 0 and not os.path.exists(output_path): os.mkdir(output_path) - dist.barrier() + if is_distributed: + dist.barrier() # save amax def state_dict_filter(state_dict): @@ -371,20 +506,16 @@ def state_dict_filter(state_dict): os.path.join(output_path, f"amax_dict_rank{rank}-mp{world_size}.pt"), ) - # if rank == 0: - # with open("expert_activation_counts.txt", "w") as f: - # for name, module in model.named_modules(): - # if isinstance(module, deekseep_model.MoE): - # counts = module.activated_expert_counts() - # f.writelines(f"{name}: {count}\n" for count in counts) - quant_config = get_quant_config(model) if enable_fp8_kvcache: quant_config["quantization"]["kv_cache_quant_algo"] = KV_CACHE_FP8 - all_quant_configs = [None] * dist.get_world_size() - dist.all_gather_object(all_quant_configs, quant_config) + if is_distributed: + all_quant_configs = [None] * dist.get_world_size() + dist.all_gather_object(all_quant_configs, quant_config) + else: + all_quant_configs = [quant_config] if rank == 0: exclude_modules = set() @@ -400,7 +531,8 @@ def state_dict_filter(state_dict): if exclude_modules: quant_config["quantization"]["exclude_modules"] = sorted(exclude_modules) # add the last layer to the exclude module as the mtp is not loaded in the quantized model - quant_config["quantization"]["exclude_modules"].append(f"layers.{len(model.layers)}*") + layers = model.layers if hasattr(model, "layers") else model.model.layers + quant_config["quantization"]["exclude_modules"].append(f"layers.{len(layers)}*") if quantized_layers: quant_config["quantization"]["quantized_layers"] = quantized_layers @@ -433,11 +565,32 @@ def state_dict_filter(state_dict): default=None, help="MLA quantization type: None (disable), per_tensor_fp8, nvfp4", ) + parser.add_argument( + "--model_type", + type=str, + choices=["deepseek", "hf"], + default="deepseek", + help="Model type: 'deepseek' for DeepSeek FP8 ckpt, 'hf' for standard HF bf16 ckpt (e.g. GLM-5).", + ) args = parser.parse_args() - model = load_deepseek_model(args.config, args.model_path, args.batch_size) + + if args.model_type == "hf": + model = load_hf_model(args.model_path) + else: + model = load_deepseek_model(args.config, args.model_path, args.batch_size) + tokenizer = AutoTokenizer.from_pretrained( args.model_path, trust_remote_code=args.trust_remote_code ) - model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size, args.mla_quant) + model = ptq( + model, + tokenizer, + args.quant_cfg, + args.batch_size, + args.calib_size, + args.mla_quant, + model_type=args.model_type, + disable_wo_quant=args.disable_wo_quant, + ) save_amax_and_quant_config(model, args.output_path, not args.disable_fp8_kvcache) diff --git a/examples/deepseek/quantize_fp8_to_nvfp4.sh b/examples/deepseek/quantize_fp8_to_nvfp4.sh index ae24e2bfd..b6e78e1c9 100755 --- a/examples/deepseek/quantize_fp8_to_nvfp4.sh +++ b/examples/deepseek/quantize_fp8_to_nvfp4.sh @@ -17,7 +17,7 @@ set -e # Exit immediately if any command fails usage() { - echo "Usage: $0 --amax_path --fp4_output_path --fp8_hf_path [--world_size ]" + echo "Usage: $0 --amax_path --fp4_output_path --fp8_hf_path [--world_size ] [--model_type ]" exit 1 } @@ -26,6 +26,7 @@ AMAX_PATH="" FP4_PATH="" FP8_HF_PATH="" WORLD_SIZE=8 +MODEL_TYPE="deepseek" # Parse command-line arguments while [[ $# -gt 0 ]]; do @@ -47,6 +48,10 @@ while [[ $# -gt 0 ]]; do WORLD_SIZE="$2" shift 2 ;; + --model_type) + MODEL_TYPE="$2" + shift 2 + ;; *) echo "Unknown argument: $1" usage @@ -88,6 +93,7 @@ python quantize_to_nvfp4.py \ --amax_path "$AMAX_PATH" \ --fp4_path "$FP4_PATH" \ --fp8_hf_path "$FP8_HF_PATH" \ - --world_size "$WORLD_SIZE" + --world_size "$WORLD_SIZE" \ + --model_type "$MODEL_TYPE" echo "Quantization command completed successfully." diff --git a/examples/deepseek/quantize_to_nvfp4.py b/examples/deepseek/quantize_to_nvfp4.py index a18cbbc16..0a51c34ff 100644 --- a/examples/deepseek/quantize_to_nvfp4.py +++ b/examples/deepseek/quantize_to_nvfp4.py @@ -44,8 +44,12 @@ from typing import Any import torch -from ds_kernel import weight_dequant from safetensors.torch import load_file, save_file + +try: + from ds_kernel import weight_dequant +except ImportError: + weight_dequant = None from tqdm import tqdm from modelopt.torch.quantization.qtensor import NVFP4QTensor @@ -90,13 +94,15 @@ def remove_quantization_config_from_original_config(export_dir: str) -> None: config_path = os.path.join(export_dir, "config.json") with open(config_path) as f: cfg = json.load(f) + if "quantization_config" not in cfg: + return del cfg["quantization_config"] with open(config_path, "w") as f: json.dump(cfg, f, indent=2, sort_keys=True) f.write("\n") -def load_and_preprocess_state_dict(modelopt_state_root, world_size=8): +def load_and_preprocess_state_dict(modelopt_state_root, world_size=8, skip_remap=False): state_dict_list = [ torch.load(f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt") for rank in range(world_size) @@ -110,7 +116,8 @@ def load_and_preprocess_state_dict(modelopt_state_root, world_size=8): amax = torch.max(amax, merged_state_dict[key].to(amax.device)) merged_state_dict[key] = amax - _remap_key(merged_state_dict) + if not skip_remap: + _remap_key(merged_state_dict) # set amax for modules to be fused and make sure they share the same input for key, amax in merged_state_dict.items(): @@ -128,18 +135,20 @@ def load_and_preprocess_state_dict(modelopt_state_root, world_size=8): return merged_state_dict -def process_quant_config(quant_config_path: str, save_path: str) -> dict[str, Any]: +def process_quant_config(quant_config_path: str, save_path: str, skip_remap: bool = False) -> dict[str, Any]: with open(quant_config_path) as f: quant_config = json.load(f) if "exclude_modules" in quant_config["quantization"]: exclude_dict = dict.fromkeys(quant_config["quantization"]["exclude_modules"]) - _remap_key(exclude_dict) + if not skip_remap: + _remap_key(exclude_dict) quant_config["quantization"]["exclude_modules"] = list(exclude_dict.keys()) per_layer_quant_config = {} if "quantized_layers" in quant_config["quantization"]: - _remap_key(quant_config["quantization"]["quantized_layers"]) + if not skip_remap: + _remap_key(quant_config["quantization"]["quantized_layers"]) per_layer_quant_config = quant_config["quantization"]["quantized_layers"] with open(save_path, "w") as f: @@ -207,6 +216,10 @@ def get_tensor(tensor_name): # Get scale_inv from the correct file scale_inv = get_tensor(scale_inv_name) fp8_weight_names.append(key) + assert weight_dequant is not None, ( + "ds_kernel.weight_dequant is required for FP8 checkpoint conversion. " + "Install ds_kernel or use --model_type hf for bf16 checkpoints." + ) bf16_state_dict[key] = weight_dequant(item, scale_inv) except KeyError: print(f"Warning: Missing scale_inv tensor for {key}, skipping conversion") @@ -299,16 +312,27 @@ def get_tensor(tensor_name): ) parser.add_argument("--fp8_hf_path", type=str, required=True, help="fp8 hf ckpt.") parser.add_argument("--world_size", type=int, required=True, help="world size used by ptq.") + parser.add_argument( + "--model_type", + type=str, + choices=["deepseek", "hf"], + default="deepseek", + help="Model type: 'deepseek' for DeepSeek FP8 ckpt, 'hf' for standard HF bf16 ckpt (e.g. GLM-5).", + ) args = parser.parse_args() + skip_remap = args.model_type == "hf" + per_layer_quant_config = process_quant_config( quant_config_path=os.path.join(args.amax_path, "hf_quant_config.json"), save_path=os.path.join(args.fp4_path, "hf_quant_config.json"), + skip_remap=skip_remap, ) renamed_state_dict = load_and_preprocess_state_dict( modelopt_state_root=args.amax_path, world_size=args.world_size, + skip_remap=skip_remap, ) convert_fp8_ckpt_to_nvfp4( renamed_state_dict, diff --git a/examples/deepseek/run_glm5_ptq.sh b/examples/deepseek/run_glm5_ptq.sh new file mode 100755 index 000000000..4035f6601 --- /dev/null +++ b/examples/deepseek/run_glm5_ptq.sh @@ -0,0 +1,27 @@ +#!/bin/bash +set -e + +JOBID=${1:?Usage: $0 } + +CONTAINER_IMAGE=$(readlink -f ~/fsw/containers/modelopt.sqsh) +CONTAINER_MOUNTS=$(readlink -f ~/fsw):/fsw + +srun --overlap --jobid=${JOBID} --nodes=1 --ntasks=1 \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${CONTAINER_MOUNTS}" \ + bash -c ' +set -e +pip install --no-deps -e /fsw/Model-Optimizer +pip install git+https://github.com/huggingface/transformers + +cd /fsw/Model-Optimizer/examples/deepseek + +python ptq.py \ + --model_path /fsw/models/glm-5-bf16 \ + --model_type hf \ + --quant_cfg NVFP4_DEFAULT_CFG \ + --output_path /fsw/models/glm-5-nvfp4-amax \ + --trust_remote_code \ + --batch_size 8 \ + --calib_size 512 +' From 418ee37f8c26eefb9d03c29957d49f1183dac8c0 Mon Sep 17 00:00:00 2001 From: William Arnold Date: Wed, 11 Feb 2026 22:30:13 -0800 Subject: [PATCH 2/3] Add NVFP4 quantization pipeline for GLM-5 via DeepSeek V3.2 code path Fix _remap_key to use component-level matching, add kernel.py stubs, MTP head extraction script, and GLM-5 documentation. --- examples/deepseek/kernel.py | 73 +++++++ examples/deepseek/ptq.py | 218 +++------------------ examples/deepseek/quantize_fp8_to_nvfp4.sh | 10 +- examples/deepseek/quantize_to_nvfp4.py | 63 +++--- examples/deepseek/run_glm5_ptq.sh | 34 +++- examples/glm5/README.md | 168 ++++++++++++++++ examples/glm5/extract_mtp_head.py | 84 ++++++++ 7 files changed, 416 insertions(+), 234 deletions(-) create mode 100644 examples/deepseek/kernel.py create mode 100644 examples/glm5/README.md create mode 100644 examples/glm5/extract_mtp_head.py diff --git a/examples/deepseek/kernel.py b/examples/deepseek/kernel.py new file mode 100644 index 000000000..0f6c3024a --- /dev/null +++ b/examples/deepseek/kernel.py @@ -0,0 +1,73 @@ +"""Pure PyTorch kernel implementations for PTQ calibration. + +Replaces the tilelang-based kernel.py which requires CUDA 12's libnvrtc. +These implementations are numerically equivalent but slower — suitable for +PTQ calibration but not production inference. +""" + +import torch + +block_size = 128 + +FP8_MAX = torch.finfo(torch.float8_e4m3fn).max + + +def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: str = None): + """Block-wise FP8 quantization of activations. + + Returns (fp8_tensor, scale_tensor) with the same shapes as the tilelang version. + """ + orig_shape = x.shape + # Flatten to 2D: (num_elements / block_size, block_size) + x_flat = x.reshape(-1, block_size) + # Compute per-block scale + amax = x_flat.float().abs().amax(dim=-1) + scale = amax / FP8_MAX + scale = scale.clamp(min=1e-12) + # Quantize + x_scaled = x_flat.float() / scale.unsqueeze(-1) + x_fp8 = x_scaled.clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn) + x_fp8 = x_fp8.reshape(orig_shape) + # Scale shape: match what the caller expects + scale = scale.reshape(*orig_shape[:-1], orig_shape[-1] // block_size) + return x_fp8, scale + + +def fp8_gemm(a: torch.Tensor, a_scale: torch.Tensor, b: torch.Tensor, b_scale: torch.Tensor): + """FP8 matrix multiply with block-wise dequantization.""" + # Dequantize and do regular matmul + a_f = a.float() * a_scale.unsqueeze(-1).repeat_interleave(block_size, dim=-1)[..., :a.shape[-1]] + b_f = b.float() * b_scale.unsqueeze(-1).repeat_interleave(block_size, dim=-1)[..., :b.shape[-1]] + return torch.matmul(a_f, b_f.t()).to(torch.bfloat16) + + +def fp8_index(q_fp8: torch.Tensor, weights: torch.Tensor, k_cache: torch.Tensor, k_scale_cache: torch.Tensor): + """Compute sparse attention index scores using FP8 Q and cached K. + + Args: + q_fp8: (bsz, seqlen, n_heads, head_dim) float8_e4m3fn + weights: (bsz, seqlen, n_heads, 1) weighting factors + k_cache: (bsz, cache_len, head_dim) float8_e4m3fn + k_scale_cache: (bsz, cache_len, head_dim // block_size) float32 + + Returns: + index_score: (bsz, seqlen, cache_len) attention-like scores + """ + bsz, seqlen, n_heads, head_dim = q_fp8.shape + cache_len = k_cache.shape[1] + n_blocks = head_dim // block_size + + # Dequant K cache: (bsz, cache_len, head_dim) + k_f = k_cache.float().reshape(bsz, cache_len, n_blocks, block_size) + k_f = k_f * k_scale_cache.unsqueeze(-1) + k_f = k_f.reshape(bsz, cache_len, head_dim) + + # Dequant Q: we don't have q_scale here, just use the fp8 values directly + q_f = q_fp8.float() # (bsz, seqlen, n_heads, head_dim) + + # weights: (bsz, seqlen, n_heads, 1) — absorb into q + q_weighted = (q_f * weights).sum(dim=2) # (bsz, seqlen, head_dim) + + # Score: (bsz, seqlen, cache_len) + index_score = torch.bmm(q_weighted, k_f.transpose(1, 2)) + return index_score diff --git a/examples/deepseek/ptq.py b/examples/deepseek/ptq.py index c5386baf9..b5ec2f7c1 100644 --- a/examples/deepseek/ptq.py +++ b/examples/deepseek/ptq.py @@ -64,43 +64,24 @@ from modelopt.torch.utils.dataset_utils import get_dataset_dataloader from modelopt.torch.utils.distributed import ParallelState -deekseep_model = None -weight_dequant = None -act_quant = None -fp8_gemm = None - - -def _import_deepseek_deps(): - """Lazily import DeepSeek-specific dependencies (only needed for --model_type deepseek).""" - global deekseep_model, weight_dequant, act_quant, fp8_gemm - if deekseep_model is not None: - return - - DS_V3_PATH = Path(__file__).resolve().parent / "DeepSeek-V3/inference" - DS_V3_2_PATH = Path(__file__).resolve().parent / "DeepSeek-V3.2-Exp/inference" - - if DS_V3_2_PATH.exists(): - sys.path.append(str(DS_V3_2_PATH)) - elif DS_V3_PATH.exists(): - sys.path.append(str(DS_V3_PATH)) - else: - raise ValueError( - f"DeepSeek-V3 or DeepSeek-V3.2-Exp not found in {Path(__file__).resolve().parent}" - ) - - import model as _model # noqa: E402 - from ds_kernel import weight_dequant as _weight_dequant # noqa: E402 - from kernel import act_quant as _act_quant, fp8_gemm as _fp8_gemm # noqa: E402 +DS_V3_PATH = Path(__file__).resolve().parent / "DeepSeek-V3/inference" +DS_V3_2_PATH = Path(__file__).resolve().parent / "DeepSeek-V3.2-Exp/inference" + +if DS_V3_2_PATH.exists(): + sys.path.append(str(DS_V3_2_PATH)) +elif DS_V3_PATH.exists(): + sys.path.append(str(DS_V3_PATH)) +else: + raise ValueError( + f"DeepSeek-V3 or DeepSeek-V3.2-Exp not found in {Path(__file__).resolve().parent}" + ) - deekseep_model = _model - weight_dequant = _weight_dequant - act_quant = _act_quant - fp8_gemm = _fp8_gemm +import model as deekseep_model # noqa: E402 +from ds_kernel import weight_dequant # noqa: E402 +from kernel import act_quant, fp8_gemm # noqa: E402 def monkey_patch_deepseek_model(): - _import_deepseek_deps() - gemm_impl: Literal["bf16", "fp8"] = "bf16" block_size = 128 @@ -250,106 +231,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: mtq.register(original_cls=deekseep_model.MoE, quantized_cls=CalibMoe) -def _register_hf_moe_for_calibration(model): - """Discover and register the MoE module class from an HF model for calibration. - - This dynamically finds the MoE class (has `experts` + `gate` attributes) and registers it - with modelopt's _QuantSparseMoe so all experts see tokens during calibration. - """ - from modelopt.torch.quantization.nn import QuantModuleRegistry - from modelopt.torch.quantization.plugins.huggingface import _QuantSparseMoe - - moe_cls = None - for _name, module in model.named_modules(): - cls = type(module) - if hasattr(module, "experts") and hasattr(module, "gate") and cls not in QuantModuleRegistry: - moe_cls = cls - break - - if moe_cls is None: - print("Warning: No unregistered MoE module class found in model.") - return - - # Check if the MoE class uses standard attribute names (top_k, num_experts) - sample = next(m for m in model.modules() if isinstance(m, moe_cls)) - has_top_k = hasattr(sample, "top_k") - has_num_experts = hasattr(sample, "num_experts") - - if has_top_k and has_num_experts: - # Standard attribute names - register directly - QuantModuleRegistry.register({moe_cls: f"hf.{moe_cls.__name__}"})(_QuantSparseMoe) - else: - # Need property adapters for non-standard attribute names - # Discover the actual attribute names - top_k_attr = "top_k" if has_top_k else None - num_experts_attr = "num_experts" if has_num_experts else None - - if top_k_attr is None: - for attr in ["num_experts_per_tok", "top_k", "topk"]: - if hasattr(sample, attr): - top_k_attr = attr - break - if num_experts_attr is None: - for attr in ["num_experts", "n_routed_experts", "num_local_experts"]: - if hasattr(sample, attr): - num_experts_attr = attr - break - - if top_k_attr is None or num_experts_attr is None: - print( - f"Warning: Could not find top_k/num_experts attributes on {moe_cls.__name__}. " - f"Skipping MoE calibration registration." - ) - return - - # Create adapter subclass - _top_k_attr = top_k_attr - _num_experts_attr = num_experts_attr - - class _QuantAdaptedSparseMoe(_QuantSparseMoe): - @property - def top_k(self): - return getattr(self, _top_k_attr) - - @top_k.setter - def top_k(self, value): - setattr(self, _top_k_attr, value) - - @property - def num_experts(self): - return getattr(self, _num_experts_attr) - - QuantModuleRegistry.register({moe_cls: f"hf.{moe_cls.__name__}"})(_QuantAdaptedSparseMoe) - - print(f"Registered MoE class {moe_cls.__name__} for calibration.") - - -def load_hf_model(model_path: str): - """Load a HuggingFace model (e.g. GLM-5 bf16 checkpoint). - - Uses device_map="auto" to shard the model across all visible GPUs (single process). - """ - from transformers import AutoModelForCausalLM - - torch.set_default_dtype(torch.bfloat16) - - model = AutoModelForCausalLM.from_pretrained( - model_path, - device_map="auto", - trust_remote_code=True, - torch_dtype=torch.bfloat16, - ) - model.eval() - - _register_hf_moe_for_calibration(model) - - return model - - def load_deepseek_model(model_config: str, model_path: str, batch_size: int): """Loads the deepseek model to memory.""" - _import_deepseek_deps() - # get distributed info world_size = int(os.getenv("WORLD_SIZE", "1")) rank = int(os.getenv("RANK", "0")) @@ -397,11 +280,9 @@ def ptq( batch_size: int, calib_size: int, mla_quant: str | None = None, - model_type: str = "deepseek", disable_wo_quant: bool = False, ): - """Runs Deepseek/HF model PTQ and returns the quantized model.""" - is_hf = model_type != "deepseek" + """Runs Deepseek model PTQ and returns the quantized model.""" # quantize the model ## create dataset @@ -423,8 +304,7 @@ def calibrate_loop(model): transformer = model.model if hasattr(model, "model") else model # make sure all processes are ready before starting the calibration - if dist.is_initialized(): - dist.barrier() + dist.barrier() ## quant config mtq_cfg = getattr(mtq, quant_cfg) @@ -432,14 +312,6 @@ def calibrate_loop(model): # disable head that corresponds to lm_head (for the huggingface checkpoint) mtq_cfg["quant_cfg"]["*head*"] = {"enable": False} - if is_hf: - # Disable GLM-5 / HF-specific layers that should not be quantized - mtq_cfg["quant_cfg"]["*indexer*"] = {"enable": False} - mtq_cfg["quant_cfg"]["*eh_proj*"] = {"enable": False} - mtq_cfg["quant_cfg"]["*enorm*"] = {"enable": False} - mtq_cfg["quant_cfg"]["*hnorm*"] = {"enable": False} - mtq_cfg["quant_cfg"]["*shared_head*"] = {"enable": False} - allowed_mla_quant = [None, "per_tensor_fp8", "nvfp4"] assert mla_quant in allowed_mla_quant, f"mla_quant must be {allowed_mla_quant}" @@ -449,14 +321,11 @@ def calibrate_loop(model): mtq_cfg["quant_cfg"]["*attn*weight_quantizer"] = {"num_bits": (4, 3), "axis": None} mtq_cfg["quant_cfg"]["*attn*input_quantizer"] = {"num_bits": (4, 3), "axis": None} elif mla_quant == "nvfp4": # for DeepSeek-R1-0528-NVFP4-Turbo - if is_hf: - mla_linear_layers = ["*q_a_proj*", "*q_b_proj*", "*kv_a_proj*", "*kv_b_proj*", "*o_proj*"] - mla_nvfp4_linear_layers = ["*q_a_proj*", "*kv_a_proj*", "*q_b_proj*", "*o_proj*"] - else: - mla_linear_layers = ["*wq_a*", "*wq_b*", "*wkv_a*", "*wkv_b*", "*wo*"] - mla_nvfp4_linear_layers = ["*wq_a*", "*wkv_a*", "*wq_b*", "*wo*"] + mla_linear_layers = ["*wq_a*", "*wq_b*", "*wkv_a*", "*wkv_b*", "*wo*"] + mla_nvfp4_linear_layers = ["*wq_a*", "*wkv_a*", "*wq_b*", "*wo*"] for layer in mla_linear_layers: if layer in mla_nvfp4_linear_layers: + # wq_a, wkv_a, wq_b, wo use NVFP4 quantization mtq_cfg["quant_cfg"][layer + "_quantizer"] = { "num_bits": (2, 1), "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, @@ -471,14 +340,13 @@ def calibrate_loop(model): mtq_cfg["quant_cfg"]["*attn.pe_bmm_quantizer*"] = {"enable": False} if not disable_wo_quant and "FP4" in quant_cfg: - wo_pattern = "*o_proj*" if is_hf else "*wo*" - mtq_cfg["quant_cfg"][wo_pattern + "weight_quantizer"] = mtq_cfg["quant_cfg"]["*input_quantizer"] - mtq_cfg["quant_cfg"][wo_pattern + "input_quantizer"] = mtq_cfg["quant_cfg"]["*weight_quantizer"] + mtq_cfg["quant_cfg"]["*wo*weight_quantizer"] = mtq_cfg["quant_cfg"]["*input_quantizer"] + mtq_cfg["quant_cfg"]["*wo*input_quantizer"] = mtq_cfg["quant_cfg"]["*weight_quantizer"] ## ptq transformer = mtq.quantize(transformer, mtq_cfg, calibrate_loop) - if int(os.environ.get("LOCAL_RANK", "0")) == 0: + if int(os.environ["LOCAL_RANK"]) == 0: mtq.print_quant_summary(transformer) return model @@ -486,15 +354,13 @@ def calibrate_loop(model): def save_amax_and_quant_config(model, output_path: str, enable_fp8_kvcache: bool = True): """Saves the amax values of the model to the output path.""" - is_distributed = dist.is_initialized() - world_size = int(os.getenv("WORLD_SIZE", "1")) if is_distributed else 1 - rank = int(os.getenv("RANK", "0")) if is_distributed else 0 + world_size = int(os.getenv("WORLD_SIZE", "1")) + rank = int(os.getenv("RANK", "0")) if rank == 0 and not os.path.exists(output_path): os.mkdir(output_path) - if is_distributed: - dist.barrier() + dist.barrier() # save amax def state_dict_filter(state_dict): @@ -511,11 +377,8 @@ def state_dict_filter(state_dict): if enable_fp8_kvcache: quant_config["quantization"]["kv_cache_quant_algo"] = KV_CACHE_FP8 - if is_distributed: - all_quant_configs = [None] * dist.get_world_size() - dist.all_gather_object(all_quant_configs, quant_config) - else: - all_quant_configs = [quant_config] + all_quant_configs = [None] * dist.get_world_size() + dist.all_gather_object(all_quant_configs, quant_config) if rank == 0: exclude_modules = set() @@ -531,8 +394,7 @@ def state_dict_filter(state_dict): if exclude_modules: quant_config["quantization"]["exclude_modules"] = sorted(exclude_modules) # add the last layer to the exclude module as the mtp is not loaded in the quantized model - layers = model.layers if hasattr(model, "layers") else model.model.layers - quant_config["quantization"]["exclude_modules"].append(f"layers.{len(layers)}*") + quant_config["quantization"]["exclude_modules"].append(f"layers.{len(model.layers)}*") if quantized_layers: quant_config["quantization"]["quantized_layers"] = quantized_layers @@ -565,32 +427,14 @@ def state_dict_filter(state_dict): default=None, help="MLA quantization type: None (disable), per_tensor_fp8, nvfp4", ) - parser.add_argument( - "--model_type", - type=str, - choices=["deepseek", "hf"], - default="deepseek", - help="Model type: 'deepseek' for DeepSeek FP8 ckpt, 'hf' for standard HF bf16 ckpt (e.g. GLM-5).", - ) args = parser.parse_args() - - if args.model_type == "hf": - model = load_hf_model(args.model_path) - else: - model = load_deepseek_model(args.config, args.model_path, args.batch_size) - + model = load_deepseek_model(args.config, args.model_path, args.batch_size) tokenizer = AutoTokenizer.from_pretrained( args.model_path, trust_remote_code=args.trust_remote_code ) model = ptq( - model, - tokenizer, - args.quant_cfg, - args.batch_size, - args.calib_size, - args.mla_quant, - model_type=args.model_type, - disable_wo_quant=args.disable_wo_quant, + model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size, + args.mla_quant, disable_wo_quant=args.disable_wo_quant, ) save_amax_and_quant_config(model, args.output_path, not args.disable_fp8_kvcache) diff --git a/examples/deepseek/quantize_fp8_to_nvfp4.sh b/examples/deepseek/quantize_fp8_to_nvfp4.sh index b6e78e1c9..ae24e2bfd 100755 --- a/examples/deepseek/quantize_fp8_to_nvfp4.sh +++ b/examples/deepseek/quantize_fp8_to_nvfp4.sh @@ -17,7 +17,7 @@ set -e # Exit immediately if any command fails usage() { - echo "Usage: $0 --amax_path --fp4_output_path --fp8_hf_path [--world_size ] [--model_type ]" + echo "Usage: $0 --amax_path --fp4_output_path --fp8_hf_path [--world_size ]" exit 1 } @@ -26,7 +26,6 @@ AMAX_PATH="" FP4_PATH="" FP8_HF_PATH="" WORLD_SIZE=8 -MODEL_TYPE="deepseek" # Parse command-line arguments while [[ $# -gt 0 ]]; do @@ -48,10 +47,6 @@ while [[ $# -gt 0 ]]; do WORLD_SIZE="$2" shift 2 ;; - --model_type) - MODEL_TYPE="$2" - shift 2 - ;; *) echo "Unknown argument: $1" usage @@ -93,7 +88,6 @@ python quantize_to_nvfp4.py \ --amax_path "$AMAX_PATH" \ --fp4_path "$FP4_PATH" \ --fp8_hf_path "$FP8_HF_PATH" \ - --world_size "$WORLD_SIZE" \ - --model_type "$MODEL_TYPE" + --world_size "$WORLD_SIZE" echo "Quantization command completed successfully." diff --git a/examples/deepseek/quantize_to_nvfp4.py b/examples/deepseek/quantize_to_nvfp4.py index 0a51c34ff..73e1a81b6 100644 --- a/examples/deepseek/quantize_to_nvfp4.py +++ b/examples/deepseek/quantize_to_nvfp4.py @@ -45,20 +45,21 @@ import torch from safetensors.torch import load_file, save_file +from tqdm import tqdm try: from ds_kernel import weight_dequant except ImportError: weight_dequant = None -from tqdm import tqdm from modelopt.torch.quantization.qtensor import NVFP4QTensor def _remap_key(key_dict: dict[str, Any]): # renaming the module to match HF modeling - # The order matters here. - mappig = { + # Uses component-level replacement (split on ".") to avoid partial matches. + # Keys inside "indexer.*" are NOT remapped (they use the same names in HF). + mapping = { "ffn": "mlp", "w1": "gate_proj", "w2": "down_proj", @@ -72,15 +73,25 @@ def _remap_key(key_dict: dict[str, Any]): "wo": "o_proj", "head": "lm_head", } + # These keys appear inside indexer.* and must NOT be remapped + indexer_passthrough = {"wq_b", "wk", "k_norm", "weights_proj"} new_dict = {} for k, v in key_dict.items(): - new_key = k.replace("layers", "model.layers") - - for original_pattern, replace_pattern in mappig.items(): - new_key = new_key.replace(original_pattern, replace_pattern) - - new_dict[new_key] = v + parts = k.split(".") + new_parts = [] + in_indexer = False + for part in parts: + if part == "indexer": + in_indexer = True + if part == "layers" and not new_parts: + new_parts.append("model") + new_parts.append("layers") + elif part in mapping and not (in_indexer and part in indexer_passthrough): + new_parts.append(mapping[part]) + else: + new_parts.append(part) + new_dict[".".join(new_parts)] = v key_dict.clear() key_dict.update(new_dict) @@ -102,7 +113,7 @@ def remove_quantization_config_from_original_config(export_dir: str) -> None: f.write("\n") -def load_and_preprocess_state_dict(modelopt_state_root, world_size=8, skip_remap=False): +def load_and_preprocess_state_dict(modelopt_state_root, world_size=8): state_dict_list = [ torch.load(f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt") for rank in range(world_size) @@ -116,8 +127,7 @@ def load_and_preprocess_state_dict(modelopt_state_root, world_size=8, skip_remap amax = torch.max(amax, merged_state_dict[key].to(amax.device)) merged_state_dict[key] = amax - if not skip_remap: - _remap_key(merged_state_dict) + _remap_key(merged_state_dict) # set amax for modules to be fused and make sure they share the same input for key, amax in merged_state_dict.items(): @@ -135,20 +145,18 @@ def load_and_preprocess_state_dict(modelopt_state_root, world_size=8, skip_remap return merged_state_dict -def process_quant_config(quant_config_path: str, save_path: str, skip_remap: bool = False) -> dict[str, Any]: +def process_quant_config(quant_config_path: str, save_path: str) -> dict[str, Any]: with open(quant_config_path) as f: quant_config = json.load(f) if "exclude_modules" in quant_config["quantization"]: exclude_dict = dict.fromkeys(quant_config["quantization"]["exclude_modules"]) - if not skip_remap: - _remap_key(exclude_dict) + _remap_key(exclude_dict) quant_config["quantization"]["exclude_modules"] = list(exclude_dict.keys()) per_layer_quant_config = {} if "quantized_layers" in quant_config["quantization"]: - if not skip_remap: - _remap_key(quant_config["quantization"]["quantized_layers"]) + _remap_key(quant_config["quantization"]["quantized_layers"]) per_layer_quant_config = quant_config["quantization"]["quantized_layers"] with open(save_path, "w") as f: @@ -211,15 +219,15 @@ def get_tensor(tensor_name): if key.endswith("_scale_inv"): continue elif item.element_size() == 1: # FP8 weight + assert weight_dequant is not None, ( + "ds_kernel is required to dequantize FP8 weights. " + "Install it or use a bf16 source checkpoint." + ) scale_inv_name = f"{key}_scale_inv" try: # Get scale_inv from the correct file scale_inv = get_tensor(scale_inv_name) fp8_weight_names.append(key) - assert weight_dequant is not None, ( - "ds_kernel.weight_dequant is required for FP8 checkpoint conversion. " - "Install ds_kernel or use --model_type hf for bf16 checkpoints." - ) bf16_state_dict[key] = weight_dequant(item, scale_inv) except KeyError: print(f"Warning: Missing scale_inv tensor for {key}, skipping conversion") @@ -310,29 +318,20 @@ def get_tensor(tensor_name): parser.add_argument( "--fp4_path", type=str, required=True, help="path to save the fp4 checkpoint." ) - parser.add_argument("--fp8_hf_path", type=str, required=True, help="fp8 hf ckpt.") + parser.add_argument("--fp8_hf_path", "--hf_path", type=str, required=True, help="Source HF checkpoint (fp8 or bf16).") parser.add_argument("--world_size", type=int, required=True, help="world size used by ptq.") - parser.add_argument( - "--model_type", - type=str, - choices=["deepseek", "hf"], - default="deepseek", - help="Model type: 'deepseek' for DeepSeek FP8 ckpt, 'hf' for standard HF bf16 ckpt (e.g. GLM-5).", - ) args = parser.parse_args() - skip_remap = args.model_type == "hf" + os.makedirs(args.fp4_path, exist_ok=True) per_layer_quant_config = process_quant_config( quant_config_path=os.path.join(args.amax_path, "hf_quant_config.json"), save_path=os.path.join(args.fp4_path, "hf_quant_config.json"), - skip_remap=skip_remap, ) renamed_state_dict = load_and_preprocess_state_dict( modelopt_state_root=args.amax_path, world_size=args.world_size, - skip_remap=skip_remap, ) convert_fp8_ckpt_to_nvfp4( renamed_state_dict, diff --git a/examples/deepseek/run_glm5_ptq.sh b/examples/deepseek/run_glm5_ptq.sh index 4035f6601..beaec3bda 100755 --- a/examples/deepseek/run_glm5_ptq.sh +++ b/examples/deepseek/run_glm5_ptq.sh @@ -1,26 +1,46 @@ #!/bin/bash set -e -JOBID=${1:?Usage: $0 } +JOBID=${1:?Usage: $0 [--skip-convert]} +SKIP_CONVERT=false +[[ "${2}" == "--skip-convert" ]] && SKIP_CONVERT=true -CONTAINER_IMAGE=$(readlink -f ~/fsw/containers/modelopt.sqsh) +CONTAINER_IMAGE=$(readlink -f ~/fsw/containers/modelopt-v2.sqsh) CONTAINER_MOUNTS=$(readlink -f ~/fsw):/fsw +HF_CKPT=/fsw/models/glm-5-bf16 +DS_CKPT=/fsw/models/glm-5-ds +AMAX_PATH=/fsw/models/glm-5-nvfp4-amax +DS_V3_2_DIR=/fsw/Model-Optimizer/examples/deepseek/DeepSeek-V3.2-Exp +GLM5_CONFIG=${DS_V3_2_DIR}/inference/config_glm5.json + srun --overlap --jobid=${JOBID} --nodes=1 --ntasks=1 \ --container-image="${CONTAINER_IMAGE}" \ --container-mounts="${CONTAINER_MOUNTS}" \ + --export="ALL,HF_TOKEN=${HF_TOKEN:?Set HF_TOKEN env var}" \ bash -c ' set -e pip install --no-deps -e /fsw/Model-Optimizer -pip install git+https://github.com/huggingface/transformers cd /fsw/Model-Optimizer/examples/deepseek -python ptq.py \ - --model_path /fsw/models/glm-5-bf16 \ - --model_type hf \ +# Step 1: Convert HF bf16 checkpoint to DeepSeek sharded format +if [ "'"${SKIP_CONVERT}"'" != "true" ]; then + python '"${DS_V3_2_DIR}"'/inference/convert.py \ + --hf-ckpt-path '"${HF_CKPT}"' \ + --save-path '"${DS_CKPT}"' \ + --n-experts 256 \ + --model-parallel 8 +else + echo "Skipping conversion (--skip-convert)" +fi + +# Step 2: Run PTQ calibration +torchrun --nproc-per-node 8 --master_port=12346 ptq.py \ + --model_path '"${DS_CKPT}"' \ + --config '"${GLM5_CONFIG}"' \ --quant_cfg NVFP4_DEFAULT_CFG \ - --output_path /fsw/models/glm-5-nvfp4-amax \ + --output_path '"${AMAX_PATH}"' \ --trust_remote_code \ --batch_size 8 \ --calib_size 512 diff --git a/examples/glm5/README.md b/examples/glm5/README.md new file mode 100644 index 000000000..394dbc664 --- /dev/null +++ b/examples/glm5/README.md @@ -0,0 +1,168 @@ +# GLM-5 NVFP4 Quantization + +This guide describes how to quantize the GLM-5 model (bf16) to NVFP4 using the +DeepSeek V3.2 PTQ pipeline. GLM-5 shares the same MoE + MLA architecture as +DeepSeek V3 (256 routed experts, MLA attention with DSA indexer), so we reuse +the DeepSeek inference code with a GLM-5-specific config. + +## Prerequisites + +- SLURM cluster with 8 GPUs (H100 80GB recommended) +- Container image with Model-Optimizer, PyTorch, and `fast_hadamard_transform` + pre-installed (e.g. `modelopt-v2.sqsh`) +- HuggingFace bf16 checkpoint of GLM-5 +- (Optional) FP8 checkpoint of GLM-5 for the MTP head + +## Overview + +The pipeline has four steps: + +1. **Convert** the HF bf16 checkpoint to DeepSeek's sharded format (8-way TP) +2. **PTQ calibration** to compute per-layer amax statistics +3. **Quantize** bf16 weights to NVFP4 using the amax values +4. **(Optional) Extract MTP head** from FP8 checkpoint and add to the output + +## Step 1: Convert HF checkpoint to DeepSeek format + +The DeepSeek V3.2 model code uses a different weight naming convention and +shards weights across tensor-parallel ranks. This step converts and shards the +HF checkpoint using parallel subprocesses (one per rank) with `safetensors` +`get_slice` for memory-efficient reading. + +```bash +srun --overlap --jobid=${JOBID} --nodes=1 --ntasks=1 \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${CONTAINER_MOUNTS}" \ + bash -c ' +pip install --no-deps -e /path/to/Model-Optimizer +cd /path/to/Model-Optimizer/examples/deepseek + +python DeepSeek-V3.2-Exp/inference/convert.py \ + --hf-ckpt-path /path/to/glm-5-bf16 \ + --save-path /path/to/glm-5-ds \ + --n-experts 256 \ + --model-parallel 8 +' +``` + +The conversion: +- Auto-detects the MTP layer (layer 78) via `config.json` and skips it +- Patches `tokenizer_config.json` to remove incompatible fields + (`tokenizer_class`, `is_local`, `extra_special_tokens`) +- Produces 8 shards of ~177GB each + +## Step 2: PTQ calibration + +Run PTQ calibration across 8 GPUs using `torchrun`. This inserts quantizers +into the model and calibrates amax values on sample data. + +```bash +srun --overlap --jobid=${JOBID} --nodes=1 --ntasks=1 \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${CONTAINER_MOUNTS}" \ + --export="ALL,HF_TOKEN=${HF_TOKEN}" \ + bash -c ' +pip install --no-deps -e /path/to/Model-Optimizer +cd /path/to/Model-Optimizer/examples/deepseek + +torchrun --nproc-per-node 8 --master_port=12346 ptq.py \ + --model_path /path/to/glm-5-ds \ + --config DeepSeek-V3.2-Exp/inference/config_glm5.json \ + --quant_cfg NVFP4_DEFAULT_CFG \ + --output_path /path/to/glm-5-nvfp4-amax \ + --trust_remote_code \ + --batch_size 8 \ + --calib_size 512 +' +``` + +Notes: +- An `HF_TOKEN` is required because the calibration dataset + (`nvidia/Nemotron-Post-Training-Dataset-v2`) is gated +- The `kernel.py` stub in `examples/deepseek/` provides pure PyTorch + implementations of `act_quant`, `fp8_gemm`, and `fp8_index` to replace the + tilelang-based kernels (which require CUDA 12's `libnvrtc`) +- Output: 8 amax files + `hf_quant_config.json` + +## Step 3: Quantize to NVFP4 + +Apply the calibrated amax values to quantize weights from the original HF +checkpoint to NVFP4 format. + +```bash +# First copy config/tokenizer files to the output directory +mkdir -p /path/to/glm-5-nvfp4 +cp /path/to/glm-5-bf16/*.json /path/to/glm-5-nvfp4/ +cp /path/to/glm-5-bf16/*token* /path/to/glm-5-nvfp4/ + +# Run quantization (requires 1 GPU) +python examples/deepseek/quantize_to_nvfp4.py \ + --amax_path /path/to/glm-5-nvfp4-amax \ + --hf_path /path/to/glm-5-bf16 \ + --fp4_path /path/to/glm-5-nvfp4 \ + --world_size 8 +``` + +This iterates through all 282 safetensors shards, quantizing weights that have +amax entries to NVFP4 and passing through non-quantized weights (norms, embed, +gate) as bf16. Takes ~35 minutes on a single GPU. + +## Step 4 (Optional): Add MTP head + +The MTP (Multi-Token Prediction) head at layer 78 is excluded from +quantization. If you have an FP8 checkpoint with the MTP head, extract it: + +```bash +python examples/glm5/extract_mtp_head.py \ + --fp8_index /path/to/glm-5-fp8/model.safetensors.index.json \ + --fp8_dir /path/to/glm-5-fp8 \ + --nvfp4_dir /path/to/glm-5-nvfp4 \ + --mtp_layer 78 +``` + +This extracts all layer-78 tensors into `mtp-fp8.safetensors` and updates +`model.safetensors.index.json` to include them. + +## Convenience script + +`examples/deepseek/run_glm5_ptq.sh` combines steps 1 and 2 into a single SLURM +job. Usage: + +```bash +# Allocate a SLURM job (3-4 hours recommended) +salloc --nodes=1 --time=3:59:00 --account= --partition=batch --no-shell + +# Run conversion + PTQ +bash examples/deepseek/run_glm5_ptq.sh + +# Skip conversion if already done +bash examples/deepseek/run_glm5_ptq.sh --skip-convert +``` + +## Key modifications to DeepSeek V3.2 code + +The following changes were made to support GLM-5: + +- **`convert.py`**: Parallel subprocess conversion with `get_slice`, MTP layer + auto-detection, tokenizer config patching +- **`model.py`**: Added `gate_bias` config flag to `ModelArgs` (GLM-5 has gate + bias, DSv3 gate bias was hardcoded to `dim == 7168`) +- **`config_glm5.json`**: GLM-5 model config in DeepSeek V3.2 format +- **`kernel.py`**: Pure PyTorch stubs for `act_quant`, `fp8_gemm`, `fp8_index` + (replaces tilelang kernels that need CUDA 12) +- **`quantize_to_nvfp4.py`**: Conditional `ds_kernel` import, component-level + `_remap_key` to avoid corrupting indexer keys, bf16 source support + +## GLM-5 model config + +| Parameter | Value | +|---|---| +| Hidden dim | 6144 | +| Layers | 78 + 1 MTP | +| Attention heads | 64 | +| Routed experts | 256 (8 activated) | +| Shared experts | 1 | +| Dense layers | 3 | +| Q LoRA rank | 2048 | +| KV LoRA rank | 512 | +| Vocab size | 154880 | diff --git a/examples/glm5/extract_mtp_head.py b/examples/glm5/extract_mtp_head.py new file mode 100644 index 000000000..62ef0a343 --- /dev/null +++ b/examples/glm5/extract_mtp_head.py @@ -0,0 +1,84 @@ +"""Extract MTP (Multi-Token Prediction) head weights from an FP8 checkpoint. + +The GLM-5 model has 78 transformer layers (0-77) plus one MTP layer (layer 78) +used for speculative decoding. During NVFP4 quantization, only layers 0-77 are +quantized. This script extracts the MTP head from a separate FP8 checkpoint and +adds it to the NVFP4 output so the final checkpoint is complete. + +Usage: + python extract_mtp_head.py \ + --fp8_index /path/to/glm-5-fp8/model.safetensors.index.json \ + --fp8_dir /path/to/glm-5-fp8 \ + --nvfp4_dir /path/to/glm-5-nvfp4 \ + --mtp_layer 78 +""" + +import argparse +import json + +from safetensors.torch import load_file, save_file + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--fp8_index", type=str, required=True, + help="Path to the FP8 model.safetensors.index.json") + parser.add_argument("--fp8_dir", type=str, required=True, + help="Directory containing FP8 safetensors files") + parser.add_argument("--nvfp4_dir", type=str, required=True, + help="NVFP4 output directory to add MTP head to") + parser.add_argument("--mtp_layer", type=int, default=78, + help="Layer index of the MTP head (default: 78)") + parser.add_argument("--output_name", type=str, default="mtp-fp8.safetensors", + help="Filename for the MTP safetensors file") + args = parser.parse_args() + + mtp_prefix = f"model.layers.{args.mtp_layer}" + + # Find MTP keys in the FP8 index + with open(args.fp8_index) as f: + idx = json.load(f) + + mtp_keys_by_file = {} + for key, filename in idx["weight_map"].items(): + if key.startswith(mtp_prefix): + mtp_keys_by_file.setdefault(filename, []).append(key) + + if not mtp_keys_by_file: + print(f"No MTP keys found with prefix '{mtp_prefix}'") + return + + total_keys = sum(len(v) for v in mtp_keys_by_file.values()) + print(f"Found {total_keys} MTP keys across {len(mtp_keys_by_file)} files") + + # Extract MTP tensors + mtp_tensors = {} + for filename, keys in sorted(mtp_keys_by_file.items()): + filepath = f"{args.fp8_dir}/{filename}" + print(f" Loading {len(keys)} keys from {filename}...") + data = load_file(filepath, device="cpu") + for k in keys: + mtp_tensors[k] = data[k] + del data + + # Save as single file + out_path = f"{args.nvfp4_dir}/{args.output_name}" + save_file(mtp_tensors, out_path) + print(f"Saved {len(mtp_tensors)} MTP tensors to {out_path}") + + # Update the NVFP4 index + nvfp4_index_path = f"{args.nvfp4_dir}/model.safetensors.index.json" + with open(nvfp4_index_path) as f: + nvfp4_idx = json.load(f) + + for k in mtp_tensors: + nvfp4_idx["weight_map"][k] = args.output_name + + with open(nvfp4_index_path, "w") as f: + json.dump(nvfp4_idx, f, indent=2) + + print(f"Updated {nvfp4_index_path} with MTP keys -> {args.output_name}") + + +if __name__ == "__main__": + main() From 19ed1e4b93fc27b8c1eb602b23e4744608e579b3 Mon Sep 17 00:00:00 2001 From: William Arnold Date: Thu, 26 Feb 2026 05:40:28 -0800 Subject: [PATCH 3/3] Add GLM-5 quantization fixes and transformers 5.x compatibility - Fix TokenizersBackend compatibility in dataset_utils.py: transformers 5.x TokenizersBackend lacks batch_encode_plus, added fallback using _encode_plus per-sample with manual padding support (left/right) - Fix quantize_to_nvfp4.py: skip quantization for layers not listed in per_layer_quant_config when using MIXED_PRECISION mode - Add GLM MLA and DSA Indexer exclusions to hf_ptq build_quant_cfg - Improve run_glm5_ptq.sh CLI: add --amax-path and --mla-quant flags - Add glm5 dequant_nvfp4.py utility --- examples/deepseek/quantize_to_nvfp4.py | 6 + examples/deepseek/run_glm5_ptq.sh | 21 +- examples/glm5/dequant_nvfp4.py | 288 +++++++++++++++++++++++++ examples/llm_ptq/example_utils.py | 6 +- modelopt/torch/utils/dataset_utils.py | 42 +++- 5 files changed, 351 insertions(+), 12 deletions(-) create mode 100644 examples/glm5/dequant_nvfp4.py diff --git a/examples/deepseek/quantize_to_nvfp4.py b/examples/deepseek/quantize_to_nvfp4.py index 73e1a81b6..42772c26f 100644 --- a/examples/deepseek/quantize_to_nvfp4.py +++ b/examples/deepseek/quantize_to_nvfp4.py @@ -244,6 +244,12 @@ def get_tensor(tensor_name): input_scale_key = layer_name + ".input_quantizer._amax" if amax_key in renamed_state_dict: + # If per_layer_quant_config is non-empty (MIXED_PRECISION) but this layer is + # not listed, it was excluded from quantization (e.g. lm_head). Skip quant. + if per_layer_quant_config and layer_name not in per_layer_quant_config: + new_dict[key] = item + continue + # default quant is NVFP4 is_nvfp4 = ( not per_layer_quant_config diff --git a/examples/deepseek/run_glm5_ptq.sh b/examples/deepseek/run_glm5_ptq.sh index beaec3bda..0b3e060dd 100755 --- a/examples/deepseek/run_glm5_ptq.sh +++ b/examples/deepseek/run_glm5_ptq.sh @@ -1,16 +1,28 @@ #!/bin/bash set -e -JOBID=${1:?Usage: $0 [--skip-convert]} +JOBID=${1:?Usage: $0 [--skip-convert] [--amax-path ] [--mla-quant ]} SKIP_CONVERT=false -[[ "${2}" == "--skip-convert" ]] && SKIP_CONVERT=true +AMAX_PATH=/fsw/models/glm-5-nvfp4-amax +MLA_QUANT="" + +# Parse optional flags +shift +while [[ $# -gt 0 ]]; do + case "$1" in + --skip-convert) SKIP_CONVERT=true ;; + --amax-path) AMAX_PATH="$2"; shift ;; + --mla-quant) MLA_QUANT="$2"; shift ;; + *) echo "Unknown argument: $1"; exit 1 ;; + esac + shift +done CONTAINER_IMAGE=$(readlink -f ~/fsw/containers/modelopt-v2.sqsh) CONTAINER_MOUNTS=$(readlink -f ~/fsw):/fsw HF_CKPT=/fsw/models/glm-5-bf16 DS_CKPT=/fsw/models/glm-5-ds -AMAX_PATH=/fsw/models/glm-5-nvfp4-amax DS_V3_2_DIR=/fsw/Model-Optimizer/examples/deepseek/DeepSeek-V3.2-Exp GLM5_CONFIG=${DS_V3_2_DIR}/inference/config_glm5.json @@ -43,5 +55,6 @@ torchrun --nproc-per-node 8 --master_port=12346 ptq.py \ --output_path '"${AMAX_PATH}"' \ --trust_remote_code \ --batch_size 8 \ - --calib_size 512 + --calib_size 512 \ + '"${MLA_QUANT:+--mla_quant $MLA_QUANT}"' ' diff --git a/examples/glm5/dequant_nvfp4.py b/examples/glm5/dequant_nvfp4.py new file mode 100644 index 000000000..8a99ba77a --- /dev/null +++ b/examples/glm5/dequant_nvfp4.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 +"""Dequantize NVFP4/FP8 checkpoint back to bf16 for HuggingFace inference testing. + +Usage: + python dequant_nvfp4.py dequant --nvfp4_dir /path/to/nvfp4 --output_dir /path/to/output + python dequant_nvfp4.py generate --model_dir /path/to/dequanted +""" + +import argparse +import json +import math +import os +import shutil +import sys +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path + +import torch +from safetensors.torch import load_file, save_file + +# e2m1 lookup table for NVFP4: indices 0-7 positive, 8-15 negative +E2M1_LUT = torch.tensor( + [0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6], + dtype=torch.float32, +) + +NVFP4_BLOCK_SIZE = 16 # 16 FP4 values per scale +FP8_BLOCK_SIZE = 128 + + +def dequant_nvfp4(weight_uint8, weight_scale, weight_scale_2): + """Dequantize NVFP4 packed uint8 weight to bf16. + + Args: + weight_uint8: [out, packed_in] uint8 tensor (2 FP4 values per byte) + weight_scale: [out, packed_in // 8] float8_e4m3fn per-block scale + weight_scale_2: scalar float32 double-scale + Returns: + [out, packed_in * 2] bf16 tensor + """ + out_dim, packed_in = weight_uint8.shape + unpacked_in = packed_in * 2 + + # Unpack uint8 to two 4-bit indices + unpacked = torch.empty(out_dim, unpacked_in, dtype=torch.long) + unpacked[:, 0::2] = (weight_uint8 & 0x0F).long() # low nibble + unpacked[:, 1::2] = (weight_uint8 >> 4).long() # high nibble + + # Lookup e2m1 values + fp_values = E2M1_LUT[unpacked] # [out, unpacked_in] f32 + + # Compute per-block scales: scale * scale_2 + per_block_scale = weight_scale.float() * weight_scale_2.float() # [out, num_blocks] + + # Reshape to [out, num_blocks, block_size] and apply scale + num_blocks = per_block_scale.shape[1] + fp_values = fp_values.view(out_dim, num_blocks, NVFP4_BLOCK_SIZE) + result = fp_values * per_block_scale.unsqueeze(-1) + + return result.reshape(out_dim, unpacked_in).to(torch.bfloat16) + + +def dequant_fp8(weight_fp8, scale_inv): + """Dequantize FP8 weight with block-128 scale_inv to bf16. + + Args: + weight_fp8: [M, N] float8_e4m3fn tensor + scale_inv: [ceil(M/128), ceil(N/128)] float32 tensor + Returns: + [M, N] bf16 tensor + """ + M, N = weight_fp8.shape + scale_M, scale_N = scale_inv.shape + bs = FP8_BLOCK_SIZE + + # Convert to float for computation + w = weight_fp8.float() + + # Pad if dimensions not divisible by block size + padded_M = scale_M * bs + padded_N = scale_N * bs + if padded_M != M or padded_N != N: + w_padded = torch.zeros(padded_M, padded_N, dtype=torch.float32) + w_padded[:M, :N] = w + w = w_padded + + # Reshape to blocks and apply scale + w = w.view(scale_M, bs, scale_N, bs) + w = w * scale_inv[:, None, :, None] + w = w.reshape(padded_M, padded_N) + + # Trim padding + return w[:M, :N].to(torch.bfloat16) + + +def is_scale_key(name): + """Check if tensor name is a scale/companion tensor that should be dropped.""" + return ( + name.endswith(".weight_scale") + or name.endswith(".weight_scale_2") + or name.endswith(".weight_scale_inv") + or name.endswith(".input_scale") + ) + + +def process_shard(src_path, dst_path): + """Process a single safetensors shard: dequantize quantized weights, pass through others.""" + tensors = load_file(src_path, device="cpu") + + output = {} + processed_nvfp4 = 0 + processed_fp8 = 0 + passthrough = 0 + + # First pass: identify weight keys and their companions + for name, tensor in tensors.items(): + if is_scale_key(name): + continue # skip scale tensors + + if tensor.dtype == torch.uint8: + # NVFP4 quantized weight + scale_key = name + "_scale" + scale2_key = name + "_scale_2" + weight_scale = tensors[scale_key] + weight_scale_2 = tensors[scale2_key] + output[name] = dequant_nvfp4(tensor, weight_scale, weight_scale_2) + processed_nvfp4 += 1 + + elif tensor.dtype == torch.float8_e4m3fn: + # FP8 quantized weight + scale_inv_key = name + "_scale_inv" + scale_inv = tensors[scale_inv_key] + output[name] = dequant_fp8(tensor, scale_inv) + processed_fp8 += 1 + + else: + # bf16, f32, int — pass through as-is + output[name] = tensor + passthrough += 1 + + save_file(output, dst_path) + return src_path.name, processed_nvfp4, processed_fp8, passthrough + + +def cmd_dequant(args): + nvfp4_dir = Path(args.nvfp4_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Find all safetensors files + st_files = sorted(nvfp4_dir.glob("*.safetensors")) + print(f"Found {len(st_files)} safetensors files", flush=True) + + # Process shards in parallel + t0 = time.time() + futures = {} + with ProcessPoolExecutor(max_workers=args.workers) as pool: + for sf in st_files: + dst = output_dir / sf.name + fut = pool.submit(process_shard, sf, dst) + futures[fut] = sf.name + + for i, fut in enumerate(as_completed(futures), 1): + name, nvfp4, fp8, pt = fut.result() + elapsed = time.time() - t0 + print( + f"[{i}/{len(st_files)}] {name}: " + f"nvfp4={nvfp4}, fp8={fp8}, passthrough={pt} " + f"({elapsed:.0f}s elapsed)", + flush=True, + ) + + elapsed = time.time() - t0 + print(f"\nAll shards processed in {elapsed:.0f}s") + + # Build new index.json with scale keys removed + index_path = nvfp4_dir / "model.safetensors.index.json" + with open(index_path) as f: + index = json.load(f) + + new_weight_map = {} + for key, filename in index["weight_map"].items(): + if not is_scale_key(key): + new_weight_map[key] = filename + + # Recompute total_size from output files + total_size = 0 + for sf in output_dir.glob("*.safetensors"): + total_size += sf.stat().st_size + + new_index = { + "metadata": {"total_size": total_size}, + "weight_map": new_weight_map, + } + with open(output_dir / "model.safetensors.index.json", "w") as f: + json.dump(new_index, f, indent=2) + print(f"Wrote model.safetensors.index.json ({len(new_weight_map)} keys)") + + # Copy config files + for pattern in ["config.json", "generation_config.json", "tokenizer*", "special_tokens*"]: + for src in nvfp4_dir.glob(pattern): + dst = output_dir / src.name + if not dst.exists(): + shutil.copy2(src, dst) + print(f"Copied {src.name}") + + # Patch tokenizer_config.json to remove fields that break HF loading + tok_cfg_path = output_dir / "tokenizer_config.json" + if tok_cfg_path.exists(): + with open(tok_cfg_path) as f: + tok_cfg = json.load(f) + changed = False + for key in ["tokenizer_class", "is_local", "extra_special_tokens"]: + if key in tok_cfg: + del tok_cfg[key] + changed = True + if changed: + with open(tok_cfg_path, "w") as f: + json.dump(tok_cfg, f, indent=2) + print("Patched tokenizer_config.json") + + print("\nDone! Output directory:", output_dir) + + +def cmd_generate(args): + from transformers import AutoModelForCausalLM, AutoTokenizer + + model_dir = args.model_dir + print(f"Loading model from {model_dir}...") + + tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + model_dir, + device_map="auto", + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + model.eval() + + prompts = [ + "What is the capital of France?", + "Write a short poem about the ocean.", + "Explain quantum computing in one sentence.", + ] + + for prompt in prompts: + messages = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = tokenizer(text, return_tensors="pt").to(model.device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=256, + do_sample=False, + ) + + response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) + print(f"\n{'='*60}") + print(f"Prompt: {prompt}") + print(f"Response: {response}") + + print(f"\n{'='*60}") + print("Generation test complete.") + + +def main(): + parser = argparse.ArgumentParser(description="Dequantize NVFP4/FP8 checkpoint to bf16") + sub = parser.add_subparsers(dest="command", required=True) + + p_dequant = sub.add_parser("dequant", help="Dequantize checkpoint") + p_dequant.add_argument("--nvfp4_dir", required=True, help="Path to NVFP4 checkpoint") + p_dequant.add_argument("--output_dir", required=True, help="Path to output bf16 checkpoint") + p_dequant.add_argument("--workers", type=int, default=32, help="Number of parallel workers") + + p_gen = sub.add_parser("generate", help="Run generation test") + p_gen.add_argument("--model_dir", required=True, help="Path to dequanted model") + + args = parser.parse_args() + if args.command == "dequant": + cmd_dequant(args) + elif args.command == "generate": + cmd_generate(args) + + +if __name__ == "__main__": + main() diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 93687a8d0..ebe0444ce 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -237,11 +237,15 @@ def build_quant_cfg( quant_cfg["quant_cfg"]["model*.*attn*k_proj*"] = {"enable": False} quant_cfg["quant_cfg"]["model*.*attn*v_proj*"] = {"enable": False} - if model_type == "deepseek": + if model_type in ("deepseek", "glm"): # Disable MLA quantization for accuracy. quant_cfg["quant_cfg"]["*self_attn.q*"] = {"enable": False} quant_cfg["quant_cfg"]["*self_attn.kv*"] = {"enable": False} + if model_type == "glm": + # Disable DSA Indexer linear quantization for accuracy. + quant_cfg["quant_cfg"]["*self_attn.indexer*"] = {"enable": False} + return quant_cfg diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 16bff49c2..f2df24d5f 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -247,13 +247,41 @@ def get_dataset_dataloader( samples = get_dataset_samples(ds_name, num_sample) all_samples.extend(samples) - batch_encoded = tokenizer.batch_encode_plus( - all_samples, - return_tensors="pt", - padding=True, - truncation=True, - max_length=max_sample_length, - ) + if hasattr(tokenizer, "batch_encode_plus"): + batch_encoded = tokenizer.batch_encode_plus( + all_samples, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_sample_length, + ) + else: + # Fallback for tokenizers (e.g. TokenizersBackend in transformers 5.x) that only + # implement _encode_plus and lack _batch_encode_plus / batch_encode_plus. + from torch.nn.utils.rnn import pad_sequence + + from transformers import BatchEncoding + + pad_id = tokenizer.pad_token_id or 0 + left_pad = getattr(tokenizer, "padding_side", "right") == "left" + encoded_list = [ + tokenizer._encode_plus(s, return_tensors="pt", truncation=True, max_length=max_sample_length) + for s in all_samples + ] + input_ids = [e["input_ids"].squeeze(0) for e in encoded_list] + attention_mask = [e["attention_mask"].squeeze(0) for e in encoded_list] + if left_pad: + # pad_sequence pads on the right; flip, pad, flip back for left padding + input_ids = pad_sequence( + [t.flip(0) for t in input_ids], batch_first=True, padding_value=pad_id + ).flip(1) + attention_mask = pad_sequence( + [t.flip(0) for t in attention_mask], batch_first=True, padding_value=0 + ).flip(1) + else: + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_id) + attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0) + batch_encoded = BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask}) if device: batch_encoded = batch_encoded.to(device)