|
| 1 | +#include "rotary_embedding_sdaa.h" |
| 2 | + |
| 3 | +__local__ halfv16 x_local, y_local; |
| 4 | +__local__ floatv16 sin_local, cos_local, tmp_local; |
| 5 | + |
| 6 | +infiniopStatus_t tecoCreateRoPEDescriptor(TecoHandle_t handle, |
| 7 | + RoPETecoDescriptor_t *desc_ptr, |
| 8 | + infiniopTensorDescriptor_t t, |
| 9 | + infiniopTensorDescriptor_t pos_ids, |
| 10 | + infiniopTensorDescriptor_t sin_table, |
| 11 | + infiniopTensorDescriptor_t cos_table){ |
| 12 | + if (desc_ptr == nullptr) |
| 13 | + return STATUS_MEMORY_NOT_ALLOCATED; |
| 14 | + |
| 15 | + if (t->ndim != 3 || |
| 16 | + pos_ids->ndim != 1 || |
| 17 | + sin_table->ndim != 2 || |
| 18 | + cos_table->ndim != 2) |
| 19 | + return STATUS_BAD_TENSOR_SHAPE; |
| 20 | + |
| 21 | + auto seqlen = t->shape[0]; |
| 22 | + auto nhead = t->shape[1]; |
| 23 | + auto dhead = t->shape[2]; |
| 24 | + auto total_seqlen = sin_table->shape[0]; |
| 25 | + |
| 26 | + if (dhead % 2 != 0) |
| 27 | + return STATUS_BAD_TENSOR_SHAPE; |
| 28 | + |
| 29 | + if (pos_ids->shape[0] != seqlen || |
| 30 | + sin_table->shape[1] != dhead || |
| 31 | + cos_table->shape[1] != dhead || |
| 32 | + sin_table->shape[0] != cos_table->shape[0]) |
| 33 | + return STATUS_BAD_TENSOR_SHAPE; |
| 34 | + |
| 35 | + if (t->strides[2] != 1 || |
| 36 | + pos_ids->strides[0] != 1 || |
| 37 | + sin_table->strides[1] != 1 || |
| 38 | + cos_table->strides[1] != 1) |
| 39 | + return STATUS_BAD_TENSOR_STRIDES; |
| 40 | + |
| 41 | + if (!dtype_eq(t->dt, F16)) |
| 42 | + return STATUS_BAD_TENSOR_DTYPE; |
| 43 | + |
| 44 | + if (!dtype_eq(sin_table->dt, F32) || !dtype_eq(cos_table->dt, F32)) |
| 45 | + return STATUS_BAD_TENSOR_DTYPE; |
| 46 | + |
| 47 | + if (!dtype_eq(pos_ids->dt, U64)) |
| 48 | + return STATUS_BAD_TENSOR_DTYPE; |
| 49 | + int x_stride_seqlen = static_cast<int>(t->strides[0]); |
| 50 | + int x_stride_nhead = static_cast<int>(t->strides[1]); |
| 51 | + *desc_ptr = new RoPETecoDescriptor{ |
| 52 | + handle->device, |
| 53 | + handle->device_id, |
| 54 | + t->dt, |
| 55 | + seqlen, |
| 56 | + nhead, |
| 57 | + dhead, |
| 58 | + total_seqlen, |
| 59 | + x_stride_seqlen, |
| 60 | + x_stride_nhead}; |
| 61 | + |
| 62 | + return STATUS_SUCCESS; |
| 63 | +} |
| 64 | + |
| 65 | +infiniopStatus_t tecoGetRoPEWorkspaceSize(RoPETecoDescriptor_t desc, uint64_t *size) { |
| 66 | + *size = 0; |
| 67 | + return STATUS_SUCCESS; |
| 68 | +} |
| 69 | + |
| 70 | +__global__ void RoPE(half *destination, |
| 71 | + const uint64_t *pos_ids, |
| 72 | + const float *sin_table, const float *cos_table, |
| 73 | + int x_stride_seqlen, int x_stride_nhead, |
| 74 | + int seqlen, int nhead, int dhead){ |
| 75 | + int other_size = seqlen * nhead; |
| 76 | + int remain = other_size % threadDim; |
| 77 | + int step_easy = (other_size - remain) / threadDim; |
| 78 | + int step_hard = step_easy + 1; |
| 79 | + int step = (threadIdx < remain ? step_hard : step_easy); |
| 80 | + int ind_start = (threadIdx < remain ? threadIdx * step_hard : remain * step_hard + (threadIdx - remain) * step_easy); |
| 81 | + |
| 82 | + int buf_size = 16; |
| 83 | + int remain_dhead = dhead % buf_size; |
| 84 | + int repeat = (dhead - remain_dhead) / buf_size; |
| 85 | + |
| 86 | + for(int i = ind_start; i < ind_start + step; i++){ |
| 87 | + int ind_i = i; |
| 88 | + int ind_s = 0; |
| 89 | + |
| 90 | + ind_s += (ind_i % nhead) * x_stride_nhead; |
| 91 | + ind_i /= nhead; |
| 92 | + ind_s += (ind_i % seqlen) * x_stride_seqlen; |
| 93 | + |
| 94 | + int index = static_cast<int>(pos_ids[ind_i % seqlen]) * dhead; |
| 95 | + |
| 96 | + for(int r = 0; r < repeat; r++){ |
| 97 | + int start_s = ind_s + r * buf_size; |
| 98 | + int sin_cos_index = index + r * buf_size; |
| 99 | + |
| 100 | + simd_load(x_local, destination + start_s); |
| 101 | + simd_load(sin_local, sin_table + sin_cos_index); |
| 102 | + simd_load(cos_local, cos_table + sin_cos_index); |
| 103 | + |
| 104 | + tmp_local = simd_cvt_h2f(x_local); |
| 105 | + |
| 106 | + for(int k = 0; k < buf_size / 2; k++){ |
| 107 | + float a = tmp_local[2 * k]; |
| 108 | + float b = tmp_local[2 * k + 1]; |
| 109 | + float sin0 = sin_local[2 * k], cos0 = cos_local[2 * k]; |
| 110 | + float sin1 = sin_local[2 * k + 1], cos1 = cos_local[2 * k + 1]; |
| 111 | + tmp_local[2 * k] = a * cos0 - b * sin0; |
| 112 | + tmp_local[2 * k + 1] = a * sin1 + b * cos1; |
| 113 | + } |
| 114 | + y_local = simd_cvt_f2h(tmp_local); |
| 115 | + simd_store(y_local, destination + start_s); |
| 116 | + |
| 117 | + } |
| 118 | + if(remain_dhead){ |
| 119 | + int start_s = ind_s + repeat * buf_size; |
| 120 | + int sin_cos_index = index + repeat * buf_size; |
| 121 | + for(int k = 0; k < remain_dhead / 2; k++){ |
| 122 | + float a = static_cast<float>(destination[start_s + 2 * k]); |
| 123 | + float b = static_cast<float>(destination[start_s + 2 * k + 1]); |
| 124 | + float sin0 = sin_table[sin_cos_index + 2 * k], cos0 = cos_local[sin_cos_index + 2 * k]; |
| 125 | + float sin1 = sin_local[sin_cos_index + 2 * k + 1], cos1 = cos_local[sin_cos_index + 2 * k + 1]; |
| 126 | + destination[start_s + 2 * k] = static_cast<half>(a * cos0 - b * sin0); |
| 127 | + destination[start_s + 2 * k + 1] = static_cast<half>(a * sin1 + b * cos1); |
| 128 | + } |
| 129 | + } |
| 130 | + } |
| 131 | +} |
| 132 | + |
| 133 | +infiniopStatus_t tecoRoPE(RoPETecoDescriptor_t desc, |
| 134 | + void *workspace, |
| 135 | + uint64_t workspace_size, |
| 136 | + void *t, |
| 137 | + void const *pos_ids, |
| 138 | + void const *sin_table, |
| 139 | + void const *cos_table, |
| 140 | + void *stream){ |
| 141 | + auto t_ptr = reinterpret_cast<half *>(t); |
| 142 | + auto sin_ptr = reinterpret_cast<const float *>(sin_table); |
| 143 | + auto cos_ptr = reinterpret_cast<const float *>(cos_table); |
| 144 | + auto pos_ptr = reinterpret_cast<const uint64_t *>(pos_ids); |
| 145 | + |
| 146 | + int seqlen = static_cast<int>(desc->seqlen); |
| 147 | + int nhead = static_cast<int>(desc->nhead); |
| 148 | + int dhead = static_cast<int>(desc->dhead); |
| 149 | + int x_stride_seqlen = desc->x_stride_seqlen; |
| 150 | + int x_stride_nhead = desc->x_stride_nhead; |
| 151 | + |
| 152 | + RoPE<<<1, (sdaaStream_t)stream>>>(t_ptr, pos_ptr, sin_ptr, cos_ptr, x_stride_seqlen, x_stride_nhead, seqlen, nhead, dhead); |
| 153 | + return STATUS_SUCCESS; |
| 154 | +} |
| 155 | + |
| 156 | +infiniopStatus_t tecoDestroyRoPEDescriptor(RoPETecoDescriptor_t desc){ |
| 157 | + delete desc; |
| 158 | + return STATUS_SUCCESS; |
| 159 | +} |
0 commit comments