Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions infini_train/include/common/cuda/cub_compat.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include <cub/version.cuh>

namespace infini_train::kernels::cuda {

#if defined(CUB_VERSION) && CUB_VERSION >= 200800
using CubSumOp = ::cuda::std::plus<>;
using CubMaxOp = ::cuda::maximum<>;
using CubMinOp = ::cuda::minimum<>;
#else
using CubSumOp = cub::Sum;
using CubMaxOp = cub::Max;
using CubMinOp = cub::Min;
#endif

} // namespace infini_train::kernels::cuda
5 changes: 3 additions & 2 deletions infini_train/src/kernels/cuda/cross_entropy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <cuda_runtime.h>

#include "infini_train/include/common/cuda/common_cuda.h"
#include "infini_train/include/common/cuda/cub_compat.cuh"
#include "infini_train/include/common/cuda/kernel_helper.cuh"
#include "infini_train/include/dispatcher.h"
#include "infini_train/include/tensor.h"
Expand Down Expand Up @@ -44,7 +45,7 @@ __global__ void CrossEntropyForwardKernel(const InputType *__restrict__ input_pt
for (int i = tid; i < num_classes; i += BLOCK_SIZE) {
thread_max = fmaxf(thread_max, common::cuda::Cast<float>(input_ptr[base + i]));
}
const float block_max = cub::BlockReduce<float, BLOCK_SIZE>(shared.reduce).Reduce(thread_max, ::cuda::maximum<>());
const float block_max = cub::BlockReduce<float, BLOCK_SIZE>(shared.reduce).Reduce(thread_max, CubMaxOp());
if (tid == 0) {
shared.max_logit = block_max;
}
Expand Down Expand Up @@ -139,7 +140,7 @@ __global__ void CrossEntropyBackwardKernel(const InputType *__restrict__ input_p
for (int i = tid; i < num_classes; i += BLOCK_SIZE) {
thread_max = fmaxf(thread_max, common::cuda::Cast<float>(input_ptr[idx_base + i]));
}
const float block_max = cub::BlockReduce<float, BLOCK_SIZE>(shared.reduce).Reduce(thread_max, ::cuda::maximum<>());
const float block_max = cub::BlockReduce<float, BLOCK_SIZE>(shared.reduce).Reduce(thread_max, CubMaxOp());
if (tid == 0) {
shared.max_logit = block_max;
}
Expand Down
21 changes: 11 additions & 10 deletions infini_train/src/kernels/cuda/reduction.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <cub/cub.cuh>

#include "infini_train/include/common/cuda/common_cuda.h"
#include "infini_train/include/common/cuda/cub_compat.cuh"
#include "infini_train/include/common/cuda/kernel_helper.cuh"
#include "infini_train/include/dispatcher.h"
#include "infini_train/include/tensor.h"
Expand All @@ -14,22 +15,22 @@ namespace {
// Reduction operators
template <typename T, typename ReduceFunc> struct CubOp;

template <typename T> struct CubOp<T, ::cuda::std::plus<>> {
template <typename T> struct CubOp<T, CubSumOp> {
__device__ static T Init() { return common::cuda::Cast<T>(0); }
__device__ static T Reduce(T a, T b) { return common::cuda::Add<T>(a, b); }
__device__ static ::cuda::std::plus<> Op() { return ::cuda::std::plus<>(); }
__device__ static CubSumOp Op() { return CubSumOp(); }
};

template <typename T> struct CubOp<T, ::cuda::maximum<>> {
template <typename T> struct CubOp<T, CubMaxOp> {
__device__ static T Init() { return common::cuda::Cast<T>(-kInfinity); }
__device__ static T Reduce(T a, T b) { return common::cuda::Max<T>(a, b); }
__device__ static ::cuda::maximum<> Op() { return ::cuda::maximum<>(); }
__device__ static CubMaxOp Op() { return CubMaxOp(); }
};

template <typename T> struct CubOp<T, ::cuda::minimum<>> {
template <typename T> struct CubOp<T, CubMinOp> {
__device__ static T Init() { return common::cuda::Cast<T>(kInfinity); }
__device__ static T Reduce(T a, T b) { return common::cuda::Min<T>(a, b); }
__device__ static ::cuda::minimum<> Op() { return ::cuda::minimum<>(); }
__device__ static CubMinOp Op() { return CubMinOp(); }
};

// Finalization strategies
Expand Down Expand Up @@ -179,19 +180,19 @@ std::shared_ptr<Tensor> ReduceOpBackward(const std::shared_ptr<Tensor> &grad_out
}

std::shared_ptr<Tensor> MeanForward(const std::shared_ptr<Tensor> &input, const int64_t dim, const bool keep_dim) {
return ReduceOpForward<::cuda::std::plus<>, MeanFinalize>(input, dim, keep_dim);
return ReduceOpForward<CubSumOp, MeanFinalize>(input, dim, keep_dim);
}

std::shared_ptr<Tensor> SumForward(const std::shared_ptr<Tensor> &input, const int64_t dim, const bool keep_dim) {
return ReduceOpForward<::cuda::std::plus<>, IdentityFinalize>(input, dim, keep_dim);
return ReduceOpForward<CubSumOp, IdentityFinalize>(input, dim, keep_dim);
}

std::shared_ptr<Tensor> MaxForward(const std::shared_ptr<Tensor> &input, const int64_t dim, const bool keep_dim) {
return ReduceOpForward<::cuda::maximum<>, IdentityFinalize>(input, dim, keep_dim);
return ReduceOpForward<CubMaxOp, IdentityFinalize>(input, dim, keep_dim);
}

std::shared_ptr<Tensor> MinForward(const std::shared_ptr<Tensor> &input, const int64_t dim, const bool keep_dim) {
return ReduceOpForward<::cuda::minimum<>, IdentityFinalize>(input, dim, keep_dim);
return ReduceOpForward<CubMinOp, IdentityFinalize>(input, dim, keep_dim);
}

std::shared_ptr<Tensor> MeanBackward(const std::shared_ptr<Tensor> &grad_output, const std::vector<int64_t> &input_dims,
Expand Down
3 changes: 2 additions & 1 deletion infini_train/src/kernels/cuda/softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "glog/logging.h"

#include "infini_train/include/common/cuda/common_cuda.h"
#include "infini_train/include/common/cuda/cub_compat.cuh"
#include "infini_train/include/common/cuda/kernel_helper.cuh"
#include "infini_train/include/dispatcher.h"
#include "infini_train/include/tensor.h"
Expand All @@ -31,7 +32,7 @@ __global__ void SoftmaxForwardKernel(T *output, const T *input, int64_t outer_si
int64_t idx = (group * axis_size + axis) * inner_size + inner_idx;
thread_max = max(thread_max, common::cuda::Cast<float>(input[idx]));
}
float block_max = BlockReduce(temp_storage_max).Reduce(thread_max, ::cuda::maximum<>());
float block_max = BlockReduce(temp_storage_max).Reduce(thread_max, CubMaxOp());

if (tid == 0) {
row_max = block_max;
Expand Down