Parallel device: avoid deadlocks when the EagerContext's default executor is async
Creates one sync executor per thread. Requires fixing a tangential use-after-free where the context assumed all of the thread-local executors were still allocated at shutdown. PiperOrigin-RevId: 316783819 Change-Id: I62e7a91dcccb847d4e1c2a5f08e30c2877556618
This commit is contained in:
parent
267f956246
commit
f8657c62c6
tensorflow
c/eager
core/common_runtime/eager
python/distribute/parallel_device
@ -212,6 +212,35 @@ TEST(CAPI, CancellationManager) {
|
|||||||
TFE_DeleteCancellationManager(c_mgr);
|
TFE_DeleteCancellationManager(c_mgr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, ExecutorContextDestructionOrder) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
|
||||||
|
{
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
|
||||||
|
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||||
|
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
TFE_DeleteExecutor(executor);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
|
||||||
|
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||||
|
|
||||||
|
TFE_DeleteExecutor(executor);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
}
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CAPI, Function_ident_CPU) {
|
TEST(CAPI, Function_ident_CPU) {
|
||||||
// First create a simple identity function.
|
// First create a simple identity function.
|
||||||
TF_Graph* function_graph = TF_NewGraph();
|
TF_Graph* function_graph = TF_NewGraph();
|
||||||
|
@ -37,6 +37,15 @@ class StatusDeleter {
|
|||||||
|
|
||||||
using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
|
using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
|
||||||
|
|
||||||
|
class ExecutorDeleter {
|
||||||
|
public:
|
||||||
|
void operator()(TFE_Executor* to_delete) const {
|
||||||
|
TFE_DeleteExecutor(to_delete);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Allows a single op at a time to be launched without blocking.
|
// Allows a single op at a time to be launched without blocking.
|
||||||
@ -51,6 +60,13 @@ class DeviceThread {
|
|||||||
explicit DeviceThread(const std::string& device)
|
explicit DeviceThread(const std::string& device)
|
||||||
: status_(TF_NewStatus()),
|
: status_(TF_NewStatus()),
|
||||||
device_(device),
|
device_(device),
|
||||||
|
// If the context's default exector is set to async, re-using that in
|
||||||
|
// each thread would cause collectives to deadlock. For consistency we
|
||||||
|
// create a new sync executor for every thread.
|
||||||
|
//
|
||||||
|
// TODO(allenl): We should have an async API that works with the
|
||||||
|
// parallel device.
|
||||||
|
executor_(TFE_NewExecutor(/*is_async=*/false)),
|
||||||
op_(nullptr),
|
op_(nullptr),
|
||||||
thread_(tensorflow::Env::Default()->StartThread(
|
thread_(tensorflow::Env::Default()->StartThread(
|
||||||
tensorflow::ThreadOptions(), "parallel_device_execute",
|
tensorflow::ThreadOptions(), "parallel_device_execute",
|
||||||
@ -105,6 +121,7 @@ class DeviceThread {
|
|||||||
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
|
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
|
||||||
|
|
||||||
const std::string device_;
|
const std::string device_;
|
||||||
|
ExecutorPtr executor_ TF_GUARDED_BY(execution_mutex_);
|
||||||
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
|
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
|
||||||
std::unique_ptr<Thread> thread_;
|
std::unique_ptr<Thread> thread_;
|
||||||
};
|
};
|
||||||
@ -186,6 +203,7 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
|
|||||||
std::vector<TensorHandlePtr>* outputs,
|
std::vector<TensorHandlePtr>* outputs,
|
||||||
TF_Status* status) const {
|
TF_Status* status) const {
|
||||||
if (op_ == nullptr) {
|
if (op_ == nullptr) {
|
||||||
|
TFE_ContextSetExecutorForThread(context, executor_.get());
|
||||||
op_.reset(TFE_NewOp(context, operation_name, status));
|
op_.reset(TFE_NewOp(context, operation_name, status));
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
TFE_OpSetDevice(op_.get(), device_.c_str(), status);
|
TFE_OpSetDevice(op_.get(), device_.c_str(), status);
|
||||||
|
@ -412,6 +412,7 @@ void TestCollective(bool async) {
|
|||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||||
|
TFE_ContextOptionsSetAsync(opts.get(), async);
|
||||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||||
TF_CreateConfig(
|
TF_CreateConfig(
|
||||||
/*xla*/ false,
|
/*xla*/ false,
|
||||||
@ -423,9 +424,6 @@ void TestCollective(bool async) {
|
|||||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
std::unique_ptr<TFE_Executor, decltype(&TFE_DeleteExecutor)> executor(
|
|
||||||
TFE_NewExecutor(async), TFE_DeleteExecutor);
|
|
||||||
TFE_ContextSetExecutorForThread(context.get(), executor.get());
|
|
||||||
|
|
||||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||||
std::array<const char*, 2> underlying_devices{
|
std::array<const char*, 2> underlying_devices{
|
||||||
@ -455,8 +453,6 @@ void TestCollective(bool async) {
|
|||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
ExpectScalarEq<float>(result_components[0].get(), 3.);
|
ExpectScalarEq<float>(result_components[0].get(), 3.);
|
||||||
ExpectScalarEq<float>(result_components[1].get(), 3.);
|
ExpectScalarEq<float>(result_components[1].get(), 3.);
|
||||||
// Destroying the context's default executor first isn't safe.
|
|
||||||
context.reset();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
|
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
|
||||||
|
@ -341,7 +341,28 @@ void EagerContext::SetExecutorForThread(EagerExecutor* executor) {
|
|||||||
if (executor == &default_executor_) {
|
if (executor == &default_executor_) {
|
||||||
thread_local_executor_.erase(std::this_thread::get_id());
|
thread_local_executor_.erase(std::this_thread::get_id());
|
||||||
} else {
|
} else {
|
||||||
thread_local_executor_[std::this_thread::get_id()] = executor;
|
auto thread_id = std::this_thread::get_id();
|
||||||
|
thread_local_executor_[thread_id] = executor;
|
||||||
|
auto& executors_with_cleanups = has_cleanup_[thread_id];
|
||||||
|
if (executors_with_cleanups.find(executor) ==
|
||||||
|
executors_with_cleanups.end()) {
|
||||||
|
executors_with_cleanups.insert(executor);
|
||||||
|
// If the executor is deleted before this context, we need to remove it
|
||||||
|
// from the map to avoid attempting to sync it in our destructor.
|
||||||
|
std::function<void()> cleanup([this, thread_id, executor]() {
|
||||||
|
{
|
||||||
|
tensorflow::mutex_lock l(executor_map_mu_);
|
||||||
|
auto existing = thread_local_executor_.find(thread_id);
|
||||||
|
if (existing != thread_local_executor_.end() &&
|
||||||
|
existing->second == executor) {
|
||||||
|
thread_local_executor_.erase(thread_id);
|
||||||
|
}
|
||||||
|
has_cleanup_[thread_id].erase(executor);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
executor->AddCleanup(reinterpret_cast<intptr_t>(this),
|
||||||
|
std::move(cleanup));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -525,6 +546,15 @@ EagerContext::~EagerContext() {
|
|||||||
custom_devices_.clear();
|
custom_devices_.clear();
|
||||||
|
|
||||||
ClearCachesAndThreadExecutors();
|
ClearCachesAndThreadExecutors();
|
||||||
|
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
|
||||||
|
{
|
||||||
|
mutex_lock l(executor_map_mu_);
|
||||||
|
executors_copy = thread_local_executor_;
|
||||||
|
}
|
||||||
|
for (const auto& entry : executors_copy) {
|
||||||
|
// Let the executor know that its cleanup closure is no longer valid.
|
||||||
|
entry.second->RemoveCleanups(reinterpret_cast<intptr_t>(this));
|
||||||
|
}
|
||||||
for (auto& entry : registered_functions_) {
|
for (auto& entry : registered_functions_) {
|
||||||
while (!entry.second->Unref()) {
|
while (!entry.second->Unref()) {
|
||||||
// remove all references.
|
// remove all references.
|
||||||
|
@ -639,6 +639,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
|||||||
// Not owned.
|
// Not owned.
|
||||||
std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_
|
std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_
|
||||||
TF_GUARDED_BY(executor_map_mu_);
|
TF_GUARDED_BY(executor_map_mu_);
|
||||||
|
std::unordered_map<std::thread::id, std::unordered_set<EagerExecutor*>>
|
||||||
|
has_cleanup_ TF_GUARDED_BY(executor_map_mu_);
|
||||||
|
|
||||||
const bool log_memory_;
|
const bool log_memory_;
|
||||||
|
|
||||||
|
@ -46,6 +46,11 @@ EagerExecutor::~EagerExecutor() {
|
|||||||
tensorflow::mutex_lock l(node_queue_mutex_);
|
tensorflow::mutex_lock l(node_queue_mutex_);
|
||||||
state_ = ExecutorState::kShutDown;
|
state_ = ExecutorState::kShutDown;
|
||||||
nodes_pending_.notify_all();
|
nodes_pending_.notify_all();
|
||||||
|
for (const auto& cleanups_for_key : cleanups_) {
|
||||||
|
for (const std::function<void()>& cleanup : cleanups_for_key.second) {
|
||||||
|
cleanup();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EagerExecutor::ShutDown() {
|
Status EagerExecutor::ShutDown() {
|
||||||
@ -413,4 +418,10 @@ Status EagerExecutor::MoveToUnfinished(core::RefCountPtr<NodeItem> item,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EagerExecutor::AddCleanup(intptr_t key, std::function<void()> callback) {
|
||||||
|
cleanups_[key].push_back(callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
void EagerExecutor::RemoveCleanups(intptr_t key) { cleanups_.erase(key); }
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -153,6 +153,13 @@ class EagerExecutor {
|
|||||||
|
|
||||||
bool ok() const TF_NO_THREAD_SAFETY_ANALYSIS { return ok_; }
|
bool ok() const TF_NO_THREAD_SAFETY_ANALYSIS { return ok_; }
|
||||||
|
|
||||||
|
// On destruction, runs `callback`. Used by the EagerContext for clearing
|
||||||
|
// thread-local executors.
|
||||||
|
void AddCleanup(intptr_t key, std::function<void()> callback);
|
||||||
|
// If `key` (e.g. a context) is destroyed before the executor, the associated
|
||||||
|
// callbacks are no longer safe to run.
|
||||||
|
void RemoveCleanups(intptr_t key);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Possible states for this executor.
|
// Possible states for this executor.
|
||||||
// Executor starts in kActive state. When Shutdown() is called, Executor
|
// Executor starts in kActive state. When Shutdown() is called, Executor
|
||||||
@ -250,6 +257,9 @@ class EagerExecutor {
|
|||||||
const eager::EagerClient* last_eager_client_;
|
const eager::EagerClient* last_eager_client_;
|
||||||
|
|
||||||
const bool enable_async_wait_for_remote_function_;
|
const bool enable_async_wait_for_remote_function_;
|
||||||
|
|
||||||
|
// Callbacks to run on destruction.
|
||||||
|
std::unordered_map<intptr_t, std::vector<std::function<void()>>> cleanups_;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool EagerExecutor::Async() const { return thread_ != nullptr; }
|
inline bool EagerExecutor::Async() const { return thread_ != nullptr; }
|
||||||
|
@ -23,6 +23,7 @@ import threading
|
|||||||
from tensorflow.python.distribute.parallel_device import parallel_device
|
from tensorflow.python.distribute.parallel_device import parallel_device
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.module import module
|
from tensorflow.python.module import module
|
||||||
@ -136,7 +137,7 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
|
|||||||
self.assertIn(self.device.components[0], outputs[0].backing_device)
|
self.assertIn(self.device.components[0], outputs[0].backing_device)
|
||||||
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
||||||
|
|
||||||
def test_collective_reduce_async(self):
|
def test_collective_reduce_async_scope(self):
|
||||||
# Note that ops on the parallel device currently don't execute
|
# Note that ops on the parallel device currently don't execute
|
||||||
# asynchronously. The test is just that we don't get deadlocks.
|
# asynchronously. The test is just that we don't get deadlocks.
|
||||||
with context.async_scope(), ops.device(self.device.name):
|
with context.async_scope(), ops.device(self.device.name):
|
||||||
@ -149,6 +150,27 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
|
|||||||
self.assertIn(self.device.components[0], outputs[0].backing_device)
|
self.assertIn(self.device.components[0], outputs[0].backing_device)
|
||||||
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
||||||
|
|
||||||
|
def test_collective_reduce_async_context(self):
|
||||||
|
previous = config.get_synchronous_execution()
|
||||||
|
try:
|
||||||
|
context._reset_context()
|
||||||
|
config.set_synchronous_execution(False)
|
||||||
|
self.setUp()
|
||||||
|
# Note that ops on the parallel device currently don't execute
|
||||||
|
# asynchronously. The test is just that we don't get deadlocks.
|
||||||
|
with ops.device(self.device.name):
|
||||||
|
x = self.device.pack(
|
||||||
|
[constant_op.constant(-1.5),
|
||||||
|
constant_op.constant(3.5)])
|
||||||
|
reduced = _collective_sum(x, num_replicas=2)
|
||||||
|
outputs = self.device.unpack(reduced)
|
||||||
|
self.assertAllClose([2., 2.], outputs)
|
||||||
|
self.assertIn(self.device.components[0], outputs[0].backing_device)
|
||||||
|
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
||||||
|
finally:
|
||||||
|
context._reset_context()
|
||||||
|
config.set_synchronous_execution(previous)
|
||||||
|
|
||||||
def test_checkpointing(self):
|
def test_checkpointing(self):
|
||||||
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||||
with self.device.scope():
|
with self.device.scope():
|
||||||
|
Loading…
Reference in New Issue
Block a user