44namespace refactor ::kernel {
55
66 bool SliceInfo::Dim::operator ==(Dim const &rhs) const noexcept {
7- return countStride == rhs.countStride &&
8- sizeStart == rhs.sizeStart &&
9- sizeStride == rhs.sizeStride ;
7+ return strideO == rhs.strideO &&
8+ strideI == rhs.strideI &&
9+ skip == rhs.skip ;
1010 }
1111 bool SliceInfo::Dim::operator !=(Dim const &rhs) const noexcept {
1212 return !operator ==(rhs);
@@ -15,74 +15,70 @@ namespace refactor::kernel {
1515 SliceInfo::SliceInfo (
1616 std::vector<Dim> dims_,
1717 dim_t blockCount_,
18- dim_t blockSize_,
19- dim_t baseOffset_) noexcept
18+ dim_t blockSize_) noexcept
2019 : dims(std::move(dims_)),
2120 blockCount (blockCount_),
22- blockSize(blockSize_),
23- baseOffset(baseOffset_) {}
21+ blockSize(blockSize_) {}
2422
25- SliceInfo::SliceInfo (Dimensions const & dims_, Tensor const &input)
26- : dims( 1 ) ,
23+ SliceInfo::SliceInfo (Dimensions dims_, Tensor const &input)
24+ : dims{} ,
2725 blockCount (1 ),
28- blockSize(input.dataType.size()),
29- baseOffset(0 ) {
30- ASSERT (dims_.size () == static_cast <size_t >(input.rank ()), " Unreachable" );
26+ blockSize(input.dataType.size()) {
27+ size_t rank = input.rank ();
28+ if (!rank) { return ; }// scalar input
29+ ASSERT (dims_.size () == rank, " unreachable" );
3130
32- auto continuous = true ;
33- auto stride = blockSize;
34- dims[0 ] = {1 , 0 , static_cast <sdim_t >(stride)};
35- for (auto i : range0_ (input.rank ()).rev ()) {
36- auto l = input.shape [i];
37- auto const &d = dims_[i];
38- if (auto &it = dims.back (); continuous && d.step == 1 ) {
39- it.countStride *= d.length ;
40- it.sizeStart = d.start * stride;
41- it.sizeStride *= l;
42- } else {
43- dims.push_back (Dim{
44- static_cast <dim_t >(it.countStride * d.length ),
45- static_cast <dim_t >(d.start * stride),
46- static_cast <sdim_t >(d.step * stride),
47- });
31+ std::vector<dim_t > shape;
32+ {// 去除形状里的 1
33+ shape.reserve (rank);
34+ for (auto i : range0_ (rank)) {
35+ if (auto l = input.shape [i]; l != 1 ) {
36+ if (auto j = shape.size (); j < i) { dims_[j] = dims_[i]; }
37+ shape.push_back (l);
38+ }
4839 }
49- continuous = d.length == l;
50- stride *= l;
40+ dims_.resize (rank = shape.size ());
5141 }
52- baseOffset = dims[0 ].sizeStart ;
53- auto elementCount = dims[0 ].countStride ;
54- blockSize *= elementCount;
55- for (auto &d : dims) {
56- d.countStride /= elementCount;
42+ dims.reserve (rank);
43+ dim_t strideI = 1 ;
44+ for (auto i : range0_ (rank).rev ()) {
45+ auto const &dim = dims_[i];
46+ dims.push_back ({
47+ .strideO = blockCount,
48+ .skip = static_cast <dim_t >(strideI * dim.start ),
49+ .strideI = static_cast <sdim_t >(strideI * dim.step ),
50+ });
51+ blockCount *= dim.length ;
52+ strideI *= shape[i];
5753 }
5854 std::reverse (dims.begin (), dims.end ());
59- blockCount = dims[0 ].countStride ;
60- for (auto i : range (1ul , dims.size ())) {
61- dims[i - 1 ].countStride = dims[i].countStride ;
55+
56+ while (!dims.empty ()) {
57+ auto const &dim = dims.back ();
58+ if (dim.strideI == static_cast <sdim_t >(dim.strideO ) && !dim.skip ) {
59+ dims.pop_back ();
60+ } else {
61+ long times = std::gcd (std::gcd (dim.strideI , dim.strideO ), dim.skip );
62+ blockCount /= times;
63+ blockSize *= times;
64+ if (!dims.empty ()) {
65+ for (auto &dim : dims) {
66+ dim.strideO /= times;
67+ dim.skip /= times;
68+ dim.strideI /= times;
69+ }
70+ if (dims.back ().strideO != 1 ) {
71+ dims.push_back ({1 , 0 , 1 });
72+ }
73+ }
74+ break ;
75+ }
6276 }
63- dims.pop_back ();
64- dims.shrink_to_fit ();
6577 }
6678
6779 SliceInfo SliceInfo::reform (dim_t maxblockSize) const noexcept {
68- auto blockSize_ = std::gcd (blockSize, maxblockSize);
69- if (blockSize_ == blockSize) { return *this ; }
70- auto times = blockSize / blockSize_;
71- SliceInfo ans{
72- std::vector<Dim>(dims.size () + 1 ),
73- blockCount * times,
74- blockSize_,
75- baseOffset,
76- };
77- for (auto i : range0_ (dims.size ())) {
78- auto const &d = dims[i];
79- ans.dims [i] = {
80- d.countStride * times,
81- d.sizeStart ,
82- d.sizeStride ,
83- };
84- }
85- ans.dims .back () = {1 , 0 , static_cast <sdim_t >(blockSize_)};
80+ auto ans = *this ;
81+ ans.reformAssign (maxblockSize);
8682 return ans;
8783 }
8884
@@ -93,10 +89,12 @@ namespace refactor::kernel {
9389 blockCount *= times;
9490 blockSize = blockSize_;
9591 for (auto &d : dims) {
96- d.countStride *= times;
92+ d.strideO *= times;
93+ d.strideI *= times;
94+ d.skip *= times;
9795 }
9896 dims.resize (dims.size () + 1 );
99- dims.back () = {1 , 0 , static_cast < sdim_t >(blockSize_) };
97+ dims.back () = {1 , 0 , 1 };
10098 }
10199
102100
0 commit comments