Skip to content

Commit 21b82f4

Browse files
authored
[Kernel] LoRA triton kernels support PDL (#27402)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent a736e5f commit 21b82f4

File tree

5 files changed

+68
-17
lines changed

5 files changed

+68
-17
lines changed

vllm/lora/ops/triton_ops/fused_moe_lora_op.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from vllm.triton_utils import tl, triton
77
from vllm.utils.torch_utils import direct_register_custom_op
88

9+
from .utils import supports_pdl
10+
911
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
1012

1113

@@ -82,6 +84,8 @@ def _fused_moe_lora_kernel(
8284
BLOCK_SIZE_K: tl.constexpr,
8385
GROUP_SIZE_M: tl.constexpr,
8486
SPLIT_K: tl.constexpr,
87+
USE_GDC: tl.constexpr,
88+
IS_PRIMARY: tl.constexpr,
8589
):
8690
pid = tl.program_id(axis=0)
8791
slice_id = tl.program_id(axis=1)
@@ -110,13 +114,11 @@ def _fused_moe_lora_kernel(
110114
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id)
111115
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
112116
return
113-
114117
# get the expert_id to process curr shard
115118
ind = lora_id * stride_el + pid_m
116119
expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
117120
if expert_id == -1:
118121
return
119-
120122
# get a_ptr,b_ptr,c_ptr
121123
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
122124
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
@@ -149,12 +151,17 @@ def _fused_moe_lora_kernel(
149151
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
150152
for k in range(0, grid_k):
151153
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
154+
# pre-fetch lora weight
155+
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
156+
# GDC wait waits for ALL programs in the the prior kernel to complete
157+
# before continuing.
158+
if USE_GDC and not IS_PRIMARY:
159+
tl.extra.cuda.gdc_wait()
152160
a = tl.load(
153161
a_ptrs,
154162
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
155163
other=0.0,
156164
)
157-
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
158165
accumulator += tl.dot(a, b)
159166
# Advance the ptrs to the next K block.
160167
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
@@ -163,12 +170,15 @@ def _fused_moe_lora_kernel(
163170
if MUL_ROUTED_WEIGHT:
164171
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
165172
accumulator = accumulator * moe_weight[:, None]
166-
173+
if USE_GDC and IS_PRIMARY:
174+
# GDC launch dependents hints the runtime system to launch dependent kernels.
175+
tl.extra.cuda.gdc_launch_dependents()
167176
accumulator = accumulator.to(c_ptr.dtype.element_ty)
168177
# Write back the block of the output
169178
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
170179
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
171180
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
181+
172182
if SPLIT_K == 1:
173183
tl.store(c_ptrs, accumulator, mask=c_mask)
174184
else:
@@ -209,7 +219,7 @@ def _fused_moe_lora_shrink(
209219
mul_routed_weight: bool = False,
210220
) -> None:
211221
w1_lora_a_stacked = lora_a_stacked[0]
212-
222+
use_gdc = supports_pdl(qcurr_hidden_states.device)
213223
shrink_config = {
214224
"BLOCK_SIZE_M": block_size_m,
215225
"BLOCK_SIZE_N": block_size_n,
@@ -218,6 +228,8 @@ def _fused_moe_lora_shrink(
218228
"num_warps": num_warps,
219229
"num_stages": num_stages,
220230
"SPLIT_K": split_k,
231+
"USE_GDC": use_gdc,
232+
"launch_pdl": use_gdc, # triton kernel metadata
221233
}
222234

223235
b_ptr = _get_ptr(lora_a_stacked, device)
@@ -229,7 +241,6 @@ def _fused_moe_lora_shrink(
229241
len(lora_a_stacked),
230242
lora_a_stacked[0].shape[0],
231243
)
232-
233244
_fused_moe_lora_kernel[grid](
234245
qcurr_hidden_states,
235246
b_ptr,
@@ -261,6 +272,7 @@ def _fused_moe_lora_shrink(
261272
num_slice_c=num_slices,
262273
top_k=1 if mul_routed_weight else top_k_num,
263274
MUL_ROUTED_WEIGHT=False,
275+
IS_PRIMARY=True,
264276
**shrink_config,
265277
)
266278

@@ -314,7 +326,7 @@ def _fused_moe_lora_expand(
314326
dtype=output.dtype,
315327
device=device,
316328
)
317-
329+
use_gdc = supports_pdl(a_intermediate_cache1.device)
318330
expand_config = {
319331
"BLOCK_SIZE_M": block_size_m,
320332
"BLOCK_SIZE_N": block_size_n,
@@ -323,6 +335,8 @@ def _fused_moe_lora_expand(
323335
"num_warps": num_warps,
324336
"num_stages": num_stages,
325337
"SPLIT_K": split_k, # Set split_k = 1 for expand calls
338+
"USE_GDC": use_gdc,
339+
"launch_pdl": use_gdc, # triton kernel metadata
326340
}
327341

328342
grid = lambda META: (
@@ -361,6 +375,7 @@ def _fused_moe_lora_expand(
361375
num_slice_c=num_slices,
362376
top_k=1,
363377
MUL_ROUTED_WEIGHT=mul_routed_weight,
378+
IS_PRIMARY=False,
364379
**expand_config,
365380
)
366381
for i in range(num_slices):

vllm/lora/ops/triton_ops/kernel_utils.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def mm_k(
2222
SPLIT_K: tl.constexpr,
2323
CAST_TYPE: tl.constexpr,
2424
b_dtype: tl.constexpr,
25+
USE_GDC: tl.constexpr,
2526
):
2627
"""
2728
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
@@ -45,19 +46,25 @@ def mm_k(
4546
CAST_TYPE: if True, cast the values from the A matrix to the B
4647
matrix dtype.
4748
b_dtype: datatype of the B matrix
49+
USE_GDC: Whether to use PDL. True indicates use.
4850
"""
4951
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
5052
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
5153
if EVEN_K:
52-
tiled_a = tl.load(a_ptr)
54+
# pre-fetech lora weight
5355
tiled_b = tl.load(b_ptr)
56+
if USE_GDC:
57+
tl.extra.cuda.gdc_wait()
58+
tiled_a = tl.load(a_ptr)
5459
else:
55-
tiled_a = tl.load(
56-
a_ptr, mask=offset_k[None, :] < K - k * (BLOCK_K * SPLIT_K), other=0
57-
)
5860
tiled_b = tl.load(
5961
b_ptr, mask=offset_k[:, None] < K - k * (BLOCK_K * SPLIT_K), other=0
6062
)
63+
if USE_GDC:
64+
tl.extra.cuda.gdc_wait()
65+
tiled_a = tl.load(
66+
a_ptr, mask=offset_k[None, :] < K - k * (BLOCK_K * SPLIT_K), other=0
67+
)
6168
if CAST_TYPE:
6269
tiled_a = tiled_a.to(b_dtype)
6370
accumulator += tl.dot(
@@ -102,6 +109,7 @@ def do_expand_kernel(
102109
EVEN_K: tl.constexpr,
103110
CAST_TYPE: tl.constexpr,
104111
ADD_INPUTS: tl.constexpr,
112+
USE_GDC: tl.constexpr,
105113
):
106114
"""
107115
Given an array of integers that identifies the rows of A, ram,
@@ -154,6 +162,7 @@ def do_expand_kernel(
154162

155163
# Compute the block matrix product.
156164
SPLIT_K = 1
165+
157166
accumulator = mm_k(
158167
a_ptr,
159168
b_ptr,
@@ -168,6 +177,7 @@ def do_expand_kernel(
168177
SPLIT_K,
169178
CAST_TYPE,
170179
cur_lora_ptr.dtype.element_ty,
180+
USE_GDC,
171181
)
172182

173183
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
@@ -223,6 +233,7 @@ def do_shrink_kernel(
223233
EVEN_K: tl.constexpr,
224234
SPLIT_K: tl.constexpr,
225235
SLICE_NUM: tl.constexpr,
236+
USE_GDC: tl.constexpr,
226237
):
227238
"""
228239
Given an array of integers that identifies the rows of A, ram,
@@ -272,8 +283,11 @@ def do_shrink_kernel(
272283
SPLIT_K,
273284
False,
274285
cur_lora_ptr.dtype.element_ty,
286+
False, # USE_GDC is always False in shrink kernel
275287
)
276-
288+
# GDC launch dependents hints the runtime system to launch dependent kernels.
289+
if USE_GDC:
290+
tl.extra.cuda.gdc_launch_dependents()
277291
# Identify the C output pointers to store the results of the accumulator.
278292
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
279293
offset_cm = tl.arange(0, BLOCK_M)
@@ -284,10 +298,10 @@ def do_shrink_kernel(
284298
+ offset_cn[None, :] * output_d2_stride
285299
)
286300
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
287-
288301
accumulator *= scaling
302+
289303
# handles write-back with reduction-splitting
290304
if SPLIT_K == 1:
291305
tl.store(c_ptr, accumulator, mask=c_mask)
292306
else:
293-
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
307+
tl.atomic_add(c_ptr, accumulator, mask=c_mask, sem="relaxed")

vllm/lora/ops/triton_ops/lora_expand_op.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from vllm.triton_utils import tl, triton
1515
from vllm.utils.torch_utils import direct_register_custom_op
1616

17+
from .utils import supports_pdl
18+
1719

1820
@triton.jit
1921
def _lora_expand_kernel(
@@ -45,6 +47,7 @@ def _lora_expand_kernel(
4547
CAST_TYPE: tl.constexpr,
4648
SLICE_NUM: tl.constexpr,
4749
SAME_STRIDE: tl.constexpr,
50+
USE_GDC: tl.constexpr,
4851
):
4952
cta_n_num = tl.cdiv(N, BLOCK_N)
5053
cta_m_num = tl.cdiv(M, BLOCK_M)
@@ -121,6 +124,7 @@ def _lora_expand_kernel(
121124
EVEN_K,
122125
CAST_TYPE,
123126
ADD_INPUTS,
127+
USE_GDC,
124128
)
125129

126130

@@ -236,7 +240,7 @@ def _lora_expand(
236240
# thread blocks simply exit.
237241
MAX_LORAS,
238242
)
239-
243+
use_gdc = supports_pdl(inputs.device)
240244
_lora_expand_kernel[grid](
241245
inputs,
242246
lora_ptr_tensor,
@@ -266,9 +270,11 @@ def _lora_expand(
266270
CAST_TYPE,
267271
NUM_SLICES,
268272
same_stride,
273+
use_gdc,
269274
num_warps=NUM_WARPS,
270275
num_ctas=NUM_CTAS,
271276
num_stages=NUM_STAGES,
277+
launch_pdl=use_gdc,
272278
)
273279

274280
return

vllm/lora/ops/triton_ops/lora_shrink_op.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from vllm.triton_utils import tl, triton
1515
from vllm.utils.torch_utils import direct_register_custom_op
1616

17+
from .utils import supports_pdl
18+
1719

1820
@triton.jit
1921
def _lora_shrink_kernel(
@@ -43,6 +45,7 @@ def _lora_shrink_kernel(
4345
SPLIT_K: tl.constexpr,
4446
GROUP_SIZE_M: tl.constexpr,
4547
SLICE_NUM: tl.constexpr,
48+
USE_GDC: tl.constexpr,
4649
):
4750
cta_n_num = tl.cdiv(N, BLOCK_N)
4851
cta_m_num = tl.cdiv(M, BLOCK_M)
@@ -83,7 +86,6 @@ def _lora_shrink_kernel(
8386
cta_lora_seq_indices = (
8487
token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset
8588
)
86-
8789
# Load all relevant row indices.
8890
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
8991
ram = tl.load(cta_lora_seq_indices + offset_m)
@@ -118,6 +120,7 @@ def _lora_shrink_kernel(
118120
EVEN_K,
119121
SPLIT_K,
120122
SLICE_NUM,
123+
USE_GDC,
121124
)
122125

123126

@@ -217,7 +220,7 @@ def _lora_shrink(
217220
# thread blocks exit early.
218221
MAX_LORAS,
219222
)
220-
223+
use_gdc = supports_pdl(inputs.device)
221224
_lora_shrink_kernel[grid](
222225
inputs,
223226
lora_ptr_tensor,
@@ -245,9 +248,11 @@ def _lora_shrink(
245248
SPLIT_K,
246249
GROUP_SIZE_M,
247250
NUM_SLICES,
251+
use_gdc,
248252
num_warps=NUM_WARPS,
249253
num_ctas=NUM_CTAS,
250254
num_stages=NUM_STAGES,
255+
launch_pdl=use_gdc,
251256
)
252257

253258
return

vllm/lora/ops/triton_ops/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33

44
import functools
55
import json
6+
from functools import lru_cache
67
from pathlib import Path
78
from typing import Any
89

910
import torch
1011

1112
from vllm import envs
1213
from vllm.logger import init_logger
14+
from vllm.platforms import current_platform
1315

1416
logger = init_logger(__name__)
1517

@@ -282,3 +284,12 @@ def get_lora_op_configs(
282284

283285
assert config_data is not None
284286
return config_data
287+
288+
289+
@lru_cache
290+
def supports_pdl(device: torch.device | None = None) -> bool:
291+
"""
292+
Refer to: https://github.com/triton-lang/triton/blob/v3.5.0/python/tutorials/11-programmatic-dependent-launch.py
293+
"""
294+
# PDL requires compute capability SM90 or above
295+
return current_platform.is_cuda() and current_platform.has_device_capability(90)

0 commit comments

Comments
 (0)