66from vllm .triton_utils import tl , triton
77from vllm .utils .torch_utils import direct_register_custom_op
88
9+ from .utils import supports_pdl
10+
911_LORA_PTR_DICT : dict [tuple [int , ...], torch .tensor ] = {}
1012
1113
@@ -82,6 +84,8 @@ def _fused_moe_lora_kernel(
8284 BLOCK_SIZE_K : tl .constexpr ,
8385 GROUP_SIZE_M : tl .constexpr ,
8486 SPLIT_K : tl .constexpr ,
87+ USE_GDC : tl .constexpr ,
88+ IS_PRIMARY : tl .constexpr ,
8589):
8690 pid = tl .program_id (axis = 0 )
8791 slice_id = tl .program_id (axis = 1 )
@@ -110,13 +114,11 @@ def _fused_moe_lora_kernel(
110114 num_tokens_post_padded = tl .load (num_tokens_post_padded_ptr + lora_id )
111115 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded :
112116 return
113-
114117 # get the expert_id to process curr shard
115118 ind = lora_id * stride_el + pid_m
116119 expert_id = tl .load (expert_ids_ptr + ind , ind < max_loras * stride_el , - 1 )
117120 if expert_id == - 1 :
118121 return
119-
120122 # get a_ptr,b_ptr,c_ptr
121123 cur_a_ptr = a_ptr + (slice_id % num_slice_a ) * slice_a_size
122124 cur_b_ptr = tl .load (b_ptr + slice_id ).to (tl .pointer_type (c_ptr .dtype .element_ty ))
@@ -149,12 +151,17 @@ def _fused_moe_lora_kernel(
149151 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
150152 for k in range (0 , grid_k ):
151153 k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K )
154+ # pre-fetch lora weight
155+ b = tl .load (b_ptrs , mask = offs_k [:, None ] < k_remaining , other = 0.0 )
156+ # GDC wait waits for ALL programs in the the prior kernel to complete
157+ # before continuing.
158+ if USE_GDC and not IS_PRIMARY :
159+ tl .extra .cuda .gdc_wait ()
152160 a = tl .load (
153161 a_ptrs ,
154162 mask = token_mask [:, None ] & (offs_k [None , :] < k_remaining ),
155163 other = 0.0 ,
156164 )
157- b = tl .load (b_ptrs , mask = offs_k [:, None ] < k_remaining , other = 0.0 )
158165 accumulator += tl .dot (a , b )
159166 # Advance the ptrs to the next K block.
160167 a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
@@ -163,12 +170,15 @@ def _fused_moe_lora_kernel(
163170 if MUL_ROUTED_WEIGHT :
164171 moe_weight = tl .load (topk_weights_ptr + offs_token , mask = token_mask , other = 0 )
165172 accumulator = accumulator * moe_weight [:, None ]
166-
173+ if USE_GDC and IS_PRIMARY :
174+ # GDC launch dependents hints the runtime system to launch dependent kernels.
175+ tl .extra .cuda .gdc_launch_dependents ()
167176 accumulator = accumulator .to (c_ptr .dtype .element_ty )
168177 # Write back the block of the output
169178 offs_cn = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
170179 c_ptrs = cur_c_ptr + stride_cm * offs_token [:, None ] + stride_cn * offs_cn [None , :]
171180 c_mask = token_mask [:, None ] & (offs_cn [None , :] < N )
181+
172182 if SPLIT_K == 1 :
173183 tl .store (c_ptrs , accumulator , mask = c_mask )
174184 else :
@@ -209,7 +219,7 @@ def _fused_moe_lora_shrink(
209219 mul_routed_weight : bool = False ,
210220) -> None :
211221 w1_lora_a_stacked = lora_a_stacked [0 ]
212-
222+ use_gdc = supports_pdl ( qcurr_hidden_states . device )
213223 shrink_config = {
214224 "BLOCK_SIZE_M" : block_size_m ,
215225 "BLOCK_SIZE_N" : block_size_n ,
@@ -218,6 +228,8 @@ def _fused_moe_lora_shrink(
218228 "num_warps" : num_warps ,
219229 "num_stages" : num_stages ,
220230 "SPLIT_K" : split_k ,
231+ "USE_GDC" : use_gdc ,
232+ "launch_pdl" : use_gdc , # triton kernel metadata
221233 }
222234
223235 b_ptr = _get_ptr (lora_a_stacked , device )
@@ -229,7 +241,6 @@ def _fused_moe_lora_shrink(
229241 len (lora_a_stacked ),
230242 lora_a_stacked [0 ].shape [0 ],
231243 )
232-
233244 _fused_moe_lora_kernel [grid ](
234245 qcurr_hidden_states ,
235246 b_ptr ,
@@ -261,6 +272,7 @@ def _fused_moe_lora_shrink(
261272 num_slice_c = num_slices ,
262273 top_k = 1 if mul_routed_weight else top_k_num ,
263274 MUL_ROUTED_WEIGHT = False ,
275+ IS_PRIMARY = True ,
264276 ** shrink_config ,
265277 )
266278
@@ -314,7 +326,7 @@ def _fused_moe_lora_expand(
314326 dtype = output .dtype ,
315327 device = device ,
316328 )
317-
329+ use_gdc = supports_pdl ( a_intermediate_cache1 . device )
318330 expand_config = {
319331 "BLOCK_SIZE_M" : block_size_m ,
320332 "BLOCK_SIZE_N" : block_size_n ,
@@ -323,6 +335,8 @@ def _fused_moe_lora_expand(
323335 "num_warps" : num_warps ,
324336 "num_stages" : num_stages ,
325337 "SPLIT_K" : split_k , # Set split_k = 1 for expand calls
338+ "USE_GDC" : use_gdc ,
339+ "launch_pdl" : use_gdc , # triton kernel metadata
326340 }
327341
328342 grid = lambda META : (
@@ -361,6 +375,7 @@ def _fused_moe_lora_expand(
361375 num_slice_c = num_slices ,
362376 top_k = 1 ,
363377 MUL_ROUTED_WEIGHT = mul_routed_weight ,
378+ IS_PRIMARY = False ,
364379 ** expand_config ,
365380 )
366381 for i in range (num_slices ):
0 commit comments