Skip to content

Commit c462d7d

Browse files
committed
fix(kernel): 引入最新版本的 cccl 以支持 nvrtc 与 cub 配合使用
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 5f518a3 commit c462d7d

File tree

5 files changed

+87
-64
lines changed

5 files changed

+87
-64
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@
1919
[submodule "src/09python_ffi/pybind11"]
2020
path = src/09python_ffi/pybind11
2121
url = git@github.com:pybind/pybind11.git
22+
[submodule "3rd-party/cccl"]
23+
path = 3rd-party/cccl
24+
url = git@github.com:NVIDIA/cccl.git

3rd-party/cccl

Submodule cccl added at b7d4228

src/04kernel/src/generator/nvrtc_repo.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "nvrtc_repo.h"
44
#include "hardware/device_manager.h"
5+
#include <filesystem>
56
#include <nvrtc.h>
67

78
#define NVRTC_ASSERT(CALL) \
@@ -38,9 +39,24 @@ namespace refactor::kernel::nvrtc {
3839
NVRTC_ASSERT(nvrtcCreateProgram(&prog, code.data(), name.data(), 0, nullptr, nullptr));
3940

4041
std::vector<std::string> opts{"--std=c++17", "--gpu-architecture=compute_80"};
42+
{
43+
auto proj = std::filesystem::path(__FILE__)
44+
.parent_path()
45+
.parent_path()
46+
.parent_path()
47+
.parent_path()
48+
.parent_path();
49+
auto cccl = proj / "3rd-party/cccl";
50+
auto cudacxx = cccl / "libcudacxx/include";
51+
auto cub = cccl / "cub";
52+
ASSERT(std::filesystem::is_directory(cub), "cub not exist");
53+
opts.emplace_back(fmt::format("-I{}", cudacxx.c_str()));
54+
opts.emplace_back(fmt::format("-I{}", cub.c_str()));
55+
}
4156
#ifdef CUDA_INCLUDE_PATH
4257
opts.emplace_back(fmt::format("-I{}", CUDA_INCLUDE_PATH));
4358
#endif
59+
4460
std::vector<const char *> optsPtr(opts.size());
4561
std::transform(opts.begin(), opts.end(), optsPtr.begin(),
4662
[](auto &s) { return s.c_str(); });

src/04kernel/src/kernels/rms_normalization/cuda_kernel.cc

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,31 +48,27 @@ namespace refactor::kernel {
4848

4949
// 0: data type
5050
// 1: block size
51-
// 2: epsilon cast
51+
// 2: T -> float
52+
// 3: T <- float
5253
constexpr static const char *TEMPLATE = R"~(
53-
#include <cub/cub.cuh>
54-
55-
static __device__ __forceinline__ {0:} squareSum({0:} a, {0:} b) {{
56-
return a * a + b * b;
57-
}}
54+
#include <cub/block/block_reduce.cuh>
5855
5956
extern "C" __global__ void kernel(
60-
{0:} *__restrict__ const y,
61-
{0:} const *__restrict__ const x,
62-
{0:} const *__restrict__ const w,
63-
float epsilon_) {{
57+
{0:} *__restrict__ y,
58+
{0:} const *__restrict__ x,
59+
{0:} const *__restrict__ w,
60+
float epsilon) {{
6461
65-
auto epsilon = {2:}(epsilon_);
6662
x += blockIdx.x * blockDim.x + threadIdx.x;
67-
y += blockIdx.x * blockDim.x + threadIdx.x;;
63+
y += blockIdx.x * blockDim.x + threadIdx.x;
6864
w += threadIdx.x;
6965
7066
using BlockReduce = cub::BlockReduce<{0:}, {1:}>;
7167
__shared__ typename BlockReduce::TempStorage tempStorage;
7268
__shared__ {0:} rms;
73-
auto acc = BlockReduce(tempStorage).Reduce(*x, squareSum);
69+
auto acc = BlockReduce(tempStorage).Reduce(*x * *x, cub::Sum());
7470
if (threadIdx.x == 0) {{
75-
rms = rsqrt(acc / blockDim.x + epsilon);
71+
rms = {3:}(rsqrt({2:}(acc) / blockDim.x + epsilon));
7672
}}
7773
__syncthreads();
7874
@@ -96,6 +92,11 @@ extern "C" __global__ void kernel(
9692
: dataType == DataType::F64 ? "static_cast<float>"
9793
: dataType == DataType::FP16 ? "__half2float"
9894
: dataType == DataType::BF16 ? "__bfloat162float"
95+
: UNREACHABLEX(const char*, "unreachable"),
96+
dataType == DataType::F32 ? ""
97+
: dataType == DataType::F64 ? ""
98+
: dataType == DataType::FP16 ? "__float2half"
99+
: dataType == DataType::BF16 ? "__float2bfloat16"
99100
: UNREACHABLEX(const char*, "unreachable")
100101
// clang-format on
101102
);
Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,56 @@
1-
// #ifdef USE_CUDA
1+
#ifdef USE_CUDA
22

3-
// #include "../../../src/kernels/rms_normalization/cpu_kernel.hh"
4-
// #include "../../../src/kernels/rms_normalization/cuda_kernel.hh"
5-
// #include "hardware/device_manager.h"
6-
// #include <gtest/gtest.h>
7-
// #include <numeric>
3+
#include "../../../src/kernels/rms_normalization/cpu_kernel.hh"
4+
#include "../../../src/kernels/rms_normalization/cuda_kernel.hh"
5+
#include "hardware/device_manager.h"
6+
#include <gtest/gtest.h>
7+
#include <numeric>
88

9-
// using namespace refactor;
10-
// using namespace kernel;
11-
// using namespace hardware;
9+
using namespace refactor;
10+
using namespace kernel;
11+
using namespace hardware;
1212

13-
// TEST(kernel, RmsNormalizationCuda) {
14-
// // build routine
15-
// auto y = Tensor::share(DataType::F32, Shape{2, 3, 4});
16-
// auto x = Tensor::share(DataType::F32, Shape{2, 3, 4});
17-
// auto w = Tensor::share(DataType::F32, Shape{4});
18-
// auto kernel = RmsNormalizationCuda::build(0, *x),
19-
// kCpu = RmsNormalizationCpu::build(0, *x);
20-
// ASSERT_TRUE(kernel && kCpu);
21-
// auto res = runtime::Resources();
22-
// auto routine = kernel->lower(res).routine,
23-
// rCpu = kCpu->lower(res).routine;
24-
// // malloc
25-
// auto &dev = *device::init(Device::Type::Nvidia, 0, "");
26-
// auto yGpu = dev.malloc(y->bytesSize()),
27-
// xGpu = dev.malloc(x->bytesSize()),
28-
// wGpu = dev.malloc(w->bytesSize());
29-
// // put input data
30-
// std::vector<float> y_(y->elementsSize());
31-
// std::vector<float> x_(x->elementsSize());
32-
// std::vector<float> w_(w->elementsSize());
33-
// std::iota(x_.begin(), x_.end(), 0);
34-
// std::iota(w_.begin(), w_.end(), 1);
35-
// xGpu->copyFromHost(x_.data(), x->bytesSize());
36-
// wGpu->copyFromHost(w_.data(), w->bytesSize());
37-
// // inference
38-
// {
39-
// void const *inputs[]{*xGpu, *wGpu};
40-
// void *outputs[]{*yGpu};
41-
// routine(res, nullptr, inputs, outputs);
42-
// }
43-
// {
44-
// void const *inputs[]{x_.data(), w_.data()};
45-
// void *outputs[]{y_.data()};
46-
// rCpu(res, nullptr, inputs, outputs);
47-
// }
48-
// // check
49-
// std::vector<float> result(y->elementsSize());
50-
// yGpu->copyToHost(result.data(), y->bytesSize());
51-
// EXPECT_EQ(result, y_);
52-
// }
13+
TEST(kernel, RmsNormalizationCuda) {
14+
// build routine
15+
auto y = Tensor::share(DataType::F32, Shape{2, 3, 4});
16+
auto x = Tensor::share(DataType::F32, Shape{2, 3, 4});
17+
auto w = Tensor::share(DataType::F32, Shape{4});
18+
auto kernel = RmsNormalizationCuda::build(0, *x),
19+
kCpu = RmsNormalizationCpu::build(0, *x);
20+
ASSERT_TRUE(kernel && kCpu);
21+
auto res = runtime::Resources();
22+
auto routine = kernel->lower(res).routine,
23+
rCpu = kCpu->lower(res).routine;
24+
// malloc
25+
auto &dev = *device::init(Device::Type::Nvidia, 0, "");
26+
auto yGpu = dev.malloc(y->bytesSize()),
27+
xGpu = dev.malloc(x->bytesSize()),
28+
wGpu = dev.malloc(w->bytesSize());
29+
// put input data
30+
std::vector<float> y_(y->elementsSize());
31+
std::vector<float> x_(x->elementsSize());
32+
std::vector<float> w_(w->elementsSize());
33+
std::iota(x_.begin(), x_.end(), 0);
34+
std::iota(w_.begin(), w_.end(), 1);
35+
xGpu->copyFromHost(x_.data(), x->bytesSize());
36+
wGpu->copyFromHost(w_.data(), w->bytesSize());
37+
// inference
38+
{
39+
void const *inputs[]{*xGpu, *wGpu};
40+
void *outputs[]{*yGpu};
41+
routine(res, nullptr, inputs, outputs);
42+
}
43+
{
44+
void const *inputs[]{x_.data(), w_.data()};
45+
void *outputs[]{y_.data()};
46+
rCpu(res, nullptr, inputs, outputs);
47+
}
48+
// check
49+
std::vector<float> result(y->elementsSize());
50+
yGpu->copyToHost(result.data(), y->bytesSize());
51+
for (auto i : range0_(y_.size())) {
52+
EXPECT_FLOAT_EQ(result[i], y_[i]);
53+
}
54+
}
5355

54-
// #endif
56+
#endif

0 commit comments

Comments
 (0)