Return a status from EagerClientCache::GetClient

PiperOrigin-RevId: 243093267
This commit is contained in:
Akshay Modi 2019-04-11 10:18:49 -07:00 committed by TensorFlower Gardener
parent 42df8cc49c
commit 72194852a2
4 changed files with 26 additions and 14 deletions

View File

@ -143,7 +143,9 @@ tensorflow::Status CreateRemoteContexts(
request.mutable_server_def()->set_task_index(parsed_name.task);
request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
auto* eager_client = remote_eager_workers->GetClient(remote_worker);
tensorflow::eager::EagerClient* eager_client;
TF_RETURN_IF_ERROR(
remote_eager_workers->GetClient(remote_worker, &eager_client));
if (eager_client == nullptr) {
return tensorflow::errors::Internal(
"Cannot find a client for the given target:", remote_worker);

View File

@ -175,8 +175,9 @@ void EagerContext::CloseRemoteContexts() {
int i = 0;
for (const auto& worker_and_context_id : remote_contexts_) {
auto* client =
remote_eager_workers_->GetClient(worker_and_context_id.first);
eager::EagerClient* client;
Status s =
remote_eager_workers_->GetClient(worker_and_context_id.first, &client);
requests[i].set_context_id(worker_and_context_id.second);
client->CloseContextAsync(
@ -321,8 +322,9 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
requests[i].set_context_id(target_and_context_id.second);
*requests[i].mutable_function_def() = fdef;
auto* eager_client =
remote_eager_workers_->GetClient(target_and_context_id.first);
eager::EagerClient* eager_client;
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(
target_and_context_id.first, &eager_client));
eager_client->RegisterFunctionAsync(
&requests[i], &responses[i],
@ -409,7 +411,8 @@ Status EagerContext::GetClientAndContextID(Device* device,
string device_task_name;
TF_RETURN_IF_ERROR(GetTaskName(device, &device_task_name));
*client = remote_eager_workers_->GetClient(device_task_name);
TF_RETURN_IF_ERROR(
remote_eager_workers_->GetClient(device_task_name, client));
if (*client == nullptr) {
return errors::InvalidArgument(
@ -537,8 +540,17 @@ Status EagerContext::InitializeRemote(
if (keep_alive_secs_ > 0) {
{
for (const auto& worker_and_context_id : remote_contexts_) {
auto* client = remote_eager_workers_->GetClient(
worker_and_context_id.first);
eager::EagerClient* client;
Status s = remote_eager_workers_->GetClient(
worker_and_context_id.first, &client);
if (!s.ok()) {
LOG(WARNING) << "Keep-alive thread was unable to find "
"a client for target "
<< worker_and_context_id.first
<< ". Got error: " << s;
continue;
}
eager::KeepAliveRequest* request =
new eager::KeepAliveRequest;

View File

@ -48,7 +48,7 @@ class EagerClient {
class EagerClientCache {
public:
virtual ~EagerClientCache() {}
virtual EagerClient* GetClient(const string& target) = 0;
virtual Status GetClient(const string& target, EagerClient** client) = 0;
};
} // namespace eager

View File

@ -66,21 +66,19 @@ class GrpcEagerClientCache : public EagerClientCache {
~GrpcEagerClientCache() override { threads_.clear(); }
EagerClient* GetClient(const string& target) override {
Status GetClient(const string& target, EagerClient** client) override {
auto it = clients_.find(target);
if (it == clients_.end()) {
tensorflow::SharedGrpcChannelPtr shared =
cache_->FindWorkerChannel(target);
// TODO(b/129072590): The check here is to prevent a segfault if 'target'
// is unknown. Return a Status here instead.
CHECK(shared) << "Unknown gRPC target " << target;
auto worker = std::unique_ptr<EagerClient>(new GrpcEagerClient(
shared, threads_[AssignClientToThread(target)].completion_queue()));
it = clients_.emplace(target, std::move(worker)).first;
}
return it->second.get();
*client = it->second.get();
return Status::OK();
}
private: