diff --git a/include/scratchcpp/compiler.h b/include/scratchcpp/compiler.h index 7cc8d44d9..fea2a29ca 100644 --- a/include/scratchcpp/compiler.h +++ b/include/scratchcpp/compiler.h @@ -149,8 +149,11 @@ class LIBSCRATCHCPP_EXPORT Compiler void warp(); void createYield(); + void createStop(); - void createStopWithoutSync(); + void createThreadStop(); + + void invalidateTarget(); void createProcedureCall(BlockPrototype *prototype, const Compiler::Args &args); diff --git a/src/blocks/controlblocks.cpp b/src/blocks/controlblocks.cpp index 807d1b78e..e340ef415 100644 --- a/src/blocks/controlblocks.cpp +++ b/src/blocks/controlblocks.cpp @@ -90,7 +90,8 @@ CompilerValue *ControlBlocks::compileStop(Compiler *compiler) if (str == "all") { compiler->addFunctionCallWithCtx("control_stop_all", Compiler::StaticType::Void); - compiler->createStop(); + compiler->invalidateTarget(); // if this is a clone, it doesn't exist anymore + compiler->createThreadStop(); } else if (str == "this script") compiler->createStop(); else if (str == "other scripts in sprite" || str == "other scripts in stage") @@ -181,7 +182,8 @@ CompilerValue *ControlBlocks::compileDeleteThisClone(Compiler *compiler) { CompilerValue *deleted = compiler->addTargetFunctionCall("control_delete_this_clone", Compiler::StaticType::Bool); compiler->beginIfStatement(deleted); - compiler->createStopWithoutSync(); // sync happens before the function call + compiler->invalidateTarget(); // the sprite doesn't exist anymore + compiler->createThreadStop(); // callers should be stopped too compiler->endIf(); return nullptr; } diff --git a/src/engine/compiler.cpp b/src/engine/compiler.cpp index f113998e4..e87311771 100644 --- a/src/engine/compiler.cpp +++ b/src/engine/compiler.cpp @@ -705,14 +705,19 @@ void Compiler::createStop() impl->builder->createStop(); } +/*! Creates a stop thread (current script and procedure callers) instruction. */ +void Compiler::createThreadStop() +{ + impl->builder->createThreadStop(); +} + /*! - * Creates a stop script without synchronization instruction.\n - * Use this if synchronization is not possible at the stop point. - * \note Only use this when everything is synchronized, e. g. after a function call. + * Creates a sprite/stage invalidation point.\n + * Use this if synchronization is not possible because the target has been deleted. */ -void Compiler::createStopWithoutSync() +void Compiler::invalidateTarget() { - impl->builder->createStopWithoutSync(); + impl->builder->invalidateTarget(); } /*! Creates a call to the procedure with the given prototype. */ diff --git a/src/engine/internal/icodebuilder.h b/src/engine/internal/icodebuilder.h index 1c793c1d1..16e81384a 100644 --- a/src/engine/internal/icodebuilder.h +++ b/src/engine/internal/icodebuilder.h @@ -99,7 +99,9 @@ class ICodeBuilder virtual void yield() = 0; virtual void createStop() = 0; - virtual void createStopWithoutSync() = 0; + virtual void createThreadStop() = 0; + + virtual void invalidateTarget() = 0; virtual void createProcedureCall(BlockPrototype *prototype, const Compiler::Args &args) = 0; }; diff --git a/src/engine/internal/llvm/instructions/control.cpp b/src/engine/internal/llvm/instructions/control.cpp index 5afc458c8..22a34cfd8 100644 --- a/src/engine/internal/llvm/instructions/control.cpp +++ b/src/engine/internal/llvm/instructions/control.cpp @@ -60,8 +60,12 @@ ProcessResult Control::process(LLVMInstruction *ins) ret.next = buildStop(ins); break; - case LLVMInstruction::Type::StopWithoutSync: - ret.next = buildStopWithoutSync(ins); + case LLVMInstruction::Type::ThreadStop: + ret.next = buildThreadStop(ins); + break; + + case LLVMInstruction::Type::InvalidateTarget: + ret.next = buildInvalidateTarget(ins); break; default: @@ -337,14 +341,27 @@ LLVMInstruction *Control::buildEndLoop(LLVMInstruction *ins) LLVMInstruction *Control::buildStop(LLVMInstruction *ins) { m_utils.syncVariables(); - return buildStopWithoutSync(ins); + + m_builder.CreateBr(m_utils.endBranch()); + llvm::BasicBlock *nextBranch = llvm::BasicBlock::Create(m_utils.llvmCtx(), "", m_utils.function()); + m_builder.SetInsertPoint(nextBranch); + + return ins->next; } -LLVMInstruction *Control::buildStopWithoutSync(LLVMInstruction *ins) +LLVMInstruction *Control::buildThreadStop(LLVMInstruction *ins) { - m_builder.CreateBr(m_utils.endBranch()); + m_utils.syncVariables(); + m_builder.CreateBr(m_utils.endThreadBranch()); + llvm::BasicBlock *nextBranch = llvm::BasicBlock::Create(m_utils.llvmCtx(), "", m_utils.function()); m_builder.SetInsertPoint(nextBranch); return ins->next; } + +LLVMInstruction *Control::buildInvalidateTarget(LLVMInstruction *ins) +{ + m_utils.invalidateTarget(); + return ins->next; +} diff --git a/src/engine/internal/llvm/instructions/control.h b/src/engine/internal/llvm/instructions/control.h index bb294f9b5..9b18fdeb7 100644 --- a/src/engine/internal/llvm/instructions/control.h +++ b/src/engine/internal/llvm/instructions/control.h @@ -27,7 +27,8 @@ class Control : public InstructionGroup LLVMInstruction *buildBeginLoopCondition(LLVMInstruction *ins); LLVMInstruction *buildEndLoop(LLVMInstruction *ins); LLVMInstruction *buildStop(LLVMInstruction *ins); - LLVMInstruction *buildStopWithoutSync(LLVMInstruction *ins); + LLVMInstruction *buildThreadStop(LLVMInstruction *ins); + LLVMInstruction *buildInvalidateTarget(LLVMInstruction *ins); }; } // namespace libscratchcpp::llvmins diff --git a/src/engine/internal/llvm/instructions/procedures.cpp b/src/engine/internal/llvm/instructions/procedures.cpp index 0b9c9b645..cac59650a 100644 --- a/src/engine/internal/llvm/instructions/procedures.cpp +++ b/src/engine/internal/llvm/instructions/procedures.cpp @@ -66,9 +66,17 @@ LLVMInstruction *Procedures::buildCallProcedure(LLVMInstruction *ins) args.push_back(m_utils.createValue(arg.second)); } + // Call the procedure llvm::Value *handle = m_builder.CreateCall(m_utils.functions().resolveFunction(name, type), args); + // Check for end thread sentinel value + llvm::BasicBlock *nextBranch = llvm::BasicBlock::Create(llvmCtx, "", function); + llvm::Value *endThread = m_builder.CreateICmpEQ(handle, m_utils.threadEndSentinel()); + m_builder.CreateCondBr(endThread, m_utils.endThreadBranch(), nextBranch); + m_builder.SetInsertPoint(nextBranch); + if (!m_utils.warp() && !ins->procedurePrototype->warp()) { + // Handle suspend llvm::BasicBlock *suspendBranch = llvm::BasicBlock::Create(llvmCtx, "", function); llvm::BasicBlock *nextBranch = llvm::BasicBlock::Create(llvmCtx, "", function); m_builder.CreateCondBr(m_builder.CreateIsNull(handle), nextBranch, suspendBranch); @@ -79,6 +87,13 @@ LLVMInstruction *Procedures::buildCallProcedure(LLVMInstruction *ins) m_builder.CreateCondBr(done, nextBranch, suspendBranch); m_builder.SetInsertPoint(nextBranch); + + // The thread could be stopped from the coroutine + llvm::BasicBlock *afterResumeBranch = llvm::BasicBlock::Create(llvmCtx, "", function); + llvm::Value *isFinished = m_builder.CreateCall(m_utils.functions().resolve_llvm_is_thread_finished(), m_utils.executionContextPtr()); + m_builder.CreateCondBr(isFinished, m_utils.endThreadBranch(), afterResumeBranch); + + m_builder.SetInsertPoint(afterResumeBranch); } m_utils.reloadVariables(); diff --git a/src/engine/internal/llvm/llvmbuildutils.cpp b/src/engine/internal/llvm/llvmbuildutils.cpp index 717626a4d..8dc05aaff 100644 --- a/src/engine/internal/llvm/llvmbuildutils.cpp +++ b/src/engine/internal/llvm/llvmbuildutils.cpp @@ -161,8 +161,13 @@ void LLVMBuildUtils::init(llvm::Function *function, BlockPrototype *procedurePro reloadVariables(); reloadLists(); - // Create end branch + // Mark the target as valid + m_targetValidFlag = m_builder.CreateAlloca(m_builder.getInt1Ty()); + m_builder.CreateStore(m_builder.getInt1(true), m_targetValidFlag); + + // Create end branches m_endBranch = llvm::BasicBlock::Create(m_llvmCtx, "end", m_function); + m_endThreadBranch = llvm::BasicBlock::Create(m_llvmCtx, "endThread", m_function); } void LLVMBuildUtils::end(LLVMInstruction *lastInstruction, LLVMRegister *lastConstant) @@ -184,9 +189,9 @@ void LLVMBuildUtils::end(LLVMInstruction *lastInstruction, LLVMRegister *lastCon syncVariables(); m_builder.CreateBr(m_endBranch); + // End branch m_builder.SetInsertPoint(m_endBranch); - // End the script function llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0); switch (m_codeType) { @@ -216,6 +221,33 @@ void LLVMBuildUtils::end(LLVMInstruction *lastInstruction, LLVMRegister *lastCon m_builder.CreateRet(castValue(lastConstant, Compiler::StaticType::Bool)); break; } + + // Thread end branch (stop the entire thread, including procedure callers) + m_builder.SetInsertPoint(m_endThreadBranch); + + switch (m_codeType) { + case Compiler::CodeType::Script: + // Mark the thread as finished + m_builder.CreateCall(m_functions.resolve_llvm_mark_thread_as_finished(), { m_executionContextPtr }); + + // Return a sentinel value (special pointer) to terminate any procedure callers + if (m_warp) + m_builder.CreateRet(threadEndSentinel()); + else if (m_procedurePrototype) + m_coroutine->endWithSentinel(threadEndSentinel()); + else { + // There's no need to return the sentinel value in standard scripts because they don't have any callers + m_coroutine->end(); + } + + break; + + case Compiler::CodeType::Reporter: + case Compiler::CodeType::HatPredicate: + // Procedures cannot be called by these scripts, so we don't have to return the sentinel value + m_builder.CreateBr(m_endBranch); + break; + } } LLVMCompilerContext *LLVMBuildUtils::compilerCtx() const @@ -323,6 +355,17 @@ llvm::BasicBlock *LLVMBuildUtils::endBranch() const return m_endBranch; } +llvm::BasicBlock *LLVMBuildUtils::endThreadBranch() const +{ + return m_endThreadBranch; +} + +llvm::Value *LLVMBuildUtils::threadEndSentinel() const +{ + llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0); + return m_builder.CreateIntToPtr(m_builder.getInt64(1), pointerType, "threadEndSentinel"); +} + BlockPrototype *LLVMBuildUtils::procedurePrototype() const { return m_procedurePrototype; @@ -401,6 +444,12 @@ LLVMListPtr &LLVMBuildUtils::listPtr(List *list) void LLVMBuildUtils::syncVariables() { + llvm::BasicBlock *syncBlock = llvm::BasicBlock::Create(m_llvmCtx, "syncVariables", m_function); + llvm::BasicBlock *syncNextBlock = llvm::BasicBlock::Create(m_llvmCtx, "syncVariables.next", m_function); + m_builder.CreateCondBr(m_builder.CreateLoad(m_builder.getInt1Ty(), m_targetValidFlag), syncBlock, syncNextBlock); + + m_builder.SetInsertPoint(syncBlock); + // Copy stack variables to the actual variables for (auto &[var, varPtr] : m_variablePtrs) { llvm::BasicBlock *copyBlock = llvm::BasicBlock::Create(m_llvmCtx, "syncVar", m_function); @@ -414,6 +463,9 @@ void LLVMBuildUtils::syncVariables() m_builder.SetInsertPoint(nextBlock); } + + m_builder.CreateBr(syncNextBlock); + m_builder.SetInsertPoint(syncNextBlock); } void LLVMBuildUtils::reloadVariables() @@ -444,6 +496,11 @@ void LLVMBuildUtils::reloadLists() } } +void LLVMBuildUtils::invalidateTarget() +{ + m_builder.CreateStore(m_builder.getInt1(false), m_targetValidFlag); +} + std::vector &LLVMBuildUtils::ifStatements() { return m_ifStatements; diff --git a/src/engine/internal/llvm/llvmbuildutils.h b/src/engine/internal/llvm/llvmbuildutils.h index 20a0a1a82..2a7938081 100644 --- a/src/engine/internal/llvm/llvmbuildutils.h +++ b/src/engine/internal/llvm/llvmbuildutils.h @@ -56,6 +56,9 @@ class LIBSCRATCHCPP_TEST_EXPORT LLVMBuildUtils size_t stringCount() const; llvm::BasicBlock *endBranch() const; + llvm::BasicBlock *endThreadBranch() const; + + llvm::Value *threadEndSentinel() const; BlockPrototype *procedurePrototype() const; bool warp() const; @@ -79,6 +82,7 @@ class LIBSCRATCHCPP_TEST_EXPORT LLVMBuildUtils void syncVariables(); void reloadVariables(); void reloadLists(); + void invalidateTarget(); std::vector &ifStatements(); std::vector &loops(); @@ -166,6 +170,7 @@ class LIBSCRATCHCPP_TEST_EXPORT LLVMBuildUtils llvm::Value *m_functionIdValue = nullptr; llvm::BasicBlock *m_endBranch = nullptr; + llvm::BasicBlock *m_endThreadBranch = nullptr; llvm::StructType *m_valueDataType = nullptr; llvm::StructType *m_stringPtrType = nullptr; @@ -182,6 +187,7 @@ class LIBSCRATCHCPP_TEST_EXPORT LLVMBuildUtils llvm::Value *m_targetVariables = nullptr; llvm::Value *m_targetLists = nullptr; llvm::Value *m_warpArg = nullptr; + llvm::Value *m_targetValidFlag = nullptr; std::unique_ptr m_coroutine; diff --git a/src/engine/internal/llvm/llvmcodebuilder.cpp b/src/engine/internal/llvm/llvmcodebuilder.cpp index aa7682929..d6f6bb1c8 100644 --- a/src/engine/internal/llvm/llvmcodebuilder.cpp +++ b/src/engine/internal/llvm/llvmcodebuilder.cpp @@ -567,9 +567,15 @@ void LLVMCodeBuilder::createStop() m_instructions.addInstruction(ins); } -void LLVMCodeBuilder::createStopWithoutSync() +void LLVMCodeBuilder::createThreadStop() { - auto ins = std::make_shared(LLVMInstruction::Type::StopWithoutSync, m_loopCondition); + auto ins = std::make_shared(LLVMInstruction::Type::ThreadStop, m_loopCondition); + m_instructions.addInstruction(ins); +} + +void LLVMCodeBuilder::invalidateTarget() +{ + auto ins = std::make_shared(LLVMInstruction::Type::InvalidateTarget, m_loopCondition); m_instructions.addInstruction(ins); } diff --git a/src/engine/internal/llvm/llvmcodebuilder.h b/src/engine/internal/llvm/llvmcodebuilder.h index 7a85eb34f..3fb986e5e 100644 --- a/src/engine/internal/llvm/llvmcodebuilder.h +++ b/src/engine/internal/llvm/llvmcodebuilder.h @@ -113,7 +113,9 @@ class LIBSCRATCHCPP_TEST_EXPORT LLVMCodeBuilder : public ICodeBuilder void yield() override; void createStop() override; - void createStopWithoutSync() override; + void createThreadStop() override; + + void invalidateTarget() override; void createProcedureCall(BlockPrototype *prototype, const Compiler::Args &args) override; diff --git a/src/engine/internal/llvm/llvmcoroutine.cpp b/src/engine/internal/llvm/llvmcoroutine.cpp index 02c4ecb16..74069a416 100644 --- a/src/engine/internal/llvm/llvmcoroutine.cpp +++ b/src/engine/internal/llvm/llvmcoroutine.cpp @@ -35,21 +35,32 @@ LLVMCoroutine::LLVMCoroutine(llvm::Module *module, llvm::IRBuilder<> *builder, l // Begin m_handle = builder->CreateCall(coroBegin, { coroIdRet, alloc }); + m_didSuspendVar = builder->CreateAlloca(builder->getInt1Ty(), nullptr, "didSuspend"); builder->CreateStore(builder->getInt1(false), m_didSuspendVar); + + m_sentinelVar = builder->CreateAlloca(pointerType, nullptr, "sentinel"); + builder->CreateStore(nullPointer, m_sentinelVar); + llvm::BasicBlock *entry = builder->GetInsertBlock(); // Create suspend branch m_suspendBlock = llvm::BasicBlock::Create(ctx, "suspend", func); builder->SetInsertPoint(m_suspendBlock); + + llvm::Value *sentinelValue = builder->CreateLoad(pointerType, m_sentinelVar); + llvm::Value *sentinelIsNull = builder->CreateIsNull(sentinelValue); builder->CreateCall(coroEnd, { m_handle, builder->getInt1(false), llvm::ConstantTokenNone::get(ctx) }); - builder->CreateRet(m_handle); + builder->CreateRet(builder->CreateSelect(sentinelIsNull, m_handle, sentinelValue)); // Create free branches m_freeMemRetBlock = llvm::BasicBlock::Create(ctx, "freeMemRet", func); builder->SetInsertPoint(m_freeMemRetBlock); + + sentinelValue = builder->CreateLoad(pointerType, m_sentinelVar); + sentinelIsNull = builder->CreateIsNull(sentinelValue); builder->CreateFree(alloc); - builder->CreateRet(llvm::ConstantPointerNull::get(pointerType)); + builder->CreateRet(builder->CreateSelect(sentinelIsNull, llvm::ConstantPointerNull::get(pointerType), sentinelValue)); llvm::BasicBlock *freeBranch = llvm::BasicBlock::Create(ctx, "free", func); builder->SetInsertPoint(freeBranch); @@ -63,6 +74,15 @@ LLVMCoroutine::LLVMCoroutine(llvm::Module *module, llvm::IRBuilder<> *builder, l llvm::Value *needFree = builder->CreateIsNotNull(mem); builder->CreateCondBr(needFree, freeBranch, m_suspendBlock); + // Create final suspend point + m_finalSuspendBlock = llvm::BasicBlock::Create(ctx, "finalSuspend", m_function); + + m_builder->SetInsertPoint(m_finalSuspendBlock); + llvm::Value *suspendResult = m_builder->CreateCall(llvm::Intrinsic::getDeclaration(m_module, llvm::Intrinsic::coro_suspend), { llvm::ConstantTokenNone::get(ctx), m_builder->getInt1(true) }); + llvm::SwitchInst *sw = m_builder->CreateSwitch(suspendResult, m_suspendBlock, 2); + sw->addCase(m_builder->getInt8(0), m_freeMemRetBlock); + sw->addCase(m_builder->getInt8(1), m_cleanupBlock); + builder->SetInsertPoint(entry); } @@ -135,20 +155,11 @@ llvm::Value *LLVMCoroutine::createResume(llvm::Module *module, llvm::IRBuilder<> void LLVMCoroutine::end() { - llvm::LLVMContext &ctx = m_builder->getContext(); - - // Add final suspend point - llvm::BasicBlock *endBranch = llvm::BasicBlock::Create(ctx, "end", m_function); - llvm::BasicBlock *finalSuspendBranch = llvm::BasicBlock::Create(ctx, "finalSuspend", m_function); - m_builder->CreateCondBr(m_builder->CreateLoad(m_builder->getInt1Ty(), m_didSuspendVar), finalSuspendBranch, endBranch); - - m_builder->SetInsertPoint(finalSuspendBranch); - llvm::Value *suspendResult = m_builder->CreateCall(llvm::Intrinsic::getDeclaration(m_module, llvm::Intrinsic::coro_suspend), { llvm::ConstantTokenNone::get(ctx), m_builder->getInt1(true) }); - llvm::SwitchInst *sw = m_builder->CreateSwitch(suspendResult, m_suspendBlock, 2); - sw->addCase(m_builder->getInt8(0), endBranch); // unreachable - sw->addCase(m_builder->getInt8(1), m_cleanupBlock); + m_builder->CreateCondBr(m_builder->CreateLoad(m_builder->getInt1Ty(), m_didSuspendVar), m_finalSuspendBlock, m_freeMemRetBlock); +} - // Jump to "free and return" branch - m_builder->SetInsertPoint(endBranch); - m_builder->CreateBr(m_freeMemRetBlock); +void LLVMCoroutine::endWithSentinel(llvm::Value *sentinel) +{ + m_builder->CreateStore(sentinel, m_sentinelVar); + end(); } diff --git a/src/engine/internal/llvm/llvmcoroutine.h b/src/engine/internal/llvm/llvmcoroutine.h index fa3d79fd9..a04a2ad10 100644 --- a/src/engine/internal/llvm/llvmcoroutine.h +++ b/src/engine/internal/llvm/llvmcoroutine.h @@ -24,6 +24,7 @@ class LLVMCoroutine void createSuspend(); static llvm::Value *createResume(llvm::Module *module, llvm::IRBuilder<> *builder, llvm::Function *function, llvm::Value *coroHandle); void end(); + void endWithSentinel(llvm::Value *sentinel); private: llvm::Module *m_module = nullptr; @@ -31,9 +32,11 @@ class LLVMCoroutine llvm::Function *m_function = nullptr; llvm::Value *m_handle = nullptr; llvm::BasicBlock *m_suspendBlock = nullptr; + llvm::BasicBlock *m_finalSuspendBlock = nullptr; llvm::BasicBlock *m_cleanupBlock = nullptr; llvm::BasicBlock *m_freeMemRetBlock = nullptr; llvm::Value *m_didSuspendVar = nullptr; + llvm::Value *m_sentinelVar = nullptr; }; } // namespace libscratchcpp diff --git a/src/engine/internal/llvm/llvmexecutablecode.cpp b/src/engine/internal/llvm/llvmexecutablecode.cpp index 1acb79f42..ce645a007 100644 --- a/src/engine/internal/llvm/llvmexecutablecode.cpp +++ b/src/engine/internal/llvm/llvmexecutablecode.cpp @@ -16,6 +16,8 @@ using namespace libscratchcpp; +static const void *END_THREAD_SENTINEL = (void *)0x1; + LLVMExecutableCode::LLVMExecutableCode( LLVMCompilerContext *ctx, function_id_t functionId, @@ -70,10 +72,11 @@ void LLVMExecutableCode::run(ExecutionContext *context) MainFunctionType f = std::get(m_mainFunction); void *handle = f(context, target, target->variableData(), target->listData()); - if (!handle) + if (!handle || handle == END_THREAD_SENTINEL) { ctx->setFinished(true); - - ctx->setCoroutineHandle(handle); + ctx->setCoroutineHandle(nullptr); + } else + ctx->setCoroutineHandle(handle); } } diff --git a/src/engine/internal/llvm/llvmfunctions.cpp b/src/engine/internal/llvm/llvmfunctions.cpp index b5f90d550..4681d846d 100644 --- a/src/engine/internal/llvm/llvmfunctions.cpp +++ b/src/engine/internal/llvm/llvmfunctions.cpp @@ -35,6 +35,16 @@ extern "C" { return static_cast(ctx)->getStringArray(functionId); } + + LIBSCRATCHCPP_EXPORT void llvm_mark_thread_as_finished(ExecutionContext *ctx) + { + static_cast(ctx)->setFinished(true); + } + + LIBSCRATCHCPP_EXPORT bool llvm_is_thread_finished(ExecutionContext *ctx) + { + return static_cast(ctx)->finished(); + } } LLVMFunctions::LLVMFunctions(LLVMCompilerContext *ctx, llvm::IRBuilder<> *builder) : @@ -282,6 +292,18 @@ llvm::FunctionCallee LLVMFunctions::resolve_llvm_get_string_array() return callee; } +llvm::FunctionCallee LLVMFunctions::resolve_llvm_mark_thread_as_finished() +{ + llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(*m_ctx->llvmCtx()), 0); + return resolveFunction("llvm_mark_thread_as_finished", llvm::FunctionType::get(m_builder->getVoidTy(), { pointerType }, false)); +} + +llvm::FunctionCallee LLVMFunctions::resolve_llvm_is_thread_finished() +{ + llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(*m_ctx->llvmCtx()), 0); + return resolveFunction("llvm_is_thread_finished", llvm::FunctionType::get(m_builder->getInt1Ty(), { pointerType }, false)); +} + llvm::FunctionCallee LLVMFunctions::resolve_string_pool_new() { return resolveFunction("string_pool_new", llvm::FunctionType::get(m_stringPtrType->getPointerTo(), false)); diff --git a/src/engine/internal/llvm/llvmfunctions.h b/src/engine/internal/llvm/llvmfunctions.h index 3aba66916..cf608c744 100644 --- a/src/engine/internal/llvm/llvmfunctions.h +++ b/src/engine/internal/llvm/llvmfunctions.h @@ -48,6 +48,8 @@ class LLVMFunctions llvm::FunctionCallee resolve_llvm_random_int64(); llvm::FunctionCallee resolve_llvm_random_bool(); llvm::FunctionCallee resolve_llvm_get_string_array(); + llvm::FunctionCallee resolve_llvm_mark_thread_as_finished(); + llvm::FunctionCallee resolve_llvm_is_thread_finished(); llvm::FunctionCallee resolve_string_pool_new(); llvm::FunctionCallee resolve_string_pool_free(); llvm::FunctionCallee resolve_string_alloc(); diff --git a/src/engine/internal/llvm/llvminstruction.h b/src/engine/internal/llvm/llvminstruction.h index 558258aa6..98d815bd0 100644 --- a/src/engine/internal/llvm/llvminstruction.h +++ b/src/engine/internal/llvm/llvminstruction.h @@ -77,7 +77,8 @@ struct LLVMInstruction BeginLoopCondition, EndLoop, Stop, - StopWithoutSync, + ThreadStop, + InvalidateTarget, CallProcedure, ProcedureArg }; diff --git a/test/compiler/compiler_test.cpp b/test/compiler/compiler_test.cpp index dfb14143e..cb876c31f 100644 --- a/test/compiler/compiler_test.cpp +++ b/test/compiler/compiler_test.cpp @@ -1673,14 +1673,28 @@ TEST_F(CompilerTest, CreateStop) compile(m_compiler.get(), block.get()); } -TEST_F(CompilerTest, CreateStopWithoutSync) +TEST_F(CompilerTest, CreateThreadStop) { auto block = std::make_shared("", ""); block->setCompileFunction([](Compiler *compiler) -> CompilerValue * { - EXPECT_CALL(*m_builder, createStopWithoutSync()); - compiler->createStopWithoutSync(); + EXPECT_CALL(*m_builder, createThreadStop()); + compiler->createThreadStop(); + return nullptr; + }); + + compile(m_compiler.get(), block.get()); +} + +TEST_F(CompilerTest, InvalidateTarget) +{ + + auto block = std::make_shared("", ""); + + block->setCompileFunction([](Compiler *compiler) -> CompilerValue * { + EXPECT_CALL(*m_builder, invalidateTarget()); + compiler->invalidateTarget(); return nullptr; }); diff --git a/test/engine/engine_test.cpp b/test/engine/engine_test.cpp index 66121708b..ebeab4a7d 100644 --- a/test/engine/engine_test.cpp +++ b/test/engine/engine_test.cpp @@ -2280,3 +2280,105 @@ TEST(EngineTest, BroadcastAndWaitCaseInsensitive) ASSERT_VAR(stage, "passed"); ASSERT_TRUE(GET_VAR(stage, "passed")->value().toBool()); } + +TEST(EngineTest, CloneScriptsAreStoppedFromProcedureDeletingTheClone) +{ + // Regtest for #660 + Project p("regtest_projects/660_stop_script_from_procedure_clone.sb3"); + ASSERT_TRUE(p.load()); + + auto engine = p.engine(); + + Stage *stage = engine->stage(); + ASSERT_TRUE(stage); + + engine->run(); + + ASSERT_VAR(stage, "passed"); + ASSERT_TRUE(GET_VAR(stage, "passed")->value().toBool()); +} + +TEST(EngineTest, ScriptsAreNotStoppedFromProcedureStoppingTheProcedureScript) +{ + // Regtest for #660 + Project p("regtest_projects/660_stop_script_from_procedure_stop_this_script.sb3"); + ASSERT_TRUE(p.load()); + + auto engine = p.engine(); + + Stage *stage = engine->stage(); + ASSERT_TRUE(stage); + + engine->run(); + + ASSERT_VAR(stage, "passed"); + ASSERT_TRUE(GET_VAR(stage, "passed")->value().toBool()); +} + +TEST(EngineTest, ScriptsAreNotStoppedFromProcedureStoppingOtherScripts) +{ + // Regtest for #660 + Project p("regtest_projects/660_stop_script_from_procedure_stop_other_scripts.sb3"); + ASSERT_TRUE(p.load()); + + auto engine = p.engine(); + + Stage *stage = engine->stage(); + ASSERT_TRUE(stage); + + engine->run(); + + ASSERT_VAR(stage, "passed"); + ASSERT_TRUE(GET_VAR(stage, "passed")->value().toBool()); +} + +TEST(EngineTest, ScriptsAreStoppedFromProcedureStoppingAll) +{ + // Regtest for #660 + Project p("regtest_projects/660_stop_script_from_procedure_stop_all.sb3"); + ASSERT_TRUE(p.load()); + + auto engine = p.engine(); + + Stage *stage = engine->stage(); + ASSERT_TRUE(stage); + + engine->run(); + + ASSERT_VAR(stage, "passed"); + ASSERT_TRUE(GET_VAR(stage, "passed")->value().toBool()); +} + +TEST(EngineTest, ScriptsAreStoppedFromProcedureStoppingAll_NonWarp) +{ + // Regtest for #660 + Project p("regtest_projects/660_stop_script_from_procedure_stop_all_non_warp.sb3"); + ASSERT_TRUE(p.load()); + + auto engine = p.engine(); + + Stage *stage = engine->stage(); + ASSERT_TRUE(stage); + + engine->run(); + + ASSERT_VAR(stage, "passed"); + ASSERT_TRUE(GET_VAR(stage, "passed")->value().toBool()); +} + +TEST(EngineTest, ScriptsAreStoppedFromProcedureStoppingAll_Nested) +{ + // Regtest for #660 + Project p("regtest_projects/660_stop_script_from_procedure_stop_all_nested.sb3"); + ASSERT_TRUE(p.load()); + + auto engine = p.engine(); + + Stage *stage = engine->stage(); + ASSERT_TRUE(stage); + + engine->run(); + + ASSERT_VAR(stage, "passed"); + ASSERT_TRUE(GET_VAR(stage, "passed")->value().toBool()); +} diff --git a/test/llvm/llvmcodebuilder_test.cpp b/test/llvm/llvmcodebuilder_test.cpp index 498e79a92..82340901d 100644 --- a/test/llvm/llvmcodebuilder_test.cpp +++ b/test/llvm/llvmcodebuilder_test.cpp @@ -3552,6 +3552,233 @@ TEST_F(LLVMCodeBuilderTest, UndefinedProcedure) ASSERT_TRUE(code->isFinished(ctx.get())); } +TEST_F(LLVMCodeBuilderTest, ProcedureThreadStop_Warp) +{ + Sprite sprite; + + // Inner procedure (proc2): prints "inner_before", then stops the thread + BlockPrototype prototype2; + prototype2.setProcCode("proc2"); + prototype2.setWarp(true); + LLVMCodeBuilder *builder = m_utils.createBuilder(&sprite, &prototype2); + + CompilerValue *v = builder->addConstValue("inner_before"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + builder->createThreadStop(); + + // This should NOT execute + v = builder->addConstValue("inner_after"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + auto proc2Code = builder->build(); + + // Outer procedure (proc1): prints "outer_before", calls proc2, then prints "outer_after" + BlockPrototype prototype1; + prototype1.setProcCode("proc1"); + prototype1.setWarp(true); + builder = m_utils.createBuilder(&sprite, &prototype1); + + v = builder->addConstValue("outer_before"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + builder->createProcedureCall(&prototype2, {}); + + // This should NOT execute (thread was stopped by proc2) + v = builder->addConstValue("outer_after"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + auto proc1Code = builder->build(); + + // Root script: prints "script_before", calls proc1, then prints "script_after" + builder = m_utils.createBuilder(&sprite, true); + + v = builder->addConstValue("script_before"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + builder->createProcedureCall(&prototype1, {}); + + // This should NOT execute (thread was stopped by proc2 via proc1) + v = builder->addConstValue("script_after"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + std::string expected = + "script_before\n" + "outer_before\n" + "inner_before\n"; + + auto code = builder->build(); + Script script(&sprite, nullptr, nullptr); + script.setCode(code); + Thread thread(&sprite, nullptr, &script); + auto ctx = code->createExecutionContext(&thread); + + testing::internal::CaptureStdout(); + code->run(ctx.get()); + ASSERT_EQ(testing::internal::GetCapturedStdout(), expected); + ASSERT_TRUE(code->isFinished(ctx.get())); +} + +TEST_F(LLVMCodeBuilderTest, ProcedureThreadStop_NonWarp) +{ + Sprite sprite; + + // Inner procedure (proc2): prints "inner_before", then stops the thread + BlockPrototype prototype2; + prototype2.setProcCode("proc2"); + prototype2.setWarp(false); + LLVMCodeBuilder *builder = m_utils.createBuilder(&sprite, &prototype2); + + CompilerValue *v = builder->addConstValue("inner_before"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + builder->createThreadStop(); + + // This should NOT execute + v = builder->addConstValue("inner_after"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + auto proc2Code = builder->build(); + + // Outer procedure (proc1): prints "outer_before", calls proc2, then prints "outer_after" + BlockPrototype prototype1; + prototype1.setProcCode("proc1"); + prototype1.setWarp(false); + builder = m_utils.createBuilder(&sprite, &prototype1); + + v = builder->addConstValue("outer_before"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + builder->createProcedureCall(&prototype2, {}); + + // This should NOT execute (thread was stopped by proc2) + v = builder->addConstValue("outer_after"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + auto proc1Code = builder->build(); + + // Root script: prints "script_before", calls proc1, then prints "script_after" + builder = m_utils.createBuilder(&sprite, false); + + v = builder->addConstValue("script_before"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + builder->createProcedureCall(&prototype1, {}); + + // This should NOT execute (thread was stopped by proc2 via proc1) + v = builder->addConstValue("script_after"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + std::string expected = + "script_before\n" + "outer_before\n" + "inner_before\n"; + + auto code = builder->build(); + Script script(&sprite, nullptr, nullptr); + script.setCode(code); + Thread thread(&sprite, nullptr, &script); + auto ctx = code->createExecutionContext(&thread); + + testing::internal::CaptureStdout(); + code->run(ctx.get()); + ASSERT_EQ(testing::internal::GetCapturedStdout(), expected); + ASSERT_TRUE(code->isFinished(ctx.get())); +} + +TEST_F(LLVMCodeBuilderTest, ProcedureThreadStop_NonWarp_AfterYield) +{ + Sprite sprite; + + // Inner procedure (proc2): yields via a repeat loop, then stops the thread + // This exercises the coroutine resume path where the sentinel must propagate + BlockPrototype prototype2; + prototype2.setProcCode("proc2"); + prototype2.setWarp(false); + + LLVMCodeBuilder *builder = m_utils.createBuilder(&sprite, &prototype2); + CompilerValue *v = builder->addConstValue("inner_before"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + // This repeat loop causes the coroutine to suspend (yield) on each iteration + v = builder->addConstValue(2); + builder->beginRepeatLoop(v); + { + v = builder->addConstValue("inner_loop"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + } + builder->endLoop(); + + // After the loop completes, stop the thread + builder->createThreadStop(); + + // This should NOT execute + v = builder->addConstValue("inner_after"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + auto proc2Code = builder->build(); + + // Outer procedure (proc1): calls proc2 + BlockPrototype prototype1; + prototype1.setProcCode("proc1"); + prototype1.setWarp(false); + + builder = m_utils.createBuilder(&sprite, &prototype1); + + v = builder->addConstValue("outer_before"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + builder->createProcedureCall(&prototype2, {}); + + // This should NOT execute (thread was stopped by proc2) + v = builder->addConstValue("outer_after"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + auto proc1Code = builder->build(); + + // Root script: calls proc1 + builder = m_utils.createBuilder(&sprite, false); + v = builder->addConstValue("script_before"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + builder->createProcedureCall(&prototype1, {}); + + // This should NOT execute (thread was stopped by proc2 via proc1) + v = builder->addConstValue("script_after"); + builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + auto code = builder->build(); + Script script(&sprite, nullptr, nullptr); + script.setCode(code); + + Thread thread(&sprite, nullptr, &script); + auto ctx = code->createExecutionContext(&thread); + + // First run: enters the repeat loop in proc2, yields after first iteration + std::string expected1 = + "script_before\n" + "outer_before\n" + "inner_before\n" + "inner_loop\n"; + + testing::internal::CaptureStdout(); + code->run(ctx.get()); + ASSERT_EQ(testing::internal::GetCapturedStdout(), expected1); + ASSERT_FALSE(code->isFinished(ctx.get())); + + // Second run: second iteration of the repeat loop, yields again + std::string expected2 = "inner_loop\n"; + testing::internal::CaptureStdout(); + code->run(ctx.get()); + ASSERT_EQ(testing::internal::GetCapturedStdout(), expected2); + ASSERT_FALSE(code->isFinished(ctx.get())); + + // Third run: loop is done, createThreadStop() fires, sentinel must propagate + // through the coroutine resume path back to the caller + // Neither "inner_after", "outer_after", nor "script_after" should print + testing::internal::CaptureStdout(); + code->run(ctx.get()); + ASSERT_EQ(testing::internal::GetCapturedStdout(), ""); + ASSERT_TRUE(code->isFinished(ctx.get())); +} + TEST_F(LLVMCodeBuilderTest, HatPredicates) { Sprite sprite; diff --git a/test/mocks/codebuildermock.h b/test/mocks/codebuildermock.h index 1d9225a23..9cdeba022 100644 --- a/test/mocks/codebuildermock.h +++ b/test/mocks/codebuildermock.h @@ -88,7 +88,9 @@ class CodeBuilderMock : public ICodeBuilder MOCK_METHOD(void, yield, (), (override)); MOCK_METHOD(void, createStop, (), (override)); - MOCK_METHOD(void, createStopWithoutSync, (), (override)); + MOCK_METHOD(void, createThreadStop, (), (override)); + + MOCK_METHOD(void, invalidateTarget, (), (override)); MOCK_METHOD(void, createProcedureCall, (BlockPrototype *, const Compiler::Args &), (override)); }; diff --git a/test/regtest_projects/660_stop_script_from_procedure_clone.sb3 b/test/regtest_projects/660_stop_script_from_procedure_clone.sb3 new file mode 100644 index 000000000..54065eab7 Binary files /dev/null and b/test/regtest_projects/660_stop_script_from_procedure_clone.sb3 differ diff --git a/test/regtest_projects/660_stop_script_from_procedure_stop_all.sb3 b/test/regtest_projects/660_stop_script_from_procedure_stop_all.sb3 new file mode 100644 index 000000000..a52437581 Binary files /dev/null and b/test/regtest_projects/660_stop_script_from_procedure_stop_all.sb3 differ diff --git a/test/regtest_projects/660_stop_script_from_procedure_stop_all_nested.sb3 b/test/regtest_projects/660_stop_script_from_procedure_stop_all_nested.sb3 new file mode 100644 index 000000000..35213c419 Binary files /dev/null and b/test/regtest_projects/660_stop_script_from_procedure_stop_all_nested.sb3 differ diff --git a/test/regtest_projects/660_stop_script_from_procedure_stop_all_non_warp.sb3 b/test/regtest_projects/660_stop_script_from_procedure_stop_all_non_warp.sb3 new file mode 100644 index 000000000..32e404135 Binary files /dev/null and b/test/regtest_projects/660_stop_script_from_procedure_stop_all_non_warp.sb3 differ diff --git a/test/regtest_projects/660_stop_script_from_procedure_stop_other_scripts.sb3 b/test/regtest_projects/660_stop_script_from_procedure_stop_other_scripts.sb3 new file mode 100644 index 000000000..b0f80633f Binary files /dev/null and b/test/regtest_projects/660_stop_script_from_procedure_stop_other_scripts.sb3 differ diff --git a/test/regtest_projects/660_stop_script_from_procedure_stop_this_script.sb3 b/test/regtest_projects/660_stop_script_from_procedure_stop_this_script.sb3 new file mode 100644 index 000000000..0a2303e80 Binary files /dev/null and b/test/regtest_projects/660_stop_script_from_procedure_stop_this_script.sb3 differ