Skip to content

Commit 0340724

Browse files
authored
[Performance] [Utils] Refactor embeddings utils (#2080)
## Purpose ## * Prerequisite for detecting if the lm head can be skipped (which is necessary to support large batch calibration) * Breaking up the embeddings utils in this way makes it easier to implement `disable_lm_head` later ## Changes ## * Generalize embedding utils * `_get_embeddings_or_warn` -> `get_embeddings` * Callers no longer have to use try-catch * Callers are responsible for warning * `untie_word_embeddings` is largely unchanged, slight clarity changes * Update modifiers * `untie_if_target_shared_embedding(...)` -> `if targets_embeddings(...): untie_word_embeddings(...)` ## Testing ## * Rename `test_model_shared_tensors` to `test_untie_word_embeddings` * Add `test_targets_embeddings` to test targeting embeddings --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 588d1e5 commit 0340724

File tree

11 files changed

+240
-218
lines changed

11 files changed

+240
-218
lines changed

src/llmcompressor/entrypoints/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@
2929
from llmcompressor.pytorch.model_load.helpers import parse_dtype
3030
from llmcompressor.transformers.compression.compressed_tensors_utils import (
3131
modify_save_pretrained,
32-
untie_word_embeddings,
3332
)
3433
from llmcompressor.transformers.utils.helpers import (
3534
is_model_ct_quantized_from_path,
3635
)
3736
from llmcompressor.typing import Processor
37+
from llmcompressor.utils import untie_word_embeddings
3838
from llmcompressor.utils.fsdp.helpers import is_fsdp_model
3939

4040

src/llmcompressor/modifiers/autoround/base.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,8 @@
2020
from llmcompressor.modifiers import Modifier
2121
from llmcompressor.modifiers.quantization.calibration import apply_calibration_status
2222
from llmcompressor.modifiers.quantization.quantization import QuantizationMixin
23-
from llmcompressor.transformers.compression.compressed_tensors_utils import (
24-
untie_if_target_shared_embedding,
25-
)
26-
from llmcompressor.utils.pytorch.module import get_no_split_params
23+
from llmcompressor.utils import targets_embeddings, untie_word_embeddings
24+
from llmcompressor.utils.pytorch import get_no_split_params
2725

2826
__all__ = ["AutoRoundModifier"]
2927

@@ -110,7 +108,6 @@ class AutoRoundModifier(Modifier, QuantizationMixin):
110108
batch_size: int = 8
111109

112110
# private variables
113-
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
114111
_all_module_input: Dict[str, List[Tuple]] = PrivateAttr(default_factory=dict)
115112
_q_input: Optional[torch.Tensor] = PrivateAttr(default=None)
116113

@@ -125,10 +122,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:
125122
QuantizationMixin.initialize_quantization(self, state.model)
126123

127124
# prepare module names
128-
self._module_names = {
129-
m: name
130-
for name, m in match_named_modules(state.model, self.targets, self.ignore)
131-
}
132125
self._add_temporary_names(state.model)
133126
# freeze all model parameters
134127
for _, param in state.model.named_parameters():
@@ -143,7 +136,9 @@ def start_calibration(self, model: torch.nn.Module):
143136
144137
:param model: model to prepare for calibration
145138
"""
146-
untie_if_target_shared_embedding(model, self._module_names.values())
139+
targets = match_named_modules(model, self.targets, self.ignore)
140+
if targets_embeddings(model, targets):
141+
untie_word_embeddings(model)
147142

148143
for _, module in match_named_modules(model, self.targets, self.ignore):
149144
# Note: No need to register observers for auto-round

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@
3434
reset_quantization_status,
3535
)
3636
from llmcompressor.modifiers.utils.hooks import HooksMixin
37-
from llmcompressor.transformers.compression.compressed_tensors_utils import (
38-
untie_if_target_shared_embedding,
39-
)
37+
from llmcompressor.utils import targets_embeddings, untie_word_embeddings
4038

4139
__all__ = ["QuantizationMixin"]
4240

@@ -182,11 +180,9 @@ def start_calibration(self, model: torch.nn.Module):
182180
183181
:param model: model to prepare for calibration
184182
"""
185-
186-
matched_module_generator = (
187-
x[1] for x in match_named_modules(model, self.resolved_targets, self.ignore)
188-
)
189-
untie_if_target_shared_embedding(model, matched_module_generator)
183+
targets = match_named_modules(model, self.resolved_targets, self.ignore)
184+
if targets_embeddings(model, targets):
185+
untie_word_embeddings(model)
190186

191187
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
192188
self._initialize_observers(module)

src/llmcompressor/modifiers/transform/quip/base.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212

1313
from llmcompressor.core import Event, EventType, State
1414
from llmcompressor.modifiers import Modifier
15-
from llmcompressor.transformers.compression.compressed_tensors_utils import (
16-
untie_if_target_shared_embedding,
17-
)
15+
from llmcompressor.typing import NamedModules
16+
from llmcompressor.utils import targets_embeddings, untie_word_embeddings
1817

1918
__all__ = ["QuIPModifier"]
2019

@@ -102,18 +101,13 @@ def on_initialize(self, state: State, **kwargs) -> bool:
102101

103102
def on_start(self, state: State, event: Event, **kwargs):
104103
self.started_ = True
105-
106-
def matched_module_generator():
107-
for scheme in self.transform_config.config_groups.values():
108-
for arg in scheme.apply:
109-
gen = match_named_modules(state.model, arg.targets, arg.ignore)
110-
for _, module in gen:
111-
yield module
104+
model = state.model
112105

113106
# Untie embeddings if they will be targeted by transforms
114-
untie_if_target_shared_embedding(state.model, matched_module_generator())
107+
if targets_embeddings(model, self._get_targets(model)):
108+
untie_word_embeddings(model)
115109

116-
apply_transform_config(state.model, self.transform_config)
110+
apply_transform_config(model, self.transform_config)
117111

118112
def on_event(self, state: State, event: Event, **kwargs):
119113
if event.type_ == EventType.CALIBRATION_EPOCH_START:
@@ -136,6 +130,17 @@ def on_finalize(self, state: State, **kwargs) -> bool:
136130

137131
return True
138132

133+
def _get_targets(self, model: torch.nn.Module) -> NamedModules:
134+
if not self.initialized_:
135+
raise ValueError("Cannot get targets before modifier has been initialized")
136+
137+
return [
138+
(name, module)
139+
for scheme in self.transform_config.config_groups.values()
140+
for arg in scheme.apply
141+
for name, module in match_named_modules(model, arg.targets, arg.ignore)
142+
]
143+
139144
def _create_config(self) -> TransformConfig:
140145
config_groups = dict()
141146
if "v" in self.rotations:

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616
from llmcompressor.core import Event, EventType, State
1717
from llmcompressor.modeling import center_embeddings, fuse_norm_linears
1818
from llmcompressor.modifiers import Modifier
19-
from llmcompressor.transformers.compression.compressed_tensors_utils import (
20-
untie_word_embeddings,
21-
)
19+
from llmcompressor.typing import NamedModules
20+
from llmcompressor.utils import untie_word_embeddings
2221

2322
from .mappings import SpinQuantMapping, infer_mapping_from_model
2423
from .norm_mappings import NormMapping, infer_norm_mapping_from_model
@@ -151,14 +150,16 @@ def on_initialize(self, state: State, **kwargs) -> bool:
151150
@torch.no_grad()
152151
def on_start(self, state: State, event: Event, **kwargs):
153152
self.started_ = True
153+
model = state.model
154+
155+
# untie embeddings to avoid unintended effects of `_center_embeddings`
156+
untie_word_embeddings(model)
154157

155-
# needed any time embeddings/lm_head is modified
156-
untie_word_embeddings(state.model)
157158
# needs to happen after the model has been hooked to execute on the GPU
158159
# otherwise we're applying weight transforms on CPU
159-
self._center_embeddings(state.model)
160-
self._fuse_norms(state.model)
161-
apply_transform_config(state.model, self.transform_config)
160+
self._center_embeddings(model)
161+
self._fuse_norms(model)
162+
apply_transform_config(model, self.transform_config)
162163

163164
def on_event(self, state: State, event: Event, **kwargs):
164165
if event.type_ == EventType.CALIBRATION_EPOCH_START:
@@ -181,6 +182,17 @@ def on_finalize(self, state: State, **kwargs) -> bool:
181182

182183
return True
183184

185+
def _get_targets(self, model: torch.nn.Module) -> NamedModules:
186+
if not self.initialized_:
187+
raise ValueError("Cannot get targets before modifier has been initialized")
188+
189+
return [
190+
(name, module)
191+
for scheme in self.transform_config.config_groups.values()
192+
for arg in scheme.apply
193+
for name, module in match_named_modules(model, arg.targets, arg.ignore)
194+
]
195+
184196
def _center_embeddings(self, model: PreTrainedModel):
185197
for _, embedding in match_named_modules(
186198
model, [self.mappings.embedding], warn_on_fail=True

src/llmcompressor/transformers/compression/compressed_tensors_utils.py

Lines changed: 1 addition & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import weakref
3-
from collections.abc import Generator
43
from functools import wraps
54
from typing import Optional
65

@@ -9,9 +8,6 @@
98
from compressed_tensors import (
109
ModelCompressor,
1110
SparsityCompressionConfig,
12-
delete_offload_parameter,
13-
has_offloaded_params,
14-
register_offload_parameter,
1511
)
1612
from compressed_tensors.config import CompressionFormat
1713
from loguru import logger
@@ -25,7 +21,7 @@
2521
from llmcompressor.transformers.utils import RECIPE_FILE_NAME
2622
from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path
2723

28-
__all__ = ["modify_save_pretrained", "untie_word_embeddings"]
24+
__all__ = ["modify_save_pretrained"]
2925

3026

3127
def modify_save_pretrained(model: PreTrainedModel):
@@ -118,119 +114,6 @@ def save_pretrained_wrapper(
118114
model.save_pretrained = save_pretrained_compressed(model.save_pretrained)
119115

120116

121-
def untie_word_embeddings(model: PreTrainedModel):
122-
"""
123-
Patches bug where HF transformers will fail to untie weights under specific
124-
circumstances (https://github.com/huggingface/transformers/issues/33689).
125-
126-
This function detects those cases and unties the tensors if applicable
127-
128-
:param model: model to fix
129-
"""
130-
try:
131-
input_embed = model.get_input_embeddings()
132-
output_embed = model.get_output_embeddings()
133-
except NotImplementedError as e:
134-
logger.warning(
135-
f"cannot untie model of type {model.__class__} which doesn't have "
136-
f"get_input_embeddings and get_output_embeddings implmented\n{e}"
137-
)
138-
return
139-
140-
for module in (input_embed, output_embed):
141-
if module is None or not hasattr(module, "weight"):
142-
logger.warning(f"Cannot untie {module} which does not have weight param")
143-
continue
144-
145-
# this could be replaced by a `get_offloaded_parameter` util
146-
if not has_offloaded_params(module):
147-
untied_data = module.weight.data.clone()
148-
else:
149-
untied_data = module._hf_hook.weights_map["weight"].clone()
150-
151-
requires_grad = module.weight.requires_grad
152-
new_parameter = torch.nn.Parameter(untied_data, requires_grad=requires_grad)
153-
delete_offload_parameter(module, "weight")
154-
register_offload_parameter(module, "weight", new_parameter)
155-
156-
if hasattr(model.config, "tie_word_embeddings"):
157-
model.config.tie_word_embeddings = False
158-
159-
160-
def _get_embeddings_or_warn(
161-
model: torch.nn.Module,
162-
) -> tuple[torch.nn.Module | None, torch.nn.Module | None]:
163-
if not (
164-
hasattr(model, "get_input_embeddings")
165-
and hasattr(model, "get_output_embeddings")
166-
):
167-
logger.warning(
168-
f"{model.__class__} doesn't have attribute get_input_embeddings and"
169-
" get_output_embeddings implemented."
170-
"\nThis can cause"
171-
" problems when quantizing layers with shared weights"
172-
)
173-
return None, None
174-
175-
try:
176-
input_embeddings, output_embeddings = (
177-
model.get_input_embeddings(),
178-
model.get_output_embeddings(),
179-
)
180-
except NotImplementedError as e:
181-
logger.warning(
182-
f"{model.__class__} doesn't have get_input_embeddings and "
183-
"get_output_embeddings implemented."
184-
"\nThis can cause"
185-
" problems when quantizing layers with shared weights"
186-
f"\n{e}"
187-
)
188-
return None, None
189-
190-
if not (
191-
isinstance(input_embeddings, torch.nn.Module)
192-
and isinstance(output_embeddings, torch.nn.Module)
193-
):
194-
logger.warning(
195-
f"expected modules from {model.__class__} get_input_embeddings and"
196-
f" get_output_embeddings but got {type(input_embeddings)}"
197-
f" and {type(output_embeddings)}."
198-
"\nThis can cause"
199-
" problems when quantizing layers with shared weights"
200-
)
201-
return None, None
202-
return input_embeddings, output_embeddings
203-
204-
205-
def untie_if_target_shared_embedding(
206-
model: torch.nn.Module, matched_module_generator: Generator[torch.nn.Module]
207-
):
208-
"""
209-
Helper method that checks for shared input/output embedding and unties them
210-
if either shows up in the matched_module_generator
211-
212-
:param model: model to untie if embeddings are shared and targeted by
213-
matched_module_generator
214-
:param matched_module_generator: Generator of all modules (not names) which
215-
will be modified by quantization or transformation
216-
"""
217-
input_embeddings, output_embeddings = _get_embeddings_or_warn(model)
218-
219-
if None in (input_embeddings, output_embeddings): # if couldn't find embeddings
220-
return
221-
222-
if (
223-
input_embeddings.weight is not output_embeddings.weight
224-
): # if not shared, can ignore
225-
return
226-
227-
# if shared, check if either is targeted
228-
for module in matched_module_generator:
229-
if module in (input_embeddings, output_embeddings):
230-
untie_word_embeddings(model)
231-
return
232-
233-
234117
def get_model_compressor(
235118
model: torch.nn.Module,
236119
sparsity_config: Optional[SparsityCompressionConfig] = None,

src/llmcompressor/typing.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
Defines type aliases for the llm-compressor library.
33
"""
44

5-
from typing import Union
5+
from typing import Iterable
66

7+
import torch
78
from datasets import Dataset, DatasetDict, IterableDataset
89
from transformers import (
910
BaseImageProcessor,
@@ -13,9 +14,12 @@
1314
)
1415

1516
# Tokenizer or Processor. Processors do not inherit from a unified base class
16-
Processor = Union[
17-
PreTrainedTokenizer, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin
18-
]
17+
Processor = (
18+
PreTrainedTokenizer | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin
19+
)
1920

2021
# Supported dataset types, IterableDataset is a streamed dataset
21-
DatasetType = Union[Dataset, DatasetDict, IterableDataset]
22+
DatasetType = Dataset | DatasetDict | IterableDataset
23+
24+
# Torch types
25+
NamedModules = Iterable[tuple[str, torch.nn.Module]]

src/llmcompressor/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44

55
# ruff: noqa
66

7+
from .transformers import *
78
from .dev import *
89
from .helpers import *

0 commit comments

Comments
 (0)