Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,25 @@ def get_dtype(dtype):
return dtype


def maybe_patch_deepseek_v3_config(hf_config):
"""Patch DeepSeek V3 config to add missing qk_head_dim attribute if needed.

Args:
hf_config: HuggingFace model config object

Returns:
The patched config object
"""
if hf_config.model_type == "deepseek_v3" and not hasattr(hf_config, "qk_head_dim"):
if hasattr(hf_config, "qk_nope_head_dim") and hasattr(hf_config, "qk_rope_head_dim"):
hf_config.qk_head_dim = hf_config.qk_nope_head_dim + hf_config.qk_rope_head_dim
print(
f"Patched DeepSeek V3 config: qk_head_dim = {hf_config.qk_head_dim} "
f"(qk_nope_head_dim={hf_config.qk_nope_head_dim} + qk_rope_head_dim={hf_config.qk_rope_head_dim})"
)
return hf_config


def get_model(
ckpt_path,
device="cuda",
Expand Down Expand Up @@ -301,6 +320,8 @@ def get_model(
# Load config once and handle VL model detection
try:
hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs)
hf_config = maybe_patch_deepseek_v3_config(hf_config)

if is_nemotron_vl(hf_config):
print(
"Detected Nemotron VL model from config. "
Expand Down
9 changes: 9 additions & 0 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_tokenizer,
is_enc_dec,
is_nemotron_vl,
maybe_patch_deepseek_v3_config,
run_nemotron_vl_preview,
)
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -270,12 +271,20 @@ def load_model(args: argparse.Namespace):
)

# Do not use real quant GEMM so the calibration can be more accurate.
# Load and patch config for DeepSeek V3 before initializing the model
from transformers import AutoConfig

config_kwargs = {"trust_remote_code": args.trust_remote_code}
hf_config = AutoConfig.from_pretrained(args.pyt_ckpt_path, **config_kwargs)
hf_config = maybe_patch_deepseek_v3_config(hf_config)

with init_quantized_weights(
quant_cfg, gpu_mem_percentage=args.gpu_max_mem_percentage, quant_gemm=False
):
model_kwargs = {"trust_remote_code": args.trust_remote_code}
if args.attn_implementation is not None:
model_kwargs["attn_implementation"] = args.attn_implementation
model_kwargs["config"] = hf_config
full_model = AutoModelForCausalLM.from_pretrained(
args.pyt_ckpt_path,
**model_kwargs,
Expand Down
10 changes: 9 additions & 1 deletion examples/llm_ptq/multinode_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch
import torch.nn as nn
from accelerate import Accelerator
from example_utils import build_quant_cfg, get_tokenizer
from example_utils import build_quant_cfg, get_tokenizer, maybe_patch_deepseek_v3_config
from tqdm import tqdm
from transformers import AutoModelForCausalLM, PreTrainedTokenizer, PreTrainedTokenizerFast

Expand Down Expand Up @@ -146,10 +146,18 @@ def load_and_prepare_model(
Returns:
Tuple of (prepared_model, model_type, original_architectures, calibration_dataloader)
"""
# Load and patch config for DeepSeek V3 before initializing the model
from transformers import AutoConfig

config_kwargs = {"trust_remote_code": trust_remote_code}
hf_config = AutoConfig.from_pretrained(model_path, **config_kwargs)
hf_config = maybe_patch_deepseek_v3_config(hf_config)

model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype="auto",
trust_remote_code=trust_remote_code,
config=hf_config,
)
model.eval()
model_type = get_model_type(model)
Expand Down
Loading