Skip to content

Commit eae024e

Browse files
committed
random_sample debug
1 parent c5b8f43 commit eae024e

File tree

4 files changed

+358
-28
lines changed

4 files changed

+358
-28
lines changed

operatorspy/tests/random_sample.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,47 +31,32 @@ class RandomSampleDescriptor(Structure):
3131

3232
def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
3333
indices = torch.zeros([topk], dtype = torch.int64)
34-
dataNp = data.clone().detach()
35-
sorted_indices = torch.arange(voc)
36-
37-
for i in range(topk):
38-
for j in range(i + 1, voc):
39-
if(dataNp[i] < dataNp[j]):
40-
tmp = dataNp[i].clone().detach()
41-
dataNp[i] = dataNp[j].clone().detach()
42-
dataNp[j] = tmp
43-
44-
tmpInd = sorted_indices[i].clone().detach()
45-
sorted_indices[i] = sorted_indices[j].clone().detach()
46-
sorted_indices[j] = tmpInd
34+
dataNp = data.clone()
4735

48-
#sorted_indices = torch.argsort(dataNp, descending=True)
36+
sorted_indices = torch.argsort(dataNp, descending=True)
4937
indices = sorted_indices[:topk]
5038

5139
dataNp = dataNp[sorted_indices]
5240

5341
globalM = dataNp[0]
5442
dataNp = (dataNp - globalM) / temperature
5543
dataNp = torch.softmax(dataNp.float(), dim = 0)
56-
sum_s = 0
44+
45+
for i in range(1, topk):
46+
dataNp[i] = dataNp[i] + dataNp[i - 1]
47+
5748
for end in range(topk):
58-
sum_s += dataNp[end]
59-
if(sum_s >= topp):
49+
if(dataNp[end] >= topp):
6050
break
6151
if(end < topk - 1):
6252
end += 1
6353
else:
6454
end = topk
6555

66-
sum_s = 0
67-
for i in range(end):
68-
sum_s += dataNp[i]
69-
random_val *= sum_s
56+
random_val *= dataNp[end - 1]
7057

71-
sum_s = 0
7258
for i in range(end):
73-
sum_s += dataNp[i]
74-
if(random_val < sum_s):
59+
if(random_val < dataNp[i]):
7560
return indices[i]
7661

7762
def random_sample_0(data):
@@ -129,7 +114,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
129114
)
130115
if torch_device == "npu":
131116
torch.npu.synchronize()
132-
117+
133118
assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
134119
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
135120
print("Test passed!")
@@ -168,7 +153,13 @@ def test_ascend(lib, test_cases):
168153
test(lib, handle, "npu", voc, random_val, topp, topk, temperature)
169154
destroy_handle(lib, handle)
170155

171-
156+
def test_teco(lib, test_cases):
157+
import torch_sdaa
158+
device = DeviceEnum.DEVICE_TECO
159+
handle = create_handle(lib, device)
160+
for (voc, random_val, topp, topk, temperature) in test_cases:
161+
test(lib, handle, "sdaa", voc, random_val, topp, topk, temperature)
162+
destroy_handle(lib, handle)
172163

173164
if __name__ == "__main__":
174165
test_cases = [
@@ -224,6 +215,9 @@ def test_ascend(lib, test_cases):
224215
test_bang(lib, test_cases)
225216
if args.ascend:
226217
test_ascend(lib, test_cases)
227-
if not (args.cpu or args.cuda or args.bang or args.ascend):
218+
if args.teco:
219+
test_teco(lib, test_cases)
220+
221+
if not (args.cpu or args.cuda or args.bang or args.ascend or args.teco):
228222
test_cpu(lib, test_cases)
229223
print("\033[92mTest passed!\033[0m")

src/ops/random_sample/operator.cc

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
#ifdef ENABLE_ASCEND_NPU
1515
#include "ascend/random_sample.h"
1616
#endif
17+
#ifdef ENABLE_TECO_SDAA
18+
#include "teco/random_sample_teco.h"
19+
#endif
1720

1821
__C infiniopStatus_t infiniopCreateRandomSampleDescriptor(infiniopHandle_t handle, infiniopRandomSampleDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, infiniopTensorDescriptor_t probs) {
1922
switch (handle->device) {
@@ -35,8 +38,14 @@ __C infiniopStatus_t infiniopCreateRandomSampleDescriptor(infiniopHandle_t handl
3538
#ifdef ENABLE_ASCEND_NPU
3639
case DevAscendNpu: {
3740
return ascendCreateRandomSampleDescriptor((AscendHandle_t) handle,
38-
(RandomSampleAscendDescriptor_t *) desc_ptr, result, probs);
41+
(RandomSampleAscendDescriptor_t *) desc_ptr, result, probs);
3942
}
43+
#endif
44+
#ifdef ENABLE_TECO_SDAA
45+
case DevTecoSDAA:
46+
return tecoCreateRandomSampleDescriptor((TecoHandle_t) handle,
47+
(RandomSampleTecoDescriptor_t *) desc_ptr, result, probs);
48+
;
4049
#endif
4150
}
4251
return STATUS_BAD_DEVICE;
@@ -64,6 +73,10 @@ __C infiniopStatus_t infiniopGetRandomSampleWorkspaceSize(infiniopRandomSampleDe
6473
case DevAscendNpu: {
6574
return ascendGetRandomSampleWorkspaceSize((RandomSampleAscendDescriptor_t) desc, size);
6675
}
76+
#endif
77+
#ifdef ENABLE_TECO_SDAA
78+
case DevTecoSDAA:
79+
return tecoGetRandomSampleWorkspaceSize((RandomSampleTecoDescriptor_t) desc, size);
6780
#endif
6881
}
6982
return STATUS_BAD_DEVICE;
@@ -97,6 +110,10 @@ __C infiniopStatus_t infiniopRandomSample(infiniopRandomSampleDescriptor_t desc,
97110
case DevAscendNpu: {
98111
return ascendRandomSample((RandomSampleAscendDescriptor_t) desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream);
99112
}
113+
#endif
114+
#ifdef ENABLE_TECO_SDAA
115+
case DevTecoSDAA:
116+
return tecoRandomSample((RandomSampleTecoDescriptor_t) desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream);
100117
#endif
101118
}
102119
return STATUS_BAD_DEVICE;
@@ -121,6 +138,10 @@ __C infiniopStatus_t infiniopDestroyRandomSampleDescriptor(infiniopRandomSampleD
121138
case DevAscendNpu: {
122139
return ascendDestroyRandomSampleDescriptor((RandomSampleAscendDescriptor_t) desc);
123140
}
141+
#endif
142+
#ifdef ENABLE_TECO_SDAA
143+
case DevTecoSDAA:
144+
return tecoDestroyRandomSampleDescriptor((RandomSampleTecoDescriptor_t) desc);
124145
#endif
125146
}
126147
return STATUS_BAD_DEVICE;
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#ifndef __SDAA_RANDOM_SAMPLE_H__
2+
#define __SDAA_RANDOM_SAMPLE_H__
3+
4+
#include "../../../devices/teco/teco_handle.h"
5+
#include "../../utils.h"
6+
#include "operators.h"
7+
#include <sdaa_runtime.h>
8+
9+
struct RandomSampleTecoDescriptor {
10+
Device device;
11+
int device_id;
12+
tecodnnHandle_t handle;
13+
sdaaStream_t stream;
14+
DT dtype;
15+
int voc;
16+
DT rDtype;
17+
int rLength;
18+
};
19+
20+
typedef struct RandomSampleTecoDescriptor *RandomSampleTecoDescriptor_t;
21+
22+
infiniopStatus_t tecoCreateRandomSampleDescriptor(TecoHandle_t handle,
23+
RandomSampleTecoDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result,
24+
infiniopTensorDescriptor_t probs);
25+
26+
infiniopStatus_t tecoGetRandomSampleWorkspaceSize(RandomSampleTecoDescriptor_t desc, uint64_t *size);
27+
28+
infiniopStatus_t tecoRandomSample(RandomSampleTecoDescriptor_t desc,
29+
void *workspace,
30+
uint64_t workspace_size,
31+
void *result,
32+
void const *probs,
33+
float random_val,
34+
float topp,
35+
int topk,
36+
float temperature,
37+
void *stream);
38+
39+
infiniopStatus_t tecoDestroyRandomSampleDescriptor(RandomSampleTecoDescriptor_t desc);
40+
41+
42+
#endif

0 commit comments

Comments
 (0)