diff --git a/examples/deepseek/kernel.py b/examples/deepseek/kernel.py new file mode 100644 index 000000000..0f6c3024a --- /dev/null +++ b/examples/deepseek/kernel.py @@ -0,0 +1,73 @@ +"""Pure PyTorch kernel implementations for PTQ calibration. + +Replaces the tilelang-based kernel.py which requires CUDA 12's libnvrtc. +These implementations are numerically equivalent but slower — suitable for +PTQ calibration but not production inference. +""" + +import torch + +block_size = 128 + +FP8_MAX = torch.finfo(torch.float8_e4m3fn).max + + +def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: str = None): + """Block-wise FP8 quantization of activations. + + Returns (fp8_tensor, scale_tensor) with the same shapes as the tilelang version. + """ + orig_shape = x.shape + # Flatten to 2D: (num_elements / block_size, block_size) + x_flat = x.reshape(-1, block_size) + # Compute per-block scale + amax = x_flat.float().abs().amax(dim=-1) + scale = amax / FP8_MAX + scale = scale.clamp(min=1e-12) + # Quantize + x_scaled = x_flat.float() / scale.unsqueeze(-1) + x_fp8 = x_scaled.clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn) + x_fp8 = x_fp8.reshape(orig_shape) + # Scale shape: match what the caller expects + scale = scale.reshape(*orig_shape[:-1], orig_shape[-1] // block_size) + return x_fp8, scale + + +def fp8_gemm(a: torch.Tensor, a_scale: torch.Tensor, b: torch.Tensor, b_scale: torch.Tensor): + """FP8 matrix multiply with block-wise dequantization.""" + # Dequantize and do regular matmul + a_f = a.float() * a_scale.unsqueeze(-1).repeat_interleave(block_size, dim=-1)[..., :a.shape[-1]] + b_f = b.float() * b_scale.unsqueeze(-1).repeat_interleave(block_size, dim=-1)[..., :b.shape[-1]] + return torch.matmul(a_f, b_f.t()).to(torch.bfloat16) + + +def fp8_index(q_fp8: torch.Tensor, weights: torch.Tensor, k_cache: torch.Tensor, k_scale_cache: torch.Tensor): + """Compute sparse attention index scores using FP8 Q and cached K. + + Args: + q_fp8: (bsz, seqlen, n_heads, head_dim) float8_e4m3fn + weights: (bsz, seqlen, n_heads, 1) weighting factors + k_cache: (bsz, cache_len, head_dim) float8_e4m3fn + k_scale_cache: (bsz, cache_len, head_dim // block_size) float32 + + Returns: + index_score: (bsz, seqlen, cache_len) attention-like scores + """ + bsz, seqlen, n_heads, head_dim = q_fp8.shape + cache_len = k_cache.shape[1] + n_blocks = head_dim // block_size + + # Dequant K cache: (bsz, cache_len, head_dim) + k_f = k_cache.float().reshape(bsz, cache_len, n_blocks, block_size) + k_f = k_f * k_scale_cache.unsqueeze(-1) + k_f = k_f.reshape(bsz, cache_len, head_dim) + + # Dequant Q: we don't have q_scale here, just use the fp8 values directly + q_f = q_fp8.float() # (bsz, seqlen, n_heads, head_dim) + + # weights: (bsz, seqlen, n_heads, 1) — absorb into q + q_weighted = (q_f * weights).sum(dim=2) # (bsz, seqlen, head_dim) + + # Score: (bsz, seqlen, cache_len) + index_score = torch.bmm(q_weighted, k_f.transpose(1, 2)) + return index_score diff --git a/examples/deepseek/ptq.py b/examples/deepseek/ptq.py index 6b1086a30..b5ec2f7c1 100644 --- a/examples/deepseek/ptq.py +++ b/examples/deepseek/ptq.py @@ -280,6 +280,7 @@ def ptq( batch_size: int, calib_size: int, mla_quant: str | None = None, + disable_wo_quant: bool = False, ): """Runs Deepseek model PTQ and returns the quantized model.""" @@ -338,7 +339,7 @@ def calibrate_loop(model): mtq_cfg["quant_cfg"]["*attn.kv_bmm_quantizer*"] = {"enable": False} mtq_cfg["quant_cfg"]["*attn.pe_bmm_quantizer*"] = {"enable": False} - if not args.disable_wo_quant and "FP4" in quant_cfg: + if not disable_wo_quant and "FP4" in quant_cfg: mtq_cfg["quant_cfg"]["*wo*weight_quantizer"] = mtq_cfg["quant_cfg"]["*input_quantizer"] mtq_cfg["quant_cfg"]["*wo*input_quantizer"] = mtq_cfg["quant_cfg"]["*weight_quantizer"] @@ -371,13 +372,6 @@ def state_dict_filter(state_dict): os.path.join(output_path, f"amax_dict_rank{rank}-mp{world_size}.pt"), ) - # if rank == 0: - # with open("expert_activation_counts.txt", "w") as f: - # for name, module in model.named_modules(): - # if isinstance(module, deekseep_model.MoE): - # counts = module.activated_expert_counts() - # f.writelines(f"{name}: {count}\n" for count in counts) - quant_config = get_quant_config(model) if enable_fp8_kvcache: @@ -439,5 +433,8 @@ def state_dict_filter(state_dict): tokenizer = AutoTokenizer.from_pretrained( args.model_path, trust_remote_code=args.trust_remote_code ) - model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size, args.mla_quant) + model = ptq( + model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size, + args.mla_quant, disable_wo_quant=args.disable_wo_quant, + ) save_amax_and_quant_config(model, args.output_path, not args.disable_fp8_kvcache) diff --git a/examples/deepseek/quantize_to_nvfp4.py b/examples/deepseek/quantize_to_nvfp4.py index a18cbbc16..42772c26f 100644 --- a/examples/deepseek/quantize_to_nvfp4.py +++ b/examples/deepseek/quantize_to_nvfp4.py @@ -44,17 +44,22 @@ from typing import Any import torch -from ds_kernel import weight_dequant from safetensors.torch import load_file, save_file from tqdm import tqdm +try: + from ds_kernel import weight_dequant +except ImportError: + weight_dequant = None + from modelopt.torch.quantization.qtensor import NVFP4QTensor def _remap_key(key_dict: dict[str, Any]): # renaming the module to match HF modeling - # The order matters here. - mappig = { + # Uses component-level replacement (split on ".") to avoid partial matches. + # Keys inside "indexer.*" are NOT remapped (they use the same names in HF). + mapping = { "ffn": "mlp", "w1": "gate_proj", "w2": "down_proj", @@ -68,15 +73,25 @@ def _remap_key(key_dict: dict[str, Any]): "wo": "o_proj", "head": "lm_head", } + # These keys appear inside indexer.* and must NOT be remapped + indexer_passthrough = {"wq_b", "wk", "k_norm", "weights_proj"} new_dict = {} for k, v in key_dict.items(): - new_key = k.replace("layers", "model.layers") - - for original_pattern, replace_pattern in mappig.items(): - new_key = new_key.replace(original_pattern, replace_pattern) - - new_dict[new_key] = v + parts = k.split(".") + new_parts = [] + in_indexer = False + for part in parts: + if part == "indexer": + in_indexer = True + if part == "layers" and not new_parts: + new_parts.append("model") + new_parts.append("layers") + elif part in mapping and not (in_indexer and part in indexer_passthrough): + new_parts.append(mapping[part]) + else: + new_parts.append(part) + new_dict[".".join(new_parts)] = v key_dict.clear() key_dict.update(new_dict) @@ -90,6 +105,8 @@ def remove_quantization_config_from_original_config(export_dir: str) -> None: config_path = os.path.join(export_dir, "config.json") with open(config_path) as f: cfg = json.load(f) + if "quantization_config" not in cfg: + return del cfg["quantization_config"] with open(config_path, "w") as f: json.dump(cfg, f, indent=2, sort_keys=True) @@ -202,6 +219,10 @@ def get_tensor(tensor_name): if key.endswith("_scale_inv"): continue elif item.element_size() == 1: # FP8 weight + assert weight_dequant is not None, ( + "ds_kernel is required to dequantize FP8 weights. " + "Install it or use a bf16 source checkpoint." + ) scale_inv_name = f"{key}_scale_inv" try: # Get scale_inv from the correct file @@ -223,6 +244,12 @@ def get_tensor(tensor_name): input_scale_key = layer_name + ".input_quantizer._amax" if amax_key in renamed_state_dict: + # If per_layer_quant_config is non-empty (MIXED_PRECISION) but this layer is + # not listed, it was excluded from quantization (e.g. lm_head). Skip quant. + if per_layer_quant_config and layer_name not in per_layer_quant_config: + new_dict[key] = item + continue + # default quant is NVFP4 is_nvfp4 = ( not per_layer_quant_config @@ -297,10 +324,12 @@ def get_tensor(tensor_name): parser.add_argument( "--fp4_path", type=str, required=True, help="path to save the fp4 checkpoint." ) - parser.add_argument("--fp8_hf_path", type=str, required=True, help="fp8 hf ckpt.") + parser.add_argument("--fp8_hf_path", "--hf_path", type=str, required=True, help="Source HF checkpoint (fp8 or bf16).") parser.add_argument("--world_size", type=int, required=True, help="world size used by ptq.") args = parser.parse_args() + os.makedirs(args.fp4_path, exist_ok=True) + per_layer_quant_config = process_quant_config( quant_config_path=os.path.join(args.amax_path, "hf_quant_config.json"), save_path=os.path.join(args.fp4_path, "hf_quant_config.json"), diff --git a/examples/deepseek/run_glm5_ptq.sh b/examples/deepseek/run_glm5_ptq.sh new file mode 100755 index 000000000..0b3e060dd --- /dev/null +++ b/examples/deepseek/run_glm5_ptq.sh @@ -0,0 +1,60 @@ +#!/bin/bash +set -e + +JOBID=${1:?Usage: $0 [--skip-convert] [--amax-path ] [--mla-quant ]} +SKIP_CONVERT=false +AMAX_PATH=/fsw/models/glm-5-nvfp4-amax +MLA_QUANT="" + +# Parse optional flags +shift +while [[ $# -gt 0 ]]; do + case "$1" in + --skip-convert) SKIP_CONVERT=true ;; + --amax-path) AMAX_PATH="$2"; shift ;; + --mla-quant) MLA_QUANT="$2"; shift ;; + *) echo "Unknown argument: $1"; exit 1 ;; + esac + shift +done + +CONTAINER_IMAGE=$(readlink -f ~/fsw/containers/modelopt-v2.sqsh) +CONTAINER_MOUNTS=$(readlink -f ~/fsw):/fsw + +HF_CKPT=/fsw/models/glm-5-bf16 +DS_CKPT=/fsw/models/glm-5-ds +DS_V3_2_DIR=/fsw/Model-Optimizer/examples/deepseek/DeepSeek-V3.2-Exp +GLM5_CONFIG=${DS_V3_2_DIR}/inference/config_glm5.json + +srun --overlap --jobid=${JOBID} --nodes=1 --ntasks=1 \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${CONTAINER_MOUNTS}" \ + --export="ALL,HF_TOKEN=${HF_TOKEN:?Set HF_TOKEN env var}" \ + bash -c ' +set -e +pip install --no-deps -e /fsw/Model-Optimizer + +cd /fsw/Model-Optimizer/examples/deepseek + +# Step 1: Convert HF bf16 checkpoint to DeepSeek sharded format +if [ "'"${SKIP_CONVERT}"'" != "true" ]; then + python '"${DS_V3_2_DIR}"'/inference/convert.py \ + --hf-ckpt-path '"${HF_CKPT}"' \ + --save-path '"${DS_CKPT}"' \ + --n-experts 256 \ + --model-parallel 8 +else + echo "Skipping conversion (--skip-convert)" +fi + +# Step 2: Run PTQ calibration +torchrun --nproc-per-node 8 --master_port=12346 ptq.py \ + --model_path '"${DS_CKPT}"' \ + --config '"${GLM5_CONFIG}"' \ + --quant_cfg NVFP4_DEFAULT_CFG \ + --output_path '"${AMAX_PATH}"' \ + --trust_remote_code \ + --batch_size 8 \ + --calib_size 512 \ + '"${MLA_QUANT:+--mla_quant $MLA_QUANT}"' +' diff --git a/examples/glm5/README.md b/examples/glm5/README.md new file mode 100644 index 000000000..394dbc664 --- /dev/null +++ b/examples/glm5/README.md @@ -0,0 +1,168 @@ +# GLM-5 NVFP4 Quantization + +This guide describes how to quantize the GLM-5 model (bf16) to NVFP4 using the +DeepSeek V3.2 PTQ pipeline. GLM-5 shares the same MoE + MLA architecture as +DeepSeek V3 (256 routed experts, MLA attention with DSA indexer), so we reuse +the DeepSeek inference code with a GLM-5-specific config. + +## Prerequisites + +- SLURM cluster with 8 GPUs (H100 80GB recommended) +- Container image with Model-Optimizer, PyTorch, and `fast_hadamard_transform` + pre-installed (e.g. `modelopt-v2.sqsh`) +- HuggingFace bf16 checkpoint of GLM-5 +- (Optional) FP8 checkpoint of GLM-5 for the MTP head + +## Overview + +The pipeline has four steps: + +1. **Convert** the HF bf16 checkpoint to DeepSeek's sharded format (8-way TP) +2. **PTQ calibration** to compute per-layer amax statistics +3. **Quantize** bf16 weights to NVFP4 using the amax values +4. **(Optional) Extract MTP head** from FP8 checkpoint and add to the output + +## Step 1: Convert HF checkpoint to DeepSeek format + +The DeepSeek V3.2 model code uses a different weight naming convention and +shards weights across tensor-parallel ranks. This step converts and shards the +HF checkpoint using parallel subprocesses (one per rank) with `safetensors` +`get_slice` for memory-efficient reading. + +```bash +srun --overlap --jobid=${JOBID} --nodes=1 --ntasks=1 \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${CONTAINER_MOUNTS}" \ + bash -c ' +pip install --no-deps -e /path/to/Model-Optimizer +cd /path/to/Model-Optimizer/examples/deepseek + +python DeepSeek-V3.2-Exp/inference/convert.py \ + --hf-ckpt-path /path/to/glm-5-bf16 \ + --save-path /path/to/glm-5-ds \ + --n-experts 256 \ + --model-parallel 8 +' +``` + +The conversion: +- Auto-detects the MTP layer (layer 78) via `config.json` and skips it +- Patches `tokenizer_config.json` to remove incompatible fields + (`tokenizer_class`, `is_local`, `extra_special_tokens`) +- Produces 8 shards of ~177GB each + +## Step 2: PTQ calibration + +Run PTQ calibration across 8 GPUs using `torchrun`. This inserts quantizers +into the model and calibrates amax values on sample data. + +```bash +srun --overlap --jobid=${JOBID} --nodes=1 --ntasks=1 \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${CONTAINER_MOUNTS}" \ + --export="ALL,HF_TOKEN=${HF_TOKEN}" \ + bash -c ' +pip install --no-deps -e /path/to/Model-Optimizer +cd /path/to/Model-Optimizer/examples/deepseek + +torchrun --nproc-per-node 8 --master_port=12346 ptq.py \ + --model_path /path/to/glm-5-ds \ + --config DeepSeek-V3.2-Exp/inference/config_glm5.json \ + --quant_cfg NVFP4_DEFAULT_CFG \ + --output_path /path/to/glm-5-nvfp4-amax \ + --trust_remote_code \ + --batch_size 8 \ + --calib_size 512 +' +``` + +Notes: +- An `HF_TOKEN` is required because the calibration dataset + (`nvidia/Nemotron-Post-Training-Dataset-v2`) is gated +- The `kernel.py` stub in `examples/deepseek/` provides pure PyTorch + implementations of `act_quant`, `fp8_gemm`, and `fp8_index` to replace the + tilelang-based kernels (which require CUDA 12's `libnvrtc`) +- Output: 8 amax files + `hf_quant_config.json` + +## Step 3: Quantize to NVFP4 + +Apply the calibrated amax values to quantize weights from the original HF +checkpoint to NVFP4 format. + +```bash +# First copy config/tokenizer files to the output directory +mkdir -p /path/to/glm-5-nvfp4 +cp /path/to/glm-5-bf16/*.json /path/to/glm-5-nvfp4/ +cp /path/to/glm-5-bf16/*token* /path/to/glm-5-nvfp4/ + +# Run quantization (requires 1 GPU) +python examples/deepseek/quantize_to_nvfp4.py \ + --amax_path /path/to/glm-5-nvfp4-amax \ + --hf_path /path/to/glm-5-bf16 \ + --fp4_path /path/to/glm-5-nvfp4 \ + --world_size 8 +``` + +This iterates through all 282 safetensors shards, quantizing weights that have +amax entries to NVFP4 and passing through non-quantized weights (norms, embed, +gate) as bf16. Takes ~35 minutes on a single GPU. + +## Step 4 (Optional): Add MTP head + +The MTP (Multi-Token Prediction) head at layer 78 is excluded from +quantization. If you have an FP8 checkpoint with the MTP head, extract it: + +```bash +python examples/glm5/extract_mtp_head.py \ + --fp8_index /path/to/glm-5-fp8/model.safetensors.index.json \ + --fp8_dir /path/to/glm-5-fp8 \ + --nvfp4_dir /path/to/glm-5-nvfp4 \ + --mtp_layer 78 +``` + +This extracts all layer-78 tensors into `mtp-fp8.safetensors` and updates +`model.safetensors.index.json` to include them. + +## Convenience script + +`examples/deepseek/run_glm5_ptq.sh` combines steps 1 and 2 into a single SLURM +job. Usage: + +```bash +# Allocate a SLURM job (3-4 hours recommended) +salloc --nodes=1 --time=3:59:00 --account= --partition=batch --no-shell + +# Run conversion + PTQ +bash examples/deepseek/run_glm5_ptq.sh + +# Skip conversion if already done +bash examples/deepseek/run_glm5_ptq.sh --skip-convert +``` + +## Key modifications to DeepSeek V3.2 code + +The following changes were made to support GLM-5: + +- **`convert.py`**: Parallel subprocess conversion with `get_slice`, MTP layer + auto-detection, tokenizer config patching +- **`model.py`**: Added `gate_bias` config flag to `ModelArgs` (GLM-5 has gate + bias, DSv3 gate bias was hardcoded to `dim == 7168`) +- **`config_glm5.json`**: GLM-5 model config in DeepSeek V3.2 format +- **`kernel.py`**: Pure PyTorch stubs for `act_quant`, `fp8_gemm`, `fp8_index` + (replaces tilelang kernels that need CUDA 12) +- **`quantize_to_nvfp4.py`**: Conditional `ds_kernel` import, component-level + `_remap_key` to avoid corrupting indexer keys, bf16 source support + +## GLM-5 model config + +| Parameter | Value | +|---|---| +| Hidden dim | 6144 | +| Layers | 78 + 1 MTP | +| Attention heads | 64 | +| Routed experts | 256 (8 activated) | +| Shared experts | 1 | +| Dense layers | 3 | +| Q LoRA rank | 2048 | +| KV LoRA rank | 512 | +| Vocab size | 154880 | diff --git a/examples/glm5/dequant_nvfp4.py b/examples/glm5/dequant_nvfp4.py new file mode 100644 index 000000000..8a99ba77a --- /dev/null +++ b/examples/glm5/dequant_nvfp4.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 +"""Dequantize NVFP4/FP8 checkpoint back to bf16 for HuggingFace inference testing. + +Usage: + python dequant_nvfp4.py dequant --nvfp4_dir /path/to/nvfp4 --output_dir /path/to/output + python dequant_nvfp4.py generate --model_dir /path/to/dequanted +""" + +import argparse +import json +import math +import os +import shutil +import sys +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path + +import torch +from safetensors.torch import load_file, save_file + +# e2m1 lookup table for NVFP4: indices 0-7 positive, 8-15 negative +E2M1_LUT = torch.tensor( + [0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6], + dtype=torch.float32, +) + +NVFP4_BLOCK_SIZE = 16 # 16 FP4 values per scale +FP8_BLOCK_SIZE = 128 + + +def dequant_nvfp4(weight_uint8, weight_scale, weight_scale_2): + """Dequantize NVFP4 packed uint8 weight to bf16. + + Args: + weight_uint8: [out, packed_in] uint8 tensor (2 FP4 values per byte) + weight_scale: [out, packed_in // 8] float8_e4m3fn per-block scale + weight_scale_2: scalar float32 double-scale + Returns: + [out, packed_in * 2] bf16 tensor + """ + out_dim, packed_in = weight_uint8.shape + unpacked_in = packed_in * 2 + + # Unpack uint8 to two 4-bit indices + unpacked = torch.empty(out_dim, unpacked_in, dtype=torch.long) + unpacked[:, 0::2] = (weight_uint8 & 0x0F).long() # low nibble + unpacked[:, 1::2] = (weight_uint8 >> 4).long() # high nibble + + # Lookup e2m1 values + fp_values = E2M1_LUT[unpacked] # [out, unpacked_in] f32 + + # Compute per-block scales: scale * scale_2 + per_block_scale = weight_scale.float() * weight_scale_2.float() # [out, num_blocks] + + # Reshape to [out, num_blocks, block_size] and apply scale + num_blocks = per_block_scale.shape[1] + fp_values = fp_values.view(out_dim, num_blocks, NVFP4_BLOCK_SIZE) + result = fp_values * per_block_scale.unsqueeze(-1) + + return result.reshape(out_dim, unpacked_in).to(torch.bfloat16) + + +def dequant_fp8(weight_fp8, scale_inv): + """Dequantize FP8 weight with block-128 scale_inv to bf16. + + Args: + weight_fp8: [M, N] float8_e4m3fn tensor + scale_inv: [ceil(M/128), ceil(N/128)] float32 tensor + Returns: + [M, N] bf16 tensor + """ + M, N = weight_fp8.shape + scale_M, scale_N = scale_inv.shape + bs = FP8_BLOCK_SIZE + + # Convert to float for computation + w = weight_fp8.float() + + # Pad if dimensions not divisible by block size + padded_M = scale_M * bs + padded_N = scale_N * bs + if padded_M != M or padded_N != N: + w_padded = torch.zeros(padded_M, padded_N, dtype=torch.float32) + w_padded[:M, :N] = w + w = w_padded + + # Reshape to blocks and apply scale + w = w.view(scale_M, bs, scale_N, bs) + w = w * scale_inv[:, None, :, None] + w = w.reshape(padded_M, padded_N) + + # Trim padding + return w[:M, :N].to(torch.bfloat16) + + +def is_scale_key(name): + """Check if tensor name is a scale/companion tensor that should be dropped.""" + return ( + name.endswith(".weight_scale") + or name.endswith(".weight_scale_2") + or name.endswith(".weight_scale_inv") + or name.endswith(".input_scale") + ) + + +def process_shard(src_path, dst_path): + """Process a single safetensors shard: dequantize quantized weights, pass through others.""" + tensors = load_file(src_path, device="cpu") + + output = {} + processed_nvfp4 = 0 + processed_fp8 = 0 + passthrough = 0 + + # First pass: identify weight keys and their companions + for name, tensor in tensors.items(): + if is_scale_key(name): + continue # skip scale tensors + + if tensor.dtype == torch.uint8: + # NVFP4 quantized weight + scale_key = name + "_scale" + scale2_key = name + "_scale_2" + weight_scale = tensors[scale_key] + weight_scale_2 = tensors[scale2_key] + output[name] = dequant_nvfp4(tensor, weight_scale, weight_scale_2) + processed_nvfp4 += 1 + + elif tensor.dtype == torch.float8_e4m3fn: + # FP8 quantized weight + scale_inv_key = name + "_scale_inv" + scale_inv = tensors[scale_inv_key] + output[name] = dequant_fp8(tensor, scale_inv) + processed_fp8 += 1 + + else: + # bf16, f32, int — pass through as-is + output[name] = tensor + passthrough += 1 + + save_file(output, dst_path) + return src_path.name, processed_nvfp4, processed_fp8, passthrough + + +def cmd_dequant(args): + nvfp4_dir = Path(args.nvfp4_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Find all safetensors files + st_files = sorted(nvfp4_dir.glob("*.safetensors")) + print(f"Found {len(st_files)} safetensors files", flush=True) + + # Process shards in parallel + t0 = time.time() + futures = {} + with ProcessPoolExecutor(max_workers=args.workers) as pool: + for sf in st_files: + dst = output_dir / sf.name + fut = pool.submit(process_shard, sf, dst) + futures[fut] = sf.name + + for i, fut in enumerate(as_completed(futures), 1): + name, nvfp4, fp8, pt = fut.result() + elapsed = time.time() - t0 + print( + f"[{i}/{len(st_files)}] {name}: " + f"nvfp4={nvfp4}, fp8={fp8}, passthrough={pt} " + f"({elapsed:.0f}s elapsed)", + flush=True, + ) + + elapsed = time.time() - t0 + print(f"\nAll shards processed in {elapsed:.0f}s") + + # Build new index.json with scale keys removed + index_path = nvfp4_dir / "model.safetensors.index.json" + with open(index_path) as f: + index = json.load(f) + + new_weight_map = {} + for key, filename in index["weight_map"].items(): + if not is_scale_key(key): + new_weight_map[key] = filename + + # Recompute total_size from output files + total_size = 0 + for sf in output_dir.glob("*.safetensors"): + total_size += sf.stat().st_size + + new_index = { + "metadata": {"total_size": total_size}, + "weight_map": new_weight_map, + } + with open(output_dir / "model.safetensors.index.json", "w") as f: + json.dump(new_index, f, indent=2) + print(f"Wrote model.safetensors.index.json ({len(new_weight_map)} keys)") + + # Copy config files + for pattern in ["config.json", "generation_config.json", "tokenizer*", "special_tokens*"]: + for src in nvfp4_dir.glob(pattern): + dst = output_dir / src.name + if not dst.exists(): + shutil.copy2(src, dst) + print(f"Copied {src.name}") + + # Patch tokenizer_config.json to remove fields that break HF loading + tok_cfg_path = output_dir / "tokenizer_config.json" + if tok_cfg_path.exists(): + with open(tok_cfg_path) as f: + tok_cfg = json.load(f) + changed = False + for key in ["tokenizer_class", "is_local", "extra_special_tokens"]: + if key in tok_cfg: + del tok_cfg[key] + changed = True + if changed: + with open(tok_cfg_path, "w") as f: + json.dump(tok_cfg, f, indent=2) + print("Patched tokenizer_config.json") + + print("\nDone! Output directory:", output_dir) + + +def cmd_generate(args): + from transformers import AutoModelForCausalLM, AutoTokenizer + + model_dir = args.model_dir + print(f"Loading model from {model_dir}...") + + tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + model_dir, + device_map="auto", + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + model.eval() + + prompts = [ + "What is the capital of France?", + "Write a short poem about the ocean.", + "Explain quantum computing in one sentence.", + ] + + for prompt in prompts: + messages = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = tokenizer(text, return_tensors="pt").to(model.device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=256, + do_sample=False, + ) + + response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) + print(f"\n{'='*60}") + print(f"Prompt: {prompt}") + print(f"Response: {response}") + + print(f"\n{'='*60}") + print("Generation test complete.") + + +def main(): + parser = argparse.ArgumentParser(description="Dequantize NVFP4/FP8 checkpoint to bf16") + sub = parser.add_subparsers(dest="command", required=True) + + p_dequant = sub.add_parser("dequant", help="Dequantize checkpoint") + p_dequant.add_argument("--nvfp4_dir", required=True, help="Path to NVFP4 checkpoint") + p_dequant.add_argument("--output_dir", required=True, help="Path to output bf16 checkpoint") + p_dequant.add_argument("--workers", type=int, default=32, help="Number of parallel workers") + + p_gen = sub.add_parser("generate", help="Run generation test") + p_gen.add_argument("--model_dir", required=True, help="Path to dequanted model") + + args = parser.parse_args() + if args.command == "dequant": + cmd_dequant(args) + elif args.command == "generate": + cmd_generate(args) + + +if __name__ == "__main__": + main() diff --git a/examples/glm5/extract_mtp_head.py b/examples/glm5/extract_mtp_head.py new file mode 100644 index 000000000..62ef0a343 --- /dev/null +++ b/examples/glm5/extract_mtp_head.py @@ -0,0 +1,84 @@ +"""Extract MTP (Multi-Token Prediction) head weights from an FP8 checkpoint. + +The GLM-5 model has 78 transformer layers (0-77) plus one MTP layer (layer 78) +used for speculative decoding. During NVFP4 quantization, only layers 0-77 are +quantized. This script extracts the MTP head from a separate FP8 checkpoint and +adds it to the NVFP4 output so the final checkpoint is complete. + +Usage: + python extract_mtp_head.py \ + --fp8_index /path/to/glm-5-fp8/model.safetensors.index.json \ + --fp8_dir /path/to/glm-5-fp8 \ + --nvfp4_dir /path/to/glm-5-nvfp4 \ + --mtp_layer 78 +""" + +import argparse +import json + +from safetensors.torch import load_file, save_file + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--fp8_index", type=str, required=True, + help="Path to the FP8 model.safetensors.index.json") + parser.add_argument("--fp8_dir", type=str, required=True, + help="Directory containing FP8 safetensors files") + parser.add_argument("--nvfp4_dir", type=str, required=True, + help="NVFP4 output directory to add MTP head to") + parser.add_argument("--mtp_layer", type=int, default=78, + help="Layer index of the MTP head (default: 78)") + parser.add_argument("--output_name", type=str, default="mtp-fp8.safetensors", + help="Filename for the MTP safetensors file") + args = parser.parse_args() + + mtp_prefix = f"model.layers.{args.mtp_layer}" + + # Find MTP keys in the FP8 index + with open(args.fp8_index) as f: + idx = json.load(f) + + mtp_keys_by_file = {} + for key, filename in idx["weight_map"].items(): + if key.startswith(mtp_prefix): + mtp_keys_by_file.setdefault(filename, []).append(key) + + if not mtp_keys_by_file: + print(f"No MTP keys found with prefix '{mtp_prefix}'") + return + + total_keys = sum(len(v) for v in mtp_keys_by_file.values()) + print(f"Found {total_keys} MTP keys across {len(mtp_keys_by_file)} files") + + # Extract MTP tensors + mtp_tensors = {} + for filename, keys in sorted(mtp_keys_by_file.items()): + filepath = f"{args.fp8_dir}/{filename}" + print(f" Loading {len(keys)} keys from {filename}...") + data = load_file(filepath, device="cpu") + for k in keys: + mtp_tensors[k] = data[k] + del data + + # Save as single file + out_path = f"{args.nvfp4_dir}/{args.output_name}" + save_file(mtp_tensors, out_path) + print(f"Saved {len(mtp_tensors)} MTP tensors to {out_path}") + + # Update the NVFP4 index + nvfp4_index_path = f"{args.nvfp4_dir}/model.safetensors.index.json" + with open(nvfp4_index_path) as f: + nvfp4_idx = json.load(f) + + for k in mtp_tensors: + nvfp4_idx["weight_map"][k] = args.output_name + + with open(nvfp4_index_path, "w") as f: + json.dump(nvfp4_idx, f, indent=2) + + print(f"Updated {nvfp4_index_path} with MTP keys -> {args.output_name}") + + +if __name__ == "__main__": + main() diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 93687a8d0..ebe0444ce 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -237,11 +237,15 @@ def build_quant_cfg( quant_cfg["quant_cfg"]["model*.*attn*k_proj*"] = {"enable": False} quant_cfg["quant_cfg"]["model*.*attn*v_proj*"] = {"enable": False} - if model_type == "deepseek": + if model_type in ("deepseek", "glm"): # Disable MLA quantization for accuracy. quant_cfg["quant_cfg"]["*self_attn.q*"] = {"enable": False} quant_cfg["quant_cfg"]["*self_attn.kv*"] = {"enable": False} + if model_type == "glm": + # Disable DSA Indexer linear quantization for accuracy. + quant_cfg["quant_cfg"]["*self_attn.indexer*"] = {"enable": False} + return quant_cfg diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 16bff49c2..f2df24d5f 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -247,13 +247,41 @@ def get_dataset_dataloader( samples = get_dataset_samples(ds_name, num_sample) all_samples.extend(samples) - batch_encoded = tokenizer.batch_encode_plus( - all_samples, - return_tensors="pt", - padding=True, - truncation=True, - max_length=max_sample_length, - ) + if hasattr(tokenizer, "batch_encode_plus"): + batch_encoded = tokenizer.batch_encode_plus( + all_samples, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_sample_length, + ) + else: + # Fallback for tokenizers (e.g. TokenizersBackend in transformers 5.x) that only + # implement _encode_plus and lack _batch_encode_plus / batch_encode_plus. + from torch.nn.utils.rnn import pad_sequence + + from transformers import BatchEncoding + + pad_id = tokenizer.pad_token_id or 0 + left_pad = getattr(tokenizer, "padding_side", "right") == "left" + encoded_list = [ + tokenizer._encode_plus(s, return_tensors="pt", truncation=True, max_length=max_sample_length) + for s in all_samples + ] + input_ids = [e["input_ids"].squeeze(0) for e in encoded_list] + attention_mask = [e["attention_mask"].squeeze(0) for e in encoded_list] + if left_pad: + # pad_sequence pads on the right; flip, pad, flip back for left padding + input_ids = pad_sequence( + [t.flip(0) for t in input_ids], batch_first=True, padding_value=pad_id + ).flip(1) + attention_mask = pad_sequence( + [t.flip(0) for t in attention_mask], batch_first=True, padding_value=0 + ).flip(1) + else: + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_id) + attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0) + batch_encoded = BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask}) if device: batch_encoded = batch_encoded.to(device)