Skip to content

Commit a02796a

Browse files
Zgc/diopi ascend fix copy (DeepLink-org#850)
* fix copy bug. * enhance copy --------- Co-authored-by: wangxing <131418410+POI-WX@users.noreply.github.com>
1 parent 86e951a commit a02796a

File tree

3 files changed

+60
-27
lines changed

3 files changed

+60
-27
lines changed

impl/ascend/device_configs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,7 +1356,6 @@
13561356
args=[
13571357
{
13581358
"ins": ["input"],
1359-
"shape": [Skip((6, 5, 384)), Skip((2, 4, 38, 45))],
13601359
"dtype": [Skip(np.complex128), Skip(np.complex64)],
13611360
},
13621361
{
@@ -1374,7 +1373,6 @@
13741373
args=[
13751374
{
13761375
"ins": ["input"],
1377-
"shape": [Skip((192, 147)), Skip((192, 147, 2)), Skip((2, 12, 38, 45, 3))],
13781376
"dtype": [Skip(np.complex128), Skip(np.complex64)],
13791377
},
13801378
{

impl/ascend_npu/torch_npu/csrc/CopyKernel.cpp

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ bool try_to_optimize_copy_with_any_format(at::Tensor& self, const at::Tensor& sr
4646
namespace {
4747

4848
std::vector<int64_t> inferOriginShape(at::IntArrayRef sizes, at::IntArrayRef strides) {
49+
if (sizes.size() <= 0) {
50+
return std::vector<int64_t>();
51+
}
4952
std::vector<int64_t> originSizes(sizes.size(), 1);
5053
originSizes[0] = sizes[0] * strides[0];
5154
for (size_t i = 1; i < sizes.size(); i++) {
@@ -57,6 +60,30 @@ std::vector<int64_t> inferOriginShape(at::IntArrayRef sizes, at::IntArrayRef str
5760
return originSizes;
5861
}
5962

63+
at::Tensor viewToSameDim(const at::Tensor& tensor, const at::IntArrayRef destShape) {
64+
const auto originShape = tensor.sizes();
65+
std::vector<int64_t> strides(destShape.size(), 0);
66+
if (originShape.size() < destShape.size()) {
67+
std::vector<int64_t> sameDims;
68+
for (int i = destShape.size() - 1; i >= 0; i--) {
69+
for (int j = originShape.size() - 1 - sameDims.size(); j >= 0; j--) {
70+
if (destShape[i] == originShape[j]) {
71+
sameDims.push_back(i);
72+
strides[i] = tensor.strides()[j];
73+
break;
74+
}
75+
}
76+
}
77+
} else if (originShape.size() == destShape.size()) {
78+
for (size_t i = 0; i < destShape.size(); i++) {
79+
if (destShape[i] == originShape[i]) {
80+
strides[i] = tensor.stride(i);
81+
}
82+
}
83+
}
84+
return impl::aten::viewStorage(tensor, destShape, strides);
85+
}
86+
6087
} // namespace
6188

6289
bool isPartOfOther(const at::Tensor& tensor) {
@@ -78,16 +105,18 @@ at::Tensor& npu_view_copy(at::Tensor& self, const at::Tensor& src, bool non_bloc
78105
auto self_stride = self.strides();
79106
auto src_size = src.sizes();
80107
auto src_stride = src.strides();
81-
auto originShape = inferOriginShape(self.sizes(), self.strides());
82-
auto originSizeTensor = at_npu::native::empty_npu(originShape, self.options());
108+
auto originSelfShape = inferOriginShape(self.sizes(), self.strides());
109+
auto originSizeTensor = at_npu::native::empty_npu(originSelfShape, self.options());
110+
111+
auto originSrcShape = inferOriginShape(src.sizes(), src.strides());
83112

84113
at_npu::native::OpCommand cmd;
85114
cmd.Name("ViewCopy")
86-
.InputWithoutContiguous(impl::aten::viewStorage(self, originShape))
115+
.InputWithoutContiguous(impl::aten::viewStorage(self, originSelfShape))
87116
.Input(self_size, at::kLong, at_npu::native::CompileType::MEMORY_HOST_COMPILE_INDEPENDENT)
88117
.Input(self_stride, at::kLong, at_npu::native::CompileType::MEMORY_HOST_COMPILE_INDEPENDENT)
89118
.Input(at::Scalar(0), at::kLong)
90-
.InputWithoutContiguous(src)
119+
.InputWithoutContiguous(impl::aten::viewStorage(src, originSrcShape))
91120
.Input(src_size, at::kLong, at_npu::native::CompileType::MEMORY_HOST_COMPILE_INDEPENDENT)
92121
.Input(src_stride, at::kLong, at_npu::native::CompileType::MEMORY_HOST_COMPILE_INDEPENDENT)
93122
.Input(at::Scalar(0), at::kLong)
@@ -102,9 +131,8 @@ at::Tensor& npu_view_copy(at::Tensor& self, const at::Tensor& src, bool non_bloc
102131
void copy_d2d_last_method(at::Tensor& self, const at::Tensor& src, bool same_type, bool non_blocking) {
103132
// general copy method but Low performance
104133
RECORD_FUNCTION("contiguous_d_ViewCopy", std::vector<c10::IValue>({src}));
105-
if (isPartOfOther(self)) {
134+
if (1 || isPartOfOther(self)) {
106135
npu_view_copy(self, src, non_blocking);
107-
// custom_ops::npu_view_copy(self, src, non_blocking);
108136
} else {
109137
custom_ops::npu_view_copy(self, src, non_blocking);
110138
}
@@ -314,16 +342,20 @@ void copy_d2d_dtype_baseformat(at::Tensor& self, const at::Tensor& src, bool non
314342
// Optimized trans-contiguous method
315343
return;
316344
} else {
317-
// General trans-contiguous method
318-
RECORD_FUNCTION("contiguous_d_AsStrided", std::vector<c10::IValue>({src}));
319-
#if 0
320-
custom_ops::npu_stride_copy_out(src, src.sizes(), src.strides(), src.storage_offset(), self);
321-
#else
322-
std::vector<int64_t> shape(src.sizes().size(), 1);
323-
shape[0] = at::detail::computeStorageNbytes(src.sizes(), src.strides(), src.itemsize()) / src.itemsize();
324-
custom_ops::npu_stride_copy_out(impl::aten::viewStorage(src, shape), src.sizes(), src.strides(), src.storage_offset(), self);
325-
#endif
326-
return;
345+
// AsStride not support double
346+
if (src.scalar_type() != at::kDouble) {
347+
// General trans-contiguous method
348+
RECORD_FUNCTION("contiguous_d_AsStrided", std::vector<c10::IValue>({src}));
349+
at::Tensor source = src;
350+
if (self.sizes() != src.sizes()) {
351+
source = viewToSameDim(source, self.sizes());
352+
}
353+
354+
// custom_ops::npu_stride_copy_out(src, src.sizes(), src.strides(), src.storage_offset(), self);
355+
auto shape = inferOriginShape(source.sizes(), source.strides());
356+
custom_ops::npu_stride_copy_out(impl::aten::viewStorage(source, shape), source.sizes(), source.strides(), source.storage_offset(), self);
357+
return;
358+
}
327359
}
328360
} else {
329361
// Contiguous source tensor copy to contiguous self tensor
@@ -459,7 +491,7 @@ class BroadcastContiguousOpt : public ContiguousOpt {
459491
}
460492
}; // class BroadcastContiguousOpt
461493

462-
REGISTER_COPY_OPT(broadcast, BroadcastContiguousOpt)
494+
// REGISTER_COPY_OPT(broadcast, BroadcastContiguousOpt)
463495

464496
constexpr int MaxCombinedCasesNum = 2;
465497
constexpr int ViewAndBaseInfoStackNum = 2;
@@ -831,7 +863,7 @@ class CombinedContiguousOpt : public ContiguousOpt {
831863
}
832864
}; // class combinedContiguousOpt
833865

834-
REGISTER_COPY_OPT(combined, CombinedContiguousOpt)
866+
// REGISTER_COPY_OPT(combined, CombinedContiguousOpt)
835867

836868
class IndexingContiguousOpt : public ContiguousOpt {
837869
public:
@@ -945,7 +977,7 @@ class IndexingContiguousOpt : public ContiguousOpt {
945977
}
946978
}; // class IndexingContiguousOpt
947979

948-
REGISTER_COPY_OPT(indexing, IndexingContiguousOpt)
980+
// REGISTER_COPY_OPT(indexing, IndexingContiguousOpt)
949981

950982
class PermuteContiguousOpt : public ContiguousOpt {
951983
public:
@@ -1106,7 +1138,7 @@ class PermuteContiguousOpt : public ContiguousOpt {
11061138
}
11071139
}; // class PermuteContiguousOpt
11081140

1109-
REGISTER_COPY_OPT(permute, PermuteContiguousOpt)
1141+
// REGISTER_COPY_OPT(permute, PermuteContiguousOpt)
11101142

11111143
bool can_use_memecpy_for_NZ_format(const ContiguousTensorDesc& tensor_desc) {
11121144
int64_t tensor_shape_size = static_cast<int64_t>(tensor_desc.sizes_.size());
@@ -1193,7 +1225,7 @@ class ReshapeContiguousOpt : public ContiguousOpt {
11931225
bool CanOptimizer(const ContiguousTensorDesc& src_desc) override { return check_reshape_match(src_desc); }
11941226
}; // class ReshapeContiguousOpt
11951227

1196-
REGISTER_COPY_OPT(reshape, ReshapeContiguousOpt)
1228+
// REGISTER_COPY_OPT(reshape, ReshapeContiguousOpt)
11971229

11981230
class ReshapeV2ContiguousOpt : public ContiguousOpt {
11991231
public:
@@ -1269,7 +1301,7 @@ class ReshapeV2ContiguousOpt : public ContiguousOpt {
12691301
}
12701302
}; // class ReshapeV2ContiguousOpt
12711303

1272-
REGISTER_COPY_OPT(reshapeV2, ReshapeV2ContiguousOpt)
1304+
// REGISTER_COPY_OPT(reshapeV2, ReshapeV2ContiguousOpt)
12731305

12741306
class SelectContiguousOpt : public ContiguousOpt {
12751307
public:
@@ -1381,7 +1413,7 @@ class SelectContiguousOpt : public ContiguousOpt {
13811413
}
13821414
}; // class SelectContiguousOpt
13831415

1384-
REGISTER_COPY_OPT(select, SelectContiguousOpt)
1416+
// REGISTER_COPY_OPT(select, SelectContiguousOpt)
13851417

13861418
class SliceContiguousOpt : public ContiguousOpt {
13871419
public:
@@ -1489,7 +1521,7 @@ class SliceContiguousOpt : public ContiguousOpt {
14891521
}
14901522
}; // class SliceContiguousOpt
14911523

1492-
REGISTER_COPY_OPT(slice, SliceContiguousOpt)
1524+
// REGISTER_COPY_OPT(slice, SliceContiguousOpt)
14931525

14941526
} // namespace native
14951527
} // namespace at_npu

impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp

100644100755
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2400,7 +2400,10 @@ std::tuple<aclTensorDesc*, aclDataBuffer*> CovertToAclOutput(const at::Tensor& t
24002400
aclDataType aclDataType = CalcuOpUtil::ConvertToAclDataType(tensor.scalar_type(), forceDataType);
24012401
const auto& npuDesc = torch_npu::NPUBridge::GetNpuStorageImplDesc(tensor);
24022402
const auto& dims = tensor.sizes();
2403-
auto& storageDims = npuDesc.storage_sizes_;
2403+
auto storageDims = npuDesc.storage_sizes_;
2404+
if (storageDims.size() == 0 && tensor.numel() > 0) {
2405+
storageDims.push_back(1);
2406+
}
24042407
AclTensorDescMaker desc;
24052408
auto aclDesc = desc.Create(aclDataType, dims, npuDesc.origin_format_).SetFormat(npuDesc.npu_format_).SetShape(storageDims).Get();
24062409
auto numel = c10::multiply_integers(storageDims);

0 commit comments

Comments
 (0)