diff --git a/infini_train/include/common/cuda/cub_compat.cuh b/infini_train/include/common/cuda/cub_compat.cuh new file mode 100644 index 00000000..4bb43edb --- /dev/null +++ b/infini_train/include/common/cuda/cub_compat.cuh @@ -0,0 +1,17 @@ +#pragma once + +#include + +namespace infini_train::kernels::cuda { + +#if defined(CUB_VERSION) && CUB_VERSION >= 200800 +using CubSumOp = ::cuda::std::plus<>; +using CubMaxOp = ::cuda::maximum<>; +using CubMinOp = ::cuda::minimum<>; +#else +using CubSumOp = cub::Sum; +using CubMaxOp = cub::Max; +using CubMinOp = cub::Min; +#endif + +} // namespace infini_train::kernels::cuda diff --git a/infini_train/src/kernels/cuda/cross_entropy.cu b/infini_train/src/kernels/cuda/cross_entropy.cu index 223a556b..56fe7270 100644 --- a/infini_train/src/kernels/cuda/cross_entropy.cu +++ b/infini_train/src/kernels/cuda/cross_entropy.cu @@ -6,6 +6,7 @@ #include #include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/common/cuda/cub_compat.cuh" #include "infini_train/include/common/cuda/kernel_helper.cuh" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" @@ -44,7 +45,7 @@ __global__ void CrossEntropyForwardKernel(const InputType *__restrict__ input_pt for (int i = tid; i < num_classes; i += BLOCK_SIZE) { thread_max = fmaxf(thread_max, common::cuda::Cast(input_ptr[base + i])); } - const float block_max = cub::BlockReduce(shared.reduce).Reduce(thread_max, ::cuda::maximum<>()); + const float block_max = cub::BlockReduce(shared.reduce).Reduce(thread_max, CubMaxOp()); if (tid == 0) { shared.max_logit = block_max; } @@ -139,7 +140,7 @@ __global__ void CrossEntropyBackwardKernel(const InputType *__restrict__ input_p for (int i = tid; i < num_classes; i += BLOCK_SIZE) { thread_max = fmaxf(thread_max, common::cuda::Cast(input_ptr[idx_base + i])); } - const float block_max = cub::BlockReduce(shared.reduce).Reduce(thread_max, ::cuda::maximum<>()); + const float block_max = cub::BlockReduce(shared.reduce).Reduce(thread_max, CubMaxOp()); if (tid == 0) { shared.max_logit = block_max; } diff --git a/infini_train/src/kernels/cuda/reduction.cu b/infini_train/src/kernels/cuda/reduction.cu index 7fd8c2c0..9c7ff9d7 100644 --- a/infini_train/src/kernels/cuda/reduction.cu +++ b/infini_train/src/kernels/cuda/reduction.cu @@ -1,6 +1,7 @@ #include #include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/common/cuda/cub_compat.cuh" #include "infini_train/include/common/cuda/kernel_helper.cuh" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" @@ -14,22 +15,22 @@ namespace { // Reduction operators template struct CubOp; -template struct CubOp> { +template struct CubOp { __device__ static T Init() { return common::cuda::Cast(0); } __device__ static T Reduce(T a, T b) { return common::cuda::Add(a, b); } - __device__ static ::cuda::std::plus<> Op() { return ::cuda::std::plus<>(); } + __device__ static CubSumOp Op() { return CubSumOp(); } }; -template struct CubOp> { +template struct CubOp { __device__ static T Init() { return common::cuda::Cast(-kInfinity); } __device__ static T Reduce(T a, T b) { return common::cuda::Max(a, b); } - __device__ static ::cuda::maximum<> Op() { return ::cuda::maximum<>(); } + __device__ static CubMaxOp Op() { return CubMaxOp(); } }; -template struct CubOp> { +template struct CubOp { __device__ static T Init() { return common::cuda::Cast(kInfinity); } __device__ static T Reduce(T a, T b) { return common::cuda::Min(a, b); } - __device__ static ::cuda::minimum<> Op() { return ::cuda::minimum<>(); } + __device__ static CubMinOp Op() { return CubMinOp(); } }; // Finalization strategies @@ -179,19 +180,19 @@ std::shared_ptr ReduceOpBackward(const std::shared_ptr &grad_out } std::shared_ptr MeanForward(const std::shared_ptr &input, const int64_t dim, const bool keep_dim) { - return ReduceOpForward<::cuda::std::plus<>, MeanFinalize>(input, dim, keep_dim); + return ReduceOpForward(input, dim, keep_dim); } std::shared_ptr SumForward(const std::shared_ptr &input, const int64_t dim, const bool keep_dim) { - return ReduceOpForward<::cuda::std::plus<>, IdentityFinalize>(input, dim, keep_dim); + return ReduceOpForward(input, dim, keep_dim); } std::shared_ptr MaxForward(const std::shared_ptr &input, const int64_t dim, const bool keep_dim) { - return ReduceOpForward<::cuda::maximum<>, IdentityFinalize>(input, dim, keep_dim); + return ReduceOpForward(input, dim, keep_dim); } std::shared_ptr MinForward(const std::shared_ptr &input, const int64_t dim, const bool keep_dim) { - return ReduceOpForward<::cuda::minimum<>, IdentityFinalize>(input, dim, keep_dim); + return ReduceOpForward(input, dim, keep_dim); } std::shared_ptr MeanBackward(const std::shared_ptr &grad_output, const std::vector &input_dims, diff --git a/infini_train/src/kernels/cuda/softmax.cu b/infini_train/src/kernels/cuda/softmax.cu index 622cb8a4..98d47fae 100644 --- a/infini_train/src/kernels/cuda/softmax.cu +++ b/infini_train/src/kernels/cuda/softmax.cu @@ -6,6 +6,7 @@ #include "glog/logging.h" #include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/common/cuda/cub_compat.cuh" #include "infini_train/include/common/cuda/kernel_helper.cuh" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" @@ -31,7 +32,7 @@ __global__ void SoftmaxForwardKernel(T *output, const T *input, int64_t outer_si int64_t idx = (group * axis_size + axis) * inner_size + inner_idx; thread_max = max(thread_max, common::cuda::Cast(input[idx])); } - float block_max = BlockReduce(temp_storage_max).Reduce(thread_max, ::cuda::maximum<>()); + float block_max = BlockReduce(temp_storage_max).Reduce(thread_max, CubMaxOp()); if (tid == 0) { row_max = block_max;