Skip to content

Commit 59939c6

Browse files
committed
refactor(kernel): 简化 nvrtc kernel 调用
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent d27db1e commit 59939c6

File tree

8 files changed

+60
-51
lines changed

8 files changed

+60
-51
lines changed

src/04kernel/src/generator/nvrtc_repo.cc

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@
1010
nvrtcGetErrorString(status))); \
1111
}
1212

13+
#define CUDA_ASSERT(CALL) \
14+
if (auto result = CALL; result != CUDA_SUCCESS) { \
15+
const char *msg; \
16+
cuGetErrorName(result, &msg); \
17+
RUNTIME_ERROR(fmt::format("cuda driver failed on \"" #CALL "\" with {} ({})", \
18+
msg, (int) result)); \
19+
}
20+
1321
namespace refactor::kernel::nvrtc {
1422

1523
Handler::Handler(std::string_view name,
@@ -85,8 +93,22 @@ namespace refactor::kernel::nvrtc {
8593
return it->second;
8694
}
8795

88-
CUfunction Handler::kernel() const {
89-
return _kernel;
96+
void Handler::launch(unsigned int gridDimX,
97+
unsigned int gridDimY,
98+
unsigned int gridDimZ,
99+
unsigned int blockDimX,
100+
unsigned int blockDimY,
101+
unsigned int blockDimZ,
102+
unsigned int sharedMemBytes,
103+
void **kernelParams) const {
104+
CUDA_ASSERT(cuLaunchKernel(
105+
_kernel,
106+
gridDimX, gridDimY, gridDimZ,
107+
blockDimX, blockDimY, blockDimZ,
108+
sharedMemBytes,
109+
nullptr,
110+
kernelParams,
111+
nullptr));
90112
}
91113

92114
std::string_view memCopyType(size_t size) {

src/04kernel/src/generator/nvrtc_repo.h

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,6 @@
44
#include "common.h"
55
#include <cuda.h>
66

7-
#define CUDA_ASSERT(CALL) \
8-
if (auto result = CALL; result != CUDA_SUCCESS) { \
9-
const char *msg; \
10-
cuGetErrorName(result, &msg); \
11-
RUNTIME_ERROR(fmt::format("cuda driver failed on \"" #CALL "\" with {} ({})", \
12-
msg, (int) result)); \
13-
}
14-
157
namespace refactor::kernel::nvrtc {
168

179
class Handler {
@@ -29,7 +21,14 @@ namespace refactor::kernel::nvrtc {
2921
std::string_view name,
3022
std::string_view code,
3123
std::string_view symbol);
32-
CUfunction kernel() const;
24+
void launch(unsigned int gridDimX,
25+
unsigned int gridDimY,
26+
unsigned int gridDimZ,
27+
unsigned int blockDimX,
28+
unsigned int blockDimY,
29+
unsigned int blockDimZ,
30+
unsigned int sharedMemBytes,
31+
void **kernelParams) const;
3332
};
3433

3534
std::string_view memCopyType(size_t);

src/04kernel/src/kernels/cast/cuda_kernel.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,9 @@ extern "C" __global__ void kernel(
115115
auto x = inputs[0];
116116
auto n = params.n;
117117
void *args[]{&y, &x, &n};
118-
CUDA_ASSERT(cuLaunchKernel(
119-
h->kernel(),
120-
params.gridSize, 1, 1,
121-
params.blockSize, 1, 1,
122-
0, nullptr, args, nullptr));
118+
h->launch(params.gridSize, 1, 1,
119+
params.blockSize, 1, 1,
120+
0, args);
123121
};
124122
}
125123

src/04kernel/src/kernels/concat/cuda_kernel.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,9 @@ extern "C" __global__ void kernel(
107107
return [h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"),
108108
params](Resources &, void *, void const *const *inputs, void *const *outputs) {
109109
void *args[]{const_cast<void **>(outputs), const_cast<void **>(inputs)};
110-
CUDA_ASSERT(cuLaunchKernel(
111-
h->kernel(),
112-
params.gridSize, 1, 1,
113-
params.blockSize, 1, 1,
114-
0, nullptr, args, nullptr));
110+
h->launch(params.gridSize, 1, 1,
111+
params.blockSize, 1, 1,
112+
0, args);
115113
};
116114
}
117115

src/04kernel/src/kernels/simple_binary/cuda_kernel.cc

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,13 @@ extern "C" __global__ void kernel(
164164
b = inputs[1];
165165
auto n = params.n;
166166
void *args[]{&c, &a, &b, &n};
167-
CUDA_ASSERT(cuLaunchKernel(
168-
h->kernel(),
169-
params.gridSize, 1, 1,
170-
params.blockSize, 1, 1,
171-
0, nullptr, args, nullptr));
167+
h->launch(params.gridSize, 1, 1,
168+
params.blockSize, 1, 1,
169+
0, args);
172170
};
171+
173172
} else if (auto rank = broadcaster.strides.size() / (broadcaster.inputsCount + 1); rank == 1) {
174-
static std::vector<dim_t> S0{0, 1, 1}, S1{1, 0, 1};
173+
static const std::vector<dim_t> S0{0, 1, 1}, S1{1, 0, 1};
175174
auto name = fmt::format("binaryScalar{}", postfix);
176175
auto code = fmt::format(SCALAR, dt_, op_);
177176
return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"),
@@ -185,12 +184,11 @@ extern "C" __global__ void kernel(
185184
v = inputs[1 - scalar];
186185
auto n = params.n;
187186
void *args[]{&c, &v, &s, &n};
188-
CUDA_ASSERT(cuLaunchKernel(
189-
h->kernel(),
190-
params.gridSize, 1, 1,
191-
params.blockSize, 1, 1,
192-
0, nullptr, args, nullptr));
187+
h->launch(params.gridSize, 1, 1,
188+
params.blockSize, 1, 1,
189+
0, args);
193190
};
191+
194192
} else {
195193
auto name = fmt::format("binary{}{}", rank, postfix);
196194
auto code = fmt::format(BROADCAST, dt_, op_, rank);
@@ -202,11 +200,9 @@ extern "C" __global__ void kernel(
202200
b = inputs[1];
203201
auto n = params.n;
204202
void *args[]{&c, &a, &b, const_cast<dim_t *>(strides.data()), &n};
205-
CUDA_ASSERT(cuLaunchKernel(
206-
h->kernel(),
207-
params.gridSize, 1, 1,
208-
params.blockSize, 1, 1,
209-
0, nullptr, args, nullptr));
203+
h->launch(params.gridSize, 1, 1,
204+
params.blockSize, 1, 1,
205+
0, args);
210206
};
211207
}
212208
}

src/04kernel/src/kernels/simple_unary/cuda_kernel.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,9 @@ extern "C" __global__ void kernel(
152152
auto x = inputs[0];
153153
auto n = params.n;
154154
void *args[]{&y, &x, &n};
155-
CUDA_ASSERT(cuLaunchKernel(
156-
h->kernel(),
157-
params.gridSize, 1, 1,
158-
params.blockSize, 1, 1,
159-
0, nullptr, args, nullptr));
155+
h->launch(params.gridSize, 1, 1,
156+
params.blockSize, 1, 1,
157+
0, args);
160158
};
161159
}
162160

src/04kernel/src/kernels/split/cuda_kernel.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,9 @@ extern "C" __global__ void kernel(
107107
return [h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"),
108108
params](Resources &, void *, void const *const *inputs, void *const *outputs) {
109109
void *args[]{const_cast<void **>(outputs), const_cast<void **>(inputs)};
110-
CUDA_ASSERT(cuLaunchKernel(
111-
h->kernel(),
112-
params.gridSize, 1, 1,
113-
params.blockSize, 1, 1,
114-
0, nullptr, args, nullptr));
110+
h->launch(params.gridSize, 1, 1,
111+
params.blockSize, 1, 1,
112+
0, args);
115113
};
116114
}
117115

src/04kernel/test/generator/test_cuda.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ extern "C" __global__ void kernel() {
1414

1515
TEST(generator, nvrtc) {
1616
auto handler = nvrtc::Handler::compile("helloWorld.cu", code, "kernel");
17-
CUDA_ASSERT(cuLaunchKernel(handler->kernel(),
18-
1, 1, 1,
19-
1, 1, 1,
20-
0, nullptr, nullptr, nullptr));
17+
handler->launch(
18+
1, 1, 1,
19+
1, 1, 1,
20+
0, nullptr);
2121
}
2222

2323
#endif// USE_CUDA

0 commit comments

Comments
 (0)