diff --git a/Makefile b/Makefile index 30b7fb9d7916..ba3094ec459c 100644 --- a/Makefile +++ b/Makefile @@ -781,6 +781,7 @@ RUNTIME_CPP_COMPONENTS = \ hexagon_dma \ hexagon_host \ ios_io \ + linux_amx \ linux_clock \ linux_host_cpu_count \ linux_yield \ @@ -1008,12 +1009,23 @@ RUNTIME_TRIPLE_WIN_ARM_32 = "arm-unknown-windows-unknown" RUNTIME_TRIPLE_WIN_ARM_64 = "aarch64-unknown-windows-unknown" RUNTIME_TRIPLE_WIN_GENERIC_64 = "le64-unknown-windows-unknown" +RUNTIME_TRIPLE_AMX_X86_32 = "i386-unknown-unknown-unknown" +RUNTIME_TRIPLE_AMX_X86_64 = "x86_64-unknown-unknown-unknown" + # `-fno-threadsafe-statics` is very important here (note that it allows us to use a 'modern' C++ # standard but still skip threadsafe guards for static initialization in our runtime code) # # `-fno-rtti` is necessary to allow us to use classes with virtual functions in the runtime code RUNTIME_CXX_FLAGS = -std=c++17 -O3 -fno-vectorize -ffreestanding -fno-blocks -fno-exceptions -fno-unwind-tables -fno-threadsafe-statics -fno-rtti +$(BUILD_DIR)/initmod.amx_x86_32.ll: $(SRC_DIR)/runtime/amx.cpp $(BUILD_DIR)/clang_ok + @mkdir -p $(@D) + $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m32 -target $(RUNTIME_TRIPLE_AMX_X86_32) -mxsave -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/amx.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.amx_x86_32.d + +$(BUILD_DIR)/initmod.amx_x86_64.ll: $(SRC_DIR)/runtime/amx.cpp $(BUILD_DIR)/clang_ok + @mkdir -p $(@D) + $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m64 -target $(RUNTIME_TRIPLE_AMX_X86_64) -mxsave -DCOMPILING_HALIDE_RUNTIME -DBITS_64 -emit-llvm -S $(SRC_DIR)/runtime/amx.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.amx_x86_64.d + $(BUILD_DIR)/initmod.windows_%_x86_32.ll: $(SRC_DIR)/runtime/windows_%_x86.cpp $(BUILD_DIR)/clang_ok @mkdir -p $(@D) $(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m32 -target $(RUNTIME_TRIPLE_WIN_X86_32) -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/windows_$*_x86.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.windows_$*_x86_32.d diff --git a/src/LLVM_Runtime_Linker.cpp b/src/LLVM_Runtime_Linker.cpp index 4995b197b433..c3912867c8fd 100644 --- a/src/LLVM_Runtime_Linker.cpp +++ b/src/LLVM_Runtime_Linker.cpp @@ -229,6 +229,7 @@ DECLARE_NO_INITMOD(windows_d3d12compute_arm) #endif // WITH_D3D12 #ifdef WITH_X86 +DECLARE_CPP_INITMOD(linux_amx) DECLARE_LL_INITMOD(x86_amx) DECLARE_LL_INITMOD(x86_avx512) DECLARE_LL_INITMOD(x86_avx2) @@ -237,6 +238,7 @@ DECLARE_LL_INITMOD(x86) DECLARE_LL_INITMOD(x86_sse41) DECLARE_CPP_INITMOD(x86_cpu_features) #else +DECLARE_NO_INITMOD(linux_amx) DECLARE_NO_INITMOD(x86_amx) DECLARE_NO_INITMOD(x86_avx512) DECLARE_NO_INITMOD(x86_avx2) @@ -1087,6 +1089,9 @@ std::unique_ptr get_initial_module_for_target(Target t, llvm::LLVM modules.push_back(get_initmod_x86_avx512_ll(c)); } if (t.has_feature(Target::AVX512_SapphireRapids)) { + if (t.os == Target::Linux) { + modules.push_back(get_initmod_linux_amx(c, bits_64, debug)); + } modules.push_back(get_initmod_x86_amx_ll(c)); } if (t.has_feature(Target::Profile)) { diff --git a/src/runtime/CMakeLists.txt b/src/runtime/CMakeLists.txt index 946784f662d5..d2b89de0fcb2 100644 --- a/src/runtime/CMakeLists.txt +++ b/src/runtime/CMakeLists.txt @@ -30,6 +30,7 @@ set(RUNTIME_CPP hexagon_dma_pool hexagon_host ios_io + linux_amx linux_clock linux_host_cpu_count linux_yield @@ -172,6 +173,8 @@ foreach (i IN LISTS RUNTIME_CPP) set(fpic -fpic) # for the generic windows 64-bit target, we need -fshort-wchar set(fshort-wchar "") + # xsave is required for checking AMX + set(xsave "") # Windows if (i MATCHES "windows_.*") # must omit -fpic, otherwise clang will complain with the following: @@ -208,6 +211,17 @@ foreach (i IN LISTS RUNTIME_CPP) set(TARGET "le64-unknown-windows-unknown") endif () endif() + # AMX only works on x86 + elseif (i MATCHES "amx") + message(STATUS "found the amx!") + set(xsave "-mxsave") + if (j EQUAL 32) + set(TARGET "i386-unknown-unknown-unknown") + elseif (j EQUAL 64) + set(TARGET "x86_64-unknown-unknown-unknown") + else () + continue() + endif () # Everything else else() if (j EQUAL 32) @@ -231,7 +245,7 @@ foreach (i IN LISTS RUNTIME_CPP) set(INITMOD "_initmod_${i}_${j}${SUFFIX}.cpp") set(SYMBOL "halide_internal_initmod_${i}_${j}${SUFFIX}") - set(clang_flags ${RUNTIME_CXX_FLAGS} ${fpic} ${fshort-wchar} ${RUNTIME_DEFINES${SUFFIX}} -m${j} -target ${TARGET} -emit-llvm -S -MD -MF "${basename}.d") + set(clang_flags ${RUNTIME_CXX_FLAGS} ${fpic} ${fshort-wchar} ${xsave} ${RUNTIME_DEFINES${SUFFIX}} -m${j} -target ${TARGET} -emit-llvm -S -MD -MF "${basename}.d") if (Halide_CLANG_TIDY_BUILD) # Create a 'fake' entry just so that clang-tidy will see a C++ compilation command diff --git a/src/runtime/HalideRuntime.h b/src/runtime/HalideRuntime.h index 62fc35640eb2..6e66f0ea5473 100644 --- a/src/runtime/HalideRuntime.h +++ b/src/runtime/HalideRuntime.h @@ -1406,6 +1406,16 @@ extern halide_can_use_target_features_t halide_set_custom_can_use_target_feature */ extern int halide_default_can_use_target_features(int count, const uint64_t *features); +/** + * @brief Enable AMX instructions + * + * This function needs to be called by the user to enable the usage of AMX instructions on Linux. + * Only a single call is required to enable the instructions for the entire process. + * + * @return int 0 on success, error otherwise + */ +extern int halide_enable_amx(); + typedef struct halide_dimension_t { #if (__cplusplus >= 201103L || _MSVC_LANG >= 201103L) int32_t min = 0, extent = 0, stride = 0; diff --git a/src/runtime/linux_amx.cpp b/src/runtime/linux_amx.cpp new file mode 100644 index 000000000000..2b5723f6bb9e --- /dev/null +++ b/src/runtime/linux_amx.cpp @@ -0,0 +1,39 @@ +#include "HalideRuntime.h" +#include "printer.h" +#include "runtime_internal.h" + +#define SYS_arch_prctl 158 + +extern "C" int syscall(long sysno, ...) throw(); + +extern "C" WEAK int halide_enable_amx() { + constexpr int XFEATURE_XTILECFG = 17; + constexpr int XFEATURE_XTILEDATA = 18; + constexpr int ARCH_REQ_XCOMP_PERM = 0x1023; + + // xgetbv instruction should always be present on CPUs with AMX + long long res = __builtin_ia32_xgetbv(0); + + // if AMX is not supported by the OS these bits are not set + if (!(res & (1 << XFEATURE_XTILECFG))) { + error(nullptr) << "XTILECFG not available for AMX instructions.\n"; + return -2; + } + + if (!(res & (1 << XFEATURE_XTILEDATA))) { + error(nullptr) << "XTILEDATA not available for AMX instruction.\n"; + return -2; + } + + // we must request permission to use AMX + auto ret = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); + + if (ret) { + error(nullptr) << "Failed to enable AMX instructions: " << ret << '\n'; + return -1; + } + + debug(nullptr) << "AMX permissions acquired\n"; + + return 0; +}