Skip to content

Commit 89e49e3

Browse files
Merge pull request #58 from PanZezhong1725/issue/48-test
issue/48/test: 重构rope测试脚本
2 parents b3941ed + eb1ae65 commit 89e49e3

File tree

1 file changed

+80
-94
lines changed

1 file changed

+80
-94
lines changed

test/infiniop/rotary_embedding.py

Lines changed: 80 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,31 @@
11
import ctypes
2-
from ctypes import c_float, POINTER, c_void_p, c_int32, c_uint64, Structure, byref
2+
from ctypes import POINTER, c_void_p, c_int32, c_uint64, Structure, byref
33
import sys
44
import os
55

6-
76
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
8-
from operatorspy import (
9-
open_lib,
10-
to_tensor,
11-
DeviceEnum,
7+
from libinfiniop import (
128
infiniopHandle_t,
139
infiniopTensorDescriptor_t,
14-
create_handle,
15-
destroy_handle,
10+
open_lib,
11+
to_tensor,
12+
get_test_devices,
1613
check_error,
17-
rearrange_tensor,
14+
rearrange_if_needed,
1815
create_workspace,
19-
U64,
16+
test_operator,
17+
get_args,
18+
debug,
19+
profile_operation,
20+
InfiniDtype,
2021
)
21-
22-
from operatorspy.tests.test_utils import get_args
2322
import torch
2423

24+
DEBUG = False
25+
PROFILE = False
26+
NUM_PRERUN = 10
27+
NUM_ITERATIONS = 1000
28+
2529

2630
class RoPEDescriptor(Structure):
2731
_fields_ = [("device", c_int32)]
@@ -40,15 +44,21 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
4044

4145
def rotary_embedding(t, pos, theta, torch_device):
4246
dh = t.shape[2]
43-
freqs = (1.0 / (theta ** (torch.arange(0, dh, 2)[: (dh // 2)].float() / dh))).to(
44-
torch_device
45-
)
46-
freqs = torch.outer(pos, freqs)
47-
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
47+
assert dh % 2 == 0, "Embedding dimension must be even."
48+
t_even = t[..., 0::2] # [seq_len, n_head, dh // 2]
49+
t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
50+
freqs = (1.0 / (theta ** (torch.arange(0, dh, 2).float() / dh))).to(torch_device)
51+
freqs = torch.outer(pos, freqs) # [seq_len, dh // 2]
52+
cos = torch.cos(freqs).unsqueeze(1) # [seq_len, 1, dh // 2]
53+
sin = torch.sin(freqs).unsqueeze(1) # [seq_len, 1, dh // 2]
54+
55+
t_out_even = t_even * cos - t_odd * sin
56+
t_out_odd = t_even * sin + t_odd * cos
57+
58+
t_out = torch.empty_like(t)
59+
t_out[..., 0::2] = t_out_even
60+
t_out[..., 1::2] = t_out_odd
4861

49-
t_ = torch.view_as_complex(t.reshape(*t.shape[:-1], -1, 2))
50-
freqs_cis = reshape_for_broadcast(freqs_cis, t_)
51-
t_out = torch.view_as_real(t_ * freqs_cis).flatten(2).to(t.dtype)
5262
return t_out
5363

5464

@@ -71,29 +81,23 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
7181
)
7282

7383
t = torch.rand(shape, dtype=dtype)
74-
if strides is not None:
75-
t = rearrange_tensor(t, strides)
76-
posTmp = torch.arange(0, t.shape[0])
84+
t = rearrange_if_needed(t, strides).to(torch_device)
85+
posTmp = torch.arange(0, t.shape[0]).to(torch_device)
7786
pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32)
7887
for i in range(posTmp.shape[0]):
7988
pos[2 * i] = posTmp[i]
8089
pos[2 * i + 1] = 0
90+
pos = pos.to(torch_device)
8191
theta = 1e4
82-
if torch_device == "mlu" or torch_device == "npu":
83-
ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device)
84-
pos = pos.to(torch_device)
85-
t = t.to(torch_device)
86-
else:
87-
t = t.to(torch_device)
88-
pos = pos.to(torch_device)
89-
ans = rotary_embedding(t, posTmp.to(torch_device), theta, torch_device)
92+
93+
ans = rotary_embedding(t, posTmp, theta, torch_device)
9094

9195
descriptor = infiniopRoPEDescriptor_t()
9296
# 2x table length for test
9397
sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta)
9498
t_tensor = to_tensor(t, lib)
9599
pos_tensor = to_tensor(pos[: t.shape[0]], lib)
96-
pos_tensor.descriptor.contents.dt = U64
100+
pos_tensor.descriptor.contents.dtype = InfiniDtype.U64
97101
sin_table_tensor = to_tensor(sin_table, lib)
98102
cos_table_tensor = to_tensor(cos_table, lib)
99103

@@ -122,69 +126,52 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
122126
lib.infiniopGetRoPEWorkspaceSize(descriptor, ctypes.byref(workspace_size))
123127
)
124128
workspace = create_workspace(workspace_size.value, t.device)
125-
check_error(
126-
lib.infiniopRoPE(
127-
descriptor,
128-
workspace.data_ptr() if workspace is not None else None,
129-
workspace_size.value,
130-
t_tensor.data,
131-
pos_tensor.data,
132-
sin_table_tensor.data,
133-
cos_table_tensor.data,
134-
None,
129+
130+
def lib_rope():
131+
check_error(
132+
lib.infiniopRoPE(
133+
descriptor,
134+
workspace.data_ptr() if workspace is not None else None,
135+
workspace_size.value,
136+
t_tensor.data,
137+
pos_tensor.data,
138+
sin_table_tensor.data,
139+
cos_table_tensor.data,
140+
None,
141+
)
135142
)
136-
)
137143

144+
lib_rope()
145+
if DEBUG:
146+
debug(t, ans, atol=1e-4, rtol=1e-2)
138147
assert torch.allclose(t, ans, atol=1e-4, rtol=1e-2)
139-
check_error(lib.infiniopDestroyRoPEDescriptor(descriptor))
140-
141-
142-
def test_cpu(lib, test_cases):
143-
device = DeviceEnum.DEVICE_CPU
144-
handle = create_handle(lib, device)
145-
for shape, strides, dtype in test_cases:
146-
test(lib, handle, "cpu", shape, strides, dtype)
147-
destroy_handle(lib, handle)
148-
149-
150-
def test_cuda(lib, test_cases):
151-
device = DeviceEnum.DEVICE_CUDA
152-
handle = create_handle(lib, device)
153-
for shape, strides, dtype in test_cases:
154-
test(lib, handle, "cuda", shape, strides, dtype)
155-
destroy_handle(lib, handle)
156-
157-
158-
def test_bang(lib, test_cases):
159-
import torch_mlu
160-
161-
device = DeviceEnum.DEVICE_BANG
162-
handle = create_handle(lib, device)
163-
for shape, strides, dtype in test_cases:
164-
test(lib, handle, "mlu", shape, strides, dtype)
165-
destroy_handle(lib, handle)
166-
167-
168-
def test_ascend(lib, test_cases):
169-
import torch_npu
148+
if PROFILE:
149+
profile_operation(
150+
"PyTorch",
151+
lambda: rotary_embedding(t, posTmp, theta, torch_device),
152+
torch_device,
153+
NUM_PRERUN,
154+
NUM_ITERATIONS,
155+
)
156+
profile_operation(
157+
" lib", lambda: lib_rope(), torch_device, NUM_PRERUN, NUM_ITERATIONS
158+
)
170159

171-
device = DeviceEnum.DEVICE_ASCEND
172-
handle = create_handle(lib, device)
173-
for shape, strides, dtype in test_cases:
174-
test(lib, handle, "npu", shape, strides, dtype)
175-
destroy_handle(lib, handle)
160+
check_error(lib.infiniopDestroyRoPEDescriptor(descriptor))
176161

177162

178163
if __name__ == "__main__":
179164
test_cases = [
180-
((1, 32, 128), None, torch.float16),
181-
((1, 32, 64), None, torch.float16),
165+
# (t_shape, t_strides)
166+
((1, 32, 128), None),
167+
((1, 32, 64), None),
182168
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
183169
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
184-
((4, 1, 32), None, torch.float16),
185-
((1, 32, 128), None, torch.float16),
186-
((3, 32, 128), (8000, 200, 1), torch.float16),
170+
((4, 1, 32), None),
171+
((1, 32, 128), None),
172+
((3, 32, 128), (8000, 200, 1)),
187173
]
174+
test_dtypes = [torch.float16]
188175
args = get_args()
189176
lib = open_lib()
190177
lib.infiniopCreateRoPEDescriptor.restype = c_int32
@@ -216,14 +203,13 @@ def test_ascend(lib, test_cases):
216203
lib.infiniopDestroyRoPEDescriptor.argtypes = [
217204
infiniopRoPEDescriptor_t,
218205
]
219-
if args.cpu:
220-
test_cpu(lib, test_cases)
221-
if args.cuda:
222-
test_cuda(lib, test_cases)
223-
if args.bang:
224-
test_bang(lib, test_cases)
225-
if args.ascend:
226-
test_ascend(lib, test_cases)
227-
if not (args.cpu or args.cuda or args.bang or args.ascend):
228-
test_cpu(lib, test_cases)
206+
# Configure testing options
207+
DEBUG = args.debug
208+
PROFILE = args.profile
209+
NUM_PRERUN = args.num_prerun
210+
NUM_ITERATIONS = args.num_iterations
211+
212+
# Execute tests
213+
for device in get_test_devices(args):
214+
test_operator(lib, device, test, test_cases, test_dtypes)
229215
print("\033[92mTest passed!\033[0m")

0 commit comments

Comments
 (0)