Explicitly cancel and wait for RPCs to complete when deleting TFE_Context

We want TFE_Context to own EagerContext and actually call its destructor
in ~TFE_Context. The issue was that there can be outstanding RPCs holding
references to EagerContext when ~TFE_Context runs. This change adds a method
to EagerContext to close all remote clients and wait for all RPCs to complete.
This should cause all remote tensor handles to drop their references to
EagerContext.

Recent Python changes made EagerTensors reference eager Context. This makes sure
that Context outlives EagerTensors. Python interpreter will not call
TFE_DeleteContext until after all local tensor handles have been deleted.

With this and previous changes, there should not be any local or remote tensor
handles left when TFE_Context unrefs EagerContext.

This CL adds a remote test that deletes the context without waiting for RPCs
to finish and deletes a test that called TFE_DeleteContext before deleting
tensor handles because this order is no longer valid.

Finally, the analogoues changes are done to eager ServerContext.

PiperOrigin-RevId: 260211891
This commit is contained in:
Igor Ganichev 2019-07-26 14:17:15 -07:00 committed by TensorFlower Gardener
parent 8a5d5406f9
commit b12f17ac5a
11 changed files with 136 additions and 19 deletions

View File

@ -89,6 +89,9 @@ TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);
// "Context" under which operations/functions are executed. It encapsulates
// things like the available devices, resource manager etc.
// TFE_Context must outlive all tensor handles created using it. In other
// words, TFE_DeleteContext() must be called after all tensor handles have
// been deleted (with TFE_DeleteTensorHandle).
//
// TODO(ashankar): Merge with TF_Session?
typedef struct TFE_Context TFE_Context;

View File

@ -76,7 +76,14 @@ struct TFE_Context {
async, device_mgr, device_mgr_owned, rendezvous,
custom_kernel_creator)) {}
~TFE_Context() { context->Unref(); }
~TFE_Context() {
// TODO(iga): Add a separate API method to shutdown TFE_Context so that we
// don't send RPCs and block in destructor.
context->WaitForAndCloseRemoteContexts();
// context->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting.
context->Unref();
}
tensorflow::EagerContext* context;
};

View File

@ -330,7 +330,7 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true);
}
void TestRemoteExecuteDeleteTensorAfterContext(bool async) {
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
@ -356,33 +356,49 @@ void TestRemoteExecuteDeleteTensorAfterContext(bool async) {
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
// Use large matrices so that RPCs don't return before we get a chance
// to call TFE_DeleteContext.
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100();
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* h1_task1 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_DeleteTensorHandle(h0_task0);
TFE_ContextAsyncWait(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContext(ctx);
// Delete tensors after context is deleted.
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h0_task1);
TFE_DeleteTensorHandle(h1_task1);
TFE_DeleteTensorHandle(retvals[0]);
TF_DeleteStatus(status);
TFE_DeleteOp(matmul);
TFE_DeleteContext(ctx);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteDeleteTensorAfterContext) {
TestRemoteExecuteDeleteTensorAfterContext(false);
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
}
TEST(CAPI, RemoteExecuteDeleteTensorAfterContextAsync) {
TestRemoteExecuteDeleteTensorAfterContext(true);
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
}
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,

View File

@ -85,6 +85,24 @@ TFE_TensorHandle* TestMatrixTensorHandle() {
return th;
}
TFE_TensorHandle* TestMatrixTensorHandle100x100() {
constexpr int64_t dims[] = {100, 100};
constexpr int num_elements = dims[0] * dims[1];
float data[num_elements];
for (int i = 0; i < num_elements; ++i) {
data[i] = 1.0f;
}
TF_Tensor* t = TF_AllocateTensor(
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() {
int64_t dims[] = {3, 2};
double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};

View File

@ -16,7 +16,6 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/platform/types.h"
// Return a tensor handle containing a float scalar
@ -34,6 +33,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle();
// Return a tensor handle containing a 2x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle();
// Return a tensor handle containing a 100x100 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle100x100();
// Return a tensor handle containing a 3x2 matrix of doubles
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();

View File

@ -232,9 +232,49 @@ void EagerContext::CloseRemoteContexts() {
}
counter.Wait();
remote_contexts_.clear();
}
#endif // !IS_MOBILE_PLATFORM
void EagerContext::WaitForAndCloseRemoteContexts() {
ClearCaches();
#if !defined(IS_MOBILE_PLATFORM)
{
mutex_lock l(keep_alive_thread_shutdown_mu_);
shutting_down_ = true;
keep_alive_thread_cv_.notify_all();
}
keep_alive_thread_.reset();
mutex_lock l(remote_state_mu_);
if (!remote_contexts_.empty() && is_master_) {
CloseRemoteContexts();
}
default_executor_.ShutDown().IgnoreError();
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
{
mutex_lock l(executor_map_mu_);
executors_copy = thread_local_executor_;
}
for (const auto& it : executors_copy) {
it.second->ShutDown().IgnoreError();
}
// This shuts down the completion queue and joins the thread polling it.
// The thread exits only after the completion queue has been drained of all
// the events. These events' completion should invoke all remaining RPC
// callbacks.
// This also deletes all EagerClient instances. There should not be any
// references to EagerClients left after all RPCs and async ops have been
// finished.
remote_eager_workers_ = nullptr;
#endif // !IS_MOBILE_PLATFORM
}
EagerContext::~EagerContext() {
ClearCaches();
for (auto& entry : registered_functions_) {

View File

@ -310,8 +310,20 @@ class EagerContext : public core::RefCounted {
// EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used
// instead (which in-turn use WorkerService.RecvTensor RPCs).
bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
#endif // IS_MOBILE_PLATFORM
// Closes remote eager contexts, waits for all RPCs to finish, and
// destroys the EagerClientCache. No RPCs can be made through this context
// after this method has been called.
// This method exists to aid a clean shutdown. It causes all RPCs to finish
// and remote TensorHandles to release their references to this context.
// To avoid deadlocks, this method must not be called on the thread
// processing RPCs because it makes RPCs and waits for their completion.
//
// On mobile, it just cleans the caches.
void WaitForAndCloseRemoteContexts();
bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; }
tensorflow::Env* TFEnv() const { return env_; }

View File

@ -150,6 +150,8 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
remote_workers, request->context_id(), std::move(rendezvous_creator),
std::move(remote_mgr));
if (!s.ok()) {
VLOG(1) << "EagerContext::InitializeRemoteWorker failed with "
<< s.ToString();
delete ctx;
return s;
}
@ -293,6 +295,8 @@ Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request,
Status EagerServiceImpl::CloseContext(const CloseContextRequest* request,
CloseContextResponse* response) {
VLOG(1) << "Executing EagerService::CloseContext for context "
<< request->context_id();
ServerContext* context = nullptr;
if (!GetServerContext(request->context_id(), &context).ok()) {
// Swallow the error here.

View File

@ -16,7 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
@ -111,7 +110,9 @@ class EagerServiceImpl {
RecordAccess();
}
~ServerContext() {
ctx_->WaitForAndCloseRemoteContexts();
// ctx_->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting.
ctx_->Unref();
}

View File

@ -59,8 +59,12 @@ class GrpcEagerClient : public EagerClient {
&stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
response, std::move(done), nullptr, nullptr);
if (enqueue_dispatchers_.find(request->context_id()) !=
enqueue_dispatchers_.end()) {
VLOG(1) << "Sending RPC to close remote eager context "
<< request->DebugString();
const auto& it = enqueue_dispatchers_.find(request->context_id());
if (it != enqueue_dispatchers_.end()) {
it->second.CancelCall();
enqueue_dispatchers_.erase(request->context_id());
} else {
LOG(ERROR) << "Remote EagerContext with id " << request->context_id()

View File

@ -597,6 +597,16 @@ class StreamingRPCDispatcher {
}
}
// Request to cancel the current streaming call. Non-blocking.
void CancelCall() {
mutex_lock l(mu_);
if (state_ == nullptr) {
return;
}
context_->TryCancel();
state_ = nullptr;
}
private:
void CreateStreamingState() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
// ClientContext cannot be reused across calls.