From b5bc285bc59b7e4e552ea3e8249488a8d4405c12 Mon Sep 17 00:00:00 2001 From: GinShio Date: Tue, 4 Nov 2025 21:52:07 +0800 Subject: [PATCH] Support enum class in struct-backed type We may use enumeration class in struct backed type, which makes struct type more semantic. However, we need to force these values to be cast from integers to the target type. --- example/ExampleDialect.td | 2 +- example/ExampleMain.cpp | 2 +- lib/TableGen/DialectType.cpp | 24 +++++++++++++------ test/example/generated/ExampleDialect.cpp.inc | 8 +++---- test/example/generated/ExampleDialect.h.inc | 12 +++++----- 5 files changed, 29 insertions(+), 19 deletions(-) diff --git a/example/ExampleDialect.td b/example/ExampleDialect.td index 8f6d172..1c7d7fd 100644 --- a/example/ExampleDialect.td +++ b/example/ExampleDialect.td @@ -367,7 +367,7 @@ def StructBackedType : DialectType { let description = [{ Test that a struct-backed type works correctly. }]; - let typeArguments = (args AttrI32:$field0, AttrI32:$field1, AttrI32:$field2); + let typeArguments = (args AttrI32:$field0, AttrI8:$field1, AttrVectorKind:$field2); let representation = (repr_struct (IntegerType 41)); let defaultGetterHasExplicitContextArgument = 1; diff --git a/example/ExampleMain.cpp b/example/ExampleMain.cpp index b78c601..763a9ab 100644 --- a/example/ExampleMain.cpp +++ b/example/ExampleMain.cpp @@ -149,7 +149,7 @@ void createFunctionExample(Module &module, const Twine &name) { b.create("Hello world!"); - xd::cpp::StructBackedType *structBackedTy = xd::cpp::StructBackedType::get(bb->getContext(), 1, 0, 2); + xd::cpp::StructBackedType *structBackedTy = xd::cpp::StructBackedType::get(bb->getContext(), 1, 0, xd::cpp::VectorKind::BigEndian); auto *structBackedVal = b.create(structBackedTy, b.getInt32(42), "gen.struct.backed.val"); b.create(structBackedVal, "consume.struct.backed.val"); diff --git a/lib/TableGen/DialectType.cpp b/lib/TableGen/DialectType.cpp index 17c2916..fc0b53c 100644 --- a/lib/TableGen/DialectType.cpp +++ b/lib/TableGen/DialectType.cpp @@ -201,17 +201,24 @@ void DialectType::emitDeclaration(raw_ostream &out, GenDialect *dialect) const { out << " static bool classof(const ::llvm::Type *t);\n\n"; unsigned fieldIdx = 1; // sentinel + auto getCastExpr = [&fmt](const NamedValue &argument, + llvm::StringRef expr) -> std::string { + return tgfmt(cast(argument.type)->getFromUnsigned(), &fmt, expr); + }; for (const auto &argument : typeArguments()) { std::string camel = convertToCamelFromSnakeCase(argument.name, true); out << tgfmt( - R"( unsigned get$0() const { - ::llvm::Type *elt = getElementType($1); + R"( $0 get$1() const { + ::llvm::Type *elt = getElementType($2); if (elt->isStructTy()) - return 0; - return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth(); + return $3; + return $4; } )", - &fmt, camel, fieldIdx++); + &fmt, argument.type->getCppType(), camel, fieldIdx++, + getCastExpr(argument, "0"), + getCastExpr(argument, + "::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth()")); } out << " };\n\n"; @@ -307,14 +314,17 @@ void DialectType::emitDefinition(raw_ostream &out, GenDialect *dialect) const { " $fields.push_back(::llvm::IntegerType::get($_context, $0));\n", &fmt, Twine(m_structSentinelBitWidth)); - for (const auto &getterArg : getterArgs) { + for (const auto &[argument, getterArg] : + llvm::zip(typeArguments(), getterArgs)) { + std::string castExpr = tgfmt(cast(argument.type)->getToUnsigned(), + &fmt, getterArg.name); out << tgfmt(R"( if ($0 == 0) $fields.push_back(::llvm::StructType::get($_context)); else $fields.push_back(::llvm::IntegerType::get($_context, $0)); )", - &fmt, getterArg.name); + &fmt, castExpr); } out << tgfmt(" auto *$st = ::llvm::StructType::create($_context, " "$fields, $os.str(), /*isPacked=*/false);\n", diff --git a/test/example/generated/ExampleDialect.cpp.inc b/test/example/generated/ExampleDialect.cpp.inc index 2a39b5b..902dc84 100644 --- a/test/example/generated/ExampleDialect.cpp.inc +++ b/test/example/generated/ExampleDialect.cpp.inc @@ -258,10 +258,10 @@ m_attributeLists[6] = argAttrList.addFnAttributes(context, attrBuilder); } } -StructBackedType* StructBackedType::get(::llvm::LLVMContext & ctx, uint32_t field0, uint32_t field1, uint32_t field2) { - +StructBackedType* StructBackedType::get(::llvm::LLVMContext & ctx, uint32_t field0, uint8_t field1, VectorKind field2) { +static_assert(sizeof(field2) <= sizeof(unsigned)); std::string name; ::llvm::raw_string_ostream os(name); os << "struct.backed"; os << '.' << (uint64_t)field0; @@ -280,10 +280,10 @@ StructBackedType* StructBackedType::get(::llvm::LLVMContext & ctx, uint32_t fiel else fields.push_back(::llvm::IntegerType::get(ctx, field1)); - if (field2 == 0) + if (static_cast(field2) == 0) fields.push_back(::llvm::StructType::get(ctx)); else - fields.push_back(::llvm::IntegerType::get(ctx, field2)); + fields.push_back(::llvm::IntegerType::get(ctx, static_cast(field2))); auto *st = ::llvm::StructType::create(ctx, fields, os.str(), /*isPacked=*/false); return static_cast(st); } diff --git a/test/example/generated/ExampleDialect.h.inc b/test/example/generated/ExampleDialect.h.inc index 46189b9..8530a3a 100644 --- a/test/example/generated/ExampleDialect.h.inc +++ b/test/example/generated/ExampleDialect.h.inc @@ -58,27 +58,27 @@ namespace xd::cpp { using ::llvm::StructType::getElementType; static StructBackedType *get( - ::llvm::LLVMContext & ctx, uint32_t field0, uint32_t field1, uint32_t field2); + ::llvm::LLVMContext & ctx, uint32_t field0, uint8_t field1, VectorKind field2); static bool classof(const ::llvm::Type *t); - unsigned getField0() const { + uint32_t getField0() const { ::llvm::Type *elt = getElementType(1); if (elt->isStructTy()) return 0; return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth(); } - unsigned getField1() const { + uint8_t getField1() const { ::llvm::Type *elt = getElementType(2); if (elt->isStructTy()) return 0; return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth(); } - unsigned getField2() const { + VectorKind getField2() const { ::llvm::Type *elt = getElementType(3); if (elt->isStructTy()) - return 0; - return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth(); + return static_cast(0); + return static_cast(::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth()); } };