@@ -60,8 +60,8 @@ uint16_t f32_to_f16(float val) {
6060}
6161
6262size_t indexToReducedOffset (size_t flat_index, size_t ndim,
63- int64_t const *broadcasted_strides,
64- int64_t const *target_strides) {
63+ ptrdiff_t const *broadcasted_strides,
64+ ptrdiff_t const *target_strides) {
6565 size_t res = 0 ;
6666 for (size_t i = 0 ; i < ndim; ++i) {
6767 res += flat_index / broadcasted_strides[i] * target_strides[i];
@@ -71,7 +71,7 @@ size_t indexToReducedOffset(size_t flat_index, size_t ndim,
7171}
7272
7373size_t indexToOffset (size_t flat_index, size_t ndim, size_t const *shape,
74- int64_t const *strides) {
74+ ptrdiff_t const *strides) {
7575 size_t res = 0 ;
7676 for (size_t i = ndim; i-- >= 0 ;) {
7777 res += (flat_index % shape[i]) * strides[i];
@@ -81,7 +81,7 @@ size_t indexToOffset(size_t flat_index, size_t ndim, size_t const *shape,
8181}
8282
8383size_t getPaddedSize (size_t ndim, size_t *shape, size_t const *pads) {
84- uint64_t total_size = 1 ;
84+ size_t total_size = 1 ;
8585 for (size_t i = 0 ; i < ndim; ++i) {
8686 total_size *= shape[i] + (i < 2 ? 0 : 2 * pads[i - 2 ]);
8787 }
0 commit comments