2020import traceback
2121from collections import defaultdict
2222from dataclasses import asdict , fields
23+ from functools import partial
2324from typing import Any , Callable , Optional , Union
2425
2526import accelerate
9697 llm_load_model ,
9798 memory_monitor ,
9899 mv_module_from_gpu ,
99- normalize_input ,
100100 set_amax_for_all_moe_layers ,
101101 set_module ,
102102 to_device ,
@@ -1918,6 +1918,45 @@ def _get_block_outputs(
19181918
19191919 return output
19201920
1921+ def normalize_decoding_layer_inputs_ (self , decoding_layer_inputs : list [tuple [tuple [Any , dict [str , Any ]]]]):
1922+ """
1923+ Processes and stores decoding layer inputs for block quantization.
1924+
1925+ This function iterates through a list of captured decoding layer calls,
1926+ replaying them through a fake decoding layer to extract and store the
1927+ inputs required for the decoding block in `self.inputs`. This effectively
1928+ "normalizes" the inputs by making them accessible in a consistent format
1929+ for subsequent quantization steps.
1930+
1931+ Args:
1932+ decoding_layer_inputs:
1933+ A list of entries captured by a forward hook on the decoding layer.
1934+ Each element is expected to be a tuple whose first item is
1935+ `(args, kwargs)`, where `args` are the positional arguments and
1936+ `kwargs` are the keyword arguments seen during the original
1937+ forward pass.
1938+
1939+ The capture hook look like:
1940+
1941+ def input_capture_hook(module, *args, **kwargs):
1942+ _all_module_input[module._tmp_name].append((args, kwargs))
1943+ """
1944+ first_block_name = self .quant_block_list [0 ][0 ]
1945+
1946+ class _FakeDecodingLayer (torch .nn .Module ):
1947+ def forward (self , * args , ** kwargs ):
1948+ return args , kwargs
1949+
1950+ fake_layer = _FakeDecodingLayer ()
1951+ fake_layer .orig_forward = fake_layer .forward
1952+ fake_layer .forward = partial (self ._get_block_forward_func (first_block_name ), fake_layer )
1953+
1954+ self .inputs = {}
1955+ self .last_cache_name = None
1956+ for step_input in decoding_layer_inputs :
1957+ args , kwargs = step_input [0 ]
1958+ fake_layer (* args , ** kwargs )
1959+
19211960 @torch .no_grad ()
19221961 def calib (self , nsamples , bs ):
19231962 """Perform calibration for quantization.
@@ -2346,7 +2385,6 @@ def _recover_forward(self):
23462385
23472386 def _replace_forward (self ):
23482387 """Replaces the forward function."""
2349- from functools import partial
23502388
23512389 for n , m in self .model .named_modules ():
23522390 if n in self .to_cached_layers and type (m ) not in self .supported_types : ##block
@@ -2652,7 +2690,10 @@ def quantize_block(
26522690 "DiffusionCompressor" ,
26532691 "MLLMCompressor" ,
26542692 ], f"Currently, { self .__class__ .__name__ } does not support support quantize block with this function."
2655- input_ids , input_others = normalize_input (inputs )
2693+ self .normalize_decoding_layer_inputs_ (inputs )
2694+ block_inputs = self .inputs [self .quant_block_list [0 ][0 ]]
2695+ decoding_layer_first_input_name = "hidden_states"
2696+ input_ids , input_others = self ._preprocess_block_inputs (block_inputs , decoding_layer_first_input_name )
26562697 return self ._quantize_block (block , input_ids , input_others , q_input , device , auto_offload )
26572698
26582699 def _get_loss (
@@ -2959,12 +3000,32 @@ def _quantize_block(
29593000
29603001 return None , output
29613002
2962- def _split_inputs (self , inputs : dict ) -> tuple [torch .Tensor , dict ]:
2963- input_ids = inputs ["input_ids" ]
2964- inputs .pop ("input_ids" , None )
3003+ def _split_inputs (self , inputs : dict , first_input_name : str ) -> tuple [torch .Tensor , dict ]:
3004+ input_ids = inputs [first_input_name ]
3005+ inputs .pop (first_input_name , None )
29653006 input_others = inputs
29663007 return input_ids , input_others
29673008
3009+ def _preprocess_block_inputs (self , inputs , first_input_name = "input_ids" ):
3010+ input_ids , input_others = self ._split_inputs (inputs , first_input_name )
3011+ clear_memory (device_list = self .device_list )
3012+ input_ids = to_device (input_ids , self .cache_device )
3013+ input_others = to_device (input_others , self .cache_device )
3014+ # As in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage
3015+
3016+ tmp_dtype = self .amp_dtype if self .amp else torch .float32
3017+ input_ids = to_dtype (input_ids , tmp_dtype )
3018+
3019+ for key in input_others .keys ():
3020+ if isinstance (input_others [key ], torch .Tensor ) and (
3021+ input_others [key ].dtype == torch .float16 or input_others [key ].dtype == torch .bfloat16
3022+ ):
3023+ input_others [key ] = input_others [key ].to (tmp_dtype )
3024+ elif isinstance (input_others [key ], list ):
3025+ for i in range (len (input_others [key ])):
3026+ to_dtype (input_others [key ][i ], tmp_dtype )
3027+ return input_ids , input_others
3028+
29683029 def _quantize_blocks (
29693030 self ,
29703031 model : torch .nn .Module ,
@@ -2991,23 +3052,7 @@ def _quantize_blocks(
29913052 for n , m in model .named_parameters ():
29923053 m .requires_grad_ (False )
29933054
2994- input_ids , input_others = self ._split_inputs (inputs )
2995- clear_memory (device_list = self .device_list )
2996- input_ids = to_device (input_ids , self .cache_device )
2997- input_others = to_device (input_others , self .cache_device )
2998- # As in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage
2999-
3000- tmp_dtype = self .amp_dtype if self .amp else torch .float32
3001- input_ids = to_dtype (input_ids , tmp_dtype )
3002-
3003- for key in input_others .keys ():
3004- if isinstance (input_others [key ], torch .Tensor ) and (
3005- input_others [key ].dtype == torch .float16 or input_others [key ].dtype == torch .bfloat16
3006- ):
3007- input_others [key ] = input_others [key ].to (tmp_dtype )
3008- elif isinstance (input_others [key ], list ):
3009- for i in range (len (input_others [key ])):
3010- to_dtype (input_others [key ][i ], tmp_dtype )
3055+ input_ids , input_others = self ._preprocess_block_inputs (inputs )
30113056
30123057 if pbar is None :
30133058 pbar = tqdm (range (0 , len (block_names ), nblocks ))
0 commit comments