Skip to content

Commit ffb63b4

Browse files
committed
feat: refactor kernel dir and add flashinfer for cuda kernel.
1 parent 762ee79 commit ffb63b4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+2316
-349
lines changed

.gitmodules

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,9 @@
2828
[submodule "third_party/Mooncake"]
2929
path = third_party/Mooncake
3030
url = https://gitcode.com/xLLM-AI/Mooncake.git
31+
[submodule "third_party/flashinfer"]
32+
path = third_party/flashinfer
33+
url = https://gitcode.com/xLLM-AI/flashinfer.git
34+
[submodule "third_party/cutlass"]
35+
path = third_party/cutlass
36+
url = https://gitcode.com/xLLM-AI/cutlass.git

third_party/CMakeLists.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,24 @@ target_include_directories(mooncake_store PUBLIC
2020
)
2121

2222
target_link_libraries(mooncake_store PUBLIC transfer_engine cachelib_memory_allocator)
23+
24+
25+
if(USE_CUDA)
26+
cc_library(
27+
NAME
28+
cutlass
29+
INCLUDES
30+
cutlass/include
31+
cutlass/tools/util/include
32+
DEPS
33+
torch # TODO: depends on CUDA instead of torch
34+
)
35+
cc_library(
36+
NAME
37+
flashinfer
38+
INCLUDES
39+
flashinfer/include
40+
DEPS
41+
cutlass
42+
)
43+
endif()

third_party/cutlass

Submodule cutlass added at e6e2cc2

third_party/flashinfer

Submodule flashinfer added at bd98dac

xllm/core/framework/batch/batch_input_builder.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
212212
state_.q_seq_lens.insert(state_.q_seq_lens.end(),
213213
state.q_seq_lens.begin(),
214214
state.q_seq_lens.end());
215-
#elif defined(USE_MLU)
215+
#elif defined(USE_MLU) || defined(USE_CUDA)
216216
int32_t seq_len_offset = state_.seq_lens.back();
217217
// skip the first element which is 0
218218
for (size_t i = 1; i < state.seq_lens.size(); ++i) {
@@ -284,7 +284,7 @@ void BatchInputBuilder::process_single_sequence(
284284
#if defined(USE_NPU)
285285
state.seq_lens.push_back(seq_len);
286286
state.q_seq_lens.push_back(q_seq_len);
287-
#elif defined(USE_MLU)
287+
#elif defined(USE_MLU) || defined(USE_CUDA)
288288
state.seq_lens.push_back(state.seq_lens.back() + seq_len);
289289
state.q_seq_lens.push_back(state.q_seq_lens.back() + q_seq_len);
290290
#endif
@@ -437,7 +437,12 @@ void BatchInputBuilder::setup_kv_cache_info(
437437
block_size = block.size();
438438
block_ids.push_back(block.id());
439439
u_block_ids.emplace_back(block.id());
440+
state.paged_kv_indices.push_back(block.id());
440441
}
442+
state.paged_kv_indptr.push_back(state.paged_kv_indptr.back() + blocks.size());
443+
int32_t last_page_len =
444+
(seq_len % block_size == 0) ? block_size : seq_len % block_size;
445+
state.paged_kv_last_page_len.push_back(last_page_len);
441446

442447
int32_t kv_cache_block_idx = n_kv_cache_tokens / block_size;
443448
for (auto iter = block_ids.begin() + kv_cache_block_idx;
@@ -506,12 +511,15 @@ void BatchInputBuilder::padding_decode_batch_size(
506511
#if defined(USE_NPU)
507512
state_.seq_lens.push_back(num_decoding_tokens);
508513
state_.q_seq_lens.push_back(num_decoding_tokens);
509-
#elif defined(USE_MLU)
514+
#elif defined(USE_MLU) || defined(USE_CUDA)
510515
state_.seq_lens.push_back(state_.seq_lens.back() + num_decoding_tokens);
511516
state_.q_seq_lens.push_back(state_.q_seq_lens.back() +
512517
num_decoding_tokens);
513518
#endif
514519
state_.block_tables_vec.emplace_back();
520+
state_.paged_kv_indices.push_back(0);
521+
state_.paged_kv_indptr.push_back(state_.paged_kv_indptr.back() + 1);
522+
state_.paged_kv_last_page_len.push_back(1);
515523
}
516524
}
517525
}

xllm/core/framework/batch/batch_input_builder.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ class BatchInputBuilder {
104104
// for continuous kvcache
105105
std::vector<int64_t> new_cache_slot_offsets; //[n_tokens]
106106
std::vector<int64_t> kv_cache_start_offsets; //[n_seq]
107+
108+
// for flashinfer
109+
std::vector<int32_t> paged_kv_indptr = {0};
110+
std::vector<int32_t> paged_kv_indices;
111+
std::vector<int32_t> paged_kv_last_page_len;
107112
};
108113

109114
// Helper methods for sequence processing
@@ -128,7 +133,6 @@ class BatchInputBuilder {
128133
uint32_t q_seq_len,
129134
BuilderState* state_ptr = nullptr,
130135
std::unordered_set<int32_t>* write_block_ids_ptr = nullptr);
131-
132136
void setup_continuous_kv_cache_info(Sequence* sequence,
133137
uint32_t n_kv_cache_tokens,
134138
uint32_t seq_len,

xllm/core/framework/model/model_input_params.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ struct ModelInputParams {
9393

9494
// Copy graph_buffer to device
9595
params.graph_buffer = safe_to(graph_buffer, device, true);
96+
97+
// params for flashinfer
98+
params.paged_kv_indptr = safe_to(paged_kv_indptr, device);
99+
params.paged_kv_indices = safe_to(paged_kv_indices, device);
100+
params.paged_kv_last_page_len = safe_to(paged_kv_last_page_len, device);
101+
96102
return params;
97103
}
98104

@@ -192,6 +198,21 @@ struct ModelInputParams {
192198
// Graph execution buffer for temporary tensor storage
193199
// Used by ACL Graph Executor to avoid repeated memory allocation
194200
torch::Tensor graph_buffer;
201+
202+
// the indptr of the paged kv-cache
203+
// used in flashinfer
204+
// IntTensor: [n_seq + 1]
205+
torch::Tensor paged_kv_indptr;
206+
207+
// the page indices of the paged kv cache
208+
// used in flashinfer
209+
torch::Tensor paged_kv_indices;
210+
211+
// the number of entries in the last page of each request in
212+
// the paged kv cache
213+
// used in flashinfer
214+
// IntTensor: [n_seq]
215+
torch::Tensor paged_kv_last_page_len;
195216
};
196217

197218
} // namespace xllm

xllm/core/kernels/CMakeLists.txt

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
include(cc_library)
22

33
if(USE_NPU)
4-
include_directories(
5-
${CMAKE_SOURCE_DIR}/third_party/spdlog/include
6-
)
74
add_subdirectory(npu)
85
endif()
96

107
if(USE_MLU)
118
add_subdirectory(mlu)
129
endif()
10+
11+
if(USE_CUDA)
12+
add_subdirectory(cuda)
13+
endif()
14+
15+
cc_library(
16+
NAME
17+
kernels
18+
HDRS
19+
param.h
20+
ops_api.h
21+
SRCS
22+
ops_api.cpp
23+
DEPS
24+
torch
25+
$<$<BOOL:${USE_NPU}>:npu_kernels>
26+
$<$<BOOL:${USE_MLU}>:mlu_kernels>
27+
$<$<BOOL:${USE_CUDA}>:cuda_kernels>
28+
)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
include(cc_library)
2+
3+
file(GLOB_RECURSE CUDA_HEADER_FILES
4+
"${CMAKE_CURRENT_LIST_DIR}/*.h"
5+
)
6+
7+
file(GLOB_RECURSE CUDA_SOURCE_FILES
8+
"${CMAKE_CURRENT_LIST_DIR}/*.cpp"
9+
)
10+
11+
cc_library(
12+
NAME
13+
cuda_kernels
14+
HDRS
15+
${CUDA_HEADER_FILES}
16+
SRCS
17+
${CUDA_SOURCE_FILES}
18+
DEPS
19+
flashinfer
20+
)

xllm/core/kernels/cuda/active.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <cuda_runtime.h>
17+
18+
#include <flashinfer/activation.cuh>
19+
20+
#include "cuda_ops_api.h"
21+
22+
using namespace flashinfer;
23+
24+
namespace xllm::kernel::cuda {
25+
26+
__device__ __forceinline__ float silu(const float& val) {
27+
return val / (1.0f + __expf(-val));
28+
}
29+
30+
__device__ __forceinline__ float gelu(const float& val) {
31+
constexpr float kAlpha = M_SQRT1_2;
32+
return val * 0.5f * (1.0f + ::erf(val * kAlpha));
33+
}
34+
35+
__device__ __forceinline__ float gelu_tanh(const float& val) {
36+
const float cdf =
37+
0.5f * (1.0f + math::tanh((0.7978845608028654f *
38+
(val + 0.044715f * val * val * val))));
39+
return val * cdf;
40+
}
41+
42+
void act_and_mul(TensorView out,
43+
TensorView input,
44+
const std::string& act_mode,
45+
bool enable_pdl) {
46+
int d = input->shape[input->ndim - 1] / 2;
47+
int64_t num_tokens = input.numel() / input->shape[input->ndim - 1];
48+
dim3 grid(num_tokens);
49+
50+
cudaSetDevice(out->device.device_id);
51+
const cudaStream_t stream = get_stream(out->device);
52+
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input->dtype, c_type, [&] {
53+
uint32_t vec_size = 16 / sizeof(c_type);
54+
cudaLaunchConfig_t config;
55+
config.gridDim = num_tokens;
56+
config.blockDim = std::min(d / vec_size, 1024U);
57+
config.dynamicSmemBytes = 0;
58+
config.stream = stream;
59+
cudaLaunchAttribute attrs[1];
60+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
61+
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
62+
config.numAttrs = 1;
63+
config.attrs = attrs;
64+
65+
auto kernel = activation::act_and_mul_kernel<c_type, act_mode>;
66+
67+
cudaLaunchKernelEx(&config,
68+
kernel,
69+
static_cast<c_type*>(out->data),
70+
static_cast<c_type*>(input->data),
71+
d);
72+
73+
cudaError_t err = cudaGetLastError();
74+
TVM_FFI_ICHECK(err == cudaSuccess)
75+
<< "Failed to launch kernel: " << cudaGetErrorString(err);
76+
77+
return true;
78+
});
79+
}
80+
81+
} // namespace xllm::kernel::cuda

0 commit comments

Comments
 (0)