Parallel device: avoid deadlocks when the EagerContext's default executor is async
Creates one sync executor per thread. Requires fixing a tangential use-after-free where the context assumed all of the thread-local executors were still allocated at shutdown. PiperOrigin-RevId: 316783819 Change-Id: I62e7a91dcccb847d4e1c2a5f08e30c2877556618
This commit is contained in:
parent
267f956246
commit
f8657c62c6
tensorflow
c/eager
core/common_runtime/eager
python/distribute/parallel_device
@ -212,6 +212,35 @@ TEST(CAPI, CancellationManager) {
|
||||
TFE_DeleteCancellationManager(c_mgr);
|
||||
}
|
||||
|
||||
TEST(CAPI, ExecutorContextDestructionOrder) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
{
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
|
||||
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TFE_DeleteExecutor(executor);
|
||||
}
|
||||
|
||||
{
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
|
||||
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, Function_ident_CPU) {
|
||||
// First create a simple identity function.
|
||||
TF_Graph* function_graph = TF_NewGraph();
|
||||
|
@ -37,6 +37,15 @@ class StatusDeleter {
|
||||
|
||||
using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
|
||||
|
||||
class ExecutorDeleter {
|
||||
public:
|
||||
void operator()(TFE_Executor* to_delete) const {
|
||||
TFE_DeleteExecutor(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
} // namespace
|
||||
|
||||
// Allows a single op at a time to be launched without blocking.
|
||||
@ -51,6 +60,13 @@ class DeviceThread {
|
||||
explicit DeviceThread(const std::string& device)
|
||||
: status_(TF_NewStatus()),
|
||||
device_(device),
|
||||
// If the context's default exector is set to async, re-using that in
|
||||
// each thread would cause collectives to deadlock. For consistency we
|
||||
// create a new sync executor for every thread.
|
||||
//
|
||||
// TODO(allenl): We should have an async API that works with the
|
||||
// parallel device.
|
||||
executor_(TFE_NewExecutor(/*is_async=*/false)),
|
||||
op_(nullptr),
|
||||
thread_(tensorflow::Env::Default()->StartThread(
|
||||
tensorflow::ThreadOptions(), "parallel_device_execute",
|
||||
@ -105,6 +121,7 @@ class DeviceThread {
|
||||
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
|
||||
|
||||
const std::string device_;
|
||||
ExecutorPtr executor_ TF_GUARDED_BY(execution_mutex_);
|
||||
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
|
||||
std::unique_ptr<Thread> thread_;
|
||||
};
|
||||
@ -186,6 +203,7 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
|
||||
std::vector<TensorHandlePtr>* outputs,
|
||||
TF_Status* status) const {
|
||||
if (op_ == nullptr) {
|
||||
TFE_ContextSetExecutorForThread(context, executor_.get());
|
||||
op_.reset(TFE_NewOp(context, operation_name, status));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op_.get(), device_.c_str(), status);
|
||||
|
@ -412,6 +412,7 @@ void TestCollective(bool async) {
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
TFE_ContextOptionsSetAsync(opts.get(), async);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
@ -423,9 +424,6 @@ void TestCollective(bool async) {
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::unique_ptr<TFE_Executor, decltype(&TFE_DeleteExecutor)> executor(
|
||||
TFE_NewExecutor(async), TFE_DeleteExecutor);
|
||||
TFE_ContextSetExecutorForThread(context.get(), executor.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{
|
||||
@ -455,8 +453,6 @@ void TestCollective(bool async) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ExpectScalarEq<float>(result_components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(result_components[1].get(), 3.);
|
||||
// Destroying the context's default executor first isn't safe.
|
||||
context.reset();
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
|
||||
|
@ -341,7 +341,28 @@ void EagerContext::SetExecutorForThread(EagerExecutor* executor) {
|
||||
if (executor == &default_executor_) {
|
||||
thread_local_executor_.erase(std::this_thread::get_id());
|
||||
} else {
|
||||
thread_local_executor_[std::this_thread::get_id()] = executor;
|
||||
auto thread_id = std::this_thread::get_id();
|
||||
thread_local_executor_[thread_id] = executor;
|
||||
auto& executors_with_cleanups = has_cleanup_[thread_id];
|
||||
if (executors_with_cleanups.find(executor) ==
|
||||
executors_with_cleanups.end()) {
|
||||
executors_with_cleanups.insert(executor);
|
||||
// If the executor is deleted before this context, we need to remove it
|
||||
// from the map to avoid attempting to sync it in our destructor.
|
||||
std::function<void()> cleanup([this, thread_id, executor]() {
|
||||
{
|
||||
tensorflow::mutex_lock l(executor_map_mu_);
|
||||
auto existing = thread_local_executor_.find(thread_id);
|
||||
if (existing != thread_local_executor_.end() &&
|
||||
existing->second == executor) {
|
||||
thread_local_executor_.erase(thread_id);
|
||||
}
|
||||
has_cleanup_[thread_id].erase(executor);
|
||||
}
|
||||
});
|
||||
executor->AddCleanup(reinterpret_cast<intptr_t>(this),
|
||||
std::move(cleanup));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -525,6 +546,15 @@ EagerContext::~EagerContext() {
|
||||
custom_devices_.clear();
|
||||
|
||||
ClearCachesAndThreadExecutors();
|
||||
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
|
||||
{
|
||||
mutex_lock l(executor_map_mu_);
|
||||
executors_copy = thread_local_executor_;
|
||||
}
|
||||
for (const auto& entry : executors_copy) {
|
||||
// Let the executor know that its cleanup closure is no longer valid.
|
||||
entry.second->RemoveCleanups(reinterpret_cast<intptr_t>(this));
|
||||
}
|
||||
for (auto& entry : registered_functions_) {
|
||||
while (!entry.second->Unref()) {
|
||||
// remove all references.
|
||||
|
@ -639,6 +639,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
||||
// Not owned.
|
||||
std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_
|
||||
TF_GUARDED_BY(executor_map_mu_);
|
||||
std::unordered_map<std::thread::id, std::unordered_set<EagerExecutor*>>
|
||||
has_cleanup_ TF_GUARDED_BY(executor_map_mu_);
|
||||
|
||||
const bool log_memory_;
|
||||
|
||||
|
@ -46,6 +46,11 @@ EagerExecutor::~EagerExecutor() {
|
||||
tensorflow::mutex_lock l(node_queue_mutex_);
|
||||
state_ = ExecutorState::kShutDown;
|
||||
nodes_pending_.notify_all();
|
||||
for (const auto& cleanups_for_key : cleanups_) {
|
||||
for (const std::function<void()>& cleanup : cleanups_for_key.second) {
|
||||
cleanup();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status EagerExecutor::ShutDown() {
|
||||
@ -413,4 +418,10 @@ Status EagerExecutor::MoveToUnfinished(core::RefCountPtr<NodeItem> item,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void EagerExecutor::AddCleanup(intptr_t key, std::function<void()> callback) {
|
||||
cleanups_[key].push_back(callback);
|
||||
}
|
||||
|
||||
void EagerExecutor::RemoveCleanups(intptr_t key) { cleanups_.erase(key); }
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -153,6 +153,13 @@ class EagerExecutor {
|
||||
|
||||
bool ok() const TF_NO_THREAD_SAFETY_ANALYSIS { return ok_; }
|
||||
|
||||
// On destruction, runs `callback`. Used by the EagerContext for clearing
|
||||
// thread-local executors.
|
||||
void AddCleanup(intptr_t key, std::function<void()> callback);
|
||||
// If `key` (e.g. a context) is destroyed before the executor, the associated
|
||||
// callbacks are no longer safe to run.
|
||||
void RemoveCleanups(intptr_t key);
|
||||
|
||||
private:
|
||||
// Possible states for this executor.
|
||||
// Executor starts in kActive state. When Shutdown() is called, Executor
|
||||
@ -250,6 +257,9 @@ class EagerExecutor {
|
||||
const eager::EagerClient* last_eager_client_;
|
||||
|
||||
const bool enable_async_wait_for_remote_function_;
|
||||
|
||||
// Callbacks to run on destruction.
|
||||
std::unordered_map<intptr_t, std::vector<std::function<void()>>> cleanups_;
|
||||
};
|
||||
|
||||
inline bool EagerExecutor::Async() const { return thread_ != nullptr; }
|
||||
|
@ -23,6 +23,7 @@ import threading
|
||||
from tensorflow.python.distribute.parallel_device import parallel_device
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.module import module
|
||||
@ -136,7 +137,7 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
|
||||
self.assertIn(self.device.components[0], outputs[0].backing_device)
|
||||
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
||||
|
||||
def test_collective_reduce_async(self):
|
||||
def test_collective_reduce_async_scope(self):
|
||||
# Note that ops on the parallel device currently don't execute
|
||||
# asynchronously. The test is just that we don't get deadlocks.
|
||||
with context.async_scope(), ops.device(self.device.name):
|
||||
@ -149,6 +150,27 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
|
||||
self.assertIn(self.device.components[0], outputs[0].backing_device)
|
||||
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
||||
|
||||
def test_collective_reduce_async_context(self):
|
||||
previous = config.get_synchronous_execution()
|
||||
try:
|
||||
context._reset_context()
|
||||
config.set_synchronous_execution(False)
|
||||
self.setUp()
|
||||
# Note that ops on the parallel device currently don't execute
|
||||
# asynchronously. The test is just that we don't get deadlocks.
|
||||
with ops.device(self.device.name):
|
||||
x = self.device.pack(
|
||||
[constant_op.constant(-1.5),
|
||||
constant_op.constant(3.5)])
|
||||
reduced = _collective_sum(x, num_replicas=2)
|
||||
outputs = self.device.unpack(reduced)
|
||||
self.assertAllClose([2., 2.], outputs)
|
||||
self.assertIn(self.device.components[0], outputs[0].backing_device)
|
||||
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
||||
finally:
|
||||
context._reset_context()
|
||||
config.set_synchronous_execution(previous)
|
||||
|
||||
def test_checkpointing(self):
|
||||
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
with self.device.scope():
|
||||
|
Loading…
Reference in New Issue
Block a user