diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index 167be9f2e0d74e..84e41d2bc259a0 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -212,7 +212,9 @@ struct BroadcastDataLoader { using VecType = phi::kps::details::VectorType; VecType vec_temp; - int thread_offset = threadIdx.x + blockIdx.x * blockDim.x; + int64_t thread_offset = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); const VecType *__restrict__ vec_input = reinterpret_cast(ins[Index]); vec_temp = vec_input[thread_offset]; diff --git a/paddle/phi/kernels/funcs/detail/gru_gpu_kernel.h b/paddle/phi/kernels/funcs/detail/gru_gpu_kernel.h index b491cbe120d06f..6311e366578337 100644 --- a/paddle/phi/kernels/funcs/detail/gru_gpu_kernel.h +++ b/paddle/phi/kernels/funcs/detail/gru_gpu_kernel.h @@ -128,7 +128,9 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T c0 = 0.0f; T b0[Tiled_size]; - int COL = blockIdx.x * blockDim.x + threadIdx.x; + int64_t COL = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int Tiled_mask = ((1 << Tiled_size) - 1); // Tiled matrix multiply using register shift, faster than sm. if (prev_output_value) { @@ -185,7 +187,9 @@ __global__ void KeFastCollectiveGruOut(const T *gate_weight, int frame_size, ActivationType act_node, bool origin_mode) { - int COL = blockIdx.x * blockDim.x + threadIdx.x; + int64_t COL = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); T a0 = 0.0f; T b0[Tiled_size]; diff --git a/paddle/phi/kernels/funcs/fake_quantize_functor.cu b/paddle/phi/kernels/funcs/fake_quantize_functor.cu index be3c3de01d6590..d8da1e7ba480c7 100644 --- a/paddle/phi/kernels/funcs/fake_quantize_functor.cu +++ b/paddle/phi/kernels/funcs/fake_quantize_functor.cu @@ -29,7 +29,9 @@ struct QuantizeDataType { template __global__ void FindAbsMaxKernel(const T *in, const int64_t n, T *out) { - int bid = threadIdx.x + blockIdx.x * blockDim.x; + int64_t bid = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); int tid = threadIdx.x; extern __shared__ char *shared_max_data_tmp[]; @@ -70,7 +72,9 @@ __global__ void ClipAndQuantKernel(const T *in, const int round_type, const int64_t n, T *out) { - int bid = threadIdx.x + blockIdx.x * blockDim.x; + int64_t bid = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); int tid = threadIdx.x; using ComputeDataType = typename QuantizeDataType::type; @@ -155,7 +159,9 @@ __global__ void ClipAndQuantDequantKernel(const T *in, const int round_type, const int64_t n, T *out) { - int bid = threadIdx.x + blockIdx.x * blockDim.x; + int64_t bid = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); int tid = threadIdx.x; using ComputeDataType = typename QuantizeDataType::type; diff --git a/paddle/phi/kernels/funcs/fc_functor.cu b/paddle/phi/kernels/funcs/fc_functor.cu index cb35feee328a75..fb65c985d15987 100644 --- a/paddle/phi/kernels/funcs/fc_functor.cu +++ b/paddle/phi/kernels/funcs/fc_functor.cu @@ -63,7 +63,9 @@ struct FcTypeTraits { template __global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; + int64_t tid = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); if (tid < num) { int bias_idx = tid % K; const T bias_ptr = bias[bias_idx]; diff --git a/paddle/phi/kernels/funcs/math_function.cu b/paddle/phi/kernels/funcs/math_function.cu index afd3e1cb25b214..2300d7e11335aa 100644 --- a/paddle/phi/kernels/funcs/math_function.cu +++ b/paddle/phi/kernels/funcs/math_function.cu @@ -209,7 +209,10 @@ DEFINE_GPU_TRANS(6); template __global__ void FillConstantKernel(const int N, T* a, const T val) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + for (int64_t i = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + i < N; i += blockDim.x * gridDim.x) { a[i] = val; } diff --git a/paddle/phi/kernels/funcs/norm_utils.cu.h b/paddle/phi/kernels/funcs/norm_utils.cu.h index 4b1ed6ddb9c9e6..e48e4fd116b4db 100644 --- a/paddle/phi/kernels/funcs/norm_utils.cu.h +++ b/paddle/phi/kernels/funcs/norm_utils.cu.h @@ -370,7 +370,9 @@ __global__ void DoubleGradComputeDXWithGlobal(const T *dy, const int sample_size, const int64_t num, T *dx) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; + int64_t gid = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int stride = blockDim.x * gridDim.x; if (ddscale != nullptr) { for (int64_t i = gid; i < num; i += stride) { @@ -397,7 +399,9 @@ __global__ void DoubleGradComputeDDYWithGlobal(const T *ddx, const int sample_size, const int64_t num, T *ddy) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; + int64_t gid = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int stride = blockDim.x * gridDim.x; if (ddx != nullptr) { diff --git a/paddle/phi/kernels/funcs/quant_dequant.h b/paddle/phi/kernels/funcs/quant_dequant.h index f11c29a6ef7e7d..18a489903a75d9 100644 --- a/paddle/phi/kernels/funcs/quant_dequant.h +++ b/paddle/phi/kernels/funcs/quant_dequant.h @@ -91,8 +91,13 @@ __global__ void QuantKernel(const T* input, const int round_type, const float max_bound, const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; + int64_t n_id = + (static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x)) + << 2; + int64_t m_id = + static_cast(blockIdx.y) * static_cast(blockDim.y) + + static_cast(threadIdx.y); bool check = ((m_id < m) && (n_id < n)); if (check) { @@ -118,8 +123,13 @@ __global__ void QuantKernelWithVecSize(const T* input, const int round_type, const float max_bound, const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; + int64_t n_id = + (static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x)) + << 2; + int64_t m_id = + static_cast(blockIdx.y) * static_cast(blockDim.y) + + static_cast(threadIdx.y); bool check = ((m_id < m) && (n_id < n)); if (check) { @@ -145,8 +155,13 @@ __global__ void QuantKernelWithVecSize(const T* input, const int round_type, const float max_bound, const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) * 3; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; + int64_t n_id = + (static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x)) * + 3; + int64_t m_id = + static_cast(blockIdx.y) * static_cast(blockDim.y) + + static_cast(threadIdx.y); bool check = ((m_id < m) && (n_id < n)); if (check) { @@ -170,8 +185,13 @@ __global__ void QuantKernelWithVecSize(const T* input, const int round_type, const float max_bound, const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) * 2; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; + int64_t n_id = + (static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x)) * + 2; + int64_t m_id = + static_cast(blockIdx.y) * static_cast(blockDim.y) + + static_cast(threadIdx.y); bool check = ((m_id < m) && (n_id < n)); if (check) { @@ -193,8 +213,12 @@ __global__ void QuantKernelWithVecSize(const T* input, const int round_type, const float max_bound, const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x); - int m_id = blockIdx.y * blockDim.y + threadIdx.y; + int64_t n_id = + (static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x)); + int64_t m_id = + static_cast(blockIdx.y) * static_cast(blockDim.y) + + static_cast(threadIdx.y); bool check = ((m_id < m) && (n_id < n)); if (check) { @@ -320,7 +344,10 @@ __global__ void DequantKernel(T* output, const float* dequant_out_scale_data) { int numel = m * n; int stride = blockDim.x * gridDim.x * VecSize; - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + int64_t idx = + (static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x)) * + VecSize; int col_id = idx % n; phi::AlignedVector in_vec; @@ -366,7 +393,10 @@ __global__ void DequantKernelWithScaleOfInputAndWeight( float quant_max_bound) { int numel = m * n; int stride = blockDim.x * gridDim.x * VecSize; - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + int64_t idx = + (static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x)) * + VecSize; int col_id = idx % n; phi::AlignedVector in_vec; diff --git a/paddle/phi/kernels/funcs/scatter.cu.h b/paddle/phi/kernels/funcs/scatter.cu.h index defbcf23b0d9f9..e9670113d9c5eb 100644 --- a/paddle/phi/kernels/funcs/scatter.cu.h +++ b/paddle/phi/kernels/funcs/scatter.cu.h @@ -402,7 +402,8 @@ inline DenseTensor restride_dim(const phi::DenseTensor& src, template __global__ void scatter_gather_elementwise_kernel(int N, func_t f) { constexpr int nv = nt * vt; - int idx = nv * blockIdx.x + threadIdx.x; + int64_t idx = + nv * static_cast(blockIdx.x) + static_cast(threadIdx.x); #pragma unroll for (int i = 0; i < vt; ++i) { diff --git a/paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h b/paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h index 50f8342f6da14d..5cb521fbc07a2e 100644 --- a/paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h +++ b/paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h @@ -26,7 +26,9 @@ __global__ void FlattenIndicesKernel(const IntT* indices, const int64_t non_zero_num, const int64_t sparse_dim, IntT* out) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; + int64_t tid = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); phi::funcs::sparse::FlattenIndices(indices, sparse_offsets, non_zero_num, @@ -42,7 +44,9 @@ __global__ void IndexToCoordinateKernel(const IntT* index, const int64_t non_zero_num, const int64_t sparse_dim, IntT* indices) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; + int64_t tid = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); IndexToCoordinate(index, dims, non_zero_num, diff --git a/paddle/phi/kernels/funcs/sparse/scatter.cu.h b/paddle/phi/kernels/funcs/sparse/scatter.cu.h index f27174d5818186..119631b9d0aa94 100644 --- a/paddle/phi/kernels/funcs/sparse/scatter.cu.h +++ b/paddle/phi/kernels/funcs/sparse/scatter.cu.h @@ -41,7 +41,9 @@ __global__ void ScatterKernel(const T* input, const int rulebook_len, const int channels, T* out) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; + int64_t tid = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); const int vec_channels = channels / VecSize; using LoadT = phi::AlignedVector; using StoreT = phi::AlignedVector; @@ -82,7 +84,9 @@ __global__ void ScatterKernelV2(const T* input, const int channels, const int buffer_counts, T* out) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; + int64_t tid = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); const int vec_channels = channels / VecSize; using LoadT = phi::AlignedVector; using StoreT = phi::AlignedVector; diff --git a/paddle/phi/kernels/funcs/sync_batch_norm_utils.h b/paddle/phi/kernels/funcs/sync_batch_norm_utils.h index 0715cec7fc8215..cc0c6a7ff30ea2 100644 --- a/paddle/phi/kernels/funcs/sync_batch_norm_utils.h +++ b/paddle/phi/kernels/funcs/sync_batch_norm_utils.h @@ -87,7 +87,9 @@ __global__ void KeSyncAndMovingStats(BatchNormParamType *means, BatchNormParamType *moving_means, BatchNormParamType *moving_variances) { // sync stats across multi-devices - int gid = blockIdx.x * blockDim.x + threadIdx.x; + int64_t gid = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int stride = blockDim.x * gridDim.x; for (int i = gid; i < C; i += stride) { auto mean = means[i] / (*num_dev); @@ -117,7 +119,9 @@ static __global__ void KeNormAffine(const T *x, const int M, const int64_t num, T *y) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; + int64_t gid = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int stride = blockDim.x * gridDim.x; for (int64_t i = gid; i < num; i += stride) { const int c = layout == DataLayout::kNCHW ? (i / M) % C : i % C; @@ -180,12 +184,18 @@ __global__ void KeBackwardLocalStats2D(const T *dy, BatchNormParamType *sum_dy_prod) { __shared__ BatchNormParamType smem_sum[BlockDim]; __shared__ BatchNormParamType smem_square_sum[BlockDim]; - for (int k = blockIdx.x * blockDim.x + threadIdx.x; k < C; + for (int64_t k = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + k < C; k += gridDim.x * blockDim.x) { BatchNormParamType sum1 = 0.; BatchNormParamType sum2 = 0.; auto mean = means[k]; - for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < N * M; + for (int64_t i = static_cast(blockIdx.y) * + static_cast(blockDim.y) + + static_cast(threadIdx.y); + i < N * M; i += gridDim.y * blockDim.y) { int id = layout == DataLayout::kNCHW ? (i / M) * C * M + k * M + i % M : i * C + k; @@ -287,7 +297,10 @@ static __global__ void KeBNBackwardScaleBias2D( __shared__ BatchNormParamType smem_sum[BlockDim]; __shared__ BatchNormParamType smem_square_sum[BlockDim]; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size; + for (int64_t i = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + i < outer_size; i += gridDim.x * blockDim.x) { BatchNormParamType ds_sum = 0.; BatchNormParamType db_sum = 0.; @@ -341,7 +354,9 @@ static __global__ void KeBNRestoreData(T *x, int M, int64_t num, const T *y) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; + int64_t gid = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int stride = blockDim.x * gridDim.x; for (int64_t i = gid; i < num; i += stride) { const int64_t c = layout == DataLayout::kNCHW ? (i / M) % C : i % C; @@ -366,7 +381,9 @@ static __global__ void KeBNBackwardData( const int64_t HxW, const int64_t num, T *dx) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; + int64_t gid = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int stride = blockDim.x * gridDim.x; auto scale = static_cast>(C) / num; auto dev_num = num_dev[0]; diff --git a/paddle/phi/kernels/funcs/weight_dequant_functor.h b/paddle/phi/kernels/funcs/weight_dequant_functor.h index 7377cab0ac2db5..7d63c92fd57814 100644 --- a/paddle/phi/kernels/funcs/weight_dequant_functor.h +++ b/paddle/phi/kernels/funcs/weight_dequant_functor.h @@ -128,7 +128,9 @@ __global__ void int8_weight_only_dequant(const uint8_t* weight, AlignedVector vec_out; int warp_id = threadIdx.x / 32, lane_id = threadIdx.x % 32; - int tile_id = blockIdx.x * blockDim.x / 32 + warp_id; + int64_t tile_id = + static_cast(blockIdx.x) * static_cast(blockDim.x) / 32 + + warp_id; // Every two rows of the original weights are interleaved into a row with // stride of 64, so if each thread processes 16 elements(for int8, we can use // ldg.128 to load weights), then every group of four adjacent threads will @@ -184,7 +186,9 @@ __global__ void int4_weight_only_dequant(const uint8_t* weight, AlignedVector vec_out; int warp_id = threadIdx.x / 32, lane_id = threadIdx.x % 32; - int tile_id = blockIdx.x * blockDim.x / 32 + warp_id; + int64_t tile_id = + static_cast(blockIdx.x) * static_cast(blockDim.x) / 32 + + warp_id; // Every 4 rows of the original weights are interleaved into a row with // stride of 32, so if each thread processes 16 elements(for int8, we can use // ldg.128 to load weights), then every group of two adjacent threads will @@ -242,7 +246,9 @@ __global__ void int8_weight_only_dequant(const uint8_t* weight, AlignedVector vec_out; int warp_id = threadIdx.x / 32, lane_id = threadIdx.x % 32; - int tile_id = blockIdx.x * blockDim.x / 32 + warp_id; + int64_t tile_id = + static_cast(blockIdx.x) * static_cast(blockDim.x) / 32 + + warp_id; // Every two rows of the original weights are interleaved into a row with // stride of 64, so if each thread processes 16 elements(for int8, we can use // ldg.128 to load weights), then every group of four adjacent threads will @@ -302,7 +308,9 @@ __global__ void int4_weight_only_dequant(const uint8_t* weight, AlignedVector vec_out; int warp_id = threadIdx.x / 32, lane_id = threadIdx.x % 32; - int tile_id = blockIdx.x * blockDim.x / 32 + warp_id; + int64_t tile_id = + static_cast(blockIdx.x) * static_cast(blockDim.x) / 32 + + warp_id; // Every two rows of the original weights are interleaved into a row with // stride of 64, so if each thread processes 16 elements(for int8, we can use // ldg.128 to load weights), then every group of four adjacent threads will diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.cu b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.cu index 6aed60cf1c23b6..3bb9aa45ad2a4f 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.cu +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.cu @@ -71,8 +71,12 @@ __global__ void naive_conv2d_kernel(const T *input, int N = oc; int kc = ic / groups; int K = kc * kh * kw; - int m_i = threadIdx.x + blockIdx.x * blockDim.x; - int n_i = threadIdx.y + blockIdx.y * blockDim.y; + int64_t m_i = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); + int64_t n_i = + static_cast(threadIdx.y) + + static_cast(blockIdx.y) * static_cast(blockDim.y); if (m_i >= M || n_i >= N) return; int batch_i = m_i / (oh * ow); diff --git a/paddle/phi/kernels/fusion/gpu/block_attn.h b/paddle/phi/kernels/fusion/gpu/block_attn.h index 9b27233f5dff1d..a04ab84b0c1a2e 100644 --- a/paddle/phi/kernels/fusion/gpu/block_attn.h +++ b/paddle/phi/kernels/fusion/gpu/block_attn.h @@ -3788,7 +3788,9 @@ __global__ void ShiftSmoothQuant(const T *input, phi::AlignedVector smooth_vec; phi::AlignedVector out_vec; - for (int linear_id = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t linear_id = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); linear_id * VecSize < num; linear_id += gridDim.x * blockDim.x) { int idx = linear_id * VecSize; @@ -3826,7 +3828,9 @@ __global__ void ShiftSmooth(const T *input, phi::AlignedVector smooth_vec; phi::AlignedVector out_vec; - for (int linear_id = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t linear_id = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); linear_id * VecSize < num; linear_id += gridDim.x * blockDim.x) { int idx = linear_id * VecSize; diff --git a/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu b/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu index dc2d495f7bb18d..007f659b783a39 100644 --- a/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu @@ -149,8 +149,13 @@ __global__ void QuantKernel(const data_t* input, const int round_type, const float max_bound, const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; + int64_t n_id = + (static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x)) + << 2; + int64_t m_id = + static_cast(blockIdx.y) * static_cast(blockDim.y) + + static_cast(threadIdx.y); bool check = ((m_id < m) && (n_id < n)); if (check) { @@ -177,8 +182,13 @@ __global__ void FP8QuantKernel(const data_t* input, const int round_type, const float max_bound, const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; + int64_t n_id = + (static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x)) + << 2; + int64_t m_id = + static_cast(blockIdx.y) * static_cast(blockDim.y) + + static_cast(threadIdx.y); bool check = ((m_id < m) && (n_id < n)); if (check) { @@ -207,8 +217,13 @@ __global__ void QuantKernel(const data_t* input, const int round_type, const float max_bound, const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; + int64_t n_id = + (static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x)) + << 2; + int64_t m_id = + static_cast(blockIdx.y) * static_cast(blockDim.y) + + static_cast(threadIdx.y); bool check = ((m_id < m) && (n_id < n)); if (check) { diff --git a/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu index dcedf010bad4b6..6cd10d8f935d59 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu @@ -54,15 +54,21 @@ __global__ void SoftmaxMaskFuseV1GPUKernel(const T* x_data, constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; constexpr int kOneLoadingCounts = 4; - int data_first_idx = - (blockDim.y * - (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + - threadIdx.y) * - kLocalBatchSize; - - int mask_fist_idx = - (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * - kLocalBatchSize; + int64_t data_first_idx = (static_cast(blockDim.y) * + (static_cast(blockIdx.x) + + static_cast(gridDim.x) * + (static_cast(blockIdx.y) + + static_cast(gridDim.y) * + static_cast(blockIdx.z))) + + static_cast(threadIdx.y)) * + kLocalBatchSize; + + int64_t mask_fist_idx = (static_cast(blockDim.y) * + (static_cast(blockIdx.x) + + static_cast(gridDim.x) * + static_cast(blockIdx.z)) + + static_cast(threadIdx.y)) * + kLocalBatchSize; // batch_count might not be a multiple of kLocalBatchSize. Check how // many batches have to computed within this WARP. diff --git a/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_utils.h b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_utils.h index c1d60cbffee2fa..4672c4f3343457 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_utils.h @@ -131,12 +131,17 @@ __global__ void FusedSoftmaxMaskVecKernel(T* dst, // gridDim/blockIdx = (DIV_UP(seq_len, warps_per_block), batch_size, head_num) // every block processes 4(warps_per_block) sequences // seq_id = seq_id * 4 + warp_id, eg.seq_len=128, 127=31*4+3 - int seq_id = blockIdx.x * warps_per_block + threadIdx.y; + int64_t seq_id = static_cast(blockIdx.x) * warps_per_block + + static_cast(threadIdx.y); if (seq_id >= seq_len) return; // ((bid*head_num + hid)*seq_len + seq_id) * seq_len - int offset = - ((blockIdx.y * gridDim.z + blockIdx.z) * seq_len + seq_id) * seq_len; + int64_t offset = + ((static_cast(blockIdx.y) * static_cast(gridDim.z) + + static_cast(blockIdx.z)) * + seq_len + + seq_id) * + seq_len; // (bid * seq_len + seq_id) * seq_len int mask_offset = (blockIdx.y * seq_len + seq_id) * seq_len; src += offset; diff --git a/paddle/phi/kernels/fusion/gpu/quant_dequant_kernel.h b/paddle/phi/kernels/fusion/gpu/quant_dequant_kernel.h index 11e5eb072c474a..60aac8dae70507 100644 --- a/paddle/phi/kernels/fusion/gpu/quant_dequant_kernel.h +++ b/paddle/phi/kernels/fusion/gpu/quant_dequant_kernel.h @@ -56,8 +56,13 @@ __global__ void QuantKernel(const T* input, const int round_type, const float max_bound, const float min_bound) { - int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; - int m_id = blockIdx.y * blockDim.y + threadIdx.y; + int64_t n_id = + (static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x)) + << 2; + int64_t m_id = + static_cast(blockIdx.y) * static_cast(blockDim.y) + + static_cast(threadIdx.y); bool check = ((m_id < m) && (n_id < n)); if (check) { @@ -112,7 +117,10 @@ __global__ void DequantKernel(T* output, const float* dequant_out_scale_data) { int numel = m * n; int stride = blockDim.x * gridDim.x * VecSize; - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + int64_t idx = + (static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x)) * + VecSize; int col_id = idx % n; phi::AlignedVector in_vec; diff --git a/paddle/phi/kernels/gpu/adagrad_kernel.cu b/paddle/phi/kernels/gpu/adagrad_kernel.cu index 6a7d428d2a6a04..a98b1c263f6982 100644 --- a/paddle/phi/kernels/gpu/adagrad_kernel.cu +++ b/paddle/phi/kernels/gpu/adagrad_kernel.cu @@ -36,7 +36,9 @@ __global__ void AdagradGPUKernel(const T* param, MT* moment_out, MT* master_param_out, int64_t num) { - auto idx = blockDim.x * blockIdx.x + threadIdx.x; + int64_t idx = + static_cast(blockDim.x) * static_cast(blockIdx.x) + + static_cast(threadIdx.x); MT lr_data = static_cast(lr[0]); for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { diff --git a/paddle/phi/kernels/gpu/adamax_kernel.cu b/paddle/phi/kernels/gpu/adamax_kernel.cu index b4b8dff83b1e77..276d74973c9200 100644 --- a/paddle/phi/kernels/gpu/adamax_kernel.cu +++ b/paddle/phi/kernels/gpu/adamax_kernel.cu @@ -36,7 +36,9 @@ __global__ void AdamaxGPUKernel(const T* param, MT* moment_out, MT* inf_norm_out, MT* master_param_out) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t idx = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); MT lr = static_cast(learning_rate[0]); MT d_pow = static_cast(beta1_pow[0]); diff --git a/paddle/phi/kernels/gpu/affine_channel_grad_kernel.cu b/paddle/phi/kernels/gpu/affine_channel_grad_kernel.cu index 6fdcebde8e6d94..e0087ef8090fdc 100644 --- a/paddle/phi/kernels/gpu/affine_channel_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/affine_channel_grad_kernel.cu @@ -35,7 +35,9 @@ __global__ static inline void KeAffineChannelCUDA(const T* x, const int64_t HxW, const int64_t num, T* y) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; + int64_t gid = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int stride = blockDim.x * gridDim.x; for (int64_t i = gid; i < num; i += stride) { const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C; diff --git a/paddle/phi/kernels/gpu/affine_channel_kernel.cu b/paddle/phi/kernels/gpu/affine_channel_kernel.cu index dec4e1f5946d61..c551ce66c6b819 100644 --- a/paddle/phi/kernels/gpu/affine_channel_kernel.cu +++ b/paddle/phi/kernels/gpu/affine_channel_kernel.cu @@ -35,7 +35,9 @@ __global__ static inline void KeAffineChannelCUDA(const T* x, const int64_t HxW, const int64_t num, T* y) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; + int64_t gid = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int stride = blockDim.x * gridDim.x; for (int64_t i = gid; i < num; i += stride) { const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C; diff --git a/paddle/phi/kernels/gpu/apply_per_channel_scale_kernel.cu b/paddle/phi/kernels/gpu/apply_per_channel_scale_kernel.cu index dce5a710f04bcd..c8d553ea20e089 100644 --- a/paddle/phi/kernels/gpu/apply_per_channel_scale_kernel.cu +++ b/paddle/phi/kernels/gpu/apply_per_channel_scale_kernel.cu @@ -93,7 +93,9 @@ __global__ void apply_per_channel_scale( using HALF_2_TYPE = typename CUDA_HALF_2_TYPE_TARIS::type; static constexpr int kElems = sizeof(AccessType) / sizeof(T); T scale[kElems], act_vec[kElems]; - int col_offset = blockIdx.x * blockDim.x + threadIdx.x; + int64_t col_offset = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int row_offset = blockIdx.y; if (col_offset * kElems >= cols || row_offset * kProcessRows >= rows) return; act += row_offset * kProcessRows * cols; diff --git a/paddle/phi/kernels/gpu/argsort_grad_kernel.cu b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu index b6c0aa797b8015..864915b2c1f1d4 100644 --- a/paddle/phi/kernels/gpu/argsort_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu @@ -64,7 +64,9 @@ static __global__ void FillFlattenGrad(const T* dO, const IndType* indices, int64_t size, T* dX) { - int index = threadIdx.x + blockIdx.x * blockDim.x; + int64_t index = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); int stride = blockDim.x * gridDim.x; for (int64_t i = index; i < size; i += stride) { dX[indices[i]] = dO[i]; diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 7fc1c73f625cd0..b8e75eb08a8b28 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -97,7 +97,9 @@ static __global__ void KeBNBackwardData(const T *dy, const int64_t HxW, const int64_t num, T *dx) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; + int64_t gid = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int stride = blockDim.x * gridDim.x; for (int64_t i = gid; i < num; i += stride) { const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C; @@ -119,7 +121,9 @@ static __global__ void KeBNRestoreData(const phi::DataLayout layout, int M, const int64_t num, const T *y) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; + int64_t gid = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int stride = blockDim.x * gridDim.x; for (int64_t i = gid; i < num; i += stride) { const int c = layout == phi::DataLayout::kNCHW ? (i / M) % C : i % C; @@ -272,7 +276,10 @@ static __global__ void BNBackward2DChannelLastStage1( int outer_loop_stride = gridDim.x * blockDim.x; int inner_loop_stride = gridDim.y * blockDim.y; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size; + for (int64_t i = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + i < outer_size; i += outer_loop_stride) { BatchNormParamType x_sum = static_cast>(0); BatchNormParamType x_square_sum = static_cast>(0); @@ -347,7 +354,10 @@ static __global__ void BNBackward2DChannelLastStage2( int outer_loop_stride = gridDim.x * blockDim.x; int inner_loop_stride = gridDim.y * blockDim.y; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size; + for (int64_t i = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + i < outer_size; i += outer_loop_stride) { BatchNormParamType ds_sum = static_cast>(0); BatchNormParamType db_sum = static_cast>(0); @@ -410,7 +420,10 @@ static __global__ void BNBackward2DChannelLastStage3( int outer_loop_stride = gridDim.x * blockDim.x; int inner_loop_stride = gridDim.y * blockDim.y; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size; + for (int64_t i = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + i < outer_size; i += outer_loop_stride) { BatchNormParamType mean_val = means[i]; BatchNormParamType inv_var_val = variances[i]; diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index fc21a8b0ff1ea4..cc1c009e36309e 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -65,7 +65,9 @@ static __global__ void BNForwardInference(const T *x, const int64_t HxW, const double epsilon, T *y) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; + int64_t gid = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int stride = blockDim.x * gridDim.x; int64_t num = HxW * N * C; for (int64_t i = gid; i < num; i += stride) { @@ -82,7 +84,9 @@ static __global__ void InverseVariance(const BatchNormParamType *variance, const double epsilon, const int C, BatchNormParamType *inv_variance) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; + int64_t tid = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); if (tid < C) { inv_variance[tid] = 1 / sqrt(variance[tid] + epsilon); } @@ -100,7 +104,9 @@ static __global__ void BN1DForwardInference( const int64_t HxW, const double epsilon, T *y) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; + int64_t gid = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int stride = blockDim.x * gridDim.x; int64_t num = static_cast(N) * C * HxW; for (int64_t i = gid; i < num; i += stride) { @@ -233,7 +239,10 @@ static __global__ void BNForwardTraining2DChannelLastCompStat( int outer_loop_stride = gridDim.x * blockDim.x; int inner_loop_stride = gridDim.y * blockDim.y; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size; + for (int64_t i = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + i < outer_size; i += outer_loop_stride) { BatchNormParamType x_sum = static_cast>(0); BatchNormParamType x_square_sum = static_cast>(0); @@ -329,7 +338,10 @@ static __global__ void BNForwardTraining2DChannelLastWriteRes( int outer_loop_stride = gridDim.x * blockDim.x; int inner_loop_stride = gridDim.y * blockDim.y; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size; + for (int64_t i = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + i < outer_size; i += outer_loop_stride) { BatchNormParamType mean_val = compute_mean[i]; BatchNormParamType inv_var_val = compute_inv_var[i]; @@ -374,7 +386,10 @@ static __global__ void BNForwardTraining2DCompStat( int outer_loop_stride = gridDim.y * blockDim.y; int inner_loop_stride = gridDim.x * blockDim.x; - for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < outer_size; + for (int64_t i = + static_cast(blockIdx.y) * static_cast(blockDim.y) + + static_cast(threadIdx.y); + i < outer_size; i += outer_loop_stride) { BatchNormParamType x_sum = static_cast>(0); BatchNormParamType x_square_sum = static_cast>(0); @@ -497,7 +512,10 @@ static __global__ void BNForwardTraining2DWriteRes( int outer_loop_stride = gridDim.y * blockDim.y; int inner_loop_stride = gridDim.x * blockDim.x; - for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < outer_size; + for (int64_t i = + static_cast(blockIdx.y) * static_cast(blockDim.y) + + static_cast(threadIdx.y); + i < outer_size; i += outer_loop_stride) { BatchNormParamType mean_val = compute_mean[i]; BatchNormParamType inv_var_val = compute_inv_var[i]; diff --git a/paddle/phi/kernels/gpu/correlation_grad_kernel.cu b/paddle/phi/kernels/gpu/correlation_grad_kernel.cu index 2a1f277d8e77f9..7f95b23b2d1cd2 100644 --- a/paddle/phi/kernels/gpu/correlation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/correlation_grad_kernel.cu @@ -36,7 +36,9 @@ __global__ void correlation_backward_input1(int64_t n, const int max_displacement, const int stride1, const int stride2) { - int thread_index = blockIdx.x * blockDim.x + threadIdx.x; + int64_t thread_index = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int64_t total_hw_c = input_channel * input_height * input_width; if (thread_index >= total_hw_c) return; @@ -117,7 +119,9 @@ __global__ void correlation_backward_input2(int64_t n, const int max_displacement, const int stride1, const int stride2) { - int thread_index = blockIdx.x * blockDim.x + threadIdx.x; + int64_t thread_index = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int64_t total_hw_c = input_channel * input_height * input_width; if (thread_index >= total_hw_c) return; diff --git a/paddle/phi/kernels/gpu/cross_entropy_bwd_w_downcast.cu b/paddle/phi/kernels/gpu/cross_entropy_bwd_w_downcast.cu index 88af9add2c9a36..9d22da2fcf9ec8 100644 --- a/paddle/phi/kernels/gpu/cross_entropy_bwd_w_downcast.cu +++ b/paddle/phi/kernels/gpu/cross_entropy_bwd_w_downcast.cu @@ -136,7 +136,8 @@ __global__ void SoftmaxWithCrossEntropyGradHardLabelWarp( const int threads_per_warp = 32; const int threads_per_block = warps_per_block * threads_per_warp; - int tid = blockIdx.x * threads_per_block + threadIdx.x; + int64_t tid = static_cast(blockIdx.x) * threads_per_block + + static_cast(threadIdx.x); int warp_id = threadIdx.x / threads_per_warp; int lane_id = threadIdx.x % threads_per_warp; diff --git a/paddle/phi/kernels/gpu/ctc_align_kernel.cu b/paddle/phi/kernels/gpu/ctc_align_kernel.cu index 22b535d5dda648..21ff3221dd57b3 100644 --- a/paddle/phi/kernels/gpu/ctc_align_kernel.cu +++ b/paddle/phi/kernels/gpu/ctc_align_kernel.cu @@ -60,7 +60,9 @@ __global__ void PaddingMergeAndDelCudaKernel(const int64_t num_token, const int64_t batch_size, T* output, T* output_length) { - int ind = blockIdx.x * blockDim.x + threadIdx.x; + int64_t ind = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); if (ind >= batch_size) return; int output_idx = ind * num_token; T prev_token = -1; diff --git a/paddle/phi/kernels/gpu/depthwise_conv.h b/paddle/phi/kernels/gpu/depthwise_conv.h index 2edac5eba5d9ef..93cbee772797b8 100644 --- a/paddle/phi/kernels/gpu/depthwise_conv.h +++ b/paddle/phi/kernels/gpu/depthwise_conv.h @@ -223,7 +223,9 @@ __device__ __inline__ void KernelDepthwiseConvNCHW( ARG_DEFINE_KernelDepthwiseConv) { const int fw_size = c_filter != -1 ? c_filter : filter_width; const int fh_size = c_filter != -1 ? c_filter : filter_height; - int idx = threadIdx.x + blockIdx.x * blockDim.x; + int64_t idx = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); if (idx >= (output_channels * batch_size * output_height * output_width)) return; @@ -274,7 +276,9 @@ __device__ __inline__ void KernelDepthwiseConvNHWC( ARG_DEFINE_KernelDepthwiseConv) { const int fw_size = c_filter != -1 ? c_filter : filter_width; const int fh_size = c_filter != -1 ? c_filter : filter_height; - int idx = threadIdx.x + blockIdx.x * blockDim.x; + int64_t idx = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); if (idx >= (output_channels * batch_size * output_height * output_width)) { return; } @@ -371,7 +375,8 @@ template __device__ __inline__ void KernelDepthwiseConvCFilterNHWC( ARG_DEFINE_KernelDepthwiseConv) { const int batch = blockIdx.z; - int h_out = blockIdx.x * dilate_height + blockIdx.y; + int64_t h_out = static_cast(blockIdx.x) * dilate_height + + static_cast(blockIdx.y); if (h_out >= output_height) { return; } @@ -546,7 +551,9 @@ __device__ __inline__ void KernelDepthwiseConvInputGradNCHW( ARG_DEFINE_KernelDepthwiseConvInputGrad) { const int fw_size = c_filter != -1 ? c_filter : filter_width; const int fh_size = c_filter != -1 ? c_filter : filter_height; - int idx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t idx = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); if (idx >= batch_size * input_channels * input_height * input_width) { return; } @@ -603,7 +610,8 @@ template __device__ __inline__ void KernelDepthwiseConvInputGradNHWC( ARG_DEFINE_KernelDepthwiseConvInputGrad) { const int batch = blockIdx.z; - int h_in = blockIdx.x * dilate_height + blockIdx.y; + int64_t h_in = static_cast(blockIdx.x) * dilate_height + + static_cast(blockIdx.y); if (h_in >= input_height) { return; } @@ -724,7 +732,8 @@ template __device__ __inline__ void KernelDepthwiseConvInputGradCFilterNHWC( ARG_DEFINE_KernelDepthwiseConvInputGrad) { - int h_in = blockIdx.x * dilate_height + blockIdx.y; + int64_t h_in = static_cast(blockIdx.x) * dilate_height + + static_cast(blockIdx.y); if (h_in >= input_height) { return; } @@ -937,7 +946,11 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW( int kh_id = blockIdx.y; int oc_id = blockIdx.z; int ic_id = oc_id / filter_multiplier; - int idx = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x; + int64_t idx = + ((static_cast(blockIdx.z) * static_cast(gridDim.y)) + + static_cast(blockIdx.y)) * + static_cast(gridDim.x) + + static_cast(blockIdx.x); const int ohw = output_height * output_width; const int onhw = num * ohw; @@ -1168,7 +1181,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC( const int dilate_width, T* filter_grad_data) { const int bid = blockIdx.z; - int image_h = blockIdx.x * dilate_height + blockIdx.y; + int64_t image_h = static_cast(blockIdx.x) * dilate_height + + static_cast(blockIdx.y); if (image_h >= output_height) { return; } diff --git a/paddle/phi/kernels/gpu/determinant_kernel.cu b/paddle/phi/kernels/gpu/determinant_kernel.cu index 877a61fc902bee..203f611a2dd3d8 100644 --- a/paddle/phi/kernels/gpu/determinant_kernel.cu +++ b/paddle/phi/kernels/gpu/determinant_kernel.cu @@ -110,7 +110,9 @@ __global__ void GetDetFromLUComplex(const T* lu_data, int64_t n, int64_t batch_size, T* out_data) { - int idx = threadIdx.x + blockIdx.x * blockDim.x; + int64_t idx = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); if (idx < batch_size) { int offset_lu = idx * n * n; int offset_ipiv = idx * n; diff --git a/paddle/phi/kernels/gpu/edit_distance_kernel.cu b/paddle/phi/kernels/gpu/edit_distance_kernel.cu index 2e2f3dd127e9e4..ba477c9c0c6a21 100644 --- a/paddle/phi/kernels/gpu/edit_distance_kernel.cu +++ b/paddle/phi/kernels/gpu/edit_distance_kernel.cu @@ -30,7 +30,9 @@ using phi::PADDLE_CUDA_NUM_THREADS; template __global__ void FillFirstRow(T* dist, const int N) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; + int64_t idx = + static_cast(blockDim.x) * static_cast(blockIdx.x) + + static_cast(threadIdx.x); if (idx < N + 1) { dist[idx] = idx; } @@ -38,7 +40,9 @@ __global__ void FillFirstRow(T* dist, const int N) { template __global__ void FillFirstColumn(T* dist, const int M, const int N) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; + int64_t idx = + static_cast(blockDim.x) * static_cast(blockIdx.x) + + static_cast(threadIdx.x); if (idx < M + 1) { dist[idx * (N + 1)] = idx; } @@ -51,7 +55,9 @@ __global__ void Levenshtein(T* dist, const int M, const int N, const int start) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; + int64_t idx = + static_cast(blockDim.x) * static_cast(blockIdx.x) + + static_cast(threadIdx.x); int offset = N; int index = start + idx * offset; int row = index / (N + 1); @@ -68,7 +74,9 @@ __global__ void Levenshtein(T* dist, template __global__ void SetOutput( T* out, const T* dist, const int M, const int N, bool normalized) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; + int64_t idx = + static_cast(blockDim.x) * static_cast(blockIdx.x) + + static_cast(threadIdx.x); if (idx == 0) { out[0] = normalized ? dist[M * (N + 1) + N] / N : dist[M * (N + 1) + N]; } diff --git a/paddle/phi/kernels/gpu/embedding_grad_add_to_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_add_to_kernel.cu index aad460475aaec0..2f8a8600ba1e58 100644 --- a/paddle/phi/kernels/gpu/embedding_grad_add_to_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_grad_add_to_kernel.cu @@ -38,7 +38,9 @@ __global__ void EmbeddingGradAddTo(T* main_grad_out, const int64_t num_tokens, const int64_t token_length) { int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * gridDim.x; + int64_t idy = + static_cast(blockIdx.x) + + static_cast(threadIdx.y) * static_cast(gridDim.x); while (idy < num_tokens) { auto id = static_cast(token_indices[idy]); diff --git a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu index 7af60601ad00aa..e7c5892df5deac 100644 --- a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu @@ -48,7 +48,9 @@ __global__ void EmbeddingGrad(T* table, const int64_t K, const int64_t D) { int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * gridDim.x; + int64_t idy = + static_cast(blockIdx.x) + + static_cast(threadIdx.y) * static_cast(gridDim.x); while (idy < K) { auto id = static_cast(ids[idy]); diff --git a/paddle/phi/kernels/gpu/embedding_kernel.cu b/paddle/phi/kernels/gpu/embedding_kernel.cu index 7e87af07220629..0e9b2ce4cb49fe 100644 --- a/paddle/phi/kernels/gpu/embedding_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_kernel.cu @@ -30,7 +30,9 @@ __global__ void EmbeddingFW(T *output, const int64_t D, const int64_t padding_idx) { int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * gridDim.x; + int64_t idy = + static_cast(blockIdx.x) + + static_cast(threadIdx.y) * static_cast(gridDim.x); while (idy < K) { auto id = static_cast(ids[idy]); diff --git a/paddle/phi/kernels/gpu/embedding_with_scaled_gradient_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_with_scaled_gradient_grad_kernel.cu index 13d7d0fa879ab6..ef8c771ef5c843 100644 --- a/paddle/phi/kernels/gpu/embedding_with_scaled_gradient_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_with_scaled_gradient_grad_kernel.cu @@ -62,7 +62,9 @@ __global__ void EmbeddingGrad(T* table, const int64_t K, const int64_t D) { int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * gridDim.x; + int64_t idy = + static_cast(blockIdx.x) + + static_cast(threadIdx.y) * static_cast(gridDim.x); while (idy < K) { auto id = static_cast(ids[idy]); diff --git a/paddle/phi/kernels/gpu/fused_token_prune_kernel.cu b/paddle/phi/kernels/gpu/fused_token_prune_kernel.cu index 7d53bfb146c150..6749e462b18915 100644 --- a/paddle/phi/kernels/gpu/fused_token_prune_kernel.cu +++ b/paddle/phi/kernels/gpu/fused_token_prune_kernel.cu @@ -79,7 +79,9 @@ __global__ void TakeAlongAxis(const T* src, template __global__ void MaximumFirst(T* mat, int num_raws, int num_cols, T max_value) { int num_threads = num_raws; - int tid = blockIdx.x * blockDim.x + threadIdx.x; + int64_t tid = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int stride = blockDim.x * gridDim.x; for (; tid < num_threads; tid += stride) { diff --git a/paddle/phi/kernels/gpu/group_norm_kernel.cu b/paddle/phi/kernels/gpu/group_norm_kernel.cu index dcedf1873286a3..21dbd1242fbd1f 100644 --- a/paddle/phi/kernels/gpu/group_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/group_norm_kernel.cu @@ -216,7 +216,9 @@ __global__ void groupNormNDHWCSumSingerChannelKernel( // The instance in the batch. __shared__ float2 smem[THREADS_PER_BLOCK]; int32_t ni = blockIdx.z; - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x; + int64_t ci = static_cast(blockIdx.x) * + static_cast(params.cPerBlock) + + static_cast(threadIdx.x); if (ci >= params.c) { return; } @@ -265,8 +267,9 @@ __global__ void groupNormNDHWCSumKernel(const GroupNormNDHWCParams params) { // The instance in the batch. int32_t ni = blockIdx.z; // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = - blockIdx.x * params.cPerBlock + threadIdx.x * THREADS_PER_CHANNEL; + int64_t ci = static_cast(blockIdx.x) * + static_cast(params.cPerBlock) + + static_cast(threadIdx.x) * THREADS_PER_CHANNEL; if (ci >= params.c || threadIdx.x * THREADS_PER_CHANNEL >= params.cPerBlock) { return; } @@ -616,8 +619,9 @@ __global__ void groupNormNDHWCScaleKernel( // The instance in the batch. int32_t ni = blockIdx.z; // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = - blockIdx.x * params.cPerBlock + threadIdx.x * THREADS_PER_CHANNEL; + int64_t ci = static_cast(blockIdx.x) * + static_cast(params.cPerBlock) + + static_cast(threadIdx.x) * THREADS_PER_CHANNEL; // The group that thread works on and the channel in the group (modulus). int32_t gi = ci / params.cPerGroup; diff --git a/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu b/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu index a51f8c1abfd75b..4872bd56e68361 100644 --- a/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu +++ b/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu @@ -122,7 +122,9 @@ __global__ void AddGumbelNoiseCUDAKernel(const T* input_data, MPType* noise, const float temperature, int64_t n) { - int index = threadIdx.x + blockIdx.x * blockDim.x; + int64_t index = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); int step = blockDim.x * gridDim.x; for (int64_t i = index; i < n; i += step) { MPType gumbel_noise = -log(-log(noise[i])); diff --git a/paddle/phi/kernels/gpu/lars_momentum_kernel.cu b/paddle/phi/kernels/gpu/lars_momentum_kernel.cu index 5e3dd03a2d5192..2cae6c2e487b77 100644 --- a/paddle/phi/kernels/gpu/lars_momentum_kernel.cu +++ b/paddle/phi/kernels/gpu/lars_momentum_kernel.cu @@ -309,7 +309,9 @@ __global__ void MergedMomentumLarsKernel(LarsParamWrapper lars_wrapper, const MT rescale_grad, const bool is_amp) { int grid_stride = gridDim.x * LARS_BLOCK_SIZE; - int tid = threadIdx.x + blockIdx.x * blockDim.x; + int64_t tid = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); for (int i = 0; i < op_num; ++i) { int numel = lars_wrapper.numel_arr[i]; @@ -369,7 +371,9 @@ __global__ void MomentumLarsKernel(const T* param, const int thresh, const int64_t numel, const bool is_amp) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; + int64_t tid = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); int grid_stride = gridDim.x * LARS_BLOCK_SIZE; #if CUDA_VERSION >= 11000 const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); diff --git a/paddle/phi/kernels/gpu/lookup_table_grad_kernel.cu b/paddle/phi/kernels/gpu/lookup_table_grad_kernel.cu index 608397aae5b963..f8d9a2ec569c9f 100644 --- a/paddle/phi/kernels/gpu/lookup_table_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lookup_table_grad_kernel.cu @@ -30,7 +30,8 @@ __global__ void LookupTableGrad(T *table, const int64_t K, const int64_t D) { int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * GridDimX; + int64_t idy = static_cast(blockIdx.x) + + static_cast(threadIdx.y) * GridDimX; while (idy < K) { int64_t id = ids[idy]; diff --git a/paddle/phi/kernels/gpu/lookup_table_kernel.cu b/paddle/phi/kernels/gpu/lookup_table_kernel.cu index 613d97a3ae6b7c..61428d41bc2380 100644 --- a/paddle/phi/kernels/gpu/lookup_table_kernel.cu +++ b/paddle/phi/kernels/gpu/lookup_table_kernel.cu @@ -35,7 +35,8 @@ __global__ void LookupTable(T *output, const int64_t D, const int64_t padding_idx) { int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * GridDimX; + int64_t idy = static_cast(blockIdx.x) + + static_cast(threadIdx.y) * GridDimX; while (idy < K) { int64_t id = ids[idy]; diff --git a/paddle/phi/kernels/gpu/multiclass_nms3_kernel.cu b/paddle/phi/kernels/gpu/multiclass_nms3_kernel.cu index 5184edec460c6e..7ced1fdc177413 100644 --- a/paddle/phi/kernels/gpu/multiclass_nms3_kernel.cu +++ b/paddle/phi/kernels/gpu/multiclass_nms3_kernel.cu @@ -672,7 +672,8 @@ __launch_bounds__(nthds_per_cta) __global__ bool clip_boxes, const T_SCORE score_shift) { if (keep_top_k > top_k) return; - for (int i = blockIdx.x * nthds_per_cta + threadIdx.x; + for (int64_t i = static_cast(blockIdx.x) * nthds_per_cta + + static_cast(threadIdx.x); i < num_images * keep_top_k; i += gridDim.x * nthds_per_cta) { const int imgId = i / keep_top_k; diff --git a/paddle/phi/kernels/gpu/multinomial_kernel.cu b/paddle/phi/kernels/gpu/multinomial_kernel.cu index 34c4a1391e3dfe..87c1bed8102a3a 100644 --- a/paddle/phi/kernels/gpu/multinomial_kernel.cu +++ b/paddle/phi/kernels/gpu/multinomial_kernel.cu @@ -45,8 +45,11 @@ __global__ void NormalizeProbability(MT* norm_probs, MT* sum_rows, int64_t num_distributions, int64_t num_categories) { - int id = threadIdx.x + blockIdx.x * blockDim.x + - blockIdx.y * gridDim.x * blockDim.x; + int64_t id = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(blockIdx.y) * static_cast(gridDim.x) * + static_cast(blockDim.x); if (id < num_distributions * num_categories) { PADDLE_ENFORCE( static_cast(in_data[id]) >= 0.0, @@ -112,7 +115,9 @@ __global__ void sampleMultinomialWithReplacement( hiprand_init(seed, idx, offset, &state); #endif - int sample = blockIdx.x * blockDim.x + threadIdx.x; + int64_t sample = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); for (int64_t dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) { if (sample < num_samples) { #if defined(__NVCC__) diff --git a/paddle/phi/kernels/gpu/nadam_kernel.cu b/paddle/phi/kernels/gpu/nadam_kernel.cu index 55f6dadab3c971..99441fb827335c 100644 --- a/paddle/phi/kernels/gpu/nadam_kernel.cu +++ b/paddle/phi/kernels/gpu/nadam_kernel.cu @@ -46,7 +46,9 @@ __global__ void NAdamGPUKernel(const T* param, MT* master_param_out) { MT lr_scalar = static_cast(learning_rate[0]); - int idx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t idx = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); for (int64_t index = idx; index < num; index += gridDim.x * blockDim.x) { // load and cast input to MT diff --git a/paddle/phi/kernels/gpu/partial_concat_grad_kernel.cu b/paddle/phi/kernels/gpu/partial_concat_grad_kernel.cu index f385c99b79447c..af5a4a675198ff 100644 --- a/paddle/phi/kernels/gpu/partial_concat_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/partial_concat_grad_kernel.cu @@ -33,7 +33,9 @@ __global__ void ConcatPartialGradCUDAKernel(T **in, int64_t start_index, int64_t out_batch_len, int64_t part_length) { - int id = blockIdx.x * blockDim.x + threadIdx.x; + int64_t id = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); while (id < all_length) { int64_t bs_id = id / out_batch_len; int64_t bs_index = id % out_batch_len; diff --git a/paddle/phi/kernels/gpu/partial_concat_kernel.cu b/paddle/phi/kernels/gpu/partial_concat_kernel.cu index 8059e109eb4d58..10c3b293e3cf9d 100644 --- a/paddle/phi/kernels/gpu/partial_concat_kernel.cu +++ b/paddle/phi/kernels/gpu/partial_concat_kernel.cu @@ -33,7 +33,9 @@ __global__ void ConcatPartialCUDAKernel(T **in, int64_t start_index, int64_t out_batch_len, int64_t part_length) { - int id = blockIdx.x * blockDim.x + threadIdx.x; + int64_t id = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); while (id < all_length) { int64_t bs_id = id / out_batch_len; int64_t bs_index = id % out_batch_len; diff --git a/paddle/phi/kernels/gpu/partial_sum_kernel.cu b/paddle/phi/kernels/gpu/partial_sum_kernel.cu index 32bee49d062fc2..de39b0228d6b2a 100644 --- a/paddle/phi/kernels/gpu/partial_sum_kernel.cu +++ b/paddle/phi/kernels/gpu/partial_sum_kernel.cu @@ -29,7 +29,9 @@ __global__ void SumArrayPartialCUDAKernel(T **in, int64_t start_index, int64_t length, int64_t row_length) { - int id = blockIdx.x * blockDim.x + threadIdx.x; + int64_t id = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); while (id < lod_length) { T total = static_cast(0); int b_id = id / length; @@ -54,7 +56,9 @@ __global__ void PartialSumGradCUDAKernel(T **res_grad, int64_t start_index, int64_t length, int64_t row_length) { - int id = blockIdx.x * blockDim.x + threadIdx.x; + int64_t id = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); while (id < lod_length) { T total = static_cast(0); int b_id = id / length; diff --git a/paddle/phi/kernels/gpu/psroi_pool_grad_kernel.cu b/paddle/phi/kernels/gpu/psroi_pool_grad_kernel.cu index 13f0b12fa7e0d7..43800cb1199c74 100644 --- a/paddle/phi/kernels/gpu/psroi_pool_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/psroi_pool_grad_kernel.cu @@ -47,7 +47,9 @@ __global__ void GPUPSROIPoolBackward(const int64_t nthreads, const int pooled_width, const int* rois_batch_id_data, T* dx_data) { - int index = blockIdx.x * blockDim.x + threadIdx.x; + int64_t index = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int offset = blockDim.x * gridDim.x; for (int64_t i = index; i < nthreads; i += offset) { // The output is in order (n, c, ph, pw) diff --git a/paddle/phi/kernels/gpu/psroi_pool_kernel.cu b/paddle/phi/kernels/gpu/psroi_pool_kernel.cu index 1193c18131ce33..bbb1649ea9bde7 100644 --- a/paddle/phi/kernels/gpu/psroi_pool_kernel.cu +++ b/paddle/phi/kernels/gpu/psroi_pool_kernel.cu @@ -45,7 +45,9 @@ __global__ void GPUPSROIPoolForward(const int nthreads, const int pooled_width, const int* rois_batch_id_data, T* output_data) { - int index = blockIdx.x * blockDim.x + threadIdx.x; + int64_t index = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int offset = blockDim.x * gridDim.x; for (size_t i = index; i < nthreads; i += offset) { // The output is in order (n, c, ph, pw) diff --git a/paddle/phi/kernels/gpu/radam_kernel.cu b/paddle/phi/kernels/gpu/radam_kernel.cu index bee2bb8492702f..388ec8c20621a2 100644 --- a/paddle/phi/kernels/gpu/radam_kernel.cu +++ b/paddle/phi/kernels/gpu/radam_kernel.cu @@ -46,7 +46,9 @@ __global__ void RAdamGPUKernel(const T* param, MT* master_param_out) { MT lr_scalar = static_cast(learning_rate[0]); - int idx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t idx = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); for (int64_t index = idx; index < num; index += gridDim.x * blockDim.x) { // load and cast input to MT diff --git a/paddle/phi/kernels/gpu/row_conv_grad_kernel.cu b/paddle/phi/kernels/gpu/row_conv_grad_kernel.cu index ac61f86fed3e19..839c87b79fea91 100644 --- a/paddle/phi/kernels/gpu/row_conv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/row_conv_grad_kernel.cu @@ -76,7 +76,9 @@ __global__ void RowConvGradInput(const T *dout, int future_context, const size_t *batch_indices, T *din) { - int d = blockIdx.x * blockDim.x + threadIdx.x; // index along input_dim + int64_t d = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); // index along input_dim int bly = blockDim.y; int thy = threadIdx.y; diff --git a/paddle/phi/kernels/gpu/row_conv_kernel.cu b/paddle/phi/kernels/gpu/row_conv_kernel.cu index ab7c8254ec7bc2..5875ccc6659ec4 100644 --- a/paddle/phi/kernels/gpu/row_conv_kernel.cu +++ b/paddle/phi/kernels/gpu/row_conv_kernel.cu @@ -76,7 +76,9 @@ __global__ void RowConvForward(const T *in, int future_context, const size_t *batch_indices, T *out) { - int d = blockIdx.x * blockDim.x + threadIdx.x; // index along input_dim + int64_t d = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); // index along input_dim int bly = blockDim.y; int thy = threadIdx.y; diff --git a/paddle/phi/kernels/gpu/shuffle_channel.h b/paddle/phi/kernels/gpu/shuffle_channel.h index 59e067374e113d..7766a731032f8f 100644 --- a/paddle/phi/kernels/gpu/shuffle_channel.h +++ b/paddle/phi/kernels/gpu/shuffle_channel.h @@ -34,7 +34,9 @@ __global__ void ShuffleChannel(const int nthreads, int group_row, int group_column, int len) { - int index = blockIdx.x * blockDim.x + threadIdx.x; + int64_t index = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); int offset = blockDim.x * gridDim.x; for (size_t ii = index; ii < nthreads; ii += offset) { const int n = index / group_row / group_column / len; diff --git a/paddle/phi/kernels/gpu/slogdeterminant_kernel.cu b/paddle/phi/kernels/gpu/slogdeterminant_kernel.cu index fde94d4b70a188..02ce6c80ed8a64 100644 --- a/paddle/phi/kernels/gpu/slogdeterminant_kernel.cu +++ b/paddle/phi/kernels/gpu/slogdeterminant_kernel.cu @@ -266,7 +266,9 @@ __global__ void GetSlogDetV2FromLU(const T* lu_data, int64_t batch_size, T* sign_data, T* logdet_data) { - int idx = threadIdx.x + blockIdx.x * blockDim.x; + int64_t idx = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); if (idx < batch_size) { int offset_lu = idx * n * n; int offset_ipiv = idx * n; diff --git a/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu b/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu index d7df2581f9656e..2f17c9ed14cfdf 100644 --- a/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu +++ b/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu @@ -1035,7 +1035,9 @@ void DispatchTopPSampling(const Context& dev_ctx, __global__ void setup_kernel(GPU(randState_t) * state, int64_t* seed, const int bs) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t idx = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); for (int i = idx; i < bs; i += gridDim.x * blockDim.x) { GPU(rand_init)(static_cast(seed[i]), 0, 0, &state[i]); } @@ -1046,7 +1048,9 @@ __global__ void setup_kernel(GPU(randState_t) * state, const uint64_t offset, const int bs, const bool need_batch_random) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t idx = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); for (int i = idx; i < bs; i += gridDim.x * blockDim.x) { if (need_batch_random) { GPU(rand_init)(seed, i, offset, &state[i]); diff --git a/paddle/phi/kernels/gpu/tril_indices_kernel.cu b/paddle/phi/kernels/gpu/tril_indices_kernel.cu index be83f28451166b..047b14c8d3973e 100644 --- a/paddle/phi/kernels/gpu/tril_indices_kernel.cu +++ b/paddle/phi/kernels/gpu/tril_indices_kernel.cu @@ -66,7 +66,9 @@ __global__ void tril_indices_kernel(T* out_data, int col, int trapezoid_size, int tril_size) { - int linear_index = blockIdx.x * blockDim.x + threadIdx.x; + int64_t linear_index = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); if (linear_index < tril_size) { T r, c; diff --git a/paddle/phi/kernels/gpu/triu_indices_kernel.cu b/paddle/phi/kernels/gpu/triu_indices_kernel.cu index cece4bee8a42c1..f9b7ebde2584e4 100644 --- a/paddle/phi/kernels/gpu/triu_indices_kernel.cu +++ b/paddle/phi/kernels/gpu/triu_indices_kernel.cu @@ -67,7 +67,9 @@ __global__ void triu_indices_kernel(T* out_data, int col, int rectangle_size, int triu_size) { - int linear_index = blockIdx.x * blockDim.x + threadIdx.x; + int64_t linear_index = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); if (linear_index < triu_size) { T r, c; diff --git a/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu b/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu index af6169ba9cb7b1..d27539acc42e60 100644 --- a/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu +++ b/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu @@ -128,7 +128,9 @@ __global__ void ArgmaxCUDAKernel(const int64_t height, // n * h } __global__ void ARangeKernel(int64_t* data, int num, int64_t scale) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t idx = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); for (int start = idx; idx < num; idx += gridDim.x) { data[idx] = idx * scale; } diff --git a/paddle/phi/kernels/gpu/weighted_sample_neighbors_kernel.cu b/paddle/phi/kernels/gpu/weighted_sample_neighbors_kernel.cu index ce22758e407862..29da196397a15d 100644 --- a/paddle/phi/kernels/gpu/weighted_sample_neighbors_kernel.cu +++ b/paddle/phi/kernels/gpu/weighted_sample_neighbors_kernel.cu @@ -63,7 +63,9 @@ __global__ void GetSampleCountAndNeighborCountKernel(const T* col_ptr, int* neighbor_count, int sample_size, int n) { - int i = threadIdx.x + blockIdx.x * blockDim.x; + int64_t i = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); if (i >= n) return; T nid = input_nodes[i]; int neighbor_size = static_cast(col_ptr[nid + 1] - col_ptr[nid]); diff --git a/paddle/phi/kernels/gpu/yolo_box_head_kernel.cu b/paddle/phi/kernels/gpu/yolo_box_head_kernel.cu index a4821e6534463d..b0df9c5cee41ca 100644 --- a/paddle/phi/kernels/gpu/yolo_box_head_kernel.cu +++ b/paddle/phi/kernels/gpu/yolo_box_head_kernel.cu @@ -33,9 +33,15 @@ __global__ void YoloBoxHeadCudaKernel(const T* input, const int grid_size_y, const int class_num, const int anchors_num) { - int x_id = blockIdx.x * blockDim.x + threadIdx.x; - int y_id = blockIdx.y * blockDim.y + threadIdx.y; - int z_id = blockIdx.z * blockDim.z + threadIdx.z; + int64_t x_id = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + int64_t y_id = + static_cast(blockIdx.y) * static_cast(blockDim.y) + + static_cast(threadIdx.y); + int64_t z_id = + static_cast(blockIdx.z) * static_cast(blockDim.z) + + static_cast(threadIdx.z); if ((x_id >= grid_size_x) || (y_id >= grid_size_y) || (z_id >= anchors_num)) { return; } diff --git a/paddle/phi/kernels/gpu/yolo_box_post_kernel.cu b/paddle/phi/kernels/gpu/yolo_box_post_kernel.cu index 1e2613c5cab773..fc3492144e4e44 100644 --- a/paddle/phi/kernels/gpu/yolo_box_post_kernel.cu +++ b/paddle/phi/kernels/gpu/yolo_box_post_kernel.cu @@ -139,9 +139,15 @@ __global__ void YoloBoxNum(const float* input, const int class_num, const int anchors_num, float prob_thresh) { - int x_id = blockIdx.x * blockDim.x + threadIdx.x; - int y_id = blockIdx.y * blockDim.y + threadIdx.y; - int z_id = blockIdx.z * blockDim.z + threadIdx.z; + int64_t x_id = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + int64_t y_id = + static_cast(blockIdx.y) * static_cast(blockDim.y) + + static_cast(threadIdx.y); + int64_t z_id = + static_cast(blockIdx.z) * static_cast(blockDim.z) + + static_cast(threadIdx.z); if ((x_id >= grid_size) || (y_id >= grid_size) || (z_id >= anchors_num)) { return; } @@ -168,9 +174,15 @@ __global__ void YoloTensorParseKernel(const float* input, const int neth, int* biases, float prob_thresh) { - int x_id = blockIdx.x * blockDim.x + threadIdx.x; - int y_id = blockIdx.y * blockDim.y + threadIdx.y; - int z_id = blockIdx.z * blockDim.z + threadIdx.z; + int64_t x_id = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + int64_t y_id = + static_cast(blockIdx.y) * static_cast(blockDim.y) + + static_cast(threadIdx.y); + int64_t z_id = + static_cast(blockIdx.z) * static_cast(blockDim.z) + + static_cast(threadIdx.z); if ((x_id >= grid_size) || (y_id >= grid_size) || (z_id >= anchors_num)) { return; } diff --git a/paddle/phi/kernels/gpudnn/mha_cudnn_frontend.cu b/paddle/phi/kernels/gpudnn/mha_cudnn_frontend.cu index 22a12e7f577de3..5d34b941907cbb 100644 --- a/paddle/phi/kernels/gpudnn/mha_cudnn_frontend.cu +++ b/paddle/phi/kernels/gpudnn/mha_cudnn_frontend.cu @@ -311,7 +311,9 @@ __global__ void fill_cu_seqlen_with_constant(scalar_t *cu_seqlens_q, scalar_t q_seqlen, scalar_t kv_seqlen, size_t n) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; + int64_t tid = + static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); if (tid < n) { cu_seqlens_q[tid] = q_seqlen; cu_seqlens_kv[tid] = kv_seqlen; diff --git a/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h b/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h index 1a23e6d845781d..d784fa92ed1c29 100644 --- a/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h @@ -280,7 +280,10 @@ __global__ void QuantActKernel(const T* x, InVec in_vec; OutVec out_vec; - for (int linear_index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + for (int64_t linear_index = (static_cast(blockIdx.x) * + static_cast(blockDim.x) + + static_cast(threadIdx.x)) * + VecSize; linear_index < elem_cnt; linear_index += gridDim.x * blockDim.x * VecSize) { int row_idx = linear_index / cols; @@ -339,7 +342,9 @@ __global__ void SplitKernel(const T* x, __syncthreads(); - for (int linear_idx = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t linear_idx = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); linear_idx < elem_cnt; linear_idx += blockDim.x * gridDim.x) { int32_t row_idx = linear_idx / kfp_num; // n @@ -395,7 +400,10 @@ __global__ void DequantActivationMergeKernel(const T* x, FpVec out_vec; FpVec x_vec; - for (int linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + for (int64_t linear_idx = (static_cast(blockIdx.x) * + static_cast(blockDim.x) + + static_cast(threadIdx.x)) * + VecSize; linear_idx < elem_cnt; linear_idx += gridDim.x * blockDim.x * VecSize) { phi::Load(x_fp + linear_idx, &x_fp_vec); diff --git a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu index d2eef5f870a47f..98bc810acb2a12 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu @@ -62,7 +62,9 @@ __global__ void SparseAdamCUDAKernelREG(MT beta1, bool lazy_mode, int ndim, bool amsgrad) { - int id = blockIdx.x * blockDim.x + threadIdx.x; + int64_t id = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); MT lr = *lr_; for (; id < ndim; id += blockDim.x * gridDim.x) { diff --git a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu index 942ba5d3da7374..0a413ef5e81923 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu @@ -68,7 +68,9 @@ __global__ void SparseAdamWCUDAKernelREG(MT beta1, bool lazy_mode, int ndim, bool amsgrad) { - int id = blockIdx.x * blockDim.x + threadIdx.x; + int64_t id = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); MT lr = *lr_ * lr_ratio; for (; id < ndim; id += blockDim.x * gridDim.x) { diff --git a/paddle/phi/kernels/selected_rows/gpu/lookup_table_grad_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/lookup_table_grad_kernel.cu index 3b7f59315e472e..fd3b82acd0a30e 100644 --- a/paddle/phi/kernels/selected_rows/gpu/lookup_table_grad_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/lookup_table_grad_kernel.cu @@ -30,7 +30,8 @@ __global__ void LookupTableGrad(T *table, const int64_t K, const int64_t D) { int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * GridDimX; + int64_t idy = static_cast(blockIdx.x) + + static_cast(threadIdx.y) * GridDimX; while (idy < K) { int64_t id = ids[idy]; diff --git a/paddle/phi/kernels/selected_rows/gpu/lookup_table_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/lookup_table_kernel.cu index a254cf4103f9bc..45e4f02aeb42e4 100644 --- a/paddle/phi/kernels/selected_rows/gpu/lookup_table_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/lookup_table_kernel.cu @@ -34,7 +34,8 @@ __global__ void LookupTable(T *output, const int64_t D, const int64_t padding_idx) { int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * GridDimX; + int64_t idy = static_cast(blockIdx.x) + + static_cast(threadIdx.y) * GridDimX; while (idy < K) { int64_t id = ids[idy];