diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 98e53b4e0..420dadf2b 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -7,6 +7,7 @@ from compressed_tensors.utils import ( align_modules, get_execution_device, + match_modules_set, match_named_modules, update_offload_parameter, ) @@ -319,64 +320,48 @@ def _set_resolved_mappings(self, model: Module) -> None: repeat for model.layer.1 and so on """ resolved_mappings: list[ResolvedMapping] = [] - for mapping_idx, mapping in enumerate(self.mappings): - num_skipped_mappings = 0 - for smooth_name, smooth_layer in ( - pbar := tqdm( - match_named_modules(model, [mapping.smooth_layer], self.ignore) + module_to_name = {} + for name, module in model.named_modules(): + if module in module_to_name: + logger.info( + f"Warning, {name} and {module_to_name[module]} both " + "share the same module the same module, " + "may have trouble resolving mappings." ) + module_to_name[module] = name + + for mapping in self.mappings: + target_patterns = (mapping.smooth_layer, *mapping.balance_layers) + + for smooth_layer, *balance_layers in match_modules_set( + model, target_patterns, self.ignore ): - pbar.set_description( - f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}" - f" ({num_skipped_mappings} skipped)" + smooth_name = module_to_name.get(smooth_layer) + balance_names = [ + module_to_name.get(balance_layer) + for balance_layer in balance_layers + ] + + all_compatible = _check_layers_are_compatible( + smooth_layer, smooth_name, balance_layers, balance_names ) - smooth_parent_name = ".".join(smooth_name.split(".")[:-1]) - smooth_parent = get_layer_by_name(smooth_parent_name, model) - - balance_layers, balance_names = [], [] - for balance_regex in mapping.balance_layers: - # find the submodules that match the activation layer - for balance_suffix, balance_layer in match_named_modules( - smooth_parent, [balance_regex], self.ignore - ): - balance_name = f"{smooth_parent_name}.{balance_suffix}" - - # exclude v_proj->o_proj mappings whose shapes are incompatible - # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 - if ( - isinstance(smooth_layer, torch.nn.Linear) - and isinstance(balance_layer, torch.nn.Linear) - and balance_name.endswith(".o_proj") - and ( - ( - smooth_name.endswith(".v_proj") - and smooth_layer.out_features - != balance_layer.in_features - ) - or ( - smooth_name.endswith(".qkv_proj") - and smooth_layer.out_features - != 3 * balance_layer.in_features - ) - ) - ): - num_skipped_mappings += 1 - continue - - balance_layers.append(balance_layer) - balance_names.append(balance_name) + # skip mapping if any of the balance layers are incompatible + if not all_compatible or len(balance_layers) == 0: + logger.info( + f"skipping AWQ for {smooth_name} for mapping {mapping}" + + ( + " because found incompatible balance layers" + if not all_compatible + else " because no balance layers were found" + ) + ) - if len(balance_layers) == 0: continue - - elif len(balance_layers) == 1: - # for single balance layer, parent is the balance layer - parent_name, parent = balance_name, balance_layer else: # for multiple balance layers, find lowest common parent - parent_name, parent = get_lowest_common_parent(balance_names, model) + parent_name, parent = get_lowest_common_module(balance_names, model) resolved_mappings.append( ResolvedMapping( @@ -721,6 +706,35 @@ def _assert_all_activations_consumed(self): raise RuntimeError("Some cached activations were not used") +def _check_layers_are_compatible( + smooth_layer, smooth_name, balance_layers, balance_names +): + """ + returns True if they are all compatible + returns False if any smooth & balance layers are incompatible + """ + for balance_layer, balance_name in zip(balance_layers, balance_names): + # exclude v_proj->o_proj mappings whose shapes are incompatible + # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 + if ( + isinstance(smooth_layer, torch.nn.Linear) + and isinstance(balance_layer, torch.nn.Linear) + and balance_name.endswith(".o_proj") + and ( + ( + smooth_name.endswith(".v_proj") + and smooth_layer.out_features != balance_layer.in_features + ) + or ( + smooth_name.endswith(".qkv_proj") + and smooth_layer.out_features != 3 * balance_layer.in_features + ) + ) + ): + return False + return True + + def _pseudo_quantize_tensor( w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 ): @@ -781,29 +795,41 @@ def _accumulate_mean( return (prev_sum + sum_added) / new_count, new_count -def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]: +def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Module]: """ - Given a list of names, returns the lowest-scope common parent. + Given a list of names, returns the lowest-scope common module. - NOTE: function excludes parents of type ModuleList, which don't play + NOTE: function excludes modules of type ModuleList, which don't play nicely with hooks because their forward method is never directly called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts are selected based on router output and their forward method is called. https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233 - Returns name of parent and pointer to parent module + Returns name of module and pointer to module Implementation is a small alteration of os.path.commonprefix https://docs.python.org/3/library/os.path.html#os.path.commonprefix """ - s1 = min(names) - s2 = max(names) - parent_name = "" + # adding "." before and after allows for handling a lot of corner + # cases which were previously mishandled ([case]->prefix->result) + # case 0: single module: [.abc.] -> .abc. -> abc + # case 1: substring modules: [.abc., .ab.] -> .ab -> "" + # case 2: parent & child: [.ab., .ab.a.] -> .ab. -> ab + s1 = min(names) + "." + s2 = max(names) + "." + + # 1) find longest shared prefix + parent_name = "." for i, c in enumerate(s1): if c != s2[i]: - parent_name = s1[:i].rstrip(".") break + parent_name += c + + # 2) throw away module name fragment and leading dot + # ".keep.thro" -> "keep" + parent_name = parent_name[1 : parent_name.rfind(".")] + # 3) return first common module that is not a module list while True: if parent_name == "": return "", module diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 950ab0f51..e8103f9e3 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -2,9 +2,10 @@ import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme from pydantic import ValidationError +from torch.nn import Linear from llmcompressor.modifiers.awq import AWQMapping, AWQModifier -from llmcompressor.modifiers.awq.base import get_lowest_common_parent +from llmcompressor.modifiers.awq.base import get_lowest_common_module from llmcompressor.modifiers.factory import ModifierFactory @@ -40,16 +41,16 @@ def test_set_resolved_mappings(): ) self_attn = torch.nn.ModuleDict( { - "q_proj": torch.nn.Linear(4, 4), - "k_proj": torch.nn.Linear(4, 4), - "v_proj": torch.nn.Linear(4, 4), - "o_proj": torch.nn.Linear(4, 4), + "q_proj": Linear(4, 4), + "k_proj": Linear(4, 4), + "v_proj": Linear(4, 4), + "o_proj": Linear(4, 4), } ) mlp = torch.nn.ModuleDict( { - "up_proj": torch.nn.Linear(4, 10), - "down_proj": torch.nn.Linear(10, 4), + "up_proj": Linear(4, 10), + "down_proj": Linear(10, 4), } ) model = torch.nn.ModuleDict( @@ -85,10 +86,12 @@ def test_set_resolved_mappings(): assert set(mapping.balance_names) == {"decoder.mlp.down_proj"} assert mapping.parent_name == "decoder.mlp.down_proj" - # make sure we exclude case where o_proj/v_proj shapes are mismatched awq = AWQModifier( mappings=[ + # make sure we exclude case where o_proj/v_proj shapes are mismatched AWQMapping("re:.*v_proj", ["re:.*o_proj"]), + # make sure we exclude mapping if any balance layers are skipped + AWQMapping("re:.*v_proj", ["re:.*z_proj", "re:.*o_proj"]), ], scheme="W4A16_ASYM", ) @@ -98,10 +101,11 @@ def test_set_resolved_mappings(): { "self_attn": torch.nn.ModuleDict( { - "q_proj": torch.nn.Linear(4, 2), - "k_proj": torch.nn.Linear(4, 2), - "v_proj": torch.nn.Linear(4, 2), - "o_proj": torch.nn.Linear(4, 4), + "q_proj": Linear(4, 2), + "k_proj": Linear(4, 2), + "v_proj": Linear(4, 2), + "z_proj": Linear(2, 4), + "o_proj": Linear(4, 4), } ) } @@ -109,6 +113,16 @@ def test_set_resolved_mappings(): } ) awq._set_resolved_mappings(model) + if len(awq._resolved_mappings) > 0: + assert all( + "o_proj" not in name for name in awq._resolved_mappings[0].balance_names + ), "should have skipped v->o mapping because o is incompatible" + assert all( + "z_proj" not in name for name in awq._resolved_mappings[0].balance_names + ), ( + "should have skipped v->[z,o] mapping because o is incompatible even though" + "z is compatible" + ) assert len(awq._resolved_mappings) == 0 @@ -179,15 +193,15 @@ def test_validate(): @pytest.mark.unit -def test_get_lowest_common_parent(): +def test_get_lowest_common_module(): mlp = torch.nn.ModuleDict( { "experts": torch.nn.ModuleList( [ torch.nn.ModuleDict( { - "gate_proj": torch.nn.Linear(4, 2), - "down_proj": torch.nn.Linear(4, 2), + "gate_proj": Linear(4, 2), + "down_proj": Linear(4, 2), } ) for _ in range(10) @@ -197,15 +211,15 @@ def test_get_lowest_common_parent(): ) self_attn = torch.nn.ModuleDict( { - "q_proj": torch.nn.Linear(4, 2), - "k_proj": torch.nn.Linear(4, 2), - "v_proj": torch.nn.Linear(4, 2), - "o_proj": torch.nn.Linear(4, 4), + "q_proj": Linear(4, 2), + "k_proj": Linear(4, 2), + "v_proj": Linear(4, 2), + "o_proj": Linear(4, 4), } ) model = torch.nn.ModuleDict( { - "embed_tokens": torch.nn.Linear(4, 2), + "embed_tokens": Linear(4, 2), "decoder": torch.nn.ModuleDict( { "self_attn": self_attn, @@ -215,22 +229,36 @@ def test_get_lowest_common_parent(): } ) - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["decoder.mlp.experts.1.gate_proj", "decoder.mlp.experts.4.down_proj"], model ) assert parent_name == "decoder.mlp" and parent == mlp - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["decoder.self_attn.q_proj", "decoder.self_attn.v_proj"], model ) assert parent_name == "decoder.self_attn" and parent == self_attn - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["decoder.mlp.experts.1.gate_proj", "decoder.self_attn.v_proj"], model ) assert parent_name == "decoder" and parent == model["decoder"] - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["embed_tokens", "decoder.self_attn.v_proj"], model ) assert parent_name == "" and parent == model + + m = torch.nn.ModuleDict( + { + "abc": Linear(3, 3), + "ab": torch.nn.ModuleDict({"a": Linear(3, 3)}), + "z": Linear(3, 3), + } + ) + parent_name, parent = get_lowest_common_module(["abc", "ab"], m) + assert parent_name == "" + parent_name, parent = get_lowest_common_module(["ab", "ab.a"], m) + assert parent_name == "ab" + parent_name, parent = get_lowest_common_module(["z"], m) + assert parent_name == "z"