Return a status from EagerClientCache::GetClient
PiperOrigin-RevId: 243093267
This commit is contained in:
parent
42df8cc49c
commit
72194852a2
@ -143,7 +143,9 @@ tensorflow::Status CreateRemoteContexts(
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.set_async(async);
|
||||
request.set_keep_alive_secs(keep_alive_secs);
|
||||
auto* eager_client = remote_eager_workers->GetClient(remote_worker);
|
||||
tensorflow::eager::EagerClient* eager_client;
|
||||
TF_RETURN_IF_ERROR(
|
||||
remote_eager_workers->GetClient(remote_worker, &eager_client));
|
||||
if (eager_client == nullptr) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Cannot find a client for the given target:", remote_worker);
|
||||
|
@ -175,8 +175,9 @@ void EagerContext::CloseRemoteContexts() {
|
||||
|
||||
int i = 0;
|
||||
for (const auto& worker_and_context_id : remote_contexts_) {
|
||||
auto* client =
|
||||
remote_eager_workers_->GetClient(worker_and_context_id.first);
|
||||
eager::EagerClient* client;
|
||||
Status s =
|
||||
remote_eager_workers_->GetClient(worker_and_context_id.first, &client);
|
||||
|
||||
requests[i].set_context_id(worker_and_context_id.second);
|
||||
client->CloseContextAsync(
|
||||
@ -321,8 +322,9 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
|
||||
requests[i].set_context_id(target_and_context_id.second);
|
||||
*requests[i].mutable_function_def() = fdef;
|
||||
|
||||
auto* eager_client =
|
||||
remote_eager_workers_->GetClient(target_and_context_id.first);
|
||||
eager::EagerClient* eager_client;
|
||||
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(
|
||||
target_and_context_id.first, &eager_client));
|
||||
|
||||
eager_client->RegisterFunctionAsync(
|
||||
&requests[i], &responses[i],
|
||||
@ -409,7 +411,8 @@ Status EagerContext::GetClientAndContextID(Device* device,
|
||||
string device_task_name;
|
||||
TF_RETURN_IF_ERROR(GetTaskName(device, &device_task_name));
|
||||
|
||||
*client = remote_eager_workers_->GetClient(device_task_name);
|
||||
TF_RETURN_IF_ERROR(
|
||||
remote_eager_workers_->GetClient(device_task_name, client));
|
||||
|
||||
if (*client == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
@ -537,8 +540,17 @@ Status EagerContext::InitializeRemote(
|
||||
if (keep_alive_secs_ > 0) {
|
||||
{
|
||||
for (const auto& worker_and_context_id : remote_contexts_) {
|
||||
auto* client = remote_eager_workers_->GetClient(
|
||||
worker_and_context_id.first);
|
||||
eager::EagerClient* client;
|
||||
Status s = remote_eager_workers_->GetClient(
|
||||
worker_and_context_id.first, &client);
|
||||
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Keep-alive thread was unable to find "
|
||||
"a client for target "
|
||||
<< worker_and_context_id.first
|
||||
<< ". Got error: " << s;
|
||||
continue;
|
||||
}
|
||||
|
||||
eager::KeepAliveRequest* request =
|
||||
new eager::KeepAliveRequest;
|
||||
|
@ -48,7 +48,7 @@ class EagerClient {
|
||||
class EagerClientCache {
|
||||
public:
|
||||
virtual ~EagerClientCache() {}
|
||||
virtual EagerClient* GetClient(const string& target) = 0;
|
||||
virtual Status GetClient(const string& target, EagerClient** client) = 0;
|
||||
};
|
||||
|
||||
} // namespace eager
|
||||
|
@ -66,21 +66,19 @@ class GrpcEagerClientCache : public EagerClientCache {
|
||||
|
||||
~GrpcEagerClientCache() override { threads_.clear(); }
|
||||
|
||||
EagerClient* GetClient(const string& target) override {
|
||||
Status GetClient(const string& target, EagerClient** client) override {
|
||||
auto it = clients_.find(target);
|
||||
if (it == clients_.end()) {
|
||||
tensorflow::SharedGrpcChannelPtr shared =
|
||||
cache_->FindWorkerChannel(target);
|
||||
// TODO(b/129072590): The check here is to prevent a segfault if 'target'
|
||||
// is unknown. Return a Status here instead.
|
||||
CHECK(shared) << "Unknown gRPC target " << target;
|
||||
auto worker = std::unique_ptr<EagerClient>(new GrpcEagerClient(
|
||||
shared, threads_[AssignClientToThread(target)].completion_queue()));
|
||||
|
||||
it = clients_.emplace(target, std::move(worker)).first;
|
||||
}
|
||||
|
||||
return it->second.get();
|
||||
*client = it->second.get();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
|
Loading…
Reference in New Issue
Block a user