@@ -31,47 +31,32 @@ class RandomSampleDescriptor(Structure):
3131
3232def random_sample (data , random_val , topp , topk , voc , temperature , torch_device ):
3333 indices = torch .zeros ([topk ], dtype = torch .int64 )
34- dataNp = data .clone ().detach ()
35- sorted_indices = torch .arange (voc )
36-
37- for i in range (topk ):
38- for j in range (i + 1 , voc ):
39- if (dataNp [i ] < dataNp [j ]):
40- tmp = dataNp [i ].clone ().detach ()
41- dataNp [i ] = dataNp [j ].clone ().detach ()
42- dataNp [j ] = tmp
43-
44- tmpInd = sorted_indices [i ].clone ().detach ()
45- sorted_indices [i ] = sorted_indices [j ].clone ().detach ()
46- sorted_indices [j ] = tmpInd
34+ dataNp = data .clone ()
4735
48- # sorted_indices = torch.argsort(dataNp, descending=True)
36+ sorted_indices = torch .argsort (dataNp , descending = True )
4937 indices = sorted_indices [:topk ]
5038
5139 dataNp = dataNp [sorted_indices ]
5240
5341 globalM = dataNp [0 ]
5442 dataNp = (dataNp - globalM ) / temperature
5543 dataNp = torch .softmax (dataNp .float (), dim = 0 )
56- sum_s = 0
44+
45+ for i in range (1 , topk ):
46+ dataNp [i ] = dataNp [i ] + dataNp [i - 1 ]
47+
5748 for end in range (topk ):
58- sum_s += dataNp [end ]
59- if (sum_s >= topp ):
49+ if (dataNp [end ] >= topp ):
6050 break
6151 if (end < topk - 1 ):
6252 end += 1
6353 else :
6454 end = topk
6555
66- sum_s = 0
67- for i in range (end ):
68- sum_s += dataNp [i ]
69- random_val *= sum_s
56+ random_val *= dataNp [end - 1 ]
7057
71- sum_s = 0
7258 for i in range (end ):
73- sum_s += dataNp [i ]
74- if (random_val < sum_s ):
59+ if (random_val < dataNp [i ]):
7560 return indices [i ]
7661
7762def random_sample_0 (data ):
@@ -129,7 +114,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
129114 )
130115 if torch_device == "npu" :
131116 torch .npu .synchronize ()
132-
117+
133118 assert indices [0 ].type (ans .dtype ) == ans or data [ans ] == data [indices [0 ]]
134119 check_error (lib .infiniopDestroyRandomSampleDescriptor (descriptor ))
135120 print ("Test passed!" )
@@ -168,7 +153,13 @@ def test_ascend(lib, test_cases):
168153 test (lib , handle , "npu" , voc , random_val , topp , topk , temperature )
169154 destroy_handle (lib , handle )
170155
171-
156+ def test_teco (lib , test_cases ):
157+ import torch_sdaa
158+ device = DeviceEnum .DEVICE_TECO
159+ handle = create_handle (lib , device )
160+ for (voc , random_val , topp , topk , temperature ) in test_cases :
161+ test (lib , handle , "sdaa" , voc , random_val , topp , topk , temperature )
162+ destroy_handle (lib , handle )
172163
173164if __name__ == "__main__" :
174165 test_cases = [
@@ -224,6 +215,9 @@ def test_ascend(lib, test_cases):
224215 test_bang (lib , test_cases )
225216 if args .ascend :
226217 test_ascend (lib , test_cases )
227- if not (args .cpu or args .cuda or args .bang or args .ascend ):
218+ if args .teco :
219+ test_teco (lib , test_cases )
220+
221+ if not (args .cpu or args .cuda or args .bang or args .ascend or args .teco ):
228222 test_cpu (lib , test_cases )
229223 print ("\033 [92mTest passed!\033 [0m" )
0 commit comments