Skip to content

Commit a386df3

Browse files
authored
Optimize minimax m2 modelling forward pass (#2176)
* first 5 iterations * opt.6 * opt.7 * opt.8-14 glm-4.6 + gpt * opt.15 chunks for attn
1 parent 927de72 commit a386df3

File tree

1 file changed

+103
-25
lines changed

1 file changed

+103
-25
lines changed

gptqmodel/hf_minimax_m2/modeling_minimax_m2.py

Lines changed: 103 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@ def __init__(self, config: MiniMaxM2Config) -> None:
126126
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
127127
gate = self.act_fn(self.w1(hidden_states))
128128
up = self.w3(hidden_states)
129-
hidden_states = gate * up
130-
hidden_states = self.w2(hidden_states)
131-
return hidden_states
129+
gate.mul_(up)
130+
del up
131+
return self.w2(gate)
132132

133133

134134
class MiniMaxM2SparseMoeBlock(nn.Module):
@@ -168,7 +168,8 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
168168
1.0 - self.jitter_noise,
169169
1.0 + self.jitter_noise,
170170
)
171-
hidden_states = hidden_states * noise
171+
hidden_states.mul_(noise)
172+
del noise
172173

173174
hidden_states = hidden_states.view(-1, hidden_dim)
174175
gate_dtype = self.gate.weight.dtype
@@ -188,7 +189,7 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
188189

189190
if correction_bias is not None:
190191
original_scores = scores
191-
scores = scores + correction_bias
192+
scores.add_(correction_bias)
192193
else:
193194
original_scores = scores
194195
topk_scores: torch.Tensor
@@ -216,24 +217,42 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
216217
routing_weights = original_scores.gather(1, selected_experts)
217218
else:
218219
routing_weights = topk_scores
220+
del scores, original_scores, topk_scores
219221

220-
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True).clamp(min=1e-12)
222+
routing_weights.div_(routing_weights.sum(dim=-1, keepdim=True).clamp(min=1e-12))
221223
if self.routed_scaling_factor != 1.0:
222-
routing_weights = routing_weights * self.routed_scaling_factor
224+
routing_weights.mul_(self.routed_scaling_factor)
223225
routing_weights = routing_weights.to(hidden_states.dtype)
224226
selected_experts = selected_experts.to(torch.long)
225227

226228
final_hidden_states = torch.zeros_like(hidden_states)
227229
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
230+
del selected_experts
228231
expert_hit = torch.nonzero(expert_mask.sum(dim=(-1, -2)) > 0, as_tuple=False).flatten()
229232

233+
# To further reduce memory, process tokens routed to each expert in chunks
234+
# instead of all at once. A chunk size of 1024 is a reasonable default.
235+
EXPERT_CHUNK_SIZE = 1024
236+
230237
for expert_idx in expert_hit.tolist():
231238
expert_layer = self.experts[expert_idx]
232-
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
233-
token_states = hidden_states.index_select(0, top_x)
234-
expert_output = expert_layer(token_states) * routing_weights[top_x, idx].unsqueeze(-1)
235-
final_hidden_states.index_add_(0, top_x, expert_output.to(final_hidden_states.dtype))
239+
idx_full, top_x_full = torch.where(expert_mask[expert_idx].squeeze(0))
240+
241+
for i in range(0, top_x_full.size(0), EXPERT_CHUNK_SIZE):
242+
top_x = top_x_full[i : i + EXPERT_CHUNK_SIZE]
243+
idx = idx_full[i : i + EXPERT_CHUNK_SIZE]
244+
245+
token_states = hidden_states.index_select(0, top_x)
246+
expert_output = expert_layer(token_states)
247+
248+
weights = routing_weights[top_x, idx].unsqueeze(-1)
249+
expert_output.mul_(weights)
250+
251+
final_hidden_states.index_add_(0, top_x, expert_output.to(final_hidden_states.dtype))
252+
del expert_output, token_states, idx, top_x, weights
236253

254+
del idx_full, top_x_full
255+
del hidden_states, routing_weights, expert_mask, expert_hit
237256
final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
238257
return final_hidden_states, router_logits
239258

@@ -302,11 +321,15 @@ def forward(
302321
output_attentions: bool = False,
303322
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
304323
bsz, q_len, _ = hidden_states.size()
324+
device = hidden_states.device
305325

326+
# projections
306327
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
307328
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
308329
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
330+
del hidden_states
309331

332+
# optional QK normalization
310333
if self.use_qk_norm:
311334
q_flat = query_states.transpose(1, 2).reshape(bsz * q_len, -1)
312335
k_flat = key_states.transpose(1, 2).reshape(bsz * q_len, -1)
@@ -315,6 +338,7 @@ def forward(
315338
query_states = q_flat.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
316339
key_states = k_flat.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
317340

341+
# rotary embeddings
318342
if position_embeddings is None:
319343
cos, sin = self.rotary_emb(value_states, position_ids)
320344
else:
@@ -326,34 +350,88 @@ def forward(
326350
query_states = query_states.transpose(1, 2)
327351
key_states = key_states.transpose(1, 2)
328352

353+
# handle cache
329354
if past_key_values is not None:
330355
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
331356
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
332357

333358
key_states = repeat_kv(key_states, self.num_key_value_groups)
334359
value_states = repeat_kv(value_states, self.num_key_value_groups)
335360

336-
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling
337-
if attention_mask is not None:
338-
attn_weights = attn_weights + attention_mask
361+
query_dtype = query_states.dtype
362+
key_len = key_states.shape[-2]
339363

364+
# precompute sliding-window mask
365+
window_mask = None
340366
if self.sliding_window is not None and past_key_values is None:
341-
query_positions = torch.arange(q_len, device=hidden_states.device).view(1, 1, q_len, 1)
342-
key_positions = torch.arange(key_states.shape[-2], device=hidden_states.device).view(1, 1, 1, -1)
343-
window_mask = key_positions < (query_positions - self.sliding_window)
344-
if window_mask.any():
345-
attn_weights = attn_weights.masked_fill(window_mask, float("-inf"))
367+
q_pos = torch.arange(q_len, device=device).view(1, 1, q_len, 1)
368+
k_pos = torch.arange(key_len, device=device).view(1, 1, 1, key_len)
369+
wm = k_pos < (q_pos - self.sliding_window)
370+
if wm.any():
371+
window_mask = wm.squeeze(1) # (1, q_len, key_len)
372+
del q_pos, k_pos, wm
373+
374+
attn_output_parts = []
375+
attn_weights_list = [] if output_attentions else None
376+
377+
for h in range(self.num_heads):
378+
# (bsz, q_len, key_len)
379+
q = query_states[:, h, :, :]
380+
k = key_states[:, h, :, :]
381+
v = value_states[:, h, :, :]
382+
383+
# Chunked attention computation to reduce peak memory usage
384+
out_parts = []
385+
attn_parts = [] if output_attentions else None
386+
387+
# A smaller chunk size reduces memory but may be slightly slower
388+
chunk_size = 1024
389+
for i in range(0, q.size(1), chunk_size):
390+
q_chunk = q[:, i:i + chunk_size, :]
391+
392+
# attn_chunk has shape (bsz, chunk_size, key_len)
393+
attn_chunk = torch.matmul(q_chunk, k.transpose(-2, -1))
394+
attn_chunk.mul_(self.scaling)
395+
396+
# Apply masks to the chunk
397+
if attention_mask is not None:
398+
attn_chunk.add_(attention_mask.squeeze(1)[:, i:i + chunk_size, :])
399+
400+
if window_mask is not None:
401+
attn_chunk.masked_fill_(window_mask[:, i:i + chunk_size, :], float("-inf"))
402+
403+
attn_chunk = torch.softmax(attn_chunk, dim=-1, dtype=torch.float32).to(query_dtype)
404+
405+
if self.training and self.attention_dropout > 0:
406+
attn_chunk = F.dropout(attn_chunk, p=self.attention_dropout, training=True)
407+
408+
if output_attentions:
409+
attn_parts.append(attn_chunk)
410+
411+
# output_chunk has shape (bsz, chunk_size, head_dim)
412+
out_chunk = torch.matmul(attn_chunk, v)
413+
out_parts.append(out_chunk)
414+
415+
del q_chunk, attn_chunk, out_chunk
416+
417+
out = torch.cat(out_parts, dim=1)
418+
attn_output_parts.append(out)
419+
420+
if output_attentions:
421+
attn = torch.cat(attn_parts, dim=1)
422+
attn_weights_list.append(attn)
423+
del attn, attn_parts
424+
425+
del q, k, v, out, out_parts
426+
427+
attn_output = torch.stack(attn_output_parts, dim=1)
428+
del attn_output_parts, query_states, key_states, value_states
346429

347-
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
348-
if self.training and self.attention_dropout > 0:
349-
attn_weights = F.dropout(attn_weights, p=self.attention_dropout)
430+
attn_weights = torch.stack(attn_weights_list, dim=1) if output_attentions else None
350431

351-
attn_output = torch.matmul(attn_weights, value_states)
352432
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
353433
attn_output = self.o_proj(attn_output)
354434

355-
if not output_attentions:
356-
attn_weights = None
357435
return attn_output, attn_weights
358436

359437

0 commit comments

Comments
 (0)