Skip to content

Commit de3b474

Browse files
authored
Merge pull request #57 from InfiniTensor/dev
Dev
2 parents 1791d53 + a0ce722 commit de3b474

File tree

11 files changed

+143
-149
lines changed

11 files changed

+143
-149
lines changed

src/04kernel/cuda/include/kernel/cuda/slice.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
namespace refactor::kernel::cuda {
77

88
struct DimInfo {
9-
unsigned int countStride, sizeStart;
10-
int sizeStride;
9+
unsigned int strideO, skip;
10+
int strideI;
1111
};
1212

1313
void launchSlice(

src/04kernel/cuda/src/slice.cu

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace refactor::kernel::cuda {
66

77
__global__ static void sliceKernel(
88
unsigned long long n,
9-
uint8_t const *src, DimInfo const *dims, uint8_t *output,
9+
uint8_t const *src, DimInfo const *dims, uint8_t *dst,
1010
unsigned int rank,
1111
unsigned int blockSize) {
1212
extern __shared__ DimInfo dimInfo[];
@@ -18,15 +18,13 @@ namespace refactor::kernel::cuda {
1818
step = blockDim.x * gridDim.x;
1919
tid < n;
2020
tid += step) {
21-
long rem = tid;
22-
auto src_ = src;
23-
auto dst_ = output + rem * blockSize;
21+
long rem = tid, j = 0;
2422
for (auto i = 0; i < rank; ++i) {
2523
auto const &dim = dimInfo[i];
26-
src_ += rem / dim.countStride * dim.sizeStride + dim.sizeStart;
27-
rem %= dim.countStride;
24+
j += rem / dim.strideO * dim.strideI + dim.skip;
25+
rem %= dim.strideO;
2826
}
29-
optimizedMemcpy(dst_, src_, blockSize);
27+
optimizedMemcpy(dst + tid * blockSize, src + j * blockSize, blockSize);
3028
}
3129
}
3230

src/04kernel/include/kernel/attributes/slice_info.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,19 @@ namespace refactor::kernel {
1515
/// @brief 优化用于计算的 Slice 描述。
1616
struct SliceInfo {
1717
struct Dim {
18-
dim_t countStride, sizeStart;
19-
sdim_t sizeStride;
18+
dim_t strideO, skip;
19+
sdim_t strideI;
2020

2121
bool operator==(Dim const &) const noexcept;
2222
bool operator!=(Dim const &) const noexcept;
2323
};
2424
std::vector<Dim> dims;
25-
dim_t blockCount, blockSize, baseOffset;
25+
dim_t blockCount, blockSize;
2626

2727
SliceInfo(decltype(dims),
2828
decltype(blockCount),
29-
decltype(blockSize),
30-
decltype(baseOffset)) noexcept;
31-
SliceInfo(Dimensions const &, Tensor const &);
29+
decltype(blockSize)) noexcept;
30+
SliceInfo(Dimensions, Tensor const &);
3231
SliceInfo reform(dim_t maxblockSize) const noexcept;
3332
void reformAssign(dim_t maxblockSize) noexcept;
3433
};

src/04kernel/src/attributes/slice_info.cc

Lines changed: 58 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
namespace refactor::kernel {
55

66
bool SliceInfo::Dim::operator==(Dim const &rhs) const noexcept {
7-
return countStride == rhs.countStride &&
8-
sizeStart == rhs.sizeStart &&
9-
sizeStride == rhs.sizeStride;
7+
return strideO == rhs.strideO &&
8+
strideI == rhs.strideI &&
9+
skip == rhs.skip;
1010
}
1111
bool SliceInfo::Dim::operator!=(Dim const &rhs) const noexcept {
1212
return !operator==(rhs);
@@ -15,74 +15,70 @@ namespace refactor::kernel {
1515
SliceInfo::SliceInfo(
1616
std::vector<Dim> dims_,
1717
dim_t blockCount_,
18-
dim_t blockSize_,
19-
dim_t baseOffset_) noexcept
18+
dim_t blockSize_) noexcept
2019
: dims(std::move(dims_)),
2120
blockCount(blockCount_),
22-
blockSize(blockSize_),
23-
baseOffset(baseOffset_) {}
21+
blockSize(blockSize_) {}
2422

25-
SliceInfo::SliceInfo(Dimensions const &dims_, Tensor const &input)
26-
: dims(1),
23+
SliceInfo::SliceInfo(Dimensions dims_, Tensor const &input)
24+
: dims{},
2725
blockCount(1),
28-
blockSize(input.dataType.size()),
29-
baseOffset(0) {
30-
ASSERT(dims_.size() == static_cast<size_t>(input.rank()), "Unreachable");
26+
blockSize(input.dataType.size()) {
27+
size_t rank = input.rank();
28+
if (!rank) { return; }// scalar input
29+
ASSERT(dims_.size() == rank, "unreachable");
3130

32-
auto continuous = true;
33-
auto stride = blockSize;
34-
dims[0] = {1, 0, static_cast<sdim_t>(stride)};
35-
for (auto i : range0_(input.rank()).rev()) {
36-
auto l = input.shape[i];
37-
auto const &d = dims_[i];
38-
if (auto &it = dims.back(); continuous && d.step == 1) {
39-
it.countStride *= d.length;
40-
it.sizeStart = d.start * stride;
41-
it.sizeStride *= l;
42-
} else {
43-
dims.push_back(Dim{
44-
static_cast<dim_t>(it.countStride * d.length),
45-
static_cast<dim_t>(d.start * stride),
46-
static_cast<sdim_t>(d.step * stride),
47-
});
31+
std::vector<dim_t> shape;
32+
{// 去除形状里的 1
33+
shape.reserve(rank);
34+
for (auto i : range0_(rank)) {
35+
if (auto l = input.shape[i]; l != 1) {
36+
if (auto j = shape.size(); j < i) { dims_[j] = dims_[i]; }
37+
shape.push_back(l);
38+
}
4839
}
49-
continuous = d.length == l;
50-
stride *= l;
40+
dims_.resize(rank = shape.size());
5141
}
52-
baseOffset = dims[0].sizeStart;
53-
auto elementCount = dims[0].countStride;
54-
blockSize *= elementCount;
55-
for (auto &d : dims) {
56-
d.countStride /= elementCount;
42+
dims.reserve(rank);
43+
dim_t strideI = 1;
44+
for (auto i : range0_(rank).rev()) {
45+
auto const &dim = dims_[i];
46+
dims.push_back({
47+
.strideO = blockCount,
48+
.skip = static_cast<dim_t>(strideI * dim.start),
49+
.strideI = static_cast<sdim_t>(strideI * dim.step),
50+
});
51+
blockCount *= dim.length;
52+
strideI *= shape[i];
5753
}
5854
std::reverse(dims.begin(), dims.end());
59-
blockCount = dims[0].countStride;
60-
for (auto i : range(1ul, dims.size())) {
61-
dims[i - 1].countStride = dims[i].countStride;
55+
56+
while (!dims.empty()) {
57+
auto const &dim = dims.back();
58+
if (dim.strideI == static_cast<sdim_t>(dim.strideO) && !dim.skip) {
59+
dims.pop_back();
60+
} else {
61+
long times = std::gcd(std::gcd(dim.strideI, dim.strideO), dim.skip);
62+
blockCount /= times;
63+
blockSize *= times;
64+
if (!dims.empty()) {
65+
for (auto &dim : dims) {
66+
dim.strideO /= times;
67+
dim.skip /= times;
68+
dim.strideI /= times;
69+
}
70+
if (dims.back().strideO != 1) {
71+
dims.push_back({1, 0, 1});
72+
}
73+
}
74+
break;
75+
}
6276
}
63-
dims.pop_back();
64-
dims.shrink_to_fit();
6577
}
6678

6779
SliceInfo SliceInfo::reform(dim_t maxblockSize) const noexcept {
68-
auto blockSize_ = std::gcd(blockSize, maxblockSize);
69-
if (blockSize_ == blockSize) { return *this; }
70-
auto times = blockSize / blockSize_;
71-
SliceInfo ans{
72-
std::vector<Dim>(dims.size() + 1),
73-
blockCount * times,
74-
blockSize_,
75-
baseOffset,
76-
};
77-
for (auto i : range0_(dims.size())) {
78-
auto const &d = dims[i];
79-
ans.dims[i] = {
80-
d.countStride * times,
81-
d.sizeStart,
82-
d.sizeStride,
83-
};
84-
}
85-
ans.dims.back() = {1, 0, static_cast<sdim_t>(blockSize_)};
80+
auto ans = *this;
81+
ans.reformAssign(maxblockSize);
8682
return ans;
8783
}
8884

@@ -93,10 +89,12 @@ namespace refactor::kernel {
9389
blockCount *= times;
9490
blockSize = blockSize_;
9591
for (auto &d : dims) {
96-
d.countStride *= times;
92+
d.strideO *= times;
93+
d.strideI *= times;
94+
d.skip *= times;
9795
}
9896
dims.resize(dims.size() + 1);
99-
dims.back() = {1, 0, static_cast<sdim_t>(blockSize_)};
97+
dims.back() = {1, 0, 1};
10098
}
10199

102100

src/04kernel/src/generator/nvrtc_repo.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace refactor::kernel::nvrtc {
2929
}
3030
NVRTC_ASSERT(nvrtcCreateProgram(&prog, code.data(), name.data(), 0, nullptr, nullptr));
3131

32-
std::vector<std::string> opts{"--std=c++20", "--gpu-architecture=compute_80"};
32+
std::vector<std::string> opts{"--std=c++17", "--gpu-architecture=compute_80"};
3333
#ifdef CUDA_INCLUDE_PATH
3434
opts.emplace_back(fmt::format("-I{}", CUDA_INCLUDE_PATH));
3535
#endif
@@ -42,9 +42,11 @@ namespace refactor::kernel::nvrtc {
4242
{
4343
size_t logSize;
4444
NVRTC_ASSERT(nvrtcGetProgramLogSize(prog, &logSize));
45-
std::vector<char> log(logSize);
46-
NVRTC_ASSERT(nvrtcGetProgramLog(prog, log.data()));
47-
fmt::println("{}", log.data());
45+
if (logSize > 1) {
46+
std::vector<char> log(logSize);
47+
NVRTC_ASSERT(nvrtcGetProgramLog(prog, log.data()));
48+
fmt::println("{}", log.data());
49+
}
4850
}
4951
if (compileResult != NVRTC_SUCCESS) {
5052
fmt::println("wrong code:");

src/04kernel/src/kernels/slice/cpu_kernel.cc

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,18 @@ namespace refactor::kernel {
2525
auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
2626
using namespace runtime;
2727
return [info = this->info](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
28-
auto src = reinterpret_cast<uint8_t const *>(inputs[0]) + info.baseOffset;
28+
auto src = reinterpret_cast<uint8_t const *>(inputs[0]);
2929
auto dst = reinterpret_cast<uint8_t *>(outputs[0]);
3030
std::for_each_n(std::execution::par_unseq,
3131
natural_t(0), info.blockCount,
3232
[=, &info](auto i) {
33-
long rem = i;
34-
auto src_ = src;
35-
auto dst_ = dst + rem * info.blockSize;
33+
long rem = i, j = 0;
3634
for (auto const &dim : info.dims) {
37-
auto d = std::div(rem, dim.countStride);
38-
src_ += d.quot * dim.sizeStride + dim.sizeStart;
35+
auto d = std::div(rem, dim.strideO);
36+
j += d.quot * dim.strideI + dim.skip;
3937
rem = d.rem;
4038
}
41-
std::memcpy(dst_, src_, info.blockSize);
39+
std::memcpy(dst + i * info.blockSize, src + j * info.blockSize, info.blockSize);
4240
});
4341
};
4442
}

src/04kernel/src/kernels/slice/cuda_kernel.cu

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,15 @@ namespace refactor::kernel {
1212
dims.begin(),
1313
[](auto const &d) {
1414
return cuda::DimInfo{
15-
d.countStride,
16-
d.sizeStart,
17-
d.sizeStride,
15+
d.strideO,
16+
d.skip,
17+
d.strideI,
1818
};
1919
});
2020
return [dims = thrust::device_vector<cuda::DimInfo>(dims),
2121
params = cuda::ThreadsDistributer()(info.blockCount),
22-
blockSize = info.blockSize,
23-
baseOffset = info.baseOffset](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
24-
auto src = reinterpret_cast<uint8_t const *>(inputs[0]) + baseOffset;
22+
blockSize = info.blockSize](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
23+
auto src = reinterpret_cast<uint8_t const *>(inputs[0]);
2524
cuda::launchSlice(params, src, dims.data().get(), outputs[0],
2625
dims.size(),
2726
blockSize);

src/04kernel/test/attributes/test_slice_info.cpp

Lines changed: 0 additions & 43 deletions
This file was deleted.

src/04kernel/test/kernels/slice/test_cpu.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,48 @@
55
using namespace refactor;
66
using namespace kernel;
77

8+
TEST(kernel, SliceCpu_with1) {
9+
// build routine
10+
Dimensions dims{
11+
{0, 1, 1},
12+
{0, 1, 2},
13+
{0, 1, 1},
14+
{2, 1, 2},
15+
};
16+
auto input = Tensor::share(DataType::F32, Shape{1, 2, 1, 4}),
17+
output = Tensor::share(DataType::F32, Shape{1, 2, 1, 2});
18+
auto kernel = SliceCpu::build(SliceInfo(dims, *input));
19+
ASSERT_TRUE(kernel);
20+
auto res = runtime::Resources();
21+
auto routine = kernel->lower(res).routine;
22+
// put input data
23+
std::vector<float>
24+
data(input->elementsSize()),
25+
result(output->elementsSize());
26+
std::iota(data.begin(), data.end(), 0);
27+
// inference
28+
{
29+
void const *inputs[]{data.data()};
30+
void *outputs[]{result.data()};
31+
routine(res, nullptr, inputs, outputs);
32+
}
33+
// check
34+
std::vector<float> ans{2, 3, 6, 7};
35+
EXPECT_EQ(result, ans);
36+
// test reform
37+
auto kernelReformed = SliceCpu::build(SliceInfo(dims, *input).reform(16));
38+
ASSERT_TRUE(kernelReformed);
39+
auto routineReformed = kernelReformed->lower(res).routine;
40+
std::vector<float> resultReformed(result.size());
41+
{
42+
void const *inputs[]{data.data()};
43+
void *outputs[]{resultReformed.data()};
44+
routineReformed(res, nullptr, inputs, outputs);
45+
}
46+
// check
47+
EXPECT_EQ(resultReformed, ans);
48+
}
49+
850
TEST(kernel, SliceCpu) {
951
// build routine
1052
Dimensions dims{

0 commit comments

Comments
 (0)