@@ -46,6 +46,9 @@ bool try_to_optimize_copy_with_any_format(at::Tensor& self, const at::Tensor& sr
4646namespace {
4747
4848std::vector<int64_t > inferOriginShape (at::IntArrayRef sizes, at::IntArrayRef strides) {
49+ if (sizes.size () <= 0 ) {
50+ return std::vector<int64_t >();
51+ }
4952 std::vector<int64_t > originSizes (sizes.size (), 1 );
5053 originSizes[0 ] = sizes[0 ] * strides[0 ];
5154 for (size_t i = 1 ; i < sizes.size (); i++) {
@@ -57,6 +60,30 @@ std::vector<int64_t> inferOriginShape(at::IntArrayRef sizes, at::IntArrayRef str
5760 return originSizes;
5861}
5962
63+ at::Tensor viewToSameDim (const at::Tensor& tensor, const at::IntArrayRef destShape) {
64+ const auto originShape = tensor.sizes ();
65+ std::vector<int64_t > strides (destShape.size (), 0 );
66+ if (originShape.size () < destShape.size ()) {
67+ std::vector<int64_t > sameDims;
68+ for (int i = destShape.size () - 1 ; i >= 0 ; i--) {
69+ for (int j = originShape.size () - 1 - sameDims.size (); j >= 0 ; j--) {
70+ if (destShape[i] == originShape[j]) {
71+ sameDims.push_back (i);
72+ strides[i] = tensor.strides ()[j];
73+ break ;
74+ }
75+ }
76+ }
77+ } else if (originShape.size () == destShape.size ()) {
78+ for (size_t i = 0 ; i < destShape.size (); i++) {
79+ if (destShape[i] == originShape[i]) {
80+ strides[i] = tensor.stride (i);
81+ }
82+ }
83+ }
84+ return impl::aten::viewStorage (tensor, destShape, strides);
85+ }
86+
6087} // namespace
6188
6289bool isPartOfOther (const at::Tensor& tensor) {
@@ -78,16 +105,18 @@ at::Tensor& npu_view_copy(at::Tensor& self, const at::Tensor& src, bool non_bloc
78105 auto self_stride = self.strides ();
79106 auto src_size = src.sizes ();
80107 auto src_stride = src.strides ();
81- auto originShape = inferOriginShape (self.sizes (), self.strides ());
82- auto originSizeTensor = at_npu::native::empty_npu (originShape, self.options ());
108+ auto originSelfShape = inferOriginShape (self.sizes (), self.strides ());
109+ auto originSizeTensor = at_npu::native::empty_npu (originSelfShape, self.options ());
110+
111+ auto originSrcShape = inferOriginShape (src.sizes (), src.strides ());
83112
84113 at_npu::native::OpCommand cmd;
85114 cmd.Name (" ViewCopy" )
86- .InputWithoutContiguous (impl::aten::viewStorage (self, originShape ))
115+ .InputWithoutContiguous (impl::aten::viewStorage (self, originSelfShape ))
87116 .Input (self_size, at::kLong , at_npu::native::CompileType::MEMORY_HOST_COMPILE_INDEPENDENT)
88117 .Input (self_stride, at::kLong , at_npu::native::CompileType::MEMORY_HOST_COMPILE_INDEPENDENT)
89118 .Input (at::Scalar (0 ), at::kLong )
90- .InputWithoutContiguous (src)
119+ .InputWithoutContiguous (impl::aten::viewStorage ( src, originSrcShape) )
91120 .Input (src_size, at::kLong , at_npu::native::CompileType::MEMORY_HOST_COMPILE_INDEPENDENT)
92121 .Input (src_stride, at::kLong , at_npu::native::CompileType::MEMORY_HOST_COMPILE_INDEPENDENT)
93122 .Input (at::Scalar (0 ), at::kLong )
@@ -102,9 +131,8 @@ at::Tensor& npu_view_copy(at::Tensor& self, const at::Tensor& src, bool non_bloc
102131void copy_d2d_last_method (at::Tensor& self, const at::Tensor& src, bool same_type, bool non_blocking) {
103132 // general copy method but Low performance
104133 RECORD_FUNCTION (" contiguous_d_ViewCopy" , std::vector<c10::IValue>({src}));
105- if (isPartOfOther (self)) {
134+ if (1 || isPartOfOther (self)) {
106135 npu_view_copy (self, src, non_blocking);
107- // custom_ops::npu_view_copy(self, src, non_blocking);
108136 } else {
109137 custom_ops::npu_view_copy (self, src, non_blocking);
110138 }
@@ -314,16 +342,20 @@ void copy_d2d_dtype_baseformat(at::Tensor& self, const at::Tensor& src, bool non
314342 // Optimized trans-contiguous method
315343 return ;
316344 } else {
317- // General trans-contiguous method
318- RECORD_FUNCTION (" contiguous_d_AsStrided" , std::vector<c10::IValue>({src}));
319- #if 0
320- custom_ops::npu_stride_copy_out(src, src.sizes(), src.strides(), src.storage_offset(), self);
321- #else
322- std::vector<int64_t > shape (src.sizes ().size (), 1 );
323- shape[0 ] = at::detail::computeStorageNbytes (src.sizes (), src.strides (), src.itemsize ()) / src.itemsize ();
324- custom_ops::npu_stride_copy_out (impl::aten::viewStorage (src, shape), src.sizes (), src.strides (), src.storage_offset (), self);
325- #endif
326- return ;
345+ // AsStride not support double
346+ if (src.scalar_type () != at::kDouble ) {
347+ // General trans-contiguous method
348+ RECORD_FUNCTION (" contiguous_d_AsStrided" , std::vector<c10::IValue>({src}));
349+ at::Tensor source = src;
350+ if (self.sizes () != src.sizes ()) {
351+ source = viewToSameDim (source, self.sizes ());
352+ }
353+
354+ // custom_ops::npu_stride_copy_out(src, src.sizes(), src.strides(), src.storage_offset(), self);
355+ auto shape = inferOriginShape (source.sizes (), source.strides ());
356+ custom_ops::npu_stride_copy_out (impl::aten::viewStorage (source, shape), source.sizes (), source.strides (), source.storage_offset (), self);
357+ return ;
358+ }
327359 }
328360 } else {
329361 // Contiguous source tensor copy to contiguous self tensor
@@ -459,7 +491,7 @@ class BroadcastContiguousOpt : public ContiguousOpt {
459491 }
460492}; // class BroadcastContiguousOpt
461493
462- REGISTER_COPY_OPT (broadcast, BroadcastContiguousOpt)
494+ // REGISTER_COPY_OPT(broadcast, BroadcastContiguousOpt)
463495
464496constexpr int MaxCombinedCasesNum = 2 ;
465497constexpr int ViewAndBaseInfoStackNum = 2 ;
@@ -831,7 +863,7 @@ class CombinedContiguousOpt : public ContiguousOpt {
831863 }
832864}; // class combinedContiguousOpt
833865
834- REGISTER_COPY_OPT (combined, CombinedContiguousOpt)
866+ // REGISTER_COPY_OPT(combined, CombinedContiguousOpt)
835867
836868class IndexingContiguousOpt : public ContiguousOpt {
837869public:
@@ -945,7 +977,7 @@ class IndexingContiguousOpt : public ContiguousOpt {
945977 }
946978}; // class IndexingContiguousOpt
947979
948- REGISTER_COPY_OPT (indexing, IndexingContiguousOpt)
980+ // REGISTER_COPY_OPT(indexing, IndexingContiguousOpt)
949981
950982class PermuteContiguousOpt : public ContiguousOpt {
951983public:
@@ -1106,7 +1138,7 @@ class PermuteContiguousOpt : public ContiguousOpt {
11061138 }
11071139}; // class PermuteContiguousOpt
11081140
1109- REGISTER_COPY_OPT (permute, PermuteContiguousOpt)
1141+ // REGISTER_COPY_OPT(permute, PermuteContiguousOpt)
11101142
11111143bool can_use_memecpy_for_NZ_format (const ContiguousTensorDesc& tensor_desc) {
11121144 int64_t tensor_shape_size = static_cast <int64_t >(tensor_desc.sizes_ .size ());
@@ -1193,7 +1225,7 @@ class ReshapeContiguousOpt : public ContiguousOpt {
11931225 bool CanOptimizer (const ContiguousTensorDesc& src_desc) override { return check_reshape_match (src_desc); }
11941226}; // class ReshapeContiguousOpt
11951227
1196- REGISTER_COPY_OPT (reshape, ReshapeContiguousOpt)
1228+ // REGISTER_COPY_OPT(reshape, ReshapeContiguousOpt)
11971229
11981230class ReshapeV2ContiguousOpt : public ContiguousOpt {
11991231public:
@@ -1269,7 +1301,7 @@ class ReshapeV2ContiguousOpt : public ContiguousOpt {
12691301 }
12701302}; // class ReshapeV2ContiguousOpt
12711303
1272- REGISTER_COPY_OPT (reshapeV2, ReshapeV2ContiguousOpt)
1304+ // REGISTER_COPY_OPT(reshapeV2, ReshapeV2ContiguousOpt)
12731305
12741306class SelectContiguousOpt : public ContiguousOpt {
12751307public:
@@ -1381,7 +1413,7 @@ class SelectContiguousOpt : public ContiguousOpt {
13811413 }
13821414}; // class SelectContiguousOpt
13831415
1384- REGISTER_COPY_OPT (select, SelectContiguousOpt)
1416+ // REGISTER_COPY_OPT(select, SelectContiguousOpt)
13851417
13861418class SliceContiguousOpt : public ContiguousOpt {
13871419public:
@@ -1489,7 +1521,7 @@ class SliceContiguousOpt : public ContiguousOpt {
14891521 }
14901522}; // class SliceContiguousOpt
14911523
1492- REGISTER_COPY_OPT (slice, SliceContiguousOpt)
1524+ // REGISTER_COPY_OPT(slice, SliceContiguousOpt)
14931525
14941526} // namespace native
14951527} // namespace at_npu
0 commit comments