Skip to content

Commit 8c4d1aa

Browse files
committed
fix: CpuRearrangeDescriptor
1 parent 0cb4203 commit 8c4d1aa

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

src/ops/rearrange/cpu/rearrange_cpu.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "rearrange_cpu.h"
22
#include "../../utils.h"
3+
#include <cstdint>
34
#include <cstring>
45
#include <numeric>
56

@@ -13,11 +14,16 @@ infiniopStatus_t cpuCreateRearrangeDescriptor(infiniopHandle_t,
1314
if (dst->ndim != src->ndim || dst->ndim < 2) {
1415
return STATUS_BAD_TENSOR_SHAPE;
1516
}
17+
std::vector<uint64_t> shape;
18+
std::vector<int64_t> strides_dst, strides_src;
1619
auto ndim = dst->ndim;
1720
for (int i = 0; i < ndim; ++i) {
1821
if (dst->shape[i] != src->shape[i]) {
1922
return STATUS_BAD_TENSOR_SHAPE;
2023
}
24+
shape.push_back(dst->shape[i]);
25+
strides_dst.push_back(dst->strides[i]);
26+
strides_src.push_back(src->strides[i]);
2127
}
2228
if (dst->strides[ndim - 1] != 1 || src->strides[ndim - 1] != 1) {
2329
return STATUS_BAD_TENSOR_STRIDES;
@@ -40,8 +46,10 @@ infiniopStatus_t cpuCreateRearrangeDescriptor(infiniopHandle_t,
4046
dst->dt,
4147
r,
4248
ndim,
43-
dst->shape, src->shape,
44-
dst->strides, src->strides};
49+
shape,
50+
strides_dst,
51+
strides_src,
52+
};
4553
return STATUS_SUCCESS;
4654
}
4755

@@ -50,7 +58,7 @@ infiniopStatus_t cpuDestroyRearrangeDescriptor(RearrangeCpuDescriptor_t desc) {
5058
return STATUS_SUCCESS;
5159
}
5260

53-
inline int indices(uint64_t i, uint64_t ndim, int64_t *strides, uint64_t *shape) {
61+
inline int indices(uint64_t i, uint64_t ndim, std::vector<int64_t> strides, std::vector<uint64_t> shape) {
5462
uint64_t ans = 0;
5563
for (int j = ndim - 2; j >= 0; --j) {
5664
ans += (i % shape[j]) * strides[j];
@@ -62,11 +70,11 @@ inline int indices(uint64_t i, uint64_t ndim, int64_t *strides, uint64_t *shape)
6270
void reform_cpu(RearrangeCpuDescriptor_t desc, void *dst, void const *src) {
6371
auto dst_ptr = reinterpret_cast<uint8_t *>(dst);
6472
auto src_ptr = reinterpret_cast<const uint8_t *>(src);
65-
int bytes_size = desc->shape_dst[desc->ndim - 1] * desc->dt.size;
73+
int bytes_size = desc->shape[desc->ndim - 1] * desc->dt.size;
6674
#pragma omp parallel for
6775
for (uint64_t i = 0; i < desc->r; ++i) {
68-
auto dst_offset = indices(i, desc->ndim, desc->strides_dst, desc->shape_dst);
69-
auto src_offset = indices(i, desc->ndim, desc->strides_src, desc->shape_src);
76+
auto dst_offset = indices(i, desc->ndim, desc->strides_dst, desc->shape);
77+
auto src_offset = indices(i, desc->ndim, desc->strides_src, desc->shape);
7078
std::memcpy(dst_ptr + dst_offset * desc->dt.size, src_ptr + src_offset * desc->dt.size, bytes_size);
7179
}
7280
}

src/ops/rearrange/cpu/rearrange_cpu.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
#define __CPU_REARRANGE_H__
33

44
#include "operators.h"
5+
#include <vector>
56
struct RearrangeCpuDescriptor {
67
Device device;
78
DataLayout dt;
89
uint64_t r;
910
uint64_t ndim;
10-
uint64_t *shape_dst, *shape_src;
11-
int64_t *strides_dst, *strides_src;
11+
std::vector<uint64_t> shape;
12+
std::vector<int64_t> strides_dst;
13+
std::vector<int64_t> strides_src;
1214
};
1315

1416
typedef struct RearrangeCpuDescriptor *RearrangeCpuDescriptor_t;

0 commit comments

Comments
 (0)