diff --git a/be/src/exec/scan/scanner_scheduler.h b/be/src/exec/scan/scanner_scheduler.h index de1553b026ed35..3635ffc4d11826 100644 --- a/be/src/exec/scan/scanner_scheduler.h +++ b/be/src/exec/scan/scanner_scheduler.h @@ -292,13 +292,28 @@ class TaskExecutorSimplifiedScanScheduler final : public ScannerScheduler { Status submit_scan_task(SimplifiedScanTask scan_task) override { if (!_is_stop) { + if (scan_task.scanner_context == nullptr) { + return Status::InternalError("scanner pool {} got null scanner context.", + _sched_name); + } + if (scan_task.scan_task == nullptr) { + return Status::InternalError("scanner pool {} got null scan task.", + _sched_name); + } + auto task_handle = scan_task.scanner_context->task_handle(); + if (task_handle == nullptr) { + return Status::InternalError( + "scanner pool {} got null task handle, scan task first schedule: {}, " + "scanner context: {}", + _sched_name, scan_task.scan_task->is_first_schedule, + scan_task.scanner_context->debug_string()); + } std::shared_ptr split_runner; if (scan_task.scan_task->is_first_schedule) { split_runner = std::make_shared("scanner_split_runner", scan_task.scan_func); RETURN_IF_ERROR(split_runner->init()); - auto result = _task_executor->enqueue_splits( - scan_task.scanner_context->task_handle(), false, {split_runner}); + auto result = _task_executor->enqueue_splits(task_handle, false, {split_runner}); if (!result.has_value()) { LOG(WARNING) << "enqueue_splits failed: " << result.error(); return result.error(); @@ -309,8 +324,7 @@ class TaskExecutorSimplifiedScanScheduler final : public ScannerScheduler { if (split_runner == nullptr) { return Status::OK(); } - RETURN_IF_ERROR(_task_executor->re_enqueue_split( - scan_task.scanner_context->task_handle(), false, split_runner)); + RETURN_IF_ERROR(_task_executor->re_enqueue_split(task_handle, false, split_runner)); } scan_task.scan_task->split_runner = split_runner; return Status::OK(); diff --git a/be/src/exec/scan/task_executor/time_sharing/time_sharing_task_executor.cpp b/be/src/exec/scan/task_executor/time_sharing/time_sharing_task_executor.cpp index 494568785e67e7..138d2df197a731 100644 --- a/be/src/exec/scan/task_executor/time_sharing/time_sharing_task_executor.cpp +++ b/be/src/exec/scan/task_executor/time_sharing/time_sharing_task_executor.cpp @@ -48,6 +48,24 @@ extern ::doris::MetricPrototype METRIC_thread_pool_task_execution_count_total; extern ::doris::MetricPrototype METRIC_thread_pool_task_wait_worker_time_ns_total; extern ::doris::MetricPrototype METRIC_thread_pool_task_wait_worker_count_total; +namespace { + +Result> get_time_sharing_task_handle( + const std::shared_ptr& task_handle, const char* operation) { + if (task_handle == nullptr) { + return ResultError(Status::InternalError("{} got null task handle", operation)); + } + + auto handle = std::dynamic_pointer_cast(task_handle); + if (handle == nullptr) { + return ResultError(Status::InternalError("{} got invalid task handle type, task id: {}", + operation, task_handle->task_id().to_string())); + } + return handle; +} + +} // namespace + SplitThreadPoolToken::SplitThreadPoolToken(TimeSharingTaskExecutor* pool, TimeSharingTaskExecutor::ExecutionMode mode, std::shared_ptr split_queue, @@ -744,7 +762,7 @@ Status TimeSharingTaskExecutor::add_task(const TaskId& task_id, } Status TimeSharingTaskExecutor::remove_task(std::shared_ptr task_handle) { - auto handle = std::dynamic_pointer_cast(task_handle); + auto handle = DORIS_TRY(get_time_sharing_task_handle(task_handle, "remove_task")); std::vector> splits_to_destroy; { @@ -807,7 +825,11 @@ Result>> TimeSharingTaskExecutor::enque } }}; std::vector> finished_futures; - auto handle = std::dynamic_pointer_cast(task_handle); + auto handle_result = get_time_sharing_task_handle(task_handle, "enqueue_splits"); + if (!handle_result.has_value()) { + return ResultError(handle_result.error()); + } + auto handle = handle_result.value(); { std::unique_lock lock(_mutex); for (const auto& task_split : splits) { @@ -840,7 +862,7 @@ Result>> TimeSharingTaskExecutor::enque Status TimeSharingTaskExecutor::re_enqueue_split(std::shared_ptr task_handle, bool intermediate, const std::shared_ptr& split) { - auto handle = std::dynamic_pointer_cast(task_handle); + auto handle = DORIS_TRY(get_time_sharing_task_handle(task_handle, "re_enqueue_split")); std::shared_ptr prioritized_split = handle->get_split(split, intermediate); prioritized_split->reset_level_priority(); diff --git a/be/test/exec/executor/time_sharing/time_sharing_task_executor_test.cpp b/be/test/exec/executor/time_sharing/time_sharing_task_executor_test.cpp index ede32923d4177c..bfb4c76632fb5f 100644 --- a/be/test/exec/executor/time_sharing/time_sharing_task_executor_test.cpp +++ b/be/test/exec/executor/time_sharing/time_sharing_task_executor_test.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include "common/exception.h" @@ -337,6 +338,20 @@ class QueueOnlySplitRunner : public SplitRunner { std::atomic _finished {false}; }; +class TestingTaskHandle final : public TaskHandle { +public: + explicit TestingTaskHandle(std::string task_id) : _task_id(std::move(task_id)) {} + + Status init() override { return Status::OK(); } + + bool is_closed() const override { return false; } + + TaskId task_id() const override { return _task_id; } + +private: + TaskId _task_id; +}; + class TimeSharingTaskExecutorTest : public testing::Test { protected: void SetUp() override {} @@ -422,6 +437,48 @@ TEST_F(TimeSharingTaskExecutorTest, test_remove_task_clears_queued_task_count) { executor.stop(); } +TEST_F(TimeSharingTaskExecutorTest, test_invalid_task_handle_returns_error) { + auto ticker = std::make_shared(); + + TimeSharingTaskExecutor::ThreadConfig thread_config; + thread_config.thread_name = "invalid_task_handle"; + thread_config.workload_group = "normal"; + TimeSharingTaskExecutor executor(thread_config, 0, 1, 1, ticker); + ASSERT_TRUE(executor.init().ok()); + + auto split = std::make_shared(); + + auto null_enqueue_result = executor.enqueue_splits(nullptr, false, {split}); + ASSERT_FALSE(null_enqueue_result.has_value()); + EXPECT_NE(std::string(null_enqueue_result.error().msg()).find("null task handle"), + std::string::npos); + + Status null_re_enqueue_status = executor.re_enqueue_split(nullptr, false, split); + ASSERT_FALSE(null_re_enqueue_status.ok()); + EXPECT_NE(std::string(null_re_enqueue_status.msg()).find("null task handle"), + std::string::npos); + + Status null_remove_status = executor.remove_task(nullptr); + ASSERT_FALSE(null_remove_status.ok()); + EXPECT_NE(std::string(null_remove_status.msg()).find("null task handle"), std::string::npos); + + auto invalid_task_handle = std::make_shared("invalid_task"); + auto invalid_enqueue_result = executor.enqueue_splits(invalid_task_handle, false, {split}); + ASSERT_FALSE(invalid_enqueue_result.has_value()); + EXPECT_NE(std::string(invalid_enqueue_result.error().msg()).find("invalid task handle type"), + std::string::npos); + + Status invalid_re_enqueue_status = executor.re_enqueue_split(invalid_task_handle, false, split); + ASSERT_FALSE(invalid_re_enqueue_status.ok()); + EXPECT_NE(std::string(invalid_re_enqueue_status.msg()).find("invalid task handle type"), + std::string::npos); + + Status invalid_remove_status = executor.remove_task(invalid_task_handle); + ASSERT_FALSE(invalid_remove_status.ok()); + EXPECT_NE(std::string(invalid_remove_status.msg()).find("invalid task handle type"), + std::string::npos); +} + TEST_F(TimeSharingTaskExecutorTest, test_tasks_complete) { auto ticker = std::make_shared();