diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 05f3e851f7f..7e2a85ba6d3 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -46,6 +46,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -55,6 +56,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/context.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -1358,6 +1360,9 @@ class ExecutorState { // Clean up when this executor is done. void Finish(); + // Schedule Finish() on a separate thread if it needs to wait for deferred + // async ops to complete; otherwise run it on the current thread. + void ScheduleFinish(); // A standalone routine for this expression so that we can express // that we don't want thread safety analysis on this reference (it's @@ -1778,7 +1783,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { const bool completed = NodeDone(s, state->item->node, ready, stats, nullptr); delete state; - if (completed) Finish(); + if (completed) ScheduleFinish(); }; nodestats::SetOpStart(stats); device->ComputeAsync(async, &state->ctx, done); @@ -1865,7 +1870,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { } // while !inline_ready.empty() // This thread of computation is done if completed = true. - if (completed) Finish(); + if (completed) ScheduleFinish(); } Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input, @@ -2421,6 +2426,25 @@ void ExecutorState::DumpState() { } } +void ExecutorState::ScheduleFinish() { + int num_deferred_ops; + { + mutex_lock lock(num_deferred_ops_mu_); + num_deferred_ops = num_deferred_ops_; + } + if (num_deferred_ops > 0) { + // Finish() may be blocked waiting for deferred async ops to complete. The + // execution of deferred async ops may be waiting for non-enqueued ops of + // other executors to complete. So running Finish() on the current thread + // (inter-op threadpool thread) may lead to a deadlock due to threadpool + // exhaustion. Instead, we run it on a separate thread to unblock the + // threadpool thread. + Env::Default()->SchedClosure([this]() { Finish(); }); + } else { + Finish(); + } +} + void ExecutorState::Finish() { mu_.lock(); auto status = status_;