11#include " rearrange_cpu.h"
22#include " ../../utils.h"
3+ #include < cstdint>
34#include < cstring>
45#include < numeric>
56
@@ -13,11 +14,16 @@ infiniopStatus_t cpuCreateRearrangeDescriptor(infiniopHandle_t,
1314 if (dst->ndim != src->ndim || dst->ndim < 2 ) {
1415 return STATUS_BAD_TENSOR_SHAPE;
1516 }
17+ std::vector<uint64_t > shape;
18+ std::vector<int64_t > strides_dst, strides_src;
1619 auto ndim = dst->ndim ;
1720 for (int i = 0 ; i < ndim; ++i) {
1821 if (dst->shape [i] != src->shape [i]) {
1922 return STATUS_BAD_TENSOR_SHAPE;
2023 }
24+ shape.push_back (dst->shape [i]);
25+ strides_dst.push_back (dst->strides [i]);
26+ strides_src.push_back (src->strides [i]);
2127 }
2228 if (dst->strides [ndim - 1 ] != 1 || src->strides [ndim - 1 ] != 1 ) {
2329 return STATUS_BAD_TENSOR_STRIDES;
@@ -40,8 +46,10 @@ infiniopStatus_t cpuCreateRearrangeDescriptor(infiniopHandle_t,
4046 dst->dt ,
4147 r,
4248 ndim,
43- dst->shape , src->shape ,
44- dst->strides , src->strides };
49+ shape,
50+ strides_dst,
51+ strides_src,
52+ };
4553 return STATUS_SUCCESS;
4654}
4755
@@ -50,7 +58,7 @@ infiniopStatus_t cpuDestroyRearrangeDescriptor(RearrangeCpuDescriptor_t desc) {
5058 return STATUS_SUCCESS;
5159}
5260
53- inline int indices (uint64_t i, uint64_t ndim, int64_t * strides, uint64_t * shape) {
61+ inline int indices (uint64_t i, uint64_t ndim, std::vector< int64_t > strides, std::vector< uint64_t > shape) {
5462 uint64_t ans = 0 ;
5563 for (int j = ndim - 2 ; j >= 0 ; --j) {
5664 ans += (i % shape[j]) * strides[j];
@@ -62,11 +70,11 @@ inline int indices(uint64_t i, uint64_t ndim, int64_t *strides, uint64_t *shape)
6270void reform_cpu (RearrangeCpuDescriptor_t desc, void *dst, void const *src) {
6371 auto dst_ptr = reinterpret_cast <uint8_t *>(dst);
6472 auto src_ptr = reinterpret_cast <const uint8_t *>(src);
65- int bytes_size = desc->shape_dst [desc->ndim - 1 ] * desc->dt .size ;
73+ int bytes_size = desc->shape [desc->ndim - 1 ] * desc->dt .size ;
6674#pragma omp parallel for
6775 for (uint64_t i = 0 ; i < desc->r ; ++i) {
68- auto dst_offset = indices (i, desc->ndim , desc->strides_dst , desc->shape_dst );
69- auto src_offset = indices (i, desc->ndim , desc->strides_src , desc->shape_src );
76+ auto dst_offset = indices (i, desc->ndim , desc->strides_dst , desc->shape );
77+ auto src_offset = indices (i, desc->ndim , desc->strides_src , desc->shape );
7078 std::memcpy (dst_ptr + dst_offset * desc->dt .size , src_ptr + src_offset * desc->dt .size , bytes_size);
7179 }
7280}
0 commit comments