From f8657c62c60dffe01e27f8d47028b533c0837d2c Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Tue, 16 Jun 2020 16:48:19 -0700 Subject: [PATCH] Parallel device: avoid deadlocks when the EagerContext's default executor is async Creates one sync executor per thread. Requires fixing a tangential use-after-free where the context assumed all of the thread-local executors were still allocated at shutdown. PiperOrigin-RevId: 316783819 Change-Id: I62e7a91dcccb847d4e1c2a5f08e30c2877556618 --- tensorflow/c/eager/c_api_experimental_test.cc | 29 +++++++++++++++++ .../parallel_device/parallel_device_lib.cc | 18 +++++++++++ .../parallel_device/parallel_device_test.cc | 6 +--- .../core/common_runtime/eager/context.cc | 32 ++++++++++++++++++- .../core/common_runtime/eager/context.h | 2 ++ .../common_runtime/eager/eager_executor.cc | 11 +++++++ .../common_runtime/eager/eager_executor.h | 10 ++++++ .../parallel_device/parallel_device_test.py | 24 +++++++++++++- 8 files changed, 125 insertions(+), 7 deletions(-) diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index 0c058398299..a4d31417073 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -212,6 +212,35 @@ TEST(CAPI, CancellationManager) { TFE_DeleteCancellationManager(c_mgr); } +TEST(CAPI, ExecutorContextDestructionOrder) { + TF_Status* status = TF_NewStatus(); + + { + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_DeleteContextOptions(opts); + TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false); + TFE_ContextSetExecutorForThread(ctx, executor); + + TFE_DeleteContext(ctx); + TFE_DeleteExecutor(executor); + } + + { + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_DeleteContextOptions(opts); + TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false); + TFE_ContextSetExecutorForThread(ctx, executor); + + TFE_DeleteExecutor(executor); + TFE_DeleteContext(ctx); + } + TF_DeleteStatus(status); +} + TEST(CAPI, Function_ident_CPU) { // First create a simple identity function. TF_Graph* function_graph = TF_NewGraph(); diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc index 98cd4812610..d0149b29c08 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc @@ -37,6 +37,15 @@ class StatusDeleter { using StatusPtr = std::unique_ptr; +class ExecutorDeleter { + public: + void operator()(TFE_Executor* to_delete) const { + TFE_DeleteExecutor(to_delete); + } +}; + +using ExecutorPtr = std::unique_ptr; + } // namespace // Allows a single op at a time to be launched without blocking. @@ -51,6 +60,13 @@ class DeviceThread { explicit DeviceThread(const std::string& device) : status_(TF_NewStatus()), device_(device), + // If the context's default exector is set to async, re-using that in + // each thread would cause collectives to deadlock. For consistency we + // create a new sync executor for every thread. + // + // TODO(allenl): We should have an async API that works with the + // parallel device. + executor_(TFE_NewExecutor(/*is_async=*/false)), op_(nullptr), thread_(tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "parallel_device_execute", @@ -105,6 +121,7 @@ class DeviceThread { StatusPtr status_ TF_GUARDED_BY(execution_mutex_); const std::string device_; + ExecutorPtr executor_ TF_GUARDED_BY(execution_mutex_); mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_); std::unique_ptr thread_; }; @@ -186,6 +203,7 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name, std::vector* outputs, TF_Status* status) const { if (op_ == nullptr) { + TFE_ContextSetExecutorForThread(context, executor_.get()); op_.reset(TFE_NewOp(context, operation_name, status)); if (TF_GetCode(status) != TF_OK) return; TFE_OpSetDevice(op_.get(), device_.c_str(), status); diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc index e5412dbba61..2fa183d50f6 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc @@ -412,6 +412,7 @@ void TestCollective(bool async) { TF_NewStatus(), TF_DeleteStatus); std::unique_ptr opts( TFE_NewContextOptions(), TFE_DeleteContextOptions); + TFE_ContextOptionsSetAsync(opts.get(), async); std::unique_ptr config( TF_CreateConfig( /*xla*/ false, @@ -423,9 +424,6 @@ void TestCollective(bool async) { std::unique_ptr context( TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - std::unique_ptr executor( - TFE_NewExecutor(async), TFE_DeleteExecutor); - TFE_ContextSetExecutorForThread(context.get(), executor.get()); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; std::array underlying_devices{ @@ -455,8 +453,6 @@ void TestCollective(bool async) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ExpectScalarEq(result_components[0].get(), 3.); ExpectScalarEq(result_components[1].get(), 3.); - // Destroying the context's default executor first isn't safe. - context.reset(); } TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); } diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 5d8cb3da6bc..970c2bcbb89 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -341,7 +341,28 @@ void EagerContext::SetExecutorForThread(EagerExecutor* executor) { if (executor == &default_executor_) { thread_local_executor_.erase(std::this_thread::get_id()); } else { - thread_local_executor_[std::this_thread::get_id()] = executor; + auto thread_id = std::this_thread::get_id(); + thread_local_executor_[thread_id] = executor; + auto& executors_with_cleanups = has_cleanup_[thread_id]; + if (executors_with_cleanups.find(executor) == + executors_with_cleanups.end()) { + executors_with_cleanups.insert(executor); + // If the executor is deleted before this context, we need to remove it + // from the map to avoid attempting to sync it in our destructor. + std::function cleanup([this, thread_id, executor]() { + { + tensorflow::mutex_lock l(executor_map_mu_); + auto existing = thread_local_executor_.find(thread_id); + if (existing != thread_local_executor_.end() && + existing->second == executor) { + thread_local_executor_.erase(thread_id); + } + has_cleanup_[thread_id].erase(executor); + } + }); + executor->AddCleanup(reinterpret_cast(this), + std::move(cleanup)); + } } } @@ -525,6 +546,15 @@ EagerContext::~EagerContext() { custom_devices_.clear(); ClearCachesAndThreadExecutors(); + std::unordered_map executors_copy; + { + mutex_lock l(executor_map_mu_); + executors_copy = thread_local_executor_; + } + for (const auto& entry : executors_copy) { + // Let the executor know that its cleanup closure is no longer valid. + entry.second->RemoveCleanups(reinterpret_cast(this)); + } for (auto& entry : registered_functions_) { while (!entry.second->Unref()) { // remove all references. diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index fa57afecbaf..cb6d09f8f1d 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -639,6 +639,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { // Not owned. std::unordered_map thread_local_executor_ TF_GUARDED_BY(executor_map_mu_); + std::unordered_map> + has_cleanup_ TF_GUARDED_BY(executor_map_mu_); const bool log_memory_; diff --git a/tensorflow/core/common_runtime/eager/eager_executor.cc b/tensorflow/core/common_runtime/eager/eager_executor.cc index ddfdabf9472..7fe321edffd 100644 --- a/tensorflow/core/common_runtime/eager/eager_executor.cc +++ b/tensorflow/core/common_runtime/eager/eager_executor.cc @@ -46,6 +46,11 @@ EagerExecutor::~EagerExecutor() { tensorflow::mutex_lock l(node_queue_mutex_); state_ = ExecutorState::kShutDown; nodes_pending_.notify_all(); + for (const auto& cleanups_for_key : cleanups_) { + for (const std::function& cleanup : cleanups_for_key.second) { + cleanup(); + } + } } Status EagerExecutor::ShutDown() { @@ -413,4 +418,10 @@ Status EagerExecutor::MoveToUnfinished(core::RefCountPtr item, return Status::OK(); } +void EagerExecutor::AddCleanup(intptr_t key, std::function callback) { + cleanups_[key].push_back(callback); +} + +void EagerExecutor::RemoveCleanups(intptr_t key) { cleanups_.erase(key); } + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/eager_executor.h b/tensorflow/core/common_runtime/eager/eager_executor.h index aa8864c7ad6..34847abc26a 100644 --- a/tensorflow/core/common_runtime/eager/eager_executor.h +++ b/tensorflow/core/common_runtime/eager/eager_executor.h @@ -153,6 +153,13 @@ class EagerExecutor { bool ok() const TF_NO_THREAD_SAFETY_ANALYSIS { return ok_; } + // On destruction, runs `callback`. Used by the EagerContext for clearing + // thread-local executors. + void AddCleanup(intptr_t key, std::function callback); + // If `key` (e.g. a context) is destroyed before the executor, the associated + // callbacks are no longer safe to run. + void RemoveCleanups(intptr_t key); + private: // Possible states for this executor. // Executor starts in kActive state. When Shutdown() is called, Executor @@ -250,6 +257,9 @@ class EagerExecutor { const eager::EagerClient* last_eager_client_; const bool enable_async_wait_for_remote_function_; + + // Callbacks to run on destruction. + std::unordered_map>> cleanups_; }; inline bool EagerExecutor::Async() const { return thread_ != nullptr; } diff --git a/tensorflow/python/distribute/parallel_device/parallel_device_test.py b/tensorflow/python/distribute/parallel_device/parallel_device_test.py index 9dbf258f70f..8fc3dcb5816 100644 --- a/tensorflow/python/distribute/parallel_device/parallel_device_test.py +++ b/tensorflow/python/distribute/parallel_device/parallel_device_test.py @@ -23,6 +23,7 @@ import threading from tensorflow.python.distribute.parallel_device import parallel_device from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.module import module @@ -136,7 +137,7 @@ class ParallelDeviceTests(_VirtualDeviceTestCase): self.assertIn(self.device.components[0], outputs[0].backing_device) self.assertIn(self.device.components[1], outputs[1].backing_device) - def test_collective_reduce_async(self): + def test_collective_reduce_async_scope(self): # Note that ops on the parallel device currently don't execute # asynchronously. The test is just that we don't get deadlocks. with context.async_scope(), ops.device(self.device.name): @@ -149,6 +150,27 @@ class ParallelDeviceTests(_VirtualDeviceTestCase): self.assertIn(self.device.components[0], outputs[0].backing_device) self.assertIn(self.device.components[1], outputs[1].backing_device) + def test_collective_reduce_async_context(self): + previous = config.get_synchronous_execution() + try: + context._reset_context() + config.set_synchronous_execution(False) + self.setUp() + # Note that ops on the parallel device currently don't execute + # asynchronously. The test is just that we don't get deadlocks. + with ops.device(self.device.name): + x = self.device.pack( + [constant_op.constant(-1.5), + constant_op.constant(3.5)]) + reduced = _collective_sum(x, num_replicas=2) + outputs = self.device.unpack(reduced) + self.assertAllClose([2., 2.], outputs) + self.assertIn(self.device.components[0], outputs[0].backing_device) + self.assertIn(self.device.components[1], outputs[1].backing_device) + finally: + context._reset_context() + config.set_synchronous_execution(previous) + def test_checkpointing(self): prefix = os.path.join(self.get_temp_dir(), "ckpt") with self.device.scope():