diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc index de4f36ea24d..752bfdf71a1 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc @@ -106,8 +106,8 @@ class GrpcEagerClientThread : public core::RefCounted { class GrpcEagerClient : public EagerClient { public: GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel, - GrpcEagerClientThread* thread) - : stub_(channel), thread_(thread) { + GrpcEagerClientThread* thread, const string& target) + : stub_(channel), thread_(thread), target_(target) { // Hold a reference to make sure the corresponding EagerClientThread // outlives the client. thread_->Ref(); @@ -127,7 +127,8 @@ class GrpcEagerClient : public EagerClient { new RPCState( \ &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \ response, std::move(done_wrapped), /*call_opts=*/nullptr, \ - /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true); \ + /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true, \ + &target_); \ } CLIENT_METHOD(CreateContext); @@ -146,7 +147,8 @@ class GrpcEagerClient : public EagerClient { new RPCState( &stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request, response, std::move(done_wrapped), /*call_opts=*/nullptr, - /*threadpool=*/nullptr); + /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true, + &target_); VLOG(1) << "Sending RPC to close remote eager context " << request->DebugString(); @@ -194,6 +196,7 @@ class GrpcEagerClient : public EagerClient { private: ::grpc::GenericStub stub_; const GrpcEagerClientThread* thread_; + const string target_; ::grpc::CompletionQueue* cq_; @@ -236,7 +239,7 @@ class GrpcEagerClientCache : public EagerClientCache { int assigned_index = AssignClientToThread(target); GrpcEagerClientThread* thread = threads_[assigned_index].get(); core::RefCountPtr worker( - new GrpcEagerClient(shared, thread)); + new GrpcEagerClient(shared, thread, target)); it = clients_.emplace(target, std::move(worker)).first; } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc index 272d6bb1b20..bcb98baaeb9 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc @@ -210,7 +210,8 @@ void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t, get_stub(index), &completion_queue_, *get_method_ptr(index), call->request(), call->response(), /*done=*/[call](const Status& s) { call->Done(s); }, call->call_opts(), - nullptr /*threadpool*/, fail_fast_, timeout_in_ms_, 0 /* max_retries */); + /*threadpool=*/nullptr, fail_fast_, timeout_in_ms_, /*max_retries=*/0, + /*target=*/nullptr); } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h index c72ba6035a4..041b6e51ffb 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h @@ -45,7 +45,7 @@ class RPCState : public GrpcClientCQTag { const ::grpc::string& method, const protobuf::Message& request, Response* response, StatusCallback done, CallOptions* call_opts, thread::ThreadPool* threadpool, int32 max_retries = 0, - bool fail_fast = true) + bool fail_fast = true, const string* target = nullptr) : RPCState( stub, cq, method, request, response, std::move(done), call_opts, threadpool, @@ -63,7 +63,7 @@ class RPCState : public GrpcClientCQTag { #endif // PLATFORM_GOOGLE return x; }(), - /*timeout_in_ms=*/0, max_retries) { + /*timeout_in_ms=*/0, max_retries, target) { } template @@ -71,7 +71,7 @@ class RPCState : public GrpcClientCQTag { const ::grpc::string& method, const Request& request, Response* response, StatusCallback done, CallOptions* call_opts, thread::ThreadPool* threadpool, bool fail_fast, int64 timeout_in_ms, - int32 max_retries) + int32 max_retries, const string* target) : call_opts_(call_opts), threadpool_(threadpool), done_(std::move(done)), @@ -80,7 +80,8 @@ class RPCState : public GrpcClientCQTag { cq_(cq), stub_(stub), method_(method), - fail_fast_(fail_fast) { + fail_fast_(fail_fast), + target_(target) { response_ = response; ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf_); if (!s.ok()) { @@ -152,10 +153,13 @@ class RPCState : public GrpcClientCQTag { StartCall(); } else { // Attach additional GRPC error information if any to the final status - s = Status(s.code(), - strings::StrCat(s.error_message(), - "\nAdditional GRPC error information:\n", - context_->debug_error_string())); + string error_msg = s.error_message(); + strings::StrAppend(&error_msg, "\nAdditional GRPC error information"); + if (target_) { + strings::StrAppend(&error_msg, " from remote target ", *target_); + } + strings::StrAppend(&error_msg, ":\n:", context_->debug_error_string()); + s = Status(s.code(), error_msg); // Always treat gRPC cancellation as a derived error. This ensures that // other error types are preferred during status aggregation. (gRPC // cancellation messages do not contain the original status message). @@ -196,6 +200,7 @@ class RPCState : public GrpcClientCQTag { ::grpc::GenericStub* stub_; ::grpc::string method_; bool fail_fast_; + const string* target_; }; // Represents state associated with one streaming RPC call.