From 5d1a968ae3b2e6ce021d96b2a3d7b7caa333217d Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 6 Nov 2025 20:45:55 +0000 Subject: [PATCH 1/5] init Signed-off-by: Sage Moore --- vllm/distributed/eplb/eplb_state.py | 32 +++++++++++----------- vllm/distributed/eplb/rebalance_algo.py | 10 +++---- vllm/distributed/eplb/rebalance_execute.py | 14 ++++++---- 3 files changed, 30 insertions(+), 26 deletions(-) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 17716e8a07ac..a0ed3a2aafc4 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -446,14 +446,14 @@ def rearrange( """ ep_group = get_ep_group().device_group - ep_rank = ep_group.rank() + # ep_rank = ep_group.rank() - time_start = None - is_main_rank = ep_rank == 0 - if is_main_rank: - torch.cuda.synchronize() - time_start = time.perf_counter() - logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") + # time_start = None + # is_main_rank = ep_rank == 0 + # if is_main_rank: + # torch.cuda.synchronize() + # time_start = time.perf_counter() + # logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") if global_expert_load is None: # Map the physical expert load to global logical experts @@ -569,15 +569,15 @@ def rearrange( self.logical_to_physical_map.copy_(new_logical_to_physical_map) self.logical_replica_count.copy_(new_logical_replica_count) - if is_main_rank: - assert time_start is not None - torch.cuda.synchronize() - time_end = time.perf_counter() - logger.info( - "Rearranged experts%sin %.2f seconds.", - " (profile) " if is_profile else " ", - time_end - time_start, - ) + # if is_main_rank: + # assert time_start is not None + # torch.cuda.synchronize() + # time_end = time.perf_counter() + # logger.info( + # "Rearranged experts%sin %.2f seconds.", + # " (profile) " if is_profile else " ", + # time_end - time_start, + # ) return None @staticmethod diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index c9d30d6481ab..487caa663bd4 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -38,12 +38,12 @@ def balanced_packing( pack_index = torch.arange( weight.size(-1), dtype=torch.int64, device=weight.device ).expand(weight.shape) - rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=weight.device) return pack_index, rank_in_pack - indices = weight.float().sort(-1, descending=True).indices.cpu() - pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu") - rank_in_pack = torch.full_like(pack_index, fill_value=-1) + indices = weight.float().sort(-1, descending=True).indices + pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device=weight.device) + rank_in_pack = torch.full_like(pack_index, fill_value=-1, device=weight.device) for i in range(num_layers): pack_weights = [0] * num_packs pack_items = [0] * num_packs @@ -212,7 +212,7 @@ def rebalance_experts( replicas for each logical expert """ num_layers, num_logical_experts = weight.shape - weight = weight.float().cpu() + weight = weight.float() if num_groups % num_nodes == 0: # use hierarchical load-balance policy phy2log, phyrank, logcnt = rebalance_experts_hierarchical( diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index f8ec3e956401..5c1efbaf03ba 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -321,15 +321,19 @@ def rearrange_expert_weights_inplace( ) return + old_global_expert_indices_cpu = old_global_expert_indices.cpu() + new_global_expert_indices_cpu = new_global_expert_indices.cpu() + + # NOTE(bowen): We need this synchronize to run, but I don't know why. + # If you figure out the reason, please let me know -- thank you! + torch.cuda.synchronize() + for layer in range(num_moe_layers): - # NOTE(bowen): We need this synchronize to run, but I don't know why. - # If you figure out the reason, please let me know -- thank you! - torch.cuda.synchronize() shuffle_layer( num_local_physical_experts, ep_rank, - old_global_expert_indices[layer].tolist(), - new_global_expert_indices[layer].tolist(), + old_global_expert_indices_cpu[layer].tolist(), + new_global_expert_indices_cpu[layer].tolist(), expert_weights[layer], expert_weights_buffer, ep_group, From c9c5f29f501f12b8ae7a08f470b7110e5268d895 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 7 Nov 2025 18:05:04 +0000 Subject: [PATCH 2/5] init Signed-off-by: Sage Moore --- vllm/distributed/eplb/rebalance_algo.py | 44 +++++++++++++++++-------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index 487caa663bd4..71fac5985895 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -12,6 +12,7 @@ on how the EPLB algorithm works. """ +import numpy as np import torch @@ -21,11 +22,9 @@ def balanced_packing( """ Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs are as balanced as possible. - Parameters: weight: [X, n], the weight of each item num_packs: number of packs - Returns: pack_index: [X, n], the pack index of each item rank_in_pack: [X, n], the rank of the item in the pack @@ -34,29 +33,48 @@ def balanced_packing( assert num_groups % num_packs == 0 groups_per_pack = num_groups // num_packs + device = weight.device + + # Handle trivial case before conversion if groups_per_pack == 1: pack_index = torch.arange( - weight.size(-1), dtype=torch.int64, device=weight.device + weight.size(-1), dtype=torch.int64, device=device ).expand(weight.shape) - rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=weight.device) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device) return pack_index, rank_in_pack - indices = weight.float().sort(-1, descending=True).indices - pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device=weight.device) - rank_in_pack = torch.full_like(pack_index, fill_value=-1, device=weight.device) + # Convert to NumPy for CPU processing + weight_np = weight.cpu().numpy() + + # Sort and get indices + indices_np = np.argsort(-weight_np, axis=-1) # Descending order + + # Initialize output arrays + pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64) + rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64) + + # Run the packing algorithm for i in range(num_layers): - pack_weights = [0] * num_packs + pack_weights = [0.0] * num_packs pack_items = [0] * num_packs - for group in indices[i]: + + for group in indices_np[i]: + # Find pack with minimum weight that still has capacity pack = min( - (i for i in range(num_packs) if pack_items[i] < groups_per_pack), + (j for j in range(num_packs) if pack_items[j] < groups_per_pack), key=pack_weights.__getitem__, ) + assert pack_items[pack] < groups_per_pack - pack_index[i, group] = pack - rank_in_pack[i, group] = pack_items[pack] - pack_weights[pack] += weight[i, group] + pack_index_np[i, group] = pack + rank_in_pack_np[i, group] = pack_items[pack] + pack_weights[pack] += weight_np[i, group] pack_items[pack] += 1 + + # Convert back to PyTorch tensors on original device + pack_index = torch.from_numpy(pack_index_np).to(device) + rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device) + return pack_index, rank_in_pack From 33170864c1a46c8096f50ea7edb78d6f0b18c491 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 7 Nov 2025 18:09:00 +0000 Subject: [PATCH 3/5] restore cuda syncs Signed-off-by: Sage Moore --- vllm/distributed/eplb/eplb_state.py | 32 ++++++++++++++--------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index a0ed3a2aafc4..17716e8a07ac 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -446,14 +446,14 @@ def rearrange( """ ep_group = get_ep_group().device_group - # ep_rank = ep_group.rank() + ep_rank = ep_group.rank() - # time_start = None - # is_main_rank = ep_rank == 0 - # if is_main_rank: - # torch.cuda.synchronize() - # time_start = time.perf_counter() - # logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") + time_start = None + is_main_rank = ep_rank == 0 + if is_main_rank: + torch.cuda.synchronize() + time_start = time.perf_counter() + logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") if global_expert_load is None: # Map the physical expert load to global logical experts @@ -569,15 +569,15 @@ def rearrange( self.logical_to_physical_map.copy_(new_logical_to_physical_map) self.logical_replica_count.copy_(new_logical_replica_count) - # if is_main_rank: - # assert time_start is not None - # torch.cuda.synchronize() - # time_end = time.perf_counter() - # logger.info( - # "Rearranged experts%sin %.2f seconds.", - # " (profile) " if is_profile else " ", - # time_end - time_start, - # ) + if is_main_rank: + assert time_start is not None + torch.cuda.synchronize() + time_end = time.perf_counter() + logger.info( + "Rearranged experts%sin %.2f seconds.", + " (profile) " if is_profile else " ", + time_end - time_start, + ) return None @staticmethod From 07fb932a49cb8ee0762ce29009a11726051a22ab Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 7 Nov 2025 18:14:27 +0000 Subject: [PATCH 4/5] comments Signed-off-by: Sage Moore --- vllm/distributed/eplb/rebalance_algo.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index 71fac5985895..441196649419 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -43,13 +43,11 @@ def balanced_packing( rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device) return pack_index, rank_in_pack - # Convert to NumPy for CPU processing weight_np = weight.cpu().numpy() - # Sort and get indices - indices_np = np.argsort(-weight_np, axis=-1) # Descending order + # Sort and get indices in decending order + indices_np = np.argsort(-weight_np, axis=-1) - # Initialize output arrays pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64) rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64) @@ -59,7 +57,7 @@ def balanced_packing( pack_items = [0] * num_packs for group in indices_np[i]: - # Find pack with minimum weight that still has capacity + # Find a pack with capacity that has the lowest weight pack = min( (j for j in range(num_packs) if pack_items[j] < groups_per_pack), key=pack_weights.__getitem__, @@ -71,7 +69,6 @@ def balanced_packing( pack_weights[pack] += weight_np[i, group] pack_items[pack] += 1 - # Convert back to PyTorch tensors on original device pack_index = torch.from_numpy(pack_index_np).to(device) rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device) From c20f9c0786d66f89a419a93851222dc9f9948400 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Sun, 9 Nov 2025 22:04:26 +0000 Subject: [PATCH 5/5] comments Signed-off-by: Sage Moore --- vllm/distributed/eplb/rebalance_algo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index 441196649419..e6645e524cc3 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -22,9 +22,11 @@ def balanced_packing( """ Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs are as balanced as possible. + Parameters: weight: [X, n], the weight of each item num_packs: number of packs + Returns: pack_index: [X, n], the pack index of each item rank_in_pack: [X, n], the rank of the item in the pack @@ -35,7 +37,6 @@ def balanced_packing( device = weight.device - # Handle trivial case before conversion if groups_per_pack == 1: pack_index = torch.arange( weight.size(-1), dtype=torch.int64, device=device