11import ctypes
2- from ctypes import c_float , POINTER , c_void_p , c_int32 , c_uint64 , Structure , byref
2+ from ctypes import POINTER , c_void_p , c_int32 , c_uint64 , Structure , byref
33import sys
44import os
55
6-
76sys .path .insert (0 , os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." , ".." )))
8- from operatorspy import (
9- open_lib ,
10- to_tensor ,
11- DeviceEnum ,
7+ from libinfiniop import (
128 infiniopHandle_t ,
139 infiniopTensorDescriptor_t ,
14- create_handle ,
15- destroy_handle ,
10+ open_lib ,
11+ to_tensor ,
12+ get_test_devices ,
1613 check_error ,
17- rearrange_tensor ,
14+ rearrange_if_needed ,
1815 create_workspace ,
19- U64 ,
16+ test_operator ,
17+ get_args ,
18+ debug ,
19+ profile_operation ,
20+ InfiniDtype ,
2021)
21-
22- from operatorspy .tests .test_utils import get_args
2322import torch
2423
24+ DEBUG = False
25+ PROFILE = False
26+ NUM_PRERUN = 10
27+ NUM_ITERATIONS = 1000
28+
2529
2630class RoPEDescriptor (Structure ):
2731 _fields_ = [("device" , c_int32 )]
@@ -40,15 +44,21 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
4044
4145def rotary_embedding (t , pos , theta , torch_device ):
4246 dh = t .shape [2 ]
43- freqs = (1.0 / (theta ** (torch .arange (0 , dh , 2 )[: (dh // 2 )].float () / dh ))).to (
44- torch_device
45- )
46- freqs = torch .outer (pos , freqs )
47- freqs_cis = torch .polar (torch .ones_like (freqs ), freqs )
47+ assert dh % 2 == 0 , "Embedding dimension must be even."
48+ t_even = t [..., 0 ::2 ] # [seq_len, n_head, dh // 2]
49+ t_odd = t [..., 1 ::2 ] # [seq_len, n_head, dh // 2]
50+ freqs = (1.0 / (theta ** (torch .arange (0 , dh , 2 ).float () / dh ))).to (torch_device )
51+ freqs = torch .outer (pos , freqs ) # [seq_len, dh // 2]
52+ cos = torch .cos (freqs ).unsqueeze (1 ) # [seq_len, 1, dh // 2]
53+ sin = torch .sin (freqs ).unsqueeze (1 ) # [seq_len, 1, dh // 2]
54+
55+ t_out_even = t_even * cos - t_odd * sin
56+ t_out_odd = t_even * sin + t_odd * cos
57+
58+ t_out = torch .empty_like (t )
59+ t_out [..., 0 ::2 ] = t_out_even
60+ t_out [..., 1 ::2 ] = t_out_odd
4861
49- t_ = torch .view_as_complex (t .reshape (* t .shape [:- 1 ], - 1 , 2 ))
50- freqs_cis = reshape_for_broadcast (freqs_cis , t_ )
51- t_out = torch .view_as_real (t_ * freqs_cis ).flatten (2 ).to (t .dtype )
5262 return t_out
5363
5464
@@ -71,29 +81,23 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
7181 )
7282
7383 t = torch .rand (shape , dtype = dtype )
74- if strides is not None :
75- t = rearrange_tensor (t , strides )
76- posTmp = torch .arange (0 , t .shape [0 ])
84+ t = rearrange_if_needed (t , strides ).to (torch_device )
85+ posTmp = torch .arange (0 , t .shape [0 ]).to (torch_device )
7786 pos = torch .zeros (2 * posTmp .shape [0 ], dtype = torch .int32 )
7887 for i in range (posTmp .shape [0 ]):
7988 pos [2 * i ] = posTmp [i ]
8089 pos [2 * i + 1 ] = 0
90+ pos = pos .to (torch_device )
8191 theta = 1e4
82- if torch_device == "mlu" or torch_device == "npu" :
83- ans = rotary_embedding (t , posTmp , theta , "cpu" ).to (torch_device )
84- pos = pos .to (torch_device )
85- t = t .to (torch_device )
86- else :
87- t = t .to (torch_device )
88- pos = pos .to (torch_device )
89- ans = rotary_embedding (t , posTmp .to (torch_device ), theta , torch_device )
92+
93+ ans = rotary_embedding (t , posTmp , theta , torch_device )
9094
9195 descriptor = infiniopRoPEDescriptor_t ()
9296 # 2x table length for test
9397 sin_table , cos_table = sin_cos_table (t .shape [0 ] * 2 , t .shape [2 ], t .device , theta )
9498 t_tensor = to_tensor (t , lib )
9599 pos_tensor = to_tensor (pos [: t .shape [0 ]], lib )
96- pos_tensor .descriptor .contents .dt = U64
100+ pos_tensor .descriptor .contents .dtype = InfiniDtype . U64
97101 sin_table_tensor = to_tensor (sin_table , lib )
98102 cos_table_tensor = to_tensor (cos_table , lib )
99103
@@ -122,69 +126,52 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
122126 lib .infiniopGetRoPEWorkspaceSize (descriptor , ctypes .byref (workspace_size ))
123127 )
124128 workspace = create_workspace (workspace_size .value , t .device )
125- check_error (
126- lib .infiniopRoPE (
127- descriptor ,
128- workspace .data_ptr () if workspace is not None else None ,
129- workspace_size .value ,
130- t_tensor .data ,
131- pos_tensor .data ,
132- sin_table_tensor .data ,
133- cos_table_tensor .data ,
134- None ,
129+
130+ def lib_rope ():
131+ check_error (
132+ lib .infiniopRoPE (
133+ descriptor ,
134+ workspace .data_ptr () if workspace is not None else None ,
135+ workspace_size .value ,
136+ t_tensor .data ,
137+ pos_tensor .data ,
138+ sin_table_tensor .data ,
139+ cos_table_tensor .data ,
140+ None ,
141+ )
135142 )
136- )
137143
144+ lib_rope ()
145+ if DEBUG :
146+ debug (t , ans , atol = 1e-4 , rtol = 1e-2 )
138147 assert torch .allclose (t , ans , atol = 1e-4 , rtol = 1e-2 )
139- check_error (lib .infiniopDestroyRoPEDescriptor (descriptor ))
140-
141-
142- def test_cpu (lib , test_cases ):
143- device = DeviceEnum .DEVICE_CPU
144- handle = create_handle (lib , device )
145- for shape , strides , dtype in test_cases :
146- test (lib , handle , "cpu" , shape , strides , dtype )
147- destroy_handle (lib , handle )
148-
149-
150- def test_cuda (lib , test_cases ):
151- device = DeviceEnum .DEVICE_CUDA
152- handle = create_handle (lib , device )
153- for shape , strides , dtype in test_cases :
154- test (lib , handle , "cuda" , shape , strides , dtype )
155- destroy_handle (lib , handle )
156-
157-
158- def test_bang (lib , test_cases ):
159- import torch_mlu
160-
161- device = DeviceEnum .DEVICE_BANG
162- handle = create_handle (lib , device )
163- for shape , strides , dtype in test_cases :
164- test (lib , handle , "mlu" , shape , strides , dtype )
165- destroy_handle (lib , handle )
166-
167-
168- def test_ascend (lib , test_cases ):
169- import torch_npu
148+ if PROFILE :
149+ profile_operation (
150+ "PyTorch" ,
151+ lambda : rotary_embedding (t , posTmp , theta , torch_device ),
152+ torch_device ,
153+ NUM_PRERUN ,
154+ NUM_ITERATIONS ,
155+ )
156+ profile_operation (
157+ " lib" , lambda : lib_rope (), torch_device , NUM_PRERUN , NUM_ITERATIONS
158+ )
170159
171- device = DeviceEnum .DEVICE_ASCEND
172- handle = create_handle (lib , device )
173- for shape , strides , dtype in test_cases :
174- test (lib , handle , "npu" , shape , strides , dtype )
175- destroy_handle (lib , handle )
160+ check_error (lib .infiniopDestroyRoPEDescriptor (descriptor ))
176161
177162
178163if __name__ == "__main__" :
179164 test_cases = [
180- ((1 , 32 , 128 ), None , torch .float16 ),
181- ((1 , 32 , 64 ), None , torch .float16 ),
165+ # (t_shape, t_strides)
166+ ((1 , 32 , 128 ), None ),
167+ ((1 , 32 , 64 ), None ),
182168 # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
183169 # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
184- ((4 , 1 , 32 ), None , torch . float16 ),
185- ((1 , 32 , 128 ), None , torch . float16 ),
186- ((3 , 32 , 128 ), (8000 , 200 , 1 ), torch . float16 ),
170+ ((4 , 1 , 32 ), None ),
171+ ((1 , 32 , 128 ), None ),
172+ ((3 , 32 , 128 ), (8000 , 200 , 1 )),
187173 ]
174+ test_dtypes = [torch .float16 ]
188175 args = get_args ()
189176 lib = open_lib ()
190177 lib .infiniopCreateRoPEDescriptor .restype = c_int32
@@ -216,14 +203,13 @@ def test_ascend(lib, test_cases):
216203 lib .infiniopDestroyRoPEDescriptor .argtypes = [
217204 infiniopRoPEDescriptor_t ,
218205 ]
219- if args .cpu :
220- test_cpu (lib , test_cases )
221- if args .cuda :
222- test_cuda (lib , test_cases )
223- if args .bang :
224- test_bang (lib , test_cases )
225- if args .ascend :
226- test_ascend (lib , test_cases )
227- if not (args .cpu or args .cuda or args .bang or args .ascend ):
228- test_cpu (lib , test_cases )
206+ # Configure testing options
207+ DEBUG = args .debug
208+ PROFILE = args .profile
209+ NUM_PRERUN = args .num_prerun
210+ NUM_ITERATIONS = args .num_iterations
211+
212+ # Execute tests
213+ for device in get_test_devices (args ):
214+ test_operator (lib , device , test , test_cases , test_dtypes )
229215 print ("\033 [92mTest passed!\033 [0m" )
0 commit comments