Skip to content

Commit 5f518a3

Browse files
committed
fix(kernel): 修正 SliceInfo 优化计算
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 39785bd commit 5f518a3

File tree

2 files changed

+24
-29
lines changed

2 files changed

+24
-29
lines changed

src/04kernel/src/attributes/slice_info.cc

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -39,41 +39,36 @@ namespace refactor::kernel {
3939
}
4040
dims_.resize(rank = shape.size());
4141
}
42-
dims.reserve(rank);
42+
// 合并末尾的连续维度
43+
for (auto i : range0_(rank).rev()) {
44+
if (auto d = shape[i]; dims_[i].length == d) {
45+
blockSize *= d;
46+
shape.pop_back();
47+
dims_.pop_back();
48+
} else {
49+
dims.resize(rank = shape.size());
50+
if (auto &dim = dims_[i]; dim.step == 1) {
51+
if (auto times = std::gcd(std::gcd(dim.start, dim.length), shape[i]); times > 1) {
52+
blockSize *= times;
53+
dim.start /= times;
54+
dim.length /= times;
55+
shape[i] /= times;
56+
}
57+
}
58+
break;
59+
}
60+
}
4361
dim_t strideI = 1;
4462
for (auto i : range0_(rank).rev()) {
4563
auto const &dim = dims_[i];
46-
dims.push_back({
64+
dims[i] = {
4765
.strideO = blockCount,
4866
.skip = static_cast<dim_t>(strideI * dim.start),
4967
.strideI = static_cast<sdim_t>(strideI * dim.step),
50-
});
68+
};
5169
blockCount *= dim.length;
5270
strideI *= shape[i];
5371
}
54-
std::reverse(dims.begin(), dims.end());
55-
56-
while (!dims.empty()) {
57-
auto const &dim = dims.back();
58-
if (dim.strideI == static_cast<sdim_t>(dim.strideO) && !dim.skip) {
59-
dims.pop_back();
60-
} else {
61-
long times = std::gcd(std::gcd(dim.strideI, dim.strideO), dim.skip);
62-
blockCount /= times;
63-
blockSize *= times;
64-
if (!dims.empty()) {
65-
for (auto &dim : dims) {
66-
dim.strideO /= times;
67-
dim.skip /= times;
68-
dim.strideI /= times;
69-
}
70-
if (dims.back().strideO != 1) {
71-
dims.push_back({1, 0, 1});
72-
}
73-
}
74-
break;
75-
}
76-
}
7772
}
7873

7974
SliceInfo SliceInfo::reform(dim_t maxblockSize) const noexcept {
@@ -97,5 +92,4 @@ namespace refactor::kernel {
9792
dims.back() = {1, 0, 1};
9893
}
9994

100-
10195
}// namespace refactor::kernel

src/04kernel/test/kernels/slice/test_cpu.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ TEST(kernel, SliceCpu) {
5959
};
6060
auto input = Tensor::share(DataType::F32, Shape{7, 6, 5, 1, 2, 3}),
6161
output = Tensor::share(DataType::F32, Shape{3, 2, 3, 1, 2, 3});
62-
auto kernel = SliceCpu::build(SliceInfo(dims, *input));
62+
auto info = SliceInfo(dims, *input);
63+
auto kernel = SliceCpu::build(info);
6364
ASSERT_TRUE(kernel);
6465
auto res = runtime::Resources();
6566
auto routine = kernel->lower(res).routine;
@@ -94,7 +95,7 @@ TEST(kernel, SliceCpu) {
9495
}
9596
}
9697
// test reform
97-
auto kernelReformed = SliceCpu::build(SliceInfo(dims, *input).reform(16));
98+
auto kernelReformed = SliceCpu::build(info.reform(16));
9899
ASSERT_TRUE(kernelReformed);
99100
auto routineReformed = kernelReformed->lower(res).routine;
100101
std::vector<float> resultReformed(result.size());

0 commit comments

Comments
 (0)