Skip to content

Commit 6f6019a

Browse files
committed
implement swiglu inplace and out of place operation
1 parent 817a2ba commit 6f6019a

File tree

8 files changed

+79
-104
lines changed

8 files changed

+79
-104
lines changed

operatorspy/tests/swiglu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def test_in_place1(
124124
descriptor, a_tensor.data, a_tensor.data, b_tensor.data, None
125125
)
126126
)
127-
128127
assert torch.allclose(a, ans, atol=1e-4, rtol=1e-2)
129128
print("in-place1 Test passed!")
130129

src/devices/teco/tensor_teco.cpp

Lines changed: 0 additions & 23 deletions
This file was deleted.

src/devices/teco/tensor_teco.h

Lines changed: 0 additions & 26 deletions
This file was deleted.

src/ops/swiglu/operator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include "ascend/swiglu.h"
1616
#endif
1717
#ifdef ENABLE_TECO_SDAA
18-
#include "teco/swiglu_tecodnn.h"
18+
#include "teco/swiglu_sdaa.h"
1919
#endif
2020

2121
__C infiniopStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t handle,
Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,22 @@
1-
#ifndef __TECO_SWIGLU_H__
2-
#define __TECO_SWIGLU_H__
3-
1+
#ifndef __SDAA_SWIGLU_H__
2+
#define __SDAA_SWIGLU_H__
43
#include "operators.h"
54
#include <sdaa_runtime.h>
6-
#include <tecodnn.h>
75
#include "../../../devices/teco/teco_handle.h"
86
struct SwiGLUTecoDescriptor {
97
Device device;
108
int device_id;
119
sdaaStream_t stream;
12-
tecodnnHandle_t handle;
13-
tecodnnActivationDescriptor_t activationDesc;
14-
tecodnnTensorDescriptor_t aDesc,bDesc,cDesc;
10+
uint64_t rows,cols;
11+
int64_t lda,ldb,ldc;
1512
};
1613

1714
typedef struct SwiGLUTecoDescriptor *SwiGLUTecoDescriptor_t;
1815

16+
1917
infiniopStatus_t tecoCreateSwiGLUDescriptor(TecoHandle_t handle,
2018
SwiGLUTecoDescriptor_t *desc_ptr,
21-
infiniopTensorDescriptor_t c_desc,
19+
infiniopTensorDescriptor_t c_desc,
2220
infiniopTensorDescriptor_t a_desc,
2321
infiniopTensorDescriptor_t b_desc);
2422

@@ -30,4 +28,5 @@ infiniopStatus_t tecoSwiGLU(SwiGLUTecoDescriptor_t desc,
3028

3129
infiniopStatus_t tecoDestroySwiGLUDescriptor(SwiGLUTecoDescriptor_t desc);
3230

31+
3332
#endif
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#include "swiglu_sdaa.h"
2+
__local__ halfv16 tempa, tempb, tempc;
3+
4+
__device__ void silu_halfv16(halfv16 *c, halfv16 *a, halfv16 *b) {
5+
floatv16 one_v = simd_stretch(1.0f);
6+
floatv16 a_silu = simd_div(simd_cvt_h2f(*b), simd_add(one_v, simd_exp(0 - simd_cvt_h2f(*b))));
7+
halfv16 out = simd_cvt_f2h(simd_mul(simd_cvt_h2f(*a), a_silu));
8+
*c = out;
9+
}
10+
11+
__device__ void silu_half(half *c, const half *a, const half *b) {
12+
*c = (*b) * (*a)/ (1.0 + expf(0 - *b)) ;
13+
}
14+
15+
__global__ void swiglu(half *c, half const *a, half const *b, size_t rows, size_t cols, size_t lda, size_t ldb, size_t ldc) {
16+
int vector_size = 16;
17+
for (size_t i = 0; i < rows / threadDim + 1; i++) {
18+
if (threadIdx < rows - i * threadDim) {
19+
size_t j = 0;
20+
for (; j < cols / vector_size; j++) {
21+
simd_load(tempa, a + (threadIdx + i * threadDim) * lda + j * vector_size);
22+
simd_load(tempb, b + (threadIdx + i * threadDim) * ldb + j * vector_size);
23+
silu_halfv16(&tempc, &tempa, &tempb);
24+
simd_store(tempc, c + (threadIdx + i * threadDim) * ldc + j * vector_size);
25+
}
26+
for (size_t k = 0; k < cols - j * vector_size; k++)
27+
{
28+
silu_half(
29+
c + (threadIdx + i * threadDim) * ldc + j * vector_size + k,
30+
a + (threadIdx + i * threadDim) * lda + j * vector_size + k,
31+
b + (threadIdx + i * threadDim) * ldb + j * vector_size + k);
32+
}
33+
34+
}
35+
}
36+
}
37+
38+
infiniopStatus_t tecoCreateSwiGLUDescriptor(TecoHandle_t handle,
39+
SwiGLUTecoDescriptor_t *desc_ptr,
40+
infiniopTensorDescriptor_t c_desc,
41+
infiniopTensorDescriptor_t a_desc,
42+
infiniopTensorDescriptor_t b_desc) {
43+
*desc_ptr = new SwiGLUTecoDescriptor{
44+
handle->device,
45+
handle->device_id,
46+
handle->stream,
47+
a_desc->shape[0],
48+
a_desc->shape[1],
49+
a_desc->strides[0],
50+
b_desc->strides[0],
51+
c_desc->strides[0],
52+
};
53+
return STATUS_SUCCESS;
54+
}
55+
56+
infiniopStatus_t tecoSwiGLU(SwiGLUTecoDescriptor_t desc,
57+
void *c,
58+
void const *a,
59+
void const *b,
60+
void *stream) {
61+
auto a_ptr = reinterpret_cast<const half *>(a);
62+
auto b_ptr = reinterpret_cast<const half *>(b);
63+
auto c_ptr = reinterpret_cast<half *>(c);
64+
swiglu<<<1>>>(c_ptr, a_ptr, b_ptr, desc->rows, desc->cols, desc->lda, desc->ldb, desc->ldc);
65+
return STATUS_SUCCESS;
66+
}
67+
68+
infiniopStatus_t tecoDestroySwiGLUDescriptor(SwiGLUTecoDescriptor_t desc) {
69+
return STATUS_SUCCESS;
70+
}

src/ops/swiglu/teco/swiglu_tecodnn.cc

Lines changed: 0 additions & 44 deletions
This file was deleted.

xmake.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ if has_config("teco") then
243243
local cc = "/opt/tecoai/bin/tecocc"
244244

245245
local includedirs = table.concat(target:get("includedirs"), " ")
246-
local args = {sourcefile, "-o", objectfile}
246+
local args = {sourcefile, "-o", objectfile,"-O2", "-fPIC", "-Wall", "-Werror", "-std=c++17", "-pthread","-c"}
247247

248248
for _, includedir in ipairs(target:get("includedirs")) do
249249
table.insert(args, "-I" .. includedir)

0 commit comments

Comments
 (0)