diff --git a/tensorflow/core/framework/run_handler.cc b/tensorflow/core/framework/run_handler.cc index fbc8e243394..86be606bbd7 100644 --- a/tensorflow/core/framework/run_handler.cc +++ b/tensorflow/core/framework/run_handler.cc @@ -36,7 +36,9 @@ limitations under the License. namespace tensorflow { namespace { +// LINT.IfChange static constexpr int32 kMaxConcurrentHandlers = 128; +// LINT.ThenChange(//tensorflow/core/framework/run_handler_test.cc) // TODO(azaks): Refactor with thread:ThreadPool class RunHandlerEnvironment { @@ -948,16 +950,18 @@ class RunHandlerPool::Impl { RunHandler::Impl* handler_impl; { mutex_lock l(mu_); - if (free_handlers_.empty()) { + if (!has_free_handler()) { profiler::TraceMe activity( [&] { return strings::StrCat("WaitingForHandler#step_id=", step_id, "#"); }, profiler::TraceMeLevel::kInfo); - if (!mu_.AwaitWithDeadline( - Condition(this, &Impl::has_free_handler), - EnvTime::NowNanos() + timeout_in_ms * 1000 * 1000)) { + if (timeout_in_ms == 0) { + mu_.Await(Condition(this, &Impl::has_free_handler)); + } else if (!mu_.AwaitWithDeadline( + Condition(this, &Impl::has_free_handler), + EnvTime::NowNanos() + timeout_in_ms * 1000 * 1000)) { return nullptr; } } diff --git a/tensorflow/core/framework/run_handler_test.cc b/tensorflow/core/framework/run_handler_test.cc index 71b1fbc8d8d..8de3e3ba6bb 100644 --- a/tensorflow/core/framework/run_handler_test.cc +++ b/tensorflow/core/framework/run_handler_test.cc @@ -205,5 +205,37 @@ TEST_F(RunHandlerTest, TestConcurrencyUseRunHandlerPool) { delete tp; } +TEST_F(RunHandlerTest, TestWaitTimeout) { + std::unique_ptr pool(new RunHandlerPool(1, 1)); + + // Get the single handler in the pool. + std::vector> blocking_handles; + const int32 kMaxConcurrentHandlers = 128; // Copied from run_handler.cc. + blocking_handles.reserve(kMaxConcurrentHandlers); + for (int i = 0; i < kMaxConcurrentHandlers; ++i) { + blocking_handles.push_back(pool->Get(i)); + } + + // A subsequent request with a non-zero timeout will fail by returning + // nullptr. + auto null_handle = pool->Get(128, 1); + EXPECT_EQ(null_handle.get(), nullptr); + + // A subsequent request with no timeout will succeed once the blocking handle + // is returned. + auto tp = std::make_unique(Env::Default(), "test", 4); + std::atomic release_time; + + tp->Schedule([&blocking_handles, &release_time]() { + Env::Default()->SleepForMicroseconds(5000); + release_time = EnvTime::NowNanos(); + blocking_handles[0].reset(); + }); + + auto next_handle = pool->Get(129, 0); + EXPECT_GT(EnvTime::NowNanos(), release_time); + EXPECT_NE(next_handle.get(), nullptr); +} + } // namespace } // namespace tensorflow