From b5ad1f38ba4e63e45249492db2281e718659da64 Mon Sep 17 00:00:00 2001 From: avtc Date: Tue, 4 Nov 2025 20:28:52 +0200 Subject: [PATCH 1/5] first 5 iterations --- .../hf_minimax_m2/modeling_minimax_m2.py | 57 ++++++++++++++----- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py b/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py index b847276da..de5313481 100644 --- a/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py +++ b/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py @@ -126,9 +126,9 @@ def __init__(self, config: MiniMaxM2Config) -> None: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: gate = self.act_fn(self.w1(hidden_states)) up = self.w3(hidden_states) - hidden_states = gate * up - hidden_states = self.w2(hidden_states) - return hidden_states + gate.mul_(up) + del up + return self.w2(gate) class MiniMaxM2SparseMoeBlock(nn.Module): @@ -233,7 +233,7 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens token_states = hidden_states.index_select(0, top_x) expert_output = expert_layer(token_states) * routing_weights[top_x, idx].unsqueeze(-1) final_hidden_states.index_add_(0, top_x, expert_output.to(final_hidden_states.dtype)) - + del expert_output, token_states, idx, top_x final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim) return final_hidden_states, router_logits @@ -302,10 +302,12 @@ def forward( output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: bsz, q_len, _ = hidden_states.size() + device = hidden_states.device query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + del hidden_states if self.use_qk_norm: q_flat = query_states.transpose(1, 2).reshape(bsz * q_len, -1) @@ -333,28 +335,55 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling + query_dtype = query_states.dtype + key_states_shape_2 = key_states.shape[-2] + + attn_weights = torch.empty( + (bsz, self.num_heads, q_len, key_states_shape_2), device=device, dtype=query_dtype + ) + for i in range(self.num_heads): + attn_weights[:, i, :, :] = torch.matmul( + query_states[:, i, :, :], key_states[:, i, :, :].transpose(-2, -1) + ) + + attn_weights *= self.scaling + del query_states, key_states + if attention_mask is not None: - attn_weights = attn_weights + attention_mask + attn_weights.add_(attention_mask) if self.sliding_window is not None and past_key_values is None: - query_positions = torch.arange(q_len, device=hidden_states.device).view(1, 1, q_len, 1) - key_positions = torch.arange(key_states.shape[-2], device=hidden_states.device).view(1, 1, 1, -1) + query_positions = torch.arange(q_len, device=device).view(1, 1, q_len, 1) + key_positions = torch.arange(key_states_shape_2, device=device).view(1, 1, 1, -1) window_mask = key_positions < (query_positions - self.sliding_window) if window_mask.any(): - attn_weights = attn_weights.masked_fill(window_mask, float("-inf")) + attn_weights.masked_fill_(window_mask, float("-inf")) + del query_positions, key_positions, window_mask + + for i in range(self.num_heads): + attn_weights[:, i, :, :] = torch.softmax( + attn_weights[:, i, :, :], dim=-1, dtype=torch.float32 + ).to(query_dtype) - attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) if self.training and self.attention_dropout > 0: attn_weights = F.dropout(attn_weights, p=self.attention_dropout) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = torch.empty( + (bsz, self.num_heads, q_len, self.head_dim), device=attn_weights.device, dtype=attn_weights.dtype + ) + for i in range(self.num_heads): + attn_output[:, i, :, :] = torch.matmul(attn_weights[:, i, :, :], value_states[:, i, :, :]) + + del value_states + attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights + if output_attentions: + return attn_output, attn_weights + + del attn_weights + return attn_output, None class MiniMaxM2LogitsProcessor(nn.Module): From 1fb87231a3496a3a7e5ba8e6f91cd420f001e0f8 Mon Sep 17 00:00:00 2001 From: avtc Date: Tue, 4 Nov 2025 21:41:37 +0200 Subject: [PATCH 2/5] opt.6 --- gptqmodel/hf_minimax_m2/modeling_minimax_m2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py b/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py index de5313481..1434c87cb 100644 --- a/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py +++ b/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py @@ -231,9 +231,11 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) token_states = hidden_states.index_select(0, top_x) - expert_output = expert_layer(token_states) * routing_weights[top_x, idx].unsqueeze(-1) + expert_output = expert_layer(token_states) + weights = routing_weights[top_x, idx].unsqueeze(-1) + expert_output.mul_(weights) final_hidden_states.index_add_(0, top_x, expert_output.to(final_hidden_states.dtype)) - del expert_output, token_states, idx, top_x + del expert_output, token_states, idx, top_x, weights final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim) return final_hidden_states, router_logits From eed748a548ee64df7abc108cbaf56672131370b6 Mon Sep 17 00:00:00 2001 From: avtc Date: Tue, 4 Nov 2025 21:46:00 +0200 Subject: [PATCH 3/5] opt.7 --- .../hf_minimax_m2/modeling_minimax_m2.py | 39 +++++++++++++------ 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py b/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py index 1434c87cb..4cabd6c5c 100644 --- a/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py +++ b/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py @@ -168,7 +168,8 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens 1.0 - self.jitter_noise, 1.0 + self.jitter_noise, ) - hidden_states = hidden_states * noise + hidden_states.mul_(noise) + del noise hidden_states = hidden_states.view(-1, hidden_dim) gate_dtype = self.gate.weight.dtype @@ -188,7 +189,7 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens if correction_bias is not None: original_scores = scores - scores = scores + correction_bias + scores.add_(correction_bias) else: original_scores = scores topk_scores: torch.Tensor @@ -216,26 +217,42 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens routing_weights = original_scores.gather(1, selected_experts) else: routing_weights = topk_scores + del scores, original_scores, topk_scores - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True).clamp(min=1e-12) + routing_weights.div_(routing_weights.sum(dim=-1, keepdim=True).clamp(min=1e-12)) if self.routed_scaling_factor != 1.0: - routing_weights = routing_weights * self.routed_scaling_factor + routing_weights.mul_(self.routed_scaling_factor) routing_weights = routing_weights.to(hidden_states.dtype) selected_experts = selected_experts.to(torch.long) final_hidden_states = torch.zeros_like(hidden_states) expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + del selected_experts expert_hit = torch.nonzero(expert_mask.sum(dim=(-1, -2)) > 0, as_tuple=False).flatten() + # To further reduce memory, process tokens routed to each expert in chunks + # instead of all at once. A chunk size of 1024 is a reasonable default. + EXPERT_CHUNK_SIZE = 1024 + for expert_idx in expert_hit.tolist(): expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - token_states = hidden_states.index_select(0, top_x) - expert_output = expert_layer(token_states) - weights = routing_weights[top_x, idx].unsqueeze(-1) - expert_output.mul_(weights) - final_hidden_states.index_add_(0, top_x, expert_output.to(final_hidden_states.dtype)) - del expert_output, token_states, idx, top_x, weights + idx_full, top_x_full = torch.where(expert_mask[expert_idx].squeeze(0)) + + for i in range(0, top_x_full.size(0), EXPERT_CHUNK_SIZE): + top_x = top_x_full[i : i + EXPERT_CHUNK_SIZE] + idx = idx_full[i : i + EXPERT_CHUNK_SIZE] + + token_states = hidden_states.index_select(0, top_x) + expert_output = expert_layer(token_states) + + weights = routing_weights[top_x, idx].unsqueeze(-1) + expert_output.mul_(weights) + + final_hidden_states.index_add_(0, top_x, expert_output.to(final_hidden_states.dtype)) + del expert_output, token_states, idx, top_x, weights + + del idx_full, top_x_full + del hidden_states, routing_weights, expert_mask, expert_hit final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim) return final_hidden_states, router_logits From 339e38b008141fd4c93c6858b1377f0a4fe8983d Mon Sep 17 00:00:00 2001 From: avtc Date: Wed, 5 Nov 2025 14:47:57 +0200 Subject: [PATCH 4/5] opt.8-14 glm-4.6 + gpt --- .../hf_minimax_m2/modeling_minimax_m2.py | 87 ++++++++++--------- 1 file changed, 48 insertions(+), 39 deletions(-) diff --git a/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py b/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py index 4cabd6c5c..68aa02764 100644 --- a/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py +++ b/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py @@ -323,11 +323,13 @@ def forward( bsz, q_len, _ = hidden_states.size() device = hidden_states.device + # projections query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) del hidden_states + # optional QK normalization if self.use_qk_norm: q_flat = query_states.transpose(1, 2).reshape(bsz * q_len, -1) k_flat = key_states.transpose(1, 2).reshape(bsz * q_len, -1) @@ -336,6 +338,7 @@ def forward( query_states = q_flat.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = k_flat.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # rotary embeddings if position_embeddings is None: cos, sin = self.rotary_emb(value_states, position_ids) else: @@ -347,6 +350,7 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) + # handle cache if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -355,54 +359,59 @@ def forward( value_states = repeat_kv(value_states, self.num_key_value_groups) query_dtype = query_states.dtype - key_states_shape_2 = key_states.shape[-2] + key_len = key_states.shape[-2] - attn_weights = torch.empty( - (bsz, self.num_heads, q_len, key_states_shape_2), device=device, dtype=query_dtype - ) - for i in range(self.num_heads): - attn_weights[:, i, :, :] = torch.matmul( - query_states[:, i, :, :], key_states[:, i, :, :].transpose(-2, -1) - ) + # precompute sliding-window mask + window_mask = None + if self.sliding_window is not None and past_key_values is None: + q_pos = torch.arange(q_len, device=device).view(1, 1, q_len, 1) + k_pos = torch.arange(key_len, device=device).view(1, 1, 1, key_len) + wm = k_pos < (q_pos - self.sliding_window) + if wm.any(): + window_mask = wm.squeeze(1) # (1, q_len, key_len) + del q_pos, k_pos, wm + + attn_output_parts = [] + attn_weights_list = [] if output_attentions else None + + for h in range(self.num_heads): + # (bsz, q_len, key_len) + q = query_states[:, h, :, :] + k = key_states[:, h, :, :] + v = value_states[:, h, :, :] + + attn = torch.matmul(q, k.transpose(-2, -1)) + attn.mul_(self.scaling) + + # shape (bsz, 1, q_len, key_len) -> (bsz, q_len, key_len) + if attention_mask is not None: + attn.add_(attention_mask.squeeze(1)) - attn_weights *= self.scaling - del query_states, key_states + if window_mask is not None: + attn.masked_fill_(window_mask, float("-inf")) - if attention_mask is not None: - attn_weights.add_(attention_mask) + attn = torch.softmax(attn, dim=-1, dtype=torch.float32).to(query_dtype) - if self.sliding_window is not None and past_key_values is None: - query_positions = torch.arange(q_len, device=device).view(1, 1, q_len, 1) - key_positions = torch.arange(key_states_shape_2, device=device).view(1, 1, 1, -1) - window_mask = key_positions < (query_positions - self.sliding_window) - if window_mask.any(): - attn_weights.masked_fill_(window_mask, float("-inf")) - del query_positions, key_positions, window_mask - - for i in range(self.num_heads): - attn_weights[:, i, :, :] = torch.softmax( - attn_weights[:, i, :, :], dim=-1, dtype=torch.float32 - ).to(query_dtype) - - if self.training and self.attention_dropout > 0: - attn_weights = F.dropout(attn_weights, p=self.attention_dropout) - - attn_output = torch.empty( - (bsz, self.num_heads, q_len, self.head_dim), device=attn_weights.device, dtype=attn_weights.dtype - ) - for i in range(self.num_heads): - attn_output[:, i, :, :] = torch.matmul(attn_weights[:, i, :, :], value_states[:, i, :, :]) + if self.training and self.attention_dropout > 0: + attn = F.dropout(attn, p=self.attention_dropout, training=True) - del value_states + if output_attentions: + attn_weights_list.append(attn) + + out = torch.matmul(attn, v) + attn_output_parts.append(out) + + del q, k, v, attn, out + + attn_output = torch.stack(attn_output_parts, dim=1) + del attn_output_parts, query_states, key_states, value_states + + attn_weights = torch.stack(attn_weights_list, dim=1) if output_attentions else None attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - if output_attentions: - return attn_output, attn_weights - - del attn_weights - return attn_output, None + return attn_output, attn_weights class MiniMaxM2LogitsProcessor(nn.Module): From 75b9d13706b6b5696e9e69da4e736dc84533c421 Mon Sep 17 00:00:00 2001 From: avtc Date: Wed, 5 Nov 2025 18:01:32 +0200 Subject: [PATCH 5/5] opt.15 chunks for attn --- .../hf_minimax_m2/modeling_minimax_m2.py | 57 +++++++++++++------ 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py b/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py index 68aa02764..b4bb0407e 100644 --- a/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py +++ b/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py @@ -380,28 +380,49 @@ def forward( k = key_states[:, h, :, :] v = value_states[:, h, :, :] - attn = torch.matmul(q, k.transpose(-2, -1)) - attn.mul_(self.scaling) - - # shape (bsz, 1, q_len, key_len) -> (bsz, q_len, key_len) - if attention_mask is not None: - attn.add_(attention_mask.squeeze(1)) - - if window_mask is not None: - attn.masked_fill_(window_mask, float("-inf")) - - attn = torch.softmax(attn, dim=-1, dtype=torch.float32).to(query_dtype) - - if self.training and self.attention_dropout > 0: - attn = F.dropout(attn, p=self.attention_dropout, training=True) + # Chunked attention computation to reduce peak memory usage + out_parts = [] + attn_parts = [] if output_attentions else None + + # A smaller chunk size reduces memory but may be slightly slower + chunk_size = 1024 + for i in range(0, q.size(1), chunk_size): + q_chunk = q[:, i:i + chunk_size, :] + + # attn_chunk has shape (bsz, chunk_size, key_len) + attn_chunk = torch.matmul(q_chunk, k.transpose(-2, -1)) + attn_chunk.mul_(self.scaling) + + # Apply masks to the chunk + if attention_mask is not None: + attn_chunk.add_(attention_mask.squeeze(1)[:, i:i + chunk_size, :]) + + if window_mask is not None: + attn_chunk.masked_fill_(window_mask[:, i:i + chunk_size, :], float("-inf")) + + attn_chunk = torch.softmax(attn_chunk, dim=-1, dtype=torch.float32).to(query_dtype) + + if self.training and self.attention_dropout > 0: + attn_chunk = F.dropout(attn_chunk, p=self.attention_dropout, training=True) + + if output_attentions: + attn_parts.append(attn_chunk) + + # output_chunk has shape (bsz, chunk_size, head_dim) + out_chunk = torch.matmul(attn_chunk, v) + out_parts.append(out_chunk) + + del q_chunk, attn_chunk, out_chunk + + out = torch.cat(out_parts, dim=1) + attn_output_parts.append(out) if output_attentions: + attn = torch.cat(attn_parts, dim=1) attn_weights_list.append(attn) + del attn, attn_parts - out = torch.matmul(attn, v) - attn_output_parts.append(out) - - del q, k, v, attn, out + del q, k, v, out, out_parts attn_output = torch.stack(attn_output_parts, dim=1) del attn_output_parts, query_states, key_states, value_states