Skip to content

Commit c5b8f43

Browse files
authored
Merge pull request #171 from InfiniTensor/causal_softmax_teco
太初平台causal_softmax算子
2 parents d1292ce + d5ea800 commit c5b8f43

File tree

4 files changed

+310
-3
lines changed

4 files changed

+310
-3
lines changed

operatorspy/tests/causal_softmax.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, x_dtype=torch.float1
6868
None,
6969
)
7070
)
71+
#print(x.flatten()[0], ans.flatten()[0])
7172
assert torch.allclose(x, ans, atol=0, rtol=1e-2)
7273
check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))
7374

@@ -106,6 +107,14 @@ def test_ascend(lib, test_cases):
106107
test(lib, handle, "npu", x_shape, x_stride)
107108

108109
destroy_handle(lib, handle)
110+
def test_teco(lib, test_cases):
111+
import torch_sdaa
112+
device = DeviceEnum.DEVICE_TECO
113+
handle = create_handle(lib, device)
114+
for x_shape, x_stride in test_cases:
115+
test(lib, handle, "sdaa", x_shape, x_stride)
116+
117+
destroy_handle(lib, handle)
109118

110119
if __name__ == "__main__":
111120
test_cases = [
@@ -147,6 +156,8 @@ def test_ascend(lib, test_cases):
147156
test_bang(lib, test_cases)
148157
if args.ascend:
149158
test_ascend(lib, test_cases)
150-
if not (args.cpu or args.cuda or args.bang or args.ascend):
159+
if args.teco:
160+
test_teco(lib, test_cases)
161+
if not (args.cpu or args.cuda or args.bang or args.ascend or args.teco):
151162
test_cpu(lib, test_cases)
152163
print("Test passed!")

src/ops/causal_softmax/operator.cc

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
#endif
88
#ifdef ENABLE_NV_GPU
99
#include "../../devices/cuda/common_cuda.h"
10-
#include "cuda/causal_softmax.cuh"
1110
#include "../../devices/cuda/cuda_handle.h"
11+
#include "cuda/causal_softmax.cuh"
1212
#endif
1313
#ifdef ENABLE_CAMBRICON_MLU
1414
#include "../../devices/bang/bang_handle.h"
@@ -18,6 +18,9 @@
1818
#ifdef ENABLE_ASCEND_NPU
1919
#include "ascend/causal_softmax_aclnn.h"
2020
#endif
21+
#ifdef ENABLE_TECO_SDAA
22+
#include "teco/causal_softmax_sdaa.h"
23+
#endif
2124

2225
__C infiniopStatus_t infiniopCreateCausalSoftmaxDescriptor(
2326
infiniopHandle_t handle,
@@ -30,7 +33,7 @@ __C infiniopStatus_t infiniopCreateCausalSoftmaxDescriptor(
3033
#endif
3134
#ifdef ENABLE_NV_GPU
3235
case DevNvGpu: {
33-
return cudaCreateCausalSoftmaxDescriptor((CudaHandle_t)handle, (CausalSoftmaxCudaDescriptor_t *) desc_ptr, y_desc);
36+
return cudaCreateCausalSoftmaxDescriptor((CudaHandle_t) handle, (CausalSoftmaxCudaDescriptor_t *) desc_ptr, y_desc);
3437
}
3538

3639
#endif
@@ -44,6 +47,10 @@ __C infiniopStatus_t infiniopCreateCausalSoftmaxDescriptor(
4447
case DevAscendNpu: {
4548
return aclnnCreateCausalSoftmaxDescriptor((AscendHandle_t) handle, (CausalSoftmaxAclnnDescriptor_t *) desc_ptr, y_desc);
4649
}
50+
#endif
51+
#ifdef ENABLE_TECO_SDAA
52+
case DevTecoSDAA:
53+
return tecoCreateCausalSoftmaxDescriptor((TecoHandle_t) handle, (CausalSoftmaxTecoDescriptor_t *) desc_ptr, y_desc);
4754
#endif
4855
}
4956
return STATUS_BAD_DEVICE;
@@ -72,6 +79,10 @@ __C infiniopStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmax
7279
case DevAscendNpu: {
7380
return aclnnGetCausalSoftmaxWorkspaceSize((CausalSoftmaxAclnnDescriptor_t) desc, size);
7481
}
82+
#endif
83+
#ifdef ENABLE_TECO_SDAA
84+
case DevTecoSDAA:
85+
return tecoGetCausalSoftmaxWorkspaceSize((CausalSoftmaxTecoDescriptor_t) desc, size);
7586
#endif
7687
}
7788
return STATUS_BAD_DEVICE;
@@ -99,6 +110,10 @@ __C infiniopStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t des
99110
case DevAscendNpu: {
100111
return aclnnCausalSoftmax((CausalSoftmaxAclnnDescriptor_t) desc, workspace, workspace_size, data, stream);
101112
}
113+
#endif
114+
#ifdef ENABLE_TECO_SDAA
115+
case DevTecoSDAA:
116+
return tecoCausalSoftmax((CausalSoftmaxTecoDescriptor_t) desc, workspace, workspace_size, data, stream);
102117
#endif
103118
}
104119
return STATUS_BAD_DEVICE;
@@ -126,6 +141,10 @@ __C infiniopStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftma
126141
case DevAscendNpu: {
127142
return aclnnDestroyCausalSoftmaxDescriptor((CausalSoftmaxAclnnDescriptor_t) desc);
128143
}
144+
#endif
145+
#ifdef ENABLE_TECO_SDAA
146+
case DevTecoSDAA:
147+
return tecoDestroyCausalSoftmaxDescriptor((CausalSoftmaxTecoDescriptor_t) desc);
129148
#endif
130149
}
131150
return STATUS_BAD_DEVICE;
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#ifndef __SDAA_CAUSAL_SOFTMAX_H__
2+
#define __SDAA_CAUSAL_SOFTMAX_H__
3+
#include "../../../devices/teco/teco_handle.h"
4+
#include "../../utils.h"
5+
#include "operators.h"
6+
#include <sdaa_runtime.h>
7+
struct CausalSoftmaxTecoDescriptor {
8+
Device device;
9+
int device_id;
10+
DT dtype;
11+
int ndim;
12+
int *stride;
13+
int *shape;
14+
};
15+
16+
typedef struct CausalSoftmaxTecoDescriptor *CausalSoftmaxTecoDescriptor_t;
17+
18+
19+
infiniopStatus_t tecoCreateCausalSoftmaxDescriptor(TecoHandle_t handle,
20+
CausalSoftmaxTecoDescriptor_t *desc_ptr,
21+
infiniopTensorDescriptor_t y_desc);
22+
23+
infiniopStatus_t tecoGetCausalSoftmaxWorkspaceSize(CausalSoftmaxTecoDescriptor_t desc, uint64_t *size);
24+
25+
infiniopStatus_t tecoCausalSoftmax(CausalSoftmaxTecoDescriptor_t desc,
26+
void *workspace,
27+
uint64_t workspace_size,
28+
void *data,
29+
void *stream);
30+
31+
infiniopStatus_t tecoDestroyCausalSoftmaxDescriptor(CausalSoftmaxTecoDescriptor_t desc);
32+
33+
34+
#endif
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
#include "causal_softmax_sdaa.h"
2+
3+
__local__ halfv16 h_local;
4+
__local__ floatv16 f_local;
5+
6+
infiniopStatus_t tecoCreateCausalSoftmaxDescriptor(TecoHandle_t handle,
7+
CausalSoftmaxTecoDescriptor_t *desc_ptr,
8+
infiniopTensorDescriptor_t y_desc){
9+
if (y_desc->ndim < 2 || y_desc->shape[y_desc->ndim - 1] < y_desc->shape[y_desc->ndim - 2]) {
10+
return STATUS_BAD_TENSOR_SHAPE;
11+
}
12+
13+
int ndim = y_desc->ndim;
14+
int *shape = (int *)malloc(ndim * sizeof(int));
15+
int *stride = (int *)malloc(ndim * sizeof(int));
16+
17+
18+
for (int i = 0; i < ndim; i++) {
19+
stride[i] = static_cast<int>(y_desc->strides[i]);
20+
shape[i] = static_cast<int>(y_desc->shape[i]);
21+
}
22+
23+
*desc_ptr = new CausalSoftmaxTecoDescriptor{
24+
handle->device,
25+
handle->device_id,
26+
y_desc->dt,
27+
ndim,
28+
stride,
29+
shape};
30+
return STATUS_SUCCESS;
31+
}
32+
33+
infiniopStatus_t tecoGetCausalSoftmaxWorkspaceSize(CausalSoftmaxTecoDescriptor_t desc, uint64_t *size) {
34+
*size = desc->ndim * sizeof(int) * 2;
35+
return STATUS_SUCCESS;
36+
}
37+
38+
template<typename T>
39+
__global__ void causalSoftmax(T *destination, int *shape, int *stride, int ndim, int mask){
40+
int othersize = 1;
41+
for(int i = 0; i < ndim - 1; i++){
42+
othersize *= shape[i];
43+
}
44+
int remain = othersize % threadDim;
45+
int step_easy = (othersize - remain) / threadDim;
46+
int step_hard = step_easy + 1;
47+
int step = (threadIdx < remain ? step_hard : step_easy);
48+
int ind_start = (threadIdx < remain ? threadIdx * step_hard : (remain * step_hard + (threadIdx - remain) * step_easy));
49+
50+
int dimsize = shape[ndim - 1];
51+
int buf_size = 16;
52+
53+
for (int i = ind_start; i < ind_start + step; i++) {
54+
int ind_d = 0;
55+
int ind_i = i;
56+
int lastI = ind_i % shape[ndim - 2];
57+
58+
int remain_dhead = (lastI + mask + 1) % buf_size;
59+
int repeat = (lastI + mask + 1 - remain_dhead) / buf_size;//针对前面这部分做softmax
60+
61+
int length = dimsize - (lastI + mask + 1);
62+
int remainI = length % buf_size;
63+
int rI = (length - remainI) / buf_size;//把后面这部分赋值为0
64+
65+
for (int j = ndim - 2; j >= 0; --j) {
66+
ind_d += (ind_i % shape[j]) * stride[j];
67+
ind_i /= shape[j];
68+
}
69+
//下面开始计算max,sum
70+
71+
float new_max = destination[ind_d];
72+
float old_max = new_max;
73+
float sum_value = 0.0f;
74+
for(int r = 0; r < repeat; r++){
75+
int start = ind_d + r * buf_size;
76+
if constexpr (std::is_same<T, half>::value){
77+
simd_load(h_local, destination + start);
78+
f_local = simd_cvt_h2f(h_local);
79+
}
80+
else if constexpr (std::is_same<T, float>::value){
81+
simd_load(f_local, destination + start);
82+
}
83+
for(int k = 0; k < buf_size; k++){
84+
if(new_max < f_local[k]){
85+
new_max = f_local[k];
86+
}
87+
}
88+
for(int k = 0; k < buf_size; k++){
89+
f_local[k] = expf(f_local[k] - new_max);
90+
}
91+
if(r > 0){
92+
sum_value = sum_value * expf(old_max - new_max);
93+
}
94+
sum_value += simd_redsum(f_local);
95+
old_max = new_max;
96+
}
97+
if(remain_dhead){
98+
int start = ind_d + repeat * buf_size;
99+
for(int k = 0; k < remain_dhead; k++){
100+
if constexpr (std::is_same<T, half>::value){
101+
if (new_max < static_cast<float>(destination[start + k])){
102+
new_max = static_cast<float>(destination[start + k]);
103+
}
104+
}
105+
else if constexpr (std::is_same<T, float>::value){
106+
if (new_max < destination[start + k]){
107+
new_max = destination[start + k];
108+
}
109+
}
110+
}
111+
if (repeat > 0){
112+
sum_value = sum_value * expf(old_max - new_max);
113+
}
114+
for(int k = 0; k < remain_dhead; k++){
115+
if constexpr (std::is_same<T, half>::value){
116+
sum_value += expf(static_cast<float>(destination[start + k]) - new_max);
117+
}
118+
else if constexpr (std::is_same<T, float>::value){
119+
sum_value += expf(destination[start + k] - new_max);
120+
}
121+
}
122+
}
123+
124+
float sum_inv = 1.0f / sum_value;
125+
//下面开始做softmax变换
126+
for(int r = 0; r < repeat; r++){
127+
int start = ind_d + r * buf_size;
128+
if constexpr (std::is_same<T, half>::value){
129+
simd_load(h_local, destination + start);
130+
f_local = simd_cvt_h2f(h_local);
131+
}
132+
else if constexpr (std::is_same<T, float>::value){
133+
simd_load(f_local, destination + start);
134+
}
135+
136+
for(int k = 0; k < buf_size; k++){
137+
f_local[k] = expf(f_local[k] - new_max) * sum_inv;
138+
}
139+
if constexpr (std::is_same<T, half>::value){
140+
h_local = simd_cvt_f2h(f_local);
141+
simd_store(h_local, destination + start);
142+
}
143+
else if constexpr (std::is_same<T, float>::value){
144+
simd_store(f_local, destination + start);
145+
}
146+
}
147+
if(remain_dhead){
148+
int start = ind_d + repeat * buf_size;
149+
for(int k = 0; k < remain_dhead; k++){
150+
if constexpr (std::is_same<T, half>::value){
151+
destination[start + k] = static_cast<half>(expf(static_cast<float>(destination[start + k]) - new_max) * sum_inv);
152+
}
153+
else if constexpr (std::is_same<T, float>::value){
154+
destination[start + k] = expf(destination[start + k] - new_max) * sum_inv;
155+
}
156+
}
157+
158+
}
159+
160+
//针对剩下部分赋值为0
161+
for(int r = 0; r < rI; r++){
162+
int start = ind_d + mask + 1 + lastI + r * buf_size;
163+
if constexpr (std::is_same<T, half>::value){
164+
for(int k = 0; k < buf_size; k++){
165+
destination[start + k] = static_cast<half>(0.0f);
166+
}
167+
}
168+
else if constexpr (std::is_same<T, float>::value){
169+
for(int k = 0; k < buf_size; k++){
170+
destination[start + k] = 0.0f;
171+
}
172+
}
173+
/***
174+
if constexpr (std::is_same<T, half>::value){
175+
simd_load(h_local, destination + start);
176+
for(int k = 0; k < buf_size; k++){
177+
h_local[k] = static_cast<half>(0.0f);
178+
}
179+
simd_store(h_local, destination + start);
180+
}
181+
else if constexpr (std::is_same<T, float>::value){
182+
simd_load(f_local, destination + start);
183+
for(int k = 0; k < buf_size; k++){
184+
f_local[k] = 0.0f;
185+
}
186+
simd_store(f_local, destination + start);
187+
}
188+
***/
189+
}
190+
191+
if (remainI){
192+
int start = ind_d + mask + 1 + lastI + rI * buf_size;
193+
if constexpr (std::is_same<T, half>::value){
194+
for(int k = 0; k < remainI; k++){
195+
destination[start + k] = static_cast<half>(0.0f);
196+
}
197+
}
198+
else if constexpr (std::is_same<T, float>::value){
199+
for(int k = 0; k < remainI; k++){
200+
destination[start + k] = 0.0f;
201+
}
202+
}
203+
}
204+
205+
}
206+
}
207+
208+
infiniopStatus_t tecoCausalSoftmax(CausalSoftmaxTecoDescriptor_t desc,
209+
void *workspace,
210+
uint64_t workspace_size,
211+
void *data,
212+
void *stream){
213+
int ndim = desc->ndim;
214+
int mask = desc->shape[ndim - 1] - desc->shape[ndim - 2];
215+
216+
int *teco_stride = reinterpret_cast<int *>(workspace);
217+
int *teco_shape = teco_stride + ndim;
218+
219+
sdaaMemcpy(teco_stride, desc->stride, ndim * sizeof(int), sdaaMemcpyHostToDevice);
220+
sdaaMemcpy(teco_shape, desc->shape, ndim * sizeof(int), sdaaMemcpyHostToDevice);
221+
sdaaDeviceSynchronize();
222+
if(dtype_eq(desc->dtype, F16)){
223+
auto destination = reinterpret_cast<half *>(data);
224+
causalSoftmax<half><<<1, (sdaaStream_t)stream>>>(destination, teco_shape, teco_stride, ndim, mask);
225+
sdaaDeviceSynchronize();
226+
return STATUS_SUCCESS;
227+
}
228+
else if(dtype_eq(desc->dtype, F32)){
229+
auto destination = reinterpret_cast<float *>(data);
230+
causalSoftmax<float><<<1, (sdaaStream_t)stream>>>(destination, teco_shape, teco_stride, ndim, mask);
231+
sdaaDeviceSynchronize();
232+
return STATUS_SUCCESS;
233+
}
234+
235+
return STATUS_BAD_TENSOR_DTYPE;
236+
}
237+
238+
infiniopStatus_t tecoDestroyCausalSoftmaxDescriptor(CausalSoftmaxTecoDescriptor_t desc){
239+
//free(desc->stride);
240+
//free(desc->shape);
241+
delete desc;
242+
return STATUS_SUCCESS;
243+
}

0 commit comments

Comments
 (0)