diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 42313f852c4..f2ec8331d8c 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -677,7 +677,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def, LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); - LOG_AND_RETURN_IF_ERROR(ctx->context.StoreCollectiveOpsServer( + LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer( std::move(server), grpc_server->worker_env()->device_mgr, grpc_server->worker_env()->collective_executor_mgr)); diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 02424c1f17b..1ef6ddd25ff 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -228,7 +228,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( tensorflow::gtl::FlatMap remote_contexts; LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( remote_workers, rendezvous_id, keep_alive_secs, server_def, - remote_eager_workers.get(), ctx->context.Async(), &remote_contexts)); + remote_eager_workers.get(), ctx->context->Async(), &remote_contexts)); tensorflow::RemoteRendezvous* r = grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); @@ -247,7 +247,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( auto* device_mgr = grpc_server->worker_env()->device_mgr; - return ctx->context.InitializeRemote( + return ctx->context->InitializeRemote( std::move(server), std::move(remote_eager_workers), std::move(remote_device_mgr), remote_contexts, r, device_mgr, keep_alive_secs); @@ -350,7 +350,7 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, unsigned char enable, TF_Status* status) { - status->status = ctx->context.SetAsyncForThread(enable); + status->status = ctx->context->SetAsyncForThread(enable); } void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } @@ -390,15 +390,15 @@ void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* list = new TF_DeviceList; - ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response); - if (ctx->context.remote_device_mgr()) { - ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response); + ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response); + if (ctx->context->remote_device_mgr()) { + ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response); } return list; } void TFE_ContextClearCaches(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->context.ClearCaches(); + status->status = ctx->context->ClearCaches(); } // Set server_def on the context, possibly updating it. @@ -424,7 +424,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { - ctx->context.SetThreadLocalDevicePlacementPolicy( + ctx->context->SetThreadLocalDevicePlacementPolicy( static_cast(policy)); } @@ -434,19 +434,19 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy( extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { return static_cast( - ctx->context.GetDevicePlacementPolicy()); + ctx->context->GetDevicePlacementPolicy()); } void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->context.AsyncWait(); + status->status = ctx->context->AsyncWait(); } void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->context.GetStatus(); + status->status = ctx->context->GetStatus(); } void TFE_ContextAsyncClearError(TFE_Context* ctx) { - ctx->context.ClearAsyncError(); + ctx->context->ClearAsyncError(); } TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { @@ -606,7 +606,7 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, return new TFE_Op(ctx, name, false, types, new TFE_OpInferenceContext(op_def)); } - if (!ctx->context.FindFunctionByName(name)) { + if (!ctx->context->FindFunctionByName(name)) { status->status = tensorflow::errors::NotFound( "'", name, "' is neither a type of a primitive operation nor a name " @@ -904,7 +904,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, const char* device_name, TF_Status* status) { tensorflow::TensorHandle* handle; - status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context, + status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context, device_name, &handle); if (status->status.ok()) { return new TFE_TensorHandle(handle); @@ -921,26 +921,26 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; } - status->status = ctx->context.AddFunctionDef(function_def); + status->status = ctx->context->AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - status->status = ctx->context.AddFunctionDef(function->fdef); + status->status = ctx->context->AddFunctionDef(function->fdef); } unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { - return ctx->context.FindFunctionDef(name) != nullptr; + return ctx->context->FindFunctionDef(name) != nullptr; } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreGraphs(true); - ctx->context.SetShouldStoreStepStats(true); + ctx->context->SetShouldStoreGraphs(true); + ctx->context->SetShouldStoreStepStats(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreGraphs(false); - ctx->context.SetShouldStoreStepStats(false); + ctx->context->SetShouldStoreGraphs(false); + ctx->context->SetShouldStoreStepStats(false); } } // extern "C" @@ -969,9 +969,9 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { TFE_ContextAsyncWait(ctx, status); if (!status->status.ok()) return; - tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); - status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); - ctx->context.ClearRunMetadata(); + tensorflow::mutex_lock ml(*ctx->context->MetadataMu()); + status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf); + ctx->context->ClearRunMetadata(); } namespace { @@ -987,9 +987,9 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, } } // namespace -void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); } +void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); } -void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); } +void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); } namespace tensorflow { void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index acf7e9c92d0..0c170ead40a 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -63,7 +63,7 @@ TFE_ProfilerContext* TFE_NewProfilerContext() { void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context, TFE_Context* eager_context) { - profiler_context->profiler_context.eager_context = &eager_context->context; + profiler_context->profiler_context.eager_context = eager_context->context; } void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) { @@ -77,11 +77,11 @@ void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port) { } void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { - ctx->context.SetShouldStoreGraphs(true); + ctx->context->SetShouldStoreGraphs(true); } void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { - ctx->context.SetShouldStoreGraphs(false); + ctx->context->SetShouldStoreGraphs(false); } bool TFE_ProfilerClientStartTracing(const char* service_addr, diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 24d647d5212..061b0e5adcd 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -62,13 +62,16 @@ struct TFE_Context { const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned, tensorflow::Rendezvous* rendezvous, const tensorflow::CustomKernelCreator* custom_kernel_creator) - : context(opts, - static_cast( - default_policy), - async, device_mgr, device_mgr_owned, rendezvous, - custom_kernel_creator) {} + : context(new tensorflow::EagerContext( + opts, + static_cast( + default_policy), + async, device_mgr, device_mgr_owned, rendezvous, + custom_kernel_creator)) {} - tensorflow::EagerContext context; + ~TFE_Context() { context->Unref(); } + + tensorflow::EagerContext* context; }; struct TFE_TensorHandle { @@ -108,7 +111,7 @@ struct TFE_Op { TFE_Op(TFE_Context* ctx, const char* op, bool is_function, const tensorflow::AttrTypeMap* t, TFE_OpInferenceContext* inference_ctx) - : operation(&ctx->context, op, is_function, t), + : operation(ctx->context, op, is_function, t), inference_ctx(inference_ctx) {} tensorflow::EagerOperation operation; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index abc733bfaa1..1d579377eb3 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -14,10 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/eager/c_api.h" -#include "tensorflow/c/eager/c_api_internal.h" #include + #include "absl/strings/match.h" +#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" @@ -297,6 +298,61 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) { TestRemoteExecuteSilentCopies(true); } +void TestRemoteExecuteDeleteTensorAfterContext(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, + TFE_DEVICE_PLACEMENT_EXPLICIT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + auto* h0_task1 = + TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteTensorHandle(h0_task0); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContext(ctx); + + // Delete tensors after context is deleted. + TFE_DeleteTensorHandle(h0_task1); + + TF_DeleteStatus(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteDeleteTensorAfterContext) { + TestRemoteExecuteDeleteTensorAfterContext(false); +} +TEST(CAPI, RemoteExecuteDeleteTensorAfterContextAsync) { + TestRemoteExecuteDeleteTensorAfterContext(true); +} + void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, const std::vector& expected_values) { std::unique_ptr status( diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index c961bc5eb68..24cb860147f 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -76,7 +76,7 @@ class RunMetadataListener { virtual void BeforeClearRunMetadata() = 0; }; -class EagerContext { +class EagerContext : public core::RefCounted { public: // TODO: remove this constructor once we migrate all callers to the next one. EagerContext(const SessionOptions& opts, diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index cc858a7fe9d..3e6a1c2376d 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/random/random.h" @@ -683,7 +684,9 @@ Status EagerLocalExecute(EagerOperation* op, std::function GetRemoteTensorDestructor( EagerContext* ctx, eager::EagerClient* eager_client, uint64 context_id, uint64 op_id, int output_num) { + ctx->Ref(); return [ctx, eager_client, context_id, op_id, output_num]() { + auto cleanup = gtl::MakeCleanup([ctx]() { ctx->Unref(); }); if (!ctx->HasActiveRemoteContext(context_id)) { // This means that this tensor was pointing to a remote device, which // has been changed out from under us. Simply return since there is diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 2b27ac826bd..448319fd3ed 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -104,10 +104,10 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, // Initialize remote tensor communication based on worker session. TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); - std::unique_ptr ctx(new tensorflow::EagerContext( + tensorflow::EagerContext* ctx = new tensorflow::EagerContext( SessionOptions(), tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, - request->async(), device_mgr, false, r, nullptr)); + request->async(), device_mgr, false, r, nullptr); std::vector device_attributes; device_mgr->ListDeviceAttributes(&device_attributes); @@ -122,9 +122,8 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, do { context_id = random::New64(); } while (contexts_.find(context_id) != contexts_.end()); - contexts_.emplace( - context_id, - new ServerContext(std::move(ctx), request->keep_alive_secs(), env_)); + contexts_.emplace(context_id, + new ServerContext(ctx, request->keep_alive_secs(), env_)); } response->set_context_id(context_id); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h index 2784c5d26e4..a33c678e789 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h @@ -104,9 +104,9 @@ class EagerServiceImpl { // and the EagerContext). class ServerContext : public core::RefCounted { public: - explicit ServerContext(std::unique_ptr ctx, + explicit ServerContext(tensorflow::EagerContext* ctx, int64 destroy_after_secs, const WorkerEnv* env) - : ctx_(std::move(ctx)), env_(env) { + : ctx_(ctx), env_(env) { destroy_after_micros_ = destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros; RecordAccess(); @@ -115,9 +115,11 @@ class EagerServiceImpl { for (const auto& entry : tensors_) { entry.second->Unref(); } + + ctx_->Unref(); } - tensorflow::EagerContext* Context() const { return ctx_.get(); } + tensorflow::EagerContext* Context() const { return ctx_; } void AddOperationOutputs( const gtl::ArraySlice& handles, @@ -179,7 +181,7 @@ class EagerServiceImpl { RemoteTensorHandleInternalEquals>; // The context for this execution. - std::unique_ptr ctx_; + tensorflow::EagerContext* ctx_; // The state related to the context for this execution. mutex tensors_mu_; diff --git a/tensorflow/lite/delegates/flex/delegate_data.cc b/tensorflow/lite/delegates/flex/delegate_data.cc index 87f37697468..1c036c2ebd7 100644 --- a/tensorflow/lite/delegates/flex/delegate_data.cc +++ b/tensorflow/lite/delegates/flex/delegate_data.cc @@ -22,7 +22,9 @@ namespace tflite { namespace flex { DelegateData::DelegateData() {} -DelegateData::~DelegateData() {} +DelegateData::~DelegateData() { + if (eager_context_) eager_context_->Unref(); +} tensorflow::Status DelegateData::Prepare( const tensorflow::SessionOptions& session_options) { @@ -40,10 +42,10 @@ tensorflow::Status DelegateData::Prepare( // Note that Rendezvous is ref-counted so it will be automatically deleted. tensorflow::Rendezvous* rendezvous = new tensorflow::IntraProcessRendezvous(device_mgr.get()); - eager_context_.reset(new tensorflow::EagerContext( + eager_context_ = new tensorflow::EagerContext( session_options, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, - /*async=*/false, std::move(device_mgr), rendezvous)); + /*async=*/false, std::move(device_mgr), rendezvous); return tensorflow::Status(); } diff --git a/tensorflow/lite/delegates/flex/delegate_data.h b/tensorflow/lite/delegates/flex/delegate_data.h index 20d6b40a5d2..5f88cfbf444 100644 --- a/tensorflow/lite/delegates/flex/delegate_data.h +++ b/tensorflow/lite/delegates/flex/delegate_data.h @@ -39,7 +39,7 @@ class DelegateData { // The EagerContext that is required for execution of Flex Ops. // Note: The context is lazily created after the first call to |Prepare()|. - tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); } + tensorflow::EagerContext* GetEagerContext() { return eager_context_; } // Map from TF Lite tensor index to TensorFlow tensor for a given context. BufferMap* GetBufferMap(const TfLiteContext* context) { @@ -48,7 +48,7 @@ class DelegateData { private: // Will be null until Prepare() is called and completes successfully. - std::unique_ptr eager_context_; + tensorflow::EagerContext* eager_context_ = nullptr; // TODO(b/112439500): Clean up stale BufferMap instances after adding the // necessary cleanup hook from a TfLiteContext to a TfLiteDelegate. std::unordered_map buffer_map_;