@@ -79,7 +79,10 @@ void topkKernel(const void *probs, void *index, void *value, int topk, int voc,
7979 status = tecodnnTopkEx(handle, axis, topk, largest, sorted, input_desc_teco, probs, value_desc_teco, value,
8080 index_desc_teco, index, compute_workspace, workSpaceSizeInBytes);
8181 sdaaStreamSynchronize(stream);
82-
82+ sdaaFree(compute_workspace);
83+ tecodnnDestroyTensorDescriptor(input_desc_teco);
84+ tecodnnDestroyTensorDescriptor(value_desc_teco);
85+ tecodnnDestroyTensorDescriptor(index_desc_teco);
8386 if (status != TECODNN_STATUS_SUCCESS) {
8487 printf("topk %s\n", tecodnnGetErrorString(status));
8588 }
@@ -106,7 +109,8 @@ void softmaxKernel(const void *probs, void *destination, int voc, tecodnnHandle_
106109
107110 status = tecodnnSoftmaxForward(handle, algo, mode, &alpha, x_desc_teco, probs, &beta, y_desc_teco, destination);
108111 sdaaStreamSynchronize(stream);
109-
112+ tecodnnDestroyTensorDescriptor(x_desc_teco);
113+ tecodnnDestroyTensorDescriptor(y_desc_teco);
110114 if (status != TECODNN_STATUS_SUCCESS) {
111115 printf("softmax %s\n", tecodnnGetErrorString(status));
112116 }
@@ -145,7 +149,8 @@ void cumSumKernel(void *value, void *scan_value, int topk_, tecodnnHandle_t hand
145149
146150 status = tecodnnCumSum(handle, 3, a_desc_teco, value, c_desc_teco, scan_value);
147151 sdaaStreamSynchronize(stream);
148-
152+ tecodnnDestroyTensorDescriptor(a_desc_teco);
153+ tecodnnDestroyTensorDescriptor(c_desc_teco);
149154 if (status != TECODNN_STATUS_SUCCESS) {
150155 printf("scan %s\n", tecodnnGetErrorString(status));
151156 }
@@ -176,9 +181,14 @@ __global__ void sample(T *scan_value, int64_t *index, uint64_t *result, float ra
176181 }
177182 }
178183 }
179-
184+
180185}
181186
187+ __global__ void randomSampleKernel(uint64_t *result, int64_t *index){
188+ if(threadIdx == 0){
189+ result[0] = index[0];
190+ }
191+ }
182192infiniopStatus_t tecoRandomSample(RandomSampleTecoDescriptor_t desc,
183193 void *workspace,
184194 uint64_t workspace_size,
@@ -204,26 +214,34 @@ infiniopStatus_t tecoRandomSample(RandomSampleTecoDescriptor_t desc,
204214
205215 tecodnnMemset(desc->handle, value, 0, topk_);
206216 topkKernel<half>(probs, (void *)index, (void *)value, topk, desc->voc, desc->handle, desc->stream);
207- softmaxKernel<half>(probs, (void *) destination, desc->voc, desc->handle, desc->stream);
208-
209217 sdaaEvent_t event;
210218 sdaaEventCreate(&event);
219+ if (topp > 0 && topk > 1){
220+ softmaxKernel<half>(probs, (void *) destination, desc->voc, desc->handle, desc->stream);
211221
212- sdaaEventRecord(event, desc->stream);
213- memKernel<half><<<1, desc->stream>>>(destination, value, index, topk);
214- sdaaEventSynchronize(event);
215- sdaaDeviceSynchronize();
222+ sdaaEventRecord(event, desc->stream);
223+ memKernel<half><<<1, desc->stream>>>(destination, value, index, topk);
224+ sdaaEventSynchronize(event);
225+ sdaaDeviceSynchronize();
216226
217- cumSumKernel<half>((void *) value, (void *) scan_value, topk_, desc->handle, desc->stream);
227+ cumSumKernel<half>((void *) value, (void *) scan_value, topk_, desc->handle, desc->stream);
218228
219- sdaaEventRecord(event, desc->stream);
220- sample<half><<<1, desc->stream>>>(scan_value, index, (uint64_t *) result, random_val, topp, topk);
221-
222- sdaaEventSynchronize(event);//必须使用这些阻塞,确保kernel完成
223- sdaaDeviceSynchronize();
229+ sdaaEventRecord(event, desc->stream);
230+ sample<half><<<1, desc->stream>>>(scan_value, index, (uint64_t *) result, random_val, topp, topk);
231+
232+ sdaaEventSynchronize(event);//必须使用这些阻塞,确保kernel完成
233+ sdaaDeviceSynchronize();
234+
235+
236+ }
237+ else {
238+ sdaaEventRecord(event, desc->stream);
239+ randomSampleKernel<<<1, desc->stream>>>((uint64_t *)result, index);
240+ sdaaEventSynchronize(event);
241+ sdaaDeviceSynchronize();
224242
243+ }
225244 sdaaEventDestroy(event);
226-
227245
228246 return STATUS_SUCCESS;
229247 }
0 commit comments