Skip to content

Commit e1b7e11

Browse files
committed
Fix threads not stopping when stopped from resumed coroutines
1 parent 2557cd8 commit e1b7e11

File tree

5 files changed

+127
-0
lines changed

5 files changed

+127
-0
lines changed

src/engine/internal/llvm/instructions/procedures.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ LLVMInstruction *Procedures::buildCallProcedure(LLVMInstruction *ins)
8787
m_builder.CreateCondBr(done, nextBranch, suspendBranch);
8888

8989
m_builder.SetInsertPoint(nextBranch);
90+
91+
// The thread could be stopped from the coroutine
92+
llvm::BasicBlock *afterResumeBranch = llvm::BasicBlock::Create(llvmCtx, "", function);
93+
llvm::Value *isFinished = m_builder.CreateCall(m_utils.functions().resolve_llvm_is_thread_finished(), m_utils.executionContextPtr());
94+
m_builder.CreateCondBr(isFinished, m_utils.endThreadBranch(), afterResumeBranch);
95+
96+
m_builder.SetInsertPoint(afterResumeBranch);
9097
}
9198

9299
m_utils.reloadVariables();

src/engine/internal/llvm/llvmbuildutils.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ void LLVMBuildUtils::end(LLVMInstruction *lastInstruction, LLVMRegister *lastCon
227227

228228
switch (m_codeType) {
229229
case Compiler::CodeType::Script:
230+
// Mark the thread as finished
231+
m_builder.CreateCall(m_functions.resolve_llvm_mark_thread_as_finished(), { m_executionContextPtr });
232+
230233
// Return a sentinel value (special pointer) to terminate any procedure callers
231234
if (m_warp)
232235
m_builder.CreateRet(threadEndSentinel());

src/engine/internal/llvm/llvmfunctions.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ extern "C"
3535
{
3636
return static_cast<LLVMExecutionContext *>(ctx)->getStringArray(functionId);
3737
}
38+
39+
LIBSCRATCHCPP_EXPORT void llvm_mark_thread_as_finished(ExecutionContext *ctx)
40+
{
41+
static_cast<LLVMExecutionContext *>(ctx)->setFinished(true);
42+
}
43+
44+
LIBSCRATCHCPP_EXPORT bool llvm_is_thread_finished(ExecutionContext *ctx)
45+
{
46+
return static_cast<LLVMExecutionContext *>(ctx)->finished();
47+
}
3848
}
3949

4050
LLVMFunctions::LLVMFunctions(LLVMCompilerContext *ctx, llvm::IRBuilder<> *builder) :
@@ -282,6 +292,18 @@ llvm::FunctionCallee LLVMFunctions::resolve_llvm_get_string_array()
282292
return callee;
283293
}
284294

295+
llvm::FunctionCallee LLVMFunctions::resolve_llvm_mark_thread_as_finished()
296+
{
297+
llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(*m_ctx->llvmCtx()), 0);
298+
return resolveFunction("llvm_mark_thread_as_finished", llvm::FunctionType::get(m_builder->getVoidTy(), { pointerType }, false));
299+
}
300+
301+
llvm::FunctionCallee LLVMFunctions::resolve_llvm_is_thread_finished()
302+
{
303+
llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(*m_ctx->llvmCtx()), 0);
304+
return resolveFunction("llvm_is_thread_finished", llvm::FunctionType::get(m_builder->getInt1Ty(), { pointerType }, false));
305+
}
306+
285307
llvm::FunctionCallee LLVMFunctions::resolve_string_pool_new()
286308
{
287309
return resolveFunction("string_pool_new", llvm::FunctionType::get(m_stringPtrType->getPointerTo(), false));

src/engine/internal/llvm/llvmfunctions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class LLVMFunctions
4848
llvm::FunctionCallee resolve_llvm_random_int64();
4949
llvm::FunctionCallee resolve_llvm_random_bool();
5050
llvm::FunctionCallee resolve_llvm_get_string_array();
51+
llvm::FunctionCallee resolve_llvm_mark_thread_as_finished();
52+
llvm::FunctionCallee resolve_llvm_is_thread_finished();
5153
llvm::FunctionCallee resolve_string_pool_new();
5254
llvm::FunctionCallee resolve_string_pool_free();
5355
llvm::FunctionCallee resolve_string_alloc();

test/llvm/llvmcodebuilder_test.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3686,6 +3686,99 @@ TEST_F(LLVMCodeBuilderTest, ProcedureThreadStop_NonWarp)
36863686
ASSERT_TRUE(code->isFinished(ctx.get()));
36873687
}
36883688

3689+
TEST_F(LLVMCodeBuilderTest, ProcedureThreadStop_NonWarp_AfterYield)
3690+
{
3691+
Sprite sprite;
3692+
3693+
// Inner procedure (proc2): yields via a repeat loop, then stops the thread
3694+
// This exercises the coroutine resume path where the sentinel must propagate
3695+
BlockPrototype prototype2;
3696+
prototype2.setProcCode("proc2");
3697+
prototype2.setWarp(false);
3698+
3699+
LLVMCodeBuilder *builder = m_utils.createBuilder(&sprite, &prototype2);
3700+
CompilerValue *v = builder->addConstValue("inner_before");
3701+
builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
3702+
3703+
// This repeat loop causes the coroutine to suspend (yield) on each iteration
3704+
v = builder->addConstValue(2);
3705+
builder->beginRepeatLoop(v);
3706+
{
3707+
v = builder->addConstValue("inner_loop");
3708+
builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
3709+
}
3710+
builder->endLoop();
3711+
3712+
// After the loop completes, stop the thread
3713+
builder->createThreadStop();
3714+
3715+
// This should NOT execute
3716+
v = builder->addConstValue("inner_after");
3717+
builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
3718+
3719+
auto proc2Code = builder->build();
3720+
3721+
// Outer procedure (proc1): calls proc2
3722+
BlockPrototype prototype1;
3723+
prototype1.setProcCode("proc1");
3724+
prototype1.setWarp(false);
3725+
3726+
builder = m_utils.createBuilder(&sprite, &prototype1);
3727+
3728+
v = builder->addConstValue("outer_before");
3729+
builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
3730+
builder->createProcedureCall(&prototype2, {});
3731+
3732+
// This should NOT execute (thread was stopped by proc2)
3733+
v = builder->addConstValue("outer_after");
3734+
builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
3735+
auto proc1Code = builder->build();
3736+
3737+
// Root script: calls proc1
3738+
builder = m_utils.createBuilder(&sprite, false);
3739+
v = builder->addConstValue("script_before");
3740+
builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
3741+
builder->createProcedureCall(&prototype1, {});
3742+
3743+
// This should NOT execute (thread was stopped by proc2 via proc1)
3744+
v = builder->addConstValue("script_after");
3745+
builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
3746+
3747+
auto code = builder->build();
3748+
Script script(&sprite, nullptr, nullptr);
3749+
script.setCode(code);
3750+
3751+
Thread thread(&sprite, nullptr, &script);
3752+
auto ctx = code->createExecutionContext(&thread);
3753+
3754+
// First run: enters the repeat loop in proc2, yields after first iteration
3755+
std::string expected1 =
3756+
"script_before\n"
3757+
"outer_before\n"
3758+
"inner_before\n"
3759+
"inner_loop\n";
3760+
3761+
testing::internal::CaptureStdout();
3762+
code->run(ctx.get());
3763+
ASSERT_EQ(testing::internal::GetCapturedStdout(), expected1);
3764+
ASSERT_FALSE(code->isFinished(ctx.get()));
3765+
3766+
// Second run: second iteration of the repeat loop, yields again
3767+
std::string expected2 = "inner_loop\n";
3768+
testing::internal::CaptureStdout();
3769+
code->run(ctx.get());
3770+
ASSERT_EQ(testing::internal::GetCapturedStdout(), expected2);
3771+
ASSERT_FALSE(code->isFinished(ctx.get()));
3772+
3773+
// Third run: loop is done, createThreadStop() fires, sentinel must propagate
3774+
// through the coroutine resume path back to the caller
3775+
// Neither "inner_after", "outer_after", nor "script_after" should print
3776+
testing::internal::CaptureStdout();
3777+
code->run(ctx.get());
3778+
ASSERT_EQ(testing::internal::GetCapturedStdout(), "");
3779+
ASSERT_TRUE(code->isFinished(ctx.get()));
3780+
}
3781+
36893782
TEST_F(LLVMCodeBuilderTest, HatPredicates)
36903783
{
36913784
Sprite sprite;

0 commit comments

Comments
 (0)