diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index fa3dfa5b3ec2..287a26351728 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -23,7 +23,7 @@ * This module enables declaration of named attributes * which support default value setup and bound checking. * - * \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD + * \sa BaseAttrsNode, AttrsWithDefaultValues */ #ifndef TVM_IR_ATTRS_H_ #define TVM_IR_ATTRS_H_ @@ -32,36 +32,17 @@ #include #include #include -#include #include #include #include -#include #include #include #include #include -#include namespace tvm { -/*! - * \brief Create a NodeRef type that represents null. - * \tparam TNodeRef the type to be created. - * \return A instance that will represent None. - */ -template -inline TObjectRef NullValue() { - static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types"); - return TObjectRef(ffi::ObjectPtr(nullptr)); -} - -template <> -inline DataType NullValue() { - return DataType(DataType::kHandle, 0, 0); -} - /*! * \brief Information about attribute fields in string representations. */ @@ -103,22 +84,6 @@ class BaseAttrsNode : public ffi::Object { public: /*! \brief virtual destructor */ virtual ~BaseAttrsNode() {} - /*! - * \brief Initialize the attributes by sequence of arguments - * \param args The positional arguments in the form - * [key0, value0, key1, value1, ..., key_n, value_n] - */ - template - inline void InitBySeq(Args&&... args); - /*! - * \brief Initialize the attributes by arguments. - * \param kwargs The key value pairs for initialization. - * [key0, value0, key1, value1, ..., key_n, value_n] - * \param allow_unknown Whether allow additional unknown fields. - * \note This function throws when the required field is not present. - */ - TVM_DLL virtual void InitByPackedArgs(const ffi::PackedArgs& kwargs, - bool allow_unknown = false) = 0; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; TVM_FFI_DECLARE_OBJECT_INFO("ir.Attrs", BaseAttrsNode, ffi::Object); @@ -149,8 +114,6 @@ class DictAttrsNode : public BaseAttrsNode { rfl::ObjectDef().def_ro("__dict__", &DictAttrsNode::dict); } - void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final; - // type info TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.DictAttrs", DictAttrsNode, BaseAttrsNode); }; @@ -380,48 +343,20 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { } /*! - * \brief Adapter for AttrsNode with the new reflection API. - * - * We will phaseout the old AttrsNode in future in favor of the new reflection API. - * This adapter allows us to gradually migrate to the new reflection API. - * - * \tparam DerivedType The final attribute type. + * \brief Create an object with all default values, using the reflection defaults. + * \tparam TObj the ObjectRef type to be created. + * \return An instance with all reflection-defined default values applied. */ -template -class AttrsNodeReflAdapter : public BaseAttrsNode { - public: - void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final { - TVM_FFI_THROW(InternalError) << "`" << DerivedType::_type_key - << "` uses new reflection mechanism for init"; - } - - private: - DerivedType* self() const { - return const_cast(static_cast(this)); - } -}; - -/*! - * \brief Create an Attr object with all default values. - * \tparam TAttrNode the type to be created. - * \return A instance that will represent None. - */ -template -inline TAttrs AttrsWithDefaultValues() { - static_assert(std::is_base_of_v, "Can only take attr nodes"); - using ContainerType = typename TAttrs::ContainerType; - if constexpr (std::is_base_of_v, ContainerType>) { - static auto finit_object = ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs"); - AnyView packed_args[1]; - packed_args[0] = ContainerType::RuntimeTypeIndex(); - ffi::Any rv; - finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv); - return rv.cast(); - } else { - auto n = ffi::make_object(); - n->InitByPackedArgs(ffi::PackedArgs(nullptr, 0), false); - return TAttrs(n); - } +template +inline TObj AttrsWithDefaultValues() { + static_assert(std::is_base_of_v, "Can only create ObjectRef-derived types"); + using ContainerType = typename TObj::ContainerType; + static auto finit_object = ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs"); + AnyView packed_args[1]; + packed_args[0] = ContainerType::RuntimeTypeIndex(); + ffi::Any rv; + finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv); + return rv.cast(); } } // namespace tvm diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h index 09d40b4ed98e..7e0624706b0c 100644 --- a/include/tvm/relax/attrs/ccl.h +++ b/include/tvm/relax/attrs/ccl.h @@ -31,7 +31,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in allreduce operators */ -struct AllReduceAttrs : public tvm::AttrsNodeReflAdapter { +struct AllReduceAttrs : public tvm::BaseAttrsNode { ffi::String op_type; bool in_group; @@ -49,7 +49,7 @@ struct AllReduceAttrs : public tvm::AttrsNodeReflAdapter { }; // struct AllReduceAttrs /*! \brief Attributes used in allgather operators */ -struct AllGatherAttrs : public tvm::AttrsNodeReflAdapter { +struct AllGatherAttrs : public tvm::BaseAttrsNode { int num_workers; bool in_group; @@ -67,7 +67,7 @@ struct AllGatherAttrs : public tvm::AttrsNodeReflAdapter { }; // struct AllGatherAttrs /*! \brief Attributes used in scatter operators */ -struct ScatterCollectiveAttrs : public tvm::AttrsNodeReflAdapter { +struct ScatterCollectiveAttrs : public tvm::BaseAttrsNode { int num_workers; int axis; diff --git a/include/tvm/relax/attrs/create.h b/include/tvm/relax/attrs/create.h index c631fd3b4e3d..9a9e453263a0 100644 --- a/include/tvm/relax/attrs/create.h +++ b/include/tvm/relax/attrs/create.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operators */ -struct InitAttrs : public AttrsNodeReflAdapter { +struct InitAttrs : public BaseAttrsNode { DataType dtype; static void RegisterReflection() { @@ -42,7 +42,7 @@ struct InitAttrs : public AttrsNodeReflAdapter { }; // struct InitAttrs /*! \brief Attributes used in tril and triu operator */ -struct TriluAttrs : public AttrsNodeReflAdapter { +struct TriluAttrs : public BaseAttrsNode { int k; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/datatype.h b/include/tvm/relax/attrs/datatype.h index dd07e3b54851..a1870597033e 100644 --- a/include/tvm/relax/attrs/datatype.h +++ b/include/tvm/relax/attrs/datatype.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in astype operator */ -struct AstypeAttrs : public AttrsNodeReflAdapter { +struct AstypeAttrs : public BaseAttrsNode { DataType dtype; static void RegisterReflection() { @@ -41,7 +41,7 @@ struct AstypeAttrs : public AttrsNodeReflAdapter { }; // struct AstypeAttrs. /*! \brief Attributes used in wrap_param operator */ -struct WrapParamAttrs : public AttrsNodeReflAdapter { +struct WrapParamAttrs : public BaseAttrsNode { DataType dtype; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/distributed.h b/include/tvm/relax/attrs/distributed.h index 356a248ba220..cce508ef1d50 100644 --- a/include/tvm/relax/attrs/distributed.h +++ b/include/tvm/relax/attrs/distributed.h @@ -32,7 +32,7 @@ namespace tvm { namespace relax { /*! \brief Attributes for redistribute and annotate_sharding operator */ -struct DistributionAttrs : public AttrsNodeReflAdapter { +struct DistributionAttrs : public BaseAttrsNode { distributed::DeviceMesh device_mesh; distributed::Placement placement; diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h index 52aac58dcde9..8cc5e36734b6 100644 --- a/include/tvm/relax/attrs/image.h +++ b/include/tvm/relax/attrs/image.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in image resize2d operator */ -struct Resize2DAttrs : public AttrsNodeReflAdapter { +struct Resize2DAttrs : public BaseAttrsNode { ffi::Array roi; ffi::String layout; ffi::String method; @@ -79,7 +79,7 @@ struct Resize2DAttrs : public AttrsNodeReflAdapter { }; // struct Resize2dAttrs /*! \brief Attributes used in image resize3d operator */ -struct Resize3DAttrs : public AttrsNodeReflAdapter { +struct Resize3DAttrs : public BaseAttrsNode { ffi::Array roi; ffi::String layout; ffi::String method; @@ -128,7 +128,7 @@ struct Resize3DAttrs : public AttrsNodeReflAdapter { }; // struct Resize3DAttrs /*! \brief Attributes used in image grid_sample operator */ -struct GridSampleAttrs : public AttrsNodeReflAdapter { +struct GridSampleAttrs : public BaseAttrsNode { ffi::String method; ffi::String layout; ffi::String padding_mode; diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h index 0ea7c06bacc0..7b4c446bb80c 100644 --- a/include/tvm/relax/attrs/index.h +++ b/include/tvm/relax/attrs/index.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in take operator */ -struct TakeAttrs : public AttrsNodeReflAdapter { +struct TakeAttrs : public BaseAttrsNode { ffi::Optional axis; ffi::String mode; @@ -45,7 +45,7 @@ struct TakeAttrs : public AttrsNodeReflAdapter { }; // struct TakeAttrs /*! \brief Attributes used in strided_slice operator */ -struct StridedSliceAttrs : public AttrsNodeReflAdapter { +struct StridedSliceAttrs : public BaseAttrsNode { bool assume_inbound; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/linear_algebra.h b/include/tvm/relax/attrs/linear_algebra.h index f95d817f1e4d..2627dafcf6b3 100644 --- a/include/tvm/relax/attrs/linear_algebra.h +++ b/include/tvm/relax/attrs/linear_algebra.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes for matmul operator */ -struct MatmulAttrs : public AttrsNodeReflAdapter { +struct MatmulAttrs : public BaseAttrsNode { DataType out_dtype; static void RegisterReflection() { @@ -42,7 +42,7 @@ struct MatmulAttrs : public AttrsNodeReflAdapter { }; // struct MatmulAttrs /*! \brief Attributes used in einsum operator */ -struct EinsumAttrs : public AttrsNodeReflAdapter { +struct EinsumAttrs : public BaseAttrsNode { ffi::String subscripts; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index f2ba7af0d9fb..71fb7b0b95ef 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -31,7 +31,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in concat operators */ -struct ConcatAttrs : public AttrsNodeReflAdapter { +struct ConcatAttrs : public BaseAttrsNode { ffi::Optional axis; static void RegisterReflection() { @@ -44,7 +44,7 @@ struct ConcatAttrs : public AttrsNodeReflAdapter { }; // struct ConcatAttrs /*! \brief Attributes used in expand_dims operators */ -struct ExpandDimsAttrs : public AttrsNodeReflAdapter { +struct ExpandDimsAttrs : public BaseAttrsNode { ffi::Array axis; static void RegisterReflection() { @@ -59,7 +59,7 @@ struct ExpandDimsAttrs : public AttrsNodeReflAdapter { }; // struct ExpandDimsAttrs /*! \brief Attributes used in layout_transform operator */ -struct LayoutTransformAttrs : public AttrsNodeReflAdapter { +struct LayoutTransformAttrs : public BaseAttrsNode { tirx::IndexMap index_map; // pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This // needs to be revisited in case PrimValue is evolved to represent symbolic expression in future. @@ -97,7 +97,7 @@ struct LayoutTransformAttrs : public AttrsNodeReflAdapter }; // struct LayoutTransformAttrs /*! \brief Attributes used in permute_dims operator */ -struct PermuteDimsAttrs : public AttrsNodeReflAdapter { +struct PermuteDimsAttrs : public BaseAttrsNode { ffi::Optional> axes; static void RegisterReflection() { @@ -110,7 +110,7 @@ struct PermuteDimsAttrs : public AttrsNodeReflAdapter { }; // struct PermuteDimsAttrs /*! \brief Attributes used in split operator */ -struct SplitAttrs : public AttrsNodeReflAdapter { +struct SplitAttrs : public BaseAttrsNode { ffi::ObjectRef indices_or_sections; int axis; @@ -125,7 +125,7 @@ struct SplitAttrs : public AttrsNodeReflAdapter { }; // struct SplitAttrs /*! \brief Attributes used in squeeze operators */ -struct SqueezeAttrs : public AttrsNodeReflAdapter { +struct SqueezeAttrs : public BaseAttrsNode { ffi::Optional> axis; static void RegisterReflection() { @@ -140,7 +140,7 @@ struct SqueezeAttrs : public AttrsNodeReflAdapter { }; // struct SqueezeAttrs /*! \brief Attributes used in stack operators */ -struct StackAttrs : public AttrsNodeReflAdapter { +struct StackAttrs : public BaseAttrsNode { ffi::Optional axis; static void RegisterReflection() { @@ -156,7 +156,7 @@ struct StackAttrs : public AttrsNodeReflAdapter { }; // struct StackAttrs /*! \brief Attributes used in repeat operators */ -struct RepeatAttrs : public AttrsNodeReflAdapter { +struct RepeatAttrs : public BaseAttrsNode { int repeats; ffi::Optional axis; @@ -173,7 +173,7 @@ struct RepeatAttrs : public AttrsNodeReflAdapter { }; // struct RepeatAttrs /*! \brief Attributes used in tile operators */ -struct TileAttrs : public AttrsNodeReflAdapter { +struct TileAttrs : public BaseAttrsNode { ffi::Array repeats; static void RegisterReflection() { @@ -185,20 +185,19 @@ struct TileAttrs : public AttrsNodeReflAdapter { }; // struct TileAttrs /*! \brief Attributes used in flip operators */ -struct FlipAttrs : public AttrsNodeReflAdapter { - Integer axis; +struct FlipAttrs : public BaseAttrsNode { + int64_t axis; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("axis", &FlipAttrs::axis, - "The axis along which to flip over.", - refl::DefaultValue(NullValue())); + "The axis along which to flip over."); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.FlipAttrs", FlipAttrs, BaseAttrsNode); }; // struct FlipAttrs /*! \brief Attributes used in gather_elements operators */ -struct GatherElementsAttrs : public AttrsNodeReflAdapter { +struct GatherElementsAttrs : public BaseAttrsNode { Integer axis; static void RegisterReflection() { @@ -212,7 +211,7 @@ struct GatherElementsAttrs : public AttrsNodeReflAdapter { }; // struct GatherElementsAttrs /*! \brief Attributes used in gather_nd operators */ -struct GatherNDAttrs : public AttrsNodeReflAdapter { +struct GatherNDAttrs : public BaseAttrsNode { Integer batch_dims; static void RegisterReflection() { @@ -224,7 +223,7 @@ struct GatherNDAttrs : public AttrsNodeReflAdapter { }; // struct GatherNDAttrs /*! \brief Attributes used in index_put operator */ -struct IndexPutAttrs : public AttrsNodeReflAdapter { +struct IndexPutAttrs : public BaseAttrsNode { bool accumulate; static void RegisterReflection() { @@ -240,7 +239,7 @@ struct IndexPutAttrs : public AttrsNodeReflAdapter { }; // struct IndexPutAttrs /*! \brief Attribute used in meshgrid operator */ -struct MeshgridAttrs : public AttrsNodeReflAdapter { +struct MeshgridAttrs : public BaseAttrsNode { ffi::Optional indexing; static void RegisterReflection() { @@ -252,7 +251,7 @@ struct MeshgridAttrs : public AttrsNodeReflAdapter { }; /*! \brief Attributes used in scatter_elements operators */ -struct ScatterElementsAttrs : public AttrsNodeReflAdapter { +struct ScatterElementsAttrs : public BaseAttrsNode { Integer axis; ffi::String reduction; @@ -271,7 +270,7 @@ struct ScatterElementsAttrs : public AttrsNodeReflAdapter }; // struct ScatterElementsAttrs /*! \brief Attributes used in scatter_nd operators */ -struct ScatterNDAttrs : public AttrsNodeReflAdapter { +struct ScatterNDAttrs : public BaseAttrsNode { ffi::String reduction; static void RegisterReflection() { @@ -286,7 +285,7 @@ struct ScatterNDAttrs : public AttrsNodeReflAdapter { }; // struct ScatterNDAttrs /*! \brief Attributes used in slice_scatter operator */ -struct SliceScatterAttrs : public AttrsNodeReflAdapter { +struct SliceScatterAttrs : public BaseAttrsNode { int axis; static void RegisterReflection() { @@ -300,7 +299,7 @@ struct SliceScatterAttrs : public AttrsNodeReflAdapter { }; // struct SliceScatterAttrs /*! \brief Attributes used in one_hot operator */ -struct OneHotAttrs : public AttrsNodeReflAdapter { +struct OneHotAttrs : public BaseAttrsNode { int depth; int axis; diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 45abeb9d5b7e..bfc85dfd5a13 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in Conv1d operator */ -struct Conv1DAttrs : public AttrsNodeReflAdapter { +struct Conv1DAttrs : public BaseAttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array dilation; @@ -74,7 +74,7 @@ struct Conv1DAttrs : public AttrsNodeReflAdapter { }; // struct Conv1dAttrs /*! \brief Attributes used in Conv2d operator */ -struct Conv2DAttrs : public AttrsNodeReflAdapter { +struct Conv2DAttrs : public BaseAttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array dilation; @@ -120,7 +120,7 @@ struct Conv2DAttrs : public AttrsNodeReflAdapter { }; // struct Conv2dAttrs /*! \brief Attributes used in Conv3d operator */ -struct Conv3DAttrs : public AttrsNodeReflAdapter { +struct Conv3DAttrs : public BaseAttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array dilation; @@ -168,7 +168,7 @@ struct Conv3DAttrs : public AttrsNodeReflAdapter { }; // struct Conv3dAttrs /*! \brief Attributes used in Conv1DTranspose operator */ -struct Conv1DTransposeAttrs : public AttrsNodeReflAdapter { +struct Conv1DTransposeAttrs : public BaseAttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array output_padding; @@ -217,7 +217,7 @@ struct Conv1DTransposeAttrs : public AttrsNodeReflAdapter }; // struct Conv1DTransposeAttrs /*! \brief Attributes used in Conv2d operator */ -struct Conv2DTransposeAttrs : public AttrsNodeReflAdapter { +struct Conv2DTransposeAttrs : public BaseAttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array output_padding; @@ -268,7 +268,7 @@ struct Conv2DTransposeAttrs : public AttrsNodeReflAdapter }; // struct Conv2DTransposeAttrs /*! \brief Attributes used in Conv3dTranspose operator */ -struct Conv3DTransposeAttrs : public AttrsNodeReflAdapter { +struct Conv3DTransposeAttrs : public BaseAttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array output_padding; @@ -321,7 +321,7 @@ struct Conv3DTransposeAttrs : public AttrsNodeReflAdapter }; // struct Conv3DTransposeAttrs /*! \brief Attributes used in max_pool1d and avg_pool1d operator */ -struct Pool1DAttrs : public AttrsNodeReflAdapter { +struct Pool1DAttrs : public BaseAttrsNode { ffi::Array pool_size; ffi::Array strides; ffi::Array padding; @@ -362,7 +362,7 @@ struct Pool1DAttrs : public AttrsNodeReflAdapter { }; // struct Pool1dAttrs /*! \brief Attributes used in max_pool2d and avg_pool2d operator */ -struct Pool2DAttrs : public AttrsNodeReflAdapter { +struct Pool2DAttrs : public BaseAttrsNode { ffi::Array pool_size; ffi::Array strides; ffi::Array padding; @@ -405,7 +405,7 @@ struct Pool2DAttrs : public AttrsNodeReflAdapter { }; // struct Pool2dAttrs /*! \brief Attributes used in max_pool3d and avg_pool3d operator */ -struct Pool3DAttrs : public AttrsNodeReflAdapter { +struct Pool3DAttrs : public BaseAttrsNode { ffi::Array pool_size; ffi::Array strides; ffi::Array padding; @@ -448,7 +448,7 @@ struct Pool3DAttrs : public AttrsNodeReflAdapter { }; // struct Pool3dAttrs /*! \brief Attributes for 1d adaptive pool operator */ -struct AdaptivePool1DAttrs : public AttrsNodeReflAdapter { +struct AdaptivePool1DAttrs : public BaseAttrsNode { ffi::Optional> output_size; ffi::String layout; ffi::String out_layout; @@ -473,7 +473,7 @@ struct AdaptivePool1DAttrs : public AttrsNodeReflAdapter { }; // struct AdaptivePool1DAttrs /*! \brief Attributes for 2d adaptive pool operator */ -struct AdaptivePool2DAttrs : public AttrsNodeReflAdapter { +struct AdaptivePool2DAttrs : public BaseAttrsNode { ffi::Optional> output_size; ffi::String layout; ffi::String out_layout; @@ -498,7 +498,7 @@ struct AdaptivePool2DAttrs : public AttrsNodeReflAdapter { }; // struct AdaptivePool2DAttrs /*! \brief Attributes for 3d adaptive pool operator */ -struct AdaptivePool3DAttrs : public AttrsNodeReflAdapter { +struct AdaptivePool3DAttrs : public BaseAttrsNode { ffi::Optional> output_size; ffi::String layout; ffi::String out_layout; @@ -523,7 +523,7 @@ struct AdaptivePool3DAttrs : public AttrsNodeReflAdapter { }; // struct AdaptivePool3DAttrs /*! \brief Attributes used in softmax operators */ -struct SoftmaxAttrs : public AttrsNodeReflAdapter { +struct SoftmaxAttrs : public BaseAttrsNode { int axis; static void RegisterReflection() { @@ -535,7 +535,7 @@ struct SoftmaxAttrs : public AttrsNodeReflAdapter { }; /*! \brief Attributes used in softmax operators */ -struct LeakyReluAttrs : public AttrsNodeReflAdapter { +struct LeakyReluAttrs : public BaseAttrsNode { double alpha; static void RegisterReflection() { @@ -547,7 +547,7 @@ struct LeakyReluAttrs : public AttrsNodeReflAdapter { }; /*! \brief Attributes used in softplus operators */ -struct SoftplusAttrs : public AttrsNodeReflAdapter { +struct SoftplusAttrs : public BaseAttrsNode { double beta; double threshold; @@ -563,7 +563,7 @@ struct SoftplusAttrs : public AttrsNodeReflAdapter { }; /*! \brief Attributes used in PReLU operator */ -struct PReluAttrs : public AttrsNodeReflAdapter { +struct PReluAttrs : public BaseAttrsNode { int axis; static void RegisterReflection() { @@ -575,7 +575,7 @@ struct PReluAttrs : public AttrsNodeReflAdapter { }; /*! \brief Attributes used in batch_norm operator */ -struct BatchNormAttrs : public AttrsNodeReflAdapter { +struct BatchNormAttrs : public BaseAttrsNode { int axis; double epsilon; bool center; @@ -602,7 +602,7 @@ struct BatchNormAttrs : public AttrsNodeReflAdapter { }; // struct BatchNormAttrs /*! \brief Attributes used in layer_norm operator */ -struct LayerNormAttrs : public AttrsNodeReflAdapter { +struct LayerNormAttrs : public BaseAttrsNode { ffi::Array axes; double epsilon; bool center; @@ -624,7 +624,7 @@ struct LayerNormAttrs : public AttrsNodeReflAdapter { }; // struct LayerNormAttrs /*! \brief Attributes used in group_norm operator */ -struct GroupNormAttrs : public AttrsNodeReflAdapter { +struct GroupNormAttrs : public BaseAttrsNode { int num_groups; int channel_axis; ffi::Array axes; @@ -653,7 +653,7 @@ struct GroupNormAttrs : public AttrsNodeReflAdapter { }; // struct GroupNormAttrs /*! \brief Attributes used in instance_norm operator */ -struct InstanceNormAttrs : public AttrsNodeReflAdapter { +struct InstanceNormAttrs : public BaseAttrsNode { int channel_axis; ffi::Array axes; double epsilon; @@ -679,7 +679,7 @@ struct InstanceNormAttrs : public AttrsNodeReflAdapter { }; // struct InstanceNormAttrs /*! \brief Attributes used in rms_norm operator */ -struct RMSNormAttrs : public AttrsNodeReflAdapter { +struct RMSNormAttrs : public BaseAttrsNode { ffi::Array axes; double epsilon; @@ -695,7 +695,7 @@ struct RMSNormAttrs : public AttrsNodeReflAdapter { }; // struct RMSNormAttrs /*! \brief Attributes used in nll_loss operator */ -struct NLLLossAttrs : public AttrsNodeReflAdapter { +struct NLLLossAttrs : public BaseAttrsNode { ffi::String reduction; int ignore_index; @@ -712,7 +712,7 @@ struct NLLLossAttrs : public AttrsNodeReflAdapter { }; // struct NLLLossAttrs /*! \brief Attributes used in dropout operator */ -struct DropoutAttrs : public AttrsNodeReflAdapter { +struct DropoutAttrs : public BaseAttrsNode { double rate; static void RegisterReflection() { @@ -725,7 +725,7 @@ struct DropoutAttrs : public AttrsNodeReflAdapter { }; // struct DropoutAttrs /*! \brief Attributes used in Attention operator */ -struct AttentionAttrs : public AttrsNodeReflAdapter { +struct AttentionAttrs : public BaseAttrsNode { ffi::Optional scale; ffi::Optional causal_mask; ffi::Optional window_size; @@ -745,7 +745,7 @@ struct AttentionAttrs : public AttrsNodeReflAdapter { }; // struct AttentionAttrs /*! \brief Attributes used for the padding operator */ -struct PadAttrs : public AttrsNodeReflAdapter { +struct PadAttrs : public BaseAttrsNode { ffi::Array pad_width; double pad_value = 0.0; tvm::ffi::String pad_mode; @@ -768,7 +768,7 @@ struct PadAttrs : public AttrsNodeReflAdapter { }; /*! \brief Attributes used for the pixel shuffle operator */ -struct PixelShuffleAttrs : public AttrsNodeReflAdapter { +struct PixelShuffleAttrs : public BaseAttrsNode { int upscale_factor; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index 54640901ff53..79e00d590abe 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -31,7 +31,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in call_tir_with_grad */ -struct CallTIRWithGradAttrs : public AttrsNodeReflAdapter { +struct CallTIRWithGradAttrs : public BaseAttrsNode { ffi::String te_grad_name; ffi::Map te_grad_kwargs; @@ -49,7 +49,7 @@ struct CallTIRWithGradAttrs : public AttrsNodeReflAdapter }; // struct CallTIRAttrs /*! \brief Attributes used in call_tir_inplace */ -struct CallTIRInplaceAttrs : public AttrsNodeReflAdapter { +struct CallTIRInplaceAttrs : public BaseAttrsNode { /*! * \brief Indices that describe which input corresponds to which output. * @@ -69,7 +69,7 @@ struct CallTIRInplaceAttrs : public AttrsNodeReflAdapter { }; // struct CallTIRInplaceAttrs /*! \brief Attributes used in call_inplace_packed */ -struct CallInplacePackedAttrs : public AttrsNodeReflAdapter { +struct CallInplacePackedAttrs : public BaseAttrsNode { /*! * \brief Indices that describe which input corresponds to which output. * @@ -89,7 +89,7 @@ struct CallInplacePackedAttrs : public AttrsNodeReflAdapter { +struct ToVDeviceAttrs : public BaseAttrsNode { VDevice dst_vdevice; static void RegisterReflection() { @@ -101,7 +101,7 @@ struct ToVDeviceAttrs : public AttrsNodeReflAdapter { }; // struct ToVDeviceAttrs /*! \brief Attributes used in hint_on_device */ -struct HintOnDeviceAttrs : public AttrsNodeReflAdapter { +struct HintOnDeviceAttrs : public BaseAttrsNode { int32_t device_type; int32_t index; MemoryScope memory_scope; diff --git a/include/tvm/relax/attrs/qdq.h b/include/tvm/relax/attrs/qdq.h index ffb554994f98..08bc054dc54f 100644 --- a/include/tvm/relax/attrs/qdq.h +++ b/include/tvm/relax/attrs/qdq.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes for relax.quantize/relax.dequantize operator */ -struct QuantizeAttrs : public AttrsNodeReflAdapter { +struct QuantizeAttrs : public BaseAttrsNode { DataType out_dtype; int axis; diff --git a/include/tvm/relax/attrs/sampling.h b/include/tvm/relax/attrs/sampling.h index 53fd3a140497..2d7421cc20e8 100644 --- a/include/tvm/relax/attrs/sampling.h +++ b/include/tvm/relax/attrs/sampling.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in multinomial_from_uniform operator */ -struct MultinomialFromUniformAttrs : public AttrsNodeReflAdapter { +struct MultinomialFromUniformAttrs : public BaseAttrsNode { DataType dtype; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/search.h b/include/tvm/relax/attrs/search.h index 32327c160d1d..015e5d8edc1c 100644 --- a/include/tvm/relax/attrs/search.h +++ b/include/tvm/relax/attrs/search.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes for search operators */ -struct ArgmaxArgminAttrs : public AttrsNodeReflAdapter { +struct ArgmaxArgminAttrs : public BaseAttrsNode { ffi::Optional axis; bool keepdims; @@ -49,7 +49,7 @@ struct ArgmaxArgminAttrs : public AttrsNodeReflAdapter { }; // struct ArgmaxArgminAttrs /*! \brief Attributes for bucketize operator */ -struct BucketizeAttrs : public tvm::AttrsNodeReflAdapter { +struct BucketizeAttrs : public tvm::BaseAttrsNode { bool out_int32; bool right; diff --git a/include/tvm/relax/attrs/sorting.h b/include/tvm/relax/attrs/sorting.h index 354b77047272..e32d47239f35 100644 --- a/include/tvm/relax/attrs/sorting.h +++ b/include/tvm/relax/attrs/sorting.h @@ -31,7 +31,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in sort operator */ -struct SortAttrs : public AttrsNodeReflAdapter { +struct SortAttrs : public BaseAttrsNode { int axis; bool descending; @@ -51,7 +51,7 @@ struct SortAttrs : public AttrsNodeReflAdapter { }; // struct SortAttrs /*! \brief Attributes used in argsort operator */ -struct ArgsortAttrs : public AttrsNodeReflAdapter { +struct ArgsortAttrs : public BaseAttrsNode { int axis; bool descending; DataType dtype; @@ -68,13 +68,13 @@ struct ArgsortAttrs : public AttrsNodeReflAdapter { "If it is not specified, it defaults to the ascending order.", refl::DefaultValue(false)) .def_ro("dtype", &ArgsortAttrs::dtype, "DType of the output indices.", - refl::DefaultValue(NullValue())); + refl::DefaultValue(DataType::Void())); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ArgsortAttrs", ArgsortAttrs, BaseAttrsNode); }; // struct ArgsortAttrs /*! \brief Attributes used in topk operator */ -struct TopKAttrs : public AttrsNodeReflAdapter { +struct TopKAttrs : public BaseAttrsNode { int k; int axis; bool largest; @@ -98,7 +98,7 @@ struct TopKAttrs : public AttrsNodeReflAdapter { "By default, return the largest k elements.", refl::DefaultValue(true)) .def_ro("dtype", &TopKAttrs::dtype, "Data type of the output indices.", - refl::DefaultValue(NullValue())); + refl::DefaultValue(DataType::Void())); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TopKAttrs", TopKAttrs, BaseAttrsNode); }; // struct TopKAttrs diff --git a/include/tvm/relax/attrs/statistical.h b/include/tvm/relax/attrs/statistical.h index 433524116d3c..367869f1ab11 100644 --- a/include/tvm/relax/attrs/statistical.h +++ b/include/tvm/relax/attrs/statistical.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes for statistical operators */ -struct StatisticalAttrs : public AttrsNodeReflAdapter { +struct StatisticalAttrs : public BaseAttrsNode { ffi::Optional> axis; bool keepdims; @@ -49,7 +49,7 @@ struct StatisticalAttrs : public AttrsNodeReflAdapter { }; // struct StatisticalAttrs /*! \brief Attributes used in scan operators like cumsum, cumprod */ -struct ScanopAttrs : public AttrsNodeReflAdapter { +struct ScanopAttrs : public BaseAttrsNode { ffi::Optional axis; DataType dtype; Bool exclusive = Bool(false); diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h index 55ed162674e2..37ec77cbbff6 100644 --- a/include/tvm/relax/attrs/vision.h +++ b/include/tvm/relax/attrs/vision.h @@ -32,8 +32,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in AllClassNonMaximumSuppression operator */ -struct AllClassNonMaximumSuppressionAttrs - : public AttrsNodeReflAdapter { +struct AllClassNonMaximumSuppressionAttrs : public BaseAttrsNode { ffi::String output_format; static void RegisterReflection() { @@ -48,7 +47,7 @@ struct AllClassNonMaximumSuppressionAttrs }; // struct AllClassNonMaximumSuppressionAttrs /*! \brief Attributes used in ROIAlign operator */ -struct ROIAlignAttrs : public AttrsNodeReflAdapter { +struct ROIAlignAttrs : public BaseAttrsNode { ffi::Array pooled_size; double spatial_scale; int sample_ratio; @@ -73,7 +72,7 @@ struct ROIAlignAttrs : public AttrsNodeReflAdapter { }; // struct ROIAlignAttrs /*! \brief Attributes used in ROIPool operator */ -struct ROIPoolAttrs : public AttrsNodeReflAdapter { +struct ROIPoolAttrs : public BaseAttrsNode { ffi::Array pooled_size; double spatial_scale; ffi::String layout; @@ -90,7 +89,7 @@ struct ROIPoolAttrs : public AttrsNodeReflAdapter { }; // struct ROIPoolAttrs /*! \brief Attributes used in GetValidCounts operator */ -struct GetValidCountsAttrs : public AttrsNodeReflAdapter { +struct GetValidCountsAttrs : public BaseAttrsNode { double score_threshold; int id_index; int score_index; @@ -110,7 +109,7 @@ struct GetValidCountsAttrs : public AttrsNodeReflAdapter { }; // struct GetValidCountsAttrs /*! \brief Attributes used in NonMaximumSuppression operator */ -struct NonMaximumSuppressionAttrs : public AttrsNodeReflAdapter { +struct NonMaximumSuppressionAttrs : public BaseAttrsNode { int max_output_size; double iou_threshold; bool force_suppress; @@ -154,7 +153,7 @@ struct NonMaximumSuppressionAttrs : public AttrsNodeReflAdapter { +struct MultiboxTransformLocAttrs : public BaseAttrsNode { bool clip; double threshold; ffi::Array variances; diff --git a/include/tvm/target/virtual_device.h b/include/tvm/target/virtual_device.h index 5ff282adb68b..79475262c4a4 100644 --- a/include/tvm/target/virtual_device.h +++ b/include/tvm/target/virtual_device.h @@ -169,7 +169,7 @@ constexpr int kInvalidDeviceType = -1; * These operations are needed during device planning. */ -class VirtualDeviceNode : public AttrsNodeReflAdapter { +class VirtualDeviceNode : public BaseAttrsNode { private: /*! * \brief The \p DLDeviceType (represented as an int) of the virtual device. If \p target is diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 3ce70fc545fb..21fd7b565c4e 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -441,7 +441,7 @@ def flip(data, axis): The input data to the operator. axis: int - axis to flip on + The axis along which to flip over. Returns ------- diff --git a/python/tvm/s_tir/transform/transform.py b/python/tvm/s_tir/transform/transform.py index e8d14171b331..af4ec493cc14 100644 --- a/python/tvm/s_tir/transform/transform.py +++ b/python/tvm/s_tir/transform/transform.py @@ -18,7 +18,6 @@ # pylint: disable=invalid-name, unsupported-binary-operation from ... import ffi as _ffi -from ... import ir as _ir from . import _ffi_api @@ -213,7 +212,7 @@ def AnnotateIrregularLoop(): @_ffi.register_object("s_tir.transform.LoopPartitionConfig") -class LoopPartitionConfig(_ir.Attrs): +class LoopPartitionConfig(_ffi.Object): """Config for loop partition pass""" @@ -240,7 +239,7 @@ def InjectVirtualThread(): @_ffi.register_object("s_tir.transform.InjectDoubleBufferConfig") -class InjectDoubleBufferConfig(_ir.Attrs): +class InjectDoubleBufferConfig(_ffi.Object): """Config for inject double buffer pass""" diff --git a/python/tvm/tirx/transform/transform.py b/python/tvm/tirx/transform/transform.py index 8082d864c1e9..fbf07b5f4897 100644 --- a/python/tvm/tirx/transform/transform.py +++ b/python/tvm/tirx/transform/transform.py @@ -21,7 +21,6 @@ from collections.abc import Callable from ... import ffi as _ffi -from ... import ir as _ir from . import _ffi_api from . import function_pass as _fpass @@ -107,7 +106,7 @@ def PointerValueTypeRewrite(): @_ffi.register_object("tirx.transform.UnrollLoopConfig") -class UnrollLoopConfig(_ir.Attrs): +class UnrollLoopConfig(_ffi.Object): """Config for unroll loop pass""" @@ -125,7 +124,7 @@ def UnrollLoop(): @_ffi.register_object("tirx.transform.RemoveNoOpConfig") -class RemoveNoOpConfig(_ir.Attrs): +class RemoveNoOpConfig(_ffi.Object): """Config for remove no op pass""" @@ -212,7 +211,7 @@ def CommonSubexprElim(): @_ffi.register_object("tirx.transform.SimplifyConfig") -class SimplifyConfig(_ir.Attrs): +class SimplifyConfig(_ffi.Object): """Config for simplify pass""" @@ -429,7 +428,7 @@ def VerifyMemory(): @_ffi.register_object("s_tir.transform.HoistIfThenElseConfig") -class HoistIfThenElseConfig(_ir.Attrs): +class HoistIfThenElseConfig(_ffi.Object): """Config for hoist if then else pass""" @@ -483,7 +482,7 @@ class HoistedLetBindings(enum.Flag): @_ffi.register_object("s_tir.transform.HoistExpressionConfig") -class HoistExpressionConfig(_ir.Attrs): +class HoistExpressionConfig(_ffi.Object): """Config for hoist expression pass""" diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index cfe269e4eba6..e7d9b9082809 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -53,14 +53,6 @@ DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key) { return attrs; } -void DictAttrsNode::InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) { - for (int i = 0; i < args.size(); i += 2) { - ffi::String key = args[i].cast(); - ffi::AnyView val = args[i + 1]; - dict.Set(key, val); - } -} - DictAttrs::DictAttrs(ffi::Map dict) { ffi::ObjectPtr n = ffi::make_object(); n->dict = std::move(dict); diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index eaa57f8315e4..dd71e8a68a51 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -41,7 +41,7 @@ namespace relax { namespace contrib { /*! \brief Attributes to store the compiler options for OpenCLML. */ -struct OpenCLMLCompilerConfigNode : public AttrsNodeReflAdapter { +struct OpenCLMLCompilerConfigNode : public ffi::Object { Integer clml_version; static void RegisterReflection() { @@ -51,12 +51,12 @@ struct OpenCLMLCompilerConfigNode : public AttrsNodeReflAdapter { +struct TensorRTCompilerConfigNode : public ffi::Object { ffi::Array tensorrt_version; bool use_implicit_batch; size_t max_workspace_size; @@ -72,12 +72,12 @@ struct TensorRTCompilerConfigNode : public AttrsNodeReflAdapter(); - attrs->axis = std::move(axis); + attrs->axis = axis; static const Op& op = Op::Get("relax.flip"); return Call(op, {std::move(data)}, Attrs{attrs}, {}); } @@ -2043,7 +2043,7 @@ StructInfo InferStructInfoFlip(const Call& call, const BlockBuilder& ctx) { } TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); - int axis = attrs->axis.IntValue(); + int axis = static_cast(attrs->axis); if (!data_sinfo->IsUnknownNdim()) { int ndim = data_sinfo->ndim; if (axis < -ndim || axis >= ndim) { @@ -2073,7 +2073,7 @@ InferLayoutOutput InferLayoutFlip( existing_layout = LayoutDecision(InitialLayout(ndim)); } - int axis = attrs->axis.IntValue(); + int axis = static_cast(attrs->axis); if (axis < 0) { axis += ndim; } @@ -2082,7 +2082,7 @@ InferLayoutOutput InferLayoutFlip( TVM_FFI_ICHECK_GE(new_axis, 0) << "Failed to find transformed axis"; ffi::ObjectPtr new_attrs = ffi::make_object(*attrs); - new_attrs->axis = Integer(new_axis); + new_attrs->axis = static_cast(new_axis); return InferLayoutOutput({existing_layout}, {existing_layout}, Attrs(new_attrs)); } diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 260d27f1ef1d..a6efffff4673 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -179,7 +179,7 @@ Expr tile(Expr data, ffi::Array repeats); * \param axis The axis to flip on * \return The computed result. */ -Expr flip(Expr data, Integer axis); +Expr flip(Expr data, int64_t axis); /*! * \brief Gather elements from a tensor using indices. diff --git a/src/s_tir/schedule/concrete_schedule.cc b/src/s_tir/schedule/concrete_schedule.cc index 21f5454040a6..94402c44d75e 100644 --- a/src/s_tir/schedule/concrete_schedule.cc +++ b/src/s_tir/schedule/concrete_schedule.cc @@ -35,7 +35,7 @@ Schedule Schedule::Concrete(IRModule mod, LinearCongruentialEngine::TRandState s n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); n->Seed(seed); - GlobalVar gv = NullValue(); + GlobalVar gv; if (FindEntryFunc(mod, &gv) != nullptr) { n->func_working_on_ = gv; } else { @@ -316,7 +316,7 @@ SBlockRV ConcreteScheduleNode::GetSBlock(const ffi::String& name, IRModule mod_; ffi::Array blocks_; }; - GlobalVar gv = NullValue(); + GlobalVar gv; if (func_name.has_value()) { gv = state_->mod->GetGlobalVar(func_name.value()); } else if (func_working_on_.has_value()) { diff --git a/src/s_tir/schedule/traced_schedule.cc b/src/s_tir/schedule/traced_schedule.cc index e12cdd69de3f..6357e1ae19d3 100644 --- a/src/s_tir/schedule/traced_schedule.cc +++ b/src/s_tir/schedule/traced_schedule.cc @@ -31,7 +31,7 @@ Schedule Schedule::Traced(IRModule mod, LinearCongruentialEngine::TRandState see n->analyzer_ = std::make_unique(); n->trace_ = Trace(); n->Seed(seed); - GlobalVar gv = NullValue(); + GlobalVar gv; if (FindEntryFunc(mod, &gv) != nullptr) { n->func_working_on_ = gv; } else { @@ -118,7 +118,7 @@ LoopRV TracedScheduleNode::SampleComputeLocation(const SBlockRV& block_rv, SBlockRV TracedScheduleNode::GetSBlock(const ffi::String& name, const ffi::Optional& func_name) { - GlobalVar gv = NullValue(); + GlobalVar gv; if (func_name.has_value()) { gv = state_->mod->GetGlobalVar(func_name.value()); } else if (func_working_on_.defined()) { diff --git a/src/s_tir/transform/hoist_expression.cc b/src/s_tir/transform/hoist_expression.cc index ac3987b6a09a..dbe389e84a63 100644 --- a/src/s_tir/transform/hoist_expression.cc +++ b/src/s_tir/transform/hoist_expression.cc @@ -58,7 +58,7 @@ enum class HoistedLetBindings : int { kLetExpr = (1 << 2), }; -struct HoistExpressionConfigNode : public AttrsNodeReflAdapter { +struct HoistExpressionConfigNode : public ffi::Object { int hoisted_conditionals; int hoisted_let_bindings; @@ -87,7 +87,7 @@ struct HoistExpressionConfigNode : public AttrsNodeReflAdapter(); @@ -95,7 +95,7 @@ class HoistExpressionConfig : public Attrs { node->hoisted_let_bindings = hoisted_let_bindings; data_ = std::move(node); } - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(HoistExpressionConfig, Attrs, + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(HoistExpressionConfig, ffi::ObjectRef, HoistExpressionConfigNode); }; @@ -103,7 +103,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { HoistExpressionConfigNode::RegisterReflection(); } TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.HoistExpression", HoistExpressionConfig); -struct HoistIfThenElseConfigNode : public AttrsNodeReflAdapter { +struct HoistIfThenElseConfigNode : public ffi::Object { bool support_block_scope_hoisting; static void RegisterReflection() { @@ -116,9 +116,9 @@ struct HoistIfThenElseConfigNode : public AttrsNodeReflAdapter { +struct InjectDoubleBufferConfigNode : public ffi::Object { int split_loop; static void RegisterReflection() { @@ -46,12 +46,12 @@ struct InjectDoubleBufferConfigNode : public AttrsNodeReflAdapter { +struct LoopPartitionConfigNode : public ffi::Object { bool partition_const_loop; bool no_unroll_loop_with_extent_one; bool unroll_loop_with_partition_hint_no_interval; @@ -64,14 +64,14 @@ struct LoopPartitionConfigNode : public AttrsNodeReflAdapter(), Var("", loop_vars[i]->dtype), IterVarType::kThreadIndex, + IterVar(Range(), Var("", loop_vars[i]->dtype), IterVarType::kThreadIndex, "threadIdx." + dim_index), /*annotations=*/{}, /*step=*/std::nullopt); diff --git a/src/s_tir/transform/storage_access.h b/src/s_tir/transform/storage_access.h index 2aa3850774f9..d85dc5a3c3ae 100644 --- a/src/s_tir/transform/storage_access.h +++ b/src/s_tir/transform/storage_access.h @@ -59,7 +59,7 @@ class StorageAccessVisitor : public StmtExprVisitor { /*! \brief The thread index that access this entry */ ffi::Array threads; /*! \brief The buffer variable, if any */ - Var buffer = NullValue(); + Var buffer = Var(ffi::ObjectPtr(nullptr)); /*! \brief The access data type */ DataType dtype; /*! \brief The touched access range diff --git a/src/s_tir/transform/unify_thread_binding.cc b/src/s_tir/transform/unify_thread_binding.cc index 3ee465223ab8..85333b6efcaf 100644 --- a/src/s_tir/transform/unify_thread_binding.cc +++ b/src/s_tir/transform/unify_thread_binding.cc @@ -159,8 +159,7 @@ class ThreadBindingUnifier : public StmtExprMutator { // necessary for unit tests. result = For(thread_binding->var, thread_binding->dom->min, thread_binding->dom->extent, ForKind::kThreadBinding, result, - IterVar(NullValue(), Var(""), IterVarType::kThreadIndex, - thread_binding->thread_tag), + IterVar(Range(), Var(""), IterVarType::kThreadIndex, thread_binding->thread_tag), {}, std::nullopt); launch_threads_.pop_back(); } diff --git a/src/tirx/analysis/stmt_finding.cc b/src/tirx/analysis/stmt_finding.cc index 0ba6146213cc..6dc3d07b4f07 100644 --- a/src/tirx/analysis/stmt_finding.cc +++ b/src/tirx/analysis/stmt_finding.cc @@ -24,7 +24,7 @@ namespace tvm { namespace tirx { const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var) { - GlobalVar result = NullValue(); + GlobalVar result; // Priority 1: PrimFunc marked as `tirx::attr::kIsEntryFunc` int num_prim_func = 0; const tirx::PrimFuncNode* main_func = nullptr; diff --git a/src/tirx/script/builder/frame.cc b/src/tirx/script/builder/frame.cc index 5e971d736113..e57b794cf31b 100644 --- a/src/tirx/script/builder/frame.cc +++ b/src/tirx/script/builder/frame.cc @@ -145,7 +145,7 @@ void PrimFuncFrameNode::ExitWithScope() { /*body=*/body, /*ret_type=*/ret_type.value_or(TupleType::Empty()), /*buffer_map=*/effective_buffer_map, - /*attrs=*/attrs.defined() ? DictAttrs(attrs) : NullValue(), + /*attrs=*/attrs.defined() ? DictAttrs(attrs) : DictAttrs(), /*span=*/tvm::Span()); func = tvm::tirx::ScriptComplete(func, effective_root_alloc_buffers, s_tir); IRBuilder builder = IRBuilder::Current(); diff --git a/src/tirx/transform/remove_no_op.cc b/src/tirx/transform/remove_no_op.cc index fcc7519334d0..4bdb5c083c01 100644 --- a/src/tirx/transform/remove_no_op.cc +++ b/src/tirx/transform/remove_no_op.cc @@ -44,7 +44,7 @@ namespace tvm { namespace tirx { -struct RemoveNoOpConfigNode : public AttrsNodeReflAdapter { +struct RemoveNoOpConfigNode : public ffi::Object { bool use_dataflow_analysis; int64_t max_simplification_steps; bool ignore_profiler_call; @@ -65,12 +65,13 @@ struct RemoveNoOpConfigNode : public AttrsNodeReflAdapter "If true, profiler calls are rendered as no-ops.", refl::DefaultValue(false)); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.transform.RemoveNoOpConfig", RemoveNoOpConfigNode, - BaseAttrsNode); + ffi::Object); }; -class RemoveNoOpConfig : public Attrs { +class RemoveNoOpConfig : public ffi::ObjectRef { public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RemoveNoOpConfig, Attrs, RemoveNoOpConfigNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RemoveNoOpConfig, ffi::ObjectRef, + RemoveNoOpConfigNode); }; TVM_FFI_STATIC_INIT_BLOCK() { RemoveNoOpConfigNode::RegisterReflection(); } diff --git a/src/tirx/transform/simplify.cc b/src/tirx/transform/simplify.cc index f193fb502da9..bf80ad00a455 100644 --- a/src/tirx/transform/simplify.cc +++ b/src/tirx/transform/simplify.cc @@ -44,7 +44,7 @@ namespace arith { using namespace tirx; -struct SimplifyConfigNode : public AttrsNodeReflAdapter { +struct SimplifyConfigNode : public ffi::Object { bool transitively_prove_inequalities; bool propagate_knowns_to_prove_conditional; bool propagate_knowns_to_simplify_expressions; @@ -78,7 +78,7 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter { refl::DefaultValue(false)); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.transform.SimplifyConfig", SimplifyConfigNode, - BaseAttrsNode); + ffi::Object); RewriteSimplifier::Extension GetEnabledExtensions() const { RewriteSimplifier::Extension flags = RewriteSimplifier::kNone; @@ -97,11 +97,15 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter { } }; -class SimplifyConfig : public Attrs { +class SimplifyConfig : public ffi::ObjectRef { public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, Attrs, SimplifyConfigNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, ffi::ObjectRef, SimplifyConfigNode); }; +static SimplifyConfig MakeDefaultSimplifyConfig() { + return AttrsWithDefaultValues(); +} + TVM_FFI_STATIC_INIT_BLOCK() { SimplifyConfigNode::RegisterReflection(); } TVM_REGISTER_PASS_CONFIG_OPTION("tirx.Simplify", SimplifyConfig); @@ -110,7 +114,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { public: static PrimFunc Apply(PrimFunc func, Analyzer* analyzer, ffi::Optional config_opt = std::nullopt) { - auto config = config_opt.value_or(AttrsWithDefaultValues()); + auto config = config_opt.value_or(MakeDefaultSimplifyConfig()); analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions()); std::optional touch_pattern = std::nullopt; diff --git a/src/tirx/transform/unroll_loop.cc b/src/tirx/transform/unroll_loop.cc index 3aea9ddd04c9..faf1ec2d677d 100644 --- a/src/tirx/transform/unroll_loop.cc +++ b/src/tirx/transform/unroll_loop.cc @@ -39,7 +39,7 @@ namespace tvm { namespace tirx { -struct UnrollLoopConfigNode : public AttrsNodeReflAdapter { +struct UnrollLoopConfigNode : public ffi::Object { int auto_max_step; int auto_max_depth; int auto_max_extent; @@ -64,12 +64,13 @@ struct UnrollLoopConfigNode : public AttrsNodeReflAdapter "Whether to always unroll local access", refl::DefaultValue(false)); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.transform.UnrollLoopConfig", UnrollLoopConfigNode, - BaseAttrsNode); + ffi::Object); }; -class UnrollLoopConfig : public Attrs { +class UnrollLoopConfig : public ffi::ObjectRef { public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(UnrollLoopConfig, Attrs, UnrollLoopConfigNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(UnrollLoopConfig, ffi::ObjectRef, + UnrollLoopConfigNode); }; TVM_FFI_STATIC_INIT_BLOCK() { UnrollLoopConfigNode::RegisterReflection(); } diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 0743c3db686e..e7a1715cc7bf 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -338,7 +338,7 @@ TEST(IRF, Substitute) { /*dtype=*/DataType::Float(32), /*shape=*/{n}, /*strides=*/{}, - /*elem_offset=*/NullValue(), + /*elem_offset=*/PrimExpr(), /*name=*/"buf", /*data_alignment=*/1, /*offset_factor=*/1,