diff --git a/Makefile b/Makefile index a87c9def123..13664f950ac 100644 --- a/Makefile +++ b/Makefile @@ -158,6 +158,11 @@ ifdef LLAMA_PERF CFLAGS += -DGGML_PERF CXXFLAGS += -DGGML_PERF endif +ifdef LLAMA_RPC +RPC_FLAGS = -DGGML_USE_RPC +else +RPC_FLAGS = +endif CCV := $(shell $(CC) --version | head -n 1) CXXV := $(shell $(CXX) --version | head -n 1) @@ -457,7 +462,14 @@ HIPBLAS_BUILD = $(HCXX) $(CXXFLAGS) $(HIPFLAGS) $^ -shared -o $@.so $(HIPLDFLAGS endif ifdef LLAMA_VULKAN VULKAN_BUILD = $(CXX) $(CXXFLAGS) $^ -lvulkan -shared -o $@.so $(LDFLAGS) +ifdef LLAMA_RPC +RPC_BUILD = $(CXX) $(CXXFLAGS) $(RPC_FLAGS) $^ -lvulkan -shared -o $@.so $(LDFLAGS) +endif endif +ifdef LLAMA_RPC +RPC_BUILD_WIN = $(CXX) $(CXXFLAGS) $(RPC_FLAGS) $^ -shared -o $@.dll $(LDFLAGS) +endif + endif ifndef LLAMA_CUBLAS @@ -946,6 +958,8 @@ quantize_clip: tools/mtmd/clip.cpp tools/quantclip.cpp ggml_v3.o ggml.o ggml-cpu quantize_ace: otherarch/acestep/quantize-acestep.cpp tools/mtmd/clip.cpp ggml_v3.o ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o ggml-backend.o ggml-backend-meta.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) +# Include RPC build targets (rpc-full-all works without manual flags) +include Makefile.rpc #window simple clinfo simplecpuinfo: simplecpuinfo.cpp diff --git a/Makefile.rpc b/Makefile.rpc new file mode 100644 index 00000000000..c2326553866 --- /dev/null +++ b/Makefile.rpc @@ -0,0 +1,120 @@ +# RPC build targets +# Usage: make rpc-full-all (automatically builds for all available backends) + +.PHONY: rpc-full-all rpc-servers-all rpc-backends-all + +# Main target - builds RPC for all backends with a single command +# This automatically detects hardware and builds appropriate backends +# Also builds standard backends for koboldcpp.py GUI +rpc-full-all: + @echo "=== KoboldCpp RPC Auto-Detection Build ===" + @echo "" + @echo "Detecting hardware..." + @HAS_AMD=0; HAS_NVIDIA=0; HAS_VULKAN=0; \ + if lspci -nn 2>/dev/null | grep -qi "1002:"; then HAS_AMD=1; echo " ✓ AMD GPU detected"; fi; \ + if lspci -nn 2>/dev/null | grep -qi "10de:"; then HAS_NVIDIA=1; echo " ✓ NVIDIA GPU detected"; fi; \ + if lspci -nn 2>/dev/null | grep -qi "8086:"; then echo " ✓ Intel GPU detected"; HAS_VULKAN=1; fi; \ + if [ $$HAS_AMD -eq 1 ] || [ $$HAS_NVIDIA -eq 1 ]; then HAS_VULKAN=1; fi; \ + if [ $$HAS_VULKAN -eq 1 ]; then echo " ✓ Vulkan support available"; fi; \ + echo "" + @echo "Building standard backends (required for koboldcpp.py GUI)..." + @echo "" + @echo "Building CPU backend (koboldcpp_default.so)..." + -$(MAKE) koboldcpp_default -j$(nproc) 2>&1 | tail -3 + @echo "" + @if pkg-config --exists vulkan 2>/dev/null || [ -f /usr/include/vulkan/vulkan.h ] || [ -f /usr/local/include/vulkan/vulkan.h ]; then \ + echo "Building Vulkan backend (koboldcpp_vulkan.so)..."; \ + $(MAKE) koboldcpp_vulkan -j$(nproc) LLAMA_VULKAN=1 2>&1 | tail -3; \ + else \ + echo "Vulkan headers not found, skipping Vulkan backend..."; \ + fi + @echo "" + @if lspci -nn 2>/dev/null | grep -qi "1002:" && (command -v hipcc &> /dev/null || pkg-config --exists rocm 2>/dev/null); then \ + echo "Building HIPBLAS backend (koboldcpp_hipblas.so)..."; \ + $(MAKE) koboldcpp_hipblas -j$(nproc) LLAMA_HIPBLAS=1 2>&1 | tail -3 || echo "HIPBLAS backend skipped"; \ + else \ + echo "AMD GPU or ROCm not available, skipping HIPBLAS backend..."; \ + fi + @echo "" + @echo "Building RPC backends..." + @echo "" + @if pkg-config --exists vulkan 2>/dev/null || [ -f /usr/include/vulkan/vulkan.h ] || [ -f /usr/local/include/vulkan/vulkan.h ]; then \ + echo "Building Vulkan RPC..."; \ + $(MAKE) LLAMA_RPC=1 LLAMA_VULKAN=1 rpc-server-vulkan koboldcpp_rpc || echo "Vulkan RPC build failed, continuing..."; \ + else \ + echo "Vulkan headers not found, skipping Vulkan RPC..."; \ + fi + @echo "" + @if lspci -nn 2>/dev/null | grep -qi "1002:" && (command -v hipcc &> /dev/null || pkg-config --exists rocm 2>/dev/null); then \ + echo "Building HIPBLAS RPC (AMD detected)..."; \ + $(MAKE) LLAMA_RPC=1 LLAMA_HIPBLAS=1 rpc-server-hip koboldcpp_hipblas_rpc || echo "HIPBLAS RPC build failed, continuing..."; \ + else \ + echo "AMD GPU or ROCm not available, skipping HIPBLAS RPC..."; \ + fi + @echo "" + @if lspci -nn 2>/dev/null | grep -qi "10de:" && command -v nvcc &> /dev/null; then \ + echo "Building CUDA RPC (NVIDIA detected)..."; \ + $(MAKE) LLAMA_RPC=1 LLAMA_CUBLAS=1 rpc-server-cuda koboldcpp_cublas_rpc || echo "CUDA RPC build failed, continuing..."; \ + else \ + echo "NVIDIA GPU or CUDA not available, skipping CUDA RPC..."; \ + fi + @echo "" + @echo "=== Build Summary ===" + @echo "Standard backends (for koboldcpp.py GUI):" + @ls -lh koboldcpp_default.so 2>/dev/null && echo " ✓ CPU backend" || echo " ✗ CPU backend" + @ls -lh koboldcpp_vulkan.so 2>/dev/null && echo " ✓ Vulkan backend" || echo " ✗ Vulkan backend" + @ls -lh koboldcpp_hipblas.so 2>/dev/null && echo " ✓ HIPBLAS backend" || echo " ✗ HIPBLAS backend" + @ls -lh koboldcpp_cublas.so 2>/dev/null && echo " ✓ CUDA backend" || echo " ✗ CUDA backend" + @echo "" + @echo "RPC backends (for distributed inference):" + @ls -lh rpc-server-vulkan koboldcpp_rpc.so 2>/dev/null && echo " ✓ Vulkan RPC" || echo " ✗ Vulkan RPC" + @ls -lh rpc-server-hip koboldcpp_hipblas_rpc.so 2>/dev/null && echo " ✓ HIPBLAS RPC" || echo " ✗ HIPBLAS RPC" + @ls -lh rpc-server-cuda koboldcpp_cublas_rpc.so 2>/dev/null && echo " ✓ CUDA RPC" || echo " ✗ CUDA RPC" + @echo "" + @echo "Usage:" + @echo " GUI: python ./koboldcpp.py" + @echo " Vulkan RPC: ./rpc-server-vulkan -H 127.0.0.1 --port 50053" + @echo " HIPBLAS RPC: ./rpc-server-hip -H 127.0.0.1 --port 50053" + @echo " CUDA RPC: ./rpc-server-cuda -H 127.0.0.1 --port 50053" + +# Vulkan RPC server and client +ifdef LLAMA_VULKAN +rpc-server-vulkan: tools/rpc-server.cpp ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o ggml-rpc.o transport.o ggml-backend.o ggml-backend-meta.o ggml-backend-reg_vulkan.o ggml-repack.o ggml-alloc.o ggml-cpu-traits.o ggml-quants.o ggml-cpu-quants.o kcpp-quantmapper.o kcpp-repackmapper.o unicode.o unicode-common.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm.o common.o llama-impl.o sampling.o budget.o kcpputils.o ggml-vulkan.o ggml-vulkan-shaders.o console.o + $(CXX) $(CXXFLAGS) $(VULKAN_FLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -lvulkan + @echo "Built rpc-server-vulkan" + +koboldcpp_rpc: ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o ggml_v3.o ggml_v2.o ggml_v1.o expose.o gpttype_adapter_vulkan.o ggml-vulkan.o ggml-vulkan-shaders.o sdcpp_vulkan.o whispercpp_vulkan.o tts_default.o music_default.o embeddings_default.o llavaclip_vulkan.o llava.o ggml-backend.o ggml-backend-meta.o ggml-backend-reg_vulkan.o ggml-repack.o $(OBJS_FULL) $(OBJS) ggml-rpc.o transport.o + $(RPC_BUILD) + @echo "Built koboldcpp_rpc (Vulkan RPC client)" +endif + +# HIPBLAS RPC server and client +ifdef LLAMA_HIPBLAS +rpc-server-hip: tools/rpc-server.cpp ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o ggml-rpc.o transport.o ggml-backend.o ggml-backend-meta.o ggml-backend-reg_cublas.o ggml-repack.o ggml-alloc.o ggml-cpu-traits.o ggml-quants.o ggml-cpu-quants.o kcpp-quantmapper.o kcpp-repackmapper.o unicode.o unicode-common.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm.o common.o llama-impl.o sampling.o budget.o kcpputils.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o $(HIP_OBJS) + $(HCXX) $(CXXFLAGS) $(HIPFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(HIPLDFLAGS) + @echo "Built rpc-server-hip" + +koboldcpp_hipblas_rpc: ggml_v4_cublas.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o gpttype_adapter_cublas.o sdcpp_cublas.o whispercpp_cublas.o tts_default.o music_default.o embeddings_default.o llavaclip_cublas.o llava.o ggml-backend.o ggml-backend-meta.o ggml-backend-reg_cublas.o ggml-repack.o $(HIP_OBJS) $(OBJS_FULL) $(OBJS) ggml-rpc.o transport.o + $(HCXX) $(CXXFLAGS) $(HIPFLAGS) $(filter-out %.h,$^) -shared -o $@.so $(LDFLAGS) $(HIPLDFLAGS) + @echo "Built koboldcpp_hipblas_rpc (HIPBLAS RPC client)" +endif + +# CUDA RPC server and client +ifdef LLAMA_CUBLAS +rpc-server-cuda: tools/rpc-server.cpp ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o ggml-rpc.o transport.o ggml-backend.o ggml-backend-meta.o ggml-backend-reg_cublas.o ggml-repack.o ggml-alloc.o ggml-cpu-traits.o ggml-quants.o ggml-cpu-quants.o kcpp-quantmapper.o kcpp-repackmapper.o unicode.o unicode-common.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm.o common.o llama-impl.o sampling.o budget.o kcpputils.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o ggml-cuda.o ggml_v2-cuda.o ggml_v2-cuda-legacy.o $(filter-out ggml_v3-cuda.o,$(CUBLAS_OBJS)) + $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CUBLASLD_FLAGS) -lcudart -lcublas -lcublasLt + @echo "Built rpc-server-cuda" + +koboldcpp_cublas_rpc: ggml_v4_cublas.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o gpttype_adapter_cublas.o sdcpp_cublas.o whispercpp_cublas.o tts_default.o music_default.o embeddings_default.o llavaclip_cublas.o llava.o ggml-backend.o ggml-backend-meta.o ggml-backend-reg_cublas.o ggml-repack.o $(CUBLAS_OBJS) $(OBJS_FULL) $(OBJS) ggml-rpc.o transport.o + $(RPC_BUILD) + @echo "Built koboldcpp_cublas_rpc (CUDA RPC client)" +endif + +# RPC object files +ifdef LLAMA_RPC +ggml-rpc.o: ggml/src/ggml-rpc/ggml-rpc.cpp ggml/include/ggml-rpc.h ggml/src/ggml-rpc/transport.h + $(CXX) $(CXXFLAGS) $(RPC_FLAGS) -c $< -o $@ + +transport.o: ggml/src/ggml-rpc/transport.cpp ggml/src/ggml-rpc/transport.h + $(CXX) $(CXXFLAGS) $(RPC_FLAGS) -c $< -o $@ +endif diff --git a/ggml/src/ggml-rpc/CMakeLists.txt b/ggml/src/ggml-rpc/CMakeLists.txt new file mode 100644 index 00000000000..40e11fead63 --- /dev/null +++ b/ggml/src/ggml-rpc/CMakeLists.txt @@ -0,0 +1,33 @@ +message(STATUS "Using RPC backend") + +ggml_add_backend_library(ggml-rpc + ggml-rpc.cpp + transport.cpp + ) + +if (WIN32) + target_link_libraries(ggml-rpc PRIVATE ws2_32) +endif() + +# RDMA auto-detection (Linux only, requires libibverbs) +if (NOT WIN32 AND NOT APPLE) + find_library(IBVERBS_LIB ibverbs) + if (IBVERBS_LIB) + option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" ON) + else() + option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" OFF) + endif() +else() + set(GGML_RPC_RDMA OFF CACHE BOOL "RDMA not available on this platform" FORCE) +endif() + +if (GGML_RPC_RDMA) + if (NOT IBVERBS_LIB) + find_library(IBVERBS_LIB ibverbs REQUIRED) + endif() + target_compile_definitions(ggml-rpc PRIVATE GGML_RPC_RDMA) + target_link_libraries(ggml-rpc PRIVATE ${IBVERBS_LIB}) + message(STATUS " RDMA transport enabled (auto-detected)") +else() + message(STATUS " RDMA transport disabled") +endif() diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp new file mode 100644 index 00000000000..a3224c0e379 --- /dev/null +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -0,0 +1,2016 @@ +#include "ggml-rpc.h" +#include "ggml-impl.h" +#include "ggml-backend-impl.h" +#include "ggml-cpp.h" +#include "transport.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); + +#define LOG_DBG(...) \ + do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0) + + +namespace fs = std::filesystem; + +// macro for nicer error messages on server crash +#define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response") + +// all RPC structures must be packed +#pragma pack(push, 1) +// ggml_tensor is serialized into rpc_tensor +struct rpc_tensor { + uint64_t id; + uint32_t type; + uint64_t buffer; + uint32_t ne[GGML_MAX_DIMS]; + uint32_t nb[GGML_MAX_DIMS]; + uint32_t op; + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; + int32_t flags; + uint64_t src[GGML_MAX_SRC]; + uint64_t view_src; + uint64_t view_offs; + uint64_t data; + char name[GGML_MAX_NAME]; + + char padding[4]; +}; + +static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8"); + +// RPC commands +enum rpc_cmd { + RPC_CMD_ALLOC_BUFFER = 0, + RPC_CMD_GET_ALIGNMENT, + RPC_CMD_GET_MAX_SIZE, + RPC_CMD_BUFFER_GET_BASE, + RPC_CMD_FREE_BUFFER, + RPC_CMD_BUFFER_CLEAR, + RPC_CMD_SET_TENSOR, + RPC_CMD_SET_TENSOR_HASH, + RPC_CMD_GET_TENSOR, + RPC_CMD_COPY_TENSOR, + RPC_CMD_GRAPH_COMPUTE, + RPC_CMD_GET_DEVICE_MEMORY, + RPC_CMD_INIT_TENSOR, + RPC_CMD_GET_ALLOC_SIZE, + RPC_CMD_HELLO, + RPC_CMD_DEVICE_COUNT, + RPC_CMD_GRAPH_RECOMPUTE, + RPC_CMD_COUNT, +}; + +static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14"); + +// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold +const size_t HASH_THRESHOLD = 10 * 1024 * 1024; + +struct rpc_msg_hello_req { + uint8_t conn_caps[RPC_CONN_CAPS_SIZE]; +}; + +struct rpc_msg_hello_rsp { + uint8_t major; + uint8_t minor; + uint8_t patch; + uint8_t padding; + uint8_t conn_caps[RPC_CONN_CAPS_SIZE]; +}; + +struct rpc_msg_device_count_rsp { + uint32_t device_count; +}; + +struct rpc_msg_get_alloc_size_req { + uint32_t device; + rpc_tensor tensor; + rpc_tensor srcs[GGML_MAX_SRC]; +}; + +struct rpc_msg_get_alloc_size_rsp { + uint64_t alloc_size; +}; + +struct rpc_msg_init_tensor_req { + rpc_tensor tensor; +}; + +struct rpc_msg_alloc_buffer_req { + uint32_t device; + uint64_t size; +}; + +struct rpc_msg_alloc_buffer_rsp { + uint64_t remote_ptr; + uint64_t remote_size; +}; + +struct rpc_msg_get_alignment_req { + uint32_t device; +}; + +struct rpc_msg_get_alignment_rsp { + uint64_t alignment; +}; + +struct rpc_msg_get_max_size_req { + uint32_t device; +}; + +struct rpc_msg_get_max_size_rsp { + uint64_t max_size; +}; + +struct rpc_msg_buffer_get_base_req { + uint64_t remote_ptr; +}; + +struct rpc_msg_buffer_get_base_rsp { + uint64_t base_ptr; +}; + +struct rpc_msg_free_buffer_req { + uint64_t remote_ptr; +}; + +struct rpc_msg_buffer_clear_req { + uint64_t remote_ptr; + uint8_t value; +}; + +struct rpc_msg_set_tensor_hash_req { + rpc_tensor tensor; + uint64_t offset; + uint64_t hash; +}; + +struct rpc_msg_set_tensor_hash_rsp { + uint8_t result; +}; + +struct rpc_msg_get_tensor_req { + rpc_tensor tensor; + uint64_t offset; + uint64_t size; +}; + +struct rpc_msg_copy_tensor_req { + rpc_tensor src; + rpc_tensor dst; +}; + +struct rpc_msg_copy_tensor_rsp { + uint8_t result; +}; + +struct rpc_msg_get_device_memory_req { + uint32_t device; +}; + +struct rpc_msg_get_device_memory_rsp { + uint64_t free_mem; + uint64_t total_mem; +}; + +struct rpc_msg_graph_recompute_req { + uint32_t device; +}; + +#pragma pack(pop) + +// RPC data structures + +static ggml_guid_t ggml_backend_rpc_guid() { + static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03}; + return &guid; +} + +struct ggml_backend_rpc_buffer_type_context { + std::string endpoint; + uint32_t device; + std::string name; + size_t alignment; + size_t max_size; +}; + +struct graph_cache { + + bool is_cached(const ggml_cgraph * cgraph) { + if ((int)last_graph.size() != cgraph->n_nodes) { + return false; + } + for (int i = 0; i < cgraph->n_nodes; i++) { + if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) { + return false; + } + } + return true; + } + + void add(const ggml_cgraph * cgraph) { + last_graph.resize(cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; i++) { + memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)); + } + } + + std::vector last_graph; +}; + +struct ggml_backend_rpc_context { + std::string endpoint; + uint32_t device; + std::string name; + graph_cache gc; +}; + +struct ggml_backend_rpc_buffer_context { + std::shared_ptr sock; + void * base_ptr; + uint64_t remote_ptr; +}; + +// RPC helper functions + +// Computes FNV-1a hash of the data +static uint64_t fnv_hash(const uint8_t * data, size_t len) { + const uint64_t fnv_prime = 0x100000001b3ULL; + uint64_t hash = 0xcbf29ce484222325ULL; + + for (size_t i = 0; i < len; ++i) { + hash ^= data[i]; + hash *= fnv_prime; + } + return hash; +} + +static bool send_msg(socket_ptr sock, const void * msg, size_t msg_size) { + if (!sock->send_data(&msg_size, sizeof(msg_size))) { + return false; + } + return sock->send_data(msg, msg_size); +} + +static bool recv_msg(socket_ptr sock, void * msg, size_t msg_size) { + uint64_t size; + if (!sock->recv_data(&size, sizeof(size))) { + return false; + } + if (size != msg_size) { + return false; + } + return sock->recv_data(msg, msg_size); +} + +static bool recv_msg(socket_ptr sock, std::vector & input) { + uint64_t size; + if (!sock->recv_data(&size, sizeof(size))) { + return false; + } + try { + input.resize(size); + } catch (const std::bad_alloc & e) { + GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size); + return false; + } + return sock->recv_data(input.data(), size); +} + +static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { + size_t pos = endpoint.find(':'); + if (pos == std::string::npos) { + return false; + } + host = endpoint.substr(0, pos); + try { + port = std::stoi(endpoint.substr(pos + 1)); + } catch (...) { + return false; + } + return true; +} + +// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | +// No response +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size) { + uint8_t cmd_byte = cmd; + if (!sock->send_data(&cmd_byte, sizeof(cmd_byte))) { + return false; + } + if (!sock->send_data(&input_size, sizeof(input_size))) { + return false; + } + if (!sock->send_data(input, input_size)) { + return false; + } + return true; +} + +// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | +// RPC response: | response_size (8 bytes) | response_data (response_size bytes) | +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { + if (!send_rpc_cmd(sock, cmd, input, input_size)) { + return false; + } + uint64_t out_size; + if (!sock->recv_data(&out_size, sizeof(out_size))) { + return false; + } + if (out_size != output_size) { + return false; + } + if (!sock->recv_data(output, output_size)) { + return false; + } + return true; +} + +// RPC client-side implementation + +// Performs HELLO handshake with transport auto-negotiation. +// Advertises local capabilities via conn_caps; if the server responds with +// matching capabilities, the socket is upgraded transparently. +static bool negotiate_hello(const std::shared_ptr & sock) { + rpc_msg_hello_req request = {}; + rpc_msg_hello_rsp response = {}; + + sock->get_caps(request.conn_caps); + + bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + + if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) { + GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", + response.major, response.minor, response.patch); + return false; + } + + sock->update_caps(response.conn_caps); + return true; +} + +static std::shared_ptr get_socket(const std::string & endpoint) { + static std::mutex mutex; + std::lock_guard lock(mutex); + static std::unordered_map> sockets; + + auto it = sockets.find(endpoint); + if (it != sockets.end()) { + if (auto sock = it->second.lock()) { + return sock; + } + } + std::string host; + int port; + if (!parse_endpoint(endpoint, host, port)) { + GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str()); + return nullptr; + } + + if (!rpc_transport_init()) { + return nullptr; + } + auto sock = socket_t::connect(host.c_str(), port); + if (sock == nullptr) { + return nullptr; + } + if (!negotiate_hello(sock)) { + return nullptr; + } + LOG_DBG("[%s] connected to %s\n", __func__, endpoint.c_str()); + sockets[endpoint] = sock; + return sock; +} + +static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_free_buffer_req request = {ctx->remote_ptr}; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0); + RPC_STATUS_ASSERT(status); + delete ctx; +} + +static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + if (ctx->base_ptr != nullptr) { + return ctx->base_ptr; + } + rpc_msg_buffer_get_base_req request = {ctx->remote_ptr}; + rpc_msg_buffer_get_base_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + ctx->base_ptr = reinterpret_cast(response.base_ptr); + return ctx->base_ptr; +} + +static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer; +} + +static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { + rpc_tensor result; + if (!tensor) { + memset(&result, 0, sizeof(result)); + return result; + } + + result.id = reinterpret_cast(tensor); + result.type = tensor->type; + if (tensor->buffer && ggml_backend_buffer_is_rpc(tensor->buffer)) { + ggml_backend_buffer_t buffer = tensor->buffer; + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + result.buffer = ctx != nullptr ? ctx->remote_ptr : 0; + result.data = reinterpret_cast(tensor->data); + } else { + result.buffer = 0; + result.data = 0; + } + for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { + result.ne[i] = tensor->ne[i]; + result.nb[i] = tensor->nb[i]; + } + result.op = tensor->op; + for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { + result.op_params[i] = tensor->op_params[i]; + } + result.flags = tensor->flags; + for (uint32_t i = 0; i < GGML_MAX_SRC; i++) { + result.src[i] = reinterpret_cast(tensor->src[i]); + } + result.view_src = reinterpret_cast(tensor->view_src); + result.view_offs = tensor->view_offs; + + // Avoid sending uninitialized data over the wire + memset(result.name, 0, sizeof(result.name)); + memset(result.padding, 0, sizeof(result.padding)); + + snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name); + return result; +} + +static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + + // CUDA backend on the server pads everything to 512 due to CUDA limitations. + // Due to bandwidth constraints, we only call the server init tensor functions if necessary. + // In particular, only quantized tensors need padding + if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) { + rpc_msg_init_tensor_req request; + + request.tensor = serialize_tensor(tensor); + + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0); + RPC_STATUS_ASSERT(status); + } + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_tensor rpc_tensor = serialize_tensor(tensor); + if (size > HASH_THRESHOLD) { + rpc_msg_set_tensor_hash_req request; + request.tensor = rpc_tensor; + request.offset = offset; + request.hash = fnv_hash((const uint8_t*)data, size); + rpc_msg_set_tensor_hash_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + if (response.result) { + // the server has the same data, no need to send it + return; + } + } + // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) + size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size; + std::vector input(input_size, 0); + memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); + memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); + memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size()); + RPC_STATUS_ASSERT(status); +} + +static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_get_tensor_req request; + request.tensor = serialize_tensor(tensor); + request.offset = offset; + request.size = size; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size); + RPC_STATUS_ASSERT(status); +} + +static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + if (ggml_backend_buffer_is_rpc(src->buffer)) { + // check if src and dst are on the same server + ggml_backend_buffer_t src_buffer = src->buffer; + ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; + ggml_backend_buffer_t dst_buffer = dst->buffer; + ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; + if (src_ctx->sock != dst_ctx->sock) { + return false; + } + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_copy_tensor_req request; + request.src = serialize_tensor(src); + request.dst = serialize_tensor(dst); + rpc_msg_copy_tensor_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.result; + } + return false; +} + +static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value}; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0); + RPC_STATUS_ASSERT(status); +} + +static void ggml_backend_rpc_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_tensor rpc_tensor = serialize_tensor(tensor); + size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + sizeof(uint64_t) + sizeof(uint8_t) + sizeof(size_t); + std::vector input(input_size, 0); + memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); + memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); + memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size)); + memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset) + sizeof(size), &value, sizeof(value)); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input.data(), input.size()); + RPC_STATUS_ASSERT(status); +} + +static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = { + /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer, + /* .get_base = */ ggml_backend_rpc_buffer_get_base, + /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_rpc_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, + /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor, + /* .clear = */ ggml_backend_rpc_buffer_clear, + /* .reset = */ NULL, +}; + +static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + return buft_ctx->name.c_str(); +} + +static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + rpc_msg_alloc_buffer_req request = {buft_ctx->device, size}; + rpc_msg_alloc_buffer_rsp response; + auto sock = get_socket(buft_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + if (response.remote_ptr != 0) { + ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, + ggml_backend_rpc_buffer_interface, + new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr}, + response.remote_size); + return buffer; + } else { + return nullptr; + } +} + +static size_t get_alignment(const std::shared_ptr & sock, uint32_t device) { + rpc_msg_get_alignment_req request = {device}; + rpc_msg_get_alignment_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.alignment; +} + +static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + return buft_ctx->alignment; +} + +static size_t get_max_size(const std::shared_ptr & sock, uint32_t device) { + rpc_msg_get_max_size_req request = {device}; + rpc_msg_get_max_size_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.max_size; +} + +static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + return buft_ctx->max_size; +} + +static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + // should we query the remote server for the actual size + bool rpc_get = false; + + // See comments in init_tensor. + rpc_get |= ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr); + + // ops that require additional memory for fleeting data on certain backends + // ref: https://github.com/ggml-org/llama.cpp/pull/15966 + rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT; + rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID; + + if (rpc_get) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + auto sock = get_socket(buft_ctx->endpoint); + + rpc_msg_get_alloc_size_req request = { + /*.device =*/ buft_ctx->device, + /*.tensor =*/ serialize_tensor(tensor), + /*.srcs =*/ {}, + }; + + // .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well + for (int i = 0; i < GGML_MAX_SRC; i++) { + request.srcs[i] = serialize_tensor(tensor->src[i]); + } + + // TODO: cache the alloc responses to avoid extra RPC calls? + rpc_msg_get_alloc_size_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + + return response.alloc_size; + } + + return ggml_nbytes(tensor); +} + +static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { + /* .get_name = */ ggml_backend_rpc_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_rpc_get_max_size, + /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; + +static const char * ggml_backend_rpc_name(ggml_backend_t backend) { + ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + + return rpc_ctx->name.c_str(); +} + +static void ggml_backend_rpc_free(ggml_backend_t backend) { + ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + delete rpc_ctx; + delete backend; +} + +static void ggml_backend_rpc_synchronize(ggml_backend_t backend) { + GGML_UNUSED(backend); + // this is no-op because we don't have any async operations +} + +static void add_tensor(ggml_tensor * tensor, std::vector & tensors, std::unordered_set & visited) { + if (tensor == nullptr) { + return; + } + if (visited.find(tensor) != visited.end()) { + return; + } + visited.insert(tensor); + for (int i = 0; i < GGML_MAX_SRC; i++) { + add_tensor(tensor->src[i], tensors, visited); + } + add_tensor(tensor->view_src, tensors, visited); + tensors.push_back(serialize_tensor(tensor)); +} + +static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector & output) { + uint32_t n_nodes = cgraph->n_nodes; + std::vector tensors; + std::unordered_set visited; + for (uint32_t i = 0; i < n_nodes; i++) { + add_tensor(cgraph->nodes[i], tensors, visited); + } + // serialization format: + // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + uint32_t n_tensors = tensors.size(); + int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); + output.resize(output_size, 0); + uint8_t * dest = output.data(); + memcpy(dest, &device, sizeof(device)); + dest += sizeof(device); + memcpy(dest, &n_nodes, sizeof(n_nodes)); + dest += sizeof(n_nodes); + for (uint32_t i = 0; i < n_nodes; i++) { + memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); + } + dest += n_nodes * sizeof(uint64_t); + memcpy(dest, &n_tensors, sizeof(n_tensors)); + dest += sizeof(n_tensors); + rpc_tensor * out_tensors = (rpc_tensor *)dest; + memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor)); +} + +static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + + GGML_ASSERT(cgraph->n_nodes > 0); + bool reuse = rpc_ctx->gc.is_cached(cgraph); + if (reuse) { + rpc_msg_graph_recompute_req request; + request.device = rpc_ctx->device; + auto sock = get_socket(rpc_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request)); + RPC_STATUS_ASSERT(status); + } else { + rpc_ctx->gc.add(cgraph); + std::vector input; + serialize_graph(rpc_ctx->device, cgraph, input); + auto sock = get_socket(rpc_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size()); + RPC_STATUS_ASSERT(status); + } + return GGML_STATUS_SUCCESS; +} + +static ggml_backend_i ggml_backend_rpc_interface = { + /* .get_name = */ ggml_backend_rpc_name, + /* .free = */ ggml_backend_rpc_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ ggml_backend_rpc_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_rpc_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, +}; + +ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) { + static std::mutex mutex; + std::lock_guard lock(mutex); + auto reg = ggml_backend_rpc_add_server(endpoint); + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(reg, device); + const char * dev_name = ggml_backend_dev_name(dev); + // NOTE: buffer types are allocated and never freed; this is by design + static std::unordered_map buft_map; + auto it = buft_map.find(dev_name); + if (it != buft_map.end()) { + return it->second; + } + auto sock = get_socket(endpoint); + if (sock == nullptr) { + GGML_LOG_ERROR("Failed to connect to %s\n", endpoint); + return nullptr; + } + size_t alignment = get_alignment(sock, device); + size_t max_size = get_max_size(sock, device); + ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { + /* .endpoint = */ endpoint, + /* .device = */ device, + /* .name = */ dev_name, + /* .alignment = */ alignment, + /* .max_size = */ max_size + }; + ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { + /* .iface = */ ggml_backend_rpc_buffer_type_interface, + /* .device = */ dev, + /* .context = */ buft_ctx + }; + buft_map[dev_name] = buft; + return buft; +} + +ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { + auto reg = ggml_backend_rpc_add_server(endpoint); + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(reg, device); + const char * dev_name = ggml_backend_dev_name(dev); + ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { + /* .endpoint = */ endpoint, + /* .device = */ device, + /* .name = */ dev_name, + /* .gc = */ {}, + }; + ggml_backend_t backend = new ggml_backend { + /* .guid = */ ggml_backend_rpc_guid(), + /* .iface = */ ggml_backend_rpc_interface, + /* .device = */ dev, + /* .context = */ ctx + }; + return backend; +} + +bool ggml_backend_is_rpc(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid()); +} + +static void get_device_memory(const std::shared_ptr & sock, uint32_t device, size_t * free, size_t * total) { + rpc_msg_get_device_memory_req request; + request.device = device; + rpc_msg_get_device_memory_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + *free = response.free_mem; + *total = response.total_mem; +} + +void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) { + auto sock = get_socket(endpoint); + if (sock == nullptr) { + *free = 0; + *total = 0; + return; + } + get_device_memory(sock, device, free, total); +} + +// RPC server-side implementation + +class rpc_server { +public: + rpc_server(std::vector all_backends, const char * cache_dir) + : backends(std::move(all_backends)), cache_dir(cache_dir) { + stored_graphs.resize(backends.size()); + } + ~rpc_server(); + + void hello(rpc_msg_hello_rsp & response); + bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); + bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response); + bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response); + bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response); + bool free_buffer(const rpc_msg_free_buffer_req & request); + bool buffer_clear(const rpc_msg_buffer_clear_req & request); + bool set_tensor(const std::vector & input); + bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response); + bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response); + bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); + bool graph_compute(const std::vector & input); + bool graph_recompute(const rpc_msg_graph_recompute_req & request); + bool init_tensor(const rpc_msg_init_tensor_req & request); + bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); + bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); + + struct stored_graph { + std::vector buffer; + ggml_cgraph * graph; + }; + +private: + bool get_cached_file(uint64_t hash, std::vector & data); + ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); + ggml_tensor * create_node(uint64_t id, + struct ggml_context * ctx, + const std::unordered_map & tensor_ptrs, + std::unordered_map & tensor_map); + + + std::vector backends; + const char * cache_dir; + std::unordered_set buffers; + // store the last computed graph for each backend + std::vector stored_graphs; +}; + +void rpc_server::hello(rpc_msg_hello_rsp & response) { + response.major = RPC_PROTO_MAJOR_VERSION; + response.minor = RPC_PROTO_MINOR_VERSION; + response.patch = RPC_PROTO_PATCH_VERSION; + LOG_DBG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch); +} + +bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead()*(1 + GGML_MAX_SRC), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + if (tensor == nullptr) { + GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n"); + return false; + } + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (request.srcs[i].id != 0) { + tensor->src[i] = deserialize_tensor(ctx, &request.srcs[i]); + } + } + + LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data); + if (tensor->buffer == nullptr) { + //No buffer allocated. + buft = ggml_backend_get_default_buffer_type(backends[dev_id]); + } else { + buft = tensor->buffer->buft; + } + + response.alloc_size = ggml_backend_buft_get_alloc_size(buft, tensor); + + return true; +} + +bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size); + response.remote_ptr = 0; + response.remote_size = 0; + if (buffer != nullptr) { + response.remote_ptr = reinterpret_cast(buffer); + response.remote_size = buffer->size; + LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", + __func__, dev_id, request.size, response.remote_ptr, response.remote_size); + buffers.insert(buffer); + } else { + LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size); + } + return true; +} + +bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); + size_t alignment = ggml_backend_buft_get_alignment(buft); + LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment); + response.alignment = alignment; + return true; +} + +bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); + size_t max_size = ggml_backend_buft_get_max_size(buft); + LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size); + response.max_size = max_size; + return true; +} + +bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) { + LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr); + ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); + if (buffers.find(buffer) == buffers.end()) { + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); + return false; + } + void * base = ggml_backend_buffer_get_base(buffer); + response.base_ptr = reinterpret_cast(base); + return true; +} + +bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) { + LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr); + ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); + if (buffers.find(buffer) == buffers.end()) { + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); + return false; + } + ggml_backend_buffer_free(buffer); + buffers.erase(buffer); + return true; +} + +bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) { + LOG_DBG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value); + ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); + if (buffers.find(buffer) == buffers.end()) { + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); + return false; + } + ggml_backend_buffer_clear(buffer, request.value); + return true; +} + +ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) { + // Validate tensor type before using it + if (tensor->type >= GGML_TYPE_COUNT) { + GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type); + return nullptr; + } + + // Fix: Prevent division by zero if blck_size is 0 (e.g., deprecated types) + if (ggml_blck_size((enum ggml_type)tensor->type) == 0) { + GGML_LOG_ERROR("[%s] invalid tensor type received (blck_size is 0): %u\n", __func__, tensor->type); + return nullptr; + } + + ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, + tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + + // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type + if (result == nullptr) { + GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\n", __func__, tensor->type); + return nullptr; + } + + for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { + result->nb[i] = tensor->nb[i]; + } + result->buffer = reinterpret_cast(tensor->buffer); + if (result->buffer && buffers.find(result->buffer) == buffers.end()) { + result->buffer = nullptr; + } + + if (result->buffer) { + // require that the tensor data does not go beyond the buffer end + uint64_t tensor_size = (uint64_t) ggml_nbytes(result); + uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer); + uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer); + GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow + GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size); + } + + result->op = (ggml_op) tensor->op; + for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { + result->op_params[i] = tensor->op_params[i]; + } + result->flags = tensor->flags; + result->data = reinterpret_cast(tensor->data); + ggml_set_name(result, tensor->name); + return result; +} + + +bool rpc_server::set_tensor(const std::vector & input) { + // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) | + if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) { + return false; + } + const rpc_tensor * in_tensor = (const rpc_tensor *)input.data(); + uint64_t offset; + memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset)); + const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset); + + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); + if (tensor == nullptr || tensor->buffer == nullptr) { + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); + return false; + } + LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size); + + // sanitize tensor->data + { + const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer); + const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); + + if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) { + GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n", + __func__, in_tensor->data, offset, size, p0, p1); + return false; + } + } + + const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset); + if (cache_dir && size > HASH_THRESHOLD) { + uint64_t hash = fnv_hash((const uint8_t*)data, size); + char hash_str[17]; + snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash); + // save to cache_dir/hash_str + fs::path cache_file = fs::path(cache_dir) / hash_str; + std::ofstream ofs(cache_file, std::ios::binary); + ofs.write((const char *)data, size); + GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.c_str()); + } + ggml_backend_tensor_set(tensor, data, offset, size); + return true; +} + +bool rpc_server::get_cached_file(uint64_t hash, std::vector & data) { + if (!cache_dir) { + return false; + } + char hash_str[17]; + snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash); + fs::path cache_file = fs::path(cache_dir) / hash_str; + std::error_code ec; + if (!fs::exists(cache_file, ec)) { + return false; + } + std::ifstream ifs(cache_file, std::ios::binary); + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + data.resize(size); + ifs.read((char *)data.data(), size); + return true; +} + +bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response) +{ + std::vector cached_file; + if (!get_cached_file(request.hash, cached_file)) { + response.result = 0; + return true; + } + size_t size = cached_file.size(); + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + if (tensor == nullptr || tensor->buffer == nullptr) { + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); + return false; + } + LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", + __func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash); + + // sanitize tensor->data + { + const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer); + const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); + + if (request.tensor.data + request.offset < p0 + || request.tensor.data + request.offset >= p1 + || size > (p1 - request.tensor.data - request.offset)) { + GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n", + __func__, request.tensor.data, request.offset, size, request.hash, p0, p1); + return false; + } + } + ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size); + response.result = 1; + return true; +} + +bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + if (tensor == nullptr) { + GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n"); + return false; + } + LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data); + // Call the backend's buffer_init_tensor function + ggml_backend_buffer_t buffer = tensor->buffer; + if (buffer && buffer->iface.init_tensor) { + buffer->iface.init_tensor(buffer, tensor); + } else { + if (!buffer) { + GGML_LOG_ERROR("Tensor with null buffer passed to init_tensor function\n"); + } + } + + if (tensor->extra != nullptr) { + // This pointer can either be passed around client/server, or probably better stored server-side and kept track of. + // Currently unimplemented. + GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n"); + return false; + } + + return true; +} + +bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response) { + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + if (tensor == nullptr || tensor->buffer == nullptr) { + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); + return false; + } + LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size); + + // sanitize tensor->data + { + const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer); + const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); + + if (request.tensor.data + request.offset < p0 || + request.tensor.data + request.offset >= p1 || + request.size > (p1 - request.tensor.data - request.offset)) { + GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n", + __func__, request.tensor.data, request.offset, request.size, p0, p1); + return false; + } + } + + response.resize(request.size, 0); + ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size); + return true; +} + +bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) { + struct ggml_init_params params { + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * src = deserialize_tensor(ctx, &request.src); + ggml_tensor * dst = deserialize_tensor(ctx, &request.dst); + if (src == nullptr || dst == nullptr || src->buffer == nullptr || dst->buffer == nullptr) { + GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__); + return false; + } + + uint64_t src_size = (uint64_t) ggml_nbytes(src); + uint64_t dst_data = (uint64_t) dst->data; + uint64_t dst_base = (uint64_t) ggml_backend_buffer_get_base(dst->buffer); + uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer); + + if (dst_data + src_size > dst_base + dst_buf_sz) { + GGML_LOG_ERROR("[%s] out-of-bounds write in rpc_server::copy_tensor:\n" + " write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n" + " buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n", + __func__, + dst_data, + dst_data + src_size, + dst_base, + dst_base + dst_buf_sz); + return false; + } + + LOG_DBG("[%s] src->buffer: %p, dst->buffer: %p\n", + __func__, (void*) src->buffer, (void*) dst->buffer); + + response.result = ggml_backend_buffer_copy_tensor(src, dst); + return true; +} + +ggml_tensor * rpc_server::create_node(uint64_t id, + struct ggml_context * ctx, + const std::unordered_map & tensor_ptrs, + std::unordered_map & tensor_map) { + if (tensor_map.find(id) != tensor_map.end()) { + return tensor_map[id]; + } + // Safely find the tensor pointer + auto it_ptr = tensor_ptrs.find(id); + if (it_ptr == tensor_ptrs.end()) { + return nullptr; + } + const rpc_tensor * tensor = it_ptr->second; + + struct ggml_tensor * result = deserialize_tensor(ctx, tensor); + if (result == nullptr) { + return nullptr; + } + if (result->buffer == nullptr && result->data != nullptr) { + GGML_LOG_ERROR("[%s] invalid data ptr", __func__); + return nullptr; + } + tensor_map[id] = result; + for (int i = 0; i < GGML_MAX_SRC; i++) { + // Check if the source ID is 0 before calling create_node recursively + if (tensor->src[i] == 0) { + result->src[i] = nullptr; + } else { + result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map); + // If the recursive call failed for a non-zero ID, propagate the error + if (result->src[i] == nullptr) { + GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n", + __func__, i, tensor->src[i], id); + // Must return nullptr to signal failure up the call stack + return nullptr; + } + } + } + + // Handle view_src similarly + if (tensor->view_src == 0) { + result->view_src = nullptr; + } else { + result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map); + // If the recursive call failed for a non-zero ID, propagate the error + if (result->view_src == nullptr) { + GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n", + __func__, tensor->view_src, id); + // Must return nullptr to signal failure up the call stack + return nullptr; + } + } + result->view_offs = tensor->view_offs; + return result; +} + +bool rpc_server::graph_compute(const std::vector & input) { + // serialization format: + // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + if (input.size() < 2*sizeof(uint32_t)) { + return false; + } + const uint8_t * src = input.data(); + uint32_t device; + memcpy(&device, src, sizeof(device)); + src += sizeof(device); + if (device >= backends.size()) { + return false; + } + uint32_t n_nodes; + memcpy(&n_nodes, src, sizeof(n_nodes)); + src += sizeof(n_nodes); + if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) { + return false; + } + const uint64_t * nodes = (const uint64_t *)src; + src += n_nodes*sizeof(uint64_t); + uint32_t n_tensors; + memcpy(&n_tensors, src, sizeof(n_tensors)); + src += sizeof(n_tensors); + if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) { + return false; + } + const rpc_tensor * tensors = (const rpc_tensor *)src; + LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors); + + size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); + if (stored_graphs[device].buffer.size() < buf_size) { + stored_graphs[device].buffer.resize(buf_size); + } + struct ggml_init_params params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ stored_graphs[device].buffer.data(), + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); + graph->n_nodes = n_nodes; + std::unordered_map tensor_ptrs; + tensor_ptrs.reserve(n_tensors); + for (uint32_t i = 0; i < n_tensors; i++) { + tensor_ptrs.emplace(tensors[i].id, &tensors[i]); + } + std::unordered_map tensor_map; + tensor_map.reserve(n_nodes); + for (uint32_t i = 0; i < n_nodes; i++) { + int64_t id; + memcpy(&id, &nodes[i], sizeof(id)); + graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map); + + // Check if create_node failed for a *non-zero* ID. + // If id was 0, create_node returning nullptr is expected. + // If id was non-zero and create_node returned nullptr, it indicates a deserialization error. + if (graph->nodes[i] == nullptr && id != 0) { + GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id); + return false; + } + } + ggml_status status = ggml_backend_graph_compute(backends[device], graph); + GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); + stored_graphs[device].graph = graph; + return true; +} + +bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) { + uint32_t device = request.device; + if (device >= backends.size()) { + return false; + } + if (stored_graphs[device].graph == nullptr) { + return false; + } + ggml_cgraph * graph = stored_graphs[device].graph; + LOG_DBG("[%s] device: %u\n", __func__, device); + ggml_status status = ggml_backend_graph_compute(backends[device], graph); + GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); + return true; +} + +bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + size_t free, total; + ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]); + ggml_backend_dev_memory(dev, &free, &total); + response.free_mem = free; + response.total_mem = total; + LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem); + return true; +} + +rpc_server::~rpc_server() { + for (auto buffer : buffers) { + ggml_backend_buffer_free(buffer); + } +} + +static void rpc_serve_client(const std::vector & backends, const char * cache_dir, + socket_ptr sock) { + rpc_server server(backends, cache_dir); + uint8_t cmd; + if (!sock->recv_data(&cmd, 1)) { + return; + } + if (cmd != RPC_CMD_HELLO) { + GGML_LOG_ERROR("Expected HELLO command, update client\n"); + return; + } + + // Read input_size and validate protocol version + uint64_t hello_input_size; + if (!sock->recv_data(&hello_input_size, sizeof(hello_input_size))) { + return; + } + + if (hello_input_size != sizeof(rpc_msg_hello_req)) { + GGML_LOG_ERROR("HELLO request size mismatch (%zu vs %zu) — client needs upgrade to protocol v%d.x\n", + (size_t)hello_input_size, sizeof(rpc_msg_hello_req), RPC_PROTO_MAJOR_VERSION); + return; + } + + rpc_msg_hello_req req = {}; + if (!sock->recv_data(&req, sizeof(req))) { + return; + } + + rpc_msg_hello_rsp rsp = {}; + server.hello(rsp); + // Advertise server transport capabilities based on client's caps + sock->get_caps(rsp.conn_caps); + if (!send_msg(sock, &rsp, sizeof(rsp))) { + return; + } + + // Activate transport upgrade using client's caps + sock->update_caps(req.conn_caps); + while (true) { + if (!sock->recv_data(&cmd, 1)) { + break; + } + if (cmd >= RPC_CMD_COUNT) { + // fail fast if the command is invalid + GGML_LOG_ERROR("Unknown command: %d\n", cmd); + break; + } + switch (cmd) { + case RPC_CMD_HELLO: { + // HELLO command is handled above + return; + } + case RPC_CMD_DEVICE_COUNT: { + if (!recv_msg(sock, nullptr, 0)) { + return; + } + rpc_msg_device_count_rsp response; + response.device_count = backends.size(); + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_ALLOC_BUFFER: { + rpc_msg_alloc_buffer_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_alloc_buffer_rsp response; + if (!server.alloc_buffer(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_GET_ALLOC_SIZE: { + rpc_msg_get_alloc_size_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_get_alloc_size_rsp response; + if (!server.get_alloc_size(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_GET_ALIGNMENT: { + rpc_msg_get_alignment_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_get_alignment_rsp response; + if (!server.get_alignment(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_GET_MAX_SIZE: { + rpc_msg_get_max_size_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_get_max_size_rsp response; + if (!server.get_max_size(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_BUFFER_GET_BASE: { + rpc_msg_buffer_get_base_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_buffer_get_base_rsp response; + if (!server.buffer_get_base(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_FREE_BUFFER: { + rpc_msg_free_buffer_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + if (!server.free_buffer(request)) { + return; + } + if (!send_msg(sock, nullptr, 0)) { + return; + } + break; + } + case RPC_CMD_BUFFER_CLEAR: { + rpc_msg_buffer_clear_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + if (!server.buffer_clear(request)) { + return; + } + if (!send_msg(sock, nullptr, 0)) { + return; + } + break; + } + case RPC_CMD_SET_TENSOR: { + std::vector input; + if (!recv_msg(sock, input)) { + return; + } + if (!server.set_tensor(input)) { + return; + } + break; + } + case RPC_CMD_SET_TENSOR_HASH: { + rpc_msg_set_tensor_hash_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_set_tensor_hash_rsp response; + if (!server.set_tensor_hash(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_INIT_TENSOR: { + rpc_msg_init_tensor_req request; + if (!recv_msg(sock, &request,sizeof(request))) { + return; + } + if (!server.init_tensor(request)) { + return; + } + if (!send_msg(sock, nullptr, 0)) { + return; + } + break; + } + case RPC_CMD_GET_TENSOR: { + rpc_msg_get_tensor_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + std::vector response; + if (!server.get_tensor(request, response)) { + return; + } + if (!send_msg(sock, response.data(), response.size())) { + return; + } + break; + } + case RPC_CMD_COPY_TENSOR: { + rpc_msg_copy_tensor_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_copy_tensor_rsp response; + if (!server.copy_tensor(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_GRAPH_COMPUTE: { + std::vector input; + if (!recv_msg(sock, input)) { + return; + } + if (!server.graph_compute(input)) { + return; + } + break; + } + case RPC_CMD_GRAPH_RECOMPUTE: { + rpc_msg_graph_recompute_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + if (!server.graph_recompute(request)) { + return; + } + break; + } + case RPC_CMD_GET_DEVICE_MEMORY: { + rpc_msg_get_device_memory_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_get_device_memory_rsp response; + if (!server.get_device_memory(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + default: { + GGML_LOG_ERROR("Unknown command: %d\n", cmd); + return; + } + } + } +} + +void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, + size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) { + if (n_devices == 0 || devices == nullptr) { + fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n"); + return; + } + std::vector backends; + printf("Starting RPC server v%d.%d.%d\n", + RPC_PROTO_MAJOR_VERSION, + RPC_PROTO_MINOR_VERSION, + RPC_PROTO_PATCH_VERSION); + printf(" endpoint : %s\n", endpoint); + printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a"); + printf("Devices:\n"); + for (size_t i = 0; i < n_devices; i++) { + auto dev = devices[i]; + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); + printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + total / 1024 / 1024, free / 1024 / 1024); + auto backend = ggml_backend_dev_init(dev, nullptr); + if (!backend) { + fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev)); + return; + } + backends.push_back(backend); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + if (reg) { + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + ggml_backend_set_n_threads_fn(backend, n_threads); + } + } + } + + std::string host; + int port; + if (!parse_endpoint(endpoint, host, port)) { + return; + } + +#ifdef GGML_RPC_RDMA + printf(" transport : TCP (RDMA auto-negotiate enabled)\n"); +#else + printf(" transport : TCP\n"); +#endif // GGML_RPC_RDMA + if (!rpc_transport_init()) { + fprintf(stderr, "Failed to initialize RPC transport\n"); + return; + } + auto server_socket = socket_t::create_server(host.c_str(), port); + if (server_socket == nullptr) { + fprintf(stderr, "Failed to create server socket\n"); + return; + } + while (true) { + auto client_socket = server_socket->accept(); + if (client_socket == nullptr) { + fprintf(stderr, "Failed to accept client connection\n"); + return; + } + printf("Accepted client connection\n"); + fflush(stdout); + rpc_serve_client(backends, cache_dir, client_socket); + printf("Client connection closed\n"); + fflush(stdout); + } + rpc_transport_shutdown(); + for (auto backend : backends) { + ggml_backend_free(backend); + } +} + +// device interface + +struct ggml_backend_rpc_device_context { + std::string endpoint; + uint32_t device; + std::string name; + std::string description; +}; + +static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ctx->name.c_str(); +} + +static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ctx->description.c_str(); +} + +static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total); +} + +static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) { + // TODO: obtain value from the server + return GGML_BACKEND_DEVICE_TYPE_GPU; + + GGML_UNUSED(dev); +} + +static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_rpc_device_get_name(dev); + props->description = ggml_backend_rpc_device_get_description(dev); + props->type = ggml_backend_rpc_device_get_type(dev); + ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device); + + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device); + + GGML_UNUSED(dev); +} + +static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + GGML_UNUSED(dev); + GGML_UNUSED(op); + //TODO: call the remote backend and cache the results + return true; +} + +static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) { + return false; + } + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context; + return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device; +} + +static const struct ggml_backend_device_i ggml_backend_rpc_device_i = { + /* .get_name = */ ggml_backend_rpc_device_get_name, + /* .get_description = */ ggml_backend_rpc_device_get_description, + /* .get_memory = */ ggml_backend_rpc_device_get_memory, + /* .get_type = */ ggml_backend_rpc_device_get_type, + /* .get_props = */ ggml_backend_rpc_device_get_props, + /* .init_backend = */ ggml_backend_rpc_device_init, + /* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_rpc_device_supports_op, + /* .supports_buft = */ ggml_backend_rpc_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// backend reg interface + +struct ggml_backend_rpc_reg_context { + std::string name; + std::vector devices; +}; + +static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) { + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + return ctx ? ctx->name.c_str() : "RPC"; +} + +static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) { + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + return ctx ? ctx->devices.size() : 0; +} + +static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) { + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + if (ctx == nullptr) { + GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead"); + } else { + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; + } +} + +static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) { + return (void *)ggml_backend_rpc_add_server; + } + if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) { + return (void *)ggml_backend_rpc_start_server; + } + return NULL; + + GGML_UNUSED(reg); +} + +static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = { + /* .get_name = */ ggml_backend_rpc_reg_get_name, + /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count, + /* .get_device = */ ggml_backend_rpc_reg_get_device, + /* .get_proc_address = */ ggml_backend_rpc_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_rpc_reg(void) { + static struct ggml_backend_reg ggml_backend_rpc_reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_rpc_reg_i, + /* .context = */ NULL, + }; + + return &ggml_backend_rpc_reg; +} + +static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) { + auto sock = get_socket(endpoint); + if (sock == nullptr) { + GGML_LOG_ERROR("Failed to connect to %s\n", endpoint); + return 0; + } + rpc_msg_device_count_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.device_count; +} + +static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = { + /* .get_name = */ ggml_backend_rpc_reg_get_name, + /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count, + /* .get_device = */ ggml_backend_rpc_reg_get_device, + /* .get_proc_address = */ ggml_backend_rpc_get_proc_address, +}; + +namespace { + std::unordered_map g_rpc_reg_map; + std::mutex g_rpc_reg_mutex; + uint32_t g_rpc_dev_id = 0; +} + +ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) { + std::lock_guard lock(g_rpc_reg_mutex); + if (g_rpc_reg_map.find(endpoint) != g_rpc_reg_map.end()) { + return g_rpc_reg_map[endpoint]; + } + uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint); + if (dev_count == 0) { + return nullptr; + } + ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context; + ctx->name = "RPC[" + std::string(endpoint) + "]"; + for (uint32_t ind = 0; ind < dev_count; ind++) { + std::string dev_name = "RPC" + std::to_string(g_rpc_dev_id); + std::string dev_desc = std::string(endpoint); + ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context { + /* .endpoint = */ endpoint, + /* .device = */ ind, + /* .name = */ dev_name, + /* .description = */ dev_desc + }; + g_rpc_dev_id++; + + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_rpc_device_i, + /* .reg = */ ggml_backend_rpc_reg(), + /* .context = */ dev_ctx, + }; + ctx->devices.push_back(dev); + } + ggml_backend_reg_t reg = new ggml_backend_reg { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_rpc_reg_interface, + /* .context = */ ctx + }; + g_rpc_reg_map[endpoint] = reg; + return reg; +} + + +GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg) diff --git a/ggml/src/ggml-rpc/transport.cpp b/ggml/src/ggml-rpc/transport.cpp new file mode 100644 index 00000000000..a728152421f --- /dev/null +++ b/ggml/src/ggml-rpc/transport.cpp @@ -0,0 +1,683 @@ +#include "transport.h" +#include "ggml-impl.h" + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +# include +# include +# include +# include +# include +#endif +#include +#include +#include + +#ifdef GGML_RPC_RDMA +# include +# include +# ifndef _WIN32 +# include +# endif +#endif // GGML_RPC_RDMA + +#ifdef _WIN32 +typedef SOCKET sockfd_t; +using ssize_t = __int64; +#else +typedef int sockfd_t; +#endif + +static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); + +#define LOG_DBG(...) \ + do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0) + +#ifdef GGML_RPC_RDMA +static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock) +static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB +static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes +using rdma_gid_t = std::array; + +struct rdma_conn { + struct ibv_context * ctx = nullptr; + struct ibv_pd * pd = nullptr; + struct ibv_cq * scq = nullptr; // send completions + struct ibv_cq * rcq = nullptr; // recv completions + struct ibv_qp * qp = nullptr; + + void * tx_buf = nullptr; + struct ibv_mr * tx_mr = nullptr; + + void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous + struct ibv_mr * rx_mr = nullptr; + int rx_head = 0; + + uint32_t max_inline = 0; + + uint8_t * rx_slot(int i) const { + return static_cast(rx_buf) + static_cast(i) * RDMA_CHUNK; + } + + bool post_rx(int i) { + struct ibv_sge sge = {}; + sge.addr = (uintptr_t)rx_slot(i); + sge.length = RDMA_CHUNK; + sge.lkey = rx_mr->lkey; + struct ibv_recv_wr wr = {}, * bad = nullptr; + wr.wr_id = (uint64_t)i; + wr.sg_list = &sge; + wr.num_sge = 1; + return ibv_post_recv(qp, &wr, &bad) == 0; + } + + ~rdma_conn() { + if (tx_mr) ibv_dereg_mr(tx_mr); + if (rx_mr) ibv_dereg_mr(rx_mr); + free(tx_buf); + free(rx_buf); + if (qp) ibv_destroy_qp(qp); + if (scq) ibv_destroy_cq(scq); + if (rcq) ibv_destroy_cq(rcq); + if (pd) ibv_dealloc_pd(pd); + if (ctx) ibv_close_device(ctx); + } +}; + +// Local RDMA parameters captured during the probe phase and later consumed +// by rdma_activate() after the remote side's caps arrive via HELLO. +struct rdma_local_info { + uint32_t qpn = 0; + uint32_t psn = 0; + uint8_t gid[RDMA_GID_SIZE] = {}; + uint8_t ib_port = 0; + int gid_idx = 0; + enum ibv_mtu path_mtu = IBV_MTU_1024; +}; + +struct rdma_caps { + uint32_t qpn; + uint32_t psn; + uint8_t gid[RDMA_GID_SIZE]; +}; + +static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size"); + +#endif // GGML_RPC_RDMA + +struct socket_t::impl { + impl(sockfd_t fd) : use_rdma(false), fd(fd) {} + ~impl(); + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + +#ifdef GGML_RPC_RDMA + bool tcp_peer_closed(); + std::optional rdma_build_target_gid(); + bool rdma_probe(); + bool rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid); + bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc); + bool rdma_send(const void * data, size_t size); + bool rdma_recv(void * data, size_t size); + + std::unique_ptr rdma; + rdma_local_info rdma_local = {}; +#endif // GGML_RPC_RDMA + bool use_rdma; + sockfd_t fd; +}; + +socket_t::impl::~impl() { +#ifdef GGML_RPC_RDMA + rdma.reset(); +#endif // GGML_RPC_RDMA + LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); +#ifdef _WIN32 + if (fd != INVALID_SOCKET) closesocket(this->fd); +#else + if (fd >= 0) close(this->fd); +#endif +} + +#ifdef GGML_RPC_RDMA + +bool socket_t::impl::tcp_peer_closed() { + if (fd < 0) return false; +#ifndef _WIN32 + struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 }; + int r = poll(&pfd, 1, 0); + return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP)); +#else + return false; +#endif +} + +// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address. +// Used to match the socket's local IP against the kernel's GID table so that +// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly: +// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4) +// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape) +// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is +// Returns std::nullopt on unsupported family or getsockname failure. +std::optional socket_t::impl::rdma_build_target_gid() { + sockaddr_storage addr = {}; + socklen_t addr_len = sizeof(addr); + if (getsockname(fd, reinterpret_cast(&addr), &addr_len) != 0) { + return std::nullopt; + } + rdma_gid_t target = {}; + if (addr.ss_family == AF_INET) { + const auto * a = reinterpret_cast(&addr); + target[10] = 0xff; + target[11] = 0xff; + memcpy(&target[12], &a->sin_addr, 4); + return target; + } + if (addr.ss_family == AF_INET6) { + const auto * a = reinterpret_cast(&addr); + memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE); + return target; + } + return std::nullopt; +} + +bool socket_t::impl::rdma_probe() { + const char * dev_env = std::getenv("GGML_RDMA_DEV"); + const char * gid_env = std::getenv("GGML_RDMA_GID"); + + auto target_gid = rdma_build_target_gid(); + if (!target_gid) { + return false; + } + + const uint8_t ib_port = 1; + int num_devs = 0; + ibv_device ** devs = ibv_get_device_list(&num_devs); + if (!devs || num_devs == 0) return false; + + ibv_context * ibctx = nullptr; + const char * matched_dev = nullptr; + int gid_idx = gid_env ? atoi(gid_env) : -1; + int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB + + for (int d = 0; d < num_devs; d++) { + const char * dn = ibv_get_device_name(devs[d]); + if (dev_env && strcmp(dev_env, dn) != 0) continue; + + ibv_context * ctx = ibv_open_device(devs[d]); + if (!ctx) continue; + + ibv_port_attr pa; + if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; } + + int found_gid = gid_idx; + int found_version = IBV_GID_TYPE_IB; + if (found_gid < 0) { + // Find a GID on this port whose bytes equal the local TCP address + // (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1 + // (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths + // are avoided. ibv_query_gid_ex returns gid+type in one call. + int v2_idx = -1; + int v1_idx = -1; + for (int i = 0; i < pa.gid_tbl_len; i++) { + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue; + if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue; + if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) { + v2_idx = i; + } else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) { + v1_idx = i; + } + } + if (v2_idx >= 0) { + found_gid = v2_idx; + found_version = IBV_GID_TYPE_ROCE_V2; + } else if (v1_idx >= 0) { + found_gid = v1_idx; + found_version = IBV_GID_TYPE_ROCE_V1; + } + } else { + // Explicit GID index from GGML_RDMA_GID — fetch its type for logging. + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) { + found_version = entry.gid_type; + } + } + if (found_gid >= 0) { + ibctx = ctx; + gid_idx = found_gid; + gid_version = found_version; + matched_dev = dn; + rdma_local.path_mtu = pa.active_mtu; + break; + } + ibv_close_device(ctx); + } + ibv_free_device_list(devs); + if (!ibctx) return false; + + rdma_local.ib_port = ib_port; + rdma_local.gid_idx = gid_idx; + + rdma = std::make_unique(); + rdma->ctx = ibctx; + + rdma->pd = ibv_alloc_pd(ibctx); + if (!rdma->pd) return false; + + rdma->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0); + rdma->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0); + if (!rdma->scq || !rdma->rcq) return false; + + ibv_qp_init_attr qia = {}; + qia.send_cq = rdma->scq; + qia.recv_cq = rdma->rcq; + qia.qp_type = IBV_QPT_RC; + qia.cap.max_send_wr = 4; + qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4; + qia.cap.max_send_sge = 1; + qia.cap.max_recv_sge = 1; + qia.cap.max_inline_data = 256; + + rdma->qp = ibv_create_qp(rdma->pd, &qia); + if (!rdma->qp) return false; + rdma->max_inline = qia.cap.max_inline_data; + + rdma->tx_buf = aligned_alloc(4096, RDMA_CHUNK); + rdma->rx_buf = aligned_alloc(4096, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK); + if (!rdma->tx_buf || !rdma->rx_buf) return false; + + rdma->tx_mr = ibv_reg_mr(rdma->pd, rdma->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE); + rdma->rx_mr = ibv_reg_mr(rdma->pd, rdma->rx_buf, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + if (!rdma->tx_mr || !rdma->rx_mr) return false; + + ibv_gid local_gid; + if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return false; + + rdma_local.qpn = rdma->qp->qp_num; + rdma_local.psn = rdma->qp->qp_num & 0xffffff; + memcpy(&rdma_local.gid, &local_gid, RDMA_GID_SIZE); + + const char * ver_str = ""; + if (gid_version == IBV_GID_TYPE_ROCE_V2) { + ver_str = " RoCEv2"; + } else if (gid_version == IBV_GID_TYPE_ROCE_V1) { + ver_str = " RoCEv1"; + } + GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n", + matched_dev, gid_idx, ver_str, rdma_local.qpn, rdma->max_inline); + return true; +} + +// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS. +// On success, the connection is live and ready for rdma_send/rdma_recv. +bool socket_t::impl::rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) { + // RESET -> INIT + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_INIT; + a.port_num = rdma_local.ib_port; + a.pkey_index = 0; + a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + return false; + } + } + + for (int i = 0; i < RDMA_RX_DEPTH; i++) { + if (!rdma->post_rx(i)) return false; + } + + // INIT -> RTR + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTR; + a.path_mtu = rdma_local.path_mtu; + a.dest_qp_num = remote_qpn; + a.rq_psn = remote_psn; + a.max_dest_rd_atomic = 1; + a.min_rnr_timer = 1; + a.ah_attr.is_global = 1; + memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE); + a.ah_attr.grh.hop_limit = 1; + a.ah_attr.grh.sgid_index = rdma_local.gid_idx; + a.ah_attr.dlid = 0; + a.ah_attr.port_num = rdma_local.ib_port; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | + IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) { + return false; + } + } + + // RTR -> RTS + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTS; + a.timeout = 14; + a.retry_cnt = 7; + a.rnr_retry = 7; + a.sq_psn = rdma_local.psn; + a.max_rd_atomic = 1; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | + IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) { + return false; + } + } + + GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n", + rdma_local.qpn, remote_qpn, 128 << rdma_local.path_mtu, RDMA_RX_DEPTH); + return true; +} + +bool socket_t::impl::rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc) { + for (uint64_t s = 0; ; s++) { + int n = ibv_poll_cq(cq, 1, wc); + if (n > 0) { + if (wc->status != IBV_WC_SUCCESS) { + GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n", + wc->status, ibv_wc_status_str(wc->status), wc->vendor_err); + } + return wc->status == IBV_WC_SUCCESS; + } + if (n < 0) return false; + if ((s & 0xFFFFF) == 0 && s > 0) { + if (tcp_peer_closed()) { + return false; + } + } + } +} + +bool socket_t::impl::rdma_send(const void * data, size_t size) { + rdma_conn * c = rdma.get(); + const uint8_t * src = (const uint8_t *)data; + size_t rem = size; + while (rem > 0) { + size_t chunk = std::min(rem, RDMA_CHUNK); + + struct ibv_sge sge = {}; + struct ibv_send_wr wr = {}, * bad = nullptr; + wr.opcode = IBV_WR_SEND; + wr.sg_list = &sge; + wr.num_sge = 1; + + if (chunk <= c->max_inline) { + sge.addr = (uintptr_t)src; + sge.length = chunk; + wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE; + } else { + memcpy(c->tx_buf, src, chunk); + sge.addr = (uintptr_t)c->tx_buf; + sge.length = chunk; + sge.lkey = c->tx_mr->lkey; + wr.send_flags = IBV_SEND_SIGNALED; + } + + if (ibv_post_send(c->qp, &wr, &bad) != 0) return false; + struct ibv_wc wc; + if (!rdma_poll(c->scq, &wc)) return false; + + src += chunk; + rem -= chunk; + } + return true; +} + +bool socket_t::impl::rdma_recv(void * data, size_t size) { + rdma_conn * c = rdma.get(); + uint8_t * dst = (uint8_t *)data; + size_t rem = size; + while (rem > 0) { + struct ibv_wc wc; + if (!rdma_poll(c->rcq, &wc)) return false; + + int slot = (int)wc.wr_id; + size_t got = wc.byte_len; + memcpy(dst, c->rx_slot(slot), got); + + if (!c->post_rx(slot)) return false; + + dst += got; + rem -= got; + } + return true; +} + +#endif // GGML_RPC_RDMA + +bool socket_t::impl::send_data(const void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_send(data, size); + } +#endif + size_t bytes_sent = 0; + while (bytes_sent < size) { + size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE); + ssize_t n = send(fd, (const char *)data + bytes_sent, size_to_send, 0); + if (n < 0) { + GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n", + bytes_sent, size_to_send); + return false; + } + bytes_sent += (size_t)n; + } + return true; +} + +bool socket_t::impl::recv_data(void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_recv(data, size); + } +#endif + size_t bytes_recv = 0; + while (bytes_recv < size) { + size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE); + ssize_t n = recv(fd, (char *)data + bytes_recv, size_to_recv, 0); + if (n < 0) { + GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n", + bytes_recv, size_to_recv); + return false; + } + if (n == 0) { + LOG_DBG("recv returned 0 (peer closed?)\n"); + return false; + } + bytes_recv += (size_t)n; + } + return true; +} + +void socket_t::impl::get_caps(uint8_t * local_caps) { + memset(local_caps, 0, RPC_CONN_CAPS_SIZE); +#ifdef GGML_RPC_RDMA + rdma_local = {}; + if (rdma_probe()) { + rdma_caps rc = {}; + rc.qpn = rdma_local.qpn; + rc.psn = rdma_local.psn; + memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE); + memcpy(local_caps, &rc, sizeof(rc)); + } else { + rdma.reset(); + } +#endif // GGML_RPC_RDMA +} + +void socket_t::impl::update_caps(const uint8_t * remote_caps) { +#ifdef GGML_RPC_RDMA + if (!rdma) { + return; + } + rdma_caps rc = {}; + memcpy(&rc, remote_caps, sizeof(rc)); + if (rc.qpn == 0) { + rdma.reset(); + return; + } + if (rdma_activate(rc.qpn, rc.psn, rc.gid)) { + use_rdma = true; + } else { + GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n"); + rdma.reset(); + } +#else + (void)remote_caps; +#endif // GGML_RPC_RDMA +} + + +///////////////////////////////////////////////////////////////////////////// + +socket_t::socket_t(std::unique_ptr p) : pimpl(std::move(p)) {} + +socket_t::~socket_t() = default; + +bool socket_t::send_data(const void * data, size_t size) { + return pimpl->send_data(data, size); +} + +bool socket_t::recv_data(void * data, size_t size) { + return pimpl->recv_data(data, size); +} + +void socket_t::get_caps(uint8_t * local_caps) { + return pimpl->get_caps(local_caps); +} + +void socket_t::update_caps(const uint8_t * remote_caps) { + return pimpl->update_caps(remote_caps); +} + +static bool is_valid_fd(sockfd_t sockfd) { +#ifdef _WIN32 + return sockfd != INVALID_SOCKET; +#else + return sockfd >= 0; +#endif +} + +static bool set_no_delay(sockfd_t sockfd) { + int flag = 1; + // set TCP_NODELAY to disable Nagle's algorithm + int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); + return ret == 0; +} + +static bool set_reuse_addr(sockfd_t sockfd) { + int flag = 1; + int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); + return ret == 0; +} + +socket_ptr socket_t::accept() { + auto client_socket_fd = ::accept(pimpl->fd, NULL, NULL); + if (!is_valid_fd(client_socket_fd)) { + return nullptr; + } + if (!set_no_delay(client_socket_fd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(client_socket_fd))); +} + +socket_ptr socket_t::create_server(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_reuse_addr(sockfd)) { + GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n"); + return nullptr; + } + if (inet_addr(host) == INADDR_NONE) { + GGML_LOG_ERROR("Invalid host address: %s\n", host); + return nullptr; + } + struct sockaddr_in serv_addr; + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = inet_addr(host); + serv_addr.sin_port = htons(port); + + if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { + return nullptr; + } + if (listen(sockfd, 1) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(sockfd))); +} + +socket_ptr socket_t::connect(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_no_delay(sockfd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + struct hostent * server = gethostbyname(host); + if (server == NULL) { + GGML_LOG_ERROR("Cannot resolve host '%s'\n", host); + return nullptr; + } + memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); + if (::connect(sockfd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(sockfd))); +} + +#ifdef _WIN32 +static std::mutex g_rpc_transport_mu; +static bool g_rpc_transport_wsa_started = false; +#endif + +bool rpc_transport_init() { +#ifdef _WIN32 + std::lock_guard lock(g_rpc_transport_mu); + if (g_rpc_transport_wsa_started) { + return true; + } + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + return false; + } + g_rpc_transport_wsa_started = true; + return true; +#else + return true; +#endif +} + +void rpc_transport_shutdown() { +#ifdef _WIN32 + std::lock_guard lock(g_rpc_transport_mu); + if (!g_rpc_transport_wsa_started) { + return; + } + WSACleanup(); + g_rpc_transport_wsa_started = false; +#endif +} diff --git a/ggml/src/ggml-rpc/transport.h b/ggml/src/ggml-rpc/transport.h new file mode 100644 index 00000000000..73b85cc530a --- /dev/null +++ b/ggml/src/ggml-rpc/transport.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +struct socket_t; +typedef std::shared_ptr socket_ptr; + +static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB +static constexpr size_t RPC_CONN_CAPS_SIZE = 24; + +struct socket_t { + ~socket_t(); + + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + + socket_ptr accept(); + + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + + static socket_ptr create_server(const char * host, int port); + static socket_ptr connect(const char * host, int port); + +private: + struct impl; + explicit socket_t(std::unique_ptr p); + std::unique_ptr pimpl; +}; + +bool rpc_transport_init(); +void rpc_transport_shutdown(); diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 290b7c946e5..7852a943419 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include "model_adapter.h" @@ -51,6 +52,7 @@ #include "tools/mtmd/llava.h" #include "tools/mtmd/mtmd-audio.h" #include "common/common.h" +#include "ggml-backend.h" #if defined(GGML_USE_HIP) // for rocblas_initialize() @@ -2195,6 +2197,62 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in ggml_time_init(); kcpp_data = new kcpp_params(); //allocate on heap to avoid linux segfault. yes this leaks memory. + // Parse device string and set HIP_VISIBLE_DEVICES/CUDA_VISIBLE_DEVICES to restrict local devices + // This prevents RPC server devices from being used locally when not specified + { + const char * device_str = inputs.devices_override; + if (device_str && strlen(device_str) > 0) { + std::string devices(device_str); + std::vector local_device_indices; + + // Parse device string to find local ROCm/CUDA devices (not RPC*) + size_t start = 0; + while (start < devices.length()) { + size_t end = devices.find(',', start); + if (end == std::string::npos) end = devices.length(); + std::string device = devices.substr(start, end - start); + + // Trim whitespace + size_t first = device.find_first_not_of(" \t"); + if (first != std::string::npos) { + size_t last = device.find_last_not_of(" \t"); + device = device.substr(first, last - first + 1); + } + + // Check if this is a local ROCm or CUDA device (not RPC*) + if (device.length() > 0) { + if (strncasecmp(device.c_str(), "ROCm", 4) == 0) { + int dev_num = atoi(device.c_str() + 4); + local_device_indices.push_back(dev_num); + } else if (strncasecmp(device.c_str(), "CUDA", 4) == 0) { + int dev_num = atoi(device.c_str() + 4); + local_device_indices.push_back(dev_num); + } else if (strncasecmp(device.c_str(), "HIP", 3) == 0) { + int dev_num = atoi(device.c_str() + 3); + local_device_indices.push_back(dev_num); + } + } + + start = end + 1; + } + + if (!local_device_indices.empty()) { + std::string visible_devices; + for (size_t i = 0; i < local_device_indices.size(); i++) { + if (i > 0) visible_devices += ","; + visible_devices += std::to_string(local_device_indices[i]); + } +#if defined(GGML_USE_HIP) + setenv("HIP_VISIBLE_DEVICES", visible_devices.c_str(), 1); + printf("HIP_VISIBLE_DEVICES=%s (restricting to specified local devices)\n", visible_devices.c_str()); +#elif defined(GGML_USE_CUDA) + setenv("CUDA_VISIBLE_DEVICES", visible_devices.c_str(), 1); + printf("CUDA_VISIBLE_DEVICES=%s (restricting to specified local devices)\n", visible_devices.c_str()); +#endif + } + } + } + file_format = in_file_format; file_format_meta = in_file_format_meta; kcpp_data->n_threads = inputs.threads; @@ -2443,6 +2501,28 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in model_params.use_direct_io = false; //no direct io for now until stable model_params.n_gpu_layers = inputs.gpulayers; + // Register RPC servers BEFORE parsing device list + const char * rpc_env = std::getenv("LLAMA_ARG_RPC"); + if (rpc_env && strlen(rpc_env) > 0) { + printf("RPC Client Mode: Pre-registering RPC servers: %s\n", rpc_env); + fflush(stdout); + auto rpc_servers = string_split(rpc_env, ','); + if (!rpc_servers.empty()) { + typedef ggml_backend_reg_t (*ggml_backend_rpc_add_server_t)(const char * endpoint); + ggml_backend_rpc_add_server_t ggml_backend_rpc_add_server_fn = (ggml_backend_rpc_add_server_t) dlsym(RTLD_DEFAULT, "ggml_backend_rpc_add_server"); + if (ggml_backend_rpc_add_server_fn) { + for (const auto & server : rpc_servers) { + auto reg = ggml_backend_rpc_add_server_fn(server.c_str()); + if (reg) { + ggml_backend_register(reg); + printf("RPC Client Mode: Pre-registered server: %s\n", server.c_str()); + fflush(stdout); + } + } + } + } + } + //set device overrides if needed std::vector devices_override; std::string dev_override_str = inputs.devices_override; diff --git a/koboldcpp.py b/koboldcpp.py old mode 100644 new mode 100755 index 9accb847242..a9ae7867a60 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -825,27 +825,108 @@ def pick_existant_file(ntoption,nonntoption): lib_hipblas = pick_existant_file("koboldcpp_hipblas.dll","koboldcpp_hipblas.so") lib_vulkan = pick_existant_file("koboldcpp_vulkan.dll","koboldcpp_vulkan.so") lib_vulkan_noavx2 = pick_existant_file("koboldcpp_vulkan_noavx2.dll","koboldcpp_vulkan_noavx2.so") +# RPC libraries +lib_rpc = pick_existant_file("koboldcpp_rpc.dll","koboldcpp_rpc.so") +lib_hipblas_rpc = pick_existant_file("koboldcpp_hipblas_rpc.dll","koboldcpp_hipblas_rpc.so") +lib_cublas_rpc = pick_existant_file("koboldcpp_cublas_rpc.dll","koboldcpp_cublas_rpc.so") libname = "" lib_option_pairs = [ (lib_default, "Use CPU"), (lib_cublas, "Use CUDA"), (lib_hipblas, "Use hipBLAS (ROCm)"), (lib_vulkan, "Use Vulkan"), + (lib_cublas_rpc, "Use CUDA + RPC"), + (lib_hipblas_rpc, "Use hipBLAS + RPC"), + (lib_rpc, "Use Vulkan + RPC"), (lib_noavx2, "Use CPU (Old CPU)"), (lib_vulkan_noavx2, "Use Vulkan (Old CPU)"), (lib_vulkan_failsafe, "Use Vulkan (Older CPU)"), (lib_failsafe, "Failsafe Mode (Older CPU)")] -default_option, cublas_option, hipblas_option, vulkan_option, noavx2_option, vulkan_noavx2_option, vulkan_failsafe_option, failsafe_option = (opt if file_exists(lib) or (os.name == 'nt' and file_exists(opt + ".dll")) else None for lib, opt in lib_option_pairs) +default_option, cublas_option, hipblas_option, vulkan_option, cublas_rpc_option, hipblas_rpc_option, rpc_option, noavx2_option, vulkan_noavx2_option, vulkan_failsafe_option, failsafe_option = (opt if file_exists(lib) or (os.name == 'nt' and file_exists(opt + ".dll")) else None for lib, opt in lib_option_pairs) runopts = [opt for lib, opt in lib_option_pairs if file_exists(lib)] def init_library(): global handle, args, libname global lib_default,lib_failsafe,lib_noavx2,lib_vulkan_failsafe,lib_cublas,lib_hipblas,lib_vulkan,lib_vulkan_noavx2 + global lib_rpc, lib_hipblas_rpc, lib_cublas_rpc + global is_rpc_client, is_rpc_server libname = lib_default + # Check if RPC endpoint is specified (client mode) + is_rpc_client = args.rpc is not None and len(args.rpc) > 0 + + # Also check if RPC* devices are specified in the device string + if not is_rpc_client and hasattr(args, 'device') and args.device: + device_str = args.device + has_rpc_devices = any( + d.strip().upper().startswith('RPC') + for d in device_str.split(',') + ) + if has_rpc_devices: + is_rpc_client = True + print("RPC Client Mode: Detected RPC* devices in device string") + + # Check if RPC server mode is enabled + is_rpc_server = args.rpc_server + if not args: # debug helper: koboldcpp.py loaded by external script pass + elif is_rpc_client: + # RPC client mode - use RPC libraries + # Choose RPC client based on local devices in device string + libname = None + + # Parse device string to determine which local backends are needed + use_hipblas_local = False + use_cublas_local = False + use_vulkan_local = False + + if args.device: + devices = [d.strip().upper() for d in args.device.split(',')] + for dev in devices: + if dev.startswith('ROCm') or dev.startswith('HIP'): + use_hipblas_local = True + elif dev.startswith('CUDA'): + use_cublas_local = True + elif dev.startswith('VULKAN'): + use_vulkan_local = True + + # Select RPC client library based on local device requirements + if use_hipblas_local and file_exists(lib_hipblas_rpc): + libname = lib_hipblas_rpc + print("Using HIPBLAS RPC backend (Client Mode)") + elif use_cublas_local and file_exists(lib_cublas_rpc): + libname = lib_cublas_rpc + print("Using CUDA RPC backend (Client Mode)") + elif use_vulkan_local and file_exists(lib_rpc): + libname = lib_rpc + print("Using Vulkan RPC backend (Client Mode)") + elif file_exists(lib_hipblas_rpc): + libname = lib_hipblas_rpc + print("Using HIPBLAS RPC backend (Client Mode - Auto-selected)") + elif file_exists(lib_cublas_rpc): + libname = lib_cublas_rpc + print("Using CUDA RPC backend (Client Mode - Auto-selected)") + elif file_exists(lib_rpc): + libname = lib_rpc + print("Using Vulkan RPC backend (Client Mode - Auto-selected)") + else: + exit_with_error(2, "RPC client mode requested but no RPC libraries found. Please build RPC libraries first.") + elif is_rpc_server: + # RPC server mode - use RPC libraries based on backend selection + if args.usevulkan is not None or file_exists(lib_rpc): + if file_exists(lib_rpc): + libname = lib_rpc + print("Using Vulkan RPC backend (Server Mode)") + elif args.usecuda is not None and file_exists(lib_hipblas_rpc): + if file_exists(lib_hipblas_rpc): + libname = lib_hipblas_rpc + print("Using HIPBLAS RPC backend (Server Mode)") + elif args.usecuda is not None and file_exists(lib_cublas_rpc): + if file_exists(lib_cublas_rpc): + libname = lib_cublas_rpc + print("Using CUDA RPC backend (Server Mode)") elif args.noavx2: #failsafe implies noavx2 always if args.failsafe and (args.usevulkan is not None) and file_exists(lib_vulkan_failsafe): libname = lib_vulkan_failsafe @@ -1016,7 +1097,40 @@ def set_backend_props(inputs): handle.set_environment_variable("GGML_VK_VISIBLE_DEVICES".encode("UTF-8"),vulkangpus.encode("UTF-8")) # set universal flags - inputs.devices_override = (args.device if args.device else "").encode("UTF-8") + # Filter device string for RPC client mode + device_str = "" + if args.device: + # Check if RPC* devices are specified + devices = [d.strip() for d in args.device.split(',')] + has_rpc_devices = any(d.upper().startswith('RPC') for d in devices) + + # If RPC* devices are specified but no RPC endpoints provided, show error + if has_rpc_devices and (args.rpc is None or len(args.rpc) == 0): + print("ERROR: RPC* devices specified in device string but no RPC endpoints provided via --rpc flag.") + print("RPC* devices (RPC0, RPC1, etc.) require RPC server endpoints to connect to.") + print("Please specify RPC endpoints with --rpc (e.g., --rpc 127.0.0.1:50053)") + print("Alternatively, remove RPC* devices from the device string.") + print("Proceeding without RPC devices...") + + if is_rpc_client: + # In RPC client mode, keep RPC* device strings if RPC endpoints are provided + # RPC endpoints are handled by the RPC library + if args.rpc is not None and len(args.rpc) > 0: + # Keep all devices including RPC* since RPC servers are registered + device_str = args.device + print(f"RPC Client Mode: Using devices (including RPC): {device_str}") + else: + # Filter out RPC* device strings if no RPC endpoints + valid_devices = [d for d in devices if not d.upper().startswith('RPC')] + if valid_devices: + device_str = ','.join(valid_devices) + print(f"RPC Client Mode: Using local devices: {device_str}") + else: + device_str = "" + print("RPC Client Mode: No local devices specified, using RPC servers only") + else: + device_str = args.device + inputs.devices_override = device_str.encode("UTF-8") inputs.quiet = args.quiet inputs.debugmode = args.debugmode inputs.executable_path = (getdirpath()+"/").encode("UTF-8") @@ -7476,7 +7590,7 @@ def hide_tooltip(event): tabs = ctk.CTkFrame(root, corner_radius = 0, width=windowwidth, height=windowheight-50) tabs.grid(row=0, stick="nsew") - tabnames= ["Quick Launch", "Hardware", "Context", "Loaded Files", "Network", "Horde Worker","Image Gen","Audio","Admin","Extra"] + tabnames= ["Quick Launch", "Hardware", "Context", "Loaded Files", "RPC Server", "Network", "Horde Worker","Image Gen","Audio","Admin","Extra"] navbuttons = {} navbuttonframe = ctk.CTkFrame(tabs, width=int(104), height=int(tabs.cget("height"))) navbuttonframe.grid(row=0, column=0, padx=2,pady=2) @@ -7644,6 +7758,17 @@ def hide_tooltip(event): autoswap_mode_var = ctk.IntVar(value=0) admin_unload_timeout_var = ctk.StringVar(value=str(0)) + # RPC Server variables + rpc_server_mode_var = ctk.IntVar(value=0) + rpc_server_backend_var = ctk.StringVar(value="Auto-detect") + rpc_host_var = ctk.StringVar(value="0.0.0.0") + rpc_port_var = ctk.StringVar(value="50053") + rpc_devices_var = ctk.StringVar(value="") + rpc_cache_layers_var = ctk.IntVar(value=0) + + # RPC endpoint for client mode + rpc_endpoint_var = ctk.StringVar(value="") + nozenity_var = ctk.IntVar(value=0) curr_tab_idx = 0 @@ -7906,7 +8031,20 @@ def auto_set_backend_gui(manual_select=False): #autopick cublas if suitable, requires at least 3.5GB VRAM to auto pick #we do not want to autoselect hip/cublas if the user has already changed their desired backend! - if eligible_cuda and exitcounter < 100 and MaxMemory[0]>3500000000 and (("Use CUDA" in runopts and CUDevicesNames[0]!="") or "Use hipBLAS (ROCm)" in runopts) and (any(CUDevicesNames)) and runmode_untouched: + # First check for RPC backends if RPC libraries are available + if "Use Vulkan + RPC" in runopts and rpc_endpoint_var.get() != "": + runopts_var.set("Use Vulkan + RPC") + print("Auto Selected Vulkan + RPC Backend\n") + found_new_backend = True + elif "Use hipBLAS + RPC" in runopts and rpc_endpoint_var.get() != "": + runopts_var.set("Use hipBLAS + RPC") + print("Auto Selected HIPBLAS + RPC Backend\n") + found_new_backend = True + elif "Use CUDA + RPC" in runopts and rpc_endpoint_var.get() != "": + runopts_var.set("Use CUDA + RPC") + print("Auto Selected CUDA + RPC Backend\n") + found_new_backend = True + elif eligible_cuda and exitcounter < 100 and MaxMemory[0]>3500000000 and (("Use CUDA" in runopts and CUDevicesNames[0]!="") or "Use hipBLAS (ROCm)" in runopts) and (any(CUDevicesNames)) and runmode_untouched: if "Use CUDA" in runopts: runopts_var.set("Use CUDA") gpu_choice_var.set("1") @@ -8005,14 +8143,20 @@ def changed_gpu_choice_var(*args): if v == "Use Vulkan" or v == "Use Vulkan (Old CPU)" or v == "Use Vulkan (Older CPU)": quick_gpuname_label.configure(text=VKDevicesNames[s]) gpuname_label.configure(text=VKDevicesNames[s]) + if 'rpc_gpuname_label' in globals(): + rpc_gpuname_label.configure(text=VKDevicesNames[s]) else: quick_gpuname_label.configure(text=CUDevicesNames[s]) gpuname_label.configure(text=CUDevicesNames[s]) + if 'rpc_gpuname_label' in globals(): + rpc_gpuname_label.configure(text=CUDevicesNames[s]) except Exception: pass else: quick_gpuname_label.configure(text="(dGPUs only, tensor split sets ratio)") gpuname_label.configure(text="(dGPUs only, tensor split sets ratio)") + if 'rpc_gpuname_label' in globals(): + rpc_gpuname_label.configure(text="(dGPUs only, tensor split sets ratio)") gpu_choice_var.trace_add("write", changed_gpu_choice_var) gpulayers_var.trace_add("write", changed_gpulayers_estimate) @@ -8069,7 +8213,8 @@ def changerunmode(a,b,c): global runmode_untouched runmode_untouched = False index = runopts_var.get() - if index == "Use Vulkan" or index == "Use Vulkan (Old CPU)" or index == "Use Vulkan (Older CPU)" or index == "Use CUDA" or index == "Use hipBLAS (ROCm)": + # Check for GPU or RPC backends + if index == "Use Vulkan" or index == "Use Vulkan (Old CPU)" or index == "Use Vulkan (Older CPU)" or index == "Use CUDA" or index == "Use hipBLAS (ROCm)" or index == "Use Vulkan + RPC" or index == "Use hipBLAS + RPC" or index == "Use CUDA + RPC": quick_gpuname_label.grid(row=3, column=1, padx=75, sticky="W") gpuname_label.grid(row=3, column=0, padx=230, sticky="W") gpu_selector_label.grid(row=3, column=0, padx = 8, pady=1, stick="nw") @@ -8094,7 +8239,7 @@ def changerunmode(a,b,c): splitmode_box.grid_remove() splitmode_box_label.grid_remove() - if index == "Use CUDA" or index == "Use hipBLAS (ROCm)": + if index == "Use CUDA" or index == "Use hipBLAS (ROCm)" or index == "Use CUDA + RPC" or index == "Use hipBLAS + RPC": mmq_box.grid(row=4, column=0, padx=340, pady=1, stick="nw") quick_mmq_box.grid(row=4, column=1, padx=8, pady=1, stick="nw") tensor_split_label.grid(row=8, column=0, padx = 8, pady=1, stick="nw") @@ -8105,11 +8250,11 @@ def changerunmode(a,b,c): tensor_split_label.grid_remove() tensor_split_entry.grid_remove() - if index == "Use Vulkan" or index == "Use Vulkan (Old CPU)": + if index == "Use Vulkan" or index == "Use Vulkan (Old CPU)" or index == "Use Vulkan + RPC": tensor_split_label.grid(row=8, column=0, padx = 8, pady=1, stick="nw") tensor_split_entry.grid(row=8, column=0, padx = 160, pady=1, stick="nw") - if index == "Use Vulkan" or index == "Use Vulkan (Old CPU)" or index == "Use Vulkan (Older CPU)" or index == "Use CUDA" or index == "Use hipBLAS (ROCm)": + if index == "Use Vulkan" or index == "Use Vulkan (Old CPU)" or index == "Use Vulkan (Older CPU)" or index == "Use CUDA" or index == "Use hipBLAS (ROCm)" or index == "Use Vulkan + RPC" or index == "Use hipBLAS + RPC" or index == "Use CUDA + RPC": gpu_layers_label.grid(row=6, column=0, padx=8, pady=1, stick="nw") gpu_layers_entry.grid(row=6, column=0, padx=160, pady=1, stick="nw") quick_gpu_layers_label.grid(row=6, column=0, padx = 8, pady=1, stick="nw") @@ -8147,6 +8292,18 @@ def changerunmode(a,b,c): changed_gpulayers_estimate() changed_gpu_choice_var() + # Show RPC endpoint field for RPC variant backends + if index == "Use CUDA + RPC" or index == "Use hipBLAS + RPC" or index == "Use Vulkan + RPC": + rpc_endpoint_label.grid(row=2, column=0, padx=8, pady=1, sticky="nw") + rpc_endpoint_entry.grid(row=2, column=1, padx=8, pady=1, sticky="nw") + rpc_endpoint_label_hw.grid(row=2, column=0, padx=160, pady=1, sticky="nw") + rpc_endpoint_entry_hw.grid(row=2, column=0, padx=160, pady=1, sticky="nw") + else: + rpc_endpoint_label.grid_remove() + rpc_endpoint_entry.grid_remove() + rpc_endpoint_label_hw.grid_remove() + rpc_endpoint_entry_hw.grid_remove() + # presets selector makelabel(quick_tab, "Backend:", 1,0,"Select a backend to use.\nCUDA runs on Nvidia GPUs, and is much faster.\nVulkan works on all GPUs but is somewhat slower.\nOtherwise, runs on CPU only.\nNoAVX2 and Failsafe modes support older PCs.") @@ -8154,6 +8311,11 @@ def changerunmode(a,b,c): runoptbox.grid(row=1, column=1,padx=8, stick="nw") runoptbox.set(runopts[0]) # Set to first available option + # RPC endpoint field for RPC variant backends + rpc_endpoint_label = makelabel(quick_tab, "RPC Endpoint:", 2, 0, "RPC server endpoint to connect to (e.g., 127.0.0.1:50053). Leave empty for local inference.") + rpc_endpoint_entry = ctk.CTkEntry(quick_tab, textvariable=rpc_endpoint_var, width=190) + rpc_endpoint_entry.grid(row=2, column=1, padx=8, pady=1, sticky="nw") + # gpu options quick_gpu_selector_label = makelabel(quick_tab, "GPU ID:", 3,0,"Which GPU ID to load the model with.\nNormally your main GPU is #1, but it can vary for multi GPU setups.",padx=8) CUDA_quick_gpu_selector_box = ctk.CTkComboBox(quick_tab, values=CUDevices, width=60, variable=gpu_choice_var, state="readonly") @@ -8200,6 +8362,11 @@ def changerunmode(a,b,c): runoptbox.grid(row=1, column=0,padx=160, stick="nw") runoptbox.set(runopts[0]) # Set to first available option + # RPC endpoint field for RPC variant backends (shared with Quick Launch) + rpc_endpoint_label_hw = makelabel(hardware_tab, "RPC Endpoint:", 2, 0, "RPC server endpoint to connect to (e.g., 127.0.0.1:50053). Leave empty for local inference.") + rpc_endpoint_entry_hw = ctk.CTkEntry(hardware_tab, textvariable=rpc_endpoint_var, width=180) + rpc_endpoint_entry_hw.grid(row=2, column=0, padx=160, pady=1, sticky="nw") + # gpu options gpu_selector_label = makelabel(hardware_tab, "GPU ID:", 3,0,"Which GPU ID to load the model with.\nNormally your main GPU is #1, but it can vary for multi GPU setups.") CUDA_gpu_selector_box = ctk.CTkComboBox(hardware_tab, values=CUDevices, width=60, variable=gpu_choice_var, state="readonly") @@ -8492,6 +8659,189 @@ def toggletaesd(a,b,c): makefileentry(audio_tab, "MusicVAE:", "Select music VAE model", musicvae_var, 36, width=280, singlerow=True, dialog_type=0, tooltiptxt="Select music VAE model") makecheckbox(audio_tab, "Music Low VRAM", musiclowvram_var, 38, 0,tooltiptxt="Unload music models when not in use.") +# RPC Server Tab + rpc_tab = tabcontent["RPC Server"] + + ctk.CTkLabel( + rpc_tab, + text="RPC Server Configuration", + fg_color="transparent", + text_color="#5DA5E5", + font=("Helvetica", 14, "bold"), + ).grid(row=0, column=0, columnspan=4, sticky="w", padx=0, pady=10) + + makecheckbox( + rpc_tab, + "Start RPC Server Mode", + rpc_server_mode_var, + 1, + 0, + tooltiptxt="Start KoboldCPP in RPC server mode. Exposes local GPUs for remote clients to use.", + ) + + # Backend selector (row 2-3, column 0) + makelabel( + rpc_tab, + "RPC Server Backend:", + 2, + 0, + padx=0, + ) + makelabel( + rpc_tab, + "GPU backend for RPC server. Must match device names below.", + 3, + 0, + padx=0, + ) + rpc_backend_options = ["Auto-detect", "Vulkan", "hipBLAS (ROCm)", "CUDA"] + rpc_server_backend_dropdown = ctk.CTkComboBox( + rpc_tab, + values=rpc_backend_options, + variable=rpc_server_backend_var, + width=200, + state="readonly", + ) + rpc_server_backend_dropdown.grid( + row=4, column=0, sticky="w", padx=0, pady=0 + ) + makelabel( + rpc_tab, + "Vulkan: Use VULKAN0,VULKAN1... | hipBLAS: Use ROCm0,ROCm1... | CUDA: Use CUDA0,CUDA1...", + 5, + 0, + padx=0, + ) + + # GPU ID selector (row 4, column 1) - synced with Hardware and Quick Launch + makelabel( + rpc_tab, + "GPU ID:", + 2, + 0, + padx=260, + ) + rpc_gpu_selector_box = ctk.CTkComboBox( + rpc_tab, + values=CUDevices, + variable=gpu_choice_var, + width=60, + state="readonly", + ) + rpc_gpu_selector_box.grid( + row=4, column=0, sticky="w", padx=260, pady=0 + ) + rpc_gpuname_label = ctk.CTkLabel(rpc_tab, text="") + rpc_gpuname_label.grid( + row=4, column=1, sticky="w", padx=85, pady=5 + ) + rpc_gpuname_label.configure(text_color="#ffff00") + + # Listening IP Address (row 6-7, column 0) + makelabel( + rpc_tab, + "Listening IP Address:", + 6, + 0, + padx=0, + ) + makelabel( + rpc_tab, + "IP address for RPC server to listen on. Use 0.0.0.0 for all interfaces, 127.0.0.1 for localhost only.", + 7, + 0, + padx=0, + ) + makelabelentry( + rpc_tab, + "", + rpc_host_var, + 8, + 150, + singleline=True, + ) + + makelabel( + rpc_tab, + "Listening Port number for RPC server connections (default: 50053).", + 9, + 0, + padx=0, + ) + makelabelentry( + rpc_tab, + "", + rpc_port_var, + 10, + 100, + singleline=True, + padx=7, + ) + + # RPC Devices (row 9-10, column 0) + makelabel( + rpc_tab, + "RPC Devices:", + 12, + 0, + padx=0, + ) + makelabel( + rpc_tab, + "Comma-separated list of GPUs to expose via RPC. Leave empty to auto-detect all devices.", + 13, + 0, + padx=0, + ) + makelabelentry( + rpc_tab, + "", + rpc_devices_var, + 14, + 200, + singleline=True, + ) + + # Allow Launch Without Models (row 12, column 0) + makecheckbox( + rpc_tab, + "Allow Launch Without Models", + nomodel, + 15, + 0, + tooltiptxt="Allows starting RPC server without loading a model file.", + ) + + # Cache Layers Locally checkbox (row 16, column 0) + makecheckbox( + rpc_tab, + "Cache Layers Locally (-c)", + rpc_cache_layers_var, + 16, + 0, + padx=7, + tooltiptxt="Save layers locally when using RPC server. Equivalent to the '-c' flag.", + ) + + # Launch Browser checkbox (row 17, column 0) + makecheckbox( + rpc_tab, + "Launch Browser", + launchbrowser, + 15, + 0, + padx=200, + tooltiptxt="Enable/Disable Browserlaunch", + ) + + ctk.CTkLabel( + rpc_tab, + text="WARNING: RPC Server mode replaces the WebUI API. Clients must connect via --rpc.", + fg_color="transparent", + text_color="red", + font=("Helvetica", 15, "bold"), + ).grid(row=18, column=0, columnspan=3, sticky="w", padx=0, pady=10) + admin_tab = tabcontent["Admin"] def toggleadmin(a,b,c): if admin_var.get()==1 and admin_dir_var.get()=="": @@ -8649,12 +8999,12 @@ def export_vars(): args.failsafe = False if gpu_choice_var.get()!="All": gpuchoiceidx = int(gpu_choice_var.get())-1 - if runopts_var.get() == "Use CUDA" or runopts_var.get() == "Use hipBLAS (ROCm)": + if runopts_var.get() == "Use CUDA" or runopts_var.get() == "Use hipBLAS (ROCm)" or runopts_var.get() == "Use CUDA + RPC" or runopts_var.get() == "Use hipBLAS + RPC": if gpu_choice_var.get()=="All": args.usecuda = ["normal"] else: args.usecuda = ["normal",str(gpuchoiceidx)] - if runopts_var.get() == "Use Vulkan" or runopts_var.get() == "Use Vulkan (Old CPU)" or runopts_var.get() == "Use Vulkan (Older CPU)": + if runopts_var.get() == "Use Vulkan" or runopts_var.get() == "Use Vulkan (Old CPU)" or runopts_var.get() == "Use Vulkan (Older CPU)" or runopts_var.get() == "Use Vulkan + RPC": if gpu_choice_var.get()=="All": args.usevulkan = [] else: @@ -8841,6 +9191,30 @@ def export_vars(): args.autoswapmode = (autoswap_mode_var.get()==1 and router_mode_var.get()==1 and admin_var.get()==1) args.baseconfig = baseconfig_var.get() args.adminunloadtimeout = (0 if admin_unload_timeout_var.get()=="" else int(admin_unload_timeout_var.get())) + args.rpc_server = (rpc_server_mode_var.get()==1) + args.rpc_host = rpc_host_var.get() + args.rpc_port = int(rpc_port_var.get()) + args.rpc_device = rpc_devices_var.get() + args.rpc_server_backend = rpc_server_backend_var.get() + args.rpc_cache_layers = (rpc_cache_layers_var.get()==1) + if rpc_endpoint_var.get() != "": + rpc_eps = rpc_endpoint_var.get() + if ',' in rpc_eps: + args.rpc = rpc_eps.split(',') + else: + args.rpc = [rpc_eps] + else: + args.rpc = None + + # Set backend flags based on RPC server backend selection + if args.rpc_server: + if args.rpc_server_backend == "hipBLAS (ROCm)": + args.usecuda = ["all"] # Use hipBLAS + elif args.rpc_server_backend == "CUDA": + args.usecuda = ["all"] # Use CUDA + elif args.rpc_server_backend == "Vulkan": + args.usevulkan = [] + args.showgui = False #prevent showgui from leaking into configs, its cli only def import_vars(mydict): @@ -8930,6 +9304,13 @@ def import_vars(mydict): elif "usecpu" in mydict and mydict["usecpu"]: if default_option is not None: runopts_var.set(default_option) + if "rpc_endpoint" in mydict and mydict["rpc_endpoint"]: + if rpc_option is not None: + runopts_var.set(rpc_option) + elif hipblas_rpc_option is not None: + runopts_var.set(hipblas_rpc_option) + elif cublas_rpc_option is not None: + runopts_var.set(cublas_rpc_option) if "gpulayers" in mydict and mydict["gpulayers"]: gpulayers_var.set(mydict["gpulayers"]) else: @@ -9129,6 +9510,25 @@ def import_vars(mydict): admin_unload_timeout_var.set(mydict["adminunloadtimeout"] if ("adminunloadtimeout" in mydict and mydict["adminunloadtimeout"]) else 0) singleinstance_var.set(mydict["singleinstance"] if ("singleinstance" in mydict) else 0) + rpc_server_mode_var.set(mydict["rpc_server"] if ("rpc_server" in mydict) else 0) + rpc_host_var.set(mydict["rpc_host"] if ("rpc_host" in mydict and mydict["rpc_host"]) else "0.0.0.0") + rpc_port_var.set(mydict["rpc_port"] if ("rpc_port" in mydict and mydict["rpc_port"]) else "50053") + rpc_devices_var.set(mydict["rpc_device"] if ("rpc_device" in mydict and mydict["rpc_device"]) else "") + rpc_server_backend_var.set(mydict["rpc_server_backend"] if ("rpc_server_backend" in mydict and mydict["rpc_server_backend"]) else "Auto-detect") + rpc_cache_layers_var.set(mydict["rpc_cache_layers"] if ("rpc_cache_layers" in mydict) else 0) + if "rpc_endpoint" in mydict and mydict["rpc_endpoint"]: + if isinstance(mydict["rpc_endpoint"], list): + rpc_endpoint_var.set(','.join(mydict["rpc_endpoint"])) + args.rpc = mydict["rpc_endpoint"] + else: + rpc_endpoint_var.set(str(mydict["rpc_endpoint"])) + if ',' in str(mydict["rpc_endpoint"]): + args.rpc = str(mydict["rpc_endpoint"]).split(',') + else: + args.rpc = [str(mydict["rpc_endpoint"])] + else: + args.rpc = None + importvars_in_progress = False gui_changed_modelfile() if "istemplate" in mydict and mydict["istemplate"]: @@ -9142,6 +9542,7 @@ def save_config_gui(): for key in deprecated_keys: savdict.pop(key, None) # avoids KeyError if missing savdict["istemplate"] = False + savdict["rpc_endpoint"] = rpc_endpoint_var.get() file_type = [("KoboldCpp Settings", "*.kcpps")] filename = zentk_asksaveasfilename(filetypes=file_type, defaultextension=".kcpps",title="Save kcpps settings config file") if not filename: @@ -9722,6 +10123,14 @@ def load_config_cli(filename): print(f"Overriding Config Value: {key}") else: setattr(args, key, value) + # Convert rpc_endpoint to args.rpc for RPC client mode + if hasattr(args, 'rpc_endpoint') and args.rpc_endpoint: + if isinstance(args.rpc_endpoint, list): + args.rpc = args.rpc_endpoint + elif ',' in str(args.rpc_endpoint): + args.rpc = str(args.rpc_endpoint).split(',') + else: + args.rpc = [str(args.rpc_endpoint)] if args.istemplate: print("\nA .kcppt template was selected from CLI...") if (args.usecuda is None) and (args.usevulkan is None): @@ -10861,7 +11270,106 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False): except Exception: print("Unable to determine available RAM") + # Set RPC endpoints environment variable for client mode + if args.rpc is not None and len(args.rpc) > 0: + rpc_endpoints = ','.join(args.rpc) + os.environ['LLAMA_ARG_RPC'] = rpc_endpoints + print(f"RPC Client Mode: Connecting to servers: {rpc_endpoints}") + init_library() # Note: if blas does not exist and is enabled, program will crash. + + # Start RPC server if requested + if args.rpc_server: + print("==========") + print("Starting RPC Server Mode...") + print(f"RPC Host: {args.rpc_host}") + print(f"RPC Port: {args.rpc_port}") + print(f"RPC Device: {args.rpc_device if args.rpc_device else 'Auto-detect'}") + print(f"RPC Library: {libname}") + print("==========") + + # Determine RPC server binary based on library and user selection + rpc_server_bin = "" + + # Check command line arguments for backend preference + # Also detect ROCm devices in rpc_device string + use_hipblas = args.usecuda is not None and ('all' in args.usecuda or len(args.usecuda) == 0) + use_cuda = args.usecuda is not None and ('all' in args.usecuda or len(args.usecuda) == 0) + use_vulkan = args.usevulkan is not None + + # Check GUI RPC server backend selection + if hasattr(args, 'rpc_server_backend') and args.rpc_server_backend: + if args.rpc_server_backend == "hipBLAS (ROCm)": + use_hipblas = True + print(f"Using HIPBLAS backend from GUI selection") + elif args.rpc_server_backend == "CUDA": + use_cuda = True + print(f"Using CUDA backend from GUI selection") + elif args.rpc_server_backend == "Vulkan": + use_vulkan = True + print(f"Using Vulkan backend from GUI selection") + + # Auto-detect backend from rpc_device string if not explicitly set + if args.rpc_device and not use_hipblas and not use_cuda and not use_vulkan: + if "ROCm" in args.rpc_device or "HIP" in args.rpc_device.upper(): + use_hipblas = True + print("Auto-detected HIPBLAS backend from RPC device string") + elif "CUDA" in args.rpc_device.upper(): + use_cuda = True + print("Auto-detected CUDA backend from RPC device string") + elif "VULKAN" in args.rpc_device.upper(): + use_vulkan = True + print("Auto-detected Vulkan backend from RPC device string") + + # Try to select appropriate RPC server + if use_hipblas and os.path.exists("./rpc-server-hip"): + rpc_server_bin = "./rpc-server-hip" + print("Using HIPBLAS RPC Server") + elif use_cuda and os.path.exists("./rpc-server-cuda"): + rpc_server_bin = "./rpc-server-cuda" + print("Using CUDA RPC Server") + elif use_vulkan and os.path.exists("./rpc-server-vulkan"): + rpc_server_bin = "./rpc-server-vulkan" + print("Using Vulkan RPC Server") + else: + # Auto-detect based on available binaries and system + if os.path.exists("./rpc-server-hip") and libname and "hipblas" in libname.lower(): + rpc_server_bin = "./rpc-server-hip" + print("Using HIPBLAS RPC Server (Auto-detected)") + elif os.path.exists("./rpc-server-cuda") and libname and "cublas" in libname.lower(): + rpc_server_bin = "./rpc-server-cuda" + print("Using CUDA RPC Server (Auto-detected)") + elif os.path.exists("./rpc-server-vulkan"): + rpc_server_bin = "./rpc-server-vulkan" + print("Using Vulkan RPC Server (Auto-detected)") + else: + exit_with_error(2, "No RPC server binary found. Please build RPC server with appropriate backend.") + + if not os.path.exists(rpc_server_bin): + exit_with_error(2, f"RPC server binary not found: {rpc_server_bin}\nPlease build RPC server with: make LLAMA_RPC=1 LLAMA_VULKAN=1 rpc-server-vulkan") + + # Build RPC server command + rpc_cmd = [rpc_server_bin, "-H", args.rpc_host, "--port", str(args.rpc_port)] + if args.rpc_device and args.rpc_device != "": + rpc_cmd.extend(["--device", args.rpc_device]) + if hasattr(args, 'rpc_cache_layers') and args.rpc_cache_layers: + rpc_cmd.append("-c") + + print(f"Starting RPC server: {' '.join(rpc_cmd)}") + + # Start RPC server as subprocess + try: + rpc_process = subprocess.Popen(rpc_cmd) + print(f"RPC server started with PID: {rpc_process.pid}") + print("Press Ctrl+C to stop RPC server") + + # Wait for RPC server to exit + rpc_process.wait() + except Exception as e: + exit_with_error(2, f"Failed to start RPC server: {e}") + + return # Exit after RPC server stops + print("==========") time.sleep(1) @@ -10872,11 +11380,20 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False): print("==========") #handle loading text model + # RPC server mode doesn't require a model file + if args.rpc_server and not args.model_param: + print("RPC Server Mode: No model file required") + args.nomodel = True + if args.model_param: if not os.path.exists(args.model_param): if args.ignoremissing: print(f"Ignoring missing model file: {args.model_param}") args.model_param = None + elif args.rpc_server: + print("RPC Server Mode: Starting without model file") + args.model_param = None + args.nomodel = True else: exitcounter = 999 exit_with_error(2,f"Cannot find text model file: {args.model_param}") @@ -11474,6 +11991,15 @@ def range_checker(arg: str): compatgroup.add_argument("--usecuda", "--usecublas", "--usehipblas", help="Use CUDA for GPU Acceleration. Requires CUDA. Enter a number afterwards to select and use 1 GPU. Leaving no number will use all GPUs.", nargs='*',metavar=('[main GPU ID]'), choices=['0','1','2','3','all', 'mmq','nommq','normal','lowvram','rowsplit']) compatgroup.add_argument("--usevulkan", help="Use Vulkan for GPU Acceleration. Can optionally specify one or more GPU Device ID (e.g. --usevulkan 0), leave blank to autodetect.", metavar=('[Device IDs]'), nargs='*', type=int, default=None) compatgroup.add_argument("--usecpu", help="Do not use any GPU acceleration (CPU Only)", action='store_true') + + # RPC arguments + parser.add_argument("--rpc", help="Connect to RPC server endpoint (e.g., 127.0.0.1:50053). Multiple can be specified for multi-machine inference.", metavar=('[endpoint]'), nargs='+', default=None) + parser.add_argument("--rpc-server", help="Start RPC server mode instead of client mode", action='store_true') + parser.add_argument("--rpc-host", help="RPC server host to listen on (default: 0.0.0.0)", default="0.0.0.0") + parser.add_argument("--rpc-port", help="RPC server port (default: 50053)", type=int, default=50053) + parser.add_argument("--rpc-device", help="RPC server device string (e.g., Vulkan0,Vulkan1)", default="") + parser.add_argument("--rpc-cache-layers", help="Save layers locally when using RPC server (equivalent to -c flag)", action='store_true') + parser.add_argument("--contextsize","--ctx-size", "-c", help="Controls the memory allocated for maximum context size, only change if you need more RAM for big contexts. (default 8192).",metavar=('[256 to 262144]'), type=check_range(int,256,262144), default=8192) parser.add_argument("--gpulayers","--gpu-layers","--n-gpu-layers","-ngl", help="Set number of layers to offload to GPU when using GPU. Requires GPU. Set to -1 to try autodetect, set to 0 to disable GPU offload.",metavar=('[GPU layers]'), nargs='?', const=1, type=int, default=-1) parser.add_argument("--tensor_split","--tensorsplit","--tensor-split","-ts", help="For CUDA and Vulkan only, ratio to split tensors across multiple GPUs, space-separated list of proportions, e.g. 7 3", metavar=('[Ratios]'), type=float, nargs='+') diff --git a/koboldcpp.so b/koboldcpp.so new file mode 120000 index 00000000000..8675c60a12d --- /dev/null +++ b/koboldcpp.so @@ -0,0 +1 @@ +koboldcpp_vulkan.so \ No newline at end of file diff --git a/koboldcpp_hipblas_rpc.so b/koboldcpp_hipblas_rpc.so new file mode 100755 index 00000000000..965fbb8d415 Binary files /dev/null and b/koboldcpp_hipblas_rpc.so differ diff --git a/koboldcpp_rpc.so b/koboldcpp_rpc.so new file mode 100755 index 00000000000..ef9103d53bd Binary files /dev/null and b/koboldcpp_rpc.so differ diff --git a/rpc-server-hip b/rpc-server-hip new file mode 100755 index 00000000000..f7dc1a14e8d Binary files /dev/null and b/rpc-server-hip differ diff --git a/rpc-server-vulkan b/rpc-server-vulkan new file mode 100755 index 00000000000..e06dcff9e1c Binary files /dev/null and b/rpc-server-vulkan differ diff --git a/smartbuild.sh b/smartbuild.sh new file mode 100755 index 00000000000..55c583ec8c9 --- /dev/null +++ b/smartbuild.sh @@ -0,0 +1,115 @@ +#!/bin/bash +# KoboldCpp Smart Build Script +# Auto-detects hardware and builds appropriate backends + +set -e + +echo "=== KoboldCpp Smart Build ===" +echo "" + +# Detect hardware +HAS_AMD=0 +HAS_NVIDIA=0 +HAS_VULKAN=0 +HAS_INTEL_GPU=0 + +# Check for AMD/ROCm +if lspci -nn 2>/dev/null | grep -qi "1002:"; then + HAS_AMD=1 + echo "✓ AMD GPU detected" +fi + +# Check for NVIDIA +if lspci -nn 2>/dev/null | grep -qi "10de:"; then + HAS_NVIDIA=1 + echo "✓ NVIDIA GPU detected" +fi + +# Check for Intel GPU +if lspci -nn 2>/dev/null | grep -qi "8086:"; then + HAS_INTEL_GPU=1 + echo "✓ Intel GPU detected" +fi + +# Vulkan is generally available on Linux with GPU +if [ $HAS_AMD -eq 1 ] || [ $HAS_NVIDIA -eq 1 ] || [ $HAS_INTEL_GPU -eq 1 ]; then + HAS_VULKAN=1 + echo "✓ Vulkan support available" +fi + +echo "" + +# Clean previous builds +echo "Cleaning previous builds..." +make clean >/dev/null 2>&1 || true + +# Build standard backends first (always needed for koboldcpp.py) +echo "" +echo "=== Building Standard Backends ===" + +# Build CPU default (always) +echo "Building CPU backend (koboldcpp_default.so)..." +make koboldcpp_default -j$(nproc) 2>&1 | tail -5 + +# Build Vulkan if available +if [ $HAS_VULKAN -eq 1 ]; then + echo "" + echo "Building Vulkan backend (koboldcpp_vulkan.so)..." + make koboldcpp_vulkan -j$(nproc) LLAMA_VULKAN=1 2>&1 | tail -5 || echo "Vulkan build skipped" +fi + +# Build HIPBLAS if AMD detected +if [ $HAS_AMD -eq 1 ]; then + echo "" + echo "Building HIPBLAS backend (koboldcpp_hipblas.so)..." + make koboldcpp_hipblas -j$(nproc) LLAMA_HIPBLAS=1 2>&1 | tail -5 || echo "HIPBLAS build skipped (ROCm may not be installed)" +fi + +# Build CUDA if NVIDIA detected +if [ $HAS_NVIDIA -eq 1 ]; then + echo "" + echo "Building CUDA backend (koboldcpp_cublas.so)..." + make koboldcpp_cublas -j$(nproc) LLAMA_CUBLAS=1 2>&1 | tail -5 || echo "CUDA build skipped (CUDA may not be installed)" +fi + +# Build RPC backends if requested +if [ "$1" == "--rpc" ]; then + echo "" + echo "=== Building RPC Backends ===" + + # RPC Vulkan + if [ $HAS_VULKAN -eq 1 ]; then + echo "Building RPC Vulkan..." + make rpc-server-vulkan koboldcpp_rpc -j$(nproc) LLAMA_RPC=1 LLAMA_VULKAN=1 2>&1 | tail -5 || echo "RPC Vulkan skipped" + fi + + # RPC HIPBLAS + if [ $HAS_AMD -eq 1 ]; then + echo "Building RPC HIPBLAS..." + make rpc-server-hip koboldcpp_hipblas_rpc -j$(nproc) LLAMA_RPC=1 LLAMA_HIPBLAS=1 2>&1 | tail -5 || echo "RPC HIPBLAS skipped" + fi + + # RPC CUDA + if [ $HAS_NVIDIA -eq 1 ]; then + echo "Building RPC CUDA..." + make rpc-server-cuda koboldcpp_cublas_rpc -j$(nproc) LLAMA_RPC=1 LLAMA_CUBLAS=1 2>&1 | tail -5 || echo "RPC CUDA skipped" + fi +fi + +echo "" +echo "=== Build Summary ===" +echo "Standard backends:" +ls -lh koboldcpp_default.so koboldcpp_vulkan.so koboldcpp_hipblas.so koboldcpp_cublas.so 2>/dev/null || echo " (checking files...)" + +if [ "$1" == "--rpc" ]; then + echo "" + echo "RPC backends:" + ls -lh rpc-server-* koboldcpp_rpc.so 2>/dev/null || echo " (checking files...)" +fi + +echo "" +echo "=== Testing koboldcpp.py ===" +python ./koboldcpp.py --help >/dev/null 2>&1 && echo "✓ koboldcpp.py ready!" || echo "✗ koboldcpp.py has issues" + +echo "" +echo "Build complete!" diff --git a/start_rpc_servers.sh b/start_rpc_servers.sh new file mode 100755 index 00000000000..b2b97309333 --- /dev/null +++ b/start_rpc_servers.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# RPC Server Startup Script + +# Kill any existing RPC servers +echo "Stopping existing RPC servers..." +pkill -f rpc-server-vulkan +sleep 2 + +# Start RPC servers +echo "Starting RPC servers..." +nohup ./rpc-server-vulkan -H 0.0.0.0 --port 50053 --device VULKAN0 > /tmp/rpc1.log 2>&1 & +nohup ./rpc-server-vulkan -H 0.0.0.0 --port 50054 --device VULKAN1 > /tmp/rpc2.log 2>&1 & +sleep 3 + +# Verify servers are running +ps aux | grep rpc-server-vulkan | grep -v grep + +echo "" +echo "RPC servers started. You can now run:" +echo " python koboldcpp.py --config qwen08rpctest3.kcpps" diff --git a/tools/rpc-server.cpp b/tools/rpc-server.cpp new file mode 100644 index 00000000000..517eef7018c --- /dev/null +++ b/tools/rpc-server.cpp @@ -0,0 +1,349 @@ +#include "ggml-rpc.h" +#ifdef _WIN32 +# define NOMINMAX +# define DIRECTORY_SEPARATOR '\\' +# include +# include +# include +#else +# define DIRECTORY_SEPARATOR '/' +# include +# include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__linux__) +#include +#include +#endif + +// NOTE: this is copied from common.cpp to avoid linking with libcommon +#ifdef _WIN32 +static std::wstring utf8_to_wstring(const std::string & str) { + if (str.empty()) { + return std::wstring(); + } + + int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0); + + if (size <= 0) { + return std::wstring(); + } + + std::wstring wstr(size, 0); + MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size); + + return wstr; +} +#endif + +// NOTE: this is copied from common.cpp to avoid linking with libcommon +// returns true if successful, false otherwise +static bool fs_create_directory_with_parents(const std::string & path) { +#ifdef _WIN32 + std::wstring wpath = utf8_to_wstring(path); + + // if the path already exists, check whether it's a directory + const DWORD attributes = GetFileAttributesW(wpath.c_str()); + if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return true; + } + + size_t pos_slash = 0; + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { + const std::wstring subpath = wpath.substr(0, pos_slash); + + pos_slash += 1; + + // skip the drive letter, in some systems it can return an access denied error + if (subpath.length() == 2 && subpath[1] == ':') { + continue; + } + + const bool success = CreateDirectoryW(subpath.c_str(), NULL); + + if (!success) { + const DWORD error = GetLastError(); + + // if the path already exists, ensure that it's a directory + if (error == ERROR_ALREADY_EXISTS) { + const DWORD attributes = GetFileAttributesW(subpath.c_str()); + if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return false; + } + } else { + return false; + } + } + } + + return true; +#else + // if the path already exists, check whether it's a directory + struct stat info; + if (stat(path.c_str(), &info) == 0) { + return S_ISDIR(info.st_mode); + } + + size_t pos_slash = 1; // skip leading slashes for directory creation + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) { + const std::string subpath = path.substr(0, pos_slash); + struct stat info; + + // if the path already exists, ensure that it's a directory + if (stat(subpath.c_str(), &info) == 0) { + if (!S_ISDIR(info.st_mode)) { + return false; + } + } else { + // create parent directories + const int ret = mkdir(subpath.c_str(), 0755); + if (ret != 0) { + return false; + } + } + + pos_slash += 1; + } + + return true; +#endif // _WIN32 +} + +// NOTE: this is copied from common.cpp to avoid linking with libcommon +static std::string fs_get_cache_directory() { + std::string cache_directory = ""; + auto ensure_trailing_slash = [](std::string p) { + // Make sure to add trailing slash + if (p.back() != DIRECTORY_SEPARATOR) { + p += DIRECTORY_SEPARATOR; + } + return p; + }; + if (getenv("LLAMA_CACHE")) { + cache_directory = std::getenv("LLAMA_CACHE"); + } else { +#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \ + defined(__OpenBSD__) || defined(__NetBSD__) + if (std::getenv("XDG_CACHE_HOME")) { + cache_directory = std::getenv("XDG_CACHE_HOME"); + } else if (std::getenv("HOME")) { + cache_directory = std::getenv("HOME") + std::string("/.cache/"); + } else { +#if defined(__linux__) + /* no $HOME is defined, fallback to getpwuid */ + struct passwd *pw = getpwuid(getuid()); + if ((!pw) || (!pw->pw_dir)) { + throw std::runtime_error("Failed to find $HOME directory"); + } + + cache_directory = std::string(pw->pw_dir) + std::string("/.cache/"); +#else /* defined(__linux__) */ + throw std::runtime_error("Failed to find $HOME directory"); +#endif /* defined(__linux__) */ + } +#elif defined(__APPLE__) + cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); +#elif defined(_WIN32) + cache_directory = std::getenv("LOCALAPPDATA"); +#elif defined(__EMSCRIPTEN__) + GGML_ABORT("not implemented on this platform"); +#else +# error Unknown architecture +#endif + cache_directory = ensure_trailing_slash(cache_directory); + cache_directory += "llama.cpp"; + } + return ensure_trailing_slash(cache_directory); +} + +struct rpc_server_params { + std::string host = "127.0.0.1"; + int port = 50052; + bool use_cache = false; + int n_threads = std::max(1U, std::thread::hardware_concurrency()/2); + std::vector devices; +}; + +static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { + fprintf(stderr, "Usage: %s [options]\n\n", argv[0]); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -t, --threads N number of threads for the CPU device (default: %d)\n", params.n_threads); + fprintf(stderr, " -d, --device comma-separated list of devices\n"); + fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str()); + fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port); + fprintf(stderr, " -c, --cache enable local file cache\n"); + fprintf(stderr, "\n"); +} + +static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & params) { + std::string arg; + for (int i = 1; i < argc; i++) { + arg = argv[i]; + if (arg == "-H" || arg == "--host") { + if (++i >= argc) { + return false; + } + params.host = argv[i]; + } else if (arg == "-t" || arg == "--threads") { + if (++i >= argc) { + return false; + } + params.n_threads = std::stoi(argv[i]); + if (params.n_threads <= 0) { + fprintf(stderr, "error: invalid number of threads: %d\n", params.n_threads); + return false; + } + } else if (arg == "-d" || arg == "--device") { + if (++i >= argc) { + return false; + } + const std::regex regex{ R"([,/]+)" }; + std::string dev_str = argv[i]; + std::sregex_token_iterator iter(dev_str.begin(), dev_str.end(), regex, -1); + std::sregex_token_iterator end; + for ( ; iter != end; ++iter) { + try { + params.devices.push_back(*iter); + } catch (const std::exception & ) { + fprintf(stderr, "error: invalid device: %s\n", iter->str().c_str()); + return false; + } + } + } else if (arg == "-p" || arg == "--port") { + if (++i >= argc) { + return false; + } + params.port = std::stoi(argv[i]); + if (params.port <= 0 || params.port > 65535) { + return false; + } + } else if (arg == "-c" || arg == "--cache") { + params.use_cache = true; + } else if (arg == "-h" || arg == "--help") { + print_usage(argc, argv, params); + exit(0); + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + print_usage(argc, argv, params); + exit(0); + } + } + return true; +} + +static std::vector get_devices(const rpc_server_params & params) { + std::vector devices; + if (!params.devices.empty()) { + for (auto device : params.devices) { + ggml_backend_dev_t dev = ggml_backend_dev_by_name(device.c_str()); + if (dev) { + devices.push_back(dev); + } else { + fprintf(stderr, "error: unknown device: %s\n", device.c_str()); + fprintf(stderr, "available devices:\n"); + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + auto * dev = ggml_backend_dev_get(i); + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); + printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024); + } + return {}; + } + } + } + + // Try non-CPU devices first + if (devices.empty()) { + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) { + devices.push_back(dev); + } + } + } + + // If there are no accelerators, fallback to CPU device + if (devices.empty()) { + ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (dev) { + devices.push_back(dev); + } + } + + return devices; +} + +int main(int argc, char * argv[]) { + std::setlocale(LC_NUMERIC, "C"); + + ggml_backend_load_all(); + + rpc_server_params params; + if (!rpc_server_params_parse(argc, argv, params)) { + fprintf(stderr, "Invalid parameters\n"); + return 1; + } + + if (params.host != "127.0.0.1") { + fprintf(stderr, "\n"); + fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"); + fprintf(stderr, "WARNING: Host ('%s') is != '127.0.0.1'\n", params.host.c_str()); + fprintf(stderr, " Never expose the RPC server to an open network!\n"); + fprintf(stderr, " This is an experimental feature and is not secure!\n"); + fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"); + fprintf(stderr, "\n"); + } + + auto devices = get_devices(params); + if (devices.empty()) { + fprintf(stderr, "No devices found\n"); + return 1; + } + std::string endpoint = params.host + ":" + std::to_string(params.port); + const char * cache_dir = nullptr; + std::string cache_dir_str; + if (params.use_cache) { + cache_dir_str = fs_get_cache_directory() + "rpc/"; + if (!fs_create_directory_with_parents(cache_dir_str)) { + fprintf(stderr, "Failed to create cache directory: %s\n", cache_dir_str.c_str()); + return 1; + } + cache_dir = cache_dir_str.c_str(); + } + + // Explicitly register RPC backend (for static linking) + // Store in static variable to prevent linker optimization + static ggml_backend_reg_t rpc_reg = NULL; + if (!rpc_reg) { + rpc_reg = ggml_backend_rpc_reg(); + } + + ggml_backend_reg_t reg = rpc_reg; + if (!reg) { + fprintf(stderr, "Failed to find RPC backend\n"); + return 1; + } + + auto start_server_fn = (decltype(ggml_backend_rpc_start_server)*) ggml_backend_reg_get_proc_address(reg, "ggml_backend_rpc_start_server"); + if (!start_server_fn) { + fprintf(stderr, "Failed to obtain RPC backend start server function\n"); + return 1; + } + + start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), devices.data()); + return 0; +} diff --git a/verify_rpc_merge.sh b/verify_rpc_merge.sh new file mode 100755 index 00000000000..cc3b2f88f74 --- /dev/null +++ b/verify_rpc_merge.sh @@ -0,0 +1,115 @@ +#!/bin/bash +# RPC Merge Verification Script + +echo "=== RPC Merge Verification ===" +echo "" + +# Check RPC source files +echo "1. Checking RPC source files..." +if [ -f "ggml/src/ggml-rpc/ggml-rpc.cpp" ]; then + echo " ✓ ggml-rpc.cpp exists" +else + echo " ✗ ggml-rpc.cpp MISSING" + exit 1 +fi + +if [ -f "ggml/src/ggml-rpc/transport.cpp" ]; then + echo " ✓ transport.cpp exists" +else + echo " ✗ transport.cpp MISSING" + exit 1 +fi + +if [ -f "ggml/src/ggml-rpc/transport.h" ]; then + echo " ✓ transport.h exists" +else + echo " ✗ transport.h MISSING" + exit 1 +fi + +if [ -f "ggml/src/ggml-rpc/CMakeLists.txt" ]; then + echo " ✓ CMakeLists.txt exists" +else + echo " ✗ CMakeLists.txt MISSING" + exit 1 +fi + +# Check RPC server +echo "" +echo "2. Checking RPC server tool..." +if [ -f "tools/rpc-server.cpp" ]; then + echo " ✓ rpc-server.cpp exists" +else + echo " ✗ rpc-server.cpp MISSING" + exit 1 +fi + +# Check Makefile RPC support +echo "" +echo "3. Checking Makefile RPC support..." +if grep -q "LLAMA_RPC" Makefile; then + echo " ✓ LLAMA_RPC flag defined" +else + echo " ✗ LLAMA_RPC flag MISSING" + exit 1 +fi + +if grep -q "rpc-server-vulkan" Makefile; then + echo " ✓ rpc-server-vulkan target defined" +else + echo " ✗ rpc-server-vulkan target MISSING" + exit 1 +fi + +if grep -q "rpc-full-all" Makefile; then + echo " ✓ rpc-full-all target defined" +else + echo " ✗ rpc-full-all target MISSING" + exit 1 +fi + +# Check RPC header +echo "" +echo "4. Checking RPC header..." +if [ -f "ggml/include/ggml-rpc.h" ]; then + if grep -q "ggml_backend_rpc_init" ggml/include/ggml-rpc.h; then + echo " ✓ ggml-rpc.h exists with RPC API" + else + echo " ✗ ggml-rpc.h MISSING RPC API" + exit 1 + fi +else + echo " ✗ ggml-rpc.h MISSING" + exit 1 +fi + +# Check backend registration +echo "" +echo "5. Checking backend registration..." +if grep -q "ggml_backend_rpc_reg()" ggml/src/ggml-backend-reg.cpp; then + echo " ✓ RPC backend registration found" +else + echo " ✗ RPC backend registration MISSING" + exit 1 +fi + +# Check command-line argument support +echo "" +echo "6. Checking command-line argument support..." +if grep -q "add_rpc_devices" common/arg.cpp; then + echo " ✓ RPC command-line argument support found" +else + echo " ✗ RPC command-line argument support MISSING" + exit 1 +fi + +echo "" +echo "=== All Checks Passed! ===" +echo "" +echo "RPC functionality has been successfully merged." +echo "" +echo "To build RPC components:" +echo " make rpc-full-all # Build all RPC components" +echo " make LLAMA_VULKAN=1 LLAMA_RPC=1 rpc-server-vulkan # Vulkan RPC server" +echo " make LLAMA_VULKAN=1 LLAMA_RPC=1 koboldcpp_rpc # Vulkan RPC client" +echo ""