Skip to content

Commit 0ca0004

Browse files
committed
success random sample
1 parent 05eb473 commit 0ca0004

File tree

2 files changed

+43
-21
lines changed

2 files changed

+43
-21
lines changed

operatorspy/tests/random_sample.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
114114
)
115115
if torch_device == "npu":
116116
torch.npu.synchronize()
117-
if torch_device == "sdaa":
118-
torch.sdaa.synchronize()
119-
#print(indices[0], data[indices[0]], ans, data[ans])
117+
120118
assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
121119
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
122120
print("Test passed!")

src/ops/random_sample/teco/random_sample_teco.scpp

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ infiniopStatus_t tecoGetRandomSampleWorkspaceSize(RandomSampleTecoDescriptor_t d
4141
}
4242

4343
infiniopStatus_t tecoDestroyRandomSampleDescriptor(RandomSampleTecoDescriptor_t desc) {
44+
//sdaaStreamDestroy(desc->stream);
4445
delete desc;
4546
return STATUS_SUCCESS;
4647
}
@@ -213,35 +214,58 @@ infiniopStatus_t tecoRandomSample(RandomSampleTecoDescriptor_t desc,
213214

214215

215216
tecodnnMemset(desc->handle, value, 0, topk_);
216-
topkKernel<half>(probs, (void *)index, (void *)value, topk, desc->voc, desc->handle, desc->stream);
217-
sdaaEvent_t event;
218-
sdaaEventCreate(&event);
217+
218+
int voc = desc->voc;
219+
tecodnnSetStream(desc->handle, desc->stream);
220+
tecodnnStatus_t status;
221+
222+
tecodnnTensorDescriptor_t input_desc_teco, value_desc_teco, index_desc_teco;
223+
tecodnnCreateTensorDescriptor(&input_desc_teco);
224+
tecodnnCreateTensorDescriptor(&value_desc_teco);
225+
tecodnnCreateTensorDescriptor(&index_desc_teco);
226+
227+
int32_t probsDims[2] = {1, voc}, probsStrides[2] = {voc, 1};
228+
int32_t resultDims[2] = {1, topk}, resultStrides[2] = {topk, 1};
229+
tecodnnSetTensorNdDescriptor(input_desc_teco, TECODNN_DATA_HALF, 2, probsDims, probsStrides);
230+
tecodnnSetTensorNdDescriptor(value_desc_teco, TECODNN_DATA_HALF, 2, resultDims, resultStrides);
231+
232+
tecodnnSetTensorNdDescriptor(index_desc_teco, TECODNN_DATA_INT64, 2, resultDims, resultStrides);
233+
234+
size_t workSpaceSizeInBytes;
235+
int axis = 1;
236+
bool largest = true;
237+
bool sorted = true;
238+
tecodnnGetTopkExWorkspaceSize(desc->handle, axis, topk, largest, sorted, input_desc_teco, value_desc_teco,
239+
index_desc_teco, &workSpaceSizeInBytes);
240+
void *compute_workspace;
241+
sdaaMalloc((void **) &compute_workspace, workSpaceSizeInBytes);
242+
243+
status = tecodnnTopkEx(desc->handle, axis, topk, largest, sorted, input_desc_teco, probs, value_desc_teco, value,
244+
index_desc_teco, index, compute_workspace, workSpaceSizeInBytes);
245+
sdaaStreamSynchronize(desc->stream);
246+
sdaaFree(compute_workspace);
247+
tecodnnDestroyTensorDescriptor(input_desc_teco);
248+
tecodnnDestroyTensorDescriptor(value_desc_teco);
249+
tecodnnDestroyTensorDescriptor(index_desc_teco);
250+
if (status != TECODNN_STATUS_SUCCESS) {
251+
printf("topk %s\n", tecodnnGetErrorString(status));
252+
}
253+
254+
//topkKernel<half>(probs, (void *)index, (void *)value, topk, desc->voc, desc->handle, desc->stream);
255+
219256
if (topp > 0 && topk > 1){
220257
softmaxKernel<half>(probs, (void *) destination, desc->voc, desc->handle, desc->stream);
221-
222-
sdaaEventRecord(event, desc->stream);
223258
memKernel<half><<<1, desc->stream>>>(destination, value, index, topk);
224-
sdaaEventSynchronize(event);
225259
sdaaDeviceSynchronize();
226-
227260
cumSumKernel<half>((void *) value, (void *) scan_value, topk_, desc->handle, desc->stream);
228-
229-
sdaaEventRecord(event, desc->stream);
230261
sample<half><<<1, desc->stream>>>(scan_value, index, (uint64_t *) result, random_val, topp, topk);
231-
232-
sdaaEventSynchronize(event);//必须使用这些阻塞,确保kernel完成
233-
sdaaDeviceSynchronize();
234-
235-
262+
sdaaDeviceSynchronize();
236263
}
237264
else {
238-
sdaaEventRecord(event, desc->stream);
239265
randomSampleKernel<<<1, desc->stream>>>((uint64_t *)result, index);
240-
sdaaEventSynchronize(event);
241266
sdaaDeviceSynchronize();
242-
243267
}
244-
sdaaEventDestroy(event);
268+
245269

246270
return STATUS_SUCCESS;
247271
}

0 commit comments

Comments
 (0)