Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions vllm/distributed/eplb/rebalance_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
on how the EPLB algorithm works.
"""

import numpy as np
import torch


Expand All @@ -34,29 +35,44 @@ def balanced_packing(
assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs

device = weight.device

if groups_per_pack == 1:
pack_index = torch.arange(
weight.size(-1), dtype=torch.int64, device=weight.device
weight.size(-1), dtype=torch.int64, device=device
).expand(weight.shape)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
return pack_index, rank_in_pack

indices = weight.float().sort(-1, descending=True).indices.cpu()
pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu")
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
weight_np = weight.cpu().numpy()

# Sort and get indices in decending order
indices_np = np.argsort(-weight_np, axis=-1)

pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)

# Run the packing algorithm
for i in range(num_layers):
pack_weights = [0] * num_packs
pack_weights = [0.0] * num_packs
pack_items = [0] * num_packs
for group in indices[i]:

for group in indices_np[i]:
# Find a pack with capacity that has the lowest weight
pack = min(
(i for i in range(num_packs) if pack_items[i] < groups_per_pack),
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
key=pack_weights.__getitem__,
)

assert pack_items[pack] < groups_per_pack
pack_index[i, group] = pack
rank_in_pack[i, group] = pack_items[pack]
pack_weights[pack] += weight[i, group]
pack_index_np[i, group] = pack
rank_in_pack_np[i, group] = pack_items[pack]
pack_weights[pack] += weight_np[i, group]
pack_items[pack] += 1

pack_index = torch.from_numpy(pack_index_np).to(device)
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)

return pack_index, rank_in_pack


Expand Down Expand Up @@ -212,7 +228,7 @@ def rebalance_experts(
replicas for each logical expert
"""
num_layers, num_logical_experts = weight.shape
weight = weight.float().cpu()
weight = weight.float()
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
Expand Down
14 changes: 9 additions & 5 deletions vllm/distributed/eplb/rebalance_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,19 @@ def rearrange_expert_weights_inplace(
)
return

old_global_expert_indices_cpu = old_global_expert_indices.cpu()
new_global_expert_indices_cpu = new_global_expert_indices.cpu()

# NOTE(bowen): We need this synchronize to run, but I don't know why.
# If you figure out the reason, please let me know -- thank you!
torch.cuda.synchronize()

for layer in range(num_moe_layers):
# NOTE(bowen): We need this synchronize to run, but I don't know why.
# If you figure out the reason, please let me know -- thank you!
torch.cuda.synchronize()
shuffle_layer(
num_local_physical_experts,
ep_rank,
old_global_expert_indices[layer].tolist(),
new_global_expert_indices[layer].tolist(),
old_global_expert_indices_cpu[layer].tolist(),
new_global_expert_indices_cpu[layer].tolist(),
expert_weights[layer],
expert_weights_buffer,
ep_group,
Expand Down