@@ -41,6 +41,7 @@ infiniopStatus_t tecoGetRandomSampleWorkspaceSize(RandomSampleTecoDescriptor_t d
4141}
4242
4343infiniopStatus_t tecoDestroyRandomSampleDescriptor(RandomSampleTecoDescriptor_t desc) {
44+ //sdaaStreamDestroy(desc->stream);
4445 delete desc;
4546 return STATUS_SUCCESS;
4647}
@@ -213,35 +214,58 @@ infiniopStatus_t tecoRandomSample(RandomSampleTecoDescriptor_t desc,
213214
214215
215216 tecodnnMemset(desc->handle, value, 0, topk_);
216- topkKernel<half>(probs, (void *)index, (void *)value, topk, desc->voc, desc->handle, desc->stream);
217- sdaaEvent_t event;
218- sdaaEventCreate(&event);
217+
218+ int voc = desc->voc;
219+ tecodnnSetStream(desc->handle, desc->stream);
220+ tecodnnStatus_t status;
221+
222+ tecodnnTensorDescriptor_t input_desc_teco, value_desc_teco, index_desc_teco;
223+ tecodnnCreateTensorDescriptor(&input_desc_teco);
224+ tecodnnCreateTensorDescriptor(&value_desc_teco);
225+ tecodnnCreateTensorDescriptor(&index_desc_teco);
226+
227+ int32_t probsDims[2] = {1, voc}, probsStrides[2] = {voc, 1};
228+ int32_t resultDims[2] = {1, topk}, resultStrides[2] = {topk, 1};
229+ tecodnnSetTensorNdDescriptor(input_desc_teco, TECODNN_DATA_HALF, 2, probsDims, probsStrides);
230+ tecodnnSetTensorNdDescriptor(value_desc_teco, TECODNN_DATA_HALF, 2, resultDims, resultStrides);
231+
232+ tecodnnSetTensorNdDescriptor(index_desc_teco, TECODNN_DATA_INT64, 2, resultDims, resultStrides);
233+
234+ size_t workSpaceSizeInBytes;
235+ int axis = 1;
236+ bool largest = true;
237+ bool sorted = true;
238+ tecodnnGetTopkExWorkspaceSize(desc->handle, axis, topk, largest, sorted, input_desc_teco, value_desc_teco,
239+ index_desc_teco, &workSpaceSizeInBytes);
240+ void *compute_workspace;
241+ sdaaMalloc((void **) &compute_workspace, workSpaceSizeInBytes);
242+
243+ status = tecodnnTopkEx(desc->handle, axis, topk, largest, sorted, input_desc_teco, probs, value_desc_teco, value,
244+ index_desc_teco, index, compute_workspace, workSpaceSizeInBytes);
245+ sdaaStreamSynchronize(desc->stream);
246+ sdaaFree(compute_workspace);
247+ tecodnnDestroyTensorDescriptor(input_desc_teco);
248+ tecodnnDestroyTensorDescriptor(value_desc_teco);
249+ tecodnnDestroyTensorDescriptor(index_desc_teco);
250+ if (status != TECODNN_STATUS_SUCCESS) {
251+ printf("topk %s\n", tecodnnGetErrorString(status));
252+ }
253+
254+ //topkKernel<half>(probs, (void *)index, (void *)value, topk, desc->voc, desc->handle, desc->stream);
255+
219256 if (topp > 0 && topk > 1){
220257 softmaxKernel<half>(probs, (void *) destination, desc->voc, desc->handle, desc->stream);
221-
222- sdaaEventRecord(event, desc->stream);
223258 memKernel<half><<<1, desc->stream>>>(destination, value, index, topk);
224- sdaaEventSynchronize(event);
225259 sdaaDeviceSynchronize();
226-
227260 cumSumKernel<half>((void *) value, (void *) scan_value, topk_, desc->handle, desc->stream);
228-
229- sdaaEventRecord(event, desc->stream);
230261 sample<half><<<1, desc->stream>>>(scan_value, index, (uint64_t *) result, random_val, topp, topk);
231-
232- sdaaEventSynchronize(event);//必须使用这些阻塞,确保kernel完成
233- sdaaDeviceSynchronize();
234-
235-
262+ sdaaDeviceSynchronize();
236263 }
237264 else {
238- sdaaEventRecord(event, desc->stream);
239265 randomSampleKernel<<<1, desc->stream>>>((uint64_t *)result, index);
240- sdaaEventSynchronize(event);
241266 sdaaDeviceSynchronize();
242-
243267 }
244- sdaaEventDestroy(event);
268+
245269
246270 return STATUS_SUCCESS;
247271 }
0 commit comments