33#include " random_sample.cuh"
44#include < cub/block/block_reduce.cuh>
55#include < cub/cub.cuh>
6-
6+ template <class T , int BLOCK_DIM>
7+ __global__ void argmaxKernel (T *val_out, int voc, uint64_t *result) {
8+ float localM = -__FLT_MAX__;
9+ uint64_t index = threadIdx .x ;
10+ for (int i = threadIdx .x ; i < voc; i += BLOCK_DIM) {
11+ if (localM < static_cast <float >(val_out[i])) {
12+ localM = static_cast <float >(val_out[i]);
13+ index = i;
14+ }
15+ }
16+ __shared__ uint64_t globalInd[BLOCK_DIM];
17+ __shared__ float globalM[BLOCK_DIM];
18+ globalInd[threadIdx .x ] = index;
19+ globalM[threadIdx .x ] = localM;
20+ for (int strip = BLOCK_DIM / 2 ; strip > 0 ; strip /= 2 ) {
21+ if (threadIdx .x < strip) {
22+ if (globalM[threadIdx .x ] < globalM[threadIdx .x + strip]) {
23+ globalM[threadIdx .x ] = globalM[threadIdx .x + strip];
24+ globalInd[threadIdx .x ] = globalInd[threadIdx .x + strip];
25+ }
26+ }
27+ __syncthreads ();
28+ }
29+ result[0 ] = globalInd[0 ];
30+ }
731template <class T , int BLOCK_DIM>
832__global__ void softmax (
933 T *val_out,
@@ -132,25 +156,26 @@ void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace
132156 void *stream) {
133157 int voc = desc->voc ;
134158 // 下面这段代码在排序
135- char *origin = reinterpret_cast <char *>(workspace);
136- char *keyTmp = origin + voc * sizeof (half);
137- half *val_out = (half *) origin;
138159
139- uint64_t *key_in = (uint64_t *) keyTmp;
140- uint64_t *key_out = key_in + voc;
160+ if (topp > 0 && topk > 1 ) {
161+ char *origin = reinterpret_cast <char *>(workspace);
162+ char *keyTmp = origin + voc * sizeof (half);
163+ half *val_out = (half *) origin;
141164
142- index <<<(voc + 1023 ) / 1024 , 1024 , 0 , (cudaStream_t) stream>>> (key_in, voc) ;
143- // 下面开始计算workspace空间
165+ uint64_t *key_in = ( uint64_t *) keyTmp ;
166+ uint64_t *key_out = key_in + voc;
144167
145- void *workspace_extra = reinterpret_cast <char *>(workspace) + desc->step ;
146- uint64_t workspace_len = workspace_size - desc->step ;
147- sort_pairs_descending<half, uint64_t >(
148- workspace_extra, workspace_len,
149- (half *) probs, val_out,
150- key_in, key_out,
151- voc, (cudaStream_t) stream);// 该函数会把排序结果和对应索引保存在val_out和key_out上
152- // 排序结束,然后开始做softmax变换
153- if (topp > 0 && topk > 1 ) {
168+ index<<<(voc + 1023 ) / 1024 , 1024 , 0 , (cudaStream_t) stream>>> (key_in, voc);
169+ // 下面开始计算workspace空间
170+
171+ void *workspace_extra = reinterpret_cast <char *>(workspace) + desc->step ;
172+ uint64_t workspace_len = workspace_size - desc->step ;
173+ sort_pairs_descending<half, uint64_t >(
174+ workspace_extra, workspace_len,
175+ (half *) probs, val_out,
176+ key_in, key_out,
177+ voc, (cudaStream_t) stream);// 该函数会把排序结果和对应索引保存在val_out和key_out上
178+ // 排序结束,然后开始做softmax变换
154179 int BLOCK_DIM = 1024 ;
155180 int num_blocks = (voc + BLOCK_DIM - 1 ) / BLOCK_DIM;
156181 softmax<half, 1024 ><<<num_blocks, BLOCK_DIM, 0 , (cudaStream_t) stream>>> (val_out, topk,
@@ -169,8 +194,9 @@ void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace
169194 key_out);
170195
171196 } else {
172- random_sample_kernel<<<1 , 1 , 0 , (cudaStream_t) stream>>> ((uint64_t *) result,
173- key_out);
197+ int BLOCK_DIM = 1024 ;
198+ int num_blocks = (voc + BLOCK_DIM - 1 ) / BLOCK_DIM;
199+ argmaxKernel<half, 1024 ><<<num_blocks, BLOCK_DIM, 0 , (cudaStream_t) stream>>> ((half *) probs, voc, (uint64_t *) result);
174200 }
175201}
176202
0 commit comments