Skip to content

Commit 4a7075f

Browse files
Merge pull request #80 from PanZezhong1725/bangRoPE
bangRoPE
2 parents 49ee9f2 + d629df9 commit 4a7075f

File tree

13 files changed

+603
-333
lines changed

13 files changed

+603
-333
lines changed

operatorspy/tests/rotary_embedding.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def rotary_embedding(t, pos, theta, torch_device):
4545
)
4646
freqs = torch.outer(pos, freqs)
4747
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
48-
48+
4949
t_ = torch.view_as_complex(t.reshape(*t.shape[:-1], -1, 2))
5050
freqs_cis = reshape_for_broadcast(freqs_cis, t_)
5151
t_out = torch.view_as_real(t_ * freqs_cis).flatten(2).to(t.dtype)
@@ -69,19 +69,31 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
6969
print(
7070
f"Testing Rotary Positional Embedding on {torch_device} with shape:{shape} strides:{strides} and dtype:{dtype}"
7171
)
72-
t = torch.rand(shape, dtype=dtype, device=torch.device(torch_device))
72+
73+
t = torch.rand(shape, dtype=dtype)
7374
if strides is not None:
7475
t = rearrange_tensor(t, strides)
75-
pos = torch.arange(0, t.shape[0], device=torch.device(torch_device))
76+
pos = torch.arange(0, t.shape[0])
7677
theta = 1e4
77-
ans = rotary_embedding(t, pos, theta, torch_device)
78-
pos = pos.to(torch.int64) # use int64 to support older versions of PyTorch
78+
79+
if(torch_device == 'mlu'):
80+
ans = rotary_embedding(t, pos, theta, "cpu").to(torch_device)
81+
pos = pos.to(torch.int64)
82+
pos = pos.to(torch_device)
83+
t = t.to(torch_device)
84+
else:
85+
t = t.to(torch_device)
86+
pos = pos.to(torch_device)
87+
ans = rotary_embedding(t, pos, theta, torch_device)
88+
pos = pos.to(torch.uint64)
89+
7990
descriptor = infiniopRoPEDescriptor_t()
8091
# 2x table length for test
8192
sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta)
8293
t_tensor = to_tensor(t, lib)
8394
pos_tensor = to_tensor(pos, lib)
84-
pos_tensor.descriptor.contents.dt = U64 # treat int64 as uint64
95+
if(torch_device == 'mlu'):
96+
pos_tensor.descriptor.contents.dt = U64
8597
sin_table_tensor = to_tensor(sin_table, lib)
8698
cos_table_tensor = to_tensor(cos_table, lib)
8799
check_error(
@@ -111,7 +123,7 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
111123
None,
112124
)
113125
)
114-
126+
115127
assert torch.allclose(t, ans, atol=1e-4, rtol=1e-2)
116128
check_error(lib.infiniopDestroyRoPEDescriptor(descriptor))
117129
print("Test passed!")
@@ -135,32 +147,18 @@ def test_cuda(lib, test_cases):
135147

136148
def test_bang(lib, test_cases):
137149
import torch_mlu
138-
139150
device = DeviceEnum.DEVICE_BANG
140-
config = None
141-
descriptor = lib.createRotaryEmbeddingDescriptor(device, config)
142-
143-
# Note: BANG does not support complex calculation, compare with cpu results
144-
t = torch.rand((1, 32, 128), dtype=torch.float16)
145-
pos = torch.ones((1,), dtype=torch.int32)
146-
theta = 1e4
147-
ans = rotary_embedding(t, pos, theta, "cpu")
148-
149-
t = t.to("mlu")
150-
pos = pos.to("mlu")
151-
lib.rotaryEmbedding(
152-
descriptor, to_tensor(t, lib), to_tensor(pos, lib), c_float(theta), None
153-
)
154-
assert torch.allclose(t.cpu(), ans, atol=1e-3, rtol=1e-3)
155-
print("Test passed!")
156-
157-
lib.destroyRotaryEmbeddingDescriptor(descriptor)
151+
handle = create_handle(lib, device)
152+
for shape, strides, dtype in test_cases:
153+
test(lib, handle, "mlu", shape, strides, dtype)
154+
destroy_handle(lib, handle)
158155

159156

160157
if __name__ == "__main__":
161158
test_cases = [
162-
((1, 32, 128), None, torch.float16),
163159
((4, 1, 32), None, torch.float16),
160+
((1, 32, 128), None, torch.float16),
161+
164162
((3, 32, 128), (8000, 200, 1), torch.float16),
165163
]
166164
args = get_args()

src/devices/bang/handle_pool.cc

Lines changed: 0 additions & 23 deletions
This file was deleted.

src/devices/bang/handle_pool.h

Lines changed: 0 additions & 23 deletions
This file was deleted.

src/ops/matmul/bang/matmul_cnnl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "matmul_cnnl.h"
2+
#include "../../../devices/bang/bang_handle.h"
23
#include "../../../devices/bang/common_bang.h"
3-
#include "../../../devices/bang/handle_pool.h"
44
#include "../../utils.h"
55
#include "cnrt.h"
66
infiniopStatus_t bangCreateMatmulDescriptor(BangHandle_t handle,
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#include "rotary_embedding_bang.h"
2+
#include "../../utils.h"
3+
4+
5+
infiniopStatus_t bangCreateRoPEDescriptor(BangHandle_t handle,
6+
RoPEBangDescriptor_t *desc_ptr,
7+
infiniopTensorDescriptor_t t,
8+
infiniopTensorDescriptor_t pos_ids,
9+
infiniopTensorDescriptor_t sin_table,
10+
infiniopTensorDescriptor_t cos_table) {
11+
12+
if (desc_ptr == nullptr)
13+
return STATUS_MEMORY_NOT_ALLOCATED;
14+
15+
if (t->ndim != 3 ||
16+
pos_ids->ndim != 1 ||
17+
sin_table->ndim != 2 ||
18+
cos_table->ndim != 2)
19+
return STATUS_BAD_TENSOR_SHAPE;
20+
21+
auto seq_len = t->shape[0];
22+
auto nhead = t->shape[1];
23+
auto dim = t->shape[2];
24+
auto total_seq_len = sin_table->shape[0];
25+
26+
if (dim % 2 != 0)
27+
return STATUS_BAD_TENSOR_SHAPE;
28+
29+
if (pos_ids->shape[0] != seq_len ||
30+
sin_table->shape[1] != dim ||
31+
cos_table->shape[1] != dim ||
32+
sin_table->shape[0] != cos_table->shape[0])
33+
return STATUS_BAD_TENSOR_SHAPE;
34+
35+
if (t->strides[2] != 1 ||
36+
pos_ids->strides[0] != 1 ||
37+
sin_table->strides[1] != 1 ||
38+
cos_table->strides[1] != 1)
39+
return STATUS_BAD_TENSOR_STRIDES;
40+
41+
if (!dtype_eq(t->dt, F16))
42+
return STATUS_BAD_TENSOR_DTYPE;
43+
44+
if (!dtype_eq(sin_table->dt, F32) || !dtype_eq(cos_table->dt, F32))
45+
return STATUS_BAD_TENSOR_DTYPE;
46+
47+
if (!dtype_eq(pos_ids->dt, U64))
48+
return STATUS_BAD_TENSOR_DTYPE;
49+
int stride_0 = static_cast<int>(t->strides[0]);
50+
int stride_1 = static_cast<int>(t->strides[1]);
51+
*desc_ptr = new RoPEBangDescriptor{
52+
handle->device,
53+
handle->device_id,
54+
t->dt,
55+
seq_len,
56+
nhead,
57+
dim,
58+
total_seq_len,
59+
stride_0, stride_1};
60+
61+
return STATUS_SUCCESS;
62+
}
63+
64+
65+
infiniopStatus_t bangGetRoPEWorkspaceSize(RoPEBangDescriptor_t desc, uint64_t *size) {
66+
*size = 0;
67+
return STATUS_SUCCESS;
68+
}
69+
70+
71+
infiniopStatus_t bangDestroyRoPEDescriptor(RoPEBangDescriptor_t desc) {
72+
delete desc;
73+
return STATUS_SUCCESS;
74+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#ifndef __BANG_ROTARY_EMBEDDING_H__
2+
#define __BANG_ROTARY_EMBEDDING_H__
3+
4+
#include "../../../devices/bang/bang_handle.h"
5+
#include "../../utils.h"
6+
#include "operators.h"
7+
8+
struct RoPEBangDescriptor {
9+
Device device;
10+
int device_id;
11+
DT dtype;
12+
uint64_t seq_len;
13+
uint64_t nhead;
14+
uint64_t dim;
15+
uint64_t total_seq_len;
16+
int stride_0;
17+
int stride_1;
18+
};
19+
20+
21+
typedef struct RoPEBangDescriptor *RoPEBangDescriptor_t;
22+
23+
infiniopStatus_t bangCreateRoPEDescriptor(BangHandle_t handle,
24+
RoPEBangDescriptor_t *desc_ptr,
25+
infiniopTensorDescriptor_t t,
26+
infiniopTensorDescriptor_t pos_ids,
27+
infiniopTensorDescriptor_t sin_table,
28+
infiniopTensorDescriptor_t cos_table);
29+
30+
infiniopStatus_t bangGetRoPEWorkspaceSize(RoPEBangDescriptor_t desc, uint64_t *size);
31+
32+
infiniopStatus_t bangRoPE(RoPEBangDescriptor_t desc,
33+
void *workspace,
34+
uint64_t workspace_size,
35+
void *t,
36+
void const *pos_ids,
37+
void const *sin_table,
38+
void const *cos_table,
39+
void *stream);
40+
41+
infiniopStatus_t bangDestroyRoPEDescriptor(RoPEBangDescriptor_t desc);
42+
43+
44+
#endif// __BANG_RMS_NORM_H__

0 commit comments

Comments
 (0)