Parallel device: avoid deadlocks when the EagerContext's default executor is async

Creates one sync executor per thread.

Requires fixing a tangential use-after-free where the context assumed all of the thread-local executors were still allocated at shutdown.

PiperOrigin-RevId: 316783819
Change-Id: I62e7a91dcccb847d4e1c2a5f08e30c2877556618
This commit is contained in:
Allen Lavoie 2020-06-16 16:48:19 -07:00 committed by TensorFlower Gardener
parent 267f956246
commit f8657c62c6
8 changed files with 125 additions and 7 deletions

View File

@ -212,6 +212,35 @@ TEST(CAPI, CancellationManager) {
TFE_DeleteCancellationManager(c_mgr); TFE_DeleteCancellationManager(c_mgr);
} }
TEST(CAPI, ExecutorContextDestructionOrder) {
TF_Status* status = TF_NewStatus();
{
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
TFE_ContextSetExecutorForThread(ctx, executor);
TFE_DeleteContext(ctx);
TFE_DeleteExecutor(executor);
}
{
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
TFE_ContextSetExecutorForThread(ctx, executor);
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
TF_DeleteStatus(status);
}
TEST(CAPI, Function_ident_CPU) { TEST(CAPI, Function_ident_CPU) {
// First create a simple identity function. // First create a simple identity function.
TF_Graph* function_graph = TF_NewGraph(); TF_Graph* function_graph = TF_NewGraph();

View File

@ -37,6 +37,15 @@ class StatusDeleter {
using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>; using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
} // namespace } // namespace
// Allows a single op at a time to be launched without blocking. // Allows a single op at a time to be launched without blocking.
@ -51,6 +60,13 @@ class DeviceThread {
explicit DeviceThread(const std::string& device) explicit DeviceThread(const std::string& device)
: status_(TF_NewStatus()), : status_(TF_NewStatus()),
device_(device), device_(device),
// If the context's default exector is set to async, re-using that in
// each thread would cause collectives to deadlock. For consistency we
// create a new sync executor for every thread.
//
// TODO(allenl): We should have an async API that works with the
// parallel device.
executor_(TFE_NewExecutor(/*is_async=*/false)),
op_(nullptr), op_(nullptr),
thread_(tensorflow::Env::Default()->StartThread( thread_(tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "parallel_device_execute", tensorflow::ThreadOptions(), "parallel_device_execute",
@ -105,6 +121,7 @@ class DeviceThread {
StatusPtr status_ TF_GUARDED_BY(execution_mutex_); StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
const std::string device_; const std::string device_;
ExecutorPtr executor_ TF_GUARDED_BY(execution_mutex_);
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_); mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
std::unique_ptr<Thread> thread_; std::unique_ptr<Thread> thread_;
}; };
@ -186,6 +203,7 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
std::vector<TensorHandlePtr>* outputs, std::vector<TensorHandlePtr>* outputs,
TF_Status* status) const { TF_Status* status) const {
if (op_ == nullptr) { if (op_ == nullptr) {
TFE_ContextSetExecutorForThread(context, executor_.get());
op_.reset(TFE_NewOp(context, operation_name, status)); op_.reset(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op_.get(), device_.c_str(), status); TFE_OpSetDevice(op_.get(), device_.c_str(), status);

View File

@ -412,6 +412,7 @@ void TestCollective(bool async) {
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts( std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions); TFE_NewContextOptions(), TFE_DeleteContextOptions);
TFE_ContextOptionsSetAsync(opts.get(), async);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config( std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig( TF_CreateConfig(
/*xla*/ false, /*xla*/ false,
@ -423,9 +424,6 @@ void TestCollective(bool async) {
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context( std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Executor, decltype(&TFE_DeleteExecutor)> executor(
TFE_NewExecutor(async), TFE_DeleteExecutor);
TFE_ContextSetExecutorForThread(context.get(), executor.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{ std::array<const char*, 2> underlying_devices{
@ -455,8 +453,6 @@ void TestCollective(bool async) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(result_components[0].get(), 3.); ExpectScalarEq<float>(result_components[0].get(), 3.);
ExpectScalarEq<float>(result_components[1].get(), 3.); ExpectScalarEq<float>(result_components[1].get(), 3.);
// Destroying the context's default executor first isn't safe.
context.reset();
} }
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); } TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }

View File

@ -341,7 +341,28 @@ void EagerContext::SetExecutorForThread(EagerExecutor* executor) {
if (executor == &default_executor_) { if (executor == &default_executor_) {
thread_local_executor_.erase(std::this_thread::get_id()); thread_local_executor_.erase(std::this_thread::get_id());
} else { } else {
thread_local_executor_[std::this_thread::get_id()] = executor; auto thread_id = std::this_thread::get_id();
thread_local_executor_[thread_id] = executor;
auto& executors_with_cleanups = has_cleanup_[thread_id];
if (executors_with_cleanups.find(executor) ==
executors_with_cleanups.end()) {
executors_with_cleanups.insert(executor);
// If the executor is deleted before this context, we need to remove it
// from the map to avoid attempting to sync it in our destructor.
std::function<void()> cleanup([this, thread_id, executor]() {
{
tensorflow::mutex_lock l(executor_map_mu_);
auto existing = thread_local_executor_.find(thread_id);
if (existing != thread_local_executor_.end() &&
existing->second == executor) {
thread_local_executor_.erase(thread_id);
}
has_cleanup_[thread_id].erase(executor);
}
});
executor->AddCleanup(reinterpret_cast<intptr_t>(this),
std::move(cleanup));
}
} }
} }
@ -525,6 +546,15 @@ EagerContext::~EagerContext() {
custom_devices_.clear(); custom_devices_.clear();
ClearCachesAndThreadExecutors(); ClearCachesAndThreadExecutors();
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
{
mutex_lock l(executor_map_mu_);
executors_copy = thread_local_executor_;
}
for (const auto& entry : executors_copy) {
// Let the executor know that its cleanup closure is no longer valid.
entry.second->RemoveCleanups(reinterpret_cast<intptr_t>(this));
}
for (auto& entry : registered_functions_) { for (auto& entry : registered_functions_) {
while (!entry.second->Unref()) { while (!entry.second->Unref()) {
// remove all references. // remove all references.

View File

@ -639,6 +639,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
// Not owned. // Not owned.
std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_ std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_
TF_GUARDED_BY(executor_map_mu_); TF_GUARDED_BY(executor_map_mu_);
std::unordered_map<std::thread::id, std::unordered_set<EagerExecutor*>>
has_cleanup_ TF_GUARDED_BY(executor_map_mu_);
const bool log_memory_; const bool log_memory_;

View File

@ -46,6 +46,11 @@ EagerExecutor::~EagerExecutor() {
tensorflow::mutex_lock l(node_queue_mutex_); tensorflow::mutex_lock l(node_queue_mutex_);
state_ = ExecutorState::kShutDown; state_ = ExecutorState::kShutDown;
nodes_pending_.notify_all(); nodes_pending_.notify_all();
for (const auto& cleanups_for_key : cleanups_) {
for (const std::function<void()>& cleanup : cleanups_for_key.second) {
cleanup();
}
}
} }
Status EagerExecutor::ShutDown() { Status EagerExecutor::ShutDown() {
@ -413,4 +418,10 @@ Status EagerExecutor::MoveToUnfinished(core::RefCountPtr<NodeItem> item,
return Status::OK(); return Status::OK();
} }
void EagerExecutor::AddCleanup(intptr_t key, std::function<void()> callback) {
cleanups_[key].push_back(callback);
}
void EagerExecutor::RemoveCleanups(intptr_t key) { cleanups_.erase(key); }
} // namespace tensorflow } // namespace tensorflow

View File

@ -153,6 +153,13 @@ class EagerExecutor {
bool ok() const TF_NO_THREAD_SAFETY_ANALYSIS { return ok_; } bool ok() const TF_NO_THREAD_SAFETY_ANALYSIS { return ok_; }
// On destruction, runs `callback`. Used by the EagerContext for clearing
// thread-local executors.
void AddCleanup(intptr_t key, std::function<void()> callback);
// If `key` (e.g. a context) is destroyed before the executor, the associated
// callbacks are no longer safe to run.
void RemoveCleanups(intptr_t key);
private: private:
// Possible states for this executor. // Possible states for this executor.
// Executor starts in kActive state. When Shutdown() is called, Executor // Executor starts in kActive state. When Shutdown() is called, Executor
@ -250,6 +257,9 @@ class EagerExecutor {
const eager::EagerClient* last_eager_client_; const eager::EagerClient* last_eager_client_;
const bool enable_async_wait_for_remote_function_; const bool enable_async_wait_for_remote_function_;
// Callbacks to run on destruction.
std::unordered_map<intptr_t, std::vector<std::function<void()>>> cleanups_;
}; };
inline bool EagerExecutor::Async() const { return thread_ != nullptr; } inline bool EagerExecutor::Async() const { return thread_ != nullptr; }

View File

@ -23,6 +23,7 @@ import threading
from tensorflow.python.distribute.parallel_device import parallel_device from tensorflow.python.distribute.parallel_device import parallel_device
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.module import module from tensorflow.python.module import module
@ -136,7 +137,7 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
self.assertIn(self.device.components[0], outputs[0].backing_device) self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device) self.assertIn(self.device.components[1], outputs[1].backing_device)
def test_collective_reduce_async(self): def test_collective_reduce_async_scope(self):
# Note that ops on the parallel device currently don't execute # Note that ops on the parallel device currently don't execute
# asynchronously. The test is just that we don't get deadlocks. # asynchronously. The test is just that we don't get deadlocks.
with context.async_scope(), ops.device(self.device.name): with context.async_scope(), ops.device(self.device.name):
@ -149,6 +150,27 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
self.assertIn(self.device.components[0], outputs[0].backing_device) self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device) self.assertIn(self.device.components[1], outputs[1].backing_device)
def test_collective_reduce_async_context(self):
previous = config.get_synchronous_execution()
try:
context._reset_context()
config.set_synchronous_execution(False)
self.setUp()
# Note that ops on the parallel device currently don't execute
# asynchronously. The test is just that we don't get deadlocks.
with ops.device(self.device.name):
x = self.device.pack(
[constant_op.constant(-1.5),
constant_op.constant(3.5)])
reduced = _collective_sum(x, num_replicas=2)
outputs = self.device.unpack(reduced)
self.assertAllClose([2., 2.], outputs)
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
finally:
context._reset_context()
config.set_synchronous_execution(previous)
def test_checkpointing(self): def test_checkpointing(self):
prefix = os.path.join(self.get_temp_dir(), "ckpt") prefix = os.path.join(self.get_temp_dir(), "ckpt")
with self.device.scope(): with self.device.scope():