Skip to content

Commit 152f1a0

Browse files
yiliu30XuehaoSun
andauthored
Refactor input normalization by replaying inputs for consistent preprocessing (#1094)
Signed-off-by: yiliu30 <yi4.liu@intel.com> Co-authored-by: Sun, Xuehao <xuehao.sun@intel.com>
1 parent ae4ac94 commit 152f1a0

File tree

2 files changed

+68
-40
lines changed

2 files changed

+68
-40
lines changed

auto_round/compressors/base.py

Lines changed: 68 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import traceback
2121
from collections import defaultdict
2222
from dataclasses import asdict, fields
23+
from functools import partial
2324
from typing import Any, Callable, Optional, Union
2425

2526
import accelerate
@@ -96,7 +97,6 @@
9697
llm_load_model,
9798
memory_monitor,
9899
mv_module_from_gpu,
99-
normalize_input,
100100
set_amax_for_all_moe_layers,
101101
set_module,
102102
to_device,
@@ -1918,6 +1918,45 @@ def _get_block_outputs(
19181918

19191919
return output
19201920

1921+
def normalize_decoding_layer_inputs_(self, decoding_layer_inputs: list[tuple[tuple[Any, dict[str, Any]]]]):
1922+
"""
1923+
Processes and stores decoding layer inputs for block quantization.
1924+
1925+
This function iterates through a list of captured decoding layer calls,
1926+
replaying them through a fake decoding layer to extract and store the
1927+
inputs required for the decoding block in `self.inputs`. This effectively
1928+
"normalizes" the inputs by making them accessible in a consistent format
1929+
for subsequent quantization steps.
1930+
1931+
Args:
1932+
decoding_layer_inputs:
1933+
A list of entries captured by a forward hook on the decoding layer.
1934+
Each element is expected to be a tuple whose first item is
1935+
`(args, kwargs)`, where `args` are the positional arguments and
1936+
`kwargs` are the keyword arguments seen during the original
1937+
forward pass.
1938+
1939+
The capture hook look like:
1940+
1941+
def input_capture_hook(module, *args, **kwargs):
1942+
_all_module_input[module._tmp_name].append((args, kwargs))
1943+
"""
1944+
first_block_name = self.quant_block_list[0][0]
1945+
1946+
class _FakeDecodingLayer(torch.nn.Module):
1947+
def forward(self, *args, **kwargs):
1948+
return args, kwargs
1949+
1950+
fake_layer = _FakeDecodingLayer()
1951+
fake_layer.orig_forward = fake_layer.forward
1952+
fake_layer.forward = partial(self._get_block_forward_func(first_block_name), fake_layer)
1953+
1954+
self.inputs = {}
1955+
self.last_cache_name = None
1956+
for step_input in decoding_layer_inputs:
1957+
args, kwargs = step_input[0]
1958+
fake_layer(*args, **kwargs)
1959+
19211960
@torch.no_grad()
19221961
def calib(self, nsamples, bs):
19231962
"""Perform calibration for quantization.
@@ -2346,7 +2385,6 @@ def _recover_forward(self):
23462385

23472386
def _replace_forward(self):
23482387
"""Replaces the forward function."""
2349-
from functools import partial
23502388

23512389
for n, m in self.model.named_modules():
23522390
if n in self.to_cached_layers and type(m) not in self.supported_types: ##block
@@ -2652,7 +2690,10 @@ def quantize_block(
26522690
"DiffusionCompressor",
26532691
"MLLMCompressor",
26542692
], f"Currently, {self.__class__.__name__} does not support support quantize block with this function."
2655-
input_ids, input_others = normalize_input(inputs)
2693+
self.normalize_decoding_layer_inputs_(inputs)
2694+
block_inputs = self.inputs[self.quant_block_list[0][0]]
2695+
decoding_layer_first_input_name = "hidden_states"
2696+
input_ids, input_others = self._preprocess_block_inputs(block_inputs, decoding_layer_first_input_name)
26562697
return self._quantize_block(block, input_ids, input_others, q_input, device, auto_offload)
26572698

26582699
def _get_loss(
@@ -2959,12 +3000,32 @@ def _quantize_block(
29593000

29603001
return None, output
29613002

2962-
def _split_inputs(self, inputs: dict) -> tuple[torch.Tensor, dict]:
2963-
input_ids = inputs["input_ids"]
2964-
inputs.pop("input_ids", None)
3003+
def _split_inputs(self, inputs: dict, first_input_name: str) -> tuple[torch.Tensor, dict]:
3004+
input_ids = inputs[first_input_name]
3005+
inputs.pop(first_input_name, None)
29653006
input_others = inputs
29663007
return input_ids, input_others
29673008

3009+
def _preprocess_block_inputs(self, inputs, first_input_name="input_ids"):
3010+
input_ids, input_others = self._split_inputs(inputs, first_input_name)
3011+
clear_memory(device_list=self.device_list)
3012+
input_ids = to_device(input_ids, self.cache_device)
3013+
input_others = to_device(input_others, self.cache_device)
3014+
# As in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage
3015+
3016+
tmp_dtype = self.amp_dtype if self.amp else torch.float32
3017+
input_ids = to_dtype(input_ids, tmp_dtype)
3018+
3019+
for key in input_others.keys():
3020+
if isinstance(input_others[key], torch.Tensor) and (
3021+
input_others[key].dtype == torch.float16 or input_others[key].dtype == torch.bfloat16
3022+
):
3023+
input_others[key] = input_others[key].to(tmp_dtype)
3024+
elif isinstance(input_others[key], list):
3025+
for i in range(len(input_others[key])):
3026+
to_dtype(input_others[key][i], tmp_dtype)
3027+
return input_ids, input_others
3028+
29683029
def _quantize_blocks(
29693030
self,
29703031
model: torch.nn.Module,
@@ -2991,23 +3052,7 @@ def _quantize_blocks(
29913052
for n, m in model.named_parameters():
29923053
m.requires_grad_(False)
29933054

2994-
input_ids, input_others = self._split_inputs(inputs)
2995-
clear_memory(device_list=self.device_list)
2996-
input_ids = to_device(input_ids, self.cache_device)
2997-
input_others = to_device(input_others, self.cache_device)
2998-
# As in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage
2999-
3000-
tmp_dtype = self.amp_dtype if self.amp else torch.float32
3001-
input_ids = to_dtype(input_ids, tmp_dtype)
3002-
3003-
for key in input_others.keys():
3004-
if isinstance(input_others[key], torch.Tensor) and (
3005-
input_others[key].dtype == torch.float16 or input_others[key].dtype == torch.bfloat16
3006-
):
3007-
input_others[key] = input_others[key].to(tmp_dtype)
3008-
elif isinstance(input_others[key], list):
3009-
for i in range(len(input_others[key])):
3010-
to_dtype(input_others[key][i], tmp_dtype)
3055+
input_ids, input_others = self._preprocess_block_inputs(inputs)
30113056

30123057
if pbar is None:
30133058
pbar = tqdm(range(0, len(block_names), nblocks))

auto_round/utils/common.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -329,20 +329,3 @@ def get_reciprocal(tensor):
329329
recip[mask] = 0.0
330330

331331
return recip
332-
333-
334-
def normalize_input(
335-
decoding_layer_inputs: tuple[Union[list[torch.Tensor], dict, Any], Optional[dict]],
336-
) -> Tuple[List[torch.Tensor], Dict[str, Any]]:
337-
"""Normalize the decoding layer inputs into input_ids and other inputs."""
338-
input_ids = []
339-
input_others = {"positional_inputs": []}
340-
for cur_inp in decoding_layer_inputs:
341-
input_ids.append(cur_inp[0][0][0])
342-
for key, val in cur_inp[0][1].items():
343-
input_others[key] = val
344-
# Force 'use_cache' to be False
345-
if "use_cache" in input_others and input_others["use_cache"] is True:
346-
logger.warning_once("Forcing 'use_cache' to be False during calibration.")
347-
input_others["use_cache"] = False
348-
return input_ids, input_others

0 commit comments

Comments
 (0)