@@ -191,131 +191,119 @@ def select_gemm_impl(
191191
192192 def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
193193 assert isinstance (layer , FusedMoE )
194- available_devices = self .mesh .devices .flatten ()
195- with jax .default_device (available_devices [0 ]):
196- w13_weight = t2j (layer .w13_weight , use_dlpack = False )
197- w2_weight = t2j (layer .w2_weight , use_dlpack = False )
194+ w13_weight = t2j (layer .w13_weight , use_dlpack = False )
195+ w2_weight = t2j (layer .w2_weight , use_dlpack = False )
198196
199- if self .moe .has_bias :
200- w13_bias = t2j (layer .w13_bias , use_dlpack = False )
201- w2_bias = t2j (layer .w2_bias , use_dlpack = False )
202-
203- if layer .activation == "swigluoai" :
204- # When using swigluoai, vLLM splits gmm output in a interleaved way.
205- # However, interleaved split is not performant on TPU. Therefore,
206- # we preprocess the weight so that splitting gmm output by middle
207- # can still get the same result.
208- w1_weight = w13_weight [:, ::2 , :]
209- w3_weight = w13_weight [:, 1 ::2 , :]
210- w13_weight = jnp .concat ([w1_weight , w3_weight ], axis = 1 )
197+ if self .moe .has_bias :
198+ w13_bias = t2j (layer .w13_bias , use_dlpack = False )
199+ w2_bias = t2j (layer .w2_bias , use_dlpack = False )
200+
201+ if layer .activation == "swigluoai" :
202+ # When using swigluoai, vLLM splits gmm output in a interleaved way.
203+ # However, interleaved split is not performant on TPU. Therefore,
204+ # we preprocess the weight so that splitting gmm output by middle
205+ # can still get the same result.
206+ w1_weight = w13_weight [:, ::2 , :]
207+ w3_weight = w13_weight [:, 1 ::2 , :]
208+ w13_weight = jnp .concat ([w1_weight , w3_weight ], axis = 1 )
211209
212- if self .moe .has_bias :
213- w1_bias = w13_bias [:, ::2 ]
214- w3_bias = w13_bias [:, 1 ::2 ]
215- w13_bias = jnp .concat ([w1_bias , w3_bias ], axis = 1 )
216-
217- if self .use_kernel and layer .use_ep :
218- # Kernel expects:
219- # w13: (num_experts, 2, hidden_size, intermediate_size)
220- # w2: (num_experts, intermediate_size, hidden_size)
221- # Current format:
222- # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
223- # w2_weight: (num_experts, hidden_size, intermediate_size)
224- num_experts = w13_weight .shape [0 ]
225- intermediate_size = w13_weight .shape [1 ] // 2
226- hidden_size = w13_weight .shape [2 ]
210+ if self .moe .has_bias :
211+ w1_bias = w13_bias [:, ::2 ]
212+ w3_bias = w13_bias [:, 1 ::2 ]
213+ w13_bias = jnp .concat ([w1_bias , w3_bias ], axis = 1 )
227214
228- # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
229- w13_reshaped = w13_weight .reshape (num_experts , 2 ,
230- intermediate_size ,
231- hidden_size )
232- w13_weight_transposed = jnp .transpose (w13_reshaped ,
233- (0 , 1 , 3 , 2 ))
215+ if self .use_kernel and layer .use_ep :
216+ # Kernel expects:
217+ # w13: (num_experts, 2, hidden_size, intermediate_size)
218+ # w2: (num_experts, intermediate_size, hidden_size)
219+ # Current format:
220+ # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
221+ # w2_weight: (num_experts, hidden_size, intermediate_size)
222+ num_experts = w13_weight .shape [0 ]
223+ intermediate_size = w13_weight .shape [1 ] // 2
224+ hidden_size = w13_weight .shape [2 ]
225+
226+ # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
227+ w13_reshaped = w13_weight .reshape (num_experts , 2 ,
228+ intermediate_size , hidden_size )
229+ w13_weight_transposed = jnp .transpose (w13_reshaped , (0 , 1 , 3 , 2 ))
230+
231+ # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
232+ w2_weight_transposed = jnp .transpose (w2_weight , (0 , 2 , 1 ))
233+
234+ # Apply EP sharding
235+ w13_weight = jax .device_put (
236+ w13_weight_transposed ,
237+ Format (Layout ((0 , 1 , 2 , 3 )),
238+ NamedSharding (self .mesh , P ("model" , None , None , None ))))
239+ w2_weight = jax .device_put (
240+ w2_weight_transposed ,
241+ Format (Layout ((0 , 1 , 2 )),
242+ NamedSharding (self .mesh , P ("model" , None , None ))))
234243
235- # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
236- w2_weight_transposed = jnp . transpose ( w2_weight , ( 0 , 2 , 1 ) )
244+ if self . moe . has_bias :
245+ w13_bias = w13_bias . reshape ( num_experts , 2 , intermediate_size )
237246
238247 # Apply EP sharding
248+ w13_bias = jax .device_put (
249+ w13_bias ,
250+ Format (Layout ((0 , 1 , 2 )),
251+ NamedSharding (self .mesh , P ("model" , None , None ))))
252+ w2_bias = jax .device_put (
253+ w2_bias ,
254+ Format (Layout ((0 , 1 )),
255+ NamedSharding (self .mesh , P ("model" , None ))))
256+
257+ else :
258+ # Original logic for non-kernel path
259+ if layer .use_ep :
239260 w13_weight = jax .device_put (
240- w13_weight_transposed ,
241- Format (
242- Layout ((0 , 1 , 2 , 3 )),
243- NamedSharding (self .mesh , P ("model" , None , None ,
244- None ))))
261+ w13_weight ,
262+ Format (Layout ((0 , 1 , 2 )),
263+ NamedSharding (self .mesh , P ("model" , None , None ))))
245264 w2_weight = jax .device_put (
246- w2_weight_transposed ,
265+ w2_weight ,
247266 Format (Layout ((0 , 1 , 2 )),
248267 NamedSharding (self .mesh , P ("model" , None , None ))))
249268
250269 if self .moe .has_bias :
251- w13_bias = w13_bias .reshape (num_experts , 2 ,
252- intermediate_size )
253-
254- # Apply EP sharding
255270 w13_bias = jax .device_put (
256271 w13_bias ,
257- Format (
258- Layout ((0 , 1 , 2 )),
259- NamedSharding (self .mesh , P ("model" , None , None ))))
272+ Format (Layout ((0 , 1 )),
273+ NamedSharding (self .mesh , P ("model" , None ))))
260274 w2_bias = jax .device_put (
261275 w2_bias ,
262276 Format (Layout ((0 , 1 )),
263277 NamedSharding (self .mesh , P ("model" , None ))))
264278
265279 else :
266- # Original logic for non-kernel path
267- if layer .use_ep :
268- w13_weight = jax .device_put (
269- w13_weight ,
270- Format (
271- Layout ((0 , 1 , 2 )),
272- NamedSharding (self .mesh , P ("model" , None , None ))))
273- w2_weight = jax .device_put (
274- w2_weight ,
275- Format (
276- Layout ((0 , 1 , 2 )),
277- NamedSharding (self .mesh , P ("model" , None , None ))))
278-
279- if self .moe .has_bias :
280- w13_bias = jax .device_put (
281- w13_bias ,
282- Format (Layout ((0 , 1 )),
283- NamedSharding (self .mesh , P ("model" , None ))))
284- w2_bias = jax .device_put (
285- w2_bias ,
286- Format (Layout ((0 , 1 )),
287- NamedSharding (self .mesh , P ("model" , None ))))
288-
289- else :
290- intermediate_size = w13_weight .shape [1 ] // 2
291- assert intermediate_size == w2_weight .shape [- 1 ]
292- output_sizes = [intermediate_size , intermediate_size ]
293- n_shards = self .mesh .shape ["model" ]
294- assert intermediate_size % n_shards == 0
295- w13_weight = reorder_concatenated_tensor_for_sharding (
296- w13_weight , output_sizes , n_shards , dim = 1 )
297- w13_weight = jax .device_put (
298- w13_weight ,
299- Format (
300- Layout ((0 , 1 , 2 )),
301- NamedSharding (self .mesh , P (None , "model" , None ))))
302- w2_weight = jax .device_put (
303- w2_weight ,
304- Format (
305- Layout ((0 , 1 , 2 )),
306- NamedSharding (self .mesh , P (None , None , "model" ))))
307-
308- if self .moe .has_bias :
309- w13_bias = reorder_concatenated_tensor_for_sharding (
310- w13_bias , output_sizes , n_shards , dim = 1 )
311- w13_bias = jax .device_put (
312- w13_bias ,
313- Format (Layout ((0 , 1 )),
314- NamedSharding (self .mesh , P (None , "model" ))))
315- w2_bias = jax .device_put (
316- w2_bias ,
317- Format (Layout ((0 , 1 )),
318- NamedSharding (self .mesh , P (None , None ))))
280+ intermediate_size = w13_weight .shape [1 ] // 2
281+ assert intermediate_size == w2_weight .shape [- 1 ]
282+ output_sizes = [intermediate_size , intermediate_size ]
283+ n_shards = self .mesh .shape ["model" ]
284+ assert intermediate_size % n_shards == 0
285+ w13_weight = reorder_concatenated_tensor_for_sharding (
286+ w13_weight , output_sizes , n_shards , dim = 1 )
287+ w13_weight = jax .device_put (
288+ w13_weight ,
289+ Format (Layout ((0 , 1 , 2 )),
290+ NamedSharding (self .mesh , P (None , "model" , None ))))
291+ w2_weight = jax .device_put (
292+ w2_weight ,
293+ Format (Layout ((0 , 1 , 2 )),
294+ NamedSharding (self .mesh , P (None , None , "model" ))))
295+
296+ if self .moe .has_bias :
297+ w13_bias = reorder_concatenated_tensor_for_sharding (
298+ w13_bias , output_sizes , n_shards , dim = 1 )
299+ w13_bias = jax .device_put (
300+ w13_bias ,
301+ Format (Layout ((0 , 1 )),
302+ NamedSharding (self .mesh , P (None , "model" ))))
303+ w2_bias = jax .device_put (
304+ w2_bias ,
305+ Format (Layout ((0 , 1 )),
306+ NamedSharding (self .mesh , P (None , None ))))
319307
320308 layer .w13_weight = Parameter (torch_view (w13_weight ),
321309 requires_grad = False )
0 commit comments