From 5e49b85d0c35c09afa28e8b64437de0b17445863 Mon Sep 17 00:00:00 2001 From: Mykhailo Diachenko Date: Wed, 1 Apr 2026 18:56:13 +0300 Subject: [PATCH] Preemptive cancellation of TaskContext jobs - Enable pre-execution cancellation in TaskContext by registering a cancellation token in constructor that invokes PreExecuteCancel() - Clear cancellation token in ~TaskContextImpl to prevent dangling refs - Fix CancellationContext::CancelOperation() to swap out the token before calling Cancel(), avoiding re-entrancy issues - Update BlockingCancel test to cover pre-execution cancellation Relates-To: HERESDK-12253 Signed-off-by: Mykhailo Diachenko --- .../olp/core/client/CancellationContext.inl | 6 +- .../include/olp/core/client/TaskContext.h | 51 ++++++++++++-- .../tests/client/TaskContextTest.cpp | 70 +++++++++++++++---- .../tests/VersionedLayerClientImplTest.cpp | 16 +++-- .../tests/VolatileLayerClientImplTest.cpp | 14 ++-- ...ionedLayerClientPrefetchPartitionsTest.cpp | 14 ++-- 6 files changed, 132 insertions(+), 39 deletions(-) diff --git a/olp-cpp-sdk-core/include/olp/core/client/CancellationContext.inl b/olp-cpp-sdk-core/include/olp/core/client/CancellationContext.inl index 2d58fc903..eeb39bc98 100644 --- a/olp-cpp-sdk-core/include/olp/core/client/CancellationContext.inl +++ b/olp-cpp-sdk-core/include/olp/core/client/CancellationContext.inl @@ -19,6 +19,7 @@ #pragma once +#include namespace olp { namespace client { @@ -58,9 +59,10 @@ inline void CancellationContext::CancelOperation() { return; } - impl_->sub_operation_cancel_token_.Cancel(); - impl_->sub_operation_cancel_token_ = CancellationToken(); + auto token = CancellationToken(); + std::swap(token, impl_->sub_operation_cancel_token_); impl_->is_cancelled_ = true; + token.Cancel(); } inline bool CancellationContext::IsCancelled() const { diff --git a/olp-cpp-sdk-core/include/olp/core/client/TaskContext.h b/olp-cpp-sdk-core/include/olp/core/client/TaskContext.h index e8a1f32fc..060d50e30 100644 --- a/olp-cpp-sdk-core/include/olp/core/client/TaskContext.h +++ b/olp-cpp-sdk-core/include/olp/core/client/TaskContext.h @@ -127,8 +127,19 @@ class CORE_API TaskContext { */ void SetExecutors(Exec execute_func, Callback callback, client::CancellationContext context) { - impl_ = std::make_shared>( - std::move(execute_func), std::move(callback), std::move(context)); + auto impl = std::make_shared>( + std::move(execute_func), std::move(callback), context); + std::weak_ptr> weak_impl = impl; + context.ExecuteOrCancelled( + [weak_impl]() -> CancellationToken { + return CancellationToken([weak_impl]() { + if (auto impl = weak_impl.lock()) { + impl->PreExecuteCancel(); + } + }); + }, + []() {}); + impl_ = std::move(impl); } /** @@ -195,8 +206,6 @@ class CORE_API TaskContext { context_(std::move(context)), state_{State::PENDING} {} - ~TaskContextImpl() override{}; - /** * @brief Checks for the cancellation, executes the task, and calls * the callback with the result or error. @@ -249,6 +258,40 @@ class CORE_API TaskContext { state_.store(State::COMPLETED); } + void PreExecuteCancel() { + State expected_state = State::PENDING; + + if (!state_.compare_exchange_strong(expected_state, State::IN_PROGRESS)) { + return; + } + + // Moving the user callback and function guarantee that they are + // executed exactly once + ExecuteFunc function = nullptr; + UserCallback callback = nullptr; + + { + std::lock_guard lock(mutex_); + function = std::move(execute_func_); + callback = std::move(callback_); + } + + Response user_response = + client::ApiError(client::ErrorCode::Cancelled, "Cancelled"); + + if (callback) { + callback(std::move(user_response)); + } + + // Resources need to be released before the notification, else lambas + // would have captured resources like network or `TaskScheduler`. + function = nullptr; + callback = nullptr; + + condition_.Notify(); + state_.store(State::COMPLETED); + } + /** * @brief Cancels the operation and waits for the notification. * diff --git a/olp-cpp-sdk-core/tests/client/TaskContextTest.cpp b/olp-cpp-sdk-core/tests/client/TaskContextTest.cpp index d2037e09b..485f94c40 100644 --- a/olp-cpp-sdk-core/tests/client/TaskContextTest.cpp +++ b/olp-cpp-sdk-core/tests/client/TaskContextTest.cpp @@ -60,8 +60,18 @@ class TaskContextTestable : public TaskContext { void SetExecutors(Exec execute_func, Callback callback, CancellationContext context) { auto impl = std::make_shared>( - std::move(execute_func), std::move(callback), std::move(context)); - notify = [=]() { impl->condition_.Notify(); }; + std::move(execute_func), std::move(callback), context); + std::weak_ptr> weak_impl = impl; + context.ExecuteOrCancelled( + [weak_impl]() -> CancellationToken { + return CancellationToken([weak_impl]() { + if (auto impl = weak_impl.lock()) { + impl->PreExecuteCancel(); + } + }); + }, + []() {}); + notify = [impl]() { impl->condition_.Notify(); }; impl_ = impl; } }; @@ -127,28 +137,58 @@ TEST(TaskContextTest, ExecuteSimple) { } TEST(TaskContextTest, BlockingCancel) { - ExecuteFunc func = [&](CancellationContext c) -> Response { - EXPECT_TRUE(c.IsCancelled()); - return std::string("Success"); - }; - Response response; Callback callback = [&](Response r) { response = std::move(r); }; - TaskContext context = TaskContext::Create(func, callback); + { + SCOPED_TRACE("Pre-exec cancellation"); + bool executed = false; + ExecuteFunc func = [&](CancellationContext) -> Response { + executed = true; + return std::string("Success"); + }; + + TaskContext context = TaskContext::Create(func, callback); + EXPECT_TRUE(context.BlockingCancel(std::chrono::seconds(0))); + EXPECT_FALSE(executed); + EXPECT_FALSE(response.IsSuccessful()); + EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled); + } - EXPECT_FALSE(context.BlockingCancel(std::chrono::seconds(0))); + { + SCOPED_TRACE("Cancel during execution"); + Condition continue_execution; + Condition execution_started; + int execution_count = 0; + response = Response{}; + ExecuteFunc func = [&](CancellationContext c) -> Response { + ++execution_count; + execution_started.Notify(); + EXPECT_TRUE(continue_execution.Wait(kWaitTime)); + const auto deadline = std::chrono::steady_clock::now() + kWaitTime; + while (!c.IsCancelled() && std::chrono::steady_clock::now() < deadline) { + std::this_thread::yield(); + } + EXPECT_TRUE(c.IsCancelled()); + return std::string("Success"); + }; + TaskContext context = TaskContext::Create(func, callback); - std::thread cancel_thread([&]() { EXPECT_TRUE(context.BlockingCancel()); }); + std::thread execute_thread([&]() { context.Execute(); }); + EXPECT_TRUE(execution_started.Wait()); - std::thread execute_thread([&]() { context.Execute(); }); + std::thread cancel_thread([&]() { EXPECT_TRUE(context.BlockingCancel()); }); - execute_thread.join(); - cancel_thread.join(); + continue_execution.Notify(); - EXPECT_FALSE(response.IsSuccessful()); - EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled); + execute_thread.join(); + cancel_thread.join(); + + EXPECT_EQ(execution_count, 1); + EXPECT_FALSE(response.IsSuccessful()); + EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled); + } } TEST(TaskContextTest, BlockingCancelIsWaiting) { diff --git a/olp-cpp-sdk-dataservice-read/tests/VersionedLayerClientImplTest.cpp b/olp-cpp-sdk-dataservice-read/tests/VersionedLayerClientImplTest.cpp index 7e41bc113..155dd2948 100644 --- a/olp-cpp-sdk-dataservice-read/tests/VersionedLayerClientImplTest.cpp +++ b/olp-cpp-sdk-dataservice-read/tests/VersionedLayerClientImplTest.cpp @@ -17,6 +17,7 @@ * License-Filename: LICENSE */ +#include #include #include @@ -1138,15 +1139,20 @@ TEST(VersionedLayerClientTest, PrefetchPartitionsCancel) { settings); { SCOPED_TRACE("Cancel request"); - std::promise block_promise; - auto block_future = block_promise.get_future(); - settings.task_scheduler->ScheduleTask( - [&block_future]() { block_future.get(); }); + std::promise block_task; + std::promise block_main; + auto block_task_future = block_task.get_future(); + auto block_main_future = block_main.get_future(); + settings.task_scheduler->ScheduleTask([&block_task_future, &block_main]() { + block_main.set_value(); + block_task_future.get(); + }); auto cancellable = client.PrefetchPartitions(request, nullptr); // cancel the request and unblock queue cancellable.GetCancellationToken().Cancel(); - block_promise.set_value(); + block_main_future.wait(); + block_task.set_value(); auto future = cancellable.GetFuture(); ASSERT_EQ(future.wait_for(kTimeout), std::future_status::ready); diff --git a/olp-cpp-sdk-dataservice-read/tests/VolatileLayerClientImplTest.cpp b/olp-cpp-sdk-dataservice-read/tests/VolatileLayerClientImplTest.cpp index 52d6424e5..de74d4582 100644 --- a/olp-cpp-sdk-dataservice-read/tests/VolatileLayerClientImplTest.cpp +++ b/olp-cpp-sdk-dataservice-read/tests/VolatileLayerClientImplTest.cpp @@ -357,12 +357,10 @@ TEST(VolatileLayerClientImplTest, GetDataCancelOnClientDestroy) { read::DataResponse data_response; { // Client owns the task scheduler - auto caller_thread_id = std::this_thread::get_id(); read::VolatileLayerClientImpl client(kHrn, kLayerId, std::move(settings)); client.GetData(read::DataRequest().WithPartitionId(kPartitionId), [&](read::DataResponse response) { data_response = std::move(response); - EXPECT_NE(caller_thread_id, std::this_thread::get_id()); }); } @@ -1013,8 +1011,7 @@ TEST(VolatileLayerClientImplTest, PrefetchTilesCancelOnClientDestroy) { read::PrefetchTilesResponse response; { - // Client owns the task scheduler - auto caller_thread_id = std::this_thread::get_id(); + // Client owns the task schedule read::VolatileLayerClientImpl client(kHrn, kLayerId, std::move(settings)); std::vector tile_keys = { olp::geo::TileKey::FromHereTile(kTileId)}; @@ -1023,11 +1020,10 @@ TEST(VolatileLayerClientImplTest, PrefetchTilesCancelOnClientDestroy) { .WithMinLevel(11) .WithMaxLevel(12); - client.PrefetchTiles( - request, [&](read::PrefetchTilesResponse prefetch_response) { - response = std::move(prefetch_response); - EXPECT_NE(caller_thread_id, std::this_thread::get_id()); - }); + client.PrefetchTiles(request, + [&](read::PrefetchTilesResponse prefetch_response) { + response = std::move(prefetch_response); + }); } // Callback must be called during client destructor. diff --git a/tests/integration/olp-cpp-sdk-dataservice-read/VersionedLayerClientPrefetchPartitionsTest.cpp b/tests/integration/olp-cpp-sdk-dataservice-read/VersionedLayerClientPrefetchPartitionsTest.cpp index 7ba4404c1..7ff8f558f 100644 --- a/tests/integration/olp-cpp-sdk-dataservice-read/VersionedLayerClientPrefetchPartitionsTest.cpp +++ b/tests/integration/olp-cpp-sdk-dataservice-read/VersionedLayerClientPrefetchPartitionsTest.cpp @@ -301,15 +301,21 @@ TEST_F(VersionedLayerClientPrefetchPartitionsTest, PrefetchPartitionsCancel) { const auto request = read::PrefetchPartitionsRequest().WithPartitionIds(partitions); - std::promise block_promise; - auto block_future = block_promise.get_future(); + std::promise block_task_promise; + std::promise block_main_promise; + auto block_future = block_task_promise.get_future(); + auto block_main_future = block_main_promise.get_future(); settings_.task_scheduler->ScheduleTask( - [&block_future]() { block_future.get(); }); + [&block_future, &block_main_promise]() { + block_main_promise.set_value(); + block_future.get(); + }); auto cancellable = client.PrefetchPartitions(request); // cancel the request and unblock queue cancellable.GetCancellationToken().Cancel(); - block_promise.set_value(); + block_main_future.get(); + block_task_promise.set_value(); auto future = cancellable.GetFuture(); ASSERT_EQ(future.wait_for(kTimeout), std::future_status::ready);