Skip to content

Commit 05eb473

Browse files
committed
destroy descriptor
1 parent e9ad7af commit 05eb473

File tree

2 files changed

+36
-18
lines changed

2 files changed

+36
-18
lines changed

operatorspy/tests/random_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
116116
torch.npu.synchronize()
117117
if torch_device == "sdaa":
118118
torch.sdaa.synchronize()
119-
print(indices[0], data[indices[0]], ans, data[ans])
119+
#print(indices[0], data[indices[0]], ans, data[ans])
120120
assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
121121
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
122122
print("Test passed!")

src/ops/random_sample/teco/random_sample_teco.scpp

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ void topkKernel(const void *probs, void *index, void *value, int topk, int voc,
7979
status = tecodnnTopkEx(handle, axis, topk, largest, sorted, input_desc_teco, probs, value_desc_teco, value,
8080
index_desc_teco, index, compute_workspace, workSpaceSizeInBytes);
8181
sdaaStreamSynchronize(stream);
82-
82+
sdaaFree(compute_workspace);
83+
tecodnnDestroyTensorDescriptor(input_desc_teco);
84+
tecodnnDestroyTensorDescriptor(value_desc_teco);
85+
tecodnnDestroyTensorDescriptor(index_desc_teco);
8386
if (status != TECODNN_STATUS_SUCCESS) {
8487
printf("topk %s\n", tecodnnGetErrorString(status));
8588
}
@@ -106,7 +109,8 @@ void softmaxKernel(const void *probs, void *destination, int voc, tecodnnHandle_
106109

107110
status = tecodnnSoftmaxForward(handle, algo, mode, &alpha, x_desc_teco, probs, &beta, y_desc_teco, destination);
108111
sdaaStreamSynchronize(stream);
109-
112+
tecodnnDestroyTensorDescriptor(x_desc_teco);
113+
tecodnnDestroyTensorDescriptor(y_desc_teco);
110114
if (status != TECODNN_STATUS_SUCCESS) {
111115
printf("softmax %s\n", tecodnnGetErrorString(status));
112116
}
@@ -145,7 +149,8 @@ void cumSumKernel(void *value, void *scan_value, int topk_, tecodnnHandle_t hand
145149

146150
status = tecodnnCumSum(handle, 3, a_desc_teco, value, c_desc_teco, scan_value);
147151
sdaaStreamSynchronize(stream);
148-
152+
tecodnnDestroyTensorDescriptor(a_desc_teco);
153+
tecodnnDestroyTensorDescriptor(c_desc_teco);
149154
if (status != TECODNN_STATUS_SUCCESS) {
150155
printf("scan %s\n", tecodnnGetErrorString(status));
151156
}
@@ -176,9 +181,14 @@ __global__ void sample(T *scan_value, int64_t *index, uint64_t *result, float ra
176181
}
177182
}
178183
}
179-
184+
180185
}
181186

187+
__global__ void randomSampleKernel(uint64_t *result, int64_t *index){
188+
if(threadIdx == 0){
189+
result[0] = index[0];
190+
}
191+
}
182192
infiniopStatus_t tecoRandomSample(RandomSampleTecoDescriptor_t desc,
183193
void *workspace,
184194
uint64_t workspace_size,
@@ -204,26 +214,34 @@ infiniopStatus_t tecoRandomSample(RandomSampleTecoDescriptor_t desc,
204214

205215
tecodnnMemset(desc->handle, value, 0, topk_);
206216
topkKernel<half>(probs, (void *)index, (void *)value, topk, desc->voc, desc->handle, desc->stream);
207-
softmaxKernel<half>(probs, (void *) destination, desc->voc, desc->handle, desc->stream);
208-
209217
sdaaEvent_t event;
210218
sdaaEventCreate(&event);
219+
if (topp > 0 && topk > 1){
220+
softmaxKernel<half>(probs, (void *) destination, desc->voc, desc->handle, desc->stream);
211221

212-
sdaaEventRecord(event, desc->stream);
213-
memKernel<half><<<1, desc->stream>>>(destination, value, index, topk);
214-
sdaaEventSynchronize(event);
215-
sdaaDeviceSynchronize();
222+
sdaaEventRecord(event, desc->stream);
223+
memKernel<half><<<1, desc->stream>>>(destination, value, index, topk);
224+
sdaaEventSynchronize(event);
225+
sdaaDeviceSynchronize();
216226

217-
cumSumKernel<half>((void *) value, (void *) scan_value, topk_, desc->handle, desc->stream);
227+
cumSumKernel<half>((void *) value, (void *) scan_value, topk_, desc->handle, desc->stream);
218228

219-
sdaaEventRecord(event, desc->stream);
220-
sample<half><<<1, desc->stream>>>(scan_value, index, (uint64_t *) result, random_val, topp, topk);
221-
222-
sdaaEventSynchronize(event);//必须使用这些阻塞,确保kernel完成
223-
sdaaDeviceSynchronize();
229+
sdaaEventRecord(event, desc->stream);
230+
sample<half><<<1, desc->stream>>>(scan_value, index, (uint64_t *) result, random_val, topp, topk);
231+
232+
sdaaEventSynchronize(event);//必须使用这些阻塞,确保kernel完成
233+
sdaaDeviceSynchronize();
234+
235+
236+
}
237+
else {
238+
sdaaEventRecord(event, desc->stream);
239+
randomSampleKernel<<<1, desc->stream>>>((uint64_t *)result, index);
240+
sdaaEventSynchronize(event);
241+
sdaaDeviceSynchronize();
224242

243+
}
225244
sdaaEventDestroy(event);
226-
227245

228246
return STATUS_SUCCESS;
229247
}

0 commit comments

Comments
 (0)