diff --git a/examples/multimodal_vision/gemma3_example.py b/examples/multimodal_vision/gemma3_example.py index dce35b7b83..7f9b7cfabe 100644 --- a/examples/multimodal_vision/gemma3_example.py +++ b/examples/multimodal_vision/gemma3_example.py @@ -1,55 +1,69 @@ import requests -import torch from PIL import Image -from transformers import AutoProcessor, Gemma3ForConditionalGeneration +from transformers import ( + AutoProcessor, + DataCollatorWithPadding, + Gemma3ForConditionalGeneration, +) from llmcompressor import oneshot -from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.utils import dispatch_for_generation # Load model. model_id = "google/gemma-3-4b-it" model = Gemma3ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) +collator = DataCollatorWithPadding(processor.tokenizer) # Oneshot arguments -DATASET_ID = "flickr30k" -DATASET_SPLIT = {"calibration": "test[:512]"} NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 +BATCH_SIZE = 512 +DATASET_ID = "flickr30k" +DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"} -# Define a oneshot data collator for multimodal inputs. -def data_collator(batch): - assert len(batch) == 1 - return {key: torch.tensor(value) for key, value in batch[0].items()} +# Define a oneshot data collator for multimodal processors +# remove extra dim added by vision processor +def data_collator(features: list[dict[str, object]]): + features = [{key: feature[key][0] for key in feature} for feature in features] + return collator(features) # Recipe recipe = [ - GPTQModifier( - targets="Linear", - scheme="W4A16", - ignore=[ - "lm_head", - r"re:model\.vision_tower.*", - r"re:model\.multi_modal_projector.*", - ], - ), + # GPTQModifier( + # targets="Linear", + # scheme="W4A16", + # ignore=[ + # "lm_head", + # r"re:model\.vision_tower.*", + # r"re:model\.multi_modal_projector.*", + # ], + # ), ] -# Perform oneshot -oneshot( - model=model, - tokenizer=model_id, - dataset=DATASET_ID, - splits=DATASET_SPLIT, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, - data_collator=data_collator, -) +from pttp import TensorProfiler + +with TensorProfiler() as prof: + # Perform oneshot + oneshot( + model=model, + tokenizer=model_id, + dataset=DATASET_ID, + splits=DATASET_SPLIT, + recipe=recipe, + batch_size=BATCH_SIZE, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + data_collator=data_collator, + trust_remote_code_model=True, + pipeline="sequential", + ) +import torch +del prof._memory.timeline[torch.device("cpu")] +prof.save_memory_timeline("with_disable.png") +exit(0) # Confirm generations of the quantized model look sane. print("========== SAMPLE GENERATION ==============") diff --git a/examples/multimodal_vision/idefics3_example.py b/examples/multimodal_vision/idefics3_example.py index 2fdaeb1a4a..a3c75722d6 100644 --- a/examples/multimodal_vision/idefics3_example.py +++ b/examples/multimodal_vision/idefics3_example.py @@ -1,5 +1,4 @@ import requests -import torch from datasets import load_dataset from PIL import Image from transformers import AutoProcessor, Idefics3ForConditionalGeneration @@ -14,16 +13,11 @@ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) # Oneshot arguments -DATASET_ID = "lmms-lab/flickr30k" -DATASET_SPLIT = "test[:512]" NUM_CALIBRATION_SAMPLES = 512 -MAX_SEQUENCE_LENGTH = 4096 # Seems to be required here - - -# Define a oneshot data collator for multimodal inputs. -def data_collator(batch): - assert len(batch) == 1 - return {key: torch.tensor(value) for key, value in batch[0].items()} +MAX_SEQUENCE_LENGTH = 4096 +BATCH_SIZE = 512 +DATASET_ID = "lmms-lab/flickr30k" +DATASET_SPLIT = f"test[:{NUM_CALIBRATION_SAMPLES}]" # Recipe @@ -69,7 +63,7 @@ def preprocess(example): # Tokenize inputs. def tokenize(sample): - return processor( + features = processor( text=sample["text"], images=sample["images"], padding=False, @@ -77,6 +71,9 @@ def tokenize(sample): truncation=True, ) + # remove extra dim added by vision processor + return [{key: feature[key][0] for key in feature} for feature in features] + # avoid errors with writer_batch_size ds = ds.map(tokenize, writer_batch_size=1, remove_columns=ds.column_names) @@ -86,10 +83,9 @@ def tokenize(sample): model=model, dataset=ds, recipe=recipe, + batch_size=BATCH_SIZE, max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, - data_collator=data_collator, sequential_targets=["LlamaDecoderLayer"], ) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index b03aacee35..e970d41eb8 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -7,6 +7,7 @@ # Select model and load it. model_id = "meta-llama/Meta-Llama-3-8B-Instruct" +#model_id = "meta-llama/Llama-3.2-1B-Instruct" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -16,8 +17,9 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 +NUM_CALIBRATION_SAMPLES = 32 MAX_SEQUENCE_LENGTH = 2048 +BATCH_SIZE = 16 # Load dataset and preprocess. ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") @@ -58,9 +60,12 @@ def tokenize(sample): model=model, dataset=ds, recipe=recipe, + batch_size=BATCH_SIZE, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, + pipeline="sequential", ) +exit(0) # Confirm generations of the quantized model look sane. print("\n\n") diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 2618b90197..b66fcfa275 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -8,9 +8,7 @@ """ from dataclasses import dataclass, field -from typing import Any, Callable - -from transformers import DefaultDataCollator +from typing import Callable, Optional @dataclass @@ -69,9 +67,25 @@ class CustomDatasetArguments(DVCDatasetArguments): }, ) - data_collator: Callable[[Any], Any] = field( - default_factory=lambda: DefaultDataCollator(), - metadata={"help": "The function to used to form a batch from the dataset"}, + data_collator: Optional[Callable] = field( + default=None, + metadata={ + "help": ( + "The function to used to form a batch from the dataset. Defaults to " + "`DataCollatorWithPadding(processor)`." + ) + }, + ) + + batch_size: int = field( + default=1, + metadata={ + "help": ( + "Calibration batch size. During calibration, LLM Compressor disables " + "lm_head output computations to reduce memory usage from large " + "calibration matches" + ) + }, ) diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py index 2b80b1ed9a..dc11088cd4 100644 --- a/src/llmcompressor/datasets/utils.py +++ b/src/llmcompressor/datasets/utils.py @@ -7,6 +7,7 @@ one-shot calibration workflows. """ +import math import multiprocessing import re from typing import Any, Callable @@ -14,8 +15,8 @@ import torch from datasets import Dataset from loguru import logger -from torch.utils.data import DataLoader, RandomSampler, SequentialSampler -from transformers.data import default_data_collator +from torch.utils.data import DataLoader, SequentialSampler +from transformers.data import DataCollatorWithPadding from llmcompressor.args import DatasetArguments from llmcompressor.transformers.data import TextGenerationDataset @@ -115,44 +116,53 @@ def get_calibration_dataloader( ) calibration_dataset = datasets.get("calibration") + tokenizer = getattr(processor, "tokenizer", processor) + collate_fn = dataset_args.data_collator or DataCollatorWithPadding(tokenizer) + if tokenizer.pad_token is None or tokenizer.pad_token_id < 0: + logger.debug("Could not find padding token. Setting PAD token to EOS token") + tokenizer.pad_token = tokenizer.eos_token return format_calibration_data( tokenized_dataset=calibration_dataset, + collate_fn=collate_fn, + batch_size=dataset_args.batch_size, num_calibration_samples=dataset_args.num_calibration_samples, do_shuffle=dataset_args.shuffle_calibration_samples, - collate_fn=dataset_args.data_collator, ) def format_calibration_data( tokenized_dataset: Dataset, + collate_fn: Callable, + batch_size: int = 1, num_calibration_samples: int | None = None, do_shuffle: bool = True, - collate_fn: Callable = default_data_collator, ) -> list[torch.Tensor]: """ Creates a dataloader out of the calibration dataset split, trimming it to the desired number of calibration samples :param tokenized_dataset: dataset to convert to dataloader - :param num_calibration_samples: number of data samples to convert + :param num_calibration_samples: number of batches to convert :param do_shuffle: whether to shuffle the dataset before selecting calibration samples, true by default :param collate_fn: optional custom collate function, or use default :return: list of trimmed calibration data tensors """ - safe_calibration_samples = len(tokenized_dataset) + # (1) shuffle dataset + if do_shuffle: + tokenized_dataset = tokenized_dataset.shuffle() + + # (2) truncate dataset if num_calibration_samples is not None: - safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) - if safe_calibration_samples != num_calibration_samples: + if num_calibration_samples > len(tokenized_dataset): logger.warning( - f"Requested {num_calibration_samples} calibration samples but " - f"the provided dataset only has {safe_calibration_samples}. " + f"Requested {num_calibration_samples} calibration samples but the " + f"provided dataset only has {len(tokenized_dataset)} samples." ) + num_calibration_samples = len(tokenized_dataset) + tokenized_dataset = tokenized_dataset.select(range(num_calibration_samples)) - if do_shuffle: - tokenized_dataset = tokenized_dataset.shuffle() - tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples)) - + # (3) infer number of workers MAX_DATALOADER_WORKERS = 8 try: num_workers = min(MAX_DATALOADER_WORKERS, multiprocessing.cpu_count() // 2) @@ -161,19 +171,16 @@ def format_calibration_data( "Could not determine number of CPUs, defaulting to 0 dataloader workers." ) num_workers = 0 + + # (4) create dataloader dataloader_params = { - "batch_size": 1, - "sampler": RandomSampler(tokenized_calibration) - if do_shuffle - else SequentialSampler(tokenized_calibration), + "batch_size": batch_size, + "sampler": SequentialSampler(tokenized_dataset), "collate_fn": collate_fn, - "pin_memory": True, + "pin_memory": False, "num_workers": num_workers, } - - calibration_dataloader = DataLoader(tokenized_calibration, **dataloader_params) - - return calibration_dataloader + return DataLoader(tokenized_dataset, **dataloader_params) def make_dataset_splits( diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index c2b29aa97c..fba7dc6688 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -12,7 +12,7 @@ import os from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Optional from loguru import logger from torch.utils.data import DataLoader @@ -248,6 +248,8 @@ def oneshot( dataset_config_name: str | None = None, dataset_path: str | None = None, splits: str | list[str] | dict[str, str] | None = None, + batch_size: int = 1, + data_collator: Optional[Callable] = None, num_calibration_samples: int = 512, shuffle_calibration_samples: bool = True, max_seq_length: int = 384, diff --git a/src/llmcompressor/entrypoints/utils.py b/src/llmcompressor/entrypoints/utils.py index 3c1b354ce7..62f7e86de4 100644 --- a/src/llmcompressor/entrypoints/utils.py +++ b/src/llmcompressor/entrypoints/utils.py @@ -29,12 +29,12 @@ from llmcompressor.pytorch.model_load.helpers import parse_dtype from llmcompressor.transformers.compression.compressed_tensors_utils import ( modify_save_pretrained, - untie_word_embeddings, ) from llmcompressor.transformers.utils.helpers import ( is_model_ct_quantized_from_path, ) from llmcompressor.typing import Processor +from llmcompressor.utils import untie_word_embeddings from llmcompressor.utils.fsdp.helpers import is_fsdp_model diff --git a/src/llmcompressor/modifiers/autoround/base.py b/src/llmcompressor/modifiers/autoround/base.py index 2480751a9b..6b6d4fe3ed 100644 --- a/src/llmcompressor/modifiers/autoround/base.py +++ b/src/llmcompressor/modifiers/autoround/base.py @@ -20,10 +20,8 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.quantization.calibration import apply_calibration_status from llmcompressor.modifiers.quantization.quantization import QuantizationMixin -from llmcompressor.transformers.compression.compressed_tensors_utils import ( - untie_if_target_shared_embedding, -) -from llmcompressor.utils.pytorch.module import get_no_split_params +from llmcompressor.utils import targets_embeddings, untie_word_embeddings +from llmcompressor.utils.pytorch import get_no_split_params __all__ = ["AutoRoundModifier"] @@ -109,7 +107,6 @@ class AutoRoundModifier(Modifier, QuantizationMixin): enable_torch_compile: bool = True # private variables - _module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) _all_module_input: Dict[str, List[Tuple]] = PrivateAttr(default_factory=dict) _q_input: Optional[torch.Tensor] = PrivateAttr(default=None) @@ -124,10 +121,6 @@ def on_initialize(self, state: State, **kwargs) -> bool: QuantizationMixin.initialize_quantization(self, state.model) # prepare module names - self._module_names = { - m: name - for name, m in match_named_modules(state.model, self.targets, self.ignore) - } self._add_temporary_names(state.model) # freeze all model parameters for _, param in state.model.named_parameters(): @@ -142,7 +135,9 @@ def start_calibration(self, model: torch.nn.Module): :param model: model to prepare for calibration """ - untie_if_target_shared_embedding(model, self._module_names.values()) + targets = match_named_modules(model, self.targets, self.ignore) + if targets_embeddings(model, targets): + untie_word_embeddings(model) for _, module in match_named_modules(model, self.targets, self.ignore): # Note: No need to register observers for auto-round diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index 42264af22e..4a63b3922b 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -34,9 +34,7 @@ reset_quantization_status, ) from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.transformers.compression.compressed_tensors_utils import ( - untie_if_target_shared_embedding, -) +from llmcompressor.utils import targets_embeddings, untie_word_embeddings __all__ = ["QuantizationMixin"] @@ -182,11 +180,9 @@ def start_calibration(self, model: torch.nn.Module): :param model: model to prepare for calibration """ - - matched_module_generator = ( - x[1] for x in match_named_modules(model, self.resolved_targets, self.ignore) - ) - untie_if_target_shared_embedding(model, matched_module_generator) + targets = match_named_modules(model, self.resolved_targets, self.ignore) + if targets_embeddings(model, targets): + untie_word_embeddings(model) for _, module in match_named_modules(model, self.resolved_targets, self.ignore): self._initialize_observers(module) diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py index ace8d64fd4..4ac2e9c881 100644 --- a/src/llmcompressor/modifiers/transform/quip/base.py +++ b/src/llmcompressor/modifiers/transform/quip/base.py @@ -12,9 +12,8 @@ from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier -from llmcompressor.transformers.compression.compressed_tensors_utils import ( - untie_if_target_shared_embedding, -) +from llmcompressor.typing import NamedModules +from llmcompressor.utils import targets_embeddings, untie_word_embeddings __all__ = ["QuIPModifier"] @@ -102,18 +101,13 @@ def on_initialize(self, state: State, **kwargs) -> bool: def on_start(self, state: State, event: Event, **kwargs): self.started_ = True - - def matched_module_generator(): - for scheme in self.transform_config.config_groups.values(): - for arg in scheme.apply: - gen = match_named_modules(state.model, arg.targets, arg.ignore) - for _, module in gen: - yield module + model = state.model # Untie embeddings if they will be targeted by transforms - untie_if_target_shared_embedding(state.model, matched_module_generator()) + if targets_embeddings(model, self._get_targets(model)): + untie_word_embeddings(model) - apply_transform_config(state.model, self.transform_config) + apply_transform_config(model, self.transform_config) def on_event(self, state: State, event: Event, **kwargs): if event.type_ == EventType.CALIBRATION_EPOCH_START: @@ -136,6 +130,17 @@ def on_finalize(self, state: State, **kwargs) -> bool: return True + def _get_targets(self, model: torch.nn.Module) -> NamedModules: + if not self.initialized_: + raise ValueError("Cannot get targets before modifier has been initialized") + + return [ + (name, module) + for scheme in self.transform_config.config_groups.values() + for arg in scheme.apply + for name, module in match_named_modules(model, arg.targets, arg.ignore) + ] + def _create_config(self) -> TransformConfig: config_groups = dict() if "v" in self.rotations: diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 8d84e860f2..aac21d29c4 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -16,9 +16,8 @@ from llmcompressor.core import Event, EventType, State from llmcompressor.modeling import center_embeddings, fuse_norm_linears from llmcompressor.modifiers import Modifier -from llmcompressor.transformers.compression.compressed_tensors_utils import ( - untie_word_embeddings, -) +from llmcompressor.typing import NamedModules +from llmcompressor.utils import untie_word_embeddings from .mappings import SpinQuantMapping, infer_mapping_from_model from .norm_mappings import NormMapping, infer_norm_mapping_from_model @@ -151,14 +150,16 @@ def on_initialize(self, state: State, **kwargs) -> bool: @torch.no_grad() def on_start(self, state: State, event: Event, **kwargs): self.started_ = True + model = state.model + + # untie embeddings to avoid unintended effects of `_center_embeddings` + untie_word_embeddings(model) - # needed any time embeddings/lm_head is modified - untie_word_embeddings(state.model) # needs to happen after the model has been hooked to execute on the GPU # otherwise we're applying weight transforms on CPU - self._center_embeddings(state.model) - self._fuse_norms(state.model) - apply_transform_config(state.model, self.transform_config) + self._center_embeddings(model) + self._fuse_norms(model) + apply_transform_config(model, self.transform_config) def on_event(self, state: State, event: Event, **kwargs): if event.type_ == EventType.CALIBRATION_EPOCH_START: @@ -181,6 +182,17 @@ def on_finalize(self, state: State, **kwargs) -> bool: return True + def _get_targets(self, model: torch.nn.Module) -> NamedModules: + if not self.initialized_: + raise ValueError("Cannot get targets before modifier has been initialized") + + return [ + (name, module) + for scheme in self.transform_config.config_groups.values() + for arg in scheme.apply + for name, module in match_named_modules(model, arg.targets, arg.ignore) + ] + def _center_embeddings(self, model: PreTrainedModel): for _, embedding in match_named_modules( model, [self.mappings.embedding], warn_on_fail=True diff --git a/src/llmcompressor/transformers/compression/compressed_tensors_utils.py b/src/llmcompressor/transformers/compression/compressed_tensors_utils.py index ddb5b97f43..8393c6ac7b 100644 --- a/src/llmcompressor/transformers/compression/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/compression/compressed_tensors_utils.py @@ -1,6 +1,5 @@ import os import weakref -from collections.abc import Generator from functools import wraps from typing import Optional @@ -9,9 +8,6 @@ from compressed_tensors import ( ModelCompressor, SparsityCompressionConfig, - delete_offload_parameter, - has_offloaded_params, - register_offload_parameter, ) from compressed_tensors.config import CompressionFormat from loguru import logger @@ -25,7 +21,7 @@ from llmcompressor.transformers.utils import RECIPE_FILE_NAME from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path -__all__ = ["modify_save_pretrained", "untie_word_embeddings"] +__all__ = ["modify_save_pretrained"] def modify_save_pretrained(model: PreTrainedModel): @@ -118,119 +114,6 @@ def save_pretrained_wrapper( model.save_pretrained = save_pretrained_compressed(model.save_pretrained) -def untie_word_embeddings(model: PreTrainedModel): - """ - Patches bug where HF transformers will fail to untie weights under specific - circumstances (https://github.com/huggingface/transformers/issues/33689). - - This function detects those cases and unties the tensors if applicable - - :param model: model to fix - """ - try: - input_embed = model.get_input_embeddings() - output_embed = model.get_output_embeddings() - except NotImplementedError as e: - logger.warning( - f"cannot untie model of type {model.__class__} which doesn't have " - f"get_input_embeddings and get_output_embeddings implmented\n{e}" - ) - return - - for module in (input_embed, output_embed): - if module is None or not hasattr(module, "weight"): - logger.warning(f"Cannot untie {module} which does not have weight param") - continue - - # this could be replaced by a `get_offloaded_parameter` util - if not has_offloaded_params(module): - untied_data = module.weight.data.clone() - else: - untied_data = module._hf_hook.weights_map["weight"].clone() - - requires_grad = module.weight.requires_grad - new_parameter = torch.nn.Parameter(untied_data, requires_grad=requires_grad) - delete_offload_parameter(module, "weight") - register_offload_parameter(module, "weight", new_parameter) - - if hasattr(model.config, "tie_word_embeddings"): - model.config.tie_word_embeddings = False - - -def _get_embeddings_or_warn( - model: torch.nn.Module, -) -> tuple[torch.nn.Module | None, torch.nn.Module | None]: - if not ( - hasattr(model, "get_input_embeddings") - and hasattr(model, "get_output_embeddings") - ): - logger.warning( - f"{model.__class__} doesn't have attribute get_input_embeddings and" - " get_output_embeddings implemented." - "\nThis can cause" - " problems when quantizing layers with shared weights" - ) - return None, None - - try: - input_embeddings, output_embeddings = ( - model.get_input_embeddings(), - model.get_output_embeddings(), - ) - except NotImplementedError as e: - logger.warning( - f"{model.__class__} doesn't have get_input_embeddings and " - "get_output_embeddings implemented." - "\nThis can cause" - " problems when quantizing layers with shared weights" - f"\n{e}" - ) - return None, None - - if not ( - isinstance(input_embeddings, torch.nn.Module) - and isinstance(output_embeddings, torch.nn.Module) - ): - logger.warning( - f"expected modules from {model.__class__} get_input_embeddings and" - f" get_output_embeddings but got {type(input_embeddings)}" - f" and {type(output_embeddings)}." - "\nThis can cause" - " problems when quantizing layers with shared weights" - ) - return None, None - return input_embeddings, output_embeddings - - -def untie_if_target_shared_embedding( - model: torch.nn.Module, matched_module_generator: Generator[torch.nn.Module] -): - """ - Helper method that checks for shared input/output embedding and unties them - if either shows up in the matched_module_generator - - :param model: model to untie if embeddings are shared and targeted by - matched_module_generator - :param matched_module_generator: Generator of all modules (not names) which - will be modified by quantization or transformation - """ - input_embeddings, output_embeddings = _get_embeddings_or_warn(model) - - if None in (input_embeddings, output_embeddings): # if couldn't find embeddings - return - - if ( - input_embeddings.weight is not output_embeddings.weight - ): # if not shared, can ignore - return - - # if shared, check if either is targeted - for module in matched_module_generator: - if module in (input_embeddings, output_embeddings): - untie_word_embeddings(model) - return - - def get_model_compressor( model: torch.nn.Module, sparsity_config: Optional[SparsityCompressionConfig] = None, diff --git a/src/llmcompressor/typing.py b/src/llmcompressor/typing.py index 1d2d195158..233f4df56a 100644 --- a/src/llmcompressor/typing.py +++ b/src/llmcompressor/typing.py @@ -2,8 +2,9 @@ Defines type aliases for the llm-compressor library. """ -from typing import Union +from typing import Iterable +import torch from datasets import Dataset, DatasetDict, IterableDataset from transformers import ( BaseImageProcessor, @@ -13,9 +14,12 @@ ) # Tokenizer or Processor. Processors do not inherit from a unified base class -Processor = Union[ - PreTrainedTokenizer, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin -] +Processor = ( + PreTrainedTokenizer | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin +) # Supported dataset types, IterableDataset is a streamed dataset -DatasetType = Union[Dataset, DatasetDict, IterableDataset] +DatasetType = Dataset | DatasetDict | IterableDataset + +# Torch types +NamedModules = Iterable[tuple[str, torch.nn.Module]] diff --git a/src/llmcompressor/utils/__init__.py b/src/llmcompressor/utils/__init__.py index 42daa5bce6..4da185855b 100644 --- a/src/llmcompressor/utils/__init__.py +++ b/src/llmcompressor/utils/__init__.py @@ -4,5 +4,6 @@ # ruff: noqa +from .transformers import * from .dev import * from .helpers import * diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 0be09bd062..ecd63d4239 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -18,7 +18,7 @@ from collections import OrderedDict from io import BytesIO from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple, Union from urllib.parse import urlparse import numpy @@ -27,6 +27,11 @@ from loguru import logger from transformers import PreTrainedModel +from llmcompressor.utils import get_embeddings + +if TYPE_CHECKING: + pass + __all__ = [ "ALL_TOKEN", "ALL_PRUNABLE_TOKEN", @@ -65,6 +70,7 @@ "DisableQuantization", "eval_context", "calibration_forward_context", + "disable_lm_head", "patch_attr", "disable_hf_kernels", "DISABLE_QAC_MODIFIERS", @@ -1049,11 +1055,45 @@ def calibration_forward_context(model: torch.nn.Module): - Disable the KV cache - Disable train mode and enable eval mode - Disable hf kernels which could bypass hooks + - Disable lm head (input and weights can still be calibrated, output will be meta) + """ + with contextlib.ExitStack() as stack: + stack.enter_context(torch.no_grad()) + stack.enter_context(disable_cache(model)) + stack.enter_context(eval_context(model)) + stack.enter_context(disable_hf_kernels(model)) + stack.enter_context(disable_lm_head(model)) + yield + + +@contextlib.contextmanager +def disable_lm_head(model: torch.nn.Module): + """ + Disable the lm_head of a model by moving it to the meta device. This function + does not untie parameters and restores the model proper loading upon exit """ - with torch.no_grad(), disable_cache(model), eval_context(model), disable_hf_kernels( - model - ): + _, lm_head = get_embeddings(model) + if lm_head is None: + logger.warning( + f"Attempted to disable lm_head of instance {model.__class__.__name__}, " + "but was unable to to find lm_head. This may lead to unexpected OOM." + ) yield + return + + elif not isinstance(lm_head, torch.nn.Linear): + logger.warning(f"Cannot disable LM head of type {lm_head.__class__.__name__}") + yield + return + + else: + dummy_weight = lm_head.weight.to("meta") + + def dummy_forward(self, input: torch.Tensor) -> torch.Tensor: + return input.to("meta") @ dummy_weight.T + + with patch_attr(lm_head, "forward", dummy_forward.__get__(lm_head)): + yield @contextlib.contextmanager diff --git a/src/llmcompressor/utils/transformers.py b/src/llmcompressor/utils/transformers.py new file mode 100644 index 0000000000..9b4831a290 --- /dev/null +++ b/src/llmcompressor/utils/transformers.py @@ -0,0 +1,96 @@ +import torch +from compressed_tensors import has_offloaded_params, register_offload_parameter +from loguru import logger +from torch.nn import Parameter +from transformers import PreTrainedModel + +from llmcompressor.typing import NamedModules + +__all__ = ["untie_word_embeddings", "targets_embeddings", "get_embeddings"] + + +def untie_word_embeddings(model: PreTrainedModel): + """ + Untie word embeddings, if possible. This function raises a warning if + embeddings cannot be found in the model definition. + + The model config will be updated to reflect that embeddings are now untied + + :param model: transformers model containing word embeddings + """ + input_embed, output_embed = get_embeddings(model) + if input_embed is None or output_embed is None: + logger.warning( + "Cannot untie embeddings. If this model has word embeddings, please " + "implement `get_input_embeddings` and `get_output_embeddings`" + ) + return + + # clone data to untie + for module in (input_embed, output_embed): + if not has_offloaded_params(module): + data = module.weight.data + else: + data = module._hf_hook.weights_map["weight"] + + requires_grad = module.weight.requires_grad + untied_param = Parameter(data.clone(), requires_grad=requires_grad) + register_offload_parameter(module, "weight", untied_param) + + # modify model config + if hasattr(model.config, "tie_word_embeddings"): + model.config.tie_word_embeddings = False + + +def targets_embeddings( + model: PreTrainedModel, + targets: NamedModules, + check_input: bool = True, + check_output: bool = True, +) -> bool: + """ + Returns True if the given targets target the word embeddings of the model + + :param model: containing word embeddings + :param targets: named modules to check + :param check_input: whether to check if input embeddings are targeted + :param check_output: whether to check if output embeddings are targeted + :return: True if embeddings are targeted, False otherwise + """ + input_embed, output_embed = get_embeddings(model) + if (check_input and input_embed) is None or (check_output and output_embed is None): + logger.warning( + "Cannot check embeddings. If this model has word embeddings, please " + "implement `get_input_embeddings` and `get_output_embeddings`" + ) + return False + + targets = set(module for _, module in targets) + return (check_input and input_embed in targets) or ( + check_output and output_embed in targets + ) + + +def get_embeddings( + model: PreTrainedModel, +) -> tuple[torch.nn.Module | None, torch.nn.Module | None]: + """ + Returns input and output embeddings of a model. If `get_input_embeddings`/ + `get_output_embeddings` is not implemented on the model, then None will be returned + instead. + + :param model: model to get embeddings from + :return: tuple of containing embedding modules or none + """ + try: + input_embed = model.get_input_embeddings() + + except (AttributeError, NotImplementedError): + input_embed = None + + try: + output_embed = model.get_output_embeddings() + except (AttributeError, NotImplementedError): + output_embed = None + + return input_embed, output_embed diff --git a/tests/llmcompressor/transformers/compression/test_compress_tensor_utils.py b/tests/llmcompressor/transformers/compression/test_compress_tensor_utils.py index 96b5d6fce6..9963ee4c4a 100644 --- a/tests/llmcompressor/transformers/compression/test_compress_tensor_utils.py +++ b/tests/llmcompressor/transformers/compression/test_compress_tensor_utils.py @@ -14,7 +14,6 @@ QuantizationStatus, quantize, ) -from compressed_tensors.utils import align_module_device, update_offload_parameter from torch import nn from transformers import AutoConfig, AutoModelForCausalLM from transformers.utils.quantization_config import CompressedTensorsConfig @@ -25,11 +24,11 @@ from llmcompressor.transformers.compression.compressed_tensors_utils import ( get_model_compressor, modify_save_pretrained, - untie_word_embeddings, ) from llmcompressor.transformers.compression.sparsity_metadata_config import ( SparsityConfigMetadata, ) +from llmcompressor.utils import untie_word_embeddings from tests.testing_utils import requires_gpu @@ -283,60 +282,6 @@ def test_model_reload_gpu(offload, torch_dtype, tie_word_embeddings, device, tmp test_model_reload(offload, torch_dtype, tie_word_embeddings, device, tmp_path) -@pytest.mark.parametrize( - "offload,torch_dtype,tie_word_embeddings,device", - [ - (False, torch.float16, False, "cpu"), - (False, torch.float32, False, "cpu"), - (True, torch.float32, False, "cpu"), - (False, torch.float16, True, "cpu"), - (False, torch.float32, True, "cpu"), - (True, torch.float16, True, "cpu"), - (True, torch.float32, True, "cpu"), - ], -) -def test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device): - """ - Test whether model offloading breaks tied/untied embeddings - """ - # load model - model_path = "nm-testing/tinysmokellama-3.2" - model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype) - if offload: - model = dispatch_model(model, {"": device}, force_hooks=True) - else: - model = model.to(device) - - if not tie_word_embeddings: - untie_word_embeddings(model) - - # modify lm head - with torch.no_grad(), align_module_device(model.lm_head): - update_offload_parameter(model.lm_head, "weight", model.lm_head.weight + 1) - - with align_module_device(model.lm_head), align_module_device( - model.model.embed_tokens - ): - if tie_word_embeddings: - assert model.lm_head.weight is model.model.embed_tokens.weight - assert model.config.tie_word_embeddings - else: - assert model.lm_head.weight is not model.model.embed_tokens.weight - assert not model.config.tie_word_embeddings - - -@requires_gpu -@pytest.mark.parametrize( - "offload,torch_dtype,tie_word_embeddings,device", - [ - (False, torch.float32, False, "cuda:0"), - (False, torch.float32, True, "cuda:0"), - ], -) -def test_model_shared_tensors_gpu(offload, torch_dtype, tie_word_embeddings, device): - test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device) - - @requires_gpu @pytest.mark.parametrize( "model_stub, recipe, sparse_format, quant_format", diff --git a/tests/llmcompressor/utils/test_helpers.py b/tests/llmcompressor/utils/test_helpers.py index cb474c1f55..5e47bebe31 100644 --- a/tests/llmcompressor/utils/test_helpers.py +++ b/tests/llmcompressor/utils/test_helpers.py @@ -5,23 +5,23 @@ from transformers import ( AutoModelForCausalLM, MllamaForConditionalGeneration, - PretrainedConfig, - PreTrainedModel, ) +from llmcompressor.pipelines.sequential.helpers import dispatch_for_sequential from llmcompressor.utils import ( ALL_TOKEN, DisableQuantization, calibration_forward_context, convert_to_bool, disable_cache, + disable_lm_head, flatten_iterable, getattr_chain, interpolate, patch_attr, validate_str_iterable, ) -from llmcompressor.utils.dev import skip_weights_download +from llmcompressor.utils.dev import dispatch_for_generation, skip_weights_download from tests.testing_utils import requires_gpu @@ -149,10 +149,8 @@ def test_DisableQuantization(): @pytest.mark.unit def test_calibration_forward_context(): - class DummyModel(PreTrainedModel): - config_class = PretrainedConfig - - model = DummyModel(PretrainedConfig()) + with skip_weights_download(): + model = AutoModelForCausalLM.from_pretrained("nm-testing/tinysmokellama-3.2") model.config.use_cache = True model.train() @@ -160,9 +158,12 @@ class DummyModel(PreTrainedModel): assert not torch.is_grad_enabled() assert not model.config.use_cache assert not model.training + assert model.lm_head.forward.__name__ == "dummy_forward" + assert torch.is_grad_enabled() assert model.config.use_cache assert model.training + assert model.lm_head.forward.__name__ == "forward" @pytest.mark.unit @@ -203,3 +204,29 @@ def test_disable_cache(model_cls, model_stub): output = model(**inputs) assert output.past_key_values is not None + + +@requires_gpu +@pytest.mark.parametrize("offload", ["sequential", "basic", "none"]) +def test_disable_lm_head(offload): + model = AutoModelForCausalLM.from_pretrained("nm-testing/tinysmokellama-3.2") + if offload == "sequential": + dispatch_for_sequential(model) + if offload == "basic": + dispatch_for_generation(model) + if offload == "none": + model = model.to("cuda") + + lm_input_device = None + + def hook(module, args): + nonlocal lm_input_device + lm_input_device = args[0].device + + model.lm_head.register_forward_pre_hook(hook) + + with disable_lm_head(model): + input = {key: value.to("cuda") for key, value in model.dummy_inputs.items()} + output = model(**input) + assert lm_input_device == torch.device("cuda:0") + assert output.logits.device == torch.device("meta") diff --git a/tests/llmcompressor/utils/test_transformers.py b/tests/llmcompressor/utils/test_transformers.py new file mode 100644 index 0000000000..deb9c4a074 --- /dev/null +++ b/tests/llmcompressor/utils/test_transformers.py @@ -0,0 +1,85 @@ +import pytest +import torch +from accelerate import dispatch_model +from compressed_tensors import align_module_device, update_offload_parameter +from transformers import AutoModelForCausalLM + +from llmcompressor.utils import targets_embeddings, untie_word_embeddings +from tests.testing_utils import requires_gpu + + +@pytest.mark.parametrize( + "offload,torch_dtype,tie_word_embeddings,device", + [ + (False, torch.float16, False, "cpu"), + (False, torch.float32, False, "cpu"), + (True, torch.float32, False, "cpu"), + (False, torch.float16, True, "cpu"), + (False, torch.float32, True, "cpu"), + (True, torch.float16, True, "cpu"), + (True, torch.float32, True, "cpu"), + ], +) +def test_untie_word_embeddings(offload, torch_dtype, tie_word_embeddings, device): + """ + Test whether model offloading breaks tied/untied embeddings + """ + # load model + model_path = "nm-testing/tinysmokellama-3.2" + model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype) + if offload: + model = dispatch_model(model, {"": device}, force_hooks=True) + else: + model = model.to(device) + + if not tie_word_embeddings: + untie_word_embeddings(model) + + # modify lm head + with torch.no_grad(), align_module_device(model.lm_head): + update_offload_parameter(model.lm_head, "weight", model.lm_head.weight + 1) + + with align_module_device(model.lm_head), align_module_device( + model.model.embed_tokens + ): + if tie_word_embeddings: + assert model.lm_head.weight is model.model.embed_tokens.weight + assert model.config.tie_word_embeddings + else: + assert model.lm_head.weight is not model.model.embed_tokens.weight + assert not model.config.tie_word_embeddings + + +@requires_gpu +@pytest.mark.parametrize( + "offload,torch_dtype,tie_word_embeddings,device", + [ + (False, torch.float32, False, "cuda:0"), + (False, torch.float32, True, "cuda:0"), + ], +) +def test_untie_word_embeddings_gpu(offload, torch_dtype, tie_word_embeddings, device): + test_untie_word_embeddings(offload, torch_dtype, tie_word_embeddings, device) + + +def test_targets_embeddings(): + model_path = "nm-testing/tinysmokellama-3.2" + model = AutoModelForCausalLM.from_pretrained(model_path) + + targets = {"embed_tokens": model.model.embed_tokens}.items() + assert targets_embeddings(model, targets, check_input=True, check_output=True) + assert targets_embeddings(model, targets, check_input=True, check_output=False) + assert not targets_embeddings(model, targets, check_input=False, check_output=True) + assert not targets_embeddings(model, targets, check_input=False, check_output=False) + + targets = {"lm_head": model.lm_head}.items() + assert targets_embeddings(model, targets, check_input=True, check_output=True) + assert not targets_embeddings(model, targets, check_input=True, check_output=False) + assert targets_embeddings(model, targets, check_input=False, check_output=True) + assert not targets_embeddings(model, targets, check_input=False, check_output=False) + + targets = {}.items() + assert not targets_embeddings(model, targets, check_input=True, check_output=True) + assert not targets_embeddings(model, targets, check_input=True, check_output=False) + assert not targets_embeddings(model, targets, check_input=False, check_output=True) + assert not targets_embeddings(model, targets, check_input=False, check_output=False)