From b12f17ac5abbbc8970008ec44f712bcbd8667b6f Mon Sep 17 00:00:00 2001 From: Igor Ganichev Date: Fri, 26 Jul 2019 14:17:15 -0700 Subject: [PATCH] Explicitly cancel and wait for RPCs to complete when deleting TFE_Context We want TFE_Context to own EagerContext and actually call its destructor in ~TFE_Context. The issue was that there can be outstanding RPCs holding references to EagerContext when ~TFE_Context runs. This change adds a method to EagerContext to close all remote clients and wait for all RPCs to complete. This should cause all remote tensor handles to drop their references to EagerContext. Recent Python changes made EagerTensors reference eager Context. This makes sure that Context outlives EagerTensors. Python interpreter will not call TFE_DeleteContext until after all local tensor handles have been deleted. With this and previous changes, there should not be any local or remote tensor handles left when TFE_Context unrefs EagerContext. This CL adds a remote test that deletes the context without waiting for RPCs to finish and deletes a test that called TFE_DeleteContext before deleting tensor handles because this order is no longer valid. Finally, the analogoues changes are done to eager ServerContext. PiperOrigin-RevId: 260211891 --- tensorflow/c/eager/c_api.h | 3 ++ tensorflow/c/eager/c_api_internal.h | 9 +++- tensorflow/c/eager/c_api_test.cc | 42 +++++++++++++------ tensorflow/c/eager/c_api_test_util.cc | 18 ++++++++ tensorflow/c/eager/c_api_test_util.h | 4 +- .../core/common_runtime/eager/context.cc | 40 ++++++++++++++++++ .../core/common_runtime/eager/context.h | 12 ++++++ .../eager/eager_service_impl.cc | 4 ++ .../eager/eager_service_impl.h | 5 ++- .../rpc/eager/grpc_eager_client.cc | 8 +++- .../core/distributed_runtime/rpc/grpc_state.h | 10 +++++ 11 files changed, 136 insertions(+), 19 deletions(-) diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 96b5bd25bcf..c408d8642b2 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -89,6 +89,9 @@ TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); // "Context" under which operations/functions are executed. It encapsulates // things like the available devices, resource manager etc. +// TFE_Context must outlive all tensor handles created using it. In other +// words, TFE_DeleteContext() must be called after all tensor handles have +// been deleted (with TFE_DeleteTensorHandle). // // TODO(ashankar): Merge with TF_Session? typedef struct TFE_Context TFE_Context; diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 83af62314c8..293422bc992 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -76,7 +76,14 @@ struct TFE_Context { async, device_mgr, device_mgr_owned, rendezvous, custom_kernel_creator)) {} - ~TFE_Context() { context->Unref(); } + ~TFE_Context() { + // TODO(iga): Add a separate API method to shutdown TFE_Context so that we + // don't send RPCs and block in destructor. + context->WaitForAndCloseRemoteContexts(); + // context->RefCountIsOne() should be true here. + // TODO(iga): Remove EagerContext refcounting. + context->Unref(); + } tensorflow::EagerContext* context; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index e68352214f4..e089dac8c04 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -330,7 +330,7 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) { TestRemoteExecuteSilentCopies(true); } -void TestRemoteExecuteDeleteTensorAfterContext(bool async) { +void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) { tensorflow::ServerDef server_def = GetServerDef(2); // This server def has the task index set to 0. @@ -356,33 +356,49 @@ void TestRemoteExecuteDeleteTensorAfterContext(bool async) { TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + // Use large matrices so that RPCs don't return before we get a chance + // to call TFE_DeleteContext. + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(); const char remote_device_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; auto* h0_task1 = TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + auto* h1_task1 = + TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); TFE_DeleteTensorHandle(h0_task0); - - TFE_ContextAsyncWait(ctx, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContext(ctx); - - // Delete tensors after context is deleted. + TFE_DeleteTensorHandle(h1_task0); TFE_DeleteTensorHandle(h0_task1); + TFE_DeleteTensorHandle(h1_task1); + TFE_DeleteTensorHandle(retvals[0]); - TF_DeleteStatus(status); + TFE_DeleteOp(matmul); + + TFE_DeleteContext(ctx); // TODO(b/136478427): Figure out how to correctly shut the server down. worker_server.release(); } -TEST(CAPI, RemoteExecuteDeleteTensorAfterContext) { - TestRemoteExecuteDeleteTensorAfterContext(false); +TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) { + TestRemoteExecuteDeleteContextWithOutstandingRPC(false); } -TEST(CAPI, RemoteExecuteDeleteTensorAfterContextAsync) { - TestRemoteExecuteDeleteTensorAfterContext(true); + +TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) { + TestRemoteExecuteDeleteContextWithOutstandingRPC(true); } void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index 10d95e61451..51566b35a9f 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -85,6 +85,24 @@ TFE_TensorHandle* TestMatrixTensorHandle() { return th; } +TFE_TensorHandle* TestMatrixTensorHandle100x100() { + constexpr int64_t dims[] = {100, 100}; + constexpr int num_elements = dims[0] * dims[1]; + float data[num_elements]; + for (int i = 0; i < num_elements; ++i) { + data[i] = 1.0f; + } + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() { int64_t dims[] = {3, 2}; double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index d0c20ac3743..28062222cf0 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -16,7 +16,6 @@ limitations under the License. #define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ #include "tensorflow/c/eager/c_api.h" - #include "tensorflow/core/platform/types.h" // Return a tensor handle containing a float scalar @@ -34,6 +33,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle(); // Return a tensor handle containing a 2x2 matrix of floats TFE_TensorHandle* TestMatrixTensorHandle(); +// Return a tensor handle containing a 100x100 matrix of floats +TFE_TensorHandle* TestMatrixTensorHandle100x100(); + // Return a tensor handle containing a 3x2 matrix of doubles TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(); diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 89d51d6fc8a..44284cd6360 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -232,9 +232,49 @@ void EagerContext::CloseRemoteContexts() { } counter.Wait(); + + remote_contexts_.clear(); } + #endif // !IS_MOBILE_PLATFORM +void EagerContext::WaitForAndCloseRemoteContexts() { + ClearCaches(); + +#if !defined(IS_MOBILE_PLATFORM) + { + mutex_lock l(keep_alive_thread_shutdown_mu_); + shutting_down_ = true; + keep_alive_thread_cv_.notify_all(); + } + keep_alive_thread_.reset(); + + mutex_lock l(remote_state_mu_); + if (!remote_contexts_.empty() && is_master_) { + CloseRemoteContexts(); + } + + default_executor_.ShutDown().IgnoreError(); + std::unordered_map executors_copy; + { + mutex_lock l(executor_map_mu_); + executors_copy = thread_local_executor_; + } + for (const auto& it : executors_copy) { + it.second->ShutDown().IgnoreError(); + } + + // This shuts down the completion queue and joins the thread polling it. + // The thread exits only after the completion queue has been drained of all + // the events. These events' completion should invoke all remaining RPC + // callbacks. + // This also deletes all EagerClient instances. There should not be any + // references to EagerClients left after all RPCs and async ops have been + // finished. + remote_eager_workers_ = nullptr; +#endif // !IS_MOBILE_PLATFORM +} + EagerContext::~EagerContext() { ClearCaches(); for (auto& entry : registered_functions_) { diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 21502b6ff4a..a07e2378e9f 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -310,8 +310,20 @@ class EagerContext : public core::RefCounted { // EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used // instead (which in-turn use WorkerService.RecvTensor RPCs). bool UseSendTensorRPC() { return use_send_tensor_rpc_; } + #endif // IS_MOBILE_PLATFORM + // Closes remote eager contexts, waits for all RPCs to finish, and + // destroys the EagerClientCache. No RPCs can be made through this context + // after this method has been called. + // This method exists to aid a clean shutdown. It causes all RPCs to finish + // and remote TensorHandles to release their references to this context. + // To avoid deadlocks, this method must not be called on the thread + // processing RPCs because it makes RPCs and waits for their completion. + // + // On mobile, it just cleans the caches. + void WaitForAndCloseRemoteContexts(); + bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; } tensorflow::Env* TFEnv() const { return env_; } diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 7bd1efa2c27..33714f98811 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -150,6 +150,8 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, remote_workers, request->context_id(), std::move(rendezvous_creator), std::move(remote_mgr)); if (!s.ok()) { + VLOG(1) << "EagerContext::InitializeRemoteWorker failed with " + << s.ToString(); delete ctx; return s; } @@ -293,6 +295,8 @@ Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request, Status EagerServiceImpl::CloseContext(const CloseContextRequest* request, CloseContextResponse* response) { + VLOG(1) << "Executing EagerService::CloseContext for context " + << request->context_id(); ServerContext* context = nullptr; if (!GetServerContext(request->context_id(), &context).ok()) { // Swallow the error here. diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h index b64c0ffb28c..c868d85d497 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ - #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h" @@ -111,7 +110,9 @@ class EagerServiceImpl { RecordAccess(); } ~ServerContext() { - + ctx_->WaitForAndCloseRemoteContexts(); + // ctx_->RefCountIsOne() should be true here. + // TODO(iga): Remove EagerContext refcounting. ctx_->Unref(); } diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc index da5d43abe72..a3c9c446d0a 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc @@ -59,8 +59,12 @@ class GrpcEagerClient : public EagerClient { &stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request, response, std::move(done), nullptr, nullptr); - if (enqueue_dispatchers_.find(request->context_id()) != - enqueue_dispatchers_.end()) { + VLOG(1) << "Sending RPC to close remote eager context " + << request->DebugString(); + + const auto& it = enqueue_dispatchers_.find(request->context_id()); + if (it != enqueue_dispatchers_.end()) { + it->second.CancelCall(); enqueue_dispatchers_.erase(request->context_id()); } else { LOG(ERROR) << "Remote EagerContext with id " << request->context_id() diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h index 10c9af37056..d05d38f7804 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h @@ -597,6 +597,16 @@ class StreamingRPCDispatcher { } } + // Request to cancel the current streaming call. Non-blocking. + void CancelCall() { + mutex_lock l(mu_); + if (state_ == nullptr) { + return; + } + context_->TryCancel(); + state_ = nullptr; + } + private: void CreateStreamingState() EXCLUSIVE_LOCKS_REQUIRED(mu_) { // ClientContext cannot be reused across calls.