Explicitly cancel and wait for RPCs to complete when deleting TFE_Context
We want TFE_Context to own EagerContext and actually call its destructor in ~TFE_Context. The issue was that there can be outstanding RPCs holding references to EagerContext when ~TFE_Context runs. This change adds a method to EagerContext to close all remote clients and wait for all RPCs to complete. This should cause all remote tensor handles to drop their references to EagerContext. Recent Python changes made EagerTensors reference eager Context. This makes sure that Context outlives EagerTensors. Python interpreter will not call TFE_DeleteContext until after all local tensor handles have been deleted. With this and previous changes, there should not be any local or remote tensor handles left when TFE_Context unrefs EagerContext. This CL adds a remote test that deletes the context without waiting for RPCs to finish and deletes a test that called TFE_DeleteContext before deleting tensor handles because this order is no longer valid. Finally, the analogoues changes are done to eager ServerContext. PiperOrigin-RevId: 260211891
This commit is contained in:
parent
8a5d5406f9
commit
b12f17ac5a
@ -89,6 +89,9 @@ TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);
|
||||
|
||||
// "Context" under which operations/functions are executed. It encapsulates
|
||||
// things like the available devices, resource manager etc.
|
||||
// TFE_Context must outlive all tensor handles created using it. In other
|
||||
// words, TFE_DeleteContext() must be called after all tensor handles have
|
||||
// been deleted (with TFE_DeleteTensorHandle).
|
||||
//
|
||||
// TODO(ashankar): Merge with TF_Session?
|
||||
typedef struct TFE_Context TFE_Context;
|
||||
|
@ -76,7 +76,14 @@ struct TFE_Context {
|
||||
async, device_mgr, device_mgr_owned, rendezvous,
|
||||
custom_kernel_creator)) {}
|
||||
|
||||
~TFE_Context() { context->Unref(); }
|
||||
~TFE_Context() {
|
||||
// TODO(iga): Add a separate API method to shutdown TFE_Context so that we
|
||||
// don't send RPCs and block in destructor.
|
||||
context->WaitForAndCloseRemoteContexts();
|
||||
// context->RefCountIsOne() should be true here.
|
||||
// TODO(iga): Remove EagerContext refcounting.
|
||||
context->Unref();
|
||||
}
|
||||
|
||||
tensorflow::EagerContext* context;
|
||||
};
|
||||
|
@ -330,7 +330,7 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||
TestRemoteExecuteSilentCopies(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteTensorAfterContext(bool async) {
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
@ -356,33 +356,49 @@ void TestRemoteExecuteDeleteTensorAfterContext(bool async) {
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
|
||||
// Use large matrices so that RPCs don't return before we get a chance
|
||||
// to call TFE_DeleteContext.
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100();
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100();
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
auto* h0_task1 =
|
||||
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto* h1_task1 =
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
|
||||
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
// Delete tensors after context is deleted.
|
||||
TFE_DeleteTensorHandle(h1_task0);
|
||||
TFE_DeleteTensorHandle(h0_task1);
|
||||
TFE_DeleteTensorHandle(h1_task1);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteDeleteTensorAfterContext) {
|
||||
TestRemoteExecuteDeleteTensorAfterContext(false);
|
||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
|
||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteDeleteTensorAfterContextAsync) {
|
||||
TestRemoteExecuteDeleteTensorAfterContext(true);
|
||||
|
||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
|
||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
|
||||
}
|
||||
|
||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||
|
@ -85,6 +85,24 @@ TFE_TensorHandle* TestMatrixTensorHandle() {
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100() {
|
||||
constexpr int64_t dims[] = {100, 100};
|
||||
constexpr int num_elements = dims[0] * dims[1];
|
||||
float data[num_elements];
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
data[i] = 1.0f;
|
||||
}
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() {
|
||||
int64_t dims[] = {3, 2};
|
||||
double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
// Return a tensor handle containing a float scalar
|
||||
@ -34,6 +33,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle();
|
||||
// Return a tensor handle containing a 2x2 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle();
|
||||
|
||||
// Return a tensor handle containing a 100x100 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100();
|
||||
|
||||
// Return a tensor handle containing a 3x2 matrix of doubles
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
|
||||
|
||||
|
@ -232,9 +232,49 @@ void EagerContext::CloseRemoteContexts() {
|
||||
}
|
||||
|
||||
counter.Wait();
|
||||
|
||||
remote_contexts_.clear();
|
||||
}
|
||||
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
void EagerContext::WaitForAndCloseRemoteContexts() {
|
||||
ClearCaches();
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
{
|
||||
mutex_lock l(keep_alive_thread_shutdown_mu_);
|
||||
shutting_down_ = true;
|
||||
keep_alive_thread_cv_.notify_all();
|
||||
}
|
||||
keep_alive_thread_.reset();
|
||||
|
||||
mutex_lock l(remote_state_mu_);
|
||||
if (!remote_contexts_.empty() && is_master_) {
|
||||
CloseRemoteContexts();
|
||||
}
|
||||
|
||||
default_executor_.ShutDown().IgnoreError();
|
||||
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
|
||||
{
|
||||
mutex_lock l(executor_map_mu_);
|
||||
executors_copy = thread_local_executor_;
|
||||
}
|
||||
for (const auto& it : executors_copy) {
|
||||
it.second->ShutDown().IgnoreError();
|
||||
}
|
||||
|
||||
// This shuts down the completion queue and joins the thread polling it.
|
||||
// The thread exits only after the completion queue has been drained of all
|
||||
// the events. These events' completion should invoke all remaining RPC
|
||||
// callbacks.
|
||||
// This also deletes all EagerClient instances. There should not be any
|
||||
// references to EagerClients left after all RPCs and async ops have been
|
||||
// finished.
|
||||
remote_eager_workers_ = nullptr;
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
EagerContext::~EagerContext() {
|
||||
ClearCaches();
|
||||
for (auto& entry : registered_functions_) {
|
||||
|
@ -310,8 +310,20 @@ class EagerContext : public core::RefCounted {
|
||||
// EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used
|
||||
// instead (which in-turn use WorkerService.RecvTensor RPCs).
|
||||
bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
|
||||
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
|
||||
// Closes remote eager contexts, waits for all RPCs to finish, and
|
||||
// destroys the EagerClientCache. No RPCs can be made through this context
|
||||
// after this method has been called.
|
||||
// This method exists to aid a clean shutdown. It causes all RPCs to finish
|
||||
// and remote TensorHandles to release their references to this context.
|
||||
// To avoid deadlocks, this method must not be called on the thread
|
||||
// processing RPCs because it makes RPCs and waits for their completion.
|
||||
//
|
||||
// On mobile, it just cleans the caches.
|
||||
void WaitForAndCloseRemoteContexts();
|
||||
|
||||
bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; }
|
||||
|
||||
tensorflow::Env* TFEnv() const { return env_; }
|
||||
|
@ -150,6 +150,8 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
||||
remote_workers, request->context_id(), std::move(rendezvous_creator),
|
||||
std::move(remote_mgr));
|
||||
if (!s.ok()) {
|
||||
VLOG(1) << "EagerContext::InitializeRemoteWorker failed with "
|
||||
<< s.ToString();
|
||||
delete ctx;
|
||||
return s;
|
||||
}
|
||||
@ -293,6 +295,8 @@ Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request,
|
||||
|
||||
Status EagerServiceImpl::CloseContext(const CloseContextRequest* request,
|
||||
CloseContextResponse* response) {
|
||||
VLOG(1) << "Executing EagerService::CloseContext for context "
|
||||
<< request->context_id();
|
||||
ServerContext* context = nullptr;
|
||||
if (!GetServerContext(request->context_id(), &context).ok()) {
|
||||
// Swallow the error here.
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
|
||||
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
|
||||
@ -111,7 +110,9 @@ class EagerServiceImpl {
|
||||
RecordAccess();
|
||||
}
|
||||
~ServerContext() {
|
||||
|
||||
ctx_->WaitForAndCloseRemoteContexts();
|
||||
// ctx_->RefCountIsOne() should be true here.
|
||||
// TODO(iga): Remove EagerContext refcounting.
|
||||
ctx_->Unref();
|
||||
}
|
||||
|
||||
|
@ -59,8 +59,12 @@ class GrpcEagerClient : public EagerClient {
|
||||
&stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
|
||||
response, std::move(done), nullptr, nullptr);
|
||||
|
||||
if (enqueue_dispatchers_.find(request->context_id()) !=
|
||||
enqueue_dispatchers_.end()) {
|
||||
VLOG(1) << "Sending RPC to close remote eager context "
|
||||
<< request->DebugString();
|
||||
|
||||
const auto& it = enqueue_dispatchers_.find(request->context_id());
|
||||
if (it != enqueue_dispatchers_.end()) {
|
||||
it->second.CancelCall();
|
||||
enqueue_dispatchers_.erase(request->context_id());
|
||||
} else {
|
||||
LOG(ERROR) << "Remote EagerContext with id " << request->context_id()
|
||||
|
@ -597,6 +597,16 @@ class StreamingRPCDispatcher {
|
||||
}
|
||||
}
|
||||
|
||||
// Request to cancel the current streaming call. Non-blocking.
|
||||
void CancelCall() {
|
||||
mutex_lock l(mu_);
|
||||
if (state_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
context_->TryCancel();
|
||||
state_ = nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
void CreateStreamingState() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
// ClientContext cannot be reused across calls.
|
||||
|
Loading…
Reference in New Issue
Block a user