diff --git a/be/src/util/threadpool.cpp b/be/src/util/threadpool.cpp index e5dce3adb6..ebce8e20f2 100644 --- a/be/src/util/threadpool.cpp +++ b/be/src/util/threadpool.cpp @@ -557,20 +557,26 @@ void ThreadPool::dispatch_thread() { state == ThreadPoolToken::State::QUIESCING); --token->_active_threads; --token->_num_submitted_tasks; + + // handle shutdown && idle if (token->_active_threads == 0) { if (state == ThreadPoolToken::State::QUIESCING) { DCHECK(token->_entries.empty()); token->transition(ThreadPoolToken::State::QUIESCED); } else if (token->_entries.empty()) { token->transition(ThreadPoolToken::State::IDLE); - } else if (token->mode() == ExecutionMode::SERIAL) { - _queue.emplace_back(token); - ++token->_num_submitted_tasks; - --token->_num_unsubmitted_tasks; } - } else if (token->mode() == ExecutionMode::CONCURRENT && - token->_num_submitted_tasks < token->_max_concurrency && - token->_num_unsubmitted_tasks > 0) { + } + + // We decrease _num_submitted_tasks holding lock, so the following DCHECK works. + DCHECK(token->_num_submitted_tasks < token->_max_concurrency); + + // If token->state is running and there are unsubmitted tasks in the token, we put + // the token back. + if (token->_num_unsubmitted_tasks > 0 && state == ThreadPoolToken::State::RUNNING) { + // SERIAL: if _entries is not empty, then num_unsubmitted_tasks must be greater than 0. + // CONCURRENT: we have to check _num_unsubmitted_tasks because there may be at least 2 + // threads are running for the token. _queue.emplace_back(token); ++token->_num_submitted_tasks; --token->_num_unsubmitted_tasks; diff --git a/be/test/util/threadpool_test.cpp b/be/test/util/threadpool_test.cpp index eceda73f55..67924e8711 100644 --- a/be/test/util/threadpool_test.cpp +++ b/be/test/util/threadpool_test.cpp @@ -893,4 +893,46 @@ TEST_F(ThreadPoolTest, TestThreadPoolDynamicAdjustMaximumMinimum) { EXPECT_EQ(0, _pool->num_threads()); } +TEST_F(ThreadPoolTest, TestThreadTokenSerial) { + std::unique_ptr thread_pool; + ThreadPoolBuilder("my_pool") + .set_min_threads(0) + .set_max_threads(1) + .set_max_queue_size(10) + .set_idle_timeout(std::chrono::milliseconds(2000)) + .build(&thread_pool); + + std::unique_ptr token1 = + thread_pool->new_token(ThreadPool::ExecutionMode::SERIAL, 2); + token1->submit_func(std::bind(&MyFunc, 0, 1)); + std::cout << "after submit 1" << std::endl; + token1->wait(); + ASSERT_EQ(0, token1->num_tasks()); + for (int i = 0; i < 10; i++) { + token1->submit_func(std::bind(&MyFunc, i, 1)); + } + std::cout << "after submit 1" << std::endl; + token1->wait(); + ASSERT_EQ(0, token1->num_tasks()); +} + +TEST_F(ThreadPoolTest, TestThreadTokenConcurrent) { + std::unique_ptr thread_pool; + ThreadPoolBuilder("my_pool") + .set_min_threads(0) + .set_max_threads(1) + .set_max_queue_size(10) + .set_idle_timeout(std::chrono::milliseconds(2000)) + .build(&thread_pool); + + std::unique_ptr token1 = + thread_pool->new_token(ThreadPool::ExecutionMode::CONCURRENT, 2); + for (int i = 0; i < 10; i++) { + token1->submit_func(std::bind(&MyFunc, i, 1)); + } + std::cout << "after submit 1" << std::endl; + token1->wait(); + ASSERT_EQ(0, token1->num_tasks()); +} + } // namespace doris