Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ RUNTIME_CPP_COMPONENTS = \
hexagon_dma \
hexagon_host \
ios_io \
linux_amx \
linux_clock \
linux_host_cpu_count \
linux_yield \
Expand Down Expand Up @@ -1008,12 +1009,23 @@ RUNTIME_TRIPLE_WIN_ARM_32 = "arm-unknown-windows-unknown"
RUNTIME_TRIPLE_WIN_ARM_64 = "aarch64-unknown-windows-unknown"
RUNTIME_TRIPLE_WIN_GENERIC_64 = "le64-unknown-windows-unknown"

RUNTIME_TRIPLE_AMX_X86_32 = "i386-unknown-unknown-unknown"
RUNTIME_TRIPLE_AMX_X86_64 = "x86_64-unknown-unknown-unknown"

# `-fno-threadsafe-statics` is very important here (note that it allows us to use a 'modern' C++
# standard but still skip threadsafe guards for static initialization in our runtime code)
#
# `-fno-rtti` is necessary to allow us to use classes with virtual functions in the runtime code
RUNTIME_CXX_FLAGS = -std=c++17 -O3 -fno-vectorize -ffreestanding -fno-blocks -fno-exceptions -fno-unwind-tables -fno-threadsafe-statics -fno-rtti

$(BUILD_DIR)/initmod.amx_x86_32.ll: $(SRC_DIR)/runtime/amx.cpp $(BUILD_DIR)/clang_ok
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, these changes/additions will also need to be made in src/runtime/CMakeLists.txt -- we support CMake on all our platforms, whereas Make only on a small percentage.

@mkdir -p $(@D)
$(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m32 -target $(RUNTIME_TRIPLE_AMX_X86_32) -mxsave -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/amx.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.amx_x86_32.d

$(BUILD_DIR)/initmod.amx_x86_64.ll: $(SRC_DIR)/runtime/amx.cpp $(BUILD_DIR)/clang_ok
@mkdir -p $(@D)
$(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m64 -target $(RUNTIME_TRIPLE_AMX_X86_64) -mxsave -DCOMPILING_HALIDE_RUNTIME -DBITS_64 -emit-llvm -S $(SRC_DIR)/runtime/amx.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.amx_x86_64.d

$(BUILD_DIR)/initmod.windows_%_x86_32.ll: $(SRC_DIR)/runtime/windows_%_x86.cpp $(BUILD_DIR)/clang_ok
@mkdir -p $(@D)
$(CLANG) $(CXX_WARNING_FLAGS) $(RUNTIME_CXX_FLAGS) -m32 -target $(RUNTIME_TRIPLE_WIN_X86_32) -DCOMPILING_HALIDE_RUNTIME -DBITS_32 -emit-llvm -S $(SRC_DIR)/runtime/windows_$*_x86.cpp -o $@ -MMD -MP -MF $(BUILD_DIR)/initmod.windows_$*_x86_32.d
Expand Down
5 changes: 5 additions & 0 deletions src/LLVM_Runtime_Linker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ DECLARE_NO_INITMOD(windows_d3d12compute_arm)
#endif // WITH_D3D12

#ifdef WITH_X86
DECLARE_CPP_INITMOD(linux_amx)
DECLARE_LL_INITMOD(x86_amx)
DECLARE_LL_INITMOD(x86_avx512)
DECLARE_LL_INITMOD(x86_avx2)
Expand All @@ -237,6 +238,7 @@ DECLARE_LL_INITMOD(x86)
DECLARE_LL_INITMOD(x86_sse41)
DECLARE_CPP_INITMOD(x86_cpu_features)
#else
DECLARE_NO_INITMOD(linux_amx)
DECLARE_NO_INITMOD(x86_amx)
DECLARE_NO_INITMOD(x86_avx512)
DECLARE_NO_INITMOD(x86_avx2)
Expand Down Expand Up @@ -1087,6 +1089,9 @@ std::unique_ptr<llvm::Module> get_initial_module_for_target(Target t, llvm::LLVM
modules.push_back(get_initmod_x86_avx512_ll(c));
}
if (t.has_feature(Target::AVX512_SapphireRapids)) {
if (t.os == Target::Linux) {
modules.push_back(get_initmod_linux_amx(c, bits_64, debug));
}
modules.push_back(get_initmod_x86_amx_ll(c));
}
if (t.has_feature(Target::Profile)) {
Expand Down
16 changes: 15 additions & 1 deletion src/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ set(RUNTIME_CPP
hexagon_dma_pool
hexagon_host
ios_io
linux_amx
linux_clock
linux_host_cpu_count
linux_yield
Expand Down Expand Up @@ -172,6 +173,8 @@ foreach (i IN LISTS RUNTIME_CPP)
set(fpic -fpic)
# for the generic windows 64-bit target, we need -fshort-wchar
set(fshort-wchar "")
# xsave is required for checking AMX
set(xsave "")
# Windows
if (i MATCHES "windows_.*")
# must omit -fpic, otherwise clang will complain with the following:
Expand Down Expand Up @@ -208,6 +211,17 @@ foreach (i IN LISTS RUNTIME_CPP)
set(TARGET "le64-unknown-windows-unknown")
endif ()
endif()
# AMX only works on x86
elseif (i MATCHES "amx")
message(STATUS "found the amx!")
set(xsave "-mxsave")
if (j EQUAL 32)
set(TARGET "i386-unknown-unknown-unknown")
elseif (j EQUAL 64)
set(TARGET "x86_64-unknown-unknown-unknown")
else ()
continue()
endif ()
# Everything else
else()
if (j EQUAL 32)
Expand All @@ -231,7 +245,7 @@ foreach (i IN LISTS RUNTIME_CPP)
set(INITMOD "_initmod_${i}_${j}${SUFFIX}.cpp")
set(SYMBOL "halide_internal_initmod_${i}_${j}${SUFFIX}")

set(clang_flags ${RUNTIME_CXX_FLAGS} ${fpic} ${fshort-wchar} ${RUNTIME_DEFINES${SUFFIX}} -m${j} -target ${TARGET} -emit-llvm -S -MD -MF "${basename}.d")
set(clang_flags ${RUNTIME_CXX_FLAGS} ${fpic} ${fshort-wchar} ${xsave} ${RUNTIME_DEFINES${SUFFIX}} -m${j} -target ${TARGET} -emit-llvm -S -MD -MF "${basename}.d")

if (Halide_CLANG_TIDY_BUILD)
# Create a 'fake' entry just so that clang-tidy will see a C++ compilation command
Expand Down
10 changes: 10 additions & 0 deletions src/runtime/HalideRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,16 @@ extern halide_can_use_target_features_t halide_set_custom_can_use_target_feature
*/
extern int halide_default_can_use_target_features(int count, const uint64_t *features);

/**
* @brief Enable AMX instructions
*
* This function needs to be called by the user to enable the usage of AMX instructions on Linux.
* Only a single call is required to enable the instructions for the entire process.
*
* @return int 0 on success, error otherwise
*/
extern int halide_enable_amx();

typedef struct halide_dimension_t {
#if (__cplusplus >= 201103L || _MSVC_LANG >= 201103L)
int32_t min = 0, extent = 0, stride = 0;
Expand Down
39 changes: 39 additions & 0 deletions src/runtime/linux_amx.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "HalideRuntime.h"
#include "printer.h"
#include "runtime_internal.h"

#define SYS_arch_prctl 158

extern "C" int syscall(long sysno, ...) throw();

extern "C" WEAK int halide_enable_amx() {
constexpr int XFEATURE_XTILECFG = 17;
constexpr int XFEATURE_XTILEDATA = 18;
constexpr int ARCH_REQ_XCOMP_PERM = 0x1023;

// xgetbv instruction should always be present on CPUs with AMX
long long res = __builtin_ia32_xgetbv(0);

// if AMX is not supported by the OS these bits are not set
if (!(res & (1 << XFEATURE_XTILECFG))) {
error(nullptr) << "XTILECFG not available for AMX instructions.\n";
return -2;
}

if (!(res & (1 << XFEATURE_XTILEDATA))) {
error(nullptr) << "XTILEDATA not available for AMX instruction.\n";
return -2;
}

// we must request permission to use AMX
auto ret = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);

if (ret) {
error(nullptr) << "Failed to enable AMX instructions: " << ret << '\n';
return -1;
}

debug(nullptr) << "AMX permissions acquired\n";

return 0;
}