@@ -126,9 +126,9 @@ def __init__(self, config: MiniMaxM2Config) -> None:
126126 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
127127 gate = self .act_fn (self .w1 (hidden_states ))
128128 up = self .w3 (hidden_states )
129- hidden_states = gate * up
130- hidden_states = self . w2 ( hidden_states )
131- return hidden_states
129+ gate . mul_ ( up )
130+ del up
131+ return self . w2 ( gate )
132132
133133
134134class MiniMaxM2SparseMoeBlock (nn .Module ):
@@ -168,7 +168,8 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
168168 1.0 - self .jitter_noise ,
169169 1.0 + self .jitter_noise ,
170170 )
171- hidden_states = hidden_states * noise
171+ hidden_states .mul_ (noise )
172+ del noise
172173
173174 hidden_states = hidden_states .view (- 1 , hidden_dim )
174175 gate_dtype = self .gate .weight .dtype
@@ -188,7 +189,7 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
188189
189190 if correction_bias is not None :
190191 original_scores = scores
191- scores = scores + correction_bias
192+ scores . add_ ( correction_bias )
192193 else :
193194 original_scores = scores
194195 topk_scores : torch .Tensor
@@ -216,24 +217,42 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
216217 routing_weights = original_scores .gather (1 , selected_experts )
217218 else :
218219 routing_weights = topk_scores
220+ del scores , original_scores , topk_scores
219221
220- routing_weights = routing_weights / routing_weights .sum (dim = - 1 , keepdim = True ).clamp (min = 1e-12 )
222+ routing_weights . div_ ( routing_weights .sum (dim = - 1 , keepdim = True ).clamp (min = 1e-12 ) )
221223 if self .routed_scaling_factor != 1.0 :
222- routing_weights = routing_weights * self .routed_scaling_factor
224+ routing_weights . mul_ ( self .routed_scaling_factor )
223225 routing_weights = routing_weights .to (hidden_states .dtype )
224226 selected_experts = selected_experts .to (torch .long )
225227
226228 final_hidden_states = torch .zeros_like (hidden_states )
227229 expert_mask = torch .nn .functional .one_hot (selected_experts , num_classes = self .num_experts ).permute (2 , 1 , 0 )
230+ del selected_experts
228231 expert_hit = torch .nonzero (expert_mask .sum (dim = (- 1 , - 2 )) > 0 , as_tuple = False ).flatten ()
229232
233+ # To further reduce memory, process tokens routed to each expert in chunks
234+ # instead of all at once. A chunk size of 1024 is a reasonable default.
235+ EXPERT_CHUNK_SIZE = 1024
236+
230237 for expert_idx in expert_hit .tolist ():
231238 expert_layer = self .experts [expert_idx ]
232- idx , top_x = torch .where (expert_mask [expert_idx ].squeeze (0 ))
233- token_states = hidden_states .index_select (0 , top_x )
234- expert_output = expert_layer (token_states ) * routing_weights [top_x , idx ].unsqueeze (- 1 )
235- final_hidden_states .index_add_ (0 , top_x , expert_output .to (final_hidden_states .dtype ))
239+ idx_full , top_x_full = torch .where (expert_mask [expert_idx ].squeeze (0 ))
240+
241+ for i in range (0 , top_x_full .size (0 ), EXPERT_CHUNK_SIZE ):
242+ top_x = top_x_full [i : i + EXPERT_CHUNK_SIZE ]
243+ idx = idx_full [i : i + EXPERT_CHUNK_SIZE ]
244+
245+ token_states = hidden_states .index_select (0 , top_x )
246+ expert_output = expert_layer (token_states )
247+
248+ weights = routing_weights [top_x , idx ].unsqueeze (- 1 )
249+ expert_output .mul_ (weights )
250+
251+ final_hidden_states .index_add_ (0 , top_x , expert_output .to (final_hidden_states .dtype ))
252+ del expert_output , token_states , idx , top_x , weights
236253
254+ del idx_full , top_x_full
255+ del hidden_states , routing_weights , expert_mask , expert_hit
237256 final_hidden_states = final_hidden_states .view (batch_size , seq_len , hidden_dim )
238257 return final_hidden_states , router_logits
239258
@@ -302,11 +321,15 @@ def forward(
302321 output_attentions : bool = False ,
303322 ) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
304323 bsz , q_len , _ = hidden_states .size ()
324+ device = hidden_states .device
305325
326+ # projections
306327 query_states = self .q_proj (hidden_states ).view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
307328 key_states = self .k_proj (hidden_states ).view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
308329 value_states = self .v_proj (hidden_states ).view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
330+ del hidden_states
309331
332+ # optional QK normalization
310333 if self .use_qk_norm :
311334 q_flat = query_states .transpose (1 , 2 ).reshape (bsz * q_len , - 1 )
312335 k_flat = key_states .transpose (1 , 2 ).reshape (bsz * q_len , - 1 )
@@ -315,6 +338,7 @@ def forward(
315338 query_states = q_flat .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
316339 key_states = k_flat .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
317340
341+ # rotary embeddings
318342 if position_embeddings is None :
319343 cos , sin = self .rotary_emb (value_states , position_ids )
320344 else :
@@ -326,34 +350,88 @@ def forward(
326350 query_states = query_states .transpose (1 , 2 )
327351 key_states = key_states .transpose (1 , 2 )
328352
353+ # handle cache
329354 if past_key_values is not None :
330355 cache_kwargs = {"sin" : sin , "cos" : cos , "cache_position" : cache_position }
331356 key_states , value_states = past_key_values .update (key_states , value_states , self .layer_idx , cache_kwargs )
332357
333358 key_states = repeat_kv (key_states , self .num_key_value_groups )
334359 value_states = repeat_kv (value_states , self .num_key_value_groups )
335360
336- attn_weights = torch .matmul (query_states , key_states .transpose (- 2 , - 1 )) * self .scaling
337- if attention_mask is not None :
338- attn_weights = attn_weights + attention_mask
361+ query_dtype = query_states .dtype
362+ key_len = key_states .shape [- 2 ]
339363
364+ # precompute sliding-window mask
365+ window_mask = None
340366 if self .sliding_window is not None and past_key_values is None :
341- query_positions = torch .arange (q_len , device = hidden_states .device ).view (1 , 1 , q_len , 1 )
342- key_positions = torch .arange (key_states .shape [- 2 ], device = hidden_states .device ).view (1 , 1 , 1 , - 1 )
343- window_mask = key_positions < (query_positions - self .sliding_window )
344- if window_mask .any ():
345- attn_weights = attn_weights .masked_fill (window_mask , float ("-inf" ))
367+ q_pos = torch .arange (q_len , device = device ).view (1 , 1 , q_len , 1 )
368+ k_pos = torch .arange (key_len , device = device ).view (1 , 1 , 1 , key_len )
369+ wm = k_pos < (q_pos - self .sliding_window )
370+ if wm .any ():
371+ window_mask = wm .squeeze (1 ) # (1, q_len, key_len)
372+ del q_pos , k_pos , wm
373+
374+ attn_output_parts = []
375+ attn_weights_list = [] if output_attentions else None
376+
377+ for h in range (self .num_heads ):
378+ # (bsz, q_len, key_len)
379+ q = query_states [:, h , :, :]
380+ k = key_states [:, h , :, :]
381+ v = value_states [:, h , :, :]
382+
383+ # Chunked attention computation to reduce peak memory usage
384+ out_parts = []
385+ attn_parts = [] if output_attentions else None
386+
387+ # A smaller chunk size reduces memory but may be slightly slower
388+ chunk_size = 1024
389+ for i in range (0 , q .size (1 ), chunk_size ):
390+ q_chunk = q [:, i :i + chunk_size , :]
391+
392+ # attn_chunk has shape (bsz, chunk_size, key_len)
393+ attn_chunk = torch .matmul (q_chunk , k .transpose (- 2 , - 1 ))
394+ attn_chunk .mul_ (self .scaling )
395+
396+ # Apply masks to the chunk
397+ if attention_mask is not None :
398+ attn_chunk .add_ (attention_mask .squeeze (1 )[:, i :i + chunk_size , :])
399+
400+ if window_mask is not None :
401+ attn_chunk .masked_fill_ (window_mask [:, i :i + chunk_size , :], float ("-inf" ))
402+
403+ attn_chunk = torch .softmax (attn_chunk , dim = - 1 , dtype = torch .float32 ).to (query_dtype )
404+
405+ if self .training and self .attention_dropout > 0 :
406+ attn_chunk = F .dropout (attn_chunk , p = self .attention_dropout , training = True )
407+
408+ if output_attentions :
409+ attn_parts .append (attn_chunk )
410+
411+ # output_chunk has shape (bsz, chunk_size, head_dim)
412+ out_chunk = torch .matmul (attn_chunk , v )
413+ out_parts .append (out_chunk )
414+
415+ del q_chunk , attn_chunk , out_chunk
416+
417+ out = torch .cat (out_parts , dim = 1 )
418+ attn_output_parts .append (out )
419+
420+ if output_attentions :
421+ attn = torch .cat (attn_parts , dim = 1 )
422+ attn_weights_list .append (attn )
423+ del attn , attn_parts
424+
425+ del q , k , v , out , out_parts
426+
427+ attn_output = torch .stack (attn_output_parts , dim = 1 )
428+ del attn_output_parts , query_states , key_states , value_states
346429
347- attn_weights = torch .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query_states .dtype )
348- if self .training and self .attention_dropout > 0 :
349- attn_weights = F .dropout (attn_weights , p = self .attention_dropout )
430+ attn_weights = torch .stack (attn_weights_list , dim = 1 ) if output_attentions else None
350431
351- attn_output = torch .matmul (attn_weights , value_states )
352432 attn_output = attn_output .transpose (1 , 2 ).contiguous ().view (bsz , q_len , - 1 )
353433 attn_output = self .o_proj (attn_output )
354434
355- if not output_attentions :
356- attn_weights = None
357435 return attn_output , attn_weights
358436
359437
0 commit comments