diff --git a/paddle/fluid/pybind/torch_compat.h b/paddle/fluid/pybind/torch_compat.h index a487c1f9daffdb..9a7ab8b3b36498 100644 --- a/paddle/fluid/pybind/torch_compat.h +++ b/paddle/fluid/pybind/torch_compat.h @@ -138,7 +138,7 @@ inline torch::IValue OperationInvoker::to_ivalue(py::handle obj) { } else if (py::isinstance(obj)) { return torch::IValue(py::cast(obj)); } else if (py::isinstance(obj)) { - return torch::IValue(py::cast(obj)); + return torch::IValue(py::cast(obj)); } else if (py::isinstance(obj)) { return torch::IValue(py::cast(obj)); } else if (py::isinstance(obj)) { diff --git a/paddle/phi/api/include/compat/ATen/ATen.h b/paddle/phi/api/include/compat/ATen/ATen.h index b42595669de6ef..a0b12c8b82e27a 100644 --- a/paddle/phi/api/include/compat/ATen/ATen.h +++ b/paddle/phi/api/include/compat/ATen/ATen.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include diff --git a/paddle/phi/api/include/compat/ATen/Utils.h b/paddle/phi/api/include/compat/ATen/Utils.h index 30a417cd6f61ec..43e794c41d927e 100644 --- a/paddle/phi/api/include/compat/ATen/Utils.h +++ b/paddle/phi/api/include/compat/ATen/Utils.h @@ -14,7 +14,6 @@ #pragma once -#include #include #include #include diff --git a/paddle/phi/api/include/compat/ATen/core/TensorBase.h b/paddle/phi/api/include/compat/ATen/core/TensorBase.h index f243112ba1c9e2..83592fd504ae05 100644 --- a/paddle/phi/api/include/compat/ATen/core/TensorBase.h +++ b/paddle/phi/api/include/compat/ATen/core/TensorBase.h @@ -34,6 +34,26 @@ class PADDLE_API TensorBase { public: TensorBase() = default; TensorBase(const PaddleTensor& tensor) : tensor_(tensor){}; // NOLINT + TensorBase(const TensorBase&) = default; + TensorBase(TensorBase&&) noexcept = default; + ~TensorBase() noexcept = default; + +#if defined(_MSC_VER) + TensorBase& operator=(const TensorBase& x) & { + tensor_ = x.tensor_; + return *this; + } + TensorBase& operator=(TensorBase&& x) & noexcept { + tensor_ = std::move(x.tensor_); + return *this; + } +#else + TensorBase& operator=(const TensorBase& x) & = default; + TensorBase& operator=(TensorBase&& x) & noexcept = default; +#endif + + TensorBase& operator=(const TensorBase&) && = delete; + TensorBase& operator=(TensorBase&&) && noexcept = delete; void* data_ptr() const { return const_cast(tensor_.data()); } template @@ -211,8 +231,9 @@ class PADDLE_API TensorBase { template TensorAccessor accessor() && = delete; - PaddleTensor _PD_GetInner() const { return tensor_; } - PaddleTensor& _PD_GetInner() { return tensor_; } + const PaddleTensor& _PD_GetInner() const& { return tensor_; } + PaddleTensor& _PD_GetInner() & { return tensor_; } + PaddleTensor&& _PD_GetInner() && { return std::move(tensor_); } protected: PaddleTensor tensor_; diff --git a/paddle/phi/api/include/compat/ATen/core/TensorBody.h b/paddle/phi/api/include/compat/ATen/core/TensorBody.h index 504da198533769..82478b9030fff2 100644 --- a/paddle/phi/api/include/compat/ATen/core/TensorBody.h +++ b/paddle/phi/api/include/compat/ATen/core/TensorBody.h @@ -24,6 +24,8 @@ class Tensor : public TensorBase { public: Tensor() = default; Tensor(const PaddleTensor& tensor) : TensorBase(tensor){}; // NOLINT + Tensor(const Tensor& tensor) = default; + Tensor(Tensor&& tensor) = default; // Implicitly move-constructible from TensorBase, but must be explicit to // increase refcount @@ -31,18 +33,33 @@ class Tensor : public TensorBase { /*implicit*/ Tensor(TensorBase&& base) // NOLINT : TensorBase(std::move(base)) {} - // TODO(dev): Implement assignment operators - // Tensor& operator=(const Tensor& x) & noexcept { - // return operator=(static_cast(x)); - // } - // Tensor& operator=(Tensor&& x) & noexcept { - // return operator=(static_cast(x)); - // } + Tensor& operator=(const PaddleTensor& x) & noexcept { + tensor_ = x; + return *this; + } + Tensor& operator=(const TensorBase& x) & noexcept { + const PaddleTensor& inner = x._PD_GetInner(); + tensor_ = inner; + return *this; + } + Tensor& operator=(PaddleTensor&& x) & noexcept { + tensor_ = std::move(x); + return *this; + } + Tensor& operator=(TensorBase&& x) & noexcept { + tensor_ = std::move(x)._PD_GetInner(); + return *this; + } + Tensor& operator=(const Tensor& x) & noexcept { + return operator=(static_cast(x)); + } + Tensor& operator=(Tensor&& x) & noexcept { + return operator=(static_cast(x)); + } Tensor& operator=(const Scalar& v) && { return fill_(v); } - // TODO(dev): Implement assignment operators - // Tensor& operator=(const Tensor& rhs) && { return copy_(rhs); } - // Tensor& operator=(Tensor&& rhs) && { return copy_(rhs); } + Tensor& operator=(const Tensor& rhs) && { return copy_(rhs); } + Tensor& operator=(Tensor&& rhs) && { return copy_(rhs); } void* data_ptr() const { return const_cast(tensor_.data()); } template diff --git a/paddle/phi/api/include/compat/ATen/core/TensorMethods.cpp b/paddle/phi/api/include/compat/ATen/core/TensorMethods.cpp index b452493b22aa3d..2ed682664a75b6 100644 --- a/paddle/phi/api/include/compat/ATen/core/TensorMethods.cpp +++ b/paddle/phi/api/include/compat/ATen/core/TensorMethods.cpp @@ -48,7 +48,7 @@ void check_type(const TensorBase& tensor, template <> \ PADDLE_API T* TensorBase::mutable_data_ptr() const { \ check_type(*this, ScalarType::name, #name); \ - return const_cast(tensor_).mutable_data(); \ + return const_cast(tensor_).data(); \ } \ \ template <> \