From d71749e61e061ed0a4924438b539c90cb4615f21 Mon Sep 17 00:00:00 2001 From: shuo-ouyang Date: Wed, 27 Oct 2021 22:08:40 +0800 Subject: [PATCH] support FInplaceIdentity for legacy operators --- include/mxnet/operator.h | 22 ++++++++++++++++++++++ src/nnvm/legacy_op_util.cc | 28 ++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index d813c74fa9b6..792c033270f0 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -356,6 +356,17 @@ class OperatorProperty { const std::vector &out_data) const { return std::vector >(); } + /*! + * \brief Get if the forward inplace option is an identity. + * This function enables inplace optimization even when input reference count + * is greater than one. + * + * \return list of bool indicating whether corresponding pair from ForwardInplaceOption + * is an identity. + */ + virtual std::vector ForwardInplaceIdentity() const { + return std::vector(); + } /*! * \brief Get possible backward inplace options. * This function enables optimization to reuse memory of inputs in output. @@ -389,6 +400,17 @@ class OperatorProperty { const std::vector &in_grad) const { return std::vector >(); } + /*! + * \brief Get if the backward inplace option is an identity. + * This function enables inplace optimization even when input reference count + * is greater than one. + * + * \return list of bool indicating whether corresponding pair from BackwardInplaceOption + * is an identity. + */ + virtual std::vector BackwardInplaceIdentity() const { + return std::vector(); + } /*! * \brief Get Backward Input Dependency for generic types of data. * Normally T can be pointer of Symbol::DataEntry, or NDArray. diff --git a/src/nnvm/legacy_op_util.cc b/src/nnvm/legacy_op_util.cc index 4bfbf7df82b3..6ff51ea1bbf3 100644 --- a/src/nnvm/legacy_op_util.cc +++ b/src/nnvm/legacy_op_util.cc @@ -280,6 +280,19 @@ std::vector> OpPropInplaceOption(const NodeAttrs& attrs) { return forward_inplace; } +std::vector OpPropInplaceIdentity(const NodeAttrs& attrs) { + auto& prop = nnvm::get(attrs.parsed); + auto forward_inplace = OpPropInplaceOption(attrs); + auto forward_inplace_identity = prop.ptr->ForwardInplaceIdentity(); + if (forward_inplace_identity.size() == 0UL) { + for (auto i = 0UL; i < forward_inplace.size(); ++i) { + forward_inplace_identity.push_back(false); + } + } + CHECK_EQ(forward_inplace.size(), forward_inplace_identity.size()); + return forward_inplace_identity; +} + std::vector OpPropResourceRequest(const NodeAttrs& attrs) { mxnet::ShapeVector ishape; auto& prop = nnvm::get(attrs.parsed); @@ -409,6 +422,19 @@ std::vector> OpBackInplaceOption(const NodeAttrs& attrs) { return remap; } +std::vector OpBackInplaceIdentity(const NodeAttrs& attrs) { + auto& prop = nnvm::get(attrs.parsed); + auto backward_inplace = OpBackInplaceOption(attrs); + auto backward_inplace_identity = prop.ptr->BackwardInplaceIdentity(); + if (backward_inplace_identity.size() == 0UL) { + for (auto i = 0UL; i < backward_inplace.size(); ++i) { + backward_inplace_identity.push_back(false); + } + } + CHECK_EQ(backward_inplace.size(), backward_inplace_identity.size()); + return backward_inplace_identity; +} + inline ExecType OpExecType(const NodeAttrs& attrs) { auto& prop = nnvm::get(attrs.parsed); return prop.ptr->exec_type(); @@ -442,6 +468,7 @@ void RegisterLegacyOpProp() { op.set_attr("FInferType", OpPropInferType); op.set_attr("FMutateInputs", OpPropMutateInputs); op.set_attr("FInplaceOption", OpPropInplaceOption); + op.set_attr("FInplaceIdentity", OpPropInplaceIdentity); op.set_attr("FResourceRequest", OpPropResourceRequest); op.set_attr("FExecType", OpExecType); op.set_attr("FCreateOpState", OpPropCreateLayerOp); @@ -463,6 +490,7 @@ void RegisterLegacyOpProp() { back_op.set_attr("FListOutputNames", OpBackListOutputNames); back_op.set_attr("FMutateInputs", OpBackMutateInputs); back_op.set_attr("FInplaceOption", OpBackInplaceOption); + back_op.set_attr("FInplaceIdentity", OpBackInplaceIdentity); back_op.set_attr("FResourceRequest", OpBackResourceRequest); back_op.set_attr("TIsLayerOpBackward", true); back_op.set_attr("TIsBackward", true);