Skip to content

Commit d1292ce

Browse files
authored
Merge pull request #168 from InfiniTensor/rope-teco
teco-rope:太初平台rope算子重构
2 parents 6f6019a + ec9544c commit d1292ce

File tree

4 files changed

+243
-0
lines changed

4 files changed

+243
-0
lines changed

operatorspy/tests/rotary_embedding.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,14 @@ def test_ascend(lib, test_cases) :
165165
test(lib, handle, "npu", shape, strides, dtype)
166166
destroy_handle(lib, handle)
167167

168+
def test_teco(lib, test_cases):
169+
import torch_sdaa
170+
device = DeviceEnum.DEVICE_TECO
171+
handle = create_handle(lib, device)
172+
for shape, strides, dtype in test_cases:
173+
test(lib, handle, "sdaa", shape, strides, dtype)
174+
destroy_handle(lib, handle)
175+
168176
if __name__ == "__main__":
169177
test_cases = [
170178
((1, 32, 128), None, torch.float16),
@@ -215,5 +223,7 @@ def test_ascend(lib, test_cases) :
215223
test_bang(lib, test_cases)
216224
if args.ascend:
217225
test_ascend(lib, test_cases)
226+
if args.teco:
227+
test_teco(lib, test_cases)
218228
if not (args.cpu or args.cuda or args.bang or args.ascend):
219229
test_cpu(lib, test_cases)

src/ops/rotary_embedding/operator.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
#ifdef ENABLE_ASCEND_NPU
1616
#include "ascend/rotary_embedding.h"
1717
#endif
18+
#ifdef ENABLE_TECO_SDAA
19+
#include "teco/rotary_embedding_sdaa.h"
20+
#endif
1821

1922
struct RoPEDescriptor {
2023
Device device;
@@ -52,6 +55,15 @@ __C infiniopStatus_t infiniopCreateRoPEDescriptor(infiniopHandle_t handle,
5255
sin_table,
5356
cos_table);
5457
}
58+
#endif
59+
#ifdef ENABLE_TECO_SDAA
60+
case DevTecoSDAA:
61+
return tecoCreateRoPEDescriptor((TecoHandle_t) handle,
62+
(RoPETecoDescriptor_t *) desc_ptr,
63+
t,
64+
pos_ids,
65+
sin_table,
66+
cos_table);
5567
#endif
5668
}
5769
return STATUS_BAD_DEVICE;
@@ -79,6 +91,11 @@ __C infiniopStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
7991
return ascendGetRoPEWorkspaceSize((RoPEAscendDescriptor_t) desc,
8092
size);
8193
}
94+
#endif
95+
#ifdef ENABLE_TECO_SDAA
96+
case DevTecoSDAA:
97+
return tecoGetRoPEWorkspaceSize((RoPETecoDescriptor_t) desc,
98+
size);
8299
#endif
83100
}
84101
return STATUS_BAD_DEVICE;
@@ -119,6 +136,16 @@ __C infiniopStatus_t infiniopRoPE(infiniopRoPEDescriptor_t desc,
119136
cos_table,
120137
stream);
121138
}
139+
#endif
140+
#ifdef ENABLE_TECO_SDAA
141+
case DevTecoSDAA:
142+
return tecoRoPE((RoPETecoDescriptor_t) desc, workspace,
143+
workspace_size,
144+
t,
145+
pos_ids,
146+
sin_table,
147+
cos_table,
148+
stream);
122149
#endif
123150
}
124151
return STATUS_BAD_DEVICE;
@@ -145,6 +172,10 @@ __C infiniopStatus_t infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc
145172
case DevAscendNpu: {
146173
return ascendDestroyRoPEDescriptor((RoPEAscendDescriptor_t) desc);
147174
}
175+
#endif
176+
#ifdef ENABLE_TECO_SDAA
177+
case DevTecoSDAA:
178+
return tecoDestroyRoPEDescriptor((RoPETecoDescriptor_t) desc);
148179
#endif
149180
}
150181
return STATUS_BAD_DEVICE;
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#ifndef __SDAA_ROPE_H__
2+
#define __SDAA_ROPE_H__
3+
#include "../../../devices/teco/teco_handle.h"
4+
#include "../../utils.h"
5+
#include "operators.h"
6+
#include <sdaa_runtime.h>
7+
struct RoPETecoDescriptor {
8+
Device device;
9+
int device_id;
10+
DT dtype;
11+
uint64_t seqlen;
12+
uint64_t nhead;
13+
uint64_t dhead;
14+
uint64_t total_seqlen;
15+
int x_stride_seqlen;
16+
int x_stride_nhead;
17+
};
18+
19+
typedef struct RoPETecoDescriptor *RoPETecoDescriptor_t;
20+
21+
22+
infiniopStatus_t tecoCreateRoPEDescriptor(TecoHandle_t handle,
23+
RoPETecoDescriptor_t *desc_ptr,
24+
infiniopTensorDescriptor_t t,
25+
infiniopTensorDescriptor_t pos_ids,
26+
infiniopTensorDescriptor_t sin_table,
27+
infiniopTensorDescriptor_t cos_table);
28+
29+
infiniopStatus_t tecoGetRoPEWorkspaceSize(RoPETecoDescriptor_t desc, uint64_t *size);
30+
31+
infiniopStatus_t tecoRoPE(RoPETecoDescriptor_t desc,
32+
void *workspace,
33+
uint64_t workspace_size,
34+
void *t,
35+
void const *pos_ids,
36+
void const *sin_table,
37+
void const *cos_table,
38+
void *stream);
39+
40+
infiniopStatus_t tecoDestroyRoPEDescriptor(RoPETecoDescriptor_t desc);
41+
42+
43+
#endif
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#include "rotary_embedding_sdaa.h"
2+
3+
__local__ halfv16 x_local, y_local;
4+
__local__ floatv16 sin_local, cos_local, tmp_local;
5+
6+
infiniopStatus_t tecoCreateRoPEDescriptor(TecoHandle_t handle,
7+
RoPETecoDescriptor_t *desc_ptr,
8+
infiniopTensorDescriptor_t t,
9+
infiniopTensorDescriptor_t pos_ids,
10+
infiniopTensorDescriptor_t sin_table,
11+
infiniopTensorDescriptor_t cos_table){
12+
if (desc_ptr == nullptr)
13+
return STATUS_MEMORY_NOT_ALLOCATED;
14+
15+
if (t->ndim != 3 ||
16+
pos_ids->ndim != 1 ||
17+
sin_table->ndim != 2 ||
18+
cos_table->ndim != 2)
19+
return STATUS_BAD_TENSOR_SHAPE;
20+
21+
auto seqlen = t->shape[0];
22+
auto nhead = t->shape[1];
23+
auto dhead = t->shape[2];
24+
auto total_seqlen = sin_table->shape[0];
25+
26+
if (dhead % 2 != 0)
27+
return STATUS_BAD_TENSOR_SHAPE;
28+
29+
if (pos_ids->shape[0] != seqlen ||
30+
sin_table->shape[1] != dhead ||
31+
cos_table->shape[1] != dhead ||
32+
sin_table->shape[0] != cos_table->shape[0])
33+
return STATUS_BAD_TENSOR_SHAPE;
34+
35+
if (t->strides[2] != 1 ||
36+
pos_ids->strides[0] != 1 ||
37+
sin_table->strides[1] != 1 ||
38+
cos_table->strides[1] != 1)
39+
return STATUS_BAD_TENSOR_STRIDES;
40+
41+
if (!dtype_eq(t->dt, F16))
42+
return STATUS_BAD_TENSOR_DTYPE;
43+
44+
if (!dtype_eq(sin_table->dt, F32) || !dtype_eq(cos_table->dt, F32))
45+
return STATUS_BAD_TENSOR_DTYPE;
46+
47+
if (!dtype_eq(pos_ids->dt, U64))
48+
return STATUS_BAD_TENSOR_DTYPE;
49+
int x_stride_seqlen = static_cast<int>(t->strides[0]);
50+
int x_stride_nhead = static_cast<int>(t->strides[1]);
51+
*desc_ptr = new RoPETecoDescriptor{
52+
handle->device,
53+
handle->device_id,
54+
t->dt,
55+
seqlen,
56+
nhead,
57+
dhead,
58+
total_seqlen,
59+
x_stride_seqlen,
60+
x_stride_nhead};
61+
62+
return STATUS_SUCCESS;
63+
}
64+
65+
infiniopStatus_t tecoGetRoPEWorkspaceSize(RoPETecoDescriptor_t desc, uint64_t *size) {
66+
*size = 0;
67+
return STATUS_SUCCESS;
68+
}
69+
70+
__global__ void RoPE(half *destination,
71+
const uint64_t *pos_ids,
72+
const float *sin_table, const float *cos_table,
73+
int x_stride_seqlen, int x_stride_nhead,
74+
int seqlen, int nhead, int dhead){
75+
int other_size = seqlen * nhead;
76+
int remain = other_size % threadDim;
77+
int step_easy = (other_size - remain) / threadDim;
78+
int step_hard = step_easy + 1;
79+
int step = (threadIdx < remain ? step_hard : step_easy);
80+
int ind_start = (threadIdx < remain ? threadIdx * step_hard : remain * step_hard + (threadIdx - remain) * step_easy);
81+
82+
int buf_size = 16;
83+
int remain_dhead = dhead % buf_size;
84+
int repeat = (dhead - remain_dhead) / buf_size;
85+
86+
for(int i = ind_start; i < ind_start + step; i++){
87+
int ind_i = i;
88+
int ind_s = 0;
89+
90+
ind_s += (ind_i % nhead) * x_stride_nhead;
91+
ind_i /= nhead;
92+
ind_s += (ind_i % seqlen) * x_stride_seqlen;
93+
94+
int index = static_cast<int>(pos_ids[ind_i % seqlen]) * dhead;
95+
96+
for(int r = 0; r < repeat; r++){
97+
int start_s = ind_s + r * buf_size;
98+
int sin_cos_index = index + r * buf_size;
99+
100+
simd_load(x_local, destination + start_s);
101+
simd_load(sin_local, sin_table + sin_cos_index);
102+
simd_load(cos_local, cos_table + sin_cos_index);
103+
104+
tmp_local = simd_cvt_h2f(x_local);
105+
106+
for(int k = 0; k < buf_size / 2; k++){
107+
float a = tmp_local[2 * k];
108+
float b = tmp_local[2 * k + 1];
109+
float sin0 = sin_local[2 * k], cos0 = cos_local[2 * k];
110+
float sin1 = sin_local[2 * k + 1], cos1 = cos_local[2 * k + 1];
111+
tmp_local[2 * k] = a * cos0 - b * sin0;
112+
tmp_local[2 * k + 1] = a * sin1 + b * cos1;
113+
}
114+
y_local = simd_cvt_f2h(tmp_local);
115+
simd_store(y_local, destination + start_s);
116+
117+
}
118+
if(remain_dhead){
119+
int start_s = ind_s + repeat * buf_size;
120+
int sin_cos_index = index + repeat * buf_size;
121+
for(int k = 0; k < remain_dhead / 2; k++){
122+
float a = static_cast<float>(destination[start_s + 2 * k]);
123+
float b = static_cast<float>(destination[start_s + 2 * k + 1]);
124+
float sin0 = sin_table[sin_cos_index + 2 * k], cos0 = cos_local[sin_cos_index + 2 * k];
125+
float sin1 = sin_local[sin_cos_index + 2 * k + 1], cos1 = cos_local[sin_cos_index + 2 * k + 1];
126+
destination[start_s + 2 * k] = static_cast<half>(a * cos0 - b * sin0);
127+
destination[start_s + 2 * k + 1] = static_cast<half>(a * sin1 + b * cos1);
128+
}
129+
}
130+
}
131+
}
132+
133+
infiniopStatus_t tecoRoPE(RoPETecoDescriptor_t desc,
134+
void *workspace,
135+
uint64_t workspace_size,
136+
void *t,
137+
void const *pos_ids,
138+
void const *sin_table,
139+
void const *cos_table,
140+
void *stream){
141+
auto t_ptr = reinterpret_cast<half *>(t);
142+
auto sin_ptr = reinterpret_cast<const float *>(sin_table);
143+
auto cos_ptr = reinterpret_cast<const float *>(cos_table);
144+
auto pos_ptr = reinterpret_cast<const uint64_t *>(pos_ids);
145+
146+
int seqlen = static_cast<int>(desc->seqlen);
147+
int nhead = static_cast<int>(desc->nhead);
148+
int dhead = static_cast<int>(desc->dhead);
149+
int x_stride_seqlen = desc->x_stride_seqlen;
150+
int x_stride_nhead = desc->x_stride_nhead;
151+
152+
RoPE<<<1, (sdaaStream_t)stream>>>(t_ptr, pos_ptr, sin_ptr, cos_ptr, x_stride_seqlen, x_stride_nhead, seqlen, nhead, dhead);
153+
return STATUS_SUCCESS;
154+
}
155+
156+
infiniopStatus_t tecoDestroyRoPEDescriptor(RoPETecoDescriptor_t desc){
157+
delete desc;
158+
return STATUS_SUCCESS;
159+
}

0 commit comments

Comments
 (0)