diff --git a/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py b/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py index b847276da..b4bb0407e 100644 --- a/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py +++ b/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py @@ -126,9 +126,9 @@ def __init__(self, config: MiniMaxM2Config) -> None: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: gate = self.act_fn(self.w1(hidden_states)) up = self.w3(hidden_states) - hidden_states = gate * up - hidden_states = self.w2(hidden_states) - return hidden_states + gate.mul_(up) + del up + return self.w2(gate) class MiniMaxM2SparseMoeBlock(nn.Module): @@ -168,7 +168,8 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens 1.0 - self.jitter_noise, 1.0 + self.jitter_noise, ) - hidden_states = hidden_states * noise + hidden_states.mul_(noise) + del noise hidden_states = hidden_states.view(-1, hidden_dim) gate_dtype = self.gate.weight.dtype @@ -188,7 +189,7 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens if correction_bias is not None: original_scores = scores - scores = scores + correction_bias + scores.add_(correction_bias) else: original_scores = scores topk_scores: torch.Tensor @@ -216,24 +217,42 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens routing_weights = original_scores.gather(1, selected_experts) else: routing_weights = topk_scores + del scores, original_scores, topk_scores - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True).clamp(min=1e-12) + routing_weights.div_(routing_weights.sum(dim=-1, keepdim=True).clamp(min=1e-12)) if self.routed_scaling_factor != 1.0: - routing_weights = routing_weights * self.routed_scaling_factor + routing_weights.mul_(self.routed_scaling_factor) routing_weights = routing_weights.to(hidden_states.dtype) selected_experts = selected_experts.to(torch.long) final_hidden_states = torch.zeros_like(hidden_states) expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + del selected_experts expert_hit = torch.nonzero(expert_mask.sum(dim=(-1, -2)) > 0, as_tuple=False).flatten() + # To further reduce memory, process tokens routed to each expert in chunks + # instead of all at once. A chunk size of 1024 is a reasonable default. + EXPERT_CHUNK_SIZE = 1024 + for expert_idx in expert_hit.tolist(): expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - token_states = hidden_states.index_select(0, top_x) - expert_output = expert_layer(token_states) * routing_weights[top_x, idx].unsqueeze(-1) - final_hidden_states.index_add_(0, top_x, expert_output.to(final_hidden_states.dtype)) + idx_full, top_x_full = torch.where(expert_mask[expert_idx].squeeze(0)) + + for i in range(0, top_x_full.size(0), EXPERT_CHUNK_SIZE): + top_x = top_x_full[i : i + EXPERT_CHUNK_SIZE] + idx = idx_full[i : i + EXPERT_CHUNK_SIZE] + + token_states = hidden_states.index_select(0, top_x) + expert_output = expert_layer(token_states) + + weights = routing_weights[top_x, idx].unsqueeze(-1) + expert_output.mul_(weights) + + final_hidden_states.index_add_(0, top_x, expert_output.to(final_hidden_states.dtype)) + del expert_output, token_states, idx, top_x, weights + del idx_full, top_x_full + del hidden_states, routing_weights, expert_mask, expert_hit final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim) return final_hidden_states, router_logits @@ -302,11 +321,15 @@ def forward( output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: bsz, q_len, _ = hidden_states.size() + device = hidden_states.device + # projections query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + del hidden_states + # optional QK normalization if self.use_qk_norm: q_flat = query_states.transpose(1, 2).reshape(bsz * q_len, -1) k_flat = key_states.transpose(1, 2).reshape(bsz * q_len, -1) @@ -315,6 +338,7 @@ def forward( query_states = q_flat.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = k_flat.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # rotary embeddings if position_embeddings is None: cos, sin = self.rotary_emb(value_states, position_ids) else: @@ -326,6 +350,7 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) + # handle cache if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -333,27 +358,80 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling - if attention_mask is not None: - attn_weights = attn_weights + attention_mask + query_dtype = query_states.dtype + key_len = key_states.shape[-2] + # precompute sliding-window mask + window_mask = None if self.sliding_window is not None and past_key_values is None: - query_positions = torch.arange(q_len, device=hidden_states.device).view(1, 1, q_len, 1) - key_positions = torch.arange(key_states.shape[-2], device=hidden_states.device).view(1, 1, 1, -1) - window_mask = key_positions < (query_positions - self.sliding_window) - if window_mask.any(): - attn_weights = attn_weights.masked_fill(window_mask, float("-inf")) + q_pos = torch.arange(q_len, device=device).view(1, 1, q_len, 1) + k_pos = torch.arange(key_len, device=device).view(1, 1, 1, key_len) + wm = k_pos < (q_pos - self.sliding_window) + if wm.any(): + window_mask = wm.squeeze(1) # (1, q_len, key_len) + del q_pos, k_pos, wm + + attn_output_parts = [] + attn_weights_list = [] if output_attentions else None + + for h in range(self.num_heads): + # (bsz, q_len, key_len) + q = query_states[:, h, :, :] + k = key_states[:, h, :, :] + v = value_states[:, h, :, :] + + # Chunked attention computation to reduce peak memory usage + out_parts = [] + attn_parts = [] if output_attentions else None + + # A smaller chunk size reduces memory but may be slightly slower + chunk_size = 1024 + for i in range(0, q.size(1), chunk_size): + q_chunk = q[:, i:i + chunk_size, :] + + # attn_chunk has shape (bsz, chunk_size, key_len) + attn_chunk = torch.matmul(q_chunk, k.transpose(-2, -1)) + attn_chunk.mul_(self.scaling) + + # Apply masks to the chunk + if attention_mask is not None: + attn_chunk.add_(attention_mask.squeeze(1)[:, i:i + chunk_size, :]) + + if window_mask is not None: + attn_chunk.masked_fill_(window_mask[:, i:i + chunk_size, :], float("-inf")) + + attn_chunk = torch.softmax(attn_chunk, dim=-1, dtype=torch.float32).to(query_dtype) + + if self.training and self.attention_dropout > 0: + attn_chunk = F.dropout(attn_chunk, p=self.attention_dropout, training=True) + + if output_attentions: + attn_parts.append(attn_chunk) + + # output_chunk has shape (bsz, chunk_size, head_dim) + out_chunk = torch.matmul(attn_chunk, v) + out_parts.append(out_chunk) + + del q_chunk, attn_chunk, out_chunk + + out = torch.cat(out_parts, dim=1) + attn_output_parts.append(out) + + if output_attentions: + attn = torch.cat(attn_parts, dim=1) + attn_weights_list.append(attn) + del attn, attn_parts + + del q, k, v, out, out_parts + + attn_output = torch.stack(attn_output_parts, dim=1) + del attn_output_parts, query_states, key_states, value_states - attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - if self.training and self.attention_dropout > 0: - attn_weights = F.dropout(attn_weights, p=self.attention_dropout) + attn_weights = torch.stack(attn_weights_list, dim=1) if output_attentions else None - attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None return attn_output, attn_weights