Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions examples/deepseek/kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Pure PyTorch kernel implementations for PTQ calibration.

Replaces the tilelang-based kernel.py which requires CUDA 12's libnvrtc.
These implementations are numerically equivalent but slower — suitable for
PTQ calibration but not production inference.
"""

import torch

block_size = 128

FP8_MAX = torch.finfo(torch.float8_e4m3fn).max


def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: str = None):
"""Block-wise FP8 quantization of activations.

Returns (fp8_tensor, scale_tensor) with the same shapes as the tilelang version.
"""
orig_shape = x.shape
# Flatten to 2D: (num_elements / block_size, block_size)
x_flat = x.reshape(-1, block_size)
# Compute per-block scale
amax = x_flat.float().abs().amax(dim=-1)
scale = amax / FP8_MAX
scale = scale.clamp(min=1e-12)
# Quantize
x_scaled = x_flat.float() / scale.unsqueeze(-1)
x_fp8 = x_scaled.clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn)
x_fp8 = x_fp8.reshape(orig_shape)
# Scale shape: match what the caller expects
scale = scale.reshape(*orig_shape[:-1], orig_shape[-1] // block_size)
return x_fp8, scale


def fp8_gemm(a: torch.Tensor, a_scale: torch.Tensor, b: torch.Tensor, b_scale: torch.Tensor):
"""FP8 matrix multiply with block-wise dequantization."""
# Dequantize and do regular matmul
a_f = a.float() * a_scale.unsqueeze(-1).repeat_interleave(block_size, dim=-1)[..., :a.shape[-1]]
b_f = b.float() * b_scale.unsqueeze(-1).repeat_interleave(block_size, dim=-1)[..., :b.shape[-1]]
return torch.matmul(a_f, b_f.t()).to(torch.bfloat16)


def fp8_index(q_fp8: torch.Tensor, weights: torch.Tensor, k_cache: torch.Tensor, k_scale_cache: torch.Tensor):
"""Compute sparse attention index scores using FP8 Q and cached K.

Args:
q_fp8: (bsz, seqlen, n_heads, head_dim) float8_e4m3fn
weights: (bsz, seqlen, n_heads, 1) weighting factors
k_cache: (bsz, cache_len, head_dim) float8_e4m3fn
k_scale_cache: (bsz, cache_len, head_dim // block_size) float32

Returns:
index_score: (bsz, seqlen, cache_len) attention-like scores
"""
bsz, seqlen, n_heads, head_dim = q_fp8.shape
cache_len = k_cache.shape[1]
n_blocks = head_dim // block_size

# Dequant K cache: (bsz, cache_len, head_dim)
k_f = k_cache.float().reshape(bsz, cache_len, n_blocks, block_size)
k_f = k_f * k_scale_cache.unsqueeze(-1)
k_f = k_f.reshape(bsz, cache_len, head_dim)

# Dequant Q: we don't have q_scale here, just use the fp8 values directly
q_f = q_fp8.float() # (bsz, seqlen, n_heads, head_dim)

# weights: (bsz, seqlen, n_heads, 1) — absorb into q
q_weighted = (q_f * weights).sum(dim=2) # (bsz, seqlen, head_dim)

# Score: (bsz, seqlen, cache_len)
index_score = torch.bmm(q_weighted, k_f.transpose(1, 2))
return index_score
15 changes: 6 additions & 9 deletions examples/deepseek/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def ptq(
batch_size: int,
calib_size: int,
mla_quant: str | None = None,
disable_wo_quant: bool = False,
):
"""Runs Deepseek model PTQ and returns the quantized model."""

Expand Down Expand Up @@ -338,7 +339,7 @@ def calibrate_loop(model):
mtq_cfg["quant_cfg"]["*attn.kv_bmm_quantizer*"] = {"enable": False}
mtq_cfg["quant_cfg"]["*attn.pe_bmm_quantizer*"] = {"enable": False}

if not args.disable_wo_quant and "FP4" in quant_cfg:
if not disable_wo_quant and "FP4" in quant_cfg:
mtq_cfg["quant_cfg"]["*wo*weight_quantizer"] = mtq_cfg["quant_cfg"]["*input_quantizer"]
mtq_cfg["quant_cfg"]["*wo*input_quantizer"] = mtq_cfg["quant_cfg"]["*weight_quantizer"]

Expand Down Expand Up @@ -371,13 +372,6 @@ def state_dict_filter(state_dict):
os.path.join(output_path, f"amax_dict_rank{rank}-mp{world_size}.pt"),
)

# if rank == 0:
# with open("expert_activation_counts.txt", "w") as f:
# for name, module in model.named_modules():
# if isinstance(module, deekseep_model.MoE):
# counts = module.activated_expert_counts()
# f.writelines(f"{name}: {count}\n" for count in counts)

quant_config = get_quant_config(model)

if enable_fp8_kvcache:
Expand Down Expand Up @@ -439,5 +433,8 @@ def state_dict_filter(state_dict):
tokenizer = AutoTokenizer.from_pretrained(
args.model_path, trust_remote_code=args.trust_remote_code
)
model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size, args.mla_quant)
model = ptq(
model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size,
args.mla_quant, disable_wo_quant=args.disable_wo_quant,
)
save_amax_and_quant_config(model, args.output_path, not args.disable_fp8_kvcache)
49 changes: 39 additions & 10 deletions examples/deepseek/quantize_to_nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,22 @@
from typing import Any

import torch
from ds_kernel import weight_dequant
from safetensors.torch import load_file, save_file
from tqdm import tqdm

try:
from ds_kernel import weight_dequant
except ImportError:
weight_dequant = None

from modelopt.torch.quantization.qtensor import NVFP4QTensor


def _remap_key(key_dict: dict[str, Any]):
# renaming the module to match HF modeling
# The order matters here.
mappig = {
# Uses component-level replacement (split on ".") to avoid partial matches.
# Keys inside "indexer.*" are NOT remapped (they use the same names in HF).
mapping = {
"ffn": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
Expand All @@ -68,15 +73,25 @@ def _remap_key(key_dict: dict[str, Any]):
"wo": "o_proj",
"head": "lm_head",
}
# These keys appear inside indexer.* and must NOT be remapped
indexer_passthrough = {"wq_b", "wk", "k_norm", "weights_proj"}

new_dict = {}
for k, v in key_dict.items():
new_key = k.replace("layers", "model.layers")

for original_pattern, replace_pattern in mappig.items():
new_key = new_key.replace(original_pattern, replace_pattern)

new_dict[new_key] = v
parts = k.split(".")
new_parts = []
in_indexer = False
for part in parts:
if part == "indexer":
in_indexer = True
if part == "layers" and not new_parts:
new_parts.append("model")
new_parts.append("layers")
elif part in mapping and not (in_indexer and part in indexer_passthrough):
new_parts.append(mapping[part])
else:
new_parts.append(part)
new_dict[".".join(new_parts)] = v

key_dict.clear()
key_dict.update(new_dict)
Expand All @@ -90,6 +105,8 @@ def remove_quantization_config_from_original_config(export_dir: str) -> None:
config_path = os.path.join(export_dir, "config.json")
with open(config_path) as f:
cfg = json.load(f)
if "quantization_config" not in cfg:
return
del cfg["quantization_config"]
with open(config_path, "w") as f:
json.dump(cfg, f, indent=2, sort_keys=True)
Expand Down Expand Up @@ -202,6 +219,10 @@ def get_tensor(tensor_name):
if key.endswith("_scale_inv"):
continue
elif item.element_size() == 1: # FP8 weight
assert weight_dequant is not None, (
"ds_kernel is required to dequantize FP8 weights. "
"Install it or use a bf16 source checkpoint."
)
scale_inv_name = f"{key}_scale_inv"
try:
# Get scale_inv from the correct file
Expand All @@ -223,6 +244,12 @@ def get_tensor(tensor_name):
input_scale_key = layer_name + ".input_quantizer._amax"

if amax_key in renamed_state_dict:
# If per_layer_quant_config is non-empty (MIXED_PRECISION) but this layer is
# not listed, it was excluded from quantization (e.g. lm_head). Skip quant.
if per_layer_quant_config and layer_name not in per_layer_quant_config:
new_dict[key] = item
continue

# default quant is NVFP4
is_nvfp4 = (
not per_layer_quant_config
Expand Down Expand Up @@ -297,10 +324,12 @@ def get_tensor(tensor_name):
parser.add_argument(
"--fp4_path", type=str, required=True, help="path to save the fp4 checkpoint."
)
parser.add_argument("--fp8_hf_path", type=str, required=True, help="fp8 hf ckpt.")
parser.add_argument("--fp8_hf_path", "--hf_path", type=str, required=True, help="Source HF checkpoint (fp8 or bf16).")
parser.add_argument("--world_size", type=int, required=True, help="world size used by ptq.")
args = parser.parse_args()

os.makedirs(args.fp4_path, exist_ok=True)

per_layer_quant_config = process_quant_config(
quant_config_path=os.path.join(args.amax_path, "hf_quant_config.json"),
save_path=os.path.join(args.fp4_path, "hf_quant_config.json"),
Expand Down
60 changes: 60 additions & 0 deletions examples/deepseek/run_glm5_ptq.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/bin/bash
set -e

JOBID=${1:?Usage: $0 <job_id> [--skip-convert] [--amax-path <path>] [--mla-quant <type>]}
SKIP_CONVERT=false
AMAX_PATH=/fsw/models/glm-5-nvfp4-amax
MLA_QUANT=""

# Parse optional flags
shift
while [[ $# -gt 0 ]]; do
case "$1" in
--skip-convert) SKIP_CONVERT=true ;;
--amax-path) AMAX_PATH="$2"; shift ;;
--mla-quant) MLA_QUANT="$2"; shift ;;
*) echo "Unknown argument: $1"; exit 1 ;;
esac
shift
done

CONTAINER_IMAGE=$(readlink -f ~/fsw/containers/modelopt-v2.sqsh)
CONTAINER_MOUNTS=$(readlink -f ~/fsw):/fsw

HF_CKPT=/fsw/models/glm-5-bf16
DS_CKPT=/fsw/models/glm-5-ds
DS_V3_2_DIR=/fsw/Model-Optimizer/examples/deepseek/DeepSeek-V3.2-Exp
GLM5_CONFIG=${DS_V3_2_DIR}/inference/config_glm5.json

srun --overlap --jobid=${JOBID} --nodes=1 --ntasks=1 \
--container-image="${CONTAINER_IMAGE}" \
--container-mounts="${CONTAINER_MOUNTS}" \
--export="ALL,HF_TOKEN=${HF_TOKEN:?Set HF_TOKEN env var}" \
bash -c '
set -e
pip install --no-deps -e /fsw/Model-Optimizer

cd /fsw/Model-Optimizer/examples/deepseek

# Step 1: Convert HF bf16 checkpoint to DeepSeek sharded format
if [ "'"${SKIP_CONVERT}"'" != "true" ]; then
python '"${DS_V3_2_DIR}"'/inference/convert.py \
--hf-ckpt-path '"${HF_CKPT}"' \
--save-path '"${DS_CKPT}"' \
--n-experts 256 \
--model-parallel 8
else
echo "Skipping conversion (--skip-convert)"
fi

# Step 2: Run PTQ calibration
torchrun --nproc-per-node 8 --master_port=12346 ptq.py \
--model_path '"${DS_CKPT}"' \
--config '"${GLM5_CONFIG}"' \
--quant_cfg NVFP4_DEFAULT_CFG \
--output_path '"${AMAX_PATH}"' \
--trust_remote_code \
--batch_size 8 \
--calib_size 512 \
'"${MLA_QUANT:+--mla_quant $MLA_QUANT}"'
'
Loading