From ec22d0c724b9279234312f5610885c7fc8d3d536 Mon Sep 17 00:00:00 2001 From: stevenfontanella Date: Tue, 3 Mar 2026 06:19:24 +0000 Subject: [PATCH] Implementation for threads in spec tests --- debug_sb.wast | 48 +++++++ patch.py | 35 +++++ scripts/test/shared.py | 2 +- src/tools/wasm-shell.cpp | 296 ++++++++++++++++++++++++++++++++++++--- src/wasm-interpreter.h | 235 ++++++++++++++++++++++++++++--- src/wasm/wasm.cpp | 1 + test/spec/waitqueue.wast | 47 +++++++ test_check.py | 42 ++++++ test_check.wast | 32 +++++ test_debug.py | 58 ++++++++ 10 files changed, 753 insertions(+), 43 deletions(-) create mode 100644 debug_sb.wast create mode 100644 patch.py create mode 100644 test_check.py create mode 100644 test_check.wast create mode 100644 test_debug.py diff --git a/debug_sb.wast b/debug_sb.wast new file mode 100644 index 00000000000..36c08628e4a --- /dev/null +++ b/debug_sb.wast @@ -0,0 +1,48 @@ + +(module $Mem + (memory (export "shared") 1 1 shared) +) +(register "mem") + +(thread $T1 (shared (module $Mem)) + (register "mem" $Mem) + (module + (memory (import "mem" "shared") 1 1 shared) + (func (export "run") + (local i32) + (i32.atomic.store (i32.const 0) (i32.const 1)) + (i32.atomic.load (i32.const 4)) + (local.set 0) + (i32.store (i32.const 24) (local.get 0)) + ) + ) + (invoke "run") +) + +(thread $T2 (shared (module $Mem)) + (register "mem" $Mem) + (module + (memory (import "mem" "shared") 1 1 shared) + (func (export "run") + (local i32) + (i32.atomic.store (i32.const 4) (i32.const 1)) + (i32.atomic.load (i32.const 0)) + (local.set 0) + (i32.store (i32.const 32) (local.get 0)) + ) + ) + (invoke "run") +) + +(wait $T1) +(wait $T2) + +(module $Check + (memory (import "mem" "shared") 1 1 shared) + (func (export "check") (result i32 i32) + (i32.load (i32.const 32)) ;; Load L_1 first so it fails at index 0 + (i32.load (i32.const 24)) + ) +) + +(assert_return (invoke $Check "check") (i32.const 999) (i32.const 999)) diff --git a/patch.py b/patch.py new file mode 100644 index 00000000000..a50c0dde3e1 --- /dev/null +++ b/patch.py @@ -0,0 +1,35 @@ + +with open("src/wasm-interpreter.h") as f: + content = f.read() + +# Add debug prints to doAtomicLoad and doAtomicStore +load_hook = """ Literal doAtomicLoad(Address addr, + Index bytes, + Type type, + Name memoryName, + Address memorySize, + MemoryOrder order) { + std::cerr << "doAtomicLoad addr=" << addr << " memoryName=" << memoryName << " instance=" << this << "\\n";""" + +content = content.replace(""" Literal doAtomicLoad(Address addr, + Index bytes, + Type type, + Name memoryName, + Address memorySize, + MemoryOrder order) {""", load_hook) + +store_hook = """ void doAtomicStore(Address addr, + Index bytes, + Literal toStore, + Name memoryName, + Address memorySize) { + std::cerr << "doAtomicStore addr=" << addr << " val=" << toStore << " memoryName=" << memoryName << " instance=" << this << "\\n";""" + +content = content.replace(""" void doAtomicStore(Address addr, + Index bytes, + Literal toStore, + Name memoryName, + Address memorySize) {""", store_hook) + +with open("src/wasm-interpreter.h", "w") as f: + f.write(content) diff --git a/scripts/test/shared.py b/scripts/test/shared.py index 18ca85d9033..e18785f0826 100644 --- a/scripts/test/shared.py +++ b/scripts/test/shared.py @@ -400,7 +400,7 @@ def get_tests(test_dir, extensions=[], recursive=False): 'threads/thread.wast', # Requires better support for multi-threaded tests - 'threads/wait_notify.wast', + # 'threads/wait_notify.wast', # Non-natural alignment is invalid for atomic operations 'threads/atomic.wast', diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp index 32c1b98ad4e..d26e09eb320 100644 --- a/src/tools/wasm-shell.cpp +++ b/src/tools/wasm-shell.cpp @@ -51,9 +51,29 @@ struct Shell { Name lastInstance; std::optional lastModuleDefinition; + size_t anonymousModuleCounter = 0; + + std::shared_ptr sharedWaitState; + Options& options; - Shell(Options& options) : options(options) { buildSpectestModule(); } + struct ThreadState { + Name name; + std::vector commands; + size_t pc = 0; + bool isSuspended = false; + std::shared_ptr instance = nullptr; + std::shared_ptr suspendedCont = nullptr; + bool done = false; + Name lastInstance; + std::optional lastModuleDefinition; + }; + std::vector activeThreads; + + Shell(Options& options) : options(options) { + sharedWaitState = std::make_shared(); + buildSpectestModule(); + } Result<> run(WASTScript& script) { size_t i = 0; @@ -102,11 +122,210 @@ struct Shell { } } - // Run threads in a blocking manner for now. - // TODO: yield on blocking instructions e.g. memory.atomic.wait32. - Result<> doThread(ThreadBlock& thread) { return run(thread.commands); } + Result<> doThread(ThreadBlock& thread) { + ThreadState state; + state.name = thread.name; + state.commands = thread.commands; + state.lastInstance = lastInstance; + state.lastModuleDefinition = lastModuleDefinition; + activeThreads.push_back(std::move(state)); + return Ok{}; + } + + Result<> doWait(Wait& wait) { + bool found = false; + for (auto& t : activeThreads) { + if (t.name == wait.thread) { + found = true; + break; + } + } + if (!found) { + return Err{"wait called for unknown thread"}; + } + + // Round-robin execution + while (true) { + bool anyProgress = false; + bool targetDone = false; + + size_t numThreads = activeThreads.size(); + for (size_t i = 0; i < numThreads; ++i) { + if (activeThreads[i].done) { + if (activeThreads[i].name == wait.thread) + targetDone = true; + continue; + } + + if (activeThreads[i].isSuspended) { + // Check if it's still waiting. WaitQueue sets `isWaiting` to false + // when notified. + bool stillWaiting = activeThreads[i].suspendedCont && + activeThreads[i].suspendedCont->isWaiting; + + if (!stillWaiting) { + // It was woken up! We need to resume it. + activeThreads[i].isSuspended = false; + Flow flow; + try { + flow = activeThreads[i].instance->resumeContinuation( + activeThreads[i].suspendedCont); + } catch (TrapException&) { + std::cerr << "Thread " << activeThreads[i].name + << " trapped upon resume\n"; + activeThreads[i].done = true; + anyProgress = true; + continue; + } catch (...) { + WASM_UNREACHABLE("unexpected error during resume"); + } + activeThreads[i].suspendedCont = nullptr; + + if (flow.breakTo == THREAD_SUSPEND_FLOW) { + // Suspended again + activeThreads[i].isSuspended = true; + activeThreads[i].suspendedCont = + activeThreads[i].instance->getSuspendedContinuation(); + anyProgress = true; + } else if (flow.suspendTag) { + activeThreads[i].instance->clearContinuationStore(); + activeThreads[i].done = true; // unhandled suspension + anyProgress = true; + } else { + auto& cmd = activeThreads[i].commands[activeThreads[i].pc].cmd; + if (auto* assnVar = std::get_if(&cmd)) { + if (auto* assn = std::get_if(assnVar)) { + auto assnRes = + assertResult(ActionResult(flow.values), assn->expected); + if (assnRes.getErr()) { + std::cerr << "Thread " << activeThreads[i].name + << " error: " << assnRes.getErr()->msg << "\n"; + activeThreads[i].done = true; + } else { + activeThreads[i].pc++; + } + } else { + activeThreads[i].pc++; + } + } else { + activeThreads[i] + .pc++; // Completed the command that originally suspended! + } + anyProgress = true; + } + } + } else { + // Normal execution of the next command. + std::swap(lastInstance, activeThreads[i].lastInstance); + std::swap(lastModuleDefinition, + activeThreads[i].lastModuleDefinition); + + if (activeThreads[i].pc < activeThreads[i].commands.size()) { + auto& cmd = activeThreads[i].commands[activeThreads[i].pc].cmd; + Action* trackAction = nullptr; + if (auto* act = std::get_if(&cmd)) { + trackAction = act; + } else if (auto* assnVar = std::get_if(&cmd)) { + if (auto* assn = std::get_if(assnVar)) { + trackAction = &assn->action; + } + } + + if (trackAction) { + auto result = doAction(*trackAction); + if (std::get_if(&result)) { + activeThreads[i].isSuspended = true; + if (auto* invoke = std::get_if(trackAction)) { + activeThreads[i].instance = + instances[invoke->base ? *invoke->base : lastInstance]; + activeThreads[i].suspendedCont = + activeThreads[i].instance->getSuspendedContinuation(); + std::cerr + << "THREAD " << i << " SUSPENDED. suspendedCont is " + << (activeThreads[i].suspendedCont ? "VALID" : "NULL") + << " instance addr=" << activeThreads[i].instance.get() + << "\n"; + } else { + std::cerr + << "THREAD " << i + << " SUSPENDED but trackAction is NOT InvokeAction!\n"; + } + anyProgress = true; + } else { + if (auto* assnVar = std::get_if(&cmd)) { + if (auto* assn = std::get_if(assnVar)) { + auto assnRes = assertResult(result, assn->expected); + if (assnRes.getErr()) { + std::cerr << "Thread " << activeThreads[i].name + << " error: " << assnRes.getErr()->msg << "\n"; + activeThreads[i].done = true; + } else { + activeThreads[i].pc++; + } + } else { + activeThreads[i].pc++; + } + } else { + activeThreads[i].pc++; + } + anyProgress = true; + } + } else if (auto* waitCmd = std::get_if(&cmd)) { + bool waitFound = false; + bool waitDone = false; + // Avoid using an index loop here since activeThreads might be + // accessed + for (size_t j = 0; j < activeThreads.size(); ++j) { + if (activeThreads[j].name == waitCmd->thread) { + waitFound = true; + waitDone = activeThreads[j].done; + break; + } + } + if (!waitFound) { + std::cerr << "Thread " << activeThreads[i].name + << " error: wait called for unknown thread\n"; + activeThreads[i].done = true; + anyProgress = true; + } else if (waitDone) { + activeThreads[i].pc++; + anyProgress = true; + } + } else { + // Not an action, wait, or assert_return, just run it + // (e.g. module instantiation or other assertions) + auto res = runCommand(cmd); + if (res.getErr()) { + std::cerr << "Thread " << activeThreads[i].name + << " error: " << res.getErr()->msg << "\n"; + activeThreads[i].done = true; + } else { + activeThreads[i].pc++; + anyProgress = true; + } + } + } else { + activeThreads[i].done = true; + anyProgress = true; // finishing counts as progress + } + + std::swap(lastInstance, activeThreads[i].lastInstance); + std::swap(lastModuleDefinition, + activeThreads[i].lastModuleDefinition); + } + } - Result<> doWait(Wait& wait) { return Ok{}; } + if (targetDone) { + break; + } + + if (!anyProgress) { + // Find if target is still suspended + return Err{"deadlock! no threads can make progress"}; + } + } + return Ok{}; + } Result> makeModule(WASTModule& mod) { std::shared_ptr wasm; @@ -173,6 +392,22 @@ struct Shell { auto instance = std::make_shared(wasm, interface.get(), linkedInstances); + // In multithreaded WASM, instances within the same thread should share a + // stack. However, the `linkedInstances` might contain modules (like memory) + // shared across ALL threads. If we blindly inherit `continuationStore` from + // `linkedInstances`, all threads will share the same execution stack, + // causing segfaults. Therefore, we MUST give this instance a fresh + // ContinuationStore for its thread execution unless it is supposed to be + // part of an existing thread's execution. For now, in `wasm-shell`, we + // simplify by giving every top-level module a fresh store but sharing the + // WAIT state. (Called function execution across modules will temporarily + // push to their respective stores, which is not perfect natively but avoids + // stack data races). Actually, `activeThreads[i]` implies each thread has + // its own stack. + auto store = std::make_shared(); + store->sharedWaitState = sharedWaitState; + instance->setContinuationStore(store); + lastInstance = instanceName; // Even if instantiation fails, the module may have partially instantiated @@ -201,6 +436,10 @@ struct Shell { CHECK_ERR(module); auto wasm = *module; + if (!wasm->name.is()) { + wasm->name = Name(std::string("anonymous_") + + std::to_string(anonymousModuleCounter++)); + } CHECK_ERR(validateModule(*wasm)); modules[wasm->name] = wasm; @@ -233,13 +472,15 @@ struct Shell { struct HostLimitResult {}; struct ExceptionResult {}; struct SuspensionResult {}; + struct ThreadSuspendResult {}; using ActionResult = std::variant; + SuspensionResult, + ThreadSuspendResult>; - std::string resultToString(ActionResult& result) { + std::string resultToString(const ActionResult& result) { if (std::get_if(&result)) { return "trap"; } else if (std::get_if(&result)) { @@ -248,6 +489,8 @@ struct Shell { return "exception"; } else if (std::get_if(&result)) { return "suspension"; + } else if (std::get_if(&result)) { + return "thread_suspend"; } else if (auto* vals = std::get_if(&result)) { std::stringstream ss; ss << *vals; @@ -265,6 +508,8 @@ struct Shell { return TrapResult{}; } auto& instance = it->second; + std::cerr << "doAction invoke name=" << invoke->name + << " instance addr=" << instance.get() << "\n"; Flow flow; try { flow = instance->callExport(invoke->name, invoke->args); @@ -277,6 +522,9 @@ struct Shell { } catch (...) { WASM_UNREACHABLE("unexpected error"); } + if (flow.breakTo == THREAD_SUSPEND_FLOW) { + return ThreadSuspendResult{}; + } if (flow.suspendTag) { // This is an unhandled suspension. Handle it here - clear the // suspension state - so nothing else is affected. @@ -352,15 +600,15 @@ struct Shell { return Ok{}; } - Result<> assertReturn(AssertReturn& assn) { + Result<> assertResult(const ActionResult& result, + const std::vector& expected) { std::stringstream err; - auto result = doAction(assn.action); auto* values = std::get_if(&result); if (!values) { return Err{std::string("expected return, got ") + resultToString(result)}; } - if (values->size() != assn.expected.size()) { - err << "expected " << assn.expected.size() << " values, got " + if (values->size() != expected.size()) { + err << "expected " << expected.size() << " values, got " << resultToString(result); return Err{err.str()}; } @@ -375,13 +623,13 @@ struct Shell { }; Literal val = (*values)[i]; - auto& expected = assn.expected[i]; - if (auto* v = std::get_if(&expected)) { + auto& exp = expected[i]; + if (auto* v = std::get_if(&exp)) { if (val != *v) { err << "expected " << *v << ", got " << val << atIndex(); return Err{err.str()}; } - } else if (auto* ref = std::get_if(&expected)) { + } else if (auto* ref = std::get_if(&exp)) { if (!val.type.isRef() || !HeapType::isSubType(val.type.getHeapType(), ref->type)) { err << "expected " << ref->type << " reference, got " << val @@ -389,23 +637,23 @@ struct Shell { return Err{err.str()}; } } else if ([[maybe_unused]] auto* nullRef = - std::get_if(&expected)) { + std::get_if(&exp)) { if (!val.isNull()) { err << "expected ref.null, got " << val << atIndex(); return Err{err.str()}; } - } else if (auto* nan = std::get_if(&expected)) { + } else if (auto* nan = std::get_if(&exp)) { auto check = checkNaN(val, *nan); if (auto* e = check.getErr()) { err << e->msg << atIndex(); return Err{err.str()}; } - } else if (auto* lanes = std::get_if(&expected)) { + } else if (auto* lanes = std::get_if(&exp)) { switch (lanes->size()) { case 4: { auto vals = val.getLanesF32x4(); - for (Index i = 0; i < 4; ++i) { - auto check = checkLane(vals[i], (*lanes)[i], i); + for (Index j = 0; j < 4; ++j) { + auto check = checkLane(vals[j], (*lanes)[j], j); if (auto* e = check.getErr()) { err << e->msg << atIndex(); return Err{err.str()}; @@ -415,8 +663,8 @@ struct Shell { } case 2: { auto vals = val.getLanesF64x2(); - for (Index i = 0; i < 2; ++i) { - auto check = checkLane(vals[i], (*lanes)[i], i); + for (Index j = 0; j < 2; ++j) { + auto check = checkLane(vals[j], (*lanes)[j], j); if (auto* e = check.getErr()) { err << e->msg << atIndex(); return Err{err.str()}; @@ -428,12 +676,16 @@ struct Shell { WASM_UNREACHABLE("unexpected number of lanes"); } } else { - WASM_UNREACHABLE("unexpected expectation"); + WASM_UNREACHABLE("unexpected result expectation"); } } return Ok{}; } + Result<> assertReturn(AssertReturn& assn) { + return assertResult(doAction(assn.action), assn.expected); + } + Result<> assertAction(AssertAction& assn) { std::stringstream err; auto result = doAction(assn.action); diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index b0281636bb6..9bce9e3ee43 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -71,7 +71,8 @@ struct NonconstantException {}; // Utilities -extern Name RETURN_FLOW, RETURN_CALL_FLOW, NONCONSTANT_FLOW, SUSPEND_FLOW; +extern Name RETURN_FLOW, RETURN_CALL_FLOW, NONCONSTANT_FLOW, SUSPEND_FLOW, + THREAD_SUSPEND_FLOW; // Stuff that flows around during executing expressions: a literal, or a change // in control flow. @@ -87,13 +88,15 @@ class Flow { : values(std::move(values)), breakTo(breakTo) {} Flow(Name breakTo, Tag* suspendTag, Literals&& values) : values(std::move(values)), breakTo(breakTo), suspendTag(suspendTag) { - assert(breakTo == SUSPEND_FLOW); + assert(breakTo == SUSPEND_FLOW || breakTo == THREAD_SUSPEND_FLOW); } Literals values; Name breakTo; // if non-null, a break is going on Tag* suspendTag = nullptr; // if non-null, breakTo must be SUSPEND_FLOW, and - // this is the tag being suspended + // this is the tag being suspended. If breakTo is + // THREAD_SUSPEND_FLOW, this represents the thread + // suspending and this field is not used. // A helper function for the common case where there is only one value const Literal& getSingleValue() { @@ -281,6 +284,10 @@ struct ContData { // resume_throw_ref). Literal exception; + // If set, this continuation was suspended into a wait queue by a thread + // and has not yet been woken up. + bool isWaiting = false; + // Whether we executed. Continuations are one-shot, so they may not be // executed a second time. bool executed = false; @@ -289,6 +296,21 @@ struct ContData { ContData(Literal func, HeapType type) : func(func), type(type) {} }; +// Shared execution state of a set of instantiated modules. +struct SharedWaitState { + // The wait queue for threads waiting on addresses (represented by GCData and + // field index). + std::unordered_map< + std::shared_ptr, + std::unordered_map>>> + waitQueues; + + // The wait queue for linear memory addresses. + std::map, + std::unordered_map>>> + memoryWaitQueues; +}; + // Shared execution state of a set of instantiated modules. struct ContinuationStore { // The current continuations, in a stack. At the top of the stack is the @@ -303,6 +325,10 @@ struct ContinuationStore { // Set when we are resuming execution, that is, re-winding the stack. bool resuming = false; + + std::shared_ptr sharedWaitState; + + ContinuationStore() { sharedWaitState = std::make_shared(); } }; // Execute an expression @@ -477,14 +503,33 @@ class ExpressionRunner : public OverriddenVisitor { // expression that is in this map, then it will just return that value. std::unordered_map restoredValuesMap; - // Shared execution state for continuations. This can be null if the - // instance does not want to ever suspend/resume. std::shared_ptr continuationStore; +public: + void setContinuationStore(std::shared_ptr store) { + continuationStore = store; + } + + void setSharedWaitState(std::shared_ptr state) { + if (continuationStore) { + continuationStore->sharedWaitState = state; + } + } + std::shared_ptr getCurrContinuationOrNull() { - if (!continuationStore || continuationStore->continuations.empty()) { + if (!continuationStore) { + std::cerr << "getCurrContinuationOrNull: store is NULL! this=" << this + << "\n"; return {}; } + if (continuationStore->continuations.empty()) { + std::cerr << "getCurrContinuationOrNull: continuations is EMPTY! this=" + << this << "\n"; + return {}; + } + std::cerr << "getCurrContinuationOrNull: returning back, size=" + << continuationStore->continuations.size() << " this=" << this + << "\n"; return continuationStore->continuations.back(); } @@ -2237,13 +2282,85 @@ class ExpressionRunner : public OverriddenVisitor { } Flow visitStructWait(StructWait* curr) { - WASM_UNREACHABLE("struct.wait not implemented"); - return Flow(); + VISIT(ref, curr->ref) + VISIT(expected, curr->expected) + VISIT(timeout, + curr->timeout) // We ignore timeout in the simulation for simplicity + + auto data = ref.getSingleValue().getGCData(); + if (!data) { + trap("null ref"); + } + + auto& field = data->values[curr->index]; + if (field != expected.getSingleValue()) { + return Literal(int32_t(1)); // not-equal, don't wait + } + + if (self()->isResuming()) { + // We have been notified and resumed. + // Clear the resume state and continue. + auto currContinuation = self()->getCurrContinuation(); + assert(curr == currContinuation->resumeExpr); + self()->continuationStore->resuming = false; + assert(currContinuation->resumeInfo.empty()); + assert(self()->restoredValuesMap.empty()); + return Literal(int32_t(0)); // ok, woken up + } + + // We need to wait. Create a continuation and suspend the thread. + auto old = self()->getCurrContinuationOrNull(); + auto new_ = std::make_shared(); + if (old) { + self()->popCurrContinuation(); + } + self()->pushCurrContinuation(new_); + new_->resumeExpr = curr; + new_->isWaiting = true; + + self() + ->continuationStore->sharedWaitState->waitQueues[data][curr->index] + .push_back(new_); + + return Flow(THREAD_SUSPEND_FLOW); } Flow visitStructNotify(StructNotify* curr) { - WASM_UNREACHABLE("struct.notify not implemented"); - return Flow(); + VISIT(ref, curr->ref) + VISIT(count, curr->count) + + auto data = ref.getSingleValue().getGCData(); + if (!data) { + trap("null ref"); + } + + int32_t countVal = count.getSingleValue().geti32(); + int32_t woken = 0; + + auto& store = self()->continuationStore; + auto it1 = store->sharedWaitState->waitQueues.find(data); + if (it1 != store->sharedWaitState->waitQueues.end()) { + auto& fieldQueues = it1->second; + auto it2 = fieldQueues.find(curr->index); + if (it2 != fieldQueues.end()) { + auto& queue = it2->second; + while (!queue.empty() && woken < countVal) { + // The waking thread will be executed by the wasm-shell scheduler. + // In the reference interpreter, awake continuations should be + // tracked. Since wasm-shell handles interleaved threads, we don't + // automatically execute them here. Wait! wasm-shell scheduler needs + // to know which threads are ready. Our ContinuationStore wait queues + // structure just pops them. The scheduler wrapper will need a way to + // track all active threads. + auto wokeCont = queue.front(); + wokeCont->isWaiting = false; + queue.erase(queue.begin()); + woken++; + } + } + } + + return Literal(woken); } // Arbitrary deterministic limit on size. If we need to allocate a Literals @@ -2707,6 +2824,10 @@ class ExpressionRunner : public OverriddenVisitor { virtual void hostLimit(std::string_view why) { WASM_UNREACHABLE("unimp"); } + virtual void invokeMain(const std::string& startName) { + WASM_UNREACHABLE("unimp"); + } + virtual void throwException(const WasmException& exn) { WASM_UNREACHABLE("unimp"); } @@ -3250,6 +3371,42 @@ class ModuleRunnerBase : public ExpressionRunner { Flow callExport(Name name) { return callExport(name, Literals()); } + std::shared_ptr getSuspendedContinuation() { + return this->getCurrContinuationOrNull(); + } + + Flow resumeContinuation(std::shared_ptr contData, + Literals arguments = {}) { + if (contData->executed) { + this->trap("continuation already executed"); + } + contData->executed = true; + + if (contData->resumeArguments.empty()) { + contData->resumeArguments = arguments; + } + + this->pushCurrContinuation(contData); + this->continuationStore->resuming = true; +#if WASM_INTERPRETER_DEBUG + std::cout << this->indent() << "resuming func " << contData->func.getFunc() + << '\n'; +#endif + + Flow ret = contData->func.getFuncData()->doCall(arguments); + + if (this->isResuming()) { + // if we didn't suspend again natively, clear resuming flag + this->continuationStore->resuming = false; + } + + if (ret.breakTo != THREAD_SUSPEND_FLOW && !ret.suspendTag) { + // The coroutine finished normally. + this->popCurrContinuation(); + } + return ret; + } + Literal getExportedFunction(Name name) { Export* export_ = wasm.getExportOrNull(name); if (!export_ || export_->kind != ExternalKind::Function) { @@ -4129,26 +4286,64 @@ class ModuleRunnerBase : public ExpressionRunner { if (loaded != expected.getSingleValue()) { return Literal(int32_t(1)); // not equal } - // TODO: Add threads support. For now, report a host limit here, as there - // are no other threads that can wake us up. Without such threads, - // we'd hang if there is no timeout, and even if there is a timeout - // then we can hang for a long time if it is in a loop. The only - // timeout value we allow here for now is 0. - if (timeout.getSingleValue().getInteger() != 0) { - hostLimit("threads support"); + + if (self()->isResuming()) { + auto currContinuation = self()->getCurrContinuation(); + assert(curr == currContinuation->resumeExpr); + self()->continuationStore->resuming = false; + assert(currContinuation->resumeInfo.empty()); + assert(self()->restoredValuesMap.empty()); + return Literal(int32_t(0)); // ok, woken up + } + + auto old = self()->getCurrContinuationOrNull(); + auto new_ = std::make_shared(); + if (old) { + self()->popCurrContinuation(); } - return Literal(int32_t(2)); // Timed out + self()->pushCurrContinuation(new_); + new_->resumeExpr = curr; + new_->isWaiting = true; + + self() + ->continuationStore->sharedWaitState + ->memoryWaitQueues[{info.instance, info.name}][addr] + .push_back(new_); + + return Flow(THREAD_SUSPEND_FLOW); } Flow visitAtomicNotify(AtomicNotify* curr) { VISIT(ptr, curr->ptr) VISIT(count, curr->notifyCount) auto info = getMemoryInstanceInfo(curr->memory); + auto memorySize = info.instance->getMemorySize(info.name); auto memorySizeBytes = info.instance->getMemorySizeBytes(info.name); auto addr = info.instance->getFinalAddress( curr, ptr.getSingleValue(), 4, memorySizeBytes); // Just check TODO actual threads support info.instance->checkAtomicAddress(addr, 4, memorySizeBytes); - return Literal(int32_t(0)); // none woken up + + int32_t countVal = count.getSingleValue().geti32(); + int32_t woken = 0; + + auto& store = self()->continuationStore; + auto it1 = store->sharedWaitState->memoryWaitQueues.find( + {(void*)info.instance, info.name}); + if (it1 != store->sharedWaitState->memoryWaitQueues.end()) { + auto& addressQueues = it1->second; + auto it2 = addressQueues.find(addr); + if (it2 != addressQueues.end()) { + auto& queue = it2->second; + while (!queue.empty() && woken < countVal) { + auto wokeCont = queue.front(); + wokeCont->isWaiting = false; + queue.erase(queue.begin()); + woken++; + } + } + } + + return Literal(woken); } Flow visitSIMDLoad(SIMDLoad* curr) { switch (curr->op) { @@ -5135,7 +5330,7 @@ class ModuleRunnerBase : public ExpressionRunner { flow.breakTo = Name(); } - if (flow.breakTo != SUSPEND_FLOW) { + if (flow.breakTo != SUSPEND_FLOW && flow.breakTo != THREAD_SUSPEND_FLOW) { // We are normally executing (not suspending), and therefore cannot still // be breaking, which would mean we missed our stop. assert(!flow.breaking() || flow.breakTo == RETURN_FLOW); diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index c05a23dbc9a..906e8a22419 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -28,6 +28,7 @@ Name RETURN_FLOW("*return:)*"); Name RETURN_CALL_FLOW("*return-call:)*"); Name NONCONSTANT_FLOW("*nonconstant:)*"); Name SUSPEND_FLOW("*suspend:)*"); +Name THREAD_SUSPEND_FLOW("*thread_suspend:)*"); namespace BinaryConsts::CustomSections { diff --git a/test/spec/waitqueue.wast b/test/spec/waitqueue.wast index cd0631ef1da..122a31b555d 100644 --- a/test/spec/waitqueue.wast +++ b/test/spec/waitqueue.wast @@ -96,3 +96,50 @@ (struct.get $t 0 (global.get $g)) ) ) + +(module $Mem + (type $Wq (struct (field (mut waitqueue)))) + (global $wq (export "wq") (mut (ref null $Wq)) (ref.null $Wq)) + + (func $init (export "init") + (global.set $wq (struct.new $Wq (i32.const 0))) + ) +) + +(register "mem") + +(invoke $Mem "init") + +(thread $T1 (shared (module $Mem)) + (register "mem" $Mem) + (module + (type $Wq (struct (field (mut waitqueue)))) + (global $wq (import "mem" "wq") (mut (ref null $Wq))) + + (func (export "run_wait") (result i32) + ;; Wait on the waitqueue, expecting value 0, infinite timeout (-1) + (struct.wait $Wq 0 (global.get $wq) (i32.const 0) (i64.const -1)) + ) + ) + ;; This thread will suspend on struct.wait + (invoke "run_wait") +) + +(thread $T2 (shared (module $Mem)) + (register "mem" $Mem) + (module + (type $Wq (struct (field (mut waitqueue)))) + (global $wq (import "mem" "wq") (mut (ref null $Wq))) + + (func (export "run_notify") (result i32) + ;; Notify 1 waiter on the waitqueue + (struct.notify $Wq 0 (global.get $wq) (i32.const 1)) + ) + ) + ;; This thread will notify the waitqueue and wake 1 thread + (assert_return (invoke "run_notify") (i32.const 1)) +) + +;; Wait for threads to complete +(wait $T1) +(wait $T2) diff --git a/test_check.py b/test_check.py new file mode 100644 index 00000000000..8a1a27deb81 --- /dev/null +++ b/test_check.py @@ -0,0 +1,42 @@ +import subprocess + +wast_content = """ +(module $Mem + (memory (export "shared") 1 1 shared) +) +(register "mem") + +(module $Check + (memory (import "mem" "shared") 1 1 shared) + + (func (export "check") (result i32) + (local i32 i32) + ;; Manually store values to simulate L_0=0 and L_1=1 + (i32.store (i32.const 24) (i32.const 0)) + (i32.store (i32.const 32) (i32.const 1)) + + (i32.load (i32.const 24)) + (local.set 0) + (i32.load (i32.const 32)) + (local.set 1) + + ;; allowed results: (L_0 = 1 && L_1 = 1) || (L_0 = 0 && L_1 = 1) || (L_0 = 1 && L_1 = 0) + + (i32.and (i32.eq (local.get 0) (i32.const 1)) (i32.eq (local.get 1) (i32.const 1))) + (i32.and (i32.eq (local.get 0) (i32.const 0)) (i32.eq (local.get 1) (i32.const 1))) + (i32.and (i32.eq (local.get 0) (i32.const 1)) (i32.eq (local.get 1) (i32.const 0))) + (i32.or) + (i32.or) + (return) + ) +) + +(assert_return (invoke $Check "check") (i32.const 1)) +""" + +with open("test_check.wast", "w") as f: + f.write(wast_content) + +r = subprocess.run(["./bin/wasm-shell", "test_check.wast"], capture_output=True, text=True) +print("STDOUT:", r.stdout) +print("STDERR:", r.stderr) diff --git a/test_check.wast b/test_check.wast new file mode 100644 index 00000000000..8c614fab571 --- /dev/null +++ b/test_check.wast @@ -0,0 +1,32 @@ + +(module $Mem + (memory (export "shared") 1 1 shared) +) +(register "mem") + +(module $Check + (memory (import "mem" "shared") 1 1 shared) + + (func (export "check") (result i32) + (local i32 i32) + ;; Manually store values to simulate L_0=0 and L_1=1 + (i32.store (i32.const 24) (i32.const 0)) + (i32.store (i32.const 32) (i32.const 1)) + + (i32.load (i32.const 24)) + (local.set 0) + (i32.load (i32.const 32)) + (local.set 1) + + ;; allowed results: (L_0 = 1 && L_1 = 1) || (L_0 = 0 && L_1 = 1) || (L_0 = 1 && L_1 = 0) + + (i32.and (i32.eq (local.get 0) (i32.const 1)) (i32.eq (local.get 1) (i32.const 1))) + (i32.and (i32.eq (local.get 0) (i32.const 0)) (i32.eq (local.get 1) (i32.const 1))) + (i32.and (i32.eq (local.get 0) (i32.const 1)) (i32.eq (local.get 1) (i32.const 0))) + (i32.or) + (i32.or) + (return) + ) +) + +(assert_return (invoke $Check "check") (i32.const 1)) diff --git a/test_debug.py b/test_debug.py new file mode 100644 index 00000000000..5ccd4ddb768 --- /dev/null +++ b/test_debug.py @@ -0,0 +1,58 @@ +import subprocess + +wast_content = """ +(module $Mem + (memory (export "shared") 1 1 shared) +) +(register "mem") + +(thread $T1 (shared (module $Mem)) + (register "mem" $Mem) + (module + (memory (import "mem" "shared") 1 1 shared) + (func (export "run") + (local i32) + (i32.atomic.store (i32.const 0) (i32.const 1)) + (i32.atomic.load (i32.const 4)) + (local.set 0) + (i32.store (i32.const 24) (local.get 0)) + ) + ) + (invoke "run") +) + +(thread $T2 (shared (module $Mem)) + (register "mem" $Mem) + (module + (memory (import "mem" "shared") 1 1 shared) + (func (export "run") + (local i32) + (i32.atomic.store (i32.const 4) (i32.const 1)) + (i32.atomic.load (i32.const 0)) + (local.set 0) + (i32.store (i32.const 32) (local.get 0)) + ) + ) + (invoke "run") +) + +(wait $T1) +(wait $T2) + +(module $Check + (memory (import "mem" "shared") 1 1 shared) + (func (export "check") (result i32 i32) + (i32.load (i32.const 32)) ;; Load L_1 first so it fails at index 0 + (i32.load (i32.const 24)) + ) +) + +(assert_return (invoke $Check "check") (i32.const 999) (i32.const 999)) +""" + +with open("debug_sb.wast", "w") as f: + f.write(wast_content) + +r = subprocess.run(["./bin/wasm-shell", "debug_sb.wast"], capture_output=True, text=True) +print("STDOUT:", r.stdout) +print("STDERR:", r.stderr)