Fix heap-use-after-free bug in RunHandlerPool

Here is the sequence of events which leads to heap-use-after-free:

1) The test execution finishes and the destructor of RunHandlerPool is called
2) The RunHandlerPool destructor attempts to free its std::unique_ptr<RunHandlerPool::Impl>
3) RunHandlerPool::Impl destructor frees its std::vector<RunHandler::Impl*> sorted_active_handlers_. And RunHandler::Impl frees its RunHandlerThreadPool::ThreadWorkSource tws_
4) At this point, the std::unique_ptr<RunHandlerThreadPool> run_handler_thread_pool_, which belongs to RunHandlerPool::Impl, has not been freed yet. So its threads are still running and may try to access RunHandlerThreadPool::ThreadWorkSource tws_, which have been freed in step 3).

Given the understanding of the root cause, the solution is to ensure that RunHandlerPool::Impl frees its run_handler_thread_pool_ before freeing other pointers.

PiperOrigin-RevId: 254394702
This commit is contained in:
Dong Lin 2019-06-21 07:30:55 -07:00 committed by TensorFlower Gardener
parent 847a10f32d
commit a858d3ae90
2 changed files with 8 additions and 4 deletions

View File

@ -370,6 +370,10 @@ class RunHandlerPool::Impl {
DCHECK_EQ(handlers_.size(), max_handlers_);
DCHECK_EQ(free_handlers_.size(), handlers_.size());
DCHECK_EQ(sorted_active_handlers_.size(), 0);
// Stop the threads in run_handler_thread_pool_ before freeing other
// pointers. Otherwise a thread may try to access a pointer after the
// pointer has been freed.
run_handler_thread_pool_.reset();
}
RunHandlerThreadPool* run_handler_thread_pool() {

View File

@ -37,7 +37,8 @@ TEST(RunHandlerUtilTest, TestBasicScheduling) {
int num_threads = 2;
int num_handlers = 10;
std::unique_ptr<RunHandlerPool> pool(new RunHandlerPool(num_threads));
std::unique_ptr<RunHandlerPool> pool(
new RunHandlerPool(num_threads, num_threads));
// RunHandler has 2 * num_threads (inter + intra) -
// all should be able to run concurrently.
@ -46,9 +47,8 @@ TEST(RunHandlerUtilTest, TestBasicScheduling) {
BlockingCounter counter(2 * num_handlers * num_threads);
int num_test_threads = 10;
thread::ThreadPool test_pool(Env::Default(), "test", num_test_threads);
for (int i = 0; i < 10; ++i) {
thread::ThreadPool test_pool(Env::Default(), "test", num_handlers);
for (int i = 0; i < num_handlers; ++i) {
test_pool.Schedule([&counter, &barrier1, &barrier2, &pool, i,
num_threads]() {
auto handler = pool->Get();