Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#pragma once

#include <utility>
namespace olp {
namespace client {

Expand Down Expand Up @@ -58,9 +59,10 @@ inline void CancellationContext::CancelOperation() {
return;
}

impl_->sub_operation_cancel_token_.Cancel();
impl_->sub_operation_cancel_token_ = CancellationToken();
auto token = CancellationToken();
std::swap(token, impl_->sub_operation_cancel_token_);
impl_->is_cancelled_ = true;
token.Cancel();
}

inline bool CancellationContext::IsCancelled() const {
Expand Down
51 changes: 47 additions & 4 deletions olp-cpp-sdk-core/include/olp/core/client/TaskContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,19 @@ class CORE_API TaskContext {
*/
void SetExecutors(Exec execute_func, Callback callback,
client::CancellationContext context) {
impl_ = std::make_shared<TaskContextImpl<ExecResult>>(
std::move(execute_func), std::move(callback), std::move(context));
auto impl = std::make_shared<TaskContextImpl<ExecResult>>(
std::move(execute_func), std::move(callback), context);
std::weak_ptr<TaskContextImpl<ExecResult>> weak_impl = impl;
context.ExecuteOrCancelled(
[weak_impl]() -> CancellationToken {
return CancellationToken([weak_impl]() {
if (auto impl = weak_impl.lock()) {
impl->PreExecuteCancel();
}
});
},
[]() {});
impl_ = std::move(impl);
}

/**
Expand Down Expand Up @@ -195,8 +206,6 @@ class CORE_API TaskContext {
context_(std::move(context)),
state_{State::PENDING} {}

~TaskContextImpl() override{};

/**
* @brief Checks for the cancellation, executes the task, and calls
* the callback with the result or error.
Expand Down Expand Up @@ -249,6 +258,40 @@ class CORE_API TaskContext {
state_.store(State::COMPLETED);
}

void PreExecuteCancel() {
State expected_state = State::PENDING;

if (!state_.compare_exchange_strong(expected_state, State::IN_PROGRESS)) {
return;
}

// Moving the user callback and function guarantee that they are
// executed exactly once
ExecuteFunc function = nullptr;
UserCallback callback = nullptr;

{
std::lock_guard<std::mutex> lock(mutex_);
function = std::move(execute_func_);
callback = std::move(callback_);
}

Response user_response =
client::ApiError(client::ErrorCode::Cancelled, "Cancelled");

if (callback) {
callback(std::move(user_response));
}

// Resources need to be released before the notification, else lambas
// would have captured resources like network or `TaskScheduler`.
function = nullptr;
callback = nullptr;

condition_.Notify();
state_.store(State::COMPLETED);
}

/**
* @brief Cancels the operation and waits for the notification.
*
Expand Down
70 changes: 55 additions & 15 deletions olp-cpp-sdk-core/tests/client/TaskContextTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,18 @@ class TaskContextTestable : public TaskContext {
void SetExecutors(Exec execute_func, Callback callback,
CancellationContext context) {
auto impl = std::make_shared<TaskContextImpl<ExecResult>>(
std::move(execute_func), std::move(callback), std::move(context));
notify = [=]() { impl->condition_.Notify(); };
std::move(execute_func), std::move(callback), context);
std::weak_ptr<TaskContextImpl<ExecResult>> weak_impl = impl;
context.ExecuteOrCancelled(
[weak_impl]() -> CancellationToken {
return CancellationToken([weak_impl]() {
if (auto impl = weak_impl.lock()) {
impl->PreExecuteCancel();
}
});
},
[]() {});
notify = [impl]() { impl->condition_.Notify(); };
impl_ = impl;
}
};
Expand Down Expand Up @@ -127,28 +137,58 @@ TEST(TaskContextTest, ExecuteSimple) {
}

TEST(TaskContextTest, BlockingCancel) {
ExecuteFunc func = [&](CancellationContext c) -> Response {
EXPECT_TRUE(c.IsCancelled());
return std::string("Success");
};

Response response;

Callback callback = [&](Response r) { response = std::move(r); };

TaskContext context = TaskContext::Create(func, callback);
{
SCOPED_TRACE("Pre-exec cancellation");
bool executed = false;
ExecuteFunc func = [&](CancellationContext) -> Response {
executed = true;
return std::string("Success");
};

TaskContext context = TaskContext::Create(func, callback);
EXPECT_TRUE(context.BlockingCancel(std::chrono::seconds(0)));
EXPECT_FALSE(executed);
EXPECT_FALSE(response.IsSuccessful());
EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled);
}

EXPECT_FALSE(context.BlockingCancel(std::chrono::seconds(0)));
{
SCOPED_TRACE("Cancel during execution");
Condition continue_execution;
Condition execution_started;
int execution_count = 0;
response = Response{};
ExecuteFunc func = [&](CancellationContext c) -> Response {
++execution_count;
execution_started.Notify();
EXPECT_TRUE(continue_execution.Wait(kWaitTime));
const auto deadline = std::chrono::steady_clock::now() + kWaitTime;
while (!c.IsCancelled() && std::chrono::steady_clock::now() < deadline) {
std::this_thread::yield();
}
EXPECT_TRUE(c.IsCancelled());
return std::string("Success");
};
TaskContext context = TaskContext::Create(func, callback);

std::thread cancel_thread([&]() { EXPECT_TRUE(context.BlockingCancel()); });
std::thread execute_thread([&]() { context.Execute(); });
EXPECT_TRUE(execution_started.Wait());

std::thread execute_thread([&]() { context.Execute(); });
std::thread cancel_thread([&]() { EXPECT_TRUE(context.BlockingCancel()); });

execute_thread.join();
cancel_thread.join();
continue_execution.Notify();

EXPECT_FALSE(response.IsSuccessful());
EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled);
execute_thread.join();
cancel_thread.join();

EXPECT_EQ(execution_count, 1);
EXPECT_FALSE(response.IsSuccessful());
EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled);
}
}

TEST(TaskContextTest, BlockingCancelIsWaiting) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* License-Filename: LICENSE
*/

#include <future>
#include <thread>

#include <gtest/gtest.h>
Expand Down Expand Up @@ -1138,15 +1139,20 @@ TEST(VersionedLayerClientTest, PrefetchPartitionsCancel) {
settings);
{
SCOPED_TRACE("Cancel request");
std::promise<void> block_promise;
auto block_future = block_promise.get_future();
settings.task_scheduler->ScheduleTask(
[&block_future]() { block_future.get(); });
std::promise<void> block_task;
std::promise<void> block_main;
auto block_task_future = block_task.get_future();
auto block_main_future = block_main.get_future();
settings.task_scheduler->ScheduleTask([&block_task_future, &block_main]() {
block_main.set_value();
block_task_future.get();
});
auto cancellable = client.PrefetchPartitions(request, nullptr);

// cancel the request and unblock queue
cancellable.GetCancellationToken().Cancel();
block_promise.set_value();
block_main_future.wait();
block_task.set_value();
auto future = cancellable.GetFuture();

ASSERT_EQ(future.wait_for(kTimeout), std::future_status::ready);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,10 @@ TEST(VolatileLayerClientImplTest, GetDataCancelOnClientDestroy) {
read::DataResponse data_response;
{
// Client owns the task scheduler
auto caller_thread_id = std::this_thread::get_id();
read::VolatileLayerClientImpl client(kHrn, kLayerId, std::move(settings));
client.GetData(read::DataRequest().WithPartitionId(kPartitionId),
[&](read::DataResponse response) {
data_response = std::move(response);
EXPECT_NE(caller_thread_id, std::this_thread::get_id());
});
}

Expand Down Expand Up @@ -1013,8 +1011,7 @@ TEST(VolatileLayerClientImplTest, PrefetchTilesCancelOnClientDestroy) {

read::PrefetchTilesResponse response;
{
// Client owns the task scheduler
auto caller_thread_id = std::this_thread::get_id();
// Client owns the task schedule
read::VolatileLayerClientImpl client(kHrn, kLayerId, std::move(settings));
std::vector<olp::geo::TileKey> tile_keys = {
olp::geo::TileKey::FromHereTile(kTileId)};
Expand All @@ -1023,11 +1020,10 @@ TEST(VolatileLayerClientImplTest, PrefetchTilesCancelOnClientDestroy) {
.WithMinLevel(11)
.WithMaxLevel(12);

client.PrefetchTiles(
request, [&](read::PrefetchTilesResponse prefetch_response) {
response = std::move(prefetch_response);
EXPECT_NE(caller_thread_id, std::this_thread::get_id());
});
client.PrefetchTiles(request,
[&](read::PrefetchTilesResponse prefetch_response) {
response = std::move(prefetch_response);
});
}

// Callback must be called during client destructor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,15 +301,21 @@ TEST_F(VersionedLayerClientPrefetchPartitionsTest, PrefetchPartitionsCancel) {
const auto request =
read::PrefetchPartitionsRequest().WithPartitionIds(partitions);

std::promise<void> block_promise;
auto block_future = block_promise.get_future();
std::promise<void> block_task_promise;
std::promise<void> block_main_promise;
auto block_future = block_task_promise.get_future();
auto block_main_future = block_main_promise.get_future();
settings_.task_scheduler->ScheduleTask(
[&block_future]() { block_future.get(); });
[&block_future, &block_main_promise]() {
block_main_promise.set_value();
block_future.get();
});
auto cancellable = client.PrefetchPartitions(request);

// cancel the request and unblock queue
cancellable.GetCancellationToken().Cancel();
block_promise.set_value();
block_main_future.get();
block_task_promise.set_value();
auto future = cancellable.GetFuture();

ASSERT_EQ(future.wait_for(kTimeout), std::future_status::ready);
Expand Down
Loading