From 70d3676b8c370a8a374016586075474f8f2bb9bf Mon Sep 17 00:00:00 2001 From: Charles2530 <2569337619@qq.com> Date: Tue, 10 Mar 2026 10:57:08 +0800 Subject: [PATCH 1/6] feat: add wan2.2_t2v model and quantization config --- .../video_gen/wan2_2_t2v/awq_w_a.yaml | 49 +++++ llmc/eval/eval_video_generate.py | 45 ++-- llmc/models/__init__.py | 1 + llmc/models/base_model.py | 4 +- llmc/models/wan2_2_t2v.py | 193 ++++++++++++++++++ 5 files changed, 271 insertions(+), 21 deletions(-) create mode 100644 configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml create mode 100755 llmc/models/wan2_2_t2v.py diff --git a/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml new file mode 100644 index 000000000..16c9e9929 --- /dev/null +++ b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml @@ -0,0 +1,49 @@ +base: + seed: &seed 42 +model: + type: Wan2T2V + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.2-T2V-A14B-Diffusers + torch_dtype: auto +calib: + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + sample_steps: 20 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + seed: *seed +eval: + eval_pos: [transformed, fake_quant] + type: video_gen + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_awq/ +quant: + video_gen: + method: Awq + weight: + bit: 4 + symmetric: True + granularity: per_channel + group_size: -1 + act: + bit: 4 + symmetric: True + granularity: per_token + special: + trans: True + trans_version: v2 + weight_clip: True + clip_sym: True +save: + save_lightx2v: True + save_path: ../lightx2v/wan2_2_t2v_awq_w_a/x2v/ diff --git a/llmc/eval/eval_video_generate.py b/llmc/eval/eval_video_generate.py index 0f99ff6c9..726187c0b 100755 --- a/llmc/eval/eval_video_generate.py +++ b/llmc/eval/eval_video_generate.py @@ -23,6 +23,7 @@ def __init__(self, model, config): self.target_width = self.eval_cfg.get('target_width', 832) self.num_frames = self.eval_cfg.get('num_frames', 81) self.guidance_scale = self.eval_cfg.get('guidance_scale', 5.0) + self.guidance_scale_2 = self.eval_cfg.get('guidance_scale_2', None) self.fps = self.eval_cfg.get('fps', 15) @torch.no_grad() @@ -56,14 +57,17 @@ def t2v_eval(self, model, testenc, bs, eval_pos): assert bs == 1, 'Only support eval bs=1' for i, data in enumerate(testenc): - output = model.Pipeline( - prompt=data['prompt'], - negative_prompt=data['negative_prompt'], - height=self.target_height, - width=self.target_width, - num_frames=self.num_frames, - guidance_scale=self.guidance_scale, - ).frames[0] + pipe_kw = { + 'prompt': data['prompt'], + 'negative_prompt': data['negative_prompt'], + 'height': self.target_height, + 'width': self.target_width, + 'num_frames': self.num_frames, + 'guidance_scale': self.guidance_scale, + } + if self.guidance_scale_2 is not None: + pipe_kw['guidance_scale_2'] = self.guidance_scale_2 + output = model.Pipeline(**pipe_kw).frames[0] export_to_video( output, os.path.join(self.output_video_path, f'{eval_pos}_output_{i}.mp4'), @@ -77,15 +81,18 @@ def i2v_eval(self, model, testenc, bs, eval_pos): for i, data in enumerate(testenc): image, width, height = self.pre_process(model, data['image']) - output = model.Pipeline( - image=image, - prompt=data['prompt'], - negative_prompt=data['negative_prompt'], - height=height, - width=width, - num_frames=self.num_frames, - guidance_scale=self.guidance_scale, - ).frames[0] + pipe_kw = { + 'image': image, + 'prompt': data['prompt'], + 'negative_prompt': data['negative_prompt'], + 'height': height, + 'width': width, + 'num_frames': self.num_frames, + 'guidance_scale': self.guidance_scale, + } + if self.guidance_scale_2 is not None: + pipe_kw['guidance_scale_2'] = self.guidance_scale_2 + output = model.Pipeline(**pipe_kw).frames[0] export_to_video( output, @@ -98,9 +105,9 @@ def i2v_eval(self, model, testenc, bs, eval_pos): @torch.no_grad() def eval_func(self, model, testenc, bs, eval_pos): assert bs == 1, 'Evaluation only supports batch size = 1.' - assert self.model_type in ['WanT2V', 'WanI2V'], ( + assert self.model_type in ['WanT2V', 'WanI2V', 'Wan2T2V'], ( f"Unsupported model type '{self.model_type}'.\n" - 'Only Wan2.1 video generation models (WanT2V, WanI2V) are supported.' + 'Only Wan video generation models (WanT2V, WanI2V, Wan2T2V) are supported.' ) if self.eval_dataset_name == 't2v': return self.t2v_eval(model, testenc, bs, eval_pos) diff --git a/llmc/models/__init__.py b/llmc/models/__init__.py index 83d746254..7351995df 100755 --- a/llmc/models/__init__.py +++ b/llmc/models/__init__.py @@ -37,3 +37,4 @@ from .vit import Vit from .wan_i2v import WanI2V from .wan_t2v import WanT2V +from .wan2_2_t2v import Wan2T2V diff --git a/llmc/models/base_model.py b/llmc/models/base_model.py index 4d7dda2ae..25393a871 100755 --- a/llmc/models/base_model.py +++ b/llmc/models/base_model.py @@ -119,7 +119,7 @@ def has_bias(self): pass def build_tokenizer(self): - if self.model_type not in ['Vit', 'WanT2V', 'WanI2V']: + if self.model_type not in ['Vit', 'WanT2V', 'WanI2V', 'Wan2T2V']: assert self.tokenizer_mode in ['fast', 'slow'] self.tokenizer = AutoTokenizer.from_pretrained( self.model_path, use_fast=self.tokenizer_mode, trust_remote_code=True @@ -129,7 +129,7 @@ def build_tokenizer(self): if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token else: - self.tokenizer = None + self.tokenizer = None def get_tokenizer(self): return self.tokenizer diff --git a/llmc/models/wan2_2_t2v.py b/llmc/models/wan2_2_t2v.py new file mode 100755 index 000000000..bf603c536 --- /dev/null +++ b/llmc/models/wan2_2_t2v.py @@ -0,0 +1,193 @@ +import inspect +from collections import defaultdict + +import torch +import torch.nn as nn +from diffusers import AutoencoderKLWan, WanPipeline +from loguru import logger + +from llmc.compression.quantization.module_utils import LlmcWanTransformerBlock +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .base_model import BaseModel + + +@MODEL_REGISTRY +class Wan2T2V(BaseModel): + """Wan2.2-T2V with MoE: two experts (high-noise + low-noise), same block structure as Wan2.1.""" + + def __init__(self, config, device_map=None, use_cache=False): + super().__init__(config, device_map, use_cache) + if 'calib' in config: + self.calib_bs = config.calib.bs + self.sample_steps = config.calib.sample_steps + self.target_height = config.calib.get('target_height', 480) + self.target_width = config.calib.get('target_width', 832) + self.num_frames = config.calib.get('num_frames', 81) + self.guidance_scale = config.calib.get('guidance_scale', 5.0) + self.guidance_scale_2 = config.calib.get('guidance_scale_2', 3.0) + else: + self.sample_steps = None + + def build_model(self): + vae = AutoencoderKLWan.from_pretrained( + self.model_path, + subfolder='vae', + torch_dtype=torch.float32, + use_safetensors=True, + ) + # Wan2.2: one pipeline, two transformer experts (transformer + transformer_2). + # Pipeline switches by SNR; both use WanTransformer3DModel with same block layout as Wan2.1. + self.Pipeline = WanPipeline.from_pretrained( + self.model_path, + vae=vae, + torch_dtype=torch.bfloat16, + use_safetensors=True, + ) + self.find_llmc_model() + # Wrap both experts with LlmcWanTransformerBlock (same as Wan2.1 per-block layout). + for block_idx, block in enumerate(self.Pipeline.transformer.blocks): + new_block = LlmcWanTransformerBlock.new(block) + self.Pipeline.transformer.blocks[block_idx] = new_block + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + for block_idx, block in enumerate(self.Pipeline.transformer_2.blocks): + new_block = LlmcWanTransformerBlock.new(block) + self.Pipeline.transformer_2.blocks[block_idx] = new_block + self.blocks = list(self.Pipeline.transformer.blocks) + list( + self.Pipeline.transformer_2.blocks + ) + logger.info( + 'Wan2.2 MoE: both experts wrapped (high-noise + low-noise, 80 blocks total).' + ) + else: + self.blocks = list(self.Pipeline.transformer.blocks) + logger.info('Wan2.2: single transformer wrapped (40 blocks).') + logger.info('Model: %s', self.model) + + def find_llmc_model(self): + self.model = self.Pipeline.transformer + + def find_blocks(self): + self.blocks = self.model.blocks + + def get_catcher(self, first_block_input): + sample_steps = self.sample_steps + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + self.signature = inspect.signature(module.forward) + self.step = 0 + + def forward(self, *args, **kwargs): + params = list(self.signature.parameters.keys()) + for i, arg in enumerate(args): + if i > 0: + kwargs[params[i]] = arg + first_block_input['data'].append(args[0]) + first_block_input['kwargs'].append(kwargs) + self.step += 1 + if self.step == sample_steps: + raise ValueError + else: + return self.module(*args) + + return Catcher + + @torch.no_grad() + def collect_first_block_input(self, calib_data, padding_mask=None): + first_block_input = defaultdict(list) + Catcher = self.get_catcher(first_block_input) + # Install Catcher on the pipeline's first block so forward passes go through it. + first_block = self.Pipeline.transformer.blocks[0] + self.Pipeline.transformer.blocks[0] = Catcher(first_block) + self.Pipeline.to('cuda') + for data in calib_data: + self.Pipeline.transformer.blocks[0].step = 0 + try: + pipe_kw = { + 'prompt': data['prompt'], + 'negative_prompt': data['negative_prompt'], + 'height': self.target_height, + 'width': self.target_width, + 'num_frames': self.num_frames, + 'guidance_scale': self.guidance_scale, + } + if hasattr(self, 'guidance_scale_2'): + pipe_kw['guidance_scale_2'] = self.guidance_scale_2 + self.Pipeline(**pipe_kw) + except ValueError: + pass + + self.first_block_input = first_block_input + assert len(self.first_block_input['data']) > 0, 'Catch input data failed.' + self.n_samples = len(self.first_block_input['data']) + logger.info('Retrieved %s calibration samples for Wan2.2 T2V.', self.n_samples) + self.Pipeline.transformer.blocks[0] = self.Pipeline.transformer.blocks[0].module + self.Pipeline.to('cpu') + + def get_padding_mask(self): + return None + + def has_bias(self): + return True + + def __str__(self): + return '\nWan2.2 MoE Model:\n%s\nTotal params: ~27B (14B active per step)' % ( + str(self.model), + ) + + def get_layernorms_in_block(self, block): + return { + 'affine_norm1': block.affine_norm1, + 'norm2': block.norm2, + 'affine_norm3': block.affine_norm3, + } + + def get_subsets_in_block(self, block): + return [ + { + 'layers': { + 'attn1.to_q': block.attn1.to_q, + 'attn1.to_k': block.attn1.to_k, + 'attn1.to_v': block.attn1.to_v, + }, + 'prev_op': [block.affine_norm1], + 'input': ['attn1.to_q'], + 'inspect': block.attn1, + 'has_kwargs': True, + 'sub_keys': {'rotary_emb': 'rotary_emb'}, + }, + { + 'layers': { + 'attn2.to_q': block.attn2.to_q, + }, + 'prev_op': [block.norm2], + 'input': ['attn2.to_q'], + 'inspect': block.attn2, + 'has_kwargs': True, + 'sub_keys': {'encoder_hidden_states': 'encoder_hidden_states'}, + }, + { + 'layers': { + 'ffn.net.0.proj': block.ffn.net[0].proj, + }, + 'prev_op': [block.affine_norm3], + 'input': ['ffn.net.0.proj'], + 'inspect': block.ffn, + 'has_kwargs': True, + }, + ] + + def find_embed_layers(self): + pass + + def get_embed_layers(self): + pass + + def get_layers_except_blocks(self): + pass + + def skip_layer_name(self): + pass From 84f89f9b19f4eb2b7f568a8cf570df8adeb33e21 Mon Sep 17 00:00:00 2001 From: Charles2530 <2569337619@qq.com> Date: Wed, 11 Mar 2026 21:29:59 +0800 Subject: [PATCH 2/6] feat: wan2.2-t2v quantization configs and model updates --- .gitignore | 2 + .../video_gen/wan2_2_t2v/awq_w_a.yaml | 4 + .../video_gen/wan_i2v/awq_w_a.yaml | 8 +- .../wan_i2v/smoothquant_w_a_fp8_example.yaml | 57 ++++ .../video_gen/wan_t2v/awq_w_a.yaml | 12 +- .../video_gen/wan_t2v/awq_w_a_s.yaml | 49 +++ .../video_gen/wan_t2v/rtn_w_a.yaml | 32 -- .../video_gen/wan_t2v/smoothquant_w_a.yaml | 22 +- docs/wan2.1_quantization_guide.md | 288 ++++++++++++++++++ llmc/compression/quantization/__init__.py | 2 +- .../base_blockwise_quantization.py | 16 +- llmc/compression/quantization/quant.py | 98 ++++++ llmc/models/wan_t2v.py | 7 +- scripts/run_llmc.sh | 37 +-- 14 files changed, 557 insertions(+), 77 deletions(-) create mode 100644 configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8_example.yaml create mode 100755 configs/quantization/video_gen/wan_t2v/awq_w_a_s.yaml delete mode 100755 configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml create mode 100644 docs/wan2.1_quantization_guide.md diff --git a/.gitignore b/.gitignore index 896b38a12..24b47eaa5 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,5 @@ save* .log *.pid *.ipynb* +models/ +output_*HiFloat4/ diff --git a/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml index 16c9e9929..1540e8b75 100644 --- a/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml @@ -31,11 +31,15 @@ quant: video_gen: method: Awq weight: + # quant_type: int-quant + quant_type: hif4 bit: 4 symmetric: True granularity: per_channel group_size: -1 act: + # quant_type: int-qu + quant_type: hif4 bit: 4 symmetric: True granularity: per_token diff --git a/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml b/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml index 680fab43b..1b1097ad7 100755 --- a/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml @@ -2,7 +2,7 @@ base: seed: &seed 42 model: type: WanI2V - path: /path/to/model + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.2-T2V-A14B/ torch_dtype: auto calib: name: i2v @@ -31,12 +31,12 @@ quant: video_gen: method: Awq weight: - bit: 8 + bit: 4 symmetric: True granularity: per_channel group_size: -1 act: - bit: 8 + bit: 4 symmetric: True granularity: per_token special: @@ -46,4 +46,4 @@ quant: clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: ../lightx2v/wan_i2v_awq_w_a/x2v/ diff --git a/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8_example.yaml b/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8_example.yaml new file mode 100644 index 000000000..adba728d0 --- /dev/null +++ b/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8_example.yaml @@ -0,0 +1,57 @@ +# Wan2.1 I2V FP8 量化配置示例 +# 这是一个快速开始的配置文件,请根据实际情况修改路径 + +base: + seed: &seed 42 + +model: + type: WanI2V + path: /path/to/wan2.1-i2v-model # 修改为你的 Wan2.1 I2V 模型路径 + torch_dtype: auto + +calib: + name: i2v + download: False + path: /path/to/calibration/data # 修改为你的校准数据路径 + sample_steps: 40 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + seed: *seed + +eval: + eval_pos: [fake_quant] + type: video_gen + name: i2v + download: False + path: /path/to/eval/data # 修改为你的评估数据路径 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_fp8/ + +quant: + video_gen: + method: SmoothQuant + weight: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_channel + use_qtorch: True + act: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_token + use_qtorch: True + special: + alpha: 0.75 # SmoothQuant 平衡参数,范围 0.5-1.0 + +save: + save_lightx2v: True # 保存为 lightx2v 兼容格式 + save_path: /path/to/save/quantized/model # 修改为你的保存路径 diff --git a/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml b/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml index 14d05479d..ec6d8714e 100755 --- a/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml @@ -2,12 +2,12 @@ base: seed: &seed 42 model: type: WanT2V - path: /path/to/wan_t2v + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.1-T2V-14B-Diffusers torch_dtype: auto calib: name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ sample_steps: 20 bs: 1 target_height: 480 @@ -20,7 +20,7 @@ eval: type: video_gen name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ bs: 1 target_height: 480 target_width: 832 @@ -31,12 +31,12 @@ quant: video_gen: method: Awq weight: - bit: 6 + bit: 4 symmetric: True granularity: per_channel group_size: -1 act: - bit: 6 + bit: 4 symmetric: True granularity: per_token special: @@ -46,4 +46,4 @@ quant: clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: ../lightx2v/wan_t2v_awq_w_a/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/awq_w_a_s.yaml b/configs/quantization/video_gen/wan_t2v/awq_w_a_s.yaml new file mode 100755 index 000000000..f140839e3 --- /dev/null +++ b/configs/quantization/video_gen/wan_t2v/awq_w_a_s.yaml @@ -0,0 +1,49 @@ +base: + seed: &seed 42 +model: + type: WanT2V + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.1-T2V-1.3B-Diffusers + torch_dtype: auto +calib: + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + sample_steps: 20 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + seed: *seed +eval: + eval_pos: [transformed, fake_quant] + type: video_gen + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_awq/ +quant: + video_gen: + method: Awq + weight: + bit: 4 + symmetric: True + granularity: per_channel + group_size: -1 + act: + bit: 4 + symmetric: True + granularity: per_token + special: + trans: True + trans_version: v2 + weight_clip: True + clip_sym: True +save: + save_lightx2v: True + save_path: ../lightx2v/wan_t2v_awq_w_a_s/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml b/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml deleted file mode 100755 index b6a53b0e0..000000000 --- a/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml +++ /dev/null @@ -1,32 +0,0 @@ -base: - seed: &seed 42 -model: - type: WanT2V - path: /path/to/wan_t2v - torch_dtype: auto -eval: - eval_pos: [transformed, fake_quant] - type: video_gen - name: t2v - download: False - path: ../assets/wan_t2v/eval/ - bs: 1 - target_height: 480 - target_width: 832 - num_frames: 81 - guidance_scale: 5.0 - output_video_path: ./output_videos_rtn/ -quant: - video_gen: - method: RTN - weight: - bit: 6 - symmetric: True - granularity: per_channel - act: - bit: 6 - symmetric: True - granularity: per_token -save: - save_lightx2v: True - save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml b/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml index 7d65f31fc..f76edd294 100755 --- a/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml @@ -2,12 +2,12 @@ base: seed: &seed 42 model: type: WanT2V - path: /path/to/wan_t2v + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.2-T2V-14B-Diffusers torch_dtype: auto calib: name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ sample_steps: 20 bs: 1 target_height: 480 @@ -20,26 +20,30 @@ eval: type: video_gen name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ bs: 1 target_height: 480 target_width: 832 num_frames: 81 guidance_scale: 5.0 - output_video_path: ./output_videos_sq/ + output_video_path: ./output_videos_awq/ quant: video_gen: - method: SmoothQuant + method: Awq weight: - bit: 6 + bit: 4 symmetric: True granularity: per_channel + group_size: -1 act: - bit: 6 + bit: 4 symmetric: True granularity: per_token special: - alpha: 0.7 + trans: True + trans_version: v2 + weight_clip: True + clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: ../lightx2v/wan_t2v_awq_w_a/x2v/ diff --git a/docs/wan2.1_quantization_guide.md b/docs/wan2.1_quantization_guide.md new file mode 100644 index 000000000..eeef5ac63 --- /dev/null +++ b/docs/wan2.1_quantization_guide.md @@ -0,0 +1,288 @@ +# Wan2.1 视频生成模型量化指南 + +## 概述 + +llmc 框架现已全面支持 Wan2.1 系列视频生成模型的量化,并提供真正量化的 INT8/FP8 权重导出,与 lightx2v 推理框架兼容。 + +## 支持的模型类型 + +- **WanI2V**: Image-to-Video (图像到视频) +- **WanT2V**: Text-to-Video (文本到视频) + +## 支持的量化方法 + +### FP8 量化 (推荐) + +**配置文件**: `configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8.yaml` + +**特点**: +- 使用 E4M3 FP8 格式 (8-bit 浮点数,4位指数,3位尾数) +- SmoothQuant 算法,平衡权重和激活的量化难度 +- 适合 GPU 推理,性能损失小 + +**量化配置**: +```yaml +quant: + video_gen: + method: SmoothQuant + weight: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_channel + use_qtorch: True + act: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_token + use_qtorch: True + special: + alpha: 0.75 # SmoothQuant 平衡参数 +``` + +### INT8 量化 + +#### 1. RTN (Round-to-Nearest) +**配置文件**: `configs/quantization/video_gen/wan_i2v/rtn_w_a.yaml` + +**特点**: +- 最简单的量化方法 +- 直接四舍五入到最近的量化级别 +- 速度快,精度略低 + +#### 2. AWQ (Activation-aware Weight Quantization) +**配置文件**: `configs/quantization/video_gen/wan_i2v/awq_w_a.yaml` + +**特点**: +- 基于激活分布优化权重量化 +- 保护重要通道,减少精度损失 +- 需要校准数据 + +#### 3. SmoothQuant +**配置文件**: `configs/quantization/video_gen/wan_i2v/smoothquant_w_a.yaml` + +**特点**: +- 平衡权重和激活的量化难度 +- 数学上等价于平滑激活异常值 +- 通常提供最佳精度 + +### LoRA 模型量化 + +支持对 LoRA 适配器模型的量化: +- `smoothquant_w_a_int8_lora.yaml` +- `rtn_w_a_lora.yaml` + +## 运行步骤 + +### 1. 准备环境 + +```bash +# 设置 llmc 路径 +export llmc=/path/to/llmc +export PYTHONPATH=$llmc:$PYTHONPATH + +# 设置 GPU +export CUDA_VISIBLE_DEVICES=0 +``` + +### 2. 准备校准数据 + +为 I2V 模型准备校准数据: +``` +assets/wan_i2v/calib/ +├── image_1.jpg +├── image_2.jpg +└── ... +``` + +为 T2V 模型准备校准数据: +``` +assets/wan_t2v/calib/ +├── prompt_1.txt +├── prompt_2.txt +└── ... +``` + +### 3. 修改配置文件 + +编辑对应的 YAML 配置文件,设置: +- `model.path`: Wan2.1 模型路径 +- `calib.path`: 校准数据路径 +- `save.save_path`: 量化模型保存路径 + +**示例 (FP8 量化)**: +```yaml +base: + seed: 42 +model: + type: WanI2V + path: /path/to/wan2.1-i2v-model # 修改为你的模型路径 + torch_dtype: auto +calib: + name: i2v + download: False + path: /path/to/calibration/data # 修改为校准数据路径 + sample_steps: 40 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 +save: + save_lightx2v: True + save_path: /path/to/save/quantized/model # 修改为保存路径 +``` + +### 4. 运行量化 + +#### 使用脚本运行 (推荐) + +```bash +# 运行 FP8 量化 (I2V) +./run_llmc.sh wan_i2v_fp8 + +# 运行 INT8 RTN 量化 (I2V) +./run_llmc.sh wan_i2v_int8_rtn + +# 运行 INT8 AWQ 量化 (I2V) +./run_llmc.sh wan_i2v_int8_awq + +# 运行 INT8 SmoothQuant 量化 (I2V) +./run_llmc.sh wan_i2v_int8_smoothquant + +# 运行 T2V 模型量化 +./run_llmc.sh wan_t2v_int8_rtn +./run_llmc.sh wan_t2v_int8_awq +./run_llmc.sh wan_t2v_int8_smoothquant +``` + +#### 直接运行命令 + +```bash +torchrun \ +--nnodes 1 \ +--nproc_per_node 1 \ +--rdzv_id $RANDOM \ +--rdzv_backend c10d \ +--rdzv_endpoint 127.0.0.1:29500 \ +${llmc}/llmc/__main__.py \ +--config configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8.yaml \ +--task_id my_quant_task +``` + +### 5. 监控进度 + +```bash +# 查看日志 +tail -f wan_i2v_fp8.log + +# 查看进程 +ps aux | grep __main__.py +``` + +### 6. 停止任务 + +```bash +# 使用保存的 PID 文件 +xargs kill -9 < wan_i2v_fp8.pid +``` + +## 配置参数说明 + +### 模型配置 +- `type`: 模型类型 (`WanI2V` 或 `WanT2V`) +- `path`: 模型权重路径 +- `torch_dtype`: 数据类型 (`auto`, `bfloat16`, `float32`) + +### 校准配置 +- `sample_steps`: 采样步数 (通常 20-40) +- `bs`: 批大小 (通常 1,视频生成显存占用大) +- `target_height`: 目标视频高度 (默认 480) +- `target_width`: 目标视频宽度 (默认 832) +- `num_frames`: 视频帧数 (默认 81) +- `guidance_scale`: CFG 引导强度 (默认 5.0) + +### 量化配置 +- `method`: 量化方法 (`RTN`, `Awq`, `SmoothQuant`) +- `weight.bit`: 权重位宽 (8, e4m3) +- `act.bit`: 激活位宽 (8, e4m3) +- `granularity`: 量化粒度 (`per_channel`, `per_token`) +- `special.alpha`: SmoothQuant 平衡参数 (0.5-1.0) + +## 在 lightx2v 中使用量化模型 + +### 1. 配置 lightx2v + +编辑 `lightx2v/configs/quantization/wan_i2v.json`: +```json +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "dit_quantized_ckpt": "/path/to/quantized/model", + "dit_quantized": true, + "dit_quant_scheme": "int8-vllm" +} +``` + +对于 FP8 模型,设置 `"dit_quant_scheme": "fp8"`。 + +### 2. 运行推理 + +```bash +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path /path/to/original/model \ +--config_json configs/quantization/wan_i2v.json \ +--prompt "Your prompt here" \ +--image_path /path/to/input/image.jpg \ +--save_result_path output.mp4 +``` + +## 性能建议 + +1. **FP8 vs INT8**: + - FP8: 精度更高,适合对质量要求高的场景 + - INT8: 压缩率更高,适合对速度要求高的场景 + +2. **量化方法选择**: + - 快速原型: RTN + - 平衡精度和速度: SmoothQuant + - 最高精度: AWQ + +3. **校准数据**: + - 使用 10-50 个样本 + - 覆盖典型使用场景 + - I2V: 使用多样化图像 + - T2V: 使用多样化文本描述 + +4. **资源需求**: + - GPU: 建议 24GB+ 显存 + - 校准时间: 30分钟 - 2小时 (取决于数据量) + - 存储空间: 量化后模型约原模型 25-50% 大小 + +## 故障排除 + +### 显存不足 +- 减小 `bs` 到 1 +- 减小 `num_frames` +- 减小 `target_height` 和 `target_width` + +### 量化精度损失过大 +- 尝试 SmoothQuant 方法 +- 增加校准数据数量 +- 调整 `alpha` 参数 (0.5-1.0) + +### lightx2v 兼容性问题 +- 确保使用 `save_lightx2v: True` +- 检查 `dit_quant_scheme` 设置 +- 确认量化模型路径正确 + +## 参考 + +- lightx2v 文档: [lightx2v 项目地址] +- llmc 框架: [llmc 项目地址] +- Wan2.1 模型: [模型地址] diff --git a/llmc/compression/quantization/__init__.py b/llmc/compression/quantization/__init__.py index 2c08343e2..07b4f5967 100644 --- a/llmc/compression/quantization/__init__.py +++ b/llmc/compression/quantization/__init__.py @@ -10,7 +10,7 @@ from .ntweak import NormTweaking from .omniq import OmniQuant from .osplus import OsPlus -from .quant import FloatQuantizer, IntegerQuantizer +from .quant import FloatQuantizer, HiFloat4Quantizer, IntegerQuantizer from .quarot import Quarot from .quik import QUIK from .rtn import RTN diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 5a2232699..0c3d5474f 100755 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -35,7 +35,12 @@ _TRANSFORMERS_LN_TYPES_, EffcientFakeQuantLinear, FakeQuantLinear, LlmcActFn, OriginFloatLinear, RotateLinear) -from .quant import FloatQuantizer, IntegerQuantizer, Weight48IntegerQuantizer +from .quant import ( + FloatQuantizer, + HiFloat4Quantizer, + IntegerQuantizer, + Weight48IntegerQuantizer, +) class BaseBlockwiseQuantization(BlockwiseOpt): @@ -157,6 +162,8 @@ def set_quant_config(self): self.weight_quant_module = IntegerQuantizer elif quant_type == 'float-quant': self.weight_quant_module = FloatQuantizer + elif quant_type == 'hif4': + self.weight_quant_module = HiFloat4Quantizer logger.info(f'The used Weight Quant Module is {self.weight_quant_module}') self.wquantizer = self.weight_quant_module(**self.quant_config['weight']) @@ -175,6 +182,13 @@ def set_quant_config(self): self.act_quant_module = IntegerQuantizer elif quant_type == 'float-quant': self.act_quant_module = FloatQuantizer + elif quant_type == 'hif4': + self.act_quant_module = HiFloat4Quantizer + else: + raise ValueError( + f"Unsupported act quant_type: {quant_type}. " + "Supported: int-quant, float-quant, hif4." + ) self.quant_config['act']['tp'] = self.tp self.aquantizer = self.act_quant_module(**self.quant_config['act']) self.act_static = self.quant_config['act'].get('static', False) diff --git a/llmc/compression/quantization/quant.py b/llmc/compression/quantization/quant.py index 2c24c03a8..55cd791a1 100755 --- a/llmc/compression/quantization/quant.py +++ b/llmc/compression/quantization/quant.py @@ -1,4 +1,6 @@ import gc +import os +import sys import torch from loguru import logger @@ -1229,6 +1231,102 @@ def __repr__(self): ) +def _get_hif4_quant_cy(): + """Lazy import HiFloat4 quant_cy (QType, quant_dequant_float) from HiFloat4/hif4_gpu.""" + _repo_root = os.path.dirname( + os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + ) + _hif4_gpu = os.path.join(_repo_root, 'HiFloat4', 'hif4_gpu') + if _hif4_gpu not in sys.path: + sys.path.insert(0, _hif4_gpu) + try: + from quant_cy import QType, quant_dequant_float + return QType, quant_dequant_float + except Exception as e: + raise ImportError( + 'HiFloat4 4-bit quantization requires the HiFloat4/hif4_gpu package. ' + 'Ensure HiFloat4 is available at repo_root/HiFloat4/hif4_gpu and built.' + ) from e + + +class HiFloat4Quantizer(BaseQuantizer): + """4-bit HiFloat (hif4) simulation quantizer using HiFloat4 quant_dequant_float. + + Uses the HiFloat4 library's quant_dequant_float for block-wise float 4-bit + quantization. No scales/zeros; quantization is done per block along the last dim. + Only supports fake (simulation) quantization; real weight packing is not implemented. + """ + + def __init__(self, bit=4, symmetric=None, granularity=None, **kwargs): + super().__init__(bit, symmetric, granularity, **kwargs) + self.quant_type = 'hif4' + self.q_dim = kwargs.get('hif4_qdim', -1) + self.force_py = kwargs.get('force_py', False) + self.force_fp32 = kwargs.get('force_fp32', True) + self._QType = None + self._quant_dequant_float = None + + def _ensure_hif4(self): + if self._quant_dequant_float is None: + self._QType, self._quant_dequant_float = _get_hif4_quant_cy() + + def fake_quant_act_static(self, act, args={}): + self._ensure_hif4() + org_dtype = act.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + act, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def fake_quant_act_dynamic(self, act, args={}): + self._ensure_hif4() + org_dtype = act.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + act, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def fake_quant_weight_static(self, weight, args): + self._ensure_hif4() + org_dtype = weight.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + weight, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def fake_quant_weight_dynamic(self, weight, args={}): + self._ensure_hif4() + org_dtype = weight.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + weight, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def real_quant_weight_static(self, weight, args): + raise NotImplementedError( + 'HiFloat4 quantizer is simulation-only (fake quant). ' + 'real_quant_weight is not supported for hif4.' + ) + + def real_quant_weight_dynamic(self, weight, args={}): + raise NotImplementedError( + 'HiFloat4 quantizer is simulation-only (fake quant). ' + 'real_quant_weight is not supported for hif4.' + ) + + def __repr__(self): + return ( + f'HiFloat4Quantizer(quant_type=hif4, q_dim={self.q_dim}, ' + f'force_py={self.force_py}, force_fp32={self.force_fp32})' + ) + + class Weight48IntegerQuantizer(BaseQuantizer): # flake8: noqa def __init__(self, bit, bit4, bit8, **kwargs): diff --git a/llmc/models/wan_t2v.py b/llmc/models/wan_t2v.py index 885bccda3..ec1f0650c 100755 --- a/llmc/models/wan_t2v.py +++ b/llmc/models/wan_t2v.py @@ -31,10 +31,13 @@ def __init__(self, config, device_map=None, use_cache=False): def build_model(self): vae = AutoencoderKLWan.from_pretrained( - self.model_path, subfolder='vae', torch_dtype=torch.float32 + self.model_path, subfolder='vae', torch_dtype=torch.float32, use_safetensors=True ) + # self.Pipeline = WanPipeline.from_pretrained( + # self.model_path, vae=vae, torch_dtype=torch.bfloat16 + # ) self.Pipeline = WanPipeline.from_pretrained( - self.model_path, vae=vae, torch_dtype=torch.bfloat16 + self.model_path, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True ) self.find_llmc_model() self.find_blocks() diff --git a/scripts/run_llmc.sh b/scripts/run_llmc.sh index d90877f69..efc4141af 100755 --- a/scripts/run_llmc.sh +++ b/scripts/run_llmc.sh @@ -1,17 +1,20 @@ -#!/bin/bash - -# export CUDA_VISIBLE_DEVICES=0,1 - -llmc=/path/to/llmc +export PATH=/mnt/lm_data_afs/wangzining/charles/miniconda3/envs/llmc/bin:$PATH +export PYTHON=/mnt/lm_data_afs/wangzining/charles/miniconda3/envs/llmc/bin/python +export PIP=/mnt/lm_data_afs/wangzining/charles/miniconda3/envs/llmc/bin/pip +export HF_ENDPOINT=https://hf-mirror.com +cd /mnt/lm_data_afs/wangzining/charles/lab/llmc +# model_name=wan_t2v +model_name=wan2_2_t2v +task_name=awq_w_a +# task_name=awq_w_a_s +log_name=${model_name}_${task_name} +rm -rf ../lightx2v/${log_name}/x2v/lightx2v_quant_model +llmc=. export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=awq_w_only -config=${llmc}/configs/quantization/methods/Awq/awq_w_only.yml - +config=${llmc}/configs/quantization/video_gen/${model_name}/${task_name}.yaml nnodes=1 nproc_per_node=1 - find_unused_port() { while true; do port=$(shuf -i 10000-60000 -n 1) @@ -22,25 +25,15 @@ find_unused_port() { done } UNUSED_PORT=$(find_unused_port) - - MASTER_ADDR=127.0.0.1 MASTER_PORT=$UNUSED_PORT task_id=$UNUSED_PORT -nohup \ + torchrun \ --nnodes $nnodes \ --nproc_per_node $nproc_per_node \ --rdzv_id $task_id \ --rdzv_backend c10d \ --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ -${llmc}/llmc/__main__.py --config $config --task_id $task_id \ -> ${task_name}.log 2>&1 & - -sleep 2 -ps aux | grep '__main__.py' | grep $task_id | awk '{print $2}' > ${task_name}.pid - -# You can kill this program by -# xargs kill -9 < xxx.pid -# xxx.pid is ${task_name}.pid file \ No newline at end of file +${llmc}/llmc/__main__.py --config $config --task_id $task_id |tee ${log_name}.log \ No newline at end of file From 715104fbb7c370601a344e7847a8f0303995817e Mon Sep 17 00:00:00 2001 From: Charles2530 <2569337619@qq.com> Date: Fri, 13 Mar 2026 23:21:38 +0800 Subject: [PATCH 3/6] Wan2.2: MoE calibration split, blockwise input, OOM fixes and config --- .gitignore | 4 +- .../video_gen/wan2_2_t2v/awq_w_a.yaml | 10 +- llmc/compression/blockwise_optimization.py | 4 + llmc/models/wan2_2_t2v.py | 97 ++++++++++++++++--- 4 files changed, 97 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index 24b47eaa5..06eb95ea8 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,6 @@ save* *.pid *.ipynb* models/ -output_*HiFloat4/ +output_* +HiFloat4/ +datasets/ \ No newline at end of file diff --git a/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml index 1540e8b75..75c8c61b3 100644 --- a/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml @@ -4,15 +4,17 @@ model: type: Wan2T2V path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.2-T2V-A14B-Diffusers torch_dtype: auto + # 显存不足时开启:校准阶段捕获的激活存到 CPU,量化时再按 block 搬到 GPU + use_cpu_to_save_cuda_mem_for_catcher: True calib: name: t2v download: False path: ./assets/wan_t2v/calib/ - sample_steps: 20 + sample_steps: 20 # OOM 时可减小,如 8 或 10 bs: 1 - target_height: 480 - target_width: 832 - num_frames: 81 + target_height: 480 # OOM 时可减小,如 320 + target_width: 832 # OOM 时可减小,如 576 + num_frames: 81 # OOM 时可减小,如 49 或 33 guidance_scale: 5.0 seed: *seed eval: diff --git a/llmc/compression/blockwise_optimization.py b/llmc/compression/blockwise_optimization.py index 72823d1bd..380e8f42c 100644 --- a/llmc/compression/blockwise_optimization.py +++ b/llmc/compression/blockwise_optimization.py @@ -31,11 +31,15 @@ def __init__(self, model, compress_config, input, padding_mask, config): def run_block_loop(self): for i in range(len(self.blocks)): self.block_idx = i + if self.input and hasattr(self.model, 'get_blockwise_input'): + self.input = self.model.get_blockwise_input(self.block_idx, self.input) logger.info( f'\nblock index: {self.block_idx}/{len(self.blocks)} ' f'\nblock: {self.blocks[self.block_idx]}' ) self.block_opt(self.blocks[self.block_idx]) + if self.input and hasattr(self.model, 'set_blockwise_input'): + self.model.set_blockwise_input(self.block_idx, self.input) if hasattr(self, 'save_scale') and self.save_scale: os.makedirs(self.scale_path, exist_ok=True) diff --git a/llmc/models/wan2_2_t2v.py b/llmc/models/wan2_2_t2v.py index bf603c536..d799a4cec 100755 --- a/llmc/models/wan2_2_t2v.py +++ b/llmc/models/wan2_2_t2v.py @@ -1,3 +1,4 @@ +import gc import inspect from collections import defaultdict @@ -53,14 +54,14 @@ def build_model(self): for block_idx, block in enumerate(self.Pipeline.transformer_2.blocks): new_block = LlmcWanTransformerBlock.new(block) self.Pipeline.transformer_2.blocks[block_idx] = new_block - self.blocks = list(self.Pipeline.transformer.blocks) + list( - self.Pipeline.transformer_2.blocks - ) + self.num_transformer_blocks = len(self.Pipeline.transformer.blocks) + self.blocks = list(self.Pipeline.transformer.blocks) + list(self.Pipeline.transformer_2.blocks) logger.info( 'Wan2.2 MoE: both experts wrapped (high-noise + low-noise, 80 blocks total).' ) else: self.blocks = list(self.Pipeline.transformer.blocks) + self.num_transformer_blocks = len(self.blocks) logger.info('Wan2.2: single transformer wrapped (40 blocks).') logger.info('Model: %s', self.model) @@ -68,7 +69,25 @@ def find_llmc_model(self): self.model = self.Pipeline.transformer def find_blocks(self): - self.blocks = self.model.blocks + self.blocks = list(self.Pipeline.transformer.blocks) + self.num_transformer_blocks = len(self.blocks) + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + self.blocks += list(self.Pipeline.transformer_2.blocks) + + def _expert_name_from_block_idx(self, block_idx): + if block_idx < self.num_transformer_blocks: + return 'transformer' + return 'transformer_2' + + def get_blockwise_input(self, block_idx, fallback_input): + if not hasattr(self, 'blockwise_inputs'): + return fallback_input + return self.blockwise_inputs[self._expert_name_from_block_idx(block_idx)] + + def set_blockwise_input(self, block_idx, block_input): + if not hasattr(self, 'blockwise_inputs'): + return + self.blockwise_inputs[self._expert_name_from_block_idx(block_idx)] = block_input def get_catcher(self, first_block_input): sample_steps = self.sample_steps @@ -97,14 +116,52 @@ def forward(self, *args, **kwargs): @torch.no_grad() def collect_first_block_input(self, calib_data, padding_mask=None): - first_block_input = defaultdict(list) - Catcher = self.get_catcher(first_block_input) - # Install Catcher on the pipeline's first block so forward passes go through it. + first_block_input = { + 'transformer': defaultdict(list), + 'transformer_2': defaultdict(list), + } + sample_steps = self.sample_steps + + class Catcher(nn.Module): + def __init__(self, module, expert_name): + super().__init__() + self.module = module + self.signature = inspect.signature(module.forward) + self.expert_name = expert_name + + def _to_cpu(self, x): + if torch.is_tensor(x): + return x.detach().cpu() + if isinstance(x, tuple): + return tuple(self._to_cpu(t) for t in x) + return x + + def forward(self, *args, **kwargs): + params = list(self.signature.parameters.keys()) + for i, arg in enumerate(args): + if i > 0: + kwargs[params[i]] = arg + cur_num = len(first_block_input[self.expert_name]['data']) + if cur_num < sample_steps: + first_block_input[self.expert_name]['data'].append( + args[0].detach().cpu() if torch.is_tensor(args[0]) else args[0] + ) + first_block_input[self.expert_name]['kwargs'].append( + {k: self._to_cpu(v) for k, v in kwargs.items()} + ) + if all(len(first_block_input[name]['data']) >= sample_steps for name in first_block_input): + raise ValueError + return self.module(*args) + first_block = self.Pipeline.transformer.blocks[0] - self.Pipeline.transformer.blocks[0] = Catcher(first_block) + self.Pipeline.transformer.blocks[0] = Catcher(first_block, 'transformer') + first_block_2 = None + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + first_block_2 = self.Pipeline.transformer_2.blocks[0] + self.Pipeline.transformer_2.blocks[0] = Catcher(first_block_2, 'transformer_2') + self.Pipeline.to('cuda') for data in calib_data: - self.Pipeline.transformer.blocks[0].step = 0 try: pipe_kw = { 'prompt': data['prompt'], @@ -119,14 +176,28 @@ def collect_first_block_input(self, calib_data, padding_mask=None): self.Pipeline(**pipe_kw) except ValueError: pass + gc.collect() + torch.cuda.empty_cache() - self.first_block_input = first_block_input - assert len(self.first_block_input['data']) > 0, 'Catch input data failed.' - self.n_samples = len(self.first_block_input['data']) - logger.info('Retrieved %s calibration samples for Wan2.2 T2V.', self.n_samples) self.Pipeline.transformer.blocks[0] = self.Pipeline.transformer.blocks[0].module + if first_block_2 is not None: + self.Pipeline.transformer_2.blocks[0] = self.Pipeline.transformer_2.blocks[0].module self.Pipeline.to('cpu') + assert len(first_block_input['transformer']['data']) > 0, 'Catch transformer input data failed.' + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + assert len(first_block_input['transformer_2']['data']) > 0, \ + 'Catch transformer_2 input data failed.' + + self.blockwise_inputs = first_block_input + self.first_block_input = self.blockwise_inputs['transformer'] + self.n_samples = sum(len(v['data']) for v in self.blockwise_inputs.values()) + logger.info( + 'Retrieved Wan2.2 calibration samples: transformer=%s, transformer_2=%s.', + len(self.blockwise_inputs['transformer']['data']), + len(self.blockwise_inputs['transformer_2']['data']), + ) + def get_padding_mask(self): return None From 02b4133c1c6ab6dc6e539f1199dd236fb7783617 Mon Sep 17 00:00:00 2001 From: Charles2530 <2569337619@qq.com> Date: Tue, 17 Mar 2026 23:02:58 +0800 Subject: [PATCH 4/6] update wan2.2 --- .gitignore | 4 ++-- configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml | 10 ++++++---- llmc/__main__.py | 8 ++++---- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 06eb95ea8..0084a84f1 100644 --- a/.gitignore +++ b/.gitignore @@ -22,7 +22,7 @@ save* .log *.pid *.ipynb* -models/ +model/ output_* HiFloat4/ -datasets/ \ No newline at end of file +datasets/ diff --git a/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml index 75c8c61b3..5754a7321 100644 --- a/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml @@ -2,7 +2,7 @@ base: seed: &seed 42 model: type: Wan2T2V - path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.2-T2V-A14B-Diffusers + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/model/Wan2.2-T2V-A14B-Diffusers torch_dtype: auto # 显存不足时开启:校准阶段捕获的激活存到 CPU,量化时再按 block 搬到 GPU use_cpu_to_save_cuda_mem_for_catcher: True @@ -40,7 +40,7 @@ quant: granularity: per_channel group_size: -1 act: - # quant_type: int-qu + # quant_type: int-quant quant_type: hif4 bit: 4 symmetric: True @@ -51,5 +51,7 @@ quant: weight_clip: True clip_sym: True save: - save_lightx2v: True - save_path: ../lightx2v/wan2_2_t2v_awq_w_a/x2v/ + # save_lightx2v: True + # save_path: ./save_for_lightx2v/wan2_2_t2v/awq_w_a/original/ + save_fake: True + save_path: ./save_for_fake/wan2_2_t2v/awq_w_a/original/ diff --git a/llmc/__main__.py b/llmc/__main__.py index ec60c1492..abf4911cb 100755 --- a/llmc/__main__.py +++ b/llmc/__main__.py @@ -32,7 +32,7 @@ def main(config): logger.info(f'tokenizer: {model.get_tokenizer()}') eval_list = get_eval_list(model, config) - eval_model(model, None, eval_list, eval_pos='pretrain') + # eval_model(model, None, eval_list, eval_pos='pretrain') blockwise_opts = [] modalities, modality_configs = get_modality(config) @@ -70,7 +70,7 @@ def main(config): blockwise_opts.append(blockwise_opt) dist.barrier() - eval_model(model, blockwise_opts, eval_list, eval_pos='transformed') + # eval_model(model, blockwise_opts, eval_list, eval_pos='transformed') if int(os.environ['RANK']) == 0: if 'save' in config and config.save.get('save_trans', False): blockwise_opt.save_model(save_trans_path) @@ -85,8 +85,8 @@ def main(config): config.save.get('trtllm_cfg'), ) - eval_model(model, blockwise_opts, eval_list, eval_pos='fake_quant') - eval_model(model, blockwise_opts, eval_list, eval_pos='fake_quant_wo_kv') + # eval_model(model, blockwise_opts, eval_list, eval_pos='fake_quant') + # eval_model(model, blockwise_opts, eval_list, eval_pos='fake_quant_wo_kv') if 'save' in config and config.save.get('save_fake', False): deploy_all_modality(blockwise_opts, 'fake_quant') From 6e2dddbec30a6143ff238146586a536e0befa64a Mon Sep 17 00:00:00 2001 From: Charles2530 <2569337619@qq.com> Date: Tue, 17 Mar 2026 23:46:18 +0800 Subject: [PATCH 5/6] feat: add HF state_dict print tool Add a small test script to load sharded safetensors from a Hugging Face repo/local dir and print parameter keys with shapes. Made-with: Cursor --- requirements/runtime.txt | 1 + tools/print_state_dict_hf.py | 119 +++++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 tools/print_state_dict_hf.py diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 5869fa8d0..8fd082be7 100755 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -6,6 +6,7 @@ loguru transformers>=4.45.2 lmms-eval==0.3.0 huggingface-hub +safetensors sentencepiece protobuf accelerate>=0.26.0 diff --git a/tools/print_state_dict_hf.py b/tools/print_state_dict_hf.py new file mode 100644 index 000000000..449aac32c --- /dev/null +++ b/tools/print_state_dict_hf.py @@ -0,0 +1,119 @@ +import argparse +import json +import os +from collections import defaultdict +from importlib.metadata import version + +from huggingface_hub import snapshot_download +from safetensors import safe_open + + +def _find_index_file(model_dir: str) -> str: + candidates = [ + "diffusion_pytorch_model.safetensors.index.json", + "model.safetensors.index.json", + "pytorch_model.bin.index.json", + ] + for name in candidates: + p = os.path.join(model_dir, name) + if os.path.isfile(p): + return p + raise FileNotFoundError( + f"Cannot find an index json in {model_dir}. Tried: {', '.join(candidates)}" + ) + + +def _iter_safetensors_index(index_path: str): + with open(index_path, "r", encoding="utf-8") as f: + index = json.load(f) + + if "weight_map" not in index: + raise ValueError(f"Index file missing 'weight_map': {index_path}") + + weight_map = index["weight_map"] + shard_to_keys = defaultdict(list) + for k, shard_rel in weight_map.items(): + shard_to_keys[shard_rel].append(k) + + for shard_rel, keys in shard_to_keys.items(): + yield shard_rel, keys + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--repo", + type=str, + default="charles2530/Wan2.2-T2V-A14B-Diffusion-AWQ-INT4", + help="Hugging Face repo id, e.g. charles2530/Wan2.2-T2V-A14B-Diffusion-AWQ-INT4", + ) + parser.add_argument( + "--local_dir", + type=str, + default=None, + help="If provided, read model files from this local directory instead of downloading.", + ) + parser.add_argument( + "--revision", + type=str, + default="main", + help="HF revision (branch/tag/commit). Default: main", + ) + parser.add_argument( + "--download", + action="store_true", + help="Force download snapshot (ignored if --local_dir is set).", + ) + parser.add_argument( + "--max_keys", + type=int, + default=200, + help="Max number of parameter keys to print (across all shards). Default: 200", + ) + parser.add_argument( + "--print_values", + action="store_true", + help="Also print tensor repr (VERY large output). Default: off", + ) + args = parser.parse_args() + + print(f"huggingface-hub : {version('huggingface-hub')}") + print(f"safetensors : {version('safetensors')}") + + if args.local_dir is not None: + model_dir = args.local_dir + else: + model_dir = snapshot_download( + repo_id=args.repo, + revision=args.revision, + local_files_only=not args.download, + ) + + index_path = _find_index_file(model_dir) + print(f"model_dir : {model_dir}") + print(f"index : {index_path}") + + printed = 0 + for shard_rel, keys in _iter_safetensors_index(index_path): + shard_path = os.path.join(model_dir, shard_rel) + if not os.path.isfile(shard_path): + raise FileNotFoundError( + f"Shard not found: {shard_path}\n" + "Tip: re-run with --download to fetch all shards." + ) + + with safe_open(shard_path, framework="pt", device="cpu") as f: + for k in keys: + t = f.get_tensor(k) + print(f"{k} shape={tuple(t.shape)} dtype={t.dtype}") + if args.print_values: + print(t) + printed += 1 + if args.max_keys is not None and printed >= args.max_keys: + print(f"Reached --max_keys={args.max_keys}, stopping.") + return + + +if __name__ == "__main__": + main() + From 57df6716efe673d4b866ca4b1fc337da42f10af9 Mon Sep 17 00:00:00 2001 From: Charles2530 <2569337619@qq.com> Date: Wed, 18 Mar 2026 18:02:25 +0800 Subject: [PATCH 6/6] debug by claude --- CLAUDE.md | 6 ++++ .../base_blockwise_quantization.py | 35 +++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..8cef71c7e --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,6 @@ +我需要给wan2.2(https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)进行hifp4的模拟量化,使用的方法是AWQ,但是目前生成的权重有一些问题 +我的两个推测一是代码存在real quant和fake quant的糅合导致结果错误,二是模型本身没有完全保存(只保存了transfomer部分) +请参考配置文件configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml和和运行脚本scripts/run_llmc.sh,帮我解决这个问题 +请注意,我现在的电脑是本地主机而不是服务器,所以需要你从代码本身的逻辑去寻找错误而不能真的运行 +可以参考int4的real quant和llmc中本身的fake quant配置寻找原因 +你有权限修改本文件夹下所有文件 \ No newline at end of file diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 0c3d5474f..832e47985 100755 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -4,6 +4,7 @@ import json import os import re +import shutil from collections import defaultdict from functools import partial @@ -1017,6 +1018,18 @@ def contiguous_params(self): if not param.is_contiguous(): param.data = param.data.contiguous() + if ( + self.config.model.type in ['Wan2T2V'] + and hasattr(self.model.Pipeline, 'transformer_2') + and self.model.Pipeline.transformer_2 is not None + ): + for name, param in self.model.Pipeline.transformer_2.named_parameters(): + if not param.is_contiguous(): + param.data = param.data.contiguous() + for name, param in self.model.Pipeline.transformer_2.named_buffers(): + if not param.is_contiguous(): + param.data = param.data.contiguous() + @torch.no_grad() def save_model(self, path): if int(os.environ['RANK']) != 0: @@ -1037,6 +1050,28 @@ def save_model(self, path): self.model.avlm_model.save_pretrained(path) logger.info('save model done --') self.copy_tokenizer(path) + elif self.config.model.type in ['Wan2T2V']: + # Copy the full original pipeline (VAE, text encoder, tokenizer, scheduler, etc.) + # so that non-quantized components are preserved. + src = self.model.model_path + if os.path.abspath(src) != os.path.abspath(path): + if os.path.exists(path): + shutil.rmtree(path) + shutil.copytree(src, path) + logger.info(f'Copied original pipeline from {src} to {path}') + # Overwrite transformer subfolder with quantized weights. + self.model.Pipeline.transformer.save_pretrained( + os.path.join(path, 'transformer') + ) + logger.info('save Wan2.2 transformer done --') + if ( + hasattr(self.model.Pipeline, 'transformer_2') + and self.model.Pipeline.transformer_2 is not None + ): + self.model.Pipeline.transformer_2.save_pretrained( + os.path.join(path, 'transformer_2') + ) + logger.info('save Wan2.2 transformer_2 done --') else: self.model.get_model().save_pretrained(path) logger.info('save model done --')