From 7d2f3c4371f8689260061b4b98f0f15dff9d95f7 Mon Sep 17 00:00:00 2001 From: Ryan Ofsky Date: Wed, 30 Apr 2025 08:39:29 -0400 Subject: [PATCH] Add windows support Add support for running on windows. These changes make the libmultiprocess API more generic, using stream types instead of file descriptors. All features are supported, including spawning processes with socket connections to the parent process. These changes were originally made in https://github.com/bitcoin/bitcoin/pull/32387 --- example/calculator.cpp | 9 +-- example/example.cpp | 10 +-- example/printer.cpp | 9 +-- include/mp/proxy-io.h | 29 +++++--- include/mp/util.h | 44 ++++++++--- src/mp/proxy.cpp | 92 ++++++++++++++++++----- src/mp/util.cpp | 164 +++++++++++++++++++++++++++++++++++++---- 7 files changed, 286 insertions(+), 71 deletions(-) diff --git a/example/calculator.cpp b/example/calculator.cpp index 86ce388b..73dd56fb 100644 --- a/example/calculator.cpp +++ b/example/calculator.cpp @@ -51,14 +51,11 @@ int main(int argc, char** argv) std::cout << "Usage: mpcalculator \n"; return 1; } - int fd; - if (std::from_chars(argv[1], argv[1] + strlen(argv[1]), fd).ec != std::errc{}) { - std::cerr << argv[1] << " is not a number or is larger than an int\n"; - return 1; - } + mp::SocketId socket{mp::StartSpawned(argv[1])}; mp::EventLoop loop("mpcalculator", LogPrint); std::unique_ptr init = std::make_unique(); - mp::ServeStream(loop, fd, *init); + mp::Stream stream{loop.m_io_context.lowLevelProvider->wrapSocketFd(socket, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP)}; + mp::ServeStream(loop, kj::mv(stream), *init); loop.loop(); return 0; } diff --git a/example/example.cpp b/example/example.cpp index 38313977..f6fe68ec 100644 --- a/example/example.cpp +++ b/example/example.cpp @@ -25,14 +25,14 @@ namespace fs = std::filesystem; static auto Spawn(mp::EventLoop& loop, const std::string& process_argv0, const std::string& new_exe_name) { - int pid; - const int fd = mp::SpawnProcess(pid, [&](int fd) -> std::vector { + auto pair{mp::SocketPair()}; + mp::ProcessId pid{mp::SpawnProcess(pair[0], [&](mp::ConnectInfo info) -> std::vector { fs::path path = process_argv0; path.remove_filename(); path.append(new_exe_name); - return {path.string(), std::to_string(fd)}; - }); - return std::make_tuple(mp::ConnectStream(loop, fd), pid); + return {path.string(), std::move(info)}; + })}; + return std::make_tuple(mp::ConnectStream(loop, loop.m_io_context.lowLevelProvider->wrapSocketFd(pair[1])), pid); } static void LogPrint(mp::LogMessage log_data) diff --git a/example/printer.cpp b/example/printer.cpp index 9150d59b..03b67d3d 100644 --- a/example/printer.cpp +++ b/example/printer.cpp @@ -44,14 +44,11 @@ int main(int argc, char** argv) std::cout << "Usage: mpprinter \n"; return 1; } - int fd; - if (std::from_chars(argv[1], argv[1] + strlen(argv[1]), fd).ec != std::errc{}) { - std::cerr << argv[1] << " is not a number or is larger than an int\n"; - return 1; - } + mp::SocketId socket{mp::StartSpawned(argv[1])}; mp::EventLoop loop("mpprinter", LogPrint); std::unique_ptr init = std::make_unique(); - mp::ServeStream(loop, fd, *init); + mp::Stream stream{loop.m_io_context.lowLevelProvider->wrapSocketFd(socket, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP)}; + mp::ServeStream(loop, std::move(stream), *init); loop.loop(); return 0; } diff --git a/include/mp/proxy-io.h b/include/mp/proxy-io.h index f7367468..89278081 100644 --- a/include/mp/proxy-io.h +++ b/include/mp/proxy-io.h @@ -185,6 +185,17 @@ class Logger std::string LongThreadName(const char* exe_name); +using Stream = kj::Own; + +inline SocketId StreamSocketId(const Stream& stream) +{ + if (stream) KJ_IF_MAYBE(fd, stream->getFd()) return *fd; +#ifdef WIN32 + if (stream) KJ_IF_MAYBE(handle, stream->getWin32Handle()) return reinterpret_cast(*handle); +#endif + throw std::logic_error("Stream socket unset"); +} + //! Event loop implementation. //! //! Cap'n Proto threading model is very simple: all I/O operations are @@ -283,11 +294,12 @@ class EventLoop //! Callback functions to run on async thread. std::optional m_async_fns MP_GUARDED_BY(m_mutex); - //! Pipe read handle used to wake up the event loop thread. - int m_wait_fd = -1; + //! Socket pair used to post and wait for wakeups to the event loop thread. + kj::Own m_wait_stream; + kj::Own m_post_stream; - //! Pipe write handle used to wake up the event loop thread. - int m_post_fd = -1; + //! Synchronous writer used to write to m_post_stream. + kj::Own m_post_writer; //! Number of clients holding references to ProxyServerBase objects that //! reference this event loop. @@ -679,13 +691,11 @@ struct ThreadContext //! over the stream. Also create a new Connection object embedded in the //! client that is freed when the client is closed. template -std::unique_ptr> ConnectStream(EventLoop& loop, int fd) +std::unique_ptr> ConnectStream(EventLoop& loop, kj::Own stream) { typename InitInterface::Client init_client(nullptr); std::unique_ptr connection; loop.sync([&] { - auto stream = - loop.m_io_context.lowLevelProvider->wrapSocketFd(fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP); connection = std::make_unique(loop, kj::mv(stream)); init_client = connection->m_rpc_system->bootstrap(ServerVatId().vat_id).castAs(); Connection* connection_ptr = connection.get(); @@ -735,10 +745,9 @@ void _Listen(EventLoop& loop, kj::Own&& listener, InitIm //! Given stream file descriptor and an init object, handle requests on the //! stream by calling methods on the Init object. template -void ServeStream(EventLoop& loop, int fd, InitImpl& init) +void ServeStream(EventLoop& loop, kj::Own stream, InitImpl& init) { - _Serve( - loop, loop.m_io_context.lowLevelProvider->wrapSocketFd(fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP), init); + _Serve(loop, kj::mv(stream), init); } //! Given listening socket file descriptor and an init object, handle incoming diff --git a/include/mp/util.h b/include/mp/util.h index e5b4dd19..e3f8f589 100644 --- a/include/mp/util.h +++ b/include/mp/util.h @@ -19,6 +19,10 @@ #include #include +#ifdef WIN32 +#include +#endif + namespace mp { //! Generic utility functions used by capnp code. @@ -216,22 +220,44 @@ std::string ThreadName(const char* exe_name); //! errors in python unit tests. std::string LogEscape(const kj::StringTree& string, size_t max_size); +#ifdef WIN32 +using ProcessId = uintptr_t; +using SocketId = uintptr_t; +constexpr SocketId SocketError{INVALID_SOCKET}; +#else +using ProcessId = int; +using SocketId = int; +constexpr SocketId SocketError{-1}; +#endif + +//! Information about parent process passed to child process. On unix this is +//! just the inherited int file descriptor formatted as a string. On windows, +//! this is a path to a named path pipe the parent process will write +//! WSADuplicateSocket info to. +using ConnectInfo = std::string; + //! Callback type used by SpawnProcess below. -using FdToArgsFn = std::function(int fd)>; +using ConnectInfoToArgsFn = std::function(const ConnectInfo&)>; + +//! Create a socket pair that can be used to communicate within a process or +//! between parent and child processes. +std::array SocketPair(); + +//! Spawn a new process that communicates with the current process over provided +//! socket argument. Calls connect_info_to_args callback with a connection +//! string that needs to be passed to the child process, and executes the +//! argv command line it returns. Returns child process id. +ProcessId SpawnProcess(SocketId socket, ConnectInfoToArgsFn&& connect_info_to_args); -//! Spawn a new process that communicates with the current process over a socket -//! pair. Returns pid through an output argument, and file descriptor for the -//! local side of the socket. Invokes fd_to_args callback with the remote file -//! descriptor number which returns the command line arguments that should be -//! used to execute the process, and which should have the remote file -//! descriptor embedded in whatever format the child process expects. -int SpawnProcess(int& pid, FdToArgsFn&& fd_to_args); +//! Initialize spawned child process using the ConnectInfo string passed to it, +//! returning a socket id for communicating with the parent process. +SocketId StartSpawned(const ConnectInfo& connect_info); //! Call execvp with vector args. void ExecProcess(const std::vector& args); //! Wait for a process to exit and return its exit code. -int WaitProcess(int pid); +int WaitProcess(ProcessId pid); inline char* CharCast(char* c) { return c; } inline char* CharCast(unsigned char* c) { return (char*)c; } diff --git a/src/mp/proxy.cpp b/src/mp/proxy.cpp index 57545d37..8af57f85 100644 --- a/src/mp/proxy.cpp +++ b/src/mp/proxy.cpp @@ -30,12 +30,15 @@ #include #include #include -#include #include #include -#include #include +#ifndef WIN32 +#include +#include +#endif + namespace mp { thread_local ThreadContext g_thread_context; @@ -66,10 +69,9 @@ void EventLoopRef::reset(bool relock) MP_NO_TSA loop->m_num_clients -= 1; if (loop->done()) { loop->m_cv.notify_all(); - int post_fd{loop->m_post_fd}; loop_lock->unlock(); char buffer = 0; - KJ_SYSCALL(write(post_fd, &buffer, 1)); // NOLINT(bugprone-suspicious-semicolon) + loop->m_post_writer->write(&buffer, 1); // By default, do not try to relock `loop_lock` after writing, // because the event loop could wake up and destroy itself and the // mutex might no longer exist. @@ -96,6 +98,20 @@ Connection::~Connection() // after the calls finish. m_rpc_system.reset(); + // shutdownWrite is needed on Windows so pending data in the m_stream socket + // will be sent instead of discarded when m_stream is destroyed. On unix, + // this doesn't seem to be needed because data is sent more reliably. + // + // Sending pending data is important if the connection is a socketpair + // because when one side of the socketpair is closed, the other side doesn't + // seem to receive any onDisconnect event. So it is important for the other + // side to instead receive Cap'n Proto "release" messages (see `struct + // Release` in capnp/rpc.capnp) from local Client objects being being + // destroyed so the remote side can free resources and shut down cleanly. + // Without this call, Server objects corresponding to the Client objects on + // the other side of the connection are not freed by Cap'n Proto. + m_stream->shutdownWrite(); + // ProxyClient cleanup handlers are in sync list, and ProxyServer cleanup // handlers are in the async list. // @@ -192,6 +208,40 @@ void EventLoop::addAsyncCleanup(std::function fn) startAsyncThread(); } +#ifdef WIN32 +//! Synchronous socket output stream. Cap'n Proto library only provides limited +//! support for synchronous IO. It provides `FdOutputStream` which wraps unix +//! file descriptors and calls write() internally, and `HandleOutStream` which +//! wraps windows HANDLE values and calls WriteFile() internally. This class +//! just provides analagous functionality wrapping SOCKET values and calls +//! send() internally. +class SocketOutputStream : public kj::OutputStream { +public: + explicit SocketOutputStream(SOCKET socket) : m_socket(socket) {} + + void write(const void* buffer, size_t size) override; + +private: + SOCKET m_socket; +}; + +static constexpr size_t WRITE_CLAMP_SIZE = 1u << 30; // 1GB clamp for Windows, like FdOutputStream + +void SocketOutputStream::write(const void* buffer, size_t size) { + const char* pos = reinterpret_cast(buffer); + + while (size > 0) { + int n = send(m_socket, pos, static_cast(kj::min(size, WRITE_CLAMP_SIZE)), 0); + + KJ_WIN32(n != SOCKET_ERROR, "send() failed"); + KJ_ASSERT(n > 0, "send() returned zero."); + + pos += n; + size -= n; + } +} +#endif + EventLoop::EventLoop(const char* exe_name, LogOptions log_opts, void* context) : m_exe_name(exe_name), m_io_context(kj::setupAsyncIo()), @@ -199,10 +249,18 @@ EventLoop::EventLoop(const char* exe_name, LogOptions log_opts, void* context) m_log_opts(std::move(log_opts)), m_context(context) { - int fds[2]; - KJ_SYSCALL(socketpair(AF_UNIX, SOCK_STREAM, 0, fds)); - m_wait_fd = fds[0]; - m_post_fd = fds[1]; + auto pipe = m_io_context.provider->newTwoWayPipe(); + m_wait_stream = kj::mv(pipe.ends[0]); + m_post_stream = kj::mv(pipe.ends[1]); + KJ_IF_MAYBE(fd, m_post_stream->getFd()) { + m_post_writer = kj::heap(*fd); +#ifdef WIN32 + } else KJ_IF_MAYBE(handle, m_post_stream->getWin32Handle()) { + m_post_writer = kj::heap(reinterpret_cast(*handle)); +#endif + } else { + throw std::logic_error("Could not get file descriptor for new pipe."); + } } EventLoop::~EventLoop() @@ -211,8 +269,8 @@ EventLoop::~EventLoop() const Lock lock(m_mutex); KJ_ASSERT(m_post_fn == nullptr); KJ_ASSERT(!m_async_fns); - KJ_ASSERT(m_wait_fd == -1); - KJ_ASSERT(m_post_fd == -1); + KJ_ASSERT(!m_wait_stream); + KJ_ASSERT(!m_post_stream); KJ_ASSERT(m_num_clients == 0); // Spin event loop. wait for any promises triggered by RPC shutdown. @@ -232,9 +290,7 @@ void EventLoop::loop() m_async_fns.emplace(); } - kj::Own wait_stream{ - m_io_context.lowLevelProvider->wrapSocketFd(m_wait_fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP)}; - int post_fd{m_post_fd}; + kj::Own& wait_stream{m_wait_stream}; char buffer = 0; for (;;) { const size_t read_bytes = wait_stream->read(&buffer, 0, 1).wait(m_io_context.waitScope); @@ -246,7 +302,7 @@ void EventLoop::loop() m_cv.notify_all(); } else if (done()) { // Intentionally do not break if m_post_fn was set, even if done() - // would return true, to ensure that the EventLoopRef write(post_fd) + // would return true, to ensure that the EventLoopRef write(post_stream) // call always succeeds and the loop does not exit between the time // that the done condition is set and the write call is made. break; @@ -256,10 +312,9 @@ void EventLoop::loop() m_task_set.reset(); MP_LOG(*this, Log::Info) << "EventLoop::loop bye."; wait_stream = nullptr; - KJ_SYSCALL(::close(post_fd)); const Lock lock(m_mutex); - m_wait_fd = -1; - m_post_fd = -1; + m_wait_stream = nullptr; + m_post_stream = nullptr; m_async_fns.reset(); m_cv.notify_all(); } @@ -274,10 +329,9 @@ void EventLoop::post(kj::Function fn) EventLoopRef ref(*this, &lock); m_cv.wait(lock.m_lock, [this]() MP_REQUIRES(m_mutex) { return m_post_fn == nullptr; }); m_post_fn = &fn; - int post_fd{m_post_fd}; Unlock(lock, [&] { char buffer = 0; - KJ_SYSCALL(write(post_fd, &buffer, 1)); + m_post_writer->write(&buffer, 1); }); m_cv.wait(lock.m_lock, [this, &fn]() MP_REQUIRES(m_mutex) { return m_post_fn != &fn; }); } diff --git a/src/mp/util.cpp b/src/mp/util.cpp index 509913b8..2dbf248c 100644 --- a/src/mp/util.cpp +++ b/src/mp/util.cpp @@ -10,19 +10,27 @@ #include #include #include +#include #include #include #include #include -#include -#include -#include #include #include // NOLINT(misc-include-cleaner) // IWYU pragma: keep #include #include #include +#ifdef WIN32 +#include +#include +#else +#include +#include +#include +#include +#endif + #ifdef __linux__ #include #endif @@ -33,9 +41,15 @@ namespace fs = std::filesystem; +#ifdef WIN32 +// Forward-declare internal capnp function. +namespace kj { namespace _ { int win32Socketpair(SOCKET socks[2]); } } +#endif + namespace mp { namespace { +#ifndef WIN32 //! Return highest possible file descriptor. size_t MaxFd() { @@ -46,6 +60,7 @@ size_t MaxFd() return 1023; } } +#endif } // namespace @@ -67,6 +82,8 @@ std::string ThreadName(const char* exe_name) // the former are shorter and are the same as what gdb prints "LWP ...". #ifdef __linux__ buffer << syscall(SYS_gettid); +#elif defined(WIN32) + buffer << GetCurrentThreadId(); #elif defined(HAVE_PTHREAD_THREADID_NP) uint64_t tid = 0; pthread_threadid_np(NULL, &tid); @@ -104,32 +121,138 @@ std::string LogEscape(const kj::StringTree& string, size_t max_size) return result; } -int SpawnProcess(int& pid, FdToArgsFn&& fd_to_args) +std::array SocketPair() +{ +#ifdef WIN32 + SOCKET pair[2]; + KJ_WINSOCK(kj::_::win32Socketpair(pair)); +#else + int pair[2]; + KJ_SYSCALL(socketpair(AF_UNIX, SOCK_STREAM, 0, pair)); +#endif + return {pair[0], pair[1]}; +} + +//! Generate command line that the executable being invoked will split up using +//! the CommandLineToArgvW function, which expects arguments with spaces to be +//! quoted, quote characters to be backslash-escaped, and backslashes to also be +//! backslash-escaped, but only if they precede a quote character. +std::string CommandLineFromArgv(const std::vector& argv) { - int fds[2]; - if (socketpair(AF_UNIX, SOCK_STREAM, 0, fds) != 0) { - throw std::system_error(errno, std::system_category(), "socketpair"); + std::string out; + for (const auto& arg : argv) { + if (!out.empty()) out += " "; + if (!arg.empty() && arg.find_first_of(" \t\"") == std::string::npos) { + // Argument has no quotes or spaces so escaping not necessary. + out += arg; + } else { + out += '"'; // Start with a quote + for (size_t i = 0; i < arg.size(); ++i) { + if (arg[i] == '\\') { + // Count consecutive backslashes + size_t backslash_count = 0; + while (i < arg.size() && arg[i] == '\\') { + ++backslash_count; + ++i; + } + if (i < arg.size() && arg[i] == '"') { + // Backslashes before a quote need to be doubled + out.append(backslash_count * 2 + 1, '\\'); + out.push_back('"'); + } else { + // Otherwise, backslashes remain as-is + out.append(backslash_count, '\\'); + --i; // Compensate for the outer loop's increment + } + } else if (arg[i] == '"') { + // Escape double quotes with a backslash + out.push_back('\\'); + out.push_back('"'); + } else { + out.push_back(arg[i]); + } + } + out += '"'; // End with a quote + } } + return out; +} - pid = fork(); +ProcessId SpawnProcess(SocketId socket, ConnectInfoToArgsFn&& connect_info_to_args) +{ +#ifndef WIN32 + int pid{fork()}; if (pid == -1) { throw std::system_error(errno, std::system_category(), "fork"); } - // Parent process closes the descriptor for socket 0, child closes the descriptor for socket 1. - if (close(fds[pid ? 0 : 1]) != 0) { - throw std::system_error(errno, std::system_category(), "close"); - } if (!pid) { // Child process must close all potentially open descriptors, except socket 0. const int maxFd = MaxFd(); for (int fd = 3; fd < maxFd; ++fd) { - if (fd != fds[0]) { + if (fd != socket) { close(fd); } } - ExecProcess(fd_to_args(fds[0])); + + int flags = fcntl(socket, F_GETFD); + if (flags == -1) throw std::system_error(errno, std::system_category(), "fcntl F_GETFD"); + if (flags & FD_CLOEXEC) { + flags &= ~FD_CLOEXEC; + if (fcntl(socket, F_SETFD, flags) == -1) throw std::system_error(errno, std::system_category(), "fcntl F_SETFD"); + } + + ExecProcess(connect_info_to_args(std::to_string(socket))); } - return fds[1]; + return pid; +#else + // Create windows pipe to send pipe.ends[0] over to child process. + static std::atomic counter{1}; + ConnectInfo pipe_path{"\\\\.\\pipe\\mp-" + std::to_string(GetCurrentProcessId()) + "-" + std::to_string(counter.fetch_add(1))}; + HANDLE pipe{CreateNamedPipeA(pipe_path.c_str(), PIPE_ACCESS_OUTBOUND, PIPE_TYPE_MESSAGE | PIPE_WAIT, 1, 0, 0, 0, nullptr)}; + KJ_WIN32(pipe != INVALID_HANDLE_VALUE, "CreateNamedPipe failed"); + + // Start child process + std::string cmd{CommandLineFromArgv(connect_info_to_args(pipe_path))}; + STARTUPINFOA si{}; + si.cb = sizeof(si); + PROCESS_INFORMATION pi{}; + KJ_WIN32(CreateProcessA(nullptr, const_cast(cmd.c_str()), nullptr, nullptr, TRUE, 0, nullptr, nullptr, &si, &pi), "CreateProcess failed"); + CloseHandle(pi.hThread); // not needed + + // Duplicate socket for the child (now that we know its PID) + WSAPROTOCOL_INFO info{}; + KJ_WINSOCK(WSADuplicateSocket(socket, pi.dwProcessId, &info), "WSADuplicateSocket failed"); + + // Send socket to the child via the pipe + KJ_WIN32(ConnectNamedPipe(pipe, nullptr) || GetLastError() == ERROR_PIPE_CONNECTED, "ConnectNamedPipe failed"); + DWORD wr; + KJ_WIN32(WriteFile(pipe, &info, sizeof(info), &wr, nullptr) && wr == sizeof(info), "WriteFile(pipe) failed"); + CloseHandle(pipe); + + return reinterpret_cast(pi.hProcess); +#endif +} + +SocketId StartSpawned(const ConnectInfo& connect_info) +{ +#ifndef WIN32 + return std::stoi(connect_info); +#else + HANDLE pipe = CreateFileA(connect_info.c_str(), GENERIC_READ, 0, nullptr, OPEN_EXISTING, 0, nullptr); + KJ_WIN32(pipe != INVALID_HANDLE_VALUE, "CreateFile(pipe) failed"); + + WSAPROTOCOL_INFO info{}; + DWORD rd; + KJ_WIN32(ReadFile(pipe, &info, sizeof(info), &rd, nullptr) && rd == sizeof(info), "ReadFile(pipe) failed"); + CloseHandle(pipe); + + WSADATA dontcare; + KJ_WIN32(WSAStartup(MAKEWORD(2, 2), &dontcare) != 0, "WSAStartup() failed"); + + SOCKET socket{WSASocketA(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, &info, 0, WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT)}; + KJ_WINSOCK(socket, "WSASocket(FROM_PROTOCOL_INFO) failed"); + return socket; +#endif } void ExecProcess(const std::vector& args) @@ -149,13 +272,22 @@ void ExecProcess(const std::vector& args) } } -int WaitProcess(int pid) +int WaitProcess(ProcessId pid) { +#ifndef WIN32 int status; if (::waitpid(pid, &status, 0 /* options */) != pid) { throw std::system_error(errno, std::system_category(), "waitpid"); } return status; +#else + HANDLE handle{reinterpret_cast(pid)}; + DWORD result{WaitForSingleObject(handle, INFINITE)}; + KJ_WIN32(result != WAIT_OBJECT_0, "WaitForSingleObject(child) failed"); + KJ_WIN32(GetExitCodeProcess(handle, &result), "GetExitCodeProcess failed"); + CloseHandle(handle); + return result; +#endif } } // namespace mp