Skip to content

Commit b26c50c

Browse files
authored
refactor: refactor kernel dir and add unified kernel api. (#271)
1 parent 23386ef commit b26c50c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1107
-352
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ DEFINE_string(store_metadata_connstring,
343343
"",
344344
"The address of the kv cache store metadata service.");
345345

346-
// --- for computation communication parallel ---
346+
// --- computation communication parallel config ---
347347

348348
DEFINE_bool(
349349
enable_multi_stream_parallel,
@@ -355,7 +355,7 @@ DEFINE_int32(default_micro_batch_num,
355355
2,
356356
"Default use two micro batches for multi-stream parallel.");
357357

358-
// --- for dit ---
358+
// --- dit config ---
359359
DEFINE_int32(max_requests_per_batch, 1, "Max number of request per batch.");
360360

361361
// --- continuous kv cache config ---
@@ -377,4 +377,4 @@ DEFINE_int64(cache_size_per_token,
377377

378378
DEFINE_int64(buffer_size_per_seq,
379379
0,
380-
"Buffer size per sequence in bytes, default 0.");
380+
"Buffer size per sequence in bytes, default 0.");

xllm/core/common/global_flags.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,6 @@ DECLARE_int32(max_global_ttft_ms);
189189

190190
DECLARE_int32(max_global_tpot_ms);
191191

192-
// dit
193192
DECLARE_int32(max_requests_per_batch);
194193

195194
DECLARE_bool(enable_continuous_kvcache);
@@ -198,4 +197,4 @@ DECLARE_int64(granularity_size);
198197

199198
DECLARE_int64(cache_size_per_token);
200199

201-
DECLARE_int64(buffer_size_per_seq);
200+
DECLARE_int64(buffer_size_per_seq);

xllm/core/framework/batch/batch_input_builder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ class BatchInputBuilder {
128128
uint32_t q_seq_len,
129129
BuilderState* state_ptr = nullptr,
130130
std::unordered_set<int32_t>* write_block_ids_ptr = nullptr);
131-
132131
void setup_continuous_kv_cache_info(Sequence* sequence,
133132
uint32_t n_kv_cache_tokens,
134133
uint32_t seq_len,

xllm/core/framework/model/model_input_params.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ struct ModelInputParams {
9393

9494
// Copy graph_buffer to device
9595
params.graph_buffer = safe_to(graph_buffer, device, true);
96+
9697
return params;
9798
}
9899

xllm/core/kernels/CMakeLists.txt

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
include(cc_library)
22

33
if(USE_NPU)
4-
include_directories(
5-
${CMAKE_SOURCE_DIR}/third_party/spdlog/include
6-
)
74
add_subdirectory(npu)
85
endif()
96

107
if(USE_MLU)
118
add_subdirectory(mlu)
129
endif()
10+
11+
12+
cc_library(
13+
NAME
14+
kernels
15+
HDRS
16+
param.h
17+
ops_api.h
18+
SRCS
19+
ops_api.cpp
20+
DEPS
21+
torch
22+
$<$<BOOL:${USE_NPU}>:npu_kernels>
23+
$<$<BOOL:${USE_MLU}>:mlu_kernels>
24+
)

xllm/core/kernels/mlu/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ include(cc_library)
22

33
file(GLOB_RECURSE MLU_HEADER_FILES
44
"${CMAKE_CURRENT_LIST_DIR}/*.h"
5-
"${CMAKE_CURRENT_LIST_DIR}/*.hpp"
65
)
76

87
file(GLOB_RECURSE MLU_SOURCE_FILES
@@ -11,7 +10,7 @@ file(GLOB_RECURSE MLU_SOURCE_FILES
1110

1211
cc_library(
1312
NAME
14-
xllm_mlu_ops
13+
mlu_kernels
1514
HDRS
1615
${MLU_HEADER_FILES}
1716
SRCS

xllm/core/kernels/mlu/active.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "mlu_ops_api.h"
17+
#include "torch_mlu_ops.h"
18+
19+
namespace xllm::kernel::mlu {
20+
21+
void active(const torch::Tensor& input,
22+
torch::Tensor& output,
23+
const std::optional<torch::Tensor>& bias,
24+
const std::optional<torch::Tensor>& cusum_token_count,
25+
const std::string& act_mode,
26+
bool is_gated,
27+
int start_expert_id,
28+
int expert_size) {
29+
tmo::torch_api::active(input,
30+
output,
31+
bias,
32+
cusum_token_count,
33+
act_mode,
34+
is_gated,
35+
start_expert_id,
36+
expert_size);
37+
}
38+
} // namespace xllm::kernel::mlu
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "mlu_ops_api.h"
17+
#include "torch_mlu_ops.h"
18+
19+
namespace xllm::kernel::mlu {
20+
21+
void reshape_paged_cache(torch::Tensor& key,
22+
torch::Tensor& value,
23+
torch::Tensor& k_cache,
24+
torch::Tensor& v_cache,
25+
const torch::Tensor& slot_mapping,
26+
bool direction) {
27+
tmo::torch_api::reshape_paged_cache(
28+
key, value, k_cache, v_cache, slot_mapping, direction);
29+
}
30+
31+
void batch_prefill(const torch::Tensor& query,
32+
const torch::Tensor& key,
33+
const torch::Tensor& value,
34+
torch::Tensor& output,
35+
std::optional<torch::Tensor>& output_lse,
36+
const std::optional<torch::Tensor>& query_start_loc,
37+
const std::optional<torch::Tensor>& seq_start_loc,
38+
const std::optional<torch::Tensor>& alibi_slope,
39+
const std::optional<torch::Tensor>& attn_bias,
40+
const std::optional<torch::Tensor>& q_quant_scale,
41+
const std::optional<torch::Tensor>& k_quant_scale,
42+
const std::optional<torch::Tensor>& v_quant_scale,
43+
const std::optional<torch::Tensor>& out_quant_scale,
44+
const std::optional<torch::Tensor>& block_table,
45+
int max_query_len,
46+
int max_seq_len,
47+
float scale,
48+
bool is_causal,
49+
int window_size_left,
50+
int window_size_right,
51+
const std::string& compute_dtype,
52+
bool return_lse) {
53+
tmo::torch_api::flash_attention(query,
54+
key,
55+
value,
56+
output,
57+
output_lse,
58+
query_start_loc,
59+
seq_start_loc,
60+
alibi_slope,
61+
attn_bias,
62+
q_quant_scale,
63+
k_quant_scale,
64+
v_quant_scale,
65+
out_quant_scale,
66+
block_table,
67+
max_query_len,
68+
max_seq_len,
69+
scale,
70+
is_causal,
71+
window_size_left,
72+
window_size_right,
73+
compute_dtype,
74+
return_lse);
75+
}
76+
77+
void batch_decode(const torch::Tensor& query,
78+
const torch::Tensor& k_cache,
79+
torch::Tensor& output,
80+
const torch::Tensor& block_table,
81+
const torch::Tensor& seq_lens,
82+
const torch::Tensor& v_cache,
83+
std::optional<torch::Tensor>& output_lse,
84+
const std::optional<torch::Tensor>& q_quant_scale,
85+
const std::optional<torch::Tensor>& k_cache_quant_scale,
86+
const std::optional<torch::Tensor>& v_cache_quant_scale,
87+
const std::optional<torch::Tensor>& out_quant_scale,
88+
const std::optional<torch::Tensor>& alibi_slope,
89+
const std::optional<torch::Tensor>& mask,
90+
const std::string& compute_dtype,
91+
int max_seq_len,
92+
int window_size_left,
93+
int window_size_right,
94+
float scale,
95+
bool return_lse,
96+
int kv_cache_quant_bit_size) {
97+
tmo::torch_api::single_query_cached_kv_attn(query,
98+
k_cache,
99+
output,
100+
block_table,
101+
seq_lens,
102+
v_cache,
103+
output_lse,
104+
q_quant_scale,
105+
k_cache_quant_scale,
106+
v_cache_quant_scale,
107+
out_quant_scale,
108+
alibi_slope,
109+
mask,
110+
compute_dtype,
111+
max_seq_len,
112+
window_size_left,
113+
window_size_right,
114+
scale,
115+
return_lse,
116+
kv_cache_quant_bit_size);
117+
}
118+
119+
} // namespace xllm::kernel::mlu
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "mlu_ops_api.h"
17+
#include "torch_mlu_ops.h"
18+
19+
namespace xllm::kernel::mlu {
20+
21+
void fused_layernorm(const torch::Tensor& input,
22+
torch::Tensor& output,
23+
const std::optional<torch::Tensor>& residual,
24+
const torch::Tensor& weight,
25+
const std::optional<torch::Tensor>& beta,
26+
const std::optional<torch::Tensor>& bias,
27+
const std::optional<torch::Tensor>& quant_scale,
28+
const std::optional<torch::Tensor>& residual_out,
29+
const std::optional<torch::Tensor>& smooth_quant_scale,
30+
const std::optional<torch::Tensor>& normed_out,
31+
const std::string& mode,
32+
double eps,
33+
bool store_output_before_norm,
34+
bool store_output_after_norm,
35+
bool dynamic_quant) {
36+
tmo::torch_api::fused_layernorm(input,
37+
output,
38+
residual,
39+
weight,
40+
beta,
41+
bias,
42+
quant_scale,
43+
residual_out,
44+
smooth_quant_scale,
45+
normed_out,
46+
mode,
47+
eps,
48+
store_output_before_norm,
49+
store_output_after_norm,
50+
dynamic_quant);
51+
}
52+
53+
} // namespace xllm::kernel::mlu

xllm/core/kernels/mlu/fused_moe.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#include "mlu_ops_api.h"
1617
#include "torch_mlu_ops.h"
17-
#include "torch_ops_api.h"
1818

1919
namespace {
2020
torch::Tensor create_group_gemm_output(const torch::Tensor& a,
@@ -27,7 +27,7 @@ torch::Tensor create_group_gemm_output(const torch::Tensor& a,
2727
}
2828
} // namespace
2929

30-
namespace xllm::mlu {
30+
namespace xllm::kernel::mlu {
3131
torch::Tensor fused_moe(
3232
const torch::Tensor& hidden_states,
3333
const torch::Tensor& gating_output,
@@ -175,4 +175,4 @@ torch::Tensor fused_moe(
175175
return output.reshape(ori_input_shape);
176176
}
177177

178-
} // namespace xllm::mlu
178+
} // namespace xllm::kernel::mlu

0 commit comments

Comments
 (0)