1+ /* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ ==============================================================================*/
15+
16+ #include " mlu_ops_api.h"
17+ #include " torch_mlu_ops.h"
18+
19+ namespace xllm ::kernel::mlu {
20+
21+ void reshape_paged_cache (torch::Tensor& key,
22+ torch::Tensor& value,
23+ torch::Tensor& k_cache,
24+ torch::Tensor& v_cache,
25+ const torch::Tensor& slot_mapping,
26+ bool direction) {
27+ tmo::torch_api::reshape_paged_cache (
28+ key, value, k_cache, v_cache, slot_mapping, direction);
29+ }
30+
31+ void batch_prefill (const torch::Tensor& query,
32+ const torch::Tensor& key,
33+ const torch::Tensor& value,
34+ torch::Tensor& output,
35+ std::optional<torch::Tensor>& output_lse,
36+ const std::optional<torch::Tensor>& query_start_loc,
37+ const std::optional<torch::Tensor>& seq_start_loc,
38+ const std::optional<torch::Tensor>& alibi_slope,
39+ const std::optional<torch::Tensor>& attn_bias,
40+ const std::optional<torch::Tensor>& q_quant_scale,
41+ const std::optional<torch::Tensor>& k_quant_scale,
42+ const std::optional<torch::Tensor>& v_quant_scale,
43+ const std::optional<torch::Tensor>& out_quant_scale,
44+ const std::optional<torch::Tensor>& block_table,
45+ int max_query_len,
46+ int max_seq_len,
47+ float scale,
48+ bool is_causal,
49+ int window_size_left,
50+ int window_size_right,
51+ const std::string& compute_dtype,
52+ bool return_lse) {
53+ tmo::torch_api::flash_attention (query,
54+ key,
55+ value,
56+ output,
57+ output_lse,
58+ query_start_loc,
59+ seq_start_loc,
60+ alibi_slope,
61+ attn_bias,
62+ q_quant_scale,
63+ k_quant_scale,
64+ v_quant_scale,
65+ out_quant_scale,
66+ block_table,
67+ max_query_len,
68+ max_seq_len,
69+ scale,
70+ is_causal,
71+ window_size_left,
72+ window_size_right,
73+ compute_dtype,
74+ return_lse);
75+ }
76+
77+ void batch_decode (const torch::Tensor& query,
78+ const torch::Tensor& k_cache,
79+ torch::Tensor& output,
80+ const torch::Tensor& block_table,
81+ const torch::Tensor& seq_lens,
82+ const torch::Tensor& v_cache,
83+ std::optional<torch::Tensor>& output_lse,
84+ const std::optional<torch::Tensor>& q_quant_scale,
85+ const std::optional<torch::Tensor>& k_cache_quant_scale,
86+ const std::optional<torch::Tensor>& v_cache_quant_scale,
87+ const std::optional<torch::Tensor>& out_quant_scale,
88+ const std::optional<torch::Tensor>& alibi_slope,
89+ const std::optional<torch::Tensor>& mask,
90+ const std::string& compute_dtype,
91+ int max_seq_len,
92+ int window_size_left,
93+ int window_size_right,
94+ float scale,
95+ bool return_lse,
96+ int kv_cache_quant_bit_size) {
97+ tmo::torch_api::single_query_cached_kv_attn (query,
98+ k_cache,
99+ output,
100+ block_table,
101+ seq_lens,
102+ v_cache,
103+ output_lse,
104+ q_quant_scale,
105+ k_cache_quant_scale,
106+ v_cache_quant_scale,
107+ out_quant_scale,
108+ alibi_slope,
109+ mask,
110+ compute_dtype,
111+ max_seq_len,
112+ window_size_left,
113+ window_size_right,
114+ scale,
115+ return_lse,
116+ kv_cache_quant_bit_size);
117+ }
118+
119+ } // namespace xllm::kernel::mlu
0 commit comments