From ff4060f0c6a4edf310e3f9880c3f1f7747fc3b9c Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 25 May 2026 20:12:21 +0000 Subject: [PATCH 1/8] [REFACTOR][IR] Replace NullValue() call sites with default construction Optional is now fully supported throughout the FFI reflection, so NullValue() sentinels are no longer needed; this commit replaces every remaining call site with direct default-construction (T() / Range() / GlobalVar() / DataType::Void()) ahead of removing the NullValue declaration. Also migrates FlipAttrs::axis from Integer (boxed nullable IntImm) to ffi::Optional per the user's directive to prefer Optional over NullValue(). --- include/tvm/relax/attrs/manipulate.h | 5 ++--- include/tvm/relax/attrs/sorting.h | 4 ++-- src/relax/op/tensor/manipulate.cc | 8 ++++---- src/relax/op/tensor/manipulate.h | 2 +- src/s_tir/schedule/concrete_schedule.cc | 4 ++-- src/s_tir/schedule/traced_schedule.cc | 4 ++-- src/s_tir/transform/lower_cross_thread_reduction.cc | 2 +- src/s_tir/transform/storage_access.h | 2 +- src/s_tir/transform/unify_thread_binding.cc | 2 +- src/tirx/analysis/stmt_finding.cc | 2 +- src/tirx/script/builder/frame.cc | 2 +- tests/cpp/ir_functor_test.cc | 2 +- 12 files changed, 19 insertions(+), 20 deletions(-) diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index f2ba7af0d9fb..cc8d28ec2639 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -186,13 +186,12 @@ struct TileAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in flip operators */ struct FlipAttrs : public AttrsNodeReflAdapter { - Integer axis; + ffi::Optional axis; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("axis", &FlipAttrs::axis, - "The axis along which to flip over.", - refl::DefaultValue(NullValue())); + "The axis along which to flip over."); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.FlipAttrs", FlipAttrs, BaseAttrsNode); }; // struct FlipAttrs diff --git a/include/tvm/relax/attrs/sorting.h b/include/tvm/relax/attrs/sorting.h index 354b77047272..b77fa2ecc72c 100644 --- a/include/tvm/relax/attrs/sorting.h +++ b/include/tvm/relax/attrs/sorting.h @@ -68,7 +68,7 @@ struct ArgsortAttrs : public AttrsNodeReflAdapter { "If it is not specified, it defaults to the ascending order.", refl::DefaultValue(false)) .def_ro("dtype", &ArgsortAttrs::dtype, "DType of the output indices.", - refl::DefaultValue(NullValue())); + refl::DefaultValue(DataType::Void())); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ArgsortAttrs", ArgsortAttrs, BaseAttrsNode); }; // struct ArgsortAttrs @@ -98,7 +98,7 @@ struct TopKAttrs : public AttrsNodeReflAdapter { "By default, return the largest k elements.", refl::DefaultValue(true)) .def_ro("dtype", &TopKAttrs::dtype, "Data type of the output indices.", - refl::DefaultValue(NullValue())); + refl::DefaultValue(DataType::Void())); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TopKAttrs", TopKAttrs, BaseAttrsNode); }; // struct TopKAttrs diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 461faf3fba99..98b7405bc030 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -2025,7 +2025,7 @@ TVM_REGISTER_OP("relax.tile") /* relax.flip */ -Expr flip(Expr data, Integer axis) { +Expr flip(Expr data, ffi::Optional axis) { auto attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.flip"); @@ -2043,7 +2043,7 @@ StructInfo InferStructInfoFlip(const Call& call, const BlockBuilder& ctx) { } TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); - int axis = attrs->axis.IntValue(); + int axis = static_cast(attrs->axis.value()); if (!data_sinfo->IsUnknownNdim()) { int ndim = data_sinfo->ndim; if (axis < -ndim || axis >= ndim) { @@ -2073,7 +2073,7 @@ InferLayoutOutput InferLayoutFlip( existing_layout = LayoutDecision(InitialLayout(ndim)); } - int axis = attrs->axis.IntValue(); + int axis = static_cast(attrs->axis.value()); if (axis < 0) { axis += ndim; } @@ -2082,7 +2082,7 @@ InferLayoutOutput InferLayoutFlip( TVM_FFI_ICHECK_GE(new_axis, 0) << "Failed to find transformed axis"; ffi::ObjectPtr new_attrs = ffi::make_object(*attrs); - new_attrs->axis = Integer(new_axis); + new_attrs->axis = static_cast(new_axis); return InferLayoutOutput({existing_layout}, {existing_layout}, Attrs(new_attrs)); } diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 260d27f1ef1d..0f278a0a008a 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -179,7 +179,7 @@ Expr tile(Expr data, ffi::Array repeats); * \param axis The axis to flip on * \return The computed result. */ -Expr flip(Expr data, Integer axis); +Expr flip(Expr data, ffi::Optional axis); /*! * \brief Gather elements from a tensor using indices. diff --git a/src/s_tir/schedule/concrete_schedule.cc b/src/s_tir/schedule/concrete_schedule.cc index 21f5454040a6..94402c44d75e 100644 --- a/src/s_tir/schedule/concrete_schedule.cc +++ b/src/s_tir/schedule/concrete_schedule.cc @@ -35,7 +35,7 @@ Schedule Schedule::Concrete(IRModule mod, LinearCongruentialEngine::TRandState s n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); n->Seed(seed); - GlobalVar gv = NullValue(); + GlobalVar gv; if (FindEntryFunc(mod, &gv) != nullptr) { n->func_working_on_ = gv; } else { @@ -316,7 +316,7 @@ SBlockRV ConcreteScheduleNode::GetSBlock(const ffi::String& name, IRModule mod_; ffi::Array blocks_; }; - GlobalVar gv = NullValue(); + GlobalVar gv; if (func_name.has_value()) { gv = state_->mod->GetGlobalVar(func_name.value()); } else if (func_working_on_.has_value()) { diff --git a/src/s_tir/schedule/traced_schedule.cc b/src/s_tir/schedule/traced_schedule.cc index e12cdd69de3f..6357e1ae19d3 100644 --- a/src/s_tir/schedule/traced_schedule.cc +++ b/src/s_tir/schedule/traced_schedule.cc @@ -31,7 +31,7 @@ Schedule Schedule::Traced(IRModule mod, LinearCongruentialEngine::TRandState see n->analyzer_ = std::make_unique(); n->trace_ = Trace(); n->Seed(seed); - GlobalVar gv = NullValue(); + GlobalVar gv; if (FindEntryFunc(mod, &gv) != nullptr) { n->func_working_on_ = gv; } else { @@ -118,7 +118,7 @@ LoopRV TracedScheduleNode::SampleComputeLocation(const SBlockRV& block_rv, SBlockRV TracedScheduleNode::GetSBlock(const ffi::String& name, const ffi::Optional& func_name) { - GlobalVar gv = NullValue(); + GlobalVar gv; if (func_name.has_value()) { gv = state_->mod->GetGlobalVar(func_name.value()); } else if (func_working_on_.defined()) { diff --git a/src/s_tir/transform/lower_cross_thread_reduction.cc b/src/s_tir/transform/lower_cross_thread_reduction.cc index 5bb0f6b7670f..ba7dd6962576 100644 --- a/src/s_tir/transform/lower_cross_thread_reduction.cc +++ b/src/s_tir/transform/lower_cross_thread_reduction.cc @@ -881,7 +881,7 @@ class CrossThreadReductionTransformer : public StmtMutator { /*kind=*/ForKind::kThreadBinding, // /*body=*/body, // /*thread_binding=*/ - IterVar(NullValue(), Var("", loop_vars[i]->dtype), IterVarType::kThreadIndex, + IterVar(Range(), Var("", loop_vars[i]->dtype), IterVarType::kThreadIndex, "threadIdx." + dim_index), /*annotations=*/{}, /*step=*/std::nullopt); diff --git a/src/s_tir/transform/storage_access.h b/src/s_tir/transform/storage_access.h index 2aa3850774f9..9bc1cf42b2e0 100644 --- a/src/s_tir/transform/storage_access.h +++ b/src/s_tir/transform/storage_access.h @@ -59,7 +59,7 @@ class StorageAccessVisitor : public StmtExprVisitor { /*! \brief The thread index that access this entry */ ffi::Array threads; /*! \brief The buffer variable, if any */ - Var buffer = NullValue(); + Var buffer; /*! \brief The access data type */ DataType dtype; /*! \brief The touched access range diff --git a/src/s_tir/transform/unify_thread_binding.cc b/src/s_tir/transform/unify_thread_binding.cc index 3ee465223ab8..5f5bea1f173b 100644 --- a/src/s_tir/transform/unify_thread_binding.cc +++ b/src/s_tir/transform/unify_thread_binding.cc @@ -159,7 +159,7 @@ class ThreadBindingUnifier : public StmtExprMutator { // necessary for unit tests. result = For(thread_binding->var, thread_binding->dom->min, thread_binding->dom->extent, ForKind::kThreadBinding, result, - IterVar(NullValue(), Var(""), IterVarType::kThreadIndex, + IterVar(Range(), Var(""), IterVarType::kThreadIndex, thread_binding->thread_tag), {}, std::nullopt); launch_threads_.pop_back(); diff --git a/src/tirx/analysis/stmt_finding.cc b/src/tirx/analysis/stmt_finding.cc index 0ba6146213cc..6dc3d07b4f07 100644 --- a/src/tirx/analysis/stmt_finding.cc +++ b/src/tirx/analysis/stmt_finding.cc @@ -24,7 +24,7 @@ namespace tvm { namespace tirx { const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var) { - GlobalVar result = NullValue(); + GlobalVar result; // Priority 1: PrimFunc marked as `tirx::attr::kIsEntryFunc` int num_prim_func = 0; const tirx::PrimFuncNode* main_func = nullptr; diff --git a/src/tirx/script/builder/frame.cc b/src/tirx/script/builder/frame.cc index 5e971d736113..e57b794cf31b 100644 --- a/src/tirx/script/builder/frame.cc +++ b/src/tirx/script/builder/frame.cc @@ -145,7 +145,7 @@ void PrimFuncFrameNode::ExitWithScope() { /*body=*/body, /*ret_type=*/ret_type.value_or(TupleType::Empty()), /*buffer_map=*/effective_buffer_map, - /*attrs=*/attrs.defined() ? DictAttrs(attrs) : NullValue(), + /*attrs=*/attrs.defined() ? DictAttrs(attrs) : DictAttrs(), /*span=*/tvm::Span()); func = tvm::tirx::ScriptComplete(func, effective_root_alloc_buffers, s_tir); IRBuilder builder = IRBuilder::Current(); diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 0743c3db686e..e7a1715cc7bf 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -338,7 +338,7 @@ TEST(IRF, Substitute) { /*dtype=*/DataType::Float(32), /*shape=*/{n}, /*strides=*/{}, - /*elem_offset=*/NullValue(), + /*elem_offset=*/PrimExpr(), /*name=*/"buf", /*data_alignment=*/1, /*offset_factor=*/1, From b9ca1301509bf30f916b375f456d10c14a362100 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 25 May 2026 20:39:42 +0000 Subject: [PATCH 2/8] [REFACTOR][IR] Drop NullValue declaration, AttrsNodeReflAdapter, BaseAttrsNode legacy methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BaseAttrsNode has only the reflection-based init path remaining. This commit removes all legacy scaffolding that was needed before the reflection migration: - Remove NullValue template and DataType specialization (all call sites replaced in the previous commit). - Remove InitBySeq (undeclared, dead) and InitByPackedArgs (pure virtual, only invoked on the dead else-branch) from BaseAttrsNode. - Remove DictAttrsNode::InitByPackedArgs override and its definition. - Remove AttrsNodeReflAdapter template — its sole role was providing the InitByPackedArgs stub that now signals an error. - Simplify AttrsWithDefaultValues to always use the FFI reflection path; broaden the static_assert to accept any ffi::ObjectRef subtype so pass-config classes can use it too. - Trim attrs.h includes: remove reflection/accessor.h, , (unused); keep structural_equal.h, structural_hash.h and as downstream files transitively depend on them. AttrFieldInfo / OpNode::arguments are kept: they are actively read by GetArgStructInfo in op_common.h for argument count validation. --- include/tvm/ir/attrs.h | 93 +++++++----------------------------------- src/ir/attrs.cc | 8 ---- 2 files changed, 14 insertions(+), 87 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index fa3dfa5b3ec2..287a26351728 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -23,7 +23,7 @@ * This module enables declaration of named attributes * which support default value setup and bound checking. * - * \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD + * \sa BaseAttrsNode, AttrsWithDefaultValues */ #ifndef TVM_IR_ATTRS_H_ #define TVM_IR_ATTRS_H_ @@ -32,36 +32,17 @@ #include #include #include -#include #include #include #include -#include #include #include #include #include -#include namespace tvm { -/*! - * \brief Create a NodeRef type that represents null. - * \tparam TNodeRef the type to be created. - * \return A instance that will represent None. - */ -template -inline TObjectRef NullValue() { - static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types"); - return TObjectRef(ffi::ObjectPtr(nullptr)); -} - -template <> -inline DataType NullValue() { - return DataType(DataType::kHandle, 0, 0); -} - /*! * \brief Information about attribute fields in string representations. */ @@ -103,22 +84,6 @@ class BaseAttrsNode : public ffi::Object { public: /*! \brief virtual destructor */ virtual ~BaseAttrsNode() {} - /*! - * \brief Initialize the attributes by sequence of arguments - * \param args The positional arguments in the form - * [key0, value0, key1, value1, ..., key_n, value_n] - */ - template - inline void InitBySeq(Args&&... args); - /*! - * \brief Initialize the attributes by arguments. - * \param kwargs The key value pairs for initialization. - * [key0, value0, key1, value1, ..., key_n, value_n] - * \param allow_unknown Whether allow additional unknown fields. - * \note This function throws when the required field is not present. - */ - TVM_DLL virtual void InitByPackedArgs(const ffi::PackedArgs& kwargs, - bool allow_unknown = false) = 0; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; TVM_FFI_DECLARE_OBJECT_INFO("ir.Attrs", BaseAttrsNode, ffi::Object); @@ -149,8 +114,6 @@ class DictAttrsNode : public BaseAttrsNode { rfl::ObjectDef().def_ro("__dict__", &DictAttrsNode::dict); } - void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final; - // type info TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.DictAttrs", DictAttrsNode, BaseAttrsNode); }; @@ -380,48 +343,20 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { } /*! - * \brief Adapter for AttrsNode with the new reflection API. - * - * We will phaseout the old AttrsNode in future in favor of the new reflection API. - * This adapter allows us to gradually migrate to the new reflection API. - * - * \tparam DerivedType The final attribute type. + * \brief Create an object with all default values, using the reflection defaults. + * \tparam TObj the ObjectRef type to be created. + * \return An instance with all reflection-defined default values applied. */ -template -class AttrsNodeReflAdapter : public BaseAttrsNode { - public: - void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final { - TVM_FFI_THROW(InternalError) << "`" << DerivedType::_type_key - << "` uses new reflection mechanism for init"; - } - - private: - DerivedType* self() const { - return const_cast(static_cast(this)); - } -}; - -/*! - * \brief Create an Attr object with all default values. - * \tparam TAttrNode the type to be created. - * \return A instance that will represent None. - */ -template -inline TAttrs AttrsWithDefaultValues() { - static_assert(std::is_base_of_v, "Can only take attr nodes"); - using ContainerType = typename TAttrs::ContainerType; - if constexpr (std::is_base_of_v, ContainerType>) { - static auto finit_object = ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs"); - AnyView packed_args[1]; - packed_args[0] = ContainerType::RuntimeTypeIndex(); - ffi::Any rv; - finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv); - return rv.cast(); - } else { - auto n = ffi::make_object(); - n->InitByPackedArgs(ffi::PackedArgs(nullptr, 0), false); - return TAttrs(n); - } +template +inline TObj AttrsWithDefaultValues() { + static_assert(std::is_base_of_v, "Can only create ObjectRef-derived types"); + using ContainerType = typename TObj::ContainerType; + static auto finit_object = ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs"); + AnyView packed_args[1]; + packed_args[0] = ContainerType::RuntimeTypeIndex(); + ffi::Any rv; + finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv); + return rv.cast(); } } // namespace tvm diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index cfe269e4eba6..e7d9b9082809 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -53,14 +53,6 @@ DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key) { return attrs; } -void DictAttrsNode::InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) { - for (int i = 0; i < args.size(); i += 2) { - ffi::String key = args[i].cast(); - ffi::AnyView val = args[i + 1]; - dict.Set(key, val); - } -} - DictAttrs::DictAttrs(ffi::Map dict) { ffi::ObjectPtr n = ffi::make_object(); n->dict = std::move(dict); From 574663d2e87bfa500dd4a74ffc5aeba38d1a58cf Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 25 May 2026 20:42:16 +0000 Subject: [PATCH 3/8] [REFACTOR][IR] Subclass BaseAttrsNode directly, drop AttrsNodeReflAdapter --- include/tvm/relax/attrs/ccl.h | 6 +-- include/tvm/relax/attrs/create.h | 4 +- include/tvm/relax/attrs/datatype.h | 4 +- include/tvm/relax/attrs/distributed.h | 2 +- include/tvm/relax/attrs/image.h | 6 +-- include/tvm/relax/attrs/index.h | 4 +- include/tvm/relax/attrs/linear_algebra.h | 4 +- include/tvm/relax/attrs/manipulate.h | 36 ++++++++-------- include/tvm/relax/attrs/nn.h | 52 ++++++++++++------------ include/tvm/relax/attrs/op.h | 10 ++--- include/tvm/relax/attrs/qdq.h | 2 +- include/tvm/relax/attrs/sampling.h | 2 +- include/tvm/relax/attrs/search.h | 4 +- include/tvm/relax/attrs/sorting.h | 6 +-- include/tvm/relax/attrs/statistical.h | 4 +- include/tvm/relax/attrs/vision.h | 13 +++--- include/tvm/target/virtual_device.h | 2 +- 17 files changed, 80 insertions(+), 81 deletions(-) diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h index 09d40b4ed98e..7e0624706b0c 100644 --- a/include/tvm/relax/attrs/ccl.h +++ b/include/tvm/relax/attrs/ccl.h @@ -31,7 +31,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in allreduce operators */ -struct AllReduceAttrs : public tvm::AttrsNodeReflAdapter { +struct AllReduceAttrs : public tvm::BaseAttrsNode { ffi::String op_type; bool in_group; @@ -49,7 +49,7 @@ struct AllReduceAttrs : public tvm::AttrsNodeReflAdapter { }; // struct AllReduceAttrs /*! \brief Attributes used in allgather operators */ -struct AllGatherAttrs : public tvm::AttrsNodeReflAdapter { +struct AllGatherAttrs : public tvm::BaseAttrsNode { int num_workers; bool in_group; @@ -67,7 +67,7 @@ struct AllGatherAttrs : public tvm::AttrsNodeReflAdapter { }; // struct AllGatherAttrs /*! \brief Attributes used in scatter operators */ -struct ScatterCollectiveAttrs : public tvm::AttrsNodeReflAdapter { +struct ScatterCollectiveAttrs : public tvm::BaseAttrsNode { int num_workers; int axis; diff --git a/include/tvm/relax/attrs/create.h b/include/tvm/relax/attrs/create.h index c631fd3b4e3d..9a9e453263a0 100644 --- a/include/tvm/relax/attrs/create.h +++ b/include/tvm/relax/attrs/create.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operators */ -struct InitAttrs : public AttrsNodeReflAdapter { +struct InitAttrs : public BaseAttrsNode { DataType dtype; static void RegisterReflection() { @@ -42,7 +42,7 @@ struct InitAttrs : public AttrsNodeReflAdapter { }; // struct InitAttrs /*! \brief Attributes used in tril and triu operator */ -struct TriluAttrs : public AttrsNodeReflAdapter { +struct TriluAttrs : public BaseAttrsNode { int k; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/datatype.h b/include/tvm/relax/attrs/datatype.h index dd07e3b54851..a1870597033e 100644 --- a/include/tvm/relax/attrs/datatype.h +++ b/include/tvm/relax/attrs/datatype.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in astype operator */ -struct AstypeAttrs : public AttrsNodeReflAdapter { +struct AstypeAttrs : public BaseAttrsNode { DataType dtype; static void RegisterReflection() { @@ -41,7 +41,7 @@ struct AstypeAttrs : public AttrsNodeReflAdapter { }; // struct AstypeAttrs. /*! \brief Attributes used in wrap_param operator */ -struct WrapParamAttrs : public AttrsNodeReflAdapter { +struct WrapParamAttrs : public BaseAttrsNode { DataType dtype; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/distributed.h b/include/tvm/relax/attrs/distributed.h index 356a248ba220..cce508ef1d50 100644 --- a/include/tvm/relax/attrs/distributed.h +++ b/include/tvm/relax/attrs/distributed.h @@ -32,7 +32,7 @@ namespace tvm { namespace relax { /*! \brief Attributes for redistribute and annotate_sharding operator */ -struct DistributionAttrs : public AttrsNodeReflAdapter { +struct DistributionAttrs : public BaseAttrsNode { distributed::DeviceMesh device_mesh; distributed::Placement placement; diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h index 52aac58dcde9..8cc5e36734b6 100644 --- a/include/tvm/relax/attrs/image.h +++ b/include/tvm/relax/attrs/image.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in image resize2d operator */ -struct Resize2DAttrs : public AttrsNodeReflAdapter { +struct Resize2DAttrs : public BaseAttrsNode { ffi::Array roi; ffi::String layout; ffi::String method; @@ -79,7 +79,7 @@ struct Resize2DAttrs : public AttrsNodeReflAdapter { }; // struct Resize2dAttrs /*! \brief Attributes used in image resize3d operator */ -struct Resize3DAttrs : public AttrsNodeReflAdapter { +struct Resize3DAttrs : public BaseAttrsNode { ffi::Array roi; ffi::String layout; ffi::String method; @@ -128,7 +128,7 @@ struct Resize3DAttrs : public AttrsNodeReflAdapter { }; // struct Resize3DAttrs /*! \brief Attributes used in image grid_sample operator */ -struct GridSampleAttrs : public AttrsNodeReflAdapter { +struct GridSampleAttrs : public BaseAttrsNode { ffi::String method; ffi::String layout; ffi::String padding_mode; diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h index 0ea7c06bacc0..7b4c446bb80c 100644 --- a/include/tvm/relax/attrs/index.h +++ b/include/tvm/relax/attrs/index.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in take operator */ -struct TakeAttrs : public AttrsNodeReflAdapter { +struct TakeAttrs : public BaseAttrsNode { ffi::Optional axis; ffi::String mode; @@ -45,7 +45,7 @@ struct TakeAttrs : public AttrsNodeReflAdapter { }; // struct TakeAttrs /*! \brief Attributes used in strided_slice operator */ -struct StridedSliceAttrs : public AttrsNodeReflAdapter { +struct StridedSliceAttrs : public BaseAttrsNode { bool assume_inbound; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/linear_algebra.h b/include/tvm/relax/attrs/linear_algebra.h index f95d817f1e4d..2627dafcf6b3 100644 --- a/include/tvm/relax/attrs/linear_algebra.h +++ b/include/tvm/relax/attrs/linear_algebra.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes for matmul operator */ -struct MatmulAttrs : public AttrsNodeReflAdapter { +struct MatmulAttrs : public BaseAttrsNode { DataType out_dtype; static void RegisterReflection() { @@ -42,7 +42,7 @@ struct MatmulAttrs : public AttrsNodeReflAdapter { }; // struct MatmulAttrs /*! \brief Attributes used in einsum operator */ -struct EinsumAttrs : public AttrsNodeReflAdapter { +struct EinsumAttrs : public BaseAttrsNode { ffi::String subscripts; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index cc8d28ec2639..e43750b9a706 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -31,7 +31,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in concat operators */ -struct ConcatAttrs : public AttrsNodeReflAdapter { +struct ConcatAttrs : public BaseAttrsNode { ffi::Optional axis; static void RegisterReflection() { @@ -44,7 +44,7 @@ struct ConcatAttrs : public AttrsNodeReflAdapter { }; // struct ConcatAttrs /*! \brief Attributes used in expand_dims operators */ -struct ExpandDimsAttrs : public AttrsNodeReflAdapter { +struct ExpandDimsAttrs : public BaseAttrsNode { ffi::Array axis; static void RegisterReflection() { @@ -59,7 +59,7 @@ struct ExpandDimsAttrs : public AttrsNodeReflAdapter { }; // struct ExpandDimsAttrs /*! \brief Attributes used in layout_transform operator */ -struct LayoutTransformAttrs : public AttrsNodeReflAdapter { +struct LayoutTransformAttrs : public BaseAttrsNode { tirx::IndexMap index_map; // pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This // needs to be revisited in case PrimValue is evolved to represent symbolic expression in future. @@ -97,7 +97,7 @@ struct LayoutTransformAttrs : public AttrsNodeReflAdapter }; // struct LayoutTransformAttrs /*! \brief Attributes used in permute_dims operator */ -struct PermuteDimsAttrs : public AttrsNodeReflAdapter { +struct PermuteDimsAttrs : public BaseAttrsNode { ffi::Optional> axes; static void RegisterReflection() { @@ -110,7 +110,7 @@ struct PermuteDimsAttrs : public AttrsNodeReflAdapter { }; // struct PermuteDimsAttrs /*! \brief Attributes used in split operator */ -struct SplitAttrs : public AttrsNodeReflAdapter { +struct SplitAttrs : public BaseAttrsNode { ffi::ObjectRef indices_or_sections; int axis; @@ -125,7 +125,7 @@ struct SplitAttrs : public AttrsNodeReflAdapter { }; // struct SplitAttrs /*! \brief Attributes used in squeeze operators */ -struct SqueezeAttrs : public AttrsNodeReflAdapter { +struct SqueezeAttrs : public BaseAttrsNode { ffi::Optional> axis; static void RegisterReflection() { @@ -140,7 +140,7 @@ struct SqueezeAttrs : public AttrsNodeReflAdapter { }; // struct SqueezeAttrs /*! \brief Attributes used in stack operators */ -struct StackAttrs : public AttrsNodeReflAdapter { +struct StackAttrs : public BaseAttrsNode { ffi::Optional axis; static void RegisterReflection() { @@ -156,7 +156,7 @@ struct StackAttrs : public AttrsNodeReflAdapter { }; // struct StackAttrs /*! \brief Attributes used in repeat operators */ -struct RepeatAttrs : public AttrsNodeReflAdapter { +struct RepeatAttrs : public BaseAttrsNode { int repeats; ffi::Optional axis; @@ -173,7 +173,7 @@ struct RepeatAttrs : public AttrsNodeReflAdapter { }; // struct RepeatAttrs /*! \brief Attributes used in tile operators */ -struct TileAttrs : public AttrsNodeReflAdapter { +struct TileAttrs : public BaseAttrsNode { ffi::Array repeats; static void RegisterReflection() { @@ -185,7 +185,7 @@ struct TileAttrs : public AttrsNodeReflAdapter { }; // struct TileAttrs /*! \brief Attributes used in flip operators */ -struct FlipAttrs : public AttrsNodeReflAdapter { +struct FlipAttrs : public BaseAttrsNode { ffi::Optional axis; static void RegisterReflection() { @@ -197,7 +197,7 @@ struct FlipAttrs : public AttrsNodeReflAdapter { }; // struct FlipAttrs /*! \brief Attributes used in gather_elements operators */ -struct GatherElementsAttrs : public AttrsNodeReflAdapter { +struct GatherElementsAttrs : public BaseAttrsNode { Integer axis; static void RegisterReflection() { @@ -211,7 +211,7 @@ struct GatherElementsAttrs : public AttrsNodeReflAdapter { }; // struct GatherElementsAttrs /*! \brief Attributes used in gather_nd operators */ -struct GatherNDAttrs : public AttrsNodeReflAdapter { +struct GatherNDAttrs : public BaseAttrsNode { Integer batch_dims; static void RegisterReflection() { @@ -223,7 +223,7 @@ struct GatherNDAttrs : public AttrsNodeReflAdapter { }; // struct GatherNDAttrs /*! \brief Attributes used in index_put operator */ -struct IndexPutAttrs : public AttrsNodeReflAdapter { +struct IndexPutAttrs : public BaseAttrsNode { bool accumulate; static void RegisterReflection() { @@ -239,7 +239,7 @@ struct IndexPutAttrs : public AttrsNodeReflAdapter { }; // struct IndexPutAttrs /*! \brief Attribute used in meshgrid operator */ -struct MeshgridAttrs : public AttrsNodeReflAdapter { +struct MeshgridAttrs : public BaseAttrsNode { ffi::Optional indexing; static void RegisterReflection() { @@ -251,7 +251,7 @@ struct MeshgridAttrs : public AttrsNodeReflAdapter { }; /*! \brief Attributes used in scatter_elements operators */ -struct ScatterElementsAttrs : public AttrsNodeReflAdapter { +struct ScatterElementsAttrs : public BaseAttrsNode { Integer axis; ffi::String reduction; @@ -270,7 +270,7 @@ struct ScatterElementsAttrs : public AttrsNodeReflAdapter }; // struct ScatterElementsAttrs /*! \brief Attributes used in scatter_nd operators */ -struct ScatterNDAttrs : public AttrsNodeReflAdapter { +struct ScatterNDAttrs : public BaseAttrsNode { ffi::String reduction; static void RegisterReflection() { @@ -285,7 +285,7 @@ struct ScatterNDAttrs : public AttrsNodeReflAdapter { }; // struct ScatterNDAttrs /*! \brief Attributes used in slice_scatter operator */ -struct SliceScatterAttrs : public AttrsNodeReflAdapter { +struct SliceScatterAttrs : public BaseAttrsNode { int axis; static void RegisterReflection() { @@ -299,7 +299,7 @@ struct SliceScatterAttrs : public AttrsNodeReflAdapter { }; // struct SliceScatterAttrs /*! \brief Attributes used in one_hot operator */ -struct OneHotAttrs : public AttrsNodeReflAdapter { +struct OneHotAttrs : public BaseAttrsNode { int depth; int axis; diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 45abeb9d5b7e..bfc85dfd5a13 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in Conv1d operator */ -struct Conv1DAttrs : public AttrsNodeReflAdapter { +struct Conv1DAttrs : public BaseAttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array dilation; @@ -74,7 +74,7 @@ struct Conv1DAttrs : public AttrsNodeReflAdapter { }; // struct Conv1dAttrs /*! \brief Attributes used in Conv2d operator */ -struct Conv2DAttrs : public AttrsNodeReflAdapter { +struct Conv2DAttrs : public BaseAttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array dilation; @@ -120,7 +120,7 @@ struct Conv2DAttrs : public AttrsNodeReflAdapter { }; // struct Conv2dAttrs /*! \brief Attributes used in Conv3d operator */ -struct Conv3DAttrs : public AttrsNodeReflAdapter { +struct Conv3DAttrs : public BaseAttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array dilation; @@ -168,7 +168,7 @@ struct Conv3DAttrs : public AttrsNodeReflAdapter { }; // struct Conv3dAttrs /*! \brief Attributes used in Conv1DTranspose operator */ -struct Conv1DTransposeAttrs : public AttrsNodeReflAdapter { +struct Conv1DTransposeAttrs : public BaseAttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array output_padding; @@ -217,7 +217,7 @@ struct Conv1DTransposeAttrs : public AttrsNodeReflAdapter }; // struct Conv1DTransposeAttrs /*! \brief Attributes used in Conv2d operator */ -struct Conv2DTransposeAttrs : public AttrsNodeReflAdapter { +struct Conv2DTransposeAttrs : public BaseAttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array output_padding; @@ -268,7 +268,7 @@ struct Conv2DTransposeAttrs : public AttrsNodeReflAdapter }; // struct Conv2DTransposeAttrs /*! \brief Attributes used in Conv3dTranspose operator */ -struct Conv3DTransposeAttrs : public AttrsNodeReflAdapter { +struct Conv3DTransposeAttrs : public BaseAttrsNode { ffi::Array strides; ffi::Array padding; ffi::Array output_padding; @@ -321,7 +321,7 @@ struct Conv3DTransposeAttrs : public AttrsNodeReflAdapter }; // struct Conv3DTransposeAttrs /*! \brief Attributes used in max_pool1d and avg_pool1d operator */ -struct Pool1DAttrs : public AttrsNodeReflAdapter { +struct Pool1DAttrs : public BaseAttrsNode { ffi::Array pool_size; ffi::Array strides; ffi::Array padding; @@ -362,7 +362,7 @@ struct Pool1DAttrs : public AttrsNodeReflAdapter { }; // struct Pool1dAttrs /*! \brief Attributes used in max_pool2d and avg_pool2d operator */ -struct Pool2DAttrs : public AttrsNodeReflAdapter { +struct Pool2DAttrs : public BaseAttrsNode { ffi::Array pool_size; ffi::Array strides; ffi::Array padding; @@ -405,7 +405,7 @@ struct Pool2DAttrs : public AttrsNodeReflAdapter { }; // struct Pool2dAttrs /*! \brief Attributes used in max_pool3d and avg_pool3d operator */ -struct Pool3DAttrs : public AttrsNodeReflAdapter { +struct Pool3DAttrs : public BaseAttrsNode { ffi::Array pool_size; ffi::Array strides; ffi::Array padding; @@ -448,7 +448,7 @@ struct Pool3DAttrs : public AttrsNodeReflAdapter { }; // struct Pool3dAttrs /*! \brief Attributes for 1d adaptive pool operator */ -struct AdaptivePool1DAttrs : public AttrsNodeReflAdapter { +struct AdaptivePool1DAttrs : public BaseAttrsNode { ffi::Optional> output_size; ffi::String layout; ffi::String out_layout; @@ -473,7 +473,7 @@ struct AdaptivePool1DAttrs : public AttrsNodeReflAdapter { }; // struct AdaptivePool1DAttrs /*! \brief Attributes for 2d adaptive pool operator */ -struct AdaptivePool2DAttrs : public AttrsNodeReflAdapter { +struct AdaptivePool2DAttrs : public BaseAttrsNode { ffi::Optional> output_size; ffi::String layout; ffi::String out_layout; @@ -498,7 +498,7 @@ struct AdaptivePool2DAttrs : public AttrsNodeReflAdapter { }; // struct AdaptivePool2DAttrs /*! \brief Attributes for 3d adaptive pool operator */ -struct AdaptivePool3DAttrs : public AttrsNodeReflAdapter { +struct AdaptivePool3DAttrs : public BaseAttrsNode { ffi::Optional> output_size; ffi::String layout; ffi::String out_layout; @@ -523,7 +523,7 @@ struct AdaptivePool3DAttrs : public AttrsNodeReflAdapter { }; // struct AdaptivePool3DAttrs /*! \brief Attributes used in softmax operators */ -struct SoftmaxAttrs : public AttrsNodeReflAdapter { +struct SoftmaxAttrs : public BaseAttrsNode { int axis; static void RegisterReflection() { @@ -535,7 +535,7 @@ struct SoftmaxAttrs : public AttrsNodeReflAdapter { }; /*! \brief Attributes used in softmax operators */ -struct LeakyReluAttrs : public AttrsNodeReflAdapter { +struct LeakyReluAttrs : public BaseAttrsNode { double alpha; static void RegisterReflection() { @@ -547,7 +547,7 @@ struct LeakyReluAttrs : public AttrsNodeReflAdapter { }; /*! \brief Attributes used in softplus operators */ -struct SoftplusAttrs : public AttrsNodeReflAdapter { +struct SoftplusAttrs : public BaseAttrsNode { double beta; double threshold; @@ -563,7 +563,7 @@ struct SoftplusAttrs : public AttrsNodeReflAdapter { }; /*! \brief Attributes used in PReLU operator */ -struct PReluAttrs : public AttrsNodeReflAdapter { +struct PReluAttrs : public BaseAttrsNode { int axis; static void RegisterReflection() { @@ -575,7 +575,7 @@ struct PReluAttrs : public AttrsNodeReflAdapter { }; /*! \brief Attributes used in batch_norm operator */ -struct BatchNormAttrs : public AttrsNodeReflAdapter { +struct BatchNormAttrs : public BaseAttrsNode { int axis; double epsilon; bool center; @@ -602,7 +602,7 @@ struct BatchNormAttrs : public AttrsNodeReflAdapter { }; // struct BatchNormAttrs /*! \brief Attributes used in layer_norm operator */ -struct LayerNormAttrs : public AttrsNodeReflAdapter { +struct LayerNormAttrs : public BaseAttrsNode { ffi::Array axes; double epsilon; bool center; @@ -624,7 +624,7 @@ struct LayerNormAttrs : public AttrsNodeReflAdapter { }; // struct LayerNormAttrs /*! \brief Attributes used in group_norm operator */ -struct GroupNormAttrs : public AttrsNodeReflAdapter { +struct GroupNormAttrs : public BaseAttrsNode { int num_groups; int channel_axis; ffi::Array axes; @@ -653,7 +653,7 @@ struct GroupNormAttrs : public AttrsNodeReflAdapter { }; // struct GroupNormAttrs /*! \brief Attributes used in instance_norm operator */ -struct InstanceNormAttrs : public AttrsNodeReflAdapter { +struct InstanceNormAttrs : public BaseAttrsNode { int channel_axis; ffi::Array axes; double epsilon; @@ -679,7 +679,7 @@ struct InstanceNormAttrs : public AttrsNodeReflAdapter { }; // struct InstanceNormAttrs /*! \brief Attributes used in rms_norm operator */ -struct RMSNormAttrs : public AttrsNodeReflAdapter { +struct RMSNormAttrs : public BaseAttrsNode { ffi::Array axes; double epsilon; @@ -695,7 +695,7 @@ struct RMSNormAttrs : public AttrsNodeReflAdapter { }; // struct RMSNormAttrs /*! \brief Attributes used in nll_loss operator */ -struct NLLLossAttrs : public AttrsNodeReflAdapter { +struct NLLLossAttrs : public BaseAttrsNode { ffi::String reduction; int ignore_index; @@ -712,7 +712,7 @@ struct NLLLossAttrs : public AttrsNodeReflAdapter { }; // struct NLLLossAttrs /*! \brief Attributes used in dropout operator */ -struct DropoutAttrs : public AttrsNodeReflAdapter { +struct DropoutAttrs : public BaseAttrsNode { double rate; static void RegisterReflection() { @@ -725,7 +725,7 @@ struct DropoutAttrs : public AttrsNodeReflAdapter { }; // struct DropoutAttrs /*! \brief Attributes used in Attention operator */ -struct AttentionAttrs : public AttrsNodeReflAdapter { +struct AttentionAttrs : public BaseAttrsNode { ffi::Optional scale; ffi::Optional causal_mask; ffi::Optional window_size; @@ -745,7 +745,7 @@ struct AttentionAttrs : public AttrsNodeReflAdapter { }; // struct AttentionAttrs /*! \brief Attributes used for the padding operator */ -struct PadAttrs : public AttrsNodeReflAdapter { +struct PadAttrs : public BaseAttrsNode { ffi::Array pad_width; double pad_value = 0.0; tvm::ffi::String pad_mode; @@ -768,7 +768,7 @@ struct PadAttrs : public AttrsNodeReflAdapter { }; /*! \brief Attributes used for the pixel shuffle operator */ -struct PixelShuffleAttrs : public AttrsNodeReflAdapter { +struct PixelShuffleAttrs : public BaseAttrsNode { int upscale_factor; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index 54640901ff53..79e00d590abe 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -31,7 +31,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in call_tir_with_grad */ -struct CallTIRWithGradAttrs : public AttrsNodeReflAdapter { +struct CallTIRWithGradAttrs : public BaseAttrsNode { ffi::String te_grad_name; ffi::Map te_grad_kwargs; @@ -49,7 +49,7 @@ struct CallTIRWithGradAttrs : public AttrsNodeReflAdapter }; // struct CallTIRAttrs /*! \brief Attributes used in call_tir_inplace */ -struct CallTIRInplaceAttrs : public AttrsNodeReflAdapter { +struct CallTIRInplaceAttrs : public BaseAttrsNode { /*! * \brief Indices that describe which input corresponds to which output. * @@ -69,7 +69,7 @@ struct CallTIRInplaceAttrs : public AttrsNodeReflAdapter { }; // struct CallTIRInplaceAttrs /*! \brief Attributes used in call_inplace_packed */ -struct CallInplacePackedAttrs : public AttrsNodeReflAdapter { +struct CallInplacePackedAttrs : public BaseAttrsNode { /*! * \brief Indices that describe which input corresponds to which output. * @@ -89,7 +89,7 @@ struct CallInplacePackedAttrs : public AttrsNodeReflAdapter { +struct ToVDeviceAttrs : public BaseAttrsNode { VDevice dst_vdevice; static void RegisterReflection() { @@ -101,7 +101,7 @@ struct ToVDeviceAttrs : public AttrsNodeReflAdapter { }; // struct ToVDeviceAttrs /*! \brief Attributes used in hint_on_device */ -struct HintOnDeviceAttrs : public AttrsNodeReflAdapter { +struct HintOnDeviceAttrs : public BaseAttrsNode { int32_t device_type; int32_t index; MemoryScope memory_scope; diff --git a/include/tvm/relax/attrs/qdq.h b/include/tvm/relax/attrs/qdq.h index ffb554994f98..08bc054dc54f 100644 --- a/include/tvm/relax/attrs/qdq.h +++ b/include/tvm/relax/attrs/qdq.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes for relax.quantize/relax.dequantize operator */ -struct QuantizeAttrs : public AttrsNodeReflAdapter { +struct QuantizeAttrs : public BaseAttrsNode { DataType out_dtype; int axis; diff --git a/include/tvm/relax/attrs/sampling.h b/include/tvm/relax/attrs/sampling.h index 53fd3a140497..2d7421cc20e8 100644 --- a/include/tvm/relax/attrs/sampling.h +++ b/include/tvm/relax/attrs/sampling.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in multinomial_from_uniform operator */ -struct MultinomialFromUniformAttrs : public AttrsNodeReflAdapter { +struct MultinomialFromUniformAttrs : public BaseAttrsNode { DataType dtype; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/search.h b/include/tvm/relax/attrs/search.h index 32327c160d1d..015e5d8edc1c 100644 --- a/include/tvm/relax/attrs/search.h +++ b/include/tvm/relax/attrs/search.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes for search operators */ -struct ArgmaxArgminAttrs : public AttrsNodeReflAdapter { +struct ArgmaxArgminAttrs : public BaseAttrsNode { ffi::Optional axis; bool keepdims; @@ -49,7 +49,7 @@ struct ArgmaxArgminAttrs : public AttrsNodeReflAdapter { }; // struct ArgmaxArgminAttrs /*! \brief Attributes for bucketize operator */ -struct BucketizeAttrs : public tvm::AttrsNodeReflAdapter { +struct BucketizeAttrs : public tvm::BaseAttrsNode { bool out_int32; bool right; diff --git a/include/tvm/relax/attrs/sorting.h b/include/tvm/relax/attrs/sorting.h index b77fa2ecc72c..e32d47239f35 100644 --- a/include/tvm/relax/attrs/sorting.h +++ b/include/tvm/relax/attrs/sorting.h @@ -31,7 +31,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in sort operator */ -struct SortAttrs : public AttrsNodeReflAdapter { +struct SortAttrs : public BaseAttrsNode { int axis; bool descending; @@ -51,7 +51,7 @@ struct SortAttrs : public AttrsNodeReflAdapter { }; // struct SortAttrs /*! \brief Attributes used in argsort operator */ -struct ArgsortAttrs : public AttrsNodeReflAdapter { +struct ArgsortAttrs : public BaseAttrsNode { int axis; bool descending; DataType dtype; @@ -74,7 +74,7 @@ struct ArgsortAttrs : public AttrsNodeReflAdapter { }; // struct ArgsortAttrs /*! \brief Attributes used in topk operator */ -struct TopKAttrs : public AttrsNodeReflAdapter { +struct TopKAttrs : public BaseAttrsNode { int k; int axis; bool largest; diff --git a/include/tvm/relax/attrs/statistical.h b/include/tvm/relax/attrs/statistical.h index 433524116d3c..367869f1ab11 100644 --- a/include/tvm/relax/attrs/statistical.h +++ b/include/tvm/relax/attrs/statistical.h @@ -30,7 +30,7 @@ namespace tvm { namespace relax { /*! \brief Attributes for statistical operators */ -struct StatisticalAttrs : public AttrsNodeReflAdapter { +struct StatisticalAttrs : public BaseAttrsNode { ffi::Optional> axis; bool keepdims; @@ -49,7 +49,7 @@ struct StatisticalAttrs : public AttrsNodeReflAdapter { }; // struct StatisticalAttrs /*! \brief Attributes used in scan operators like cumsum, cumprod */ -struct ScanopAttrs : public AttrsNodeReflAdapter { +struct ScanopAttrs : public BaseAttrsNode { ffi::Optional axis; DataType dtype; Bool exclusive = Bool(false); diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h index 55ed162674e2..37ec77cbbff6 100644 --- a/include/tvm/relax/attrs/vision.h +++ b/include/tvm/relax/attrs/vision.h @@ -32,8 +32,7 @@ namespace tvm { namespace relax { /*! \brief Attributes used in AllClassNonMaximumSuppression operator */ -struct AllClassNonMaximumSuppressionAttrs - : public AttrsNodeReflAdapter { +struct AllClassNonMaximumSuppressionAttrs : public BaseAttrsNode { ffi::String output_format; static void RegisterReflection() { @@ -48,7 +47,7 @@ struct AllClassNonMaximumSuppressionAttrs }; // struct AllClassNonMaximumSuppressionAttrs /*! \brief Attributes used in ROIAlign operator */ -struct ROIAlignAttrs : public AttrsNodeReflAdapter { +struct ROIAlignAttrs : public BaseAttrsNode { ffi::Array pooled_size; double spatial_scale; int sample_ratio; @@ -73,7 +72,7 @@ struct ROIAlignAttrs : public AttrsNodeReflAdapter { }; // struct ROIAlignAttrs /*! \brief Attributes used in ROIPool operator */ -struct ROIPoolAttrs : public AttrsNodeReflAdapter { +struct ROIPoolAttrs : public BaseAttrsNode { ffi::Array pooled_size; double spatial_scale; ffi::String layout; @@ -90,7 +89,7 @@ struct ROIPoolAttrs : public AttrsNodeReflAdapter { }; // struct ROIPoolAttrs /*! \brief Attributes used in GetValidCounts operator */ -struct GetValidCountsAttrs : public AttrsNodeReflAdapter { +struct GetValidCountsAttrs : public BaseAttrsNode { double score_threshold; int id_index; int score_index; @@ -110,7 +109,7 @@ struct GetValidCountsAttrs : public AttrsNodeReflAdapter { }; // struct GetValidCountsAttrs /*! \brief Attributes used in NonMaximumSuppression operator */ -struct NonMaximumSuppressionAttrs : public AttrsNodeReflAdapter { +struct NonMaximumSuppressionAttrs : public BaseAttrsNode { int max_output_size; double iou_threshold; bool force_suppress; @@ -154,7 +153,7 @@ struct NonMaximumSuppressionAttrs : public AttrsNodeReflAdapter { +struct MultiboxTransformLocAttrs : public BaseAttrsNode { bool clip; double threshold; ffi::Array variances; diff --git a/include/tvm/target/virtual_device.h b/include/tvm/target/virtual_device.h index 5ff282adb68b..79475262c4a4 100644 --- a/include/tvm/target/virtual_device.h +++ b/include/tvm/target/virtual_device.h @@ -169,7 +169,7 @@ constexpr int kInvalidDeviceType = -1; * These operations are needed during device planning. */ -class VirtualDeviceNode : public AttrsNodeReflAdapter { +class VirtualDeviceNode : public BaseAttrsNode { private: /*! * \brief The \p DLDeviceType (represented as an int) of the virtual device. If \p target is From b08ef238ef5de8ced9d223e5edc62c6a21bb0bc9 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 25 May 2026 20:42:23 +0000 Subject: [PATCH 4/8] [REFACTOR] Migrate pass-config classes to subclass ffi::Object --- python/tvm/s_tir/transform/transform.py | 5 ++--- python/tvm/tirx/transform/transform.py | 11 +++++------ src/relax/backend/contrib/clml/codegen.cc | 8 ++++---- src/relax/backend/contrib/tensorrt/codegen.cc | 8 ++++---- src/s_tir/transform/hoist_expression.cc | 12 ++++++------ src/s_tir/transform/inject_double_buffer.cc | 8 ++++---- src/s_tir/transform/loop_partition.cc | 8 ++++---- src/s_tir/transform/unify_thread_binding.cc | 3 +-- src/tirx/transform/remove_no_op.cc | 9 +++++---- src/tirx/transform/simplify.cc | 14 +++++++++----- src/tirx/transform/unroll_loop.cc | 9 +++++---- 11 files changed, 49 insertions(+), 46 deletions(-) diff --git a/python/tvm/s_tir/transform/transform.py b/python/tvm/s_tir/transform/transform.py index e8d14171b331..af4ec493cc14 100644 --- a/python/tvm/s_tir/transform/transform.py +++ b/python/tvm/s_tir/transform/transform.py @@ -18,7 +18,6 @@ # pylint: disable=invalid-name, unsupported-binary-operation from ... import ffi as _ffi -from ... import ir as _ir from . import _ffi_api @@ -213,7 +212,7 @@ def AnnotateIrregularLoop(): @_ffi.register_object("s_tir.transform.LoopPartitionConfig") -class LoopPartitionConfig(_ir.Attrs): +class LoopPartitionConfig(_ffi.Object): """Config for loop partition pass""" @@ -240,7 +239,7 @@ def InjectVirtualThread(): @_ffi.register_object("s_tir.transform.InjectDoubleBufferConfig") -class InjectDoubleBufferConfig(_ir.Attrs): +class InjectDoubleBufferConfig(_ffi.Object): """Config for inject double buffer pass""" diff --git a/python/tvm/tirx/transform/transform.py b/python/tvm/tirx/transform/transform.py index 8082d864c1e9..fbf07b5f4897 100644 --- a/python/tvm/tirx/transform/transform.py +++ b/python/tvm/tirx/transform/transform.py @@ -21,7 +21,6 @@ from collections.abc import Callable from ... import ffi as _ffi -from ... import ir as _ir from . import _ffi_api from . import function_pass as _fpass @@ -107,7 +106,7 @@ def PointerValueTypeRewrite(): @_ffi.register_object("tirx.transform.UnrollLoopConfig") -class UnrollLoopConfig(_ir.Attrs): +class UnrollLoopConfig(_ffi.Object): """Config for unroll loop pass""" @@ -125,7 +124,7 @@ def UnrollLoop(): @_ffi.register_object("tirx.transform.RemoveNoOpConfig") -class RemoveNoOpConfig(_ir.Attrs): +class RemoveNoOpConfig(_ffi.Object): """Config for remove no op pass""" @@ -212,7 +211,7 @@ def CommonSubexprElim(): @_ffi.register_object("tirx.transform.SimplifyConfig") -class SimplifyConfig(_ir.Attrs): +class SimplifyConfig(_ffi.Object): """Config for simplify pass""" @@ -429,7 +428,7 @@ def VerifyMemory(): @_ffi.register_object("s_tir.transform.HoistIfThenElseConfig") -class HoistIfThenElseConfig(_ir.Attrs): +class HoistIfThenElseConfig(_ffi.Object): """Config for hoist if then else pass""" @@ -483,7 +482,7 @@ class HoistedLetBindings(enum.Flag): @_ffi.register_object("s_tir.transform.HoistExpressionConfig") -class HoistExpressionConfig(_ir.Attrs): +class HoistExpressionConfig(_ffi.Object): """Config for hoist expression pass""" diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index eaa57f8315e4..dd71e8a68a51 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -41,7 +41,7 @@ namespace relax { namespace contrib { /*! \brief Attributes to store the compiler options for OpenCLML. */ -struct OpenCLMLCompilerConfigNode : public AttrsNodeReflAdapter { +struct OpenCLMLCompilerConfigNode : public ffi::Object { Integer clml_version; static void RegisterReflection() { @@ -51,12 +51,12 @@ struct OpenCLMLCompilerConfigNode : public AttrsNodeReflAdapter { +struct TensorRTCompilerConfigNode : public ffi::Object { ffi::Array tensorrt_version; bool use_implicit_batch; size_t max_workspace_size; @@ -72,12 +72,12 @@ struct TensorRTCompilerConfigNode : public AttrsNodeReflAdapter { +struct HoistExpressionConfigNode : public ffi::Object { int hoisted_conditionals; int hoisted_let_bindings; @@ -87,7 +87,7 @@ struct HoistExpressionConfigNode : public AttrsNodeReflAdapter(); @@ -95,7 +95,7 @@ class HoistExpressionConfig : public Attrs { node->hoisted_let_bindings = hoisted_let_bindings; data_ = std::move(node); } - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(HoistExpressionConfig, Attrs, + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(HoistExpressionConfig, ffi::ObjectRef, HoistExpressionConfigNode); }; @@ -103,7 +103,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { HoistExpressionConfigNode::RegisterReflection(); } TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.HoistExpression", HoistExpressionConfig); -struct HoistIfThenElseConfigNode : public AttrsNodeReflAdapter { +struct HoistIfThenElseConfigNode : public ffi::Object { bool support_block_scope_hoisting; static void RegisterReflection() { @@ -116,9 +116,9 @@ struct HoistIfThenElseConfigNode : public AttrsNodeReflAdapter { +struct InjectDoubleBufferConfigNode : public ffi::Object { int split_loop; static void RegisterReflection() { @@ -46,12 +46,12 @@ struct InjectDoubleBufferConfigNode : public AttrsNodeReflAdapter { +struct LoopPartitionConfigNode : public ffi::Object { bool partition_const_loop; bool no_unroll_loop_with_extent_one; bool unroll_loop_with_partition_hint_no_interval; @@ -64,14 +64,14 @@ struct LoopPartitionConfigNode : public AttrsNodeReflAdaptervar, thread_binding->dom->min, thread_binding->dom->extent, ForKind::kThreadBinding, result, - IterVar(Range(), Var(""), IterVarType::kThreadIndex, - thread_binding->thread_tag), + IterVar(Range(), Var(""), IterVarType::kThreadIndex, thread_binding->thread_tag), {}, std::nullopt); launch_threads_.pop_back(); } diff --git a/src/tirx/transform/remove_no_op.cc b/src/tirx/transform/remove_no_op.cc index fcc7519334d0..4bdb5c083c01 100644 --- a/src/tirx/transform/remove_no_op.cc +++ b/src/tirx/transform/remove_no_op.cc @@ -44,7 +44,7 @@ namespace tvm { namespace tirx { -struct RemoveNoOpConfigNode : public AttrsNodeReflAdapter { +struct RemoveNoOpConfigNode : public ffi::Object { bool use_dataflow_analysis; int64_t max_simplification_steps; bool ignore_profiler_call; @@ -65,12 +65,13 @@ struct RemoveNoOpConfigNode : public AttrsNodeReflAdapter "If true, profiler calls are rendered as no-ops.", refl::DefaultValue(false)); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.transform.RemoveNoOpConfig", RemoveNoOpConfigNode, - BaseAttrsNode); + ffi::Object); }; -class RemoveNoOpConfig : public Attrs { +class RemoveNoOpConfig : public ffi::ObjectRef { public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RemoveNoOpConfig, Attrs, RemoveNoOpConfigNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RemoveNoOpConfig, ffi::ObjectRef, + RemoveNoOpConfigNode); }; TVM_FFI_STATIC_INIT_BLOCK() { RemoveNoOpConfigNode::RegisterReflection(); } diff --git a/src/tirx/transform/simplify.cc b/src/tirx/transform/simplify.cc index f193fb502da9..bf80ad00a455 100644 --- a/src/tirx/transform/simplify.cc +++ b/src/tirx/transform/simplify.cc @@ -44,7 +44,7 @@ namespace arith { using namespace tirx; -struct SimplifyConfigNode : public AttrsNodeReflAdapter { +struct SimplifyConfigNode : public ffi::Object { bool transitively_prove_inequalities; bool propagate_knowns_to_prove_conditional; bool propagate_knowns_to_simplify_expressions; @@ -78,7 +78,7 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter { refl::DefaultValue(false)); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.transform.SimplifyConfig", SimplifyConfigNode, - BaseAttrsNode); + ffi::Object); RewriteSimplifier::Extension GetEnabledExtensions() const { RewriteSimplifier::Extension flags = RewriteSimplifier::kNone; @@ -97,11 +97,15 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter { } }; -class SimplifyConfig : public Attrs { +class SimplifyConfig : public ffi::ObjectRef { public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, Attrs, SimplifyConfigNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, ffi::ObjectRef, SimplifyConfigNode); }; +static SimplifyConfig MakeDefaultSimplifyConfig() { + return AttrsWithDefaultValues(); +} + TVM_FFI_STATIC_INIT_BLOCK() { SimplifyConfigNode::RegisterReflection(); } TVM_REGISTER_PASS_CONFIG_OPTION("tirx.Simplify", SimplifyConfig); @@ -110,7 +114,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { public: static PrimFunc Apply(PrimFunc func, Analyzer* analyzer, ffi::Optional config_opt = std::nullopt) { - auto config = config_opt.value_or(AttrsWithDefaultValues()); + auto config = config_opt.value_or(MakeDefaultSimplifyConfig()); analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions()); std::optional touch_pattern = std::nullopt; diff --git a/src/tirx/transform/unroll_loop.cc b/src/tirx/transform/unroll_loop.cc index 3aea9ddd04c9..faf1ec2d677d 100644 --- a/src/tirx/transform/unroll_loop.cc +++ b/src/tirx/transform/unroll_loop.cc @@ -39,7 +39,7 @@ namespace tvm { namespace tirx { -struct UnrollLoopConfigNode : public AttrsNodeReflAdapter { +struct UnrollLoopConfigNode : public ffi::Object { int auto_max_step; int auto_max_depth; int auto_max_extent; @@ -64,12 +64,13 @@ struct UnrollLoopConfigNode : public AttrsNodeReflAdapter "Whether to always unroll local access", refl::DefaultValue(false)); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.transform.UnrollLoopConfig", UnrollLoopConfigNode, - BaseAttrsNode); + ffi::Object); }; -class UnrollLoopConfig : public Attrs { +class UnrollLoopConfig : public ffi::ObjectRef { public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(UnrollLoopConfig, Attrs, UnrollLoopConfigNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(UnrollLoopConfig, ffi::ObjectRef, + UnrollLoopConfigNode); }; TVM_FFI_STATIC_INIT_BLOCK() { UnrollLoopConfigNode::RegisterReflection(); } From e33573bddc4f3639fbda1dfceb899cadf91cce67 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 25 May 2026 22:52:35 +0000 Subject: [PATCH 5/8] [REFACTOR][RELAX] Implement axis=None (flip all axes) for relax.flip MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address Gemini high-priority review on #19607: after migrating FlipAttrs::axis to Optional, the existing call sites unconditionally call attrs->axis.value(), which throws when axis is nullopt. Implement the NumPy semantics — axis=None flips every axis — so the optional field has a well-defined meaning instead of being a landmine. - InferStructInfoFlip handles missing axis (shape unchanged). - InferLayoutFlip handles missing axis (layout preserved as-is). - flip lowering generates a per-axis sequence of topi.flip calls when axis is missing (Option A, matches the single-axis TE pattern). - Python wrapper defaults to axis=None. - New tests: test_flip_infer_struct_info_axis_none and test_flip_axis_none (end-to-end execution against np.flip). --- python/tvm/relax/op/manipulate.py | 9 +++++--- .../transform/legalize_ops/manipulate.py | 18 ++++++++++++++- src/relax/op/tensor/manipulate.cc | 22 +++++++++++++------ tests/python/relax/test_op_manipulate.py | 12 ++++++++++ .../test_transform_legalize_ops_manipulate.py | 21 ++++++++++++++++++ 5 files changed, 71 insertions(+), 11 deletions(-) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 3ce70fc545fb..c1df820323a2 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -432,7 +432,7 @@ def tile(data: Expr, repeats: int | tuple[int] | list[int]) -> Expr: return _ffi_api.tile(data, repeats) # type: ignore -def flip(data, axis): +def flip(data, axis=None): """Reverses the order of elements along given axis while preserving array shape. Parameters @@ -440,8 +440,9 @@ def flip(data, axis): data : relax.Expr The input data to the operator. - axis: int - axis to flip on + axis: int, optional + The axis along which to flip. If ``None`` (default), flip all axes, + which is equivalent to NumPy's ``np.flip(data)`` with no axis argument. Returns ------- @@ -456,6 +457,8 @@ def flip(data, axis): relax.flip(x, axis=0) = [[3., 4.], [1., 2.]] relax.flip(x, axis=1) = [[2., 1.], [4., 3.]] + + relax.flip(x) = [[4., 3.], [2., 1.]] # flip all axes """ return _ffi_api.flip(data, axis) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index fc7ee0d12eb8..efe7c333380c 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -167,7 +167,23 @@ def _tile(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.flip") def _flip(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(topi.flip, call.args[0], int(call.attrs.axis)) + axis = call.attrs.axis + if axis is None: + # axis=None means flip all axes (NumPy semantics). + # Get the number of dimensions from the input struct info. + data_sinfo = call.args[0].struct_info + ndim = data_sinfo.ndim + if ndim < 0: + raise ValueError( + "relax.flip with axis=None requires static ndim to lower to TE; " + "ndim is unknown." + ) + # Apply topi.flip for each axis in sequence. + result = call.args[0] + for i in range(ndim): + result = bb.emit_te(topi.flip, result, i) + return result + return bb.call_te(topi.flip, call.args[0], int(axis)) @register_legalize("relax.gather_elements") diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 98b7405bc030..75ba8a24a44c 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -2043,13 +2043,16 @@ StructInfo InferStructInfoFlip(const Call& call, const BlockBuilder& ctx) { } TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); - int axis = static_cast(attrs->axis.value()); - if (!data_sinfo->IsUnknownNdim()) { - int ndim = data_sinfo->ndim; - if (axis < -ndim || axis >= ndim) { - ctx->ReportFatal(Diagnostic::Error(call) << "Flip requires the input axis belongs range " - "[-ndim, ndim - 1]. However, the input axis is " - << axis << ", while ndim is " << ndim); + // axis == nullopt means flip all axes (NumPy semantics); shape is unchanged. + if (attrs->axis.has_value()) { + int axis = static_cast(attrs->axis.value()); + if (!data_sinfo->IsUnknownNdim()) { + int ndim = data_sinfo->ndim; + if (axis < -ndim || axis >= ndim) { + ctx->ReportFatal(Diagnostic::Error(call) << "Flip requires the input axis belongs range " + "[-ndim, ndim - 1]. However, the input axis is " + << axis << ", while ndim is " << ndim); + } } } return data_sinfo; @@ -2073,6 +2076,11 @@ InferLayoutOutput InferLayoutFlip( existing_layout = LayoutDecision(InitialLayout(ndim)); } + // axis == nullopt means flip all axes; no layout remapping needed. + if (!attrs->axis.has_value()) { + return InferLayoutOutput({existing_layout}, {existing_layout}, call->attrs); + } + int axis = static_cast(attrs->axis.value()); if (axis < 0) { axis += ndim; diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 537bc9c06c04..f56088c2503d 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -3213,6 +3213,18 @@ def test_flip_infer_struct_info_wrong_inputs(): bb.normalize(relax.op.flip(x0, axis=3)) +def test_flip_infer_struct_info_axis_none(): + # axis=None (flip all axes) should produce the same struct info as the input. + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float16", ndim=3)) + x2 = relax.Var("x", R.Tensor("int32")) + + _check_inference(bb, relax.op.flip(x0), relax.TensorStructInfo((2, 10, 4), "float32")) + _check_inference(bb, relax.op.flip(x1), R.Tensor("float16", ndim=3)) + _check_inference(bb, relax.op.flip(x2), R.Tensor("int32")) + + def test_gather_elements_infer_struct_info(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index 8734f76bbb37..30eabe4d0c73 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -1363,6 +1363,27 @@ def flip(var_rxplaceholder: T.handle, var_T_reverse_sequence: T.handle): tvm.ir.assert_structural_equal(mod, Expected) +def test_flip_axis_none(): + """axis=None flips all axes; verify by running on a concrete tensor via numpy.""" + import numpy as np + + @I.ir_module(s_tir=True) + class Flip: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.flip(x) + return gv + + mod = LegalizeOps()(Flip) + # After legalization the module should still be valid and runnable. + ex = tvm.compile(mod, target=tvm.target.Target("llvm", host="llvm")) + vm = relax.VirtualMachine(ex, tvm.cpu()) + data = np.arange(6, dtype="float32").reshape(2, 3) + out = vm["main"](tvm.runtime.tensor(data)).numpy() + expected = np.flip(data) + np.testing.assert_array_equal(out, expected) + + def test_scatter_elements(): # fmt: off @I.ir_module(s_tir=True) From 5aea5d54fd6e8ac4cb085d01aee0e520009c75a2 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 00:38:45 +0000 Subject: [PATCH 6/8] [CI][LINT] Collapse multi-line ValueError message to single line for ruff-format Address ruff-format CI failure on #19607: the two-line string formed by adjacent string literals exceeds the single-line preferred form ruff-format wants when it fits. --- python/tvm/relax/transform/legalize_ops/manipulate.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index efe7c333380c..ce69bf056267 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -175,8 +175,7 @@ def _flip(bb: BlockBuilder, call: Call) -> Expr: ndim = data_sinfo.ndim if ndim < 0: raise ValueError( - "relax.flip with axis=None requires static ndim to lower to TE; " - "ndim is unknown." + "relax.flip with axis=None requires static ndim to lower to TE; ndim is unknown." ) # Apply topi.flip for each axis in sequence. result = call.args[0] From f506f284eb558cabf668e13712ced836b2868a48 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 11:26:07 +0000 Subject: [PATCH 7/8] [BUG][IR] Fix Var() default ctor regression in storage_access.h MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In commit ff4060f0c6, the NullValue() default initializer for StorageAccessVisitor::AccessEntry::buffer was replaced with bare `Var buffer;`. Unlike most ObjectRef subclasses, Var has an all-default-arg constructor (Var(name="v", dtype=Int(32))), so `Var()` does NOT produce a null Var — it produces a real Var named "v". This made every default-constructed AccessEntry look like it has a defined buffer, breaking the precondition of the ForNode visitor's range-relaxation step (touched.size() ICHECK). Surfaced as CUDA CI failures in paged-attention KV cache and prefill tests post-PR-19607. Restore the null semantics by using the explicit Var(ObjectPtr(nullptr)) constructor — equivalent to the deleted NullValue(). --- src/s_tir/transform/storage_access.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/s_tir/transform/storage_access.h b/src/s_tir/transform/storage_access.h index 9bc1cf42b2e0..d85dc5a3c3ae 100644 --- a/src/s_tir/transform/storage_access.h +++ b/src/s_tir/transform/storage_access.h @@ -59,7 +59,7 @@ class StorageAccessVisitor : public StmtExprVisitor { /*! \brief The thread index that access this entry */ ffi::Array threads; /*! \brief The buffer variable, if any */ - Var buffer; + Var buffer = Var(ffi::ObjectPtr(nullptr)); /*! \brief The access data type */ DataType dtype; /*! \brief The touched access range From 35178d5de9d07f306c0d6cef5c5c01650ce47528 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 26 May 2026 11:58:33 +0000 Subject: [PATCH 8/8] [REFACTOR][RELAX] Revert FlipAttrs::axis to non-optional int64_t MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Earlier in this PR the field was migrated from Integer to Optional on the rule "Integer fields become Optional". That rule doesn't fit here — relax.flip's axis is a required argument, not a nullable one, and the brief NumPy axis=None compatibility added on top complicates the surface without callers asking for it. Revert the field, signature, and Python wrapper to int64_t. Drop the axis-None legalize branch and the corresponding tests. --- include/tvm/relax/attrs/manipulate.h | 2 +- python/tvm/relax/op/manipulate.py | 9 ++---- .../transform/legalize_ops/manipulate.py | 17 +---------- src/relax/op/tensor/manipulate.cc | 28 +++++++------------ src/relax/op/tensor/manipulate.h | 2 +- tests/python/relax/test_op_manipulate.py | 12 -------- .../test_transform_legalize_ops_manipulate.py | 21 -------------- 7 files changed, 16 insertions(+), 75 deletions(-) diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index e43750b9a706..71fb7b0b95ef 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -186,7 +186,7 @@ struct TileAttrs : public BaseAttrsNode { /*! \brief Attributes used in flip operators */ struct FlipAttrs : public BaseAttrsNode { - ffi::Optional axis; + int64_t axis; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index c1df820323a2..21fd7b565c4e 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -432,7 +432,7 @@ def tile(data: Expr, repeats: int | tuple[int] | list[int]) -> Expr: return _ffi_api.tile(data, repeats) # type: ignore -def flip(data, axis=None): +def flip(data, axis): """Reverses the order of elements along given axis while preserving array shape. Parameters @@ -440,9 +440,8 @@ def flip(data, axis=None): data : relax.Expr The input data to the operator. - axis: int, optional - The axis along which to flip. If ``None`` (default), flip all axes, - which is equivalent to NumPy's ``np.flip(data)`` with no axis argument. + axis: int + The axis along which to flip over. Returns ------- @@ -457,8 +456,6 @@ def flip(data, axis=None): relax.flip(x, axis=0) = [[3., 4.], [1., 2.]] relax.flip(x, axis=1) = [[2., 1.], [4., 3.]] - - relax.flip(x) = [[4., 3.], [2., 1.]] # flip all axes """ return _ffi_api.flip(data, axis) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index ce69bf056267..fc7ee0d12eb8 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -167,22 +167,7 @@ def _tile(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.flip") def _flip(bb: BlockBuilder, call: Call) -> Expr: - axis = call.attrs.axis - if axis is None: - # axis=None means flip all axes (NumPy semantics). - # Get the number of dimensions from the input struct info. - data_sinfo = call.args[0].struct_info - ndim = data_sinfo.ndim - if ndim < 0: - raise ValueError( - "relax.flip with axis=None requires static ndim to lower to TE; ndim is unknown." - ) - # Apply topi.flip for each axis in sequence. - result = call.args[0] - for i in range(ndim): - result = bb.emit_te(topi.flip, result, i) - return result - return bb.call_te(topi.flip, call.args[0], int(axis)) + return bb.call_te(topi.flip, call.args[0], int(call.attrs.axis)) @register_legalize("relax.gather_elements") diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 75ba8a24a44c..f6fc45deaa39 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -2025,9 +2025,9 @@ TVM_REGISTER_OP("relax.tile") /* relax.flip */ -Expr flip(Expr data, ffi::Optional axis) { +Expr flip(Expr data, int64_t axis) { auto attrs = ffi::make_object(); - attrs->axis = std::move(axis); + attrs->axis = axis; static const Op& op = Op::Get("relax.flip"); return Call(op, {std::move(data)}, Attrs{attrs}, {}); } @@ -2043,16 +2043,13 @@ StructInfo InferStructInfoFlip(const Call& call, const BlockBuilder& ctx) { } TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); - // axis == nullopt means flip all axes (NumPy semantics); shape is unchanged. - if (attrs->axis.has_value()) { - int axis = static_cast(attrs->axis.value()); - if (!data_sinfo->IsUnknownNdim()) { - int ndim = data_sinfo->ndim; - if (axis < -ndim || axis >= ndim) { - ctx->ReportFatal(Diagnostic::Error(call) << "Flip requires the input axis belongs range " - "[-ndim, ndim - 1]. However, the input axis is " - << axis << ", while ndim is " << ndim); - } + int axis = static_cast(attrs->axis); + if (!data_sinfo->IsUnknownNdim()) { + int ndim = data_sinfo->ndim; + if (axis < -ndim || axis >= ndim) { + ctx->ReportFatal(Diagnostic::Error(call) << "Flip requires the input axis belongs range " + "[-ndim, ndim - 1]. However, the input axis is " + << axis << ", while ndim is " << ndim); } } return data_sinfo; @@ -2076,12 +2073,7 @@ InferLayoutOutput InferLayoutFlip( existing_layout = LayoutDecision(InitialLayout(ndim)); } - // axis == nullopt means flip all axes; no layout remapping needed. - if (!attrs->axis.has_value()) { - return InferLayoutOutput({existing_layout}, {existing_layout}, call->attrs); - } - - int axis = static_cast(attrs->axis.value()); + int axis = static_cast(attrs->axis); if (axis < 0) { axis += ndim; } diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 0f278a0a008a..a6efffff4673 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -179,7 +179,7 @@ Expr tile(Expr data, ffi::Array repeats); * \param axis The axis to flip on * \return The computed result. */ -Expr flip(Expr data, ffi::Optional axis); +Expr flip(Expr data, int64_t axis); /*! * \brief Gather elements from a tensor using indices. diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index f56088c2503d..537bc9c06c04 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -3213,18 +3213,6 @@ def test_flip_infer_struct_info_wrong_inputs(): bb.normalize(relax.op.flip(x0, axis=3)) -def test_flip_infer_struct_info_axis_none(): - # axis=None (flip all axes) should produce the same struct info as the input. - bb = relax.BlockBuilder() - x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) - x1 = relax.Var("x", R.Tensor("float16", ndim=3)) - x2 = relax.Var("x", R.Tensor("int32")) - - _check_inference(bb, relax.op.flip(x0), relax.TensorStructInfo((2, 10, 4), "float32")) - _check_inference(bb, relax.op.flip(x1), R.Tensor("float16", ndim=3)) - _check_inference(bb, relax.op.flip(x2), R.Tensor("int32")) - - def test_gather_elements_infer_struct_info(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index 30eabe4d0c73..8734f76bbb37 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -1363,27 +1363,6 @@ def flip(var_rxplaceholder: T.handle, var_T_reverse_sequence: T.handle): tvm.ir.assert_structural_equal(mod, Expected) -def test_flip_axis_none(): - """axis=None flips all axes; verify by running on a concrete tensor via numpy.""" - import numpy as np - - @I.ir_module(s_tir=True) - class Flip: - @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): - gv = R.flip(x) - return gv - - mod = LegalizeOps()(Flip) - # After legalization the module should still be valid and runnable. - ex = tvm.compile(mod, target=tvm.target.Target("llvm", host="llvm")) - vm = relax.VirtualMachine(ex, tvm.cpu()) - data = np.arange(6, dtype="float32").reshape(2, 3) - out = vm["main"](tvm.runtime.tensor(data)).numpy() - expected = np.flip(data) - np.testing.assert_array_equal(out, expected) - - def test_scatter_elements(): # fmt: off @I.ir_module(s_tir=True)