Skip to content
Merged
93 changes: 14 additions & 79 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* This module enables declaration of named attributes
* which support default value setup and bound checking.
*
* \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD
* \sa BaseAttrsNode, AttrsWithDefaultValues
*/
#ifndef TVM_IR_ATTRS_H_
#define TVM_IR_ATTRS_H_
Expand All @@ -32,36 +32,17 @@
#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/extra/structural_hash.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/cow.h>
#include <tvm/ir/expr.h>

#include <functional>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>

namespace tvm {

/*!
* \brief Create a NodeRef type that represents null.
* \tparam TNodeRef the type to be created.
* \return A instance that will represent None.
*/
template <typename TObjectRef>
inline TObjectRef NullValue() {
static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types");
return TObjectRef(ffi::ObjectPtr<typename TObjectRef::ContainerType>(nullptr));
}

template <>
inline DataType NullValue<DataType>() {
return DataType(DataType::kHandle, 0, 0);
}

/*!
* \brief Information about attribute fields in string representations.
*/
Expand Down Expand Up @@ -103,22 +84,6 @@ class BaseAttrsNode : public ffi::Object {
public:
/*! \brief virtual destructor */
virtual ~BaseAttrsNode() {}
/*!
* \brief Initialize the attributes by sequence of arguments
* \param args The positional arguments in the form
* [key0, value0, key1, value1, ..., key_n, value_n]
*/
template <typename... Args>
inline void InitBySeq(Args&&... args);
/*!
* \brief Initialize the attributes by arguments.
* \param kwargs The key value pairs for initialization.
* [key0, value0, key1, value1, ..., key_n, value_n]
* \param allow_unknown Whether allow additional unknown fields.
* \note This function throws when the required field is not present.
*/
TVM_DLL virtual void InitByPackedArgs(const ffi::PackedArgs& kwargs,
bool allow_unknown = false) = 0;

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
TVM_FFI_DECLARE_OBJECT_INFO("ir.Attrs", BaseAttrsNode, ffi::Object);
Expand Down Expand Up @@ -149,8 +114,6 @@ class DictAttrsNode : public BaseAttrsNode {
rfl::ObjectDef<DictAttrsNode>().def_ro("__dict__", &DictAttrsNode::dict);
}

void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final;

// type info
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.DictAttrs", DictAttrsNode, BaseAttrsNode);
};
Expand Down Expand Up @@ -380,48 +343,20 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
}

/*!
* \brief Adapter for AttrsNode with the new reflection API.
*
* We will phaseout the old AttrsNode in future in favor of the new reflection API.
* This adapter allows us to gradually migrate to the new reflection API.
*
* \tparam DerivedType The final attribute type.
* \brief Create an object with all default values, using the reflection defaults.
* \tparam TObj the ObjectRef type to be created.
* \return An instance with all reflection-defined default values applied.
*/
template <typename DerivedType>
class AttrsNodeReflAdapter : public BaseAttrsNode {
public:
void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final {
TVM_FFI_THROW(InternalError) << "`" << DerivedType::_type_key
<< "` uses new reflection mechanism for init";
}

private:
DerivedType* self() const {
return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
}
};

/*!
* \brief Create an Attr object with all default values.
* \tparam TAttrNode the type to be created.
* \return A instance that will represent None.
*/
template <typename TAttrs>
inline TAttrs AttrsWithDefaultValues() {
static_assert(std::is_base_of_v<Attrs, TAttrs>, "Can only take attr nodes");
using ContainerType = typename TAttrs::ContainerType;
if constexpr (std::is_base_of_v<AttrsNodeReflAdapter<ContainerType>, ContainerType>) {
static auto finit_object = ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs");
AnyView packed_args[1];
packed_args[0] = ContainerType::RuntimeTypeIndex();
ffi::Any rv;
finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv);
return rv.cast<TAttrs>();
} else {
auto n = ffi::make_object<ContainerType>();
n->InitByPackedArgs(ffi::PackedArgs(nullptr, 0), false);
return TAttrs(n);
}
template <typename TObj>
inline TObj AttrsWithDefaultValues() {
static_assert(std::is_base_of_v<ffi::ObjectRef, TObj>, "Can only create ObjectRef-derived types");
using ContainerType = typename TObj::ContainerType;
static auto finit_object = ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs");
AnyView packed_args[1];
packed_args[0] = ContainerType::RuntimeTypeIndex();
ffi::Any rv;
finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv);
return rv.cast<TObj>();
}

} // namespace tvm
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/relax/attrs/ccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace tvm {
namespace relax {

/*! \brief Attributes used in allreduce operators */
struct AllReduceAttrs : public tvm::AttrsNodeReflAdapter<AllReduceAttrs> {
struct AllReduceAttrs : public tvm::BaseAttrsNode {
ffi::String op_type;
bool in_group;

Expand All @@ -49,7 +49,7 @@ struct AllReduceAttrs : public tvm::AttrsNodeReflAdapter<AllReduceAttrs> {
}; // struct AllReduceAttrs

/*! \brief Attributes used in allgather operators */
struct AllGatherAttrs : public tvm::AttrsNodeReflAdapter<AllGatherAttrs> {
struct AllGatherAttrs : public tvm::BaseAttrsNode {
int num_workers;
bool in_group;

Expand All @@ -67,7 +67,7 @@ struct AllGatherAttrs : public tvm::AttrsNodeReflAdapter<AllGatherAttrs> {
}; // struct AllGatherAttrs

/*! \brief Attributes used in scatter operators */
struct ScatterCollectiveAttrs : public tvm::AttrsNodeReflAdapter<ScatterCollectiveAttrs> {
struct ScatterCollectiveAttrs : public tvm::BaseAttrsNode {
int num_workers;
int axis;

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relax/attrs/create.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace tvm {
namespace relax {

/*! \brief Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operators */
struct InitAttrs : public AttrsNodeReflAdapter<InitAttrs> {
struct InitAttrs : public BaseAttrsNode {
DataType dtype;

static void RegisterReflection() {
Expand All @@ -42,7 +42,7 @@ struct InitAttrs : public AttrsNodeReflAdapter<InitAttrs> {
}; // struct InitAttrs

/*! \brief Attributes used in tril and triu operator */
struct TriluAttrs : public AttrsNodeReflAdapter<TriluAttrs> {
struct TriluAttrs : public BaseAttrsNode {
int k;

static void RegisterReflection() {
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relax/attrs/datatype.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace tvm {
namespace relax {

/*! \brief Attributes used in astype operator */
struct AstypeAttrs : public AttrsNodeReflAdapter<AstypeAttrs> {
struct AstypeAttrs : public BaseAttrsNode {
DataType dtype;

static void RegisterReflection() {
Expand All @@ -41,7 +41,7 @@ struct AstypeAttrs : public AttrsNodeReflAdapter<AstypeAttrs> {
}; // struct AstypeAttrs.

/*! \brief Attributes used in wrap_param operator */
struct WrapParamAttrs : public AttrsNodeReflAdapter<WrapParamAttrs> {
struct WrapParamAttrs : public BaseAttrsNode {
DataType dtype;

static void RegisterReflection() {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relax/attrs/distributed.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace tvm {
namespace relax {

/*! \brief Attributes for redistribute and annotate_sharding operator */
struct DistributionAttrs : public AttrsNodeReflAdapter<DistributionAttrs> {
struct DistributionAttrs : public BaseAttrsNode {
distributed::DeviceMesh device_mesh;
distributed::Placement placement;

Expand Down
6 changes: 3 additions & 3 deletions include/tvm/relax/attrs/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace tvm {
namespace relax {

/*! \brief Attributes used in image resize2d operator */
struct Resize2DAttrs : public AttrsNodeReflAdapter<Resize2DAttrs> {
struct Resize2DAttrs : public BaseAttrsNode {
ffi::Array<FloatImm> roi;
ffi::String layout;
ffi::String method;
Expand Down Expand Up @@ -79,7 +79,7 @@ struct Resize2DAttrs : public AttrsNodeReflAdapter<Resize2DAttrs> {
}; // struct Resize2dAttrs

/*! \brief Attributes used in image resize3d operator */
struct Resize3DAttrs : public AttrsNodeReflAdapter<Resize3DAttrs> {
struct Resize3DAttrs : public BaseAttrsNode {
ffi::Array<FloatImm> roi;
ffi::String layout;
ffi::String method;
Expand Down Expand Up @@ -128,7 +128,7 @@ struct Resize3DAttrs : public AttrsNodeReflAdapter<Resize3DAttrs> {
}; // struct Resize3DAttrs

/*! \brief Attributes used in image grid_sample operator */
struct GridSampleAttrs : public AttrsNodeReflAdapter<GridSampleAttrs> {
struct GridSampleAttrs : public BaseAttrsNode {
ffi::String method;
ffi::String layout;
ffi::String padding_mode;
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relax/attrs/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace tvm {
namespace relax {

/*! \brief Attributes used in take operator */
struct TakeAttrs : public AttrsNodeReflAdapter<TakeAttrs> {
struct TakeAttrs : public BaseAttrsNode {
ffi::Optional<int64_t> axis;
ffi::String mode;

Expand All @@ -45,7 +45,7 @@ struct TakeAttrs : public AttrsNodeReflAdapter<TakeAttrs> {
}; // struct TakeAttrs

/*! \brief Attributes used in strided_slice operator */
struct StridedSliceAttrs : public AttrsNodeReflAdapter<StridedSliceAttrs> {
struct StridedSliceAttrs : public BaseAttrsNode {
bool assume_inbound;

static void RegisterReflection() {
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relax/attrs/linear_algebra.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace tvm {
namespace relax {

/*! \brief Attributes for matmul operator */
struct MatmulAttrs : public AttrsNodeReflAdapter<MatmulAttrs> {
struct MatmulAttrs : public BaseAttrsNode {
DataType out_dtype;

static void RegisterReflection() {
Expand All @@ -42,7 +42,7 @@ struct MatmulAttrs : public AttrsNodeReflAdapter<MatmulAttrs> {
}; // struct MatmulAttrs

/*! \brief Attributes used in einsum operator */
struct EinsumAttrs : public AttrsNodeReflAdapter<EinsumAttrs> {
struct EinsumAttrs : public BaseAttrsNode {
ffi::String subscripts;

static void RegisterReflection() {
Expand Down
Loading