Skip to content

Commit a5d52a7

Browse files
[Misc] Change default device for vllm_get_model (#1116)
1 parent 74f70aa commit a5d52a7

File tree

2 files changed

+95
-106
lines changed

2 files changed

+95
-106
lines changed

tpu_inference/layers/vllm/quantization/unquantized.py

Lines changed: 93 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -191,131 +191,119 @@ def select_gemm_impl(
191191

192192
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
193193
assert isinstance(layer, FusedMoE)
194-
available_devices = self.mesh.devices.flatten()
195-
with jax.default_device(available_devices[0]):
196-
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
197-
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
194+
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
195+
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
198196

199-
if self.moe.has_bias:
200-
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
201-
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
202-
203-
if layer.activation == "swigluoai":
204-
# When using swigluoai, vLLM splits gmm output in a interleaved way.
205-
# However, interleaved split is not performant on TPU. Therefore,
206-
# we preprocess the weight so that splitting gmm output by middle
207-
# can still get the same result.
208-
w1_weight = w13_weight[:, ::2, :]
209-
w3_weight = w13_weight[:, 1::2, :]
210-
w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
197+
if self.moe.has_bias:
198+
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
199+
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
200+
201+
if layer.activation == "swigluoai":
202+
# When using swigluoai, vLLM splits gmm output in a interleaved way.
203+
# However, interleaved split is not performant on TPU. Therefore,
204+
# we preprocess the weight so that splitting gmm output by middle
205+
# can still get the same result.
206+
w1_weight = w13_weight[:, ::2, :]
207+
w3_weight = w13_weight[:, 1::2, :]
208+
w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
211209

212-
if self.moe.has_bias:
213-
w1_bias = w13_bias[:, ::2]
214-
w3_bias = w13_bias[:, 1::2]
215-
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
216-
217-
if self.use_kernel and layer.use_ep:
218-
# Kernel expects:
219-
# w13: (num_experts, 2, hidden_size, intermediate_size)
220-
# w2: (num_experts, intermediate_size, hidden_size)
221-
# Current format:
222-
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
223-
# w2_weight: (num_experts, hidden_size, intermediate_size)
224-
num_experts = w13_weight.shape[0]
225-
intermediate_size = w13_weight.shape[1] // 2
226-
hidden_size = w13_weight.shape[2]
210+
if self.moe.has_bias:
211+
w1_bias = w13_bias[:, ::2]
212+
w3_bias = w13_bias[:, 1::2]
213+
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
227214

228-
# Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
229-
w13_reshaped = w13_weight.reshape(num_experts, 2,
230-
intermediate_size,
231-
hidden_size)
232-
w13_weight_transposed = jnp.transpose(w13_reshaped,
233-
(0, 1, 3, 2))
215+
if self.use_kernel and layer.use_ep:
216+
# Kernel expects:
217+
# w13: (num_experts, 2, hidden_size, intermediate_size)
218+
# w2: (num_experts, intermediate_size, hidden_size)
219+
# Current format:
220+
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
221+
# w2_weight: (num_experts, hidden_size, intermediate_size)
222+
num_experts = w13_weight.shape[0]
223+
intermediate_size = w13_weight.shape[1] // 2
224+
hidden_size = w13_weight.shape[2]
225+
226+
# Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
227+
w13_reshaped = w13_weight.reshape(num_experts, 2,
228+
intermediate_size, hidden_size)
229+
w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
230+
231+
# Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
232+
w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
233+
234+
# Apply EP sharding
235+
w13_weight = jax.device_put(
236+
w13_weight_transposed,
237+
Format(Layout((0, 1, 2, 3)),
238+
NamedSharding(self.mesh, P("model", None, None, None))))
239+
w2_weight = jax.device_put(
240+
w2_weight_transposed,
241+
Format(Layout((0, 1, 2)),
242+
NamedSharding(self.mesh, P("model", None, None))))
234243

235-
# Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
236-
w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
244+
if self.moe.has_bias:
245+
w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
237246

238247
# Apply EP sharding
248+
w13_bias = jax.device_put(
249+
w13_bias,
250+
Format(Layout((0, 1, 2)),
251+
NamedSharding(self.mesh, P("model", None, None))))
252+
w2_bias = jax.device_put(
253+
w2_bias,
254+
Format(Layout((0, 1)),
255+
NamedSharding(self.mesh, P("model", None))))
256+
257+
else:
258+
# Original logic for non-kernel path
259+
if layer.use_ep:
239260
w13_weight = jax.device_put(
240-
w13_weight_transposed,
241-
Format(
242-
Layout((0, 1, 2, 3)),
243-
NamedSharding(self.mesh, P("model", None, None,
244-
None))))
261+
w13_weight,
262+
Format(Layout((0, 1, 2)),
263+
NamedSharding(self.mesh, P("model", None, None))))
245264
w2_weight = jax.device_put(
246-
w2_weight_transposed,
265+
w2_weight,
247266
Format(Layout((0, 1, 2)),
248267
NamedSharding(self.mesh, P("model", None, None))))
249268

250269
if self.moe.has_bias:
251-
w13_bias = w13_bias.reshape(num_experts, 2,
252-
intermediate_size)
253-
254-
# Apply EP sharding
255270
w13_bias = jax.device_put(
256271
w13_bias,
257-
Format(
258-
Layout((0, 1, 2)),
259-
NamedSharding(self.mesh, P("model", None, None))))
272+
Format(Layout((0, 1)),
273+
NamedSharding(self.mesh, P("model", None))))
260274
w2_bias = jax.device_put(
261275
w2_bias,
262276
Format(Layout((0, 1)),
263277
NamedSharding(self.mesh, P("model", None))))
264278

265279
else:
266-
# Original logic for non-kernel path
267-
if layer.use_ep:
268-
w13_weight = jax.device_put(
269-
w13_weight,
270-
Format(
271-
Layout((0, 1, 2)),
272-
NamedSharding(self.mesh, P("model", None, None))))
273-
w2_weight = jax.device_put(
274-
w2_weight,
275-
Format(
276-
Layout((0, 1, 2)),
277-
NamedSharding(self.mesh, P("model", None, None))))
278-
279-
if self.moe.has_bias:
280-
w13_bias = jax.device_put(
281-
w13_bias,
282-
Format(Layout((0, 1)),
283-
NamedSharding(self.mesh, P("model", None))))
284-
w2_bias = jax.device_put(
285-
w2_bias,
286-
Format(Layout((0, 1)),
287-
NamedSharding(self.mesh, P("model", None))))
288-
289-
else:
290-
intermediate_size = w13_weight.shape[1] // 2
291-
assert intermediate_size == w2_weight.shape[-1]
292-
output_sizes = [intermediate_size, intermediate_size]
293-
n_shards = self.mesh.shape["model"]
294-
assert intermediate_size % n_shards == 0
295-
w13_weight = reorder_concatenated_tensor_for_sharding(
296-
w13_weight, output_sizes, n_shards, dim=1)
297-
w13_weight = jax.device_put(
298-
w13_weight,
299-
Format(
300-
Layout((0, 1, 2)),
301-
NamedSharding(self.mesh, P(None, "model", None))))
302-
w2_weight = jax.device_put(
303-
w2_weight,
304-
Format(
305-
Layout((0, 1, 2)),
306-
NamedSharding(self.mesh, P(None, None, "model"))))
307-
308-
if self.moe.has_bias:
309-
w13_bias = reorder_concatenated_tensor_for_sharding(
310-
w13_bias, output_sizes, n_shards, dim=1)
311-
w13_bias = jax.device_put(
312-
w13_bias,
313-
Format(Layout((0, 1)),
314-
NamedSharding(self.mesh, P(None, "model"))))
315-
w2_bias = jax.device_put(
316-
w2_bias,
317-
Format(Layout((0, 1)),
318-
NamedSharding(self.mesh, P(None, None))))
280+
intermediate_size = w13_weight.shape[1] // 2
281+
assert intermediate_size == w2_weight.shape[-1]
282+
output_sizes = [intermediate_size, intermediate_size]
283+
n_shards = self.mesh.shape["model"]
284+
assert intermediate_size % n_shards == 0
285+
w13_weight = reorder_concatenated_tensor_for_sharding(
286+
w13_weight, output_sizes, n_shards, dim=1)
287+
w13_weight = jax.device_put(
288+
w13_weight,
289+
Format(Layout((0, 1, 2)),
290+
NamedSharding(self.mesh, P(None, "model", None))))
291+
w2_weight = jax.device_put(
292+
w2_weight,
293+
Format(Layout((0, 1, 2)),
294+
NamedSharding(self.mesh, P(None, None, "model"))))
295+
296+
if self.moe.has_bias:
297+
w13_bias = reorder_concatenated_tensor_for_sharding(
298+
w13_bias, output_sizes, n_shards, dim=1)
299+
w13_bias = jax.device_put(
300+
w13_bias,
301+
Format(Layout((0, 1)),
302+
NamedSharding(self.mesh, P(None, "model"))))
303+
w2_bias = jax.device_put(
304+
w2_bias,
305+
Format(Layout((0, 1)),
306+
NamedSharding(self.mesh, P(None, None))))
319307

320308
layer.w13_weight = Parameter(torch_view(w13_weight),
321309
requires_grad=False)

tpu_inference/models/vllm/vllm_model_wrapper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def load_weights(self):
120120

121121
# Load the vLLM model and wrap it into a new model whose forward
122122
# function can calculate the hidden_state and logits.
123-
with load_context:
123+
available_devices = self.mesh.devices.flatten()
124+
with load_context, jax.default_device(available_devices[0]):
124125
vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
125126
lora_manager = None
126127
if vllm_config_for_load.lora_config is not None:

0 commit comments

Comments
 (0)