@@ -45,7 +45,7 @@ def rotary_embedding(t, pos, theta, torch_device):
4545 )
4646 freqs = torch .outer (pos , freqs )
4747 freqs_cis = torch .polar (torch .ones_like (freqs ), freqs )
48-
48+
4949 t_ = torch .view_as_complex (t .reshape (* t .shape [:- 1 ], - 1 , 2 ))
5050 freqs_cis = reshape_for_broadcast (freqs_cis , t_ )
5151 t_out = torch .view_as_real (t_ * freqs_cis ).flatten (2 ).to (t .dtype )
@@ -69,19 +69,31 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
6969 print (
7070 f"Testing Rotary Positional Embedding on { torch_device } with shape:{ shape } strides:{ strides } and dtype:{ dtype } "
7171 )
72- t = torch .rand (shape , dtype = dtype , device = torch .device (torch_device ))
72+
73+ t = torch .rand (shape , dtype = dtype )
7374 if strides is not None :
7475 t = rearrange_tensor (t , strides )
75- pos = torch .arange (0 , t .shape [0 ], device = torch . device ( torch_device ) )
76+ pos = torch .arange (0 , t .shape [0 ])
7677 theta = 1e4
77- ans = rotary_embedding (t , pos , theta , torch_device )
78- pos = pos .to (torch .int64 ) # use int64 to support older versions of PyTorch
78+
79+ if (torch_device == 'mlu' ):
80+ ans = rotary_embedding (t , pos , theta , "cpu" ).to (torch_device )
81+ pos = pos .to (torch .int64 )
82+ pos = pos .to (torch_device )
83+ t = t .to (torch_device )
84+ else :
85+ t = t .to (torch_device )
86+ pos = pos .to (torch_device )
87+ ans = rotary_embedding (t , pos , theta , torch_device )
88+ pos = pos .to (torch .uint64 )
89+
7990 descriptor = infiniopRoPEDescriptor_t ()
8091 # 2x table length for test
8192 sin_table , cos_table = sin_cos_table (t .shape [0 ] * 2 , t .shape [2 ], t .device , theta )
8293 t_tensor = to_tensor (t , lib )
8394 pos_tensor = to_tensor (pos , lib )
84- pos_tensor .descriptor .contents .dt = U64 # treat int64 as uint64
95+ if (torch_device == 'mlu' ):
96+ pos_tensor .descriptor .contents .dt = U64
8597 sin_table_tensor = to_tensor (sin_table , lib )
8698 cos_table_tensor = to_tensor (cos_table , lib )
8799 check_error (
@@ -111,7 +123,7 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
111123 None ,
112124 )
113125 )
114-
126+
115127 assert torch .allclose (t , ans , atol = 1e-4 , rtol = 1e-2 )
116128 check_error (lib .infiniopDestroyRoPEDescriptor (descriptor ))
117129 print ("Test passed!" )
@@ -135,32 +147,18 @@ def test_cuda(lib, test_cases):
135147
136148def test_bang (lib , test_cases ):
137149 import torch_mlu
138-
139150 device = DeviceEnum .DEVICE_BANG
140- config = None
141- descriptor = lib .createRotaryEmbeddingDescriptor (device , config )
142-
143- # Note: BANG does not support complex calculation, compare with cpu results
144- t = torch .rand ((1 , 32 , 128 ), dtype = torch .float16 )
145- pos = torch .ones ((1 ,), dtype = torch .int32 )
146- theta = 1e4
147- ans = rotary_embedding (t , pos , theta , "cpu" )
148-
149- t = t .to ("mlu" )
150- pos = pos .to ("mlu" )
151- lib .rotaryEmbedding (
152- descriptor , to_tensor (t , lib ), to_tensor (pos , lib ), c_float (theta ), None
153- )
154- assert torch .allclose (t .cpu (), ans , atol = 1e-3 , rtol = 1e-3 )
155- print ("Test passed!" )
156-
157- lib .destroyRotaryEmbeddingDescriptor (descriptor )
151+ handle = create_handle (lib , device )
152+ for shape , strides , dtype in test_cases :
153+ test (lib , handle , "mlu" , shape , strides , dtype )
154+ destroy_handle (lib , handle )
158155
159156
160157if __name__ == "__main__" :
161158 test_cases = [
162- ((1 , 32 , 128 ), None , torch .float16 ),
163159 ((4 , 1 , 32 ), None , torch .float16 ),
160+ ((1 , 32 , 128 ), None , torch .float16 ),
161+
164162 ((3 , 32 , 128 ), (8000 , 200 , 1 ), torch .float16 ),
165163 ]
166164 args = get_args ()
0 commit comments