Skip to content

Commit 7b28d78

Browse files
authored
Add Exaone4 AWQ mapping (#2046)
EXAONE4 uses [QK-Reorder-Norm](https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B#introduction). Transformers: [Exaone4DecoderLayer](https://github.com/huggingface/transformers/blob/v4.57.1/src/transformers/models/exaone4/modeling_exaone4.py#L284-L314) ```python class Exaone4DecoderLayer(GradientCheckpointingLayer): def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.mlp(hidden_states) hidden_states = self.post_feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states return hidden_states ``` Signed-off-by: lkm2835 <lkm2835@gmail.com>
1 parent b77175d commit 7b28d78

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

src/llmcompressor/modifiers/awq/mappings.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,23 @@ class AWQMapping:
142142
# ["re:.*dense$"]
143143
# ),
144144
]
145+
146+
# Exaone4
147+
_exaone4_mappings = [
148+
AWQMapping("re:.*v_proj$", ["re:.*o_proj$"]),
149+
AWQMapping(
150+
"re:.*up_proj$",
151+
["re:.*down_proj$"],
152+
),
153+
]
154+
145155
AWQ_MAPPING_REGISTRY: dict[str, list[AWQMapping]] = {
146156
"BloomForCausalLM": _bloom_mappings,
147157
"CohereForCausalLM": _cohere_mappings,
148158
"Cohere2ForCausalLM": _cohere_mappings,
149159
"Cohere2VisionForConditionalGeneration": _cohere_mappings,
150160
"DeepseekV3ForCausalLM": _deepseek_mappings,
161+
"Exaone4ForCausalLM": _exaone4_mappings,
151162
"Gemma2ForCausalLM": _gemma_mappings,
152163
"Gemma3ForCausalLM": _gemma_mappings,
153164
"Gemma3ForConditionalGeneration": _gemma_mappings,

0 commit comments

Comments
 (0)