Skip to content

Commit 21d46fd

Browse files
SigureMoCopilot
andauthored
[Compat] Add more constructors for compatible at::Tensor (#76307)
--------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 0091f01 commit 21d46fd

File tree

6 files changed

+53
-15
lines changed

6 files changed

+53
-15
lines changed

paddle/fluid/pybind/torch_compat.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ inline torch::IValue OperationInvoker::to_ivalue(py::handle obj) {
138138
} else if (py::isinstance<py::bool_>(obj)) {
139139
return torch::IValue(py::cast<bool>(obj));
140140
} else if (py::isinstance<py::int_>(obj)) {
141-
return torch::IValue(py::cast<int>(obj));
141+
return torch::IValue(py::cast<int64_t>(obj));
142142
} else if (py::isinstance<py::float_>(obj)) {
143143
return torch::IValue(py::cast<double>(obj));
144144
} else if (py::isinstance<py::str>(obj)) {

paddle/phi/api/include/compat/ATen/ATen.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <ATen/Device.h>
1818
#include <ATen/Functions.h>
1919
#include <ATen/Tensor.h>
20+
#include <ATen/Utils.h>
2021
#include <c10/core/Device.h>
2122
#include <c10/core/DeviceType.h>
2223
#include <c10/core/MemoryFormat.h>

paddle/phi/api/include/compat/ATen/Utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
#pragma once
1616

17-
#include <ATen/EmptyTensor.h>
1817
#include <c10/core/ScalarType.h>
1918
#include <c10/util/ArrayRef.h>
2019
#include <c10/util/Exception.h>

paddle/phi/api/include/compat/ATen/core/TensorBase.h

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,26 @@ class PADDLE_API TensorBase {
3434
public:
3535
TensorBase() = default;
3636
TensorBase(const PaddleTensor& tensor) : tensor_(tensor){}; // NOLINT
37+
TensorBase(const TensorBase&) = default;
38+
TensorBase(TensorBase&&) noexcept = default;
39+
~TensorBase() noexcept = default;
40+
41+
#if defined(_MSC_VER)
42+
TensorBase& operator=(const TensorBase& x) & {
43+
tensor_ = x.tensor_;
44+
return *this;
45+
}
46+
TensorBase& operator=(TensorBase&& x) & noexcept {
47+
tensor_ = std::move(x.tensor_);
48+
return *this;
49+
}
50+
#else
51+
TensorBase& operator=(const TensorBase& x) & = default;
52+
TensorBase& operator=(TensorBase&& x) & noexcept = default;
53+
#endif
54+
55+
TensorBase& operator=(const TensorBase&) && = delete;
56+
TensorBase& operator=(TensorBase&&) && noexcept = delete;
3757

3858
void* data_ptr() const { return const_cast<void*>(tensor_.data()); }
3959
template <typename T>
@@ -211,8 +231,9 @@ class PADDLE_API TensorBase {
211231
template <typename T, size_t N>
212232
TensorAccessor<T, N> accessor() && = delete;
213233

214-
PaddleTensor _PD_GetInner() const { return tensor_; }
215-
PaddleTensor& _PD_GetInner() { return tensor_; }
234+
const PaddleTensor& _PD_GetInner() const& { return tensor_; }
235+
PaddleTensor& _PD_GetInner() & { return tensor_; }
236+
PaddleTensor&& _PD_GetInner() && { return std::move(tensor_); }
216237

217238
protected:
218239
PaddleTensor tensor_;

paddle/phi/api/include/compat/ATen/core/TensorBody.h

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,42 @@ class Tensor : public TensorBase {
2424
public:
2525
Tensor() = default;
2626
Tensor(const PaddleTensor& tensor) : TensorBase(tensor){}; // NOLINT
27+
Tensor(const Tensor& tensor) = default;
28+
Tensor(Tensor&& tensor) = default;
2729

2830
// Implicitly move-constructible from TensorBase, but must be explicit to
2931
// increase refcount
3032
explicit Tensor(const TensorBase& base) : TensorBase(base) {} // NOLINT
3133
/*implicit*/ Tensor(TensorBase&& base) // NOLINT
3234
: TensorBase(std::move(base)) {}
3335

34-
// TODO(dev): Implement assignment operators
35-
// Tensor& operator=(const Tensor& x) & noexcept {
36-
// return operator=(static_cast<const TensorBase&>(x));
37-
// }
38-
// Tensor& operator=(Tensor&& x) & noexcept {
39-
// return operator=(static_cast<TensorBase&&>(x));
40-
// }
36+
Tensor& operator=(const PaddleTensor& x) & noexcept {
37+
tensor_ = x;
38+
return *this;
39+
}
40+
Tensor& operator=(const TensorBase& x) & noexcept {
41+
const PaddleTensor& inner = x._PD_GetInner();
42+
tensor_ = inner;
43+
return *this;
44+
}
45+
Tensor& operator=(PaddleTensor&& x) & noexcept {
46+
tensor_ = std::move(x);
47+
return *this;
48+
}
49+
Tensor& operator=(TensorBase&& x) & noexcept {
50+
tensor_ = std::move(x)._PD_GetInner();
51+
return *this;
52+
}
4153

54+
Tensor& operator=(const Tensor& x) & noexcept {
55+
return operator=(static_cast<const TensorBase&>(x));
56+
}
57+
Tensor& operator=(Tensor&& x) & noexcept {
58+
return operator=(static_cast<TensorBase&&>(x));
59+
}
4260
Tensor& operator=(const Scalar& v) && { return fill_(v); }
43-
// TODO(dev): Implement assignment operators
44-
// Tensor& operator=(const Tensor& rhs) && { return copy_(rhs); }
45-
// Tensor& operator=(Tensor&& rhs) && { return copy_(rhs); }
61+
Tensor& operator=(const Tensor& rhs) && { return copy_(rhs); }
62+
Tensor& operator=(Tensor&& rhs) && { return copy_(rhs); }
4663

4764
void* data_ptr() const { return const_cast<void*>(tensor_.data()); }
4865
template <typename T>

paddle/phi/api/include/compat/ATen/core/TensorMethods.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ void check_type(const TensorBase& tensor,
4848
template <> \
4949
PADDLE_API T* TensorBase::mutable_data_ptr() const { \
5050
check_type(*this, ScalarType::name, #name); \
51-
return const_cast<PaddleTensor&>(tensor_).mutable_data<T>(); \
51+
return const_cast<PaddleTensor&>(tensor_).data<T>(); \
5252
} \
5353
\
5454
template <> \

0 commit comments

Comments
 (0)