From 8f23523ff80d58f336c6e5502bf3e7c2937244e5 Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Mon, 30 Dec 2019 10:27:19 -0800 Subject: [PATCH] Hold a reference to ensure the eager client outlives the function data op. PiperOrigin-RevId: 287567727 Change-Id: I86f060e22ab3ef7421c34a53f40008f418b99fa0 --- .../eager/cluster_function_library_runtime.cc | 4 ++-- .../eager/cluster_function_library_runtime.h | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc index 3f940284396..6f395f04290 100644 --- a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc @@ -131,7 +131,7 @@ void EagerClusterFunctionLibraryRuntime::Run( function_data = &function_data_[handle]; } - EagerClient* eager_client = function_data->eager_client; + EagerClient* eager_client = function_data->eager_client.get(); if (eager_client == nullptr) { done(errors::Internal("Could not find eager client")); return; @@ -195,7 +195,7 @@ void EagerClusterFunctionLibraryRuntime::CleanUp( function_data = &function_data_[handle]; } - EagerClient* eager_client = function_data->eager_client; + EagerClient* eager_client = function_data->eager_client.get(); if (eager_client == nullptr) { done(errors::Internal("Could not find eager client")); return; diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h index c5b7ada2241..3d21637225e 100644 --- a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h +++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h @@ -70,12 +70,16 @@ class EagerClusterFunctionLibraryRuntime struct FunctionData { const string target; - EagerClient* eager_client = nullptr; + core::RefCountPtr eager_client; std::unique_ptr op; FunctionData(const string& target, EagerClient* eager_client, std::unique_ptr op) - : target(target), eager_client(eager_client), op(std::move(op)) {} + : target(target), + eager_client(core::RefCountPtr(eager_client)), + op(std::move(op)) { + eager_client->Ref(); + } }; mutable mutex mu_;