Refcount the EagerContext.

Currently, the only place we increase the refcount is in remote tensor handle
destructors.

Unfortunately, since tensorhandle/context lifetime is tied to python objects, it did not seem possible to enforce deletion of the context to be after the deletion of all tensor handles. This also means that if you remote tensors, you will not be able to delete the EagerContext. This is mostly fine, since we don't support deletion of eager contexts in any case, and we only run into this issue during python teardown.

PiperOrigin-RevId: 246178464
This commit is contained in:
Akshay Modi 2019-05-01 12:33:48 -07:00 committed by TensorFlower Gardener
parent 750ceb3938
commit eee7e2f0c5
11 changed files with 118 additions and 53 deletions

View File

@ -677,7 +677,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
LOG_AND_RETURN_IF_ERROR(ctx->context.StoreCollectiveOpsServer(
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
std::move(server), grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr));

View File

@ -228,7 +228,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
remote_workers, rendezvous_id, keep_alive_secs, server_def,
remote_eager_workers.get(), ctx->context.Async(), &remote_contexts));
remote_eager_workers.get(), ctx->context->Async(), &remote_contexts));
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
@ -247,7 +247,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
auto* device_mgr = grpc_server->worker_env()->device_mgr;
return ctx->context.InitializeRemote(
return ctx->context->InitializeRemote(
std::move(server), std::move(remote_eager_workers),
std::move(remote_device_mgr), remote_contexts, r, device_mgr,
keep_alive_secs);
@ -350,7 +350,7 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
unsigned char enable,
TF_Status* status) {
status->status = ctx->context.SetAsyncForThread(enable);
status->status = ctx->context->SetAsyncForThread(enable);
}
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
@ -390,15 +390,15 @@ void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList;
ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response);
if (ctx->context.remote_device_mgr()) {
ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response);
ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response);
if (ctx->context->remote_device_mgr()) {
ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response);
}
return list;
}
void TFE_ContextClearCaches(TFE_Context* ctx, TF_Status* status) {
status->status = ctx->context.ClearCaches();
status->status = ctx->context->ClearCaches();
}
// Set server_def on the context, possibly updating it.
@ -424,7 +424,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
ctx->context.SetThreadLocalDevicePlacementPolicy(
ctx->context->SetThreadLocalDevicePlacementPolicy(
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
}
@ -434,19 +434,19 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) {
return static_cast<TFE_ContextDevicePlacementPolicy>(
ctx->context.GetDevicePlacementPolicy());
ctx->context->GetDevicePlacementPolicy());
}
void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
status->status = ctx->context.AsyncWait();
status->status = ctx->context->AsyncWait();
}
void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
status->status = ctx->context.GetStatus();
status->status = ctx->context->GetStatus();
}
void TFE_ContextAsyncClearError(TFE_Context* ctx) {
ctx->context.ClearAsyncError();
ctx->context->ClearAsyncError();
}
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
@ -606,7 +606,7 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
return new TFE_Op(ctx, name, false, types,
new TFE_OpInferenceContext(op_def));
}
if (!ctx->context.FindFunctionByName(name)) {
if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
@ -904,7 +904,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
const char* device_name,
TF_Status* status) {
tensorflow::TensorHandle* handle;
status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context,
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
device_name, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);
@ -921,26 +921,26 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
return;
}
status->status = ctx->context.AddFunctionDef(function_def);
status->status = ctx->context->AddFunctionDef(function_def);
}
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
status->status = ctx->context.AddFunctionDef(function->fdef);
status->status = ctx->context->AddFunctionDef(function->fdef);
}
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
return ctx->context.FindFunctionDef(name) != nullptr;
return ctx->context->FindFunctionDef(name) != nullptr;
}
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
ctx->context.SetShouldStoreGraphs(true);
ctx->context.SetShouldStoreStepStats(true);
ctx->context->SetShouldStoreGraphs(true);
ctx->context->SetShouldStoreStepStats(true);
}
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
ctx->context.SetShouldStoreGraphs(false);
ctx->context.SetShouldStoreStepStats(false);
ctx->context->SetShouldStoreGraphs(false);
ctx->context->SetShouldStoreStepStats(false);
}
} // extern "C"
@ -969,9 +969,9 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
TFE_ContextAsyncWait(ctx, status);
if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf);
ctx->context.ClearRunMetadata();
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
ctx->context->ClearRunMetadata();
}
namespace {
@ -987,9 +987,9 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
}
} // namespace
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); }
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); }
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
namespace tensorflow {
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,

View File

@ -63,7 +63,7 @@ TFE_ProfilerContext* TFE_NewProfilerContext() {
void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context,
TFE_Context* eager_context) {
profiler_context->profiler_context.eager_context = &eager_context->context;
profiler_context->profiler_context.eager_context = eager_context->context;
}
void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) {
@ -77,11 +77,11 @@ void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port) {
}
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
ctx->context.SetShouldStoreGraphs(true);
ctx->context->SetShouldStoreGraphs(true);
}
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
ctx->context.SetShouldStoreGraphs(false);
ctx->context->SetShouldStoreGraphs(false);
}
bool TFE_ProfilerClientStartTracing(const char* service_addr,

View File

@ -62,13 +62,16 @@ struct TFE_Context {
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
tensorflow::Rendezvous* rendezvous,
const tensorflow::CustomKernelCreator* custom_kernel_creator)
: context(opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_policy),
async, device_mgr, device_mgr_owned, rendezvous,
custom_kernel_creator) {}
: context(new tensorflow::EagerContext(
opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_policy),
async, device_mgr, device_mgr_owned, rendezvous,
custom_kernel_creator)) {}
tensorflow::EagerContext context;
~TFE_Context() { context->Unref(); }
tensorflow::EagerContext* context;
};
struct TFE_TensorHandle {
@ -108,7 +111,7 @@ struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
TFE_OpInferenceContext* inference_ctx)
: operation(&ctx->context, op, is_function, t),
: operation(ctx->context, op, is_function, t),
inference_ctx(inference_ctx) {}
tensorflow::EagerOperation operation;

View File

@ -14,10 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include <string.h>
#include "absl/strings/match.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.pb.h"
@ -297,6 +298,61 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true);
}
void TestRemoteExecuteDeleteTensorAfterContext(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(h0_task0);
TFE_ContextAsyncWait(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContext(ctx);
// Delete tensors after context is deleted.
TFE_DeleteTensorHandle(h0_task1);
TF_DeleteStatus(status);
// TODO(nareshmodi): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteDeleteTensorAfterContext) {
TestRemoteExecuteDeleteTensorAfterContext(false);
}
TEST(CAPI, RemoteExecuteDeleteTensorAfterContextAsync) {
TestRemoteExecuteDeleteTensorAfterContext(true);
}
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
const std::vector<float>& expected_values) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(

View File

@ -76,7 +76,7 @@ class RunMetadataListener {
virtual void BeforeClearRunMetadata() = 0;
};
class EagerContext {
class EagerContext : public core::RefCounted {
public:
// TODO: remove this constructor once we migrate all callers to the next one.
EagerContext(const SessionOptions& opts,

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/random/random.h"
@ -683,7 +684,9 @@ Status EagerLocalExecute(EagerOperation* op,
std::function<void()> GetRemoteTensorDestructor(
EagerContext* ctx, eager::EagerClient* eager_client, uint64 context_id,
uint64 op_id, int output_num) {
ctx->Ref();
return [ctx, eager_client, context_id, op_id, output_num]() {
auto cleanup = gtl::MakeCleanup([ctx]() { ctx->Unref(); });
if (!ctx->HasActiveRemoteContext(context_id)) {
// This means that this tensor was pointing to a remote device, which
// has been changed out from under us. Simply return since there is

View File

@ -104,10 +104,10 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
// Initialize remote tensor communication based on worker session.
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
std::unique_ptr<tensorflow::EagerContext> ctx(new tensorflow::EagerContext(
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
request->async(), device_mgr, false, r, nullptr));
request->async(), device_mgr, false, r, nullptr);
std::vector<DeviceAttributes> device_attributes;
device_mgr->ListDeviceAttributes(&device_attributes);
@ -122,9 +122,8 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
do {
context_id = random::New64();
} while (contexts_.find(context_id) != contexts_.end());
contexts_.emplace(
context_id,
new ServerContext(std::move(ctx), request->keep_alive_secs(), env_));
contexts_.emplace(context_id,
new ServerContext(ctx, request->keep_alive_secs(), env_));
}
response->set_context_id(context_id);

View File

@ -104,9 +104,9 @@ class EagerServiceImpl {
// and the EagerContext).
class ServerContext : public core::RefCounted {
public:
explicit ServerContext(std::unique_ptr<tensorflow::EagerContext> ctx,
explicit ServerContext(tensorflow::EagerContext* ctx,
int64 destroy_after_secs, const WorkerEnv* env)
: ctx_(std::move(ctx)), env_(env) {
: ctx_(ctx), env_(env) {
destroy_after_micros_ =
destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros;
RecordAccess();
@ -115,9 +115,11 @@ class EagerServiceImpl {
for (const auto& entry : tensors_) {
entry.second->Unref();
}
ctx_->Unref();
}
tensorflow::EagerContext* Context() const { return ctx_.get(); }
tensorflow::EagerContext* Context() const { return ctx_; }
void AddOperationOutputs(
const gtl::ArraySlice<tensorflow::TensorHandle*>& handles,
@ -179,7 +181,7 @@ class EagerServiceImpl {
RemoteTensorHandleInternalEquals>;
// The context for this execution.
std::unique_ptr<tensorflow::EagerContext> ctx_;
tensorflow::EagerContext* ctx_;
// The state related to the context for this execution.
mutex tensors_mu_;

View File

@ -22,7 +22,9 @@ namespace tflite {
namespace flex {
DelegateData::DelegateData() {}
DelegateData::~DelegateData() {}
DelegateData::~DelegateData() {
if (eager_context_) eager_context_->Unref();
}
tensorflow::Status DelegateData::Prepare(
const tensorflow::SessionOptions& session_options) {
@ -40,10 +42,10 @@ tensorflow::Status DelegateData::Prepare(
// Note that Rendezvous is ref-counted so it will be automatically deleted.
tensorflow::Rendezvous* rendezvous =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
eager_context_.reset(new tensorflow::EagerContext(
eager_context_ = new tensorflow::EagerContext(
session_options,
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
/*async=*/false, std::move(device_mgr), rendezvous));
/*async=*/false, std::move(device_mgr), rendezvous);
return tensorflow::Status();
}

View File

@ -39,7 +39,7 @@ class DelegateData {
// The EagerContext that is required for execution of Flex Ops.
// Note: The context is lazily created after the first call to |Prepare()|.
tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); }
tensorflow::EagerContext* GetEagerContext() { return eager_context_; }
// Map from TF Lite tensor index to TensorFlow tensor for a given context.
BufferMap* GetBufferMap(const TfLiteContext* context) {
@ -48,7 +48,7 @@ class DelegateData {
private:
// Will be null until Prepare() is called and completes successfully.
std::unique_ptr<tensorflow::EagerContext> eager_context_;
tensorflow::EagerContext* eager_context_ = nullptr;
// TODO(b/112439500): Clean up stale BufferMap instances after adding the
// necessary cleanup hook from a TfLiteContext to a TfLiteDelegate.
std::unordered_map<const TfLiteContext*, BufferMap> buffer_map_;