diff --git a/operatorspy/tests/random_sample.py b/operatorspy/tests/random_sample.py index ea680c57..d09f293f 100644 --- a/operatorspy/tests/random_sample.py +++ b/operatorspy/tests/random_sample.py @@ -30,7 +30,10 @@ class RandomSampleDescriptor(Structure): def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): - indices = torch.zeros([topk], dtype = torch.int64) + if(torch_device == "cuda"): + indices = torch.zeros([topk], dtype = torch.uint64) + else: + indices = torch.zeros([topk], dtype = torch.int64) dataNp = data.clone().detach() sorted_indices = torch.arange(voc) @@ -52,7 +55,7 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): globalM = dataNp[0] dataNp = (dataNp - globalM) / temperature - dataNp = torch.softmax(dataNp.float(), dim = 0) + dataNp = torch.softmax(dataNp, dim = 0) sum_s = 0 for end in range(topk): sum_s += dataNp[end] @@ -88,7 +91,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_ ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu") else: ans = random_sample_0(data) - if(torch_device == 'mlu' or torch_device == 'npu'): + if(torch_device != "cuda"): indices = torch.zeros([1], dtype = torch.int64).to(torch_device) else: @@ -96,7 +99,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_ indices = torch.zeros([1], dtype = torch.uint64).to(torch_device) x_tensor = to_tensor(data, lib) indices_tensor = to_tensor(indices, lib) - if(torch_device == 'mlu' or torch_device == 'npu'): + if(torch_device != 'cuda'): indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64 diff --git a/src/ops/random_sample/cpu/random_sample.cc b/src/ops/random_sample/cpu/random_sample.cc index 28de5b93..4a58de8f 100644 --- a/src/ops/random_sample/cpu/random_sample.cc +++ b/src/ops/random_sample/cpu/random_sample.cc @@ -135,16 +135,11 @@ void random_sample_cpu_f16(RandomSampleCpuDescriptor_t desc, auto index_ = reinterpret_cast(result); auto source = reinterpret_cast(probs); - char *origin = reinterpret_cast(workspace); - uint16_t *logits_ = (uint16_t *) origin; - - std::copy(source, source + voc, logits_); - - float M = f16_to_f32(logits_[0]); + float M = f16_to_f32(source[0]); int index = 0; for (int j = 1; j < voc; j++) { - if (M < f16_to_f32(logits_[j])) { - M = f16_to_f32(logits_[j]); + if (M < f16_to_f32(source[j])) { + M = f16_to_f32(source[j]); index = j; } } diff --git a/src/ops/random_sample/cuda/random_sample.cu b/src/ops/random_sample/cuda/random_sample.cu index 40761e89..b0e8d2e2 100644 --- a/src/ops/random_sample/cuda/random_sample.cu +++ b/src/ops/random_sample/cuda/random_sample.cu @@ -3,14 +3,38 @@ #include "random_sample.cuh" #include #include - +template +__global__ void argmaxKernel(T *val_out, int voc, uint64_t *result) { + float localM = -__FLT_MAX__; + uint64_t index = threadIdx.x; + for (int i = threadIdx.x; i < voc; i += BLOCK_DIM) { + if (localM < static_cast(val_out[i])) { + localM = static_cast(val_out[i]); + index = i; + } + } + __shared__ uint64_t globalInd[BLOCK_DIM]; + __shared__ float globalM[BLOCK_DIM]; + globalInd[threadIdx.x] = index; + globalM[threadIdx.x] = localM; + for (int strip = BLOCK_DIM / 2; strip > 0; strip /= 2) { + if (threadIdx.x < strip) { + if (globalM[threadIdx.x] < globalM[threadIdx.x + strip]) { + globalM[threadIdx.x] = globalM[threadIdx.x + strip]; + globalInd[threadIdx.x] = globalInd[threadIdx.x + strip]; + } + } + __syncthreads(); + } + result[0] = globalInd[0]; +} template __global__ void softmax( T *val_out, int topk, float temperature, int voc) { float sum_s = 0.0f; - for (int i = threadIdx.x; i < topk; i += BLOCK_DIM) { + for (int i = threadIdx.x; i < voc; i += BLOCK_DIM) { sum_s += __expf(static_cast(val_out[i] - val_out[0]) / temperature); } __shared__ float sum_inverse_total; @@ -84,26 +108,51 @@ void inclusive_sum( data, data, voc, stream); } -template -void random_sample_workspace(size_t &size_radix_sort, size_t &size_scan, - int voc, cudaStream_t stream) { +infiniopStatus_t random_sample_workspace(size_t &size_radix_sort, size_t &size_scan, + int voc, DT dtype) { + if (dtype_eq(dtype, F16)) { + sort_pairs_descending(nullptr, size_radix_sort, + nullptr, nullptr, + nullptr, nullptr, + voc, nullptr); - sort_pairs_descending(nullptr, size_radix_sort, - nullptr, nullptr, - nullptr, nullptr, - voc, stream); + inclusive_sum( + nullptr, size_scan, + nullptr, voc, + nullptr); + return STATUS_SUCCESS; + } else if (dtype_eq(dtype, F32)) { + sort_pairs_descending(nullptr, size_radix_sort, + nullptr, nullptr, + nullptr, nullptr, + voc, nullptr); - inclusive_sum( - nullptr, size_scan, - nullptr, voc, - stream); + inclusive_sum( + nullptr, size_scan, + nullptr, voc, + nullptr); + return STATUS_SUCCESS; + } else if (dtype_eq(dtype, F64)) { + sort_pairs_descending(nullptr, size_radix_sort, + nullptr, nullptr, + nullptr, nullptr, + voc, nullptr); + + inclusive_sum( + nullptr, size_scan, + nullptr, voc, + nullptr); + return STATUS_SUCCESS; + } else { + return STATUS_BAD_TENSOR_DTYPE; + } } __global__ void random_sample_kernel(uint64_t *result, uint64_t *key_out) { result[0] = key_out[0]; } -void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace, void *result, +void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace, uint64_t workspace_size, void *result, void const *probs, float random_val, float topp, @@ -112,28 +161,26 @@ void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace void *stream) { int voc = desc->voc; //下面这段代码在排序 - char *origin = reinterpret_cast(workspace); - char *keyTmp = origin + voc * sizeof(half); - half *val_out = (half *) origin; - - uint64_t *key_in = (uint64_t *) keyTmp; - uint64_t *key_out = key_in + voc; - - index<<<(voc + 1023) / 1024, 1024, 0, (cudaStream_t) stream>>>(key_in, voc); - //下面开始计算workspace空间 - size_t size_radix_sort; - size_t size_scan; - random_sample_workspace(size_radix_sort, size_scan, - voc, (cudaStream_t) stream); - void *workspace_extra; - cudaMalloc(&workspace_extra, size_radix_sort + size_scan); - sort_pairs_descending( - workspace_extra, size_radix_sort, - (half *) probs, val_out, - key_in, key_out, - voc, (cudaStream_t) stream);//该函数会把排序结果和对应索引保存在val_out和key_out上 - //排序结束,然后开始做softmax变换 + if (topp > 0 && topk > 1) { + char *origin = reinterpret_cast(workspace); + char *keyTmp = origin + voc * sizeof(half); + half *val_out = (half *) origin; + + uint64_t *key_in = (uint64_t *) keyTmp; + uint64_t *key_out = key_in + voc; + + index<<<(voc + 1023) / 1024, 1024, 0, (cudaStream_t) stream>>>(key_in, voc); + //下面开始计算workspace空间 + + void *workspace_extra = reinterpret_cast(workspace) + desc->step; + uint64_t workspace_len = workspace_size - desc->step; + sort_pairs_descending( + workspace_extra, workspace_len, + (half *) probs, val_out, + key_in, key_out, + voc, (cudaStream_t) stream);//该函数会把排序结果和对应索引保存在val_out和key_out上 + //排序结束,然后开始做softmax变换 int BLOCK_DIM = 1024; int num_blocks = (voc + BLOCK_DIM - 1) / BLOCK_DIM; softmax<<>>(val_out, topk, @@ -141,7 +188,7 @@ void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace inclusive_sum( - workspace_extra, size_scan, + workspace_extra, workspace_len, val_out, voc, (cudaStream_t) stream);//该函数会实现scan功能不断累加结果 random_sample_kernel<<<1, 1, 0, (cudaStream_t) stream>>>((uint64_t *) result, @@ -152,10 +199,10 @@ void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace key_out); } else { - random_sample_kernel<<<1, 1, 0, (cudaStream_t) stream>>>((uint64_t *) result, - key_out); + int BLOCK_DIM = 1024; + int num_blocks = (voc + BLOCK_DIM - 1) / BLOCK_DIM; + argmaxKernel<<>>((half *) probs, voc, (uint64_t *) result); } - cudaFree(workspace_extra); } infiniopStatus_t cudaRandomSample(RandomSampleCudaDescriptor_t desc, @@ -172,7 +219,7 @@ infiniopStatus_t cudaRandomSample(RandomSampleCudaDescriptor_t desc, return STATUS_BAD_DEVICE; } if (dtype_eq(desc->dtype, F16)) { - random_sample_nv_gpu_f16(desc, workspace, result, probs, random_val, topp, topk, temperature, stream); + random_sample_nv_gpu_f16(desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream); return STATUS_SUCCESS; } diff --git a/src/ops/random_sample/cuda/random_sample.cuh b/src/ops/random_sample/cuda/random_sample.cuh index d3fff76d..d99b034a 100644 --- a/src/ops/random_sample/cuda/random_sample.cuh +++ b/src/ops/random_sample/cuda/random_sample.cuh @@ -11,6 +11,7 @@ struct RandomSampleCudaDescriptor { int voc; DT rDtype; int rLength; + int step; }; typedef struct RandomSampleCudaDescriptor *RandomSampleCudaDescriptor_t; @@ -18,8 +19,9 @@ typedef struct RandomSampleCudaDescriptor *RandomSampleCudaDescriptor_t; infiniopStatus_t cudaCreateRandomSampleDescriptor(CudaHandle_t handle, RandomSampleCudaDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, infiniopTensorDescriptor_t probs); - -infiniopStatus_t cudaGetRandomSampleWorkspaceSize(RandomSampleCudaDescriptor_t desc, uint64_t *size); +infiniopStatus_t random_sample_workspace(size_t &size_radix_sort, size_t &size_scan, + int voc, DT dtype); +infiniopStatus_t cudaGetRandomSampleWorkspaceSize(RandomSampleCudaDescriptor_t desc, unsigned long int *size); infiniopStatus_t cudaRandomSample(RandomSampleCudaDescriptor_t desc, void *workspace, diff --git a/src/ops/random_sample/cuda/random_sample_cuda.cc b/src/ops/random_sample/cuda/random_sample_cuda.cc index 022a113b..283808ba 100644 --- a/src/ops/random_sample/cuda/random_sample_cuda.cc +++ b/src/ops/random_sample/cuda/random_sample_cuda.cc @@ -8,26 +8,37 @@ infiniopStatus_t cudaCreateRandomSampleDescriptor(CudaHandle_t handle, if (probs->ndim != 1) { return STATUS_BAD_TENSOR_SHAPE; } - if (!dtype_eq(result->dt, U64)) + if (!dtype_eq(probs->dt, F16) && !dtype_eq(result->dt, U64)) { return STATUS_BAD_TENSOR_DTYPE; + } + int voc = probs->shape[0]; int rLength = result->shape[0]; if (result->ndim != 1 && rLength != 1) { return STATUS_BAD_TENSOR_SHAPE; } + int step = 2 * voc * sizeof(uint64_t) + voc * sizeof(probs->dt); *desc_ptr = new RandomSampleCudaDescriptor{ handle->device, handle->device_id, probs->dt, voc, result->dt, - rLength}; + rLength, + step}; return STATUS_SUCCESS; } -infiniopStatus_t cudaGetRandomSampleWorkspaceSize(RandomSampleCudaDescriptor_t desc, uint64_t *size) { - *size = desc->voc * (2 * sizeof(uint64_t) + sizeof(desc->dtype)); +infiniopStatus_t cudaGetRandomSampleWorkspaceSize(RandomSampleCudaDescriptor_t desc, unsigned long int *size) { + size_t size_radix_sort; + size_t size_scan; + infiniopStatus_t status = random_sample_workspace(size_radix_sort, size_scan, + desc->voc, desc->dtype); + if (status != STATUS_SUCCESS) { + return status; + } + *size = desc->step + std::max(size_radix_sort, size_scan); return STATUS_SUCCESS; } diff --git a/src/ops/utils.h b/src/ops/utils.h index b48cf419..8e0286a1 100644 --- a/src/ops/utils.h +++ b/src/ops/utils.h @@ -224,7 +224,7 @@ inline infiniopTensorDescriptor_t dim_merge(infiniopTensorDescriptor_t desc, uin // split the dimension dim of a tensor descriptor into multiple dimensions inline infiniopTensorDescriptor_t dim_split(infiniopTensorDescriptor_t desc, uint64_t dim, const std::vector &dims) { uint64_t ndim = desc->ndim; - if (desc->shape[dim] != std::accumulate(dims.begin(), dims.end(), (uint64_t)1, std::multiplies{})) { + if (desc->shape[dim] != std::accumulate(dims.begin(), dims.end(), (uint64_t) 1, std::multiplies{})) { return nullptr; } uint64_t new_ndim = ndim + dims.size() - 1;