From eee7e2f0c572a7ccd57570d7e154e1f87e338337 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Wed, 1 May 2019 12:33:48 -0700 Subject: [PATCH] Refcount the EagerContext. Currently, the only place we increase the refcount is in remote tensor handle destructors. Unfortunately, since tensorhandle/context lifetime is tied to python objects, it did not seem possible to enforce deletion of the context to be after the deletion of all tensor handles. This also means that if you remote tensors, you will not be able to delete the EagerContext. This is mostly fine, since we don't support deletion of eager contexts in any case, and we only run into this issue during python teardown. PiperOrigin-RevId: 246178464 --- tensorflow/c/c_api_experimental.cc | 2 +- tensorflow/c/eager/c_api.cc | 52 ++++++++--------- tensorflow/c/eager/c_api_experimental.cc | 6 +- tensorflow/c/eager/c_api_internal.h | 17 +++--- tensorflow/c/eager/c_api_test.cc | 58 ++++++++++++++++++- .../core/common_runtime/eager/context.h | 2 +- .../core/common_runtime/eager/execute.cc | 3 + .../eager/eager_service_impl.cc | 9 ++- .../eager/eager_service_impl.h | 10 ++-- .../lite/delegates/flex/delegate_data.cc | 8 ++- .../lite/delegates/flex/delegate_data.h | 4 +- 11 files changed, 118 insertions(+), 53 deletions(-) diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 42313f852c4..f2ec8331d8c 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -677,7 +677,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def, LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); - LOG_AND_RETURN_IF_ERROR(ctx->context.StoreCollectiveOpsServer( + LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer( std::move(server), grpc_server->worker_env()->device_mgr, grpc_server->worker_env()->collective_executor_mgr)); diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 02424c1f17b..1ef6ddd25ff 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -228,7 +228,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( tensorflow::gtl::FlatMap remote_contexts; LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( remote_workers, rendezvous_id, keep_alive_secs, server_def, - remote_eager_workers.get(), ctx->context.Async(), &remote_contexts)); + remote_eager_workers.get(), ctx->context->Async(), &remote_contexts)); tensorflow::RemoteRendezvous* r = grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); @@ -247,7 +247,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( auto* device_mgr = grpc_server->worker_env()->device_mgr; - return ctx->context.InitializeRemote( + return ctx->context->InitializeRemote( std::move(server), std::move(remote_eager_workers), std::move(remote_device_mgr), remote_contexts, r, device_mgr, keep_alive_secs); @@ -350,7 +350,7 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, unsigned char enable, TF_Status* status) { - status->status = ctx->context.SetAsyncForThread(enable); + status->status = ctx->context->SetAsyncForThread(enable); } void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } @@ -390,15 +390,15 @@ void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* list = new TF_DeviceList; - ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response); - if (ctx->context.remote_device_mgr()) { - ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response); + ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response); + if (ctx->context->remote_device_mgr()) { + ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response); } return list; } void TFE_ContextClearCaches(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->context.ClearCaches(); + status->status = ctx->context->ClearCaches(); } // Set server_def on the context, possibly updating it. @@ -424,7 +424,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { - ctx->context.SetThreadLocalDevicePlacementPolicy( + ctx->context->SetThreadLocalDevicePlacementPolicy( static_cast(policy)); } @@ -434,19 +434,19 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy( extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { return static_cast( - ctx->context.GetDevicePlacementPolicy()); + ctx->context->GetDevicePlacementPolicy()); } void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->context.AsyncWait(); + status->status = ctx->context->AsyncWait(); } void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->context.GetStatus(); + status->status = ctx->context->GetStatus(); } void TFE_ContextAsyncClearError(TFE_Context* ctx) { - ctx->context.ClearAsyncError(); + ctx->context->ClearAsyncError(); } TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { @@ -606,7 +606,7 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, return new TFE_Op(ctx, name, false, types, new TFE_OpInferenceContext(op_def)); } - if (!ctx->context.FindFunctionByName(name)) { + if (!ctx->context->FindFunctionByName(name)) { status->status = tensorflow::errors::NotFound( "'", name, "' is neither a type of a primitive operation nor a name " @@ -904,7 +904,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, const char* device_name, TF_Status* status) { tensorflow::TensorHandle* handle; - status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context, + status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context, device_name, &handle); if (status->status.ok()) { return new TFE_TensorHandle(handle); @@ -921,26 +921,26 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; } - status->status = ctx->context.AddFunctionDef(function_def); + status->status = ctx->context->AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - status->status = ctx->context.AddFunctionDef(function->fdef); + status->status = ctx->context->AddFunctionDef(function->fdef); } unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { - return ctx->context.FindFunctionDef(name) != nullptr; + return ctx->context->FindFunctionDef(name) != nullptr; } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreGraphs(true); - ctx->context.SetShouldStoreStepStats(true); + ctx->context->SetShouldStoreGraphs(true); + ctx->context->SetShouldStoreStepStats(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreGraphs(false); - ctx->context.SetShouldStoreStepStats(false); + ctx->context->SetShouldStoreGraphs(false); + ctx->context->SetShouldStoreStepStats(false); } } // extern "C" @@ -969,9 +969,9 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { TFE_ContextAsyncWait(ctx, status); if (!status->status.ok()) return; - tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); - status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); - ctx->context.ClearRunMetadata(); + tensorflow::mutex_lock ml(*ctx->context->MetadataMu()); + status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf); + ctx->context->ClearRunMetadata(); } namespace { @@ -987,9 +987,9 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, } } // namespace -void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); } +void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); } -void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); } +void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); } namespace tensorflow { void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index acf7e9c92d0..0c170ead40a 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -63,7 +63,7 @@ TFE_ProfilerContext* TFE_NewProfilerContext() { void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context, TFE_Context* eager_context) { - profiler_context->profiler_context.eager_context = &eager_context->context; + profiler_context->profiler_context.eager_context = eager_context->context; } void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) { @@ -77,11 +77,11 @@ void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port) { } void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { - ctx->context.SetShouldStoreGraphs(true); + ctx->context->SetShouldStoreGraphs(true); } void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { - ctx->context.SetShouldStoreGraphs(false); + ctx->context->SetShouldStoreGraphs(false); } bool TFE_ProfilerClientStartTracing(const char* service_addr, diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 24d647d5212..061b0e5adcd 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -62,13 +62,16 @@ struct TFE_Context { const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned, tensorflow::Rendezvous* rendezvous, const tensorflow::CustomKernelCreator* custom_kernel_creator) - : context(opts, - static_cast( - default_policy), - async, device_mgr, device_mgr_owned, rendezvous, - custom_kernel_creator) {} + : context(new tensorflow::EagerContext( + opts, + static_cast( + default_policy), + async, device_mgr, device_mgr_owned, rendezvous, + custom_kernel_creator)) {} - tensorflow::EagerContext context; + ~TFE_Context() { context->Unref(); } + + tensorflow::EagerContext* context; }; struct TFE_TensorHandle { @@ -108,7 +111,7 @@ struct TFE_Op { TFE_Op(TFE_Context* ctx, const char* op, bool is_function, const tensorflow::AttrTypeMap* t, TFE_OpInferenceContext* inference_ctx) - : operation(&ctx->context, op, is_function, t), + : operation(ctx->context, op, is_function, t), inference_ctx(inference_ctx) {} tensorflow::EagerOperation operation; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index abc733bfaa1..1d579377eb3 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -14,10 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/eager/c_api.h" -#include "tensorflow/c/eager/c_api_internal.h" #include + #include "absl/strings/match.h" +#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" @@ -297,6 +298,61 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) { TestRemoteExecuteSilentCopies(true); } +void TestRemoteExecuteDeleteTensorAfterContext(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, + TFE_DEVICE_PLACEMENT_EXPLICIT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + auto* h0_task1 = + TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteTensorHandle(h0_task0); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContext(ctx); + + // Delete tensors after context is deleted. + TFE_DeleteTensorHandle(h0_task1); + + TF_DeleteStatus(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteDeleteTensorAfterContext) { + TestRemoteExecuteDeleteTensorAfterContext(false); +} +TEST(CAPI, RemoteExecuteDeleteTensorAfterContextAsync) { + TestRemoteExecuteDeleteTensorAfterContext(true); +} + void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, const std::vector& expected_values) { std::unique_ptr status( diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index c961bc5eb68..24cb860147f 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -76,7 +76,7 @@ class RunMetadataListener { virtual void BeforeClearRunMetadata() = 0; }; -class EagerContext { +class EagerContext : public core::RefCounted { public: // TODO: remove this constructor once we migrate all callers to the next one. EagerContext(const SessionOptions& opts, diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index cc858a7fe9d..3e6a1c2376d 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/random/random.h" @@ -683,7 +684,9 @@ Status EagerLocalExecute(EagerOperation* op, std::function GetRemoteTensorDestructor( EagerContext* ctx, eager::EagerClient* eager_client, uint64 context_id, uint64 op_id, int output_num) { + ctx->Ref(); return [ctx, eager_client, context_id, op_id, output_num]() { + auto cleanup = gtl::MakeCleanup([ctx]() { ctx->Unref(); }); if (!ctx->HasActiveRemoteContext(context_id)) { // This means that this tensor was pointing to a remote device, which // has been changed out from under us. Simply return since there is diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 2b27ac826bd..448319fd3ed 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -104,10 +104,10 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, // Initialize remote tensor communication based on worker session. TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); - std::unique_ptr ctx(new tensorflow::EagerContext( + tensorflow::EagerContext* ctx = new tensorflow::EagerContext( SessionOptions(), tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, - request->async(), device_mgr, false, r, nullptr)); + request->async(), device_mgr, false, r, nullptr); std::vector device_attributes; device_mgr->ListDeviceAttributes(&device_attributes); @@ -122,9 +122,8 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, do { context_id = random::New64(); } while (contexts_.find(context_id) != contexts_.end()); - contexts_.emplace( - context_id, - new ServerContext(std::move(ctx), request->keep_alive_secs(), env_)); + contexts_.emplace(context_id, + new ServerContext(ctx, request->keep_alive_secs(), env_)); } response->set_context_id(context_id); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h index 2784c5d26e4..a33c678e789 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h @@ -104,9 +104,9 @@ class EagerServiceImpl { // and the EagerContext). class ServerContext : public core::RefCounted { public: - explicit ServerContext(std::unique_ptr ctx, + explicit ServerContext(tensorflow::EagerContext* ctx, int64 destroy_after_secs, const WorkerEnv* env) - : ctx_(std::move(ctx)), env_(env) { + : ctx_(ctx), env_(env) { destroy_after_micros_ = destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros; RecordAccess(); @@ -115,9 +115,11 @@ class EagerServiceImpl { for (const auto& entry : tensors_) { entry.second->Unref(); } + + ctx_->Unref(); } - tensorflow::EagerContext* Context() const { return ctx_.get(); } + tensorflow::EagerContext* Context() const { return ctx_; } void AddOperationOutputs( const gtl::ArraySlice& handles, @@ -179,7 +181,7 @@ class EagerServiceImpl { RemoteTensorHandleInternalEquals>; // The context for this execution. - std::unique_ptr ctx_; + tensorflow::EagerContext* ctx_; // The state related to the context for this execution. mutex tensors_mu_; diff --git a/tensorflow/lite/delegates/flex/delegate_data.cc b/tensorflow/lite/delegates/flex/delegate_data.cc index 87f37697468..1c036c2ebd7 100644 --- a/tensorflow/lite/delegates/flex/delegate_data.cc +++ b/tensorflow/lite/delegates/flex/delegate_data.cc @@ -22,7 +22,9 @@ namespace tflite { namespace flex { DelegateData::DelegateData() {} -DelegateData::~DelegateData() {} +DelegateData::~DelegateData() { + if (eager_context_) eager_context_->Unref(); +} tensorflow::Status DelegateData::Prepare( const tensorflow::SessionOptions& session_options) { @@ -40,10 +42,10 @@ tensorflow::Status DelegateData::Prepare( // Note that Rendezvous is ref-counted so it will be automatically deleted. tensorflow::Rendezvous* rendezvous = new tensorflow::IntraProcessRendezvous(device_mgr.get()); - eager_context_.reset(new tensorflow::EagerContext( + eager_context_ = new tensorflow::EagerContext( session_options, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, - /*async=*/false, std::move(device_mgr), rendezvous)); + /*async=*/false, std::move(device_mgr), rendezvous); return tensorflow::Status(); } diff --git a/tensorflow/lite/delegates/flex/delegate_data.h b/tensorflow/lite/delegates/flex/delegate_data.h index 20d6b40a5d2..5f88cfbf444 100644 --- a/tensorflow/lite/delegates/flex/delegate_data.h +++ b/tensorflow/lite/delegates/flex/delegate_data.h @@ -39,7 +39,7 @@ class DelegateData { // The EagerContext that is required for execution of Flex Ops. // Note: The context is lazily created after the first call to |Prepare()|. - tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); } + tensorflow::EagerContext* GetEagerContext() { return eager_context_; } // Map from TF Lite tensor index to TensorFlow tensor for a given context. BufferMap* GetBufferMap(const TfLiteContext* context) { @@ -48,7 +48,7 @@ class DelegateData { private: // Will be null until Prepare() is called and completes successfully. - std::unique_ptr eager_context_; + tensorflow::EagerContext* eager_context_ = nullptr; // TODO(b/112439500): Clean up stale BufferMap instances after adding the // necessary cleanup hook from a TfLiteContext to a TfLiteDelegate. std::unordered_map buffer_map_;