From 8d8dfc1c3403f83ce529feb57f56389b6966192a Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 29 Apr 2022 15:38:05 +0100 Subject: [PATCH 01/12] add runtime code for requesting amx permissions --- src/runtime/CMakeLists.txt | 16 +++++++++++++++- src/runtime/amx.cpp | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 src/runtime/amx.cpp diff --git a/src/runtime/CMakeLists.txt b/src/runtime/CMakeLists.txt index 946784f662d5..a47f24cac8a7 100644 --- a/src/runtime/CMakeLists.txt +++ b/src/runtime/CMakeLists.txt @@ -5,6 +5,7 @@ set(RUNTIME_CPP alignment_32 alignment_64 allocation_cache + amx android_clock android_host_cpu_count android_io @@ -172,6 +173,8 @@ foreach (i IN LISTS RUNTIME_CPP) set(fpic -fpic) # for the generic windows 64-bit target, we need -fshort-wchar set(fshort-wchar "") + # xsave is required for checking AMX + set(xsave "") # Windows if (i MATCHES "windows_.*") # must omit -fpic, otherwise clang will complain with the following: @@ -208,6 +211,17 @@ foreach (i IN LISTS RUNTIME_CPP) set(TARGET "le64-unknown-windows-unknown") endif () endif() + # AMX only works on x86 + elseif (i MATCHES "amx") + message(STATUS "found the amx!") + set(xsave "-mxsave") + if (j EQUAL 32) + set(TARGET "i386-unknown-unknown-unknown") + elseif (j EQUAL 64) + set(TARGET "x86_64-unknown-unknown-unknown") + else () + continue() + endif () # Everything else else() if (j EQUAL 32) @@ -231,7 +245,7 @@ foreach (i IN LISTS RUNTIME_CPP) set(INITMOD "_initmod_${i}_${j}${SUFFIX}.cpp") set(SYMBOL "halide_internal_initmod_${i}_${j}${SUFFIX}") - set(clang_flags ${RUNTIME_CXX_FLAGS} ${fpic} ${fshort-wchar} ${RUNTIME_DEFINES${SUFFIX}} -m${j} -target ${TARGET} -emit-llvm -S -MD -MF "${basename}.d") + set(clang_flags ${RUNTIME_CXX_FLAGS} ${fpic} ${fshort-wchar} ${xsave} ${RUNTIME_DEFINES${SUFFIX}} -m${j} -target ${TARGET} -emit-llvm -S -MD -MF "${basename}.d") if (Halide_CLANG_TIDY_BUILD) # Create a 'fake' entry just so that clang-tidy will see a C++ compilation command diff --git a/src/runtime/amx.cpp b/src/runtime/amx.cpp new file mode 100644 index 000000000000..ff63b29667e6 --- /dev/null +++ b/src/runtime/amx.cpp @@ -0,0 +1,32 @@ +#include "runtime_internal.h" + +#define SYS_arch_prctl 158 + +extern "C" long syscall(long sysno, ...) throw(); + +extern "C" WEAK int halide_amx_req_perm(void *user_context) { + constexpr int XFEATURE_XTILECFG = 17; + constexpr int XFEATURE_XTILEDATA = 18; + constexpr int ARCH_REQ_XCOMP_PERM = 0x1023; + + // xgetbv instruction should always be present on CPUs with AMX + long long res = __builtin_ia32_xgetbv(0); + + // if AMX is not supported by the OS these bits are not set + if (!(res & (1 << XFEATURE_XTILECFG))) { + return -2; + } + + if (!(res & (1 << XFEATURE_XTILEDATA))) { + return -2; + } + + // we must request permission to use AMX + long ret = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); + + if (ret) { + return -1; + } + + return 0; +} From 3ecebf4603e2cee4663853ca869c02f89436ffd7 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 29 Apr 2022 16:26:55 +0100 Subject: [PATCH 02/12] add amx module to cpp initmods --- src/LLVM_Runtime_Linker.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/LLVM_Runtime_Linker.cpp b/src/LLVM_Runtime_Linker.cpp index 4995b197b433..1f6674a684a3 100644 --- a/src/LLVM_Runtime_Linker.cpp +++ b/src/LLVM_Runtime_Linker.cpp @@ -229,6 +229,7 @@ DECLARE_NO_INITMOD(windows_d3d12compute_arm) #endif // WITH_D3D12 #ifdef WITH_X86 +DECLARE_CPP_INITMOD(amx) DECLARE_LL_INITMOD(x86_amx) DECLARE_LL_INITMOD(x86_avx512) DECLARE_LL_INITMOD(x86_avx2) @@ -237,6 +238,7 @@ DECLARE_LL_INITMOD(x86) DECLARE_LL_INITMOD(x86_sse41) DECLARE_CPP_INITMOD(x86_cpu_features) #else +DECLARE_NO_INITMOD(amx) DECLARE_NO_INITMOD(x86_amx) DECLARE_NO_INITMOD(x86_avx512) DECLARE_NO_INITMOD(x86_avx2) @@ -1087,6 +1089,7 @@ std::unique_ptr get_initial_module_for_target(Target t, llvm::LLVM modules.push_back(get_initmod_x86_avx512_ll(c)); } if (t.has_feature(Target::AVX512_SapphireRapids)) { + modules.push_back(get_initmod_amx(c, bits_64, debug)); modules.push_back(get_initmod_x86_amx_ll(c)); } if (t.has_feature(Target::Profile)) { From 320912e0022d69eab13ecd2bf0d83ad600272fde Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Tue, 3 May 2022 11:16:35 +0100 Subject: [PATCH 03/12] add error messages if support is not found --- src/runtime/amx.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/runtime/amx.cpp b/src/runtime/amx.cpp index ff63b29667e6..3ce2c2350433 100644 --- a/src/runtime/amx.cpp +++ b/src/runtime/amx.cpp @@ -4,7 +4,7 @@ extern "C" long syscall(long sysno, ...) throw(); -extern "C" WEAK int halide_amx_req_perm(void *user_context) { +extern "C" WEAK int halide_amx_req_perm() { constexpr int XFEATURE_XTILECFG = 17; constexpr int XFEATURE_XTILEDATA = 18; constexpr int ARCH_REQ_XCOMP_PERM = 0x1023; @@ -14,10 +14,12 @@ extern "C" WEAK int halide_amx_req_perm(void *user_context) { // if AMX is not supported by the OS these bits are not set if (!(res & (1 << XFEATURE_XTILECFG))) { + halide_error(nullptr, "XTILECFG not available for AMX instructions."); return -2; } if (!(res & (1 << XFEATURE_XTILEDATA))) { + halide_error(nullptr, "XTILEDATA not available for AMX instructions."); return -2; } @@ -25,6 +27,7 @@ extern "C" WEAK int halide_amx_req_perm(void *user_context) { long ret = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); if (ret) { + halide_error(nullptr, "Failed to enable AMX instructions."); return -1; } From fa14962311cf5b137240c4693a5f310f35c031d0 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Tue, 3 May 2022 14:05:14 +0100 Subject: [PATCH 04/12] add class which will inject amx permission request --- src/AMXReqPerm.cpp | 21 +++++++++++++++++++++ src/AMXReqPerm.h | 21 +++++++++++++++++++++ src/CMakeLists.txt | 2 ++ 3 files changed, 44 insertions(+) create mode 100644 src/AMXReqPerm.cpp create mode 100644 src/AMXReqPerm.h diff --git a/src/AMXReqPerm.cpp b/src/AMXReqPerm.cpp new file mode 100644 index 000000000000..8297f3f9b023 --- /dev/null +++ b/src/AMXReqPerm.cpp @@ -0,0 +1,21 @@ +#include "AMXReqPerm.h" + +#include "IR.h" + +namespace Halide { +namespace Internal { +void AMXReqPerm::enable_amx() { + requires_amx_ = true; +} + +Stmt AMXReqPerm::inject_request_amx(Stmt s) { + if (requires_amx_) { + return Block::make({Evaluate::make(Call::make(type_of(), "halide_amx_req_perm", {}, Call::Extern)), + s, + Evaluate::make(Call::make(type_of(), "halide_amx_free_perm", {}, Call::Extern))}); + } else { + return s; + } +} +} // namespace Internal +} // namespace Halide \ No newline at end of file diff --git a/src/AMXReqPerm.h b/src/AMXReqPerm.h new file mode 100644 index 000000000000..42c57af28254 --- /dev/null +++ b/src/AMXReqPerm.h @@ -0,0 +1,21 @@ +#ifndef HALIDE_AMX_REQ_PERM_H +#define HALIDE_AMX_REQ_PERM_H + +#include "Expr.h" + +namespace Halide { +namespace Internal { +class AMXReqPerm { + bool requires_amx_{false}; + +public: + AMXReqPerm() = default; + + void enable_amx(); + + Stmt inject_request_amx(Stmt s); +}; +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c31e37c32a20..4a9bfbf36391 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -12,6 +12,7 @@ set(HEADER_FILES AddParameterChecks.h AlignLoads.h AllocationBoundsInference.h + AMXReqPerm.h ApplySplit.h Argument.h AssociativeOpsTable.h @@ -175,6 +176,7 @@ set(SOURCE_FILES AddParameterChecks.cpp AlignLoads.cpp AllocationBoundsInference.cpp + AMXReqPerm.cpp ApplySplit.cpp Argument.cpp AssociativeOpsTable.cpp From fede3601f4eba08124e4be9587ed12ef7a42d199 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Wed, 4 May 2022 18:37:35 +0100 Subject: [PATCH 05/12] print the error code on failure --- src/runtime/amx.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/runtime/amx.cpp b/src/runtime/amx.cpp index 3ce2c2350433..28eb4888ffbf 100644 --- a/src/runtime/amx.cpp +++ b/src/runtime/amx.cpp @@ -1,8 +1,10 @@ +#include "HalideRuntime.h" +#include "printer.h" #include "runtime_internal.h" #define SYS_arch_prctl 158 -extern "C" long syscall(long sysno, ...) throw(); +extern "C" int syscall(long sysno, ...) throw(); extern "C" WEAK int halide_amx_req_perm() { constexpr int XFEATURE_XTILECFG = 17; @@ -14,22 +16,24 @@ extern "C" WEAK int halide_amx_req_perm() { // if AMX is not supported by the OS these bits are not set if (!(res & (1 << XFEATURE_XTILECFG))) { - halide_error(nullptr, "XTILECFG not available for AMX instructions."); + error(nullptr) << "XTILECFG not available for AMX instructions.\n"; return -2; } if (!(res & (1 << XFEATURE_XTILEDATA))) { - halide_error(nullptr, "XTILEDATA not available for AMX instructions."); + error(nullptr) << "XTILEDATA not available for AMX instruction.\n"; return -2; } // we must request permission to use AMX - long ret = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); + auto ret = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); if (ret) { - halide_error(nullptr, "Failed to enable AMX instructions."); + error(nullptr) << "Failed to enable AMX instructions: " << ret << '\n'; return -1; } + debug(nullptr) << "AMX permissions acquired\n"; + return 0; } From 400651629c485535164e05e7c3bc97bd615ff0a6 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Thu, 5 May 2022 13:27:44 +0100 Subject: [PATCH 06/12] add build of amx module to Makefile --- Makefile | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/Makefile b/Makefile index 30b7fb9d7916..6eb64613c449 100644 --- a/Makefile +++ b/Makefile @@ -756,6 +756,7 @@ RUNTIME_CPP_COMPONENTS = \ alignment_32 \ allocation_cache \ alignment_64 \ + amx \ android_clock \ android_host_cpu_count \ android_io \ @@ -1008,12 +1009,23 @@ RUNTIME_TRIPLE_WIN_ARM_32 = "arm-unknown-windows-unknown" RUNTIME_TRIPLE_WIN_ARM_64 = "aarch64-unknown-windows-unknown" RUNTIME_TRIPLE_WIN_GENERIC_64 = "le64-unknown-windows-unknown" +RUNTIME_TRIPLE_AMX_X86_32 = "i386-unknown-unknown-unknown" +RUNTIME_TRIPLE_AMX_X86_64 = "x86_64-unknown-unknown-unknown" + # `-fno-threadsafe-statics` is very important here (note that it allows us to use a 'modern' C++ # standard but still skip threadsafe guards for static initialization in our runtime code) # # `-fno-rtti` is necessary to allow us to use classes with virtual functions in the runtime code RUNTIME_CXX_FLAGS = -std=c++17 -O3 -fno-vectorize -ffreestanding -fno-blocks -fno-exceptions -fno-unwind-tables -fno-threadsafe-statics -fno-rtti +$(BUILD_DIR)/initmod.amx_x86_32.ll: $(SRC_DIR)/runtime/amx.cpp $(BUILD_DIR)/clang_ok + @mkdir -p $(@D) + $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m32 -target $(RUNTIME_TRIPLE_AMX_X86_32) -mxsave -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/amx.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.amx_x86_32.d + +$(BUILD_DIR)/initmod.amx_x86_64.ll: $(SRC_DIR)/runtime/amx.cpp $(BUILD_DIR)/clang_ok + @mkdir -p $(@D) + $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m64 -target $(RUNTIME_TRIPLE_AMX_X86_64) -mxsave -DCOMPILING_HALIDE_RUNTIME -DBITS_64 -emit-llvm -S $(SRC_DIR)/runtime/amx.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.amx_x86_64.d + $(BUILD_DIR)/initmod.windows_%_x86_32.ll: $(SRC_DIR)/runtime/windows_%_x86.cpp $(BUILD_DIR)/clang_ok @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m32 -target $(RUNTIME_TRIPLE_WIN_X86_32) -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/windows_$*_x86.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.windows_$*_x86_32.d From 750809207f742398fa8daf16f34c0c7f423e27bc Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Wed, 3 Aug 2022 16:27:43 +0100 Subject: [PATCH 07/12] Use AMXRequestPermission in ExtractTileOperations Whenever an AMX instruction is generated ReqPerm will be set to true, which as a result in a later lowering step will insert the appropriate calls to enable AMX instructions on a kernel which supports this. --- src/AMXReqPerm.cpp | 21 ----------------- src/AMXReqPerm.h | 21 ----------------- src/CMakeLists.txt | 2 -- src/ExtractTileOperations.cpp | 44 +++++++++++++++++++++++++++++------ src/ExtractTileOperations.h | 2 ++ 5 files changed, 39 insertions(+), 51 deletions(-) delete mode 100644 src/AMXReqPerm.cpp delete mode 100644 src/AMXReqPerm.h diff --git a/src/AMXReqPerm.cpp b/src/AMXReqPerm.cpp deleted file mode 100644 index 8297f3f9b023..000000000000 --- a/src/AMXReqPerm.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include "AMXReqPerm.h" - -#include "IR.h" - -namespace Halide { -namespace Internal { -void AMXReqPerm::enable_amx() { - requires_amx_ = true; -} - -Stmt AMXReqPerm::inject_request_amx(Stmt s) { - if (requires_amx_) { - return Block::make({Evaluate::make(Call::make(type_of(), "halide_amx_req_perm", {}, Call::Extern)), - s, - Evaluate::make(Call::make(type_of(), "halide_amx_free_perm", {}, Call::Extern))}); - } else { - return s; - } -} -} // namespace Internal -} // namespace Halide \ No newline at end of file diff --git a/src/AMXReqPerm.h b/src/AMXReqPerm.h deleted file mode 100644 index 42c57af28254..000000000000 --- a/src/AMXReqPerm.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef HALIDE_AMX_REQ_PERM_H -#define HALIDE_AMX_REQ_PERM_H - -#include "Expr.h" - -namespace Halide { -namespace Internal { -class AMXReqPerm { - bool requires_amx_{false}; - -public: - AMXReqPerm() = default; - - void enable_amx(); - - Stmt inject_request_amx(Stmt s); -}; -} // namespace Internal -} // namespace Halide - -#endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4a9bfbf36391..c31e37c32a20 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -12,7 +12,6 @@ set(HEADER_FILES AddParameterChecks.h AlignLoads.h AllocationBoundsInference.h - AMXReqPerm.h ApplySplit.h Argument.h AssociativeOpsTable.h @@ -176,7 +175,6 @@ set(SOURCE_FILES AddParameterChecks.cpp AlignLoads.cpp AllocationBoundsInference.cpp - AMXReqPerm.cpp ApplySplit.cpp Argument.cpp AssociativeOpsTable.cpp diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 8fdcea73f34b..d81797515456 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -1,5 +1,6 @@ #include "ExtractTileOperations.h" +#include "IR.h" #include "IRMatch.h" #include "IRMutator.h" #include "IROperator.h" @@ -37,6 +38,22 @@ using std::string; using std::vector; namespace { +struct RequestPermission { + bool requires_amx_{false}; + + void enable_amx() { + requires_amx_ = true; + } + + /// Inject a call to enable amx through a syscall if amx was detected + Stmt inject_request_amx(Stmt s) { + if (requires_amx_) { + return Block::make({Evaluate::make(Call::make(type_of(), "halide_amx_req_perm", {}, Call::Extern)), std::move(s)}); + } else { + return s; + } + } +}; template struct Tile { @@ -379,7 +396,7 @@ struct Matmul { int tile_r; }; -Matmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType op_type) { +Matmul convert_to_matmul(const Store *op, RequestPermission &perm, const string &new_name, AMXOpType op_type) { // m[ramp(0, 1, S)] = VectorAdd(lhs[{XYR tile}] * xX(rhs[{YR tile}])) + m[ramp(0, 1, S)] const auto wild_i8x = Variable::make(Int(8, 0), "*"); const auto wild_u8x = Variable::make(UInt(8, 0), "*"); @@ -522,7 +539,7 @@ Matmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType op_t return {true, std::move(store), tile_x, tile_y, tile_r}; } -Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_name) { +Stmt convert_to_zero(const Store *op, RequestPermission &perm, int tile_x, int tile_y, const string &new_name) { if (const auto *ramp = op->index.as()) { if (const auto *bcast = op->value.as()) { if (is_const_one(ramp->stride) && @@ -534,6 +551,7 @@ Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_ const auto &store_type = op->value.type(); // will be f32 or i32 auto tile_zero_type = store_type.with_lanes(1024 / store_type.bytes()); + perm.enable_amx(); auto val = Call::make(tile_zero_type, "tile_zero", {rows, colbytes}, Call::Intrinsic); auto store = Store::make(new_name, std::move(val), Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); return store; @@ -543,7 +561,7 @@ Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_ return {}; } -Stmt convert_to_tile_store(const Store *op, const string &amx_name, int tile_x, int tile_y) { +Stmt convert_to_tile_store(const Store *op, RequestPermission &perm, const string &amx_name, int tile_x, int tile_y) { auto tile = get_2d_tile_index(op->index); if (tile.result && tile.extent[0] == tile_x && tile.extent[1] == tile_y) { auto out = Variable::make(Handle(), op->name); @@ -552,6 +570,7 @@ Stmt convert_to_tile_store(const Store *op, const string &amx_name, int tile_x, auto bytes = op->value.type().bytes(); internal_assert(bytes == 4) << "AMX store only supported for int32 and float32 output, not for " << op->value.type() << "\n"; // {tile_x, tile_y, var, base, stride} + perm.enable_amx(); auto store = Call::make(Int(32), "tile_store", {tile_x, tile_y * bytes, std::move(out), tile.base * bytes, tile.stride[0] * bytes, std::move(tile_val)}, Call::Intrinsic); return Evaluate::make(std::move(store)); } @@ -561,6 +580,8 @@ Stmt convert_to_tile_store(const Store *op, const string &amx_name, int tile_x, class ExtractTileOperations : public IRMutator { using IRMutator::visit; + RequestPermission &perm; + string tile_name; string amx_name; vector pending_stores; @@ -570,6 +591,12 @@ class ExtractTileOperations : public IRMutator { int found_tile_r = -1; AMXOpType op_type; +public: + ExtractTileOperations(RequestPermission &arp) + : perm(arp) { + } + +private: Stmt visit(const Allocate *op) override { if (op->memory_type == MemoryType::AMXTile) { user_assert( @@ -634,12 +661,12 @@ class ExtractTileOperations : public IRMutator { if (!load || load->name != tile_name) { return op; } - auto store = convert_to_tile_store(op, amx_name, found_tile_x, found_tile_y); + auto store = convert_to_tile_store(op, perm, amx_name, found_tile_x, found_tile_y); user_assert(store.defined()) << "Store to AMX tile allocation of a non-tile value"; return store; } - auto matmul = convert_to_matmul(op, amx_name, op_type); + auto matmul = convert_to_matmul(op, perm, amx_name, op_type); if (matmul.result) { user_assert( (found_tile_x < 0 || matmul.tile_x == found_tile_x) && @@ -658,7 +685,7 @@ class ExtractTileOperations : public IRMutator { return op; } - auto zero = convert_to_zero(op, found_tile_x, found_tile_y, amx_name); + auto zero = convert_to_zero(op, perm, found_tile_x, found_tile_y, amx_name); if (zero.defined()) { return zero; } @@ -672,7 +699,10 @@ class ExtractTileOperations : public IRMutator { } // namespace Stmt extract_tile_operations(const Stmt &s) { - return ExtractTileOperations().mutate(s); + RequestPermission perm; + + Stmt s_extracted = ExtractTileOperations{perm}.mutate(s); + return perm.inject_request_amx(std::move(s_extracted)); } } // namespace Internal } // namespace Halide diff --git a/src/ExtractTileOperations.h b/src/ExtractTileOperations.h index 918e3b1b9940..a3345b2b96a2 100644 --- a/src/ExtractTileOperations.h +++ b/src/ExtractTileOperations.h @@ -11,6 +11,8 @@ namespace Halide { namespace Internal { +class AMXRequestPermission; + /** Rewrite any AMX tile operations that have been stored in the AMXTile memory * type as intrinsic calls, to be used in the X86 backend. */ Stmt extract_tile_operations(const Stmt &s); From bc161c29d445491381f23e936527e997746683b7 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Thu, 11 Aug 2022 16:48:38 +0100 Subject: [PATCH 08/12] rename amx.cpp -> linux_amx.cpp Rename the file to indicate that this is linux specific code. Additionally makes sure the module is only available for linux OS. --- Makefile | 2 +- src/LLVM_Runtime_Linker.cpp | 8 +++++--- src/runtime/CMakeLists.txt | 2 +- src/runtime/{amx.cpp => linux_amx.cpp} | 0 4 files changed, 7 insertions(+), 5 deletions(-) rename src/runtime/{amx.cpp => linux_amx.cpp} (100%) diff --git a/Makefile b/Makefile index 6eb64613c449..ba3094ec459c 100644 --- a/Makefile +++ b/Makefile @@ -756,7 +756,6 @@ RUNTIME_CPP_COMPONENTS = \ alignment_32 \ allocation_cache \ alignment_64 \ - amx \ android_clock \ android_host_cpu_count \ android_io \ @@ -782,6 +781,7 @@ RUNTIME_CPP_COMPONENTS = \ hexagon_dma \ hexagon_host \ ios_io \ + linux_amx \ linux_clock \ linux_host_cpu_count \ linux_yield \ diff --git a/src/LLVM_Runtime_Linker.cpp b/src/LLVM_Runtime_Linker.cpp index 1f6674a684a3..c3912867c8fd 100644 --- a/src/LLVM_Runtime_Linker.cpp +++ b/src/LLVM_Runtime_Linker.cpp @@ -229,7 +229,7 @@ DECLARE_NO_INITMOD(windows_d3d12compute_arm) #endif // WITH_D3D12 #ifdef WITH_X86 -DECLARE_CPP_INITMOD(amx) +DECLARE_CPP_INITMOD(linux_amx) DECLARE_LL_INITMOD(x86_amx) DECLARE_LL_INITMOD(x86_avx512) DECLARE_LL_INITMOD(x86_avx2) @@ -238,7 +238,7 @@ DECLARE_LL_INITMOD(x86) DECLARE_LL_INITMOD(x86_sse41) DECLARE_CPP_INITMOD(x86_cpu_features) #else -DECLARE_NO_INITMOD(amx) +DECLARE_NO_INITMOD(linux_amx) DECLARE_NO_INITMOD(x86_amx) DECLARE_NO_INITMOD(x86_avx512) DECLARE_NO_INITMOD(x86_avx2) @@ -1089,7 +1089,9 @@ std::unique_ptr get_initial_module_for_target(Target t, llvm::LLVM modules.push_back(get_initmod_x86_avx512_ll(c)); } if (t.has_feature(Target::AVX512_SapphireRapids)) { - modules.push_back(get_initmod_amx(c, bits_64, debug)); + if (t.os == Target::Linux) { + modules.push_back(get_initmod_linux_amx(c, bits_64, debug)); + } modules.push_back(get_initmod_x86_amx_ll(c)); } if (t.has_feature(Target::Profile)) { diff --git a/src/runtime/CMakeLists.txt b/src/runtime/CMakeLists.txt index a47f24cac8a7..d2b89de0fcb2 100644 --- a/src/runtime/CMakeLists.txt +++ b/src/runtime/CMakeLists.txt @@ -5,7 +5,6 @@ set(RUNTIME_CPP alignment_32 alignment_64 allocation_cache - amx android_clock android_host_cpu_count android_io @@ -31,6 +30,7 @@ set(RUNTIME_CPP hexagon_dma_pool hexagon_host ios_io + linux_amx linux_clock linux_host_cpu_count linux_yield diff --git a/src/runtime/amx.cpp b/src/runtime/linux_amx.cpp similarity index 100% rename from src/runtime/amx.cpp rename to src/runtime/linux_amx.cpp From d02e9bf5c0ff202cacb3159aa2aa747db5cd59f9 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Thu, 11 Aug 2022 16:50:03 +0100 Subject: [PATCH 09/12] rename to `halide_enable_amx` This name reflects the intention better --- src/runtime/linux_amx.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/linux_amx.cpp b/src/runtime/linux_amx.cpp index 28eb4888ffbf..2b5723f6bb9e 100644 --- a/src/runtime/linux_amx.cpp +++ b/src/runtime/linux_amx.cpp @@ -6,7 +6,7 @@ extern "C" int syscall(long sysno, ...) throw(); -extern "C" WEAK int halide_amx_req_perm() { +extern "C" WEAK int halide_enable_amx() { constexpr int XFEATURE_XTILECFG = 17; constexpr int XFEATURE_XTILEDATA = 18; constexpr int ARCH_REQ_XCOMP_PERM = 0x1023; From a3522176b99a647ace039b3bdb599d613c78ff91 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Thu, 11 Aug 2022 16:50:29 +0100 Subject: [PATCH 10/12] add a symbol for `halide_enable_amx` in the runtime header --- src/runtime/runtime_internal.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/runtime/runtime_internal.h b/src/runtime/runtime_internal.h index 2801f9bfedc5..01a22c507cdf 100644 --- a/src/runtime/runtime_internal.h +++ b/src/runtime/runtime_internal.h @@ -186,6 +186,8 @@ struct halide_pseudostack_slot_t { WEAK void halide_use_jit_module(); WEAK void halide_release_jit_module(); +WEAK int halide_enable_amx(); + WEAK_INLINE int halide_malloc_alignment(); void halide_thread_yield(); From 58dcd78cfe25c98b20772c97eff8259e5e185991 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Thu, 11 Aug 2022 17:39:09 +0100 Subject: [PATCH 11/12] Revert changes to ExtractTileOperations This reverts commit 69228f73354d819f04513089fefd72c2b4abf038. This reverts commit 29e765898359a74f77fbd531d41ce9d5fbe66045. --- src/ExtractTileOperations.cpp | 44 ++++++----------------------------- src/ExtractTileOperations.h | 2 -- 2 files changed, 7 insertions(+), 39 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index d81797515456..8fdcea73f34b 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -1,6 +1,5 @@ #include "ExtractTileOperations.h" -#include "IR.h" #include "IRMatch.h" #include "IRMutator.h" #include "IROperator.h" @@ -38,22 +37,6 @@ using std::string; using std::vector; namespace { -struct RequestPermission { - bool requires_amx_{false}; - - void enable_amx() { - requires_amx_ = true; - } - - /// Inject a call to enable amx through a syscall if amx was detected - Stmt inject_request_amx(Stmt s) { - if (requires_amx_) { - return Block::make({Evaluate::make(Call::make(type_of(), "halide_amx_req_perm", {}, Call::Extern)), std::move(s)}); - } else { - return s; - } - } -}; template struct Tile { @@ -396,7 +379,7 @@ struct Matmul { int tile_r; }; -Matmul convert_to_matmul(const Store *op, RequestPermission &perm, const string &new_name, AMXOpType op_type) { +Matmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType op_type) { // m[ramp(0, 1, S)] = VectorAdd(lhs[{XYR tile}] * xX(rhs[{YR tile}])) + m[ramp(0, 1, S)] const auto wild_i8x = Variable::make(Int(8, 0), "*"); const auto wild_u8x = Variable::make(UInt(8, 0), "*"); @@ -539,7 +522,7 @@ Matmul convert_to_matmul(const Store *op, RequestPermission &perm, const string return {true, std::move(store), tile_x, tile_y, tile_r}; } -Stmt convert_to_zero(const Store *op, RequestPermission &perm, int tile_x, int tile_y, const string &new_name) { +Stmt convert_to_zero(const Store *op, int tile_x, int tile_y, const string &new_name) { if (const auto *ramp = op->index.as()) { if (const auto *bcast = op->value.as()) { if (is_const_one(ramp->stride) && @@ -551,7 +534,6 @@ Stmt convert_to_zero(const Store *op, RequestPermission &perm, int tile_x, int t const auto &store_type = op->value.type(); // will be f32 or i32 auto tile_zero_type = store_type.with_lanes(1024 / store_type.bytes()); - perm.enable_amx(); auto val = Call::make(tile_zero_type, "tile_zero", {rows, colbytes}, Call::Intrinsic); auto store = Store::make(new_name, std::move(val), Ramp::make(0, 1, 256), Parameter(), const_true(256), ModulusRemainder()); return store; @@ -561,7 +543,7 @@ Stmt convert_to_zero(const Store *op, RequestPermission &perm, int tile_x, int t return {}; } -Stmt convert_to_tile_store(const Store *op, RequestPermission &perm, const string &amx_name, int tile_x, int tile_y) { +Stmt convert_to_tile_store(const Store *op, const string &amx_name, int tile_x, int tile_y) { auto tile = get_2d_tile_index(op->index); if (tile.result && tile.extent[0] == tile_x && tile.extent[1] == tile_y) { auto out = Variable::make(Handle(), op->name); @@ -570,7 +552,6 @@ Stmt convert_to_tile_store(const Store *op, RequestPermission &perm, const strin auto bytes = op->value.type().bytes(); internal_assert(bytes == 4) << "AMX store only supported for int32 and float32 output, not for " << op->value.type() << "\n"; // {tile_x, tile_y, var, base, stride} - perm.enable_amx(); auto store = Call::make(Int(32), "tile_store", {tile_x, tile_y * bytes, std::move(out), tile.base * bytes, tile.stride[0] * bytes, std::move(tile_val)}, Call::Intrinsic); return Evaluate::make(std::move(store)); } @@ -580,8 +561,6 @@ Stmt convert_to_tile_store(const Store *op, RequestPermission &perm, const strin class ExtractTileOperations : public IRMutator { using IRMutator::visit; - RequestPermission &perm; - string tile_name; string amx_name; vector pending_stores; @@ -591,12 +570,6 @@ class ExtractTileOperations : public IRMutator { int found_tile_r = -1; AMXOpType op_type; -public: - ExtractTileOperations(RequestPermission &arp) - : perm(arp) { - } - -private: Stmt visit(const Allocate *op) override { if (op->memory_type == MemoryType::AMXTile) { user_assert( @@ -661,12 +634,12 @@ class ExtractTileOperations : public IRMutator { if (!load || load->name != tile_name) { return op; } - auto store = convert_to_tile_store(op, perm, amx_name, found_tile_x, found_tile_y); + auto store = convert_to_tile_store(op, amx_name, found_tile_x, found_tile_y); user_assert(store.defined()) << "Store to AMX tile allocation of a non-tile value"; return store; } - auto matmul = convert_to_matmul(op, perm, amx_name, op_type); + auto matmul = convert_to_matmul(op, amx_name, op_type); if (matmul.result) { user_assert( (found_tile_x < 0 || matmul.tile_x == found_tile_x) && @@ -685,7 +658,7 @@ class ExtractTileOperations : public IRMutator { return op; } - auto zero = convert_to_zero(op, perm, found_tile_x, found_tile_y, amx_name); + auto zero = convert_to_zero(op, found_tile_x, found_tile_y, amx_name); if (zero.defined()) { return zero; } @@ -699,10 +672,7 @@ class ExtractTileOperations : public IRMutator { } // namespace Stmt extract_tile_operations(const Stmt &s) { - RequestPermission perm; - - Stmt s_extracted = ExtractTileOperations{perm}.mutate(s); - return perm.inject_request_amx(std::move(s_extracted)); + return ExtractTileOperations().mutate(s); } } // namespace Internal } // namespace Halide diff --git a/src/ExtractTileOperations.h b/src/ExtractTileOperations.h index a3345b2b96a2..918e3b1b9940 100644 --- a/src/ExtractTileOperations.h +++ b/src/ExtractTileOperations.h @@ -11,8 +11,6 @@ namespace Halide { namespace Internal { -class AMXRequestPermission; - /** Rewrite any AMX tile operations that have been stored in the AMXTile memory * type as intrinsic calls, to be used in the X86 backend. */ Stmt extract_tile_operations(const Stmt &s); From 497536e4aad5a6efc9fc70644c3039010acb11d4 Mon Sep 17 00:00:00 2001 From: Frederik Engels Date: Fri, 12 Aug 2022 16:13:24 +0100 Subject: [PATCH 12/12] move the `halide_enable_amx` symbol to HalideRuntime.h --- src/runtime/HalideRuntime.h | 10 ++++++++++ src/runtime/runtime_internal.h | 2 -- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/runtime/HalideRuntime.h b/src/runtime/HalideRuntime.h index 62fc35640eb2..6e66f0ea5473 100644 --- a/src/runtime/HalideRuntime.h +++ b/src/runtime/HalideRuntime.h @@ -1406,6 +1406,16 @@ extern halide_can_use_target_features_t halide_set_custom_can_use_target_feature */ extern int halide_default_can_use_target_features(int count, const uint64_t *features); +/** + * @brief Enable AMX instructions + * + * This function needs to be called by the user to enable the usage of AMX instructions on Linux. + * Only a single call is required to enable the instructions for the entire process. + * + * @return int 0 on success, error otherwise + */ +extern int halide_enable_amx(); + typedef struct halide_dimension_t { #if (__cplusplus >= 201103L || _MSVC_LANG >= 201103L) int32_t min = 0, extent = 0, stride = 0; diff --git a/src/runtime/runtime_internal.h b/src/runtime/runtime_internal.h index 01a22c507cdf..2801f9bfedc5 100644 --- a/src/runtime/runtime_internal.h +++ b/src/runtime/runtime_internal.h @@ -186,8 +186,6 @@ struct halide_pseudostack_slot_t { WEAK void halide_use_jit_module(); WEAK void halide_release_jit_module(); -WEAK int halide_enable_amx(); - WEAK_INLINE int halide_malloc_alignment(); void halide_thread_yield();