From 72194852a227258c67192c04934a0f0adbde0b38 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Thu, 11 Apr 2019 10:18:49 -0700 Subject: [PATCH] Return a status from EagerClientCache::GetClient PiperOrigin-RevId: 243093267 --- tensorflow/c/eager/c_api.cc | 4 ++- .../core/common_runtime/eager/context.cc | 26 ++++++++++++++----- .../distributed_runtime/eager/eager_client.h | 2 +- .../rpc/eager/grpc_eager_client.cc | 8 +++--- 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index b4e90832caf..f375a7ec3fa 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -143,7 +143,9 @@ tensorflow::Status CreateRemoteContexts( request.mutable_server_def()->set_task_index(parsed_name.task); request.set_async(async); request.set_keep_alive_secs(keep_alive_secs); - auto* eager_client = remote_eager_workers->GetClient(remote_worker); + tensorflow::eager::EagerClient* eager_client; + TF_RETURN_IF_ERROR( + remote_eager_workers->GetClient(remote_worker, &eager_client)); if (eager_client == nullptr) { return tensorflow::errors::Internal( "Cannot find a client for the given target:", remote_worker); diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 97af6789116..bde58a005e7 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -175,8 +175,9 @@ void EagerContext::CloseRemoteContexts() { int i = 0; for (const auto& worker_and_context_id : remote_contexts_) { - auto* client = - remote_eager_workers_->GetClient(worker_and_context_id.first); + eager::EagerClient* client; + Status s = + remote_eager_workers_->GetClient(worker_and_context_id.first, &client); requests[i].set_context_id(worker_and_context_id.second); client->CloseContextAsync( @@ -321,8 +322,9 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) { requests[i].set_context_id(target_and_context_id.second); *requests[i].mutable_function_def() = fdef; - auto* eager_client = - remote_eager_workers_->GetClient(target_and_context_id.first); + eager::EagerClient* eager_client; + TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient( + target_and_context_id.first, &eager_client)); eager_client->RegisterFunctionAsync( &requests[i], &responses[i], @@ -409,7 +411,8 @@ Status EagerContext::GetClientAndContextID(Device* device, string device_task_name; TF_RETURN_IF_ERROR(GetTaskName(device, &device_task_name)); - *client = remote_eager_workers_->GetClient(device_task_name); + TF_RETURN_IF_ERROR( + remote_eager_workers_->GetClient(device_task_name, client)); if (*client == nullptr) { return errors::InvalidArgument( @@ -537,8 +540,17 @@ Status EagerContext::InitializeRemote( if (keep_alive_secs_ > 0) { { for (const auto& worker_and_context_id : remote_contexts_) { - auto* client = remote_eager_workers_->GetClient( - worker_and_context_id.first); + eager::EagerClient* client; + Status s = remote_eager_workers_->GetClient( + worker_and_context_id.first, &client); + + if (!s.ok()) { + LOG(WARNING) << "Keep-alive thread was unable to find " + "a client for target " + << worker_and_context_id.first + << ". Got error: " << s; + continue; + } eager::KeepAliveRequest* request = new eager::KeepAliveRequest; diff --git a/tensorflow/core/distributed_runtime/eager/eager_client.h b/tensorflow/core/distributed_runtime/eager/eager_client.h index 707f3234b97..3d1cb7af2e0 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_client.h +++ b/tensorflow/core/distributed_runtime/eager/eager_client.h @@ -48,7 +48,7 @@ class EagerClient { class EagerClientCache { public: virtual ~EagerClientCache() {} - virtual EagerClient* GetClient(const string& target) = 0; + virtual Status GetClient(const string& target, EagerClient** client) = 0; }; } // namespace eager diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc index 53e3955d694..4a8e19dddf7 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc @@ -66,21 +66,19 @@ class GrpcEagerClientCache : public EagerClientCache { ~GrpcEagerClientCache() override { threads_.clear(); } - EagerClient* GetClient(const string& target) override { + Status GetClient(const string& target, EagerClient** client) override { auto it = clients_.find(target); if (it == clients_.end()) { tensorflow::SharedGrpcChannelPtr shared = cache_->FindWorkerChannel(target); - // TODO(b/129072590): The check here is to prevent a segfault if 'target' - // is unknown. Return a Status here instead. - CHECK(shared) << "Unknown gRPC target " << target; auto worker = std::unique_ptr(new GrpcEagerClient( shared, threads_[AssignClientToThread(target)].completion_queue())); it = clients_.emplace(target, std::move(worker)).first; } - return it->second.get(); + *client = it->second.get(); + return Status::OK(); } private: