Report remote target in error messages for gRPC eager service requests.

PiperOrigin-RevId: 311634462
Change-Id: Ib0550c172e419ea17dac9ffa28c18b9e1a03b3cc
This commit is contained in:
Haoyu Zhang 2020-05-14 17:03:36 -07:00 committed by TensorFlower Gardener
parent 90077f8c7c
commit d5e0f468cd
3 changed files with 23 additions and 14 deletions

View File

@ -106,8 +106,8 @@ class GrpcEagerClientThread : public core::RefCounted {
class GrpcEagerClient : public EagerClient {
public:
GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel,
GrpcEagerClientThread* thread)
: stub_(channel), thread_(thread) {
GrpcEagerClientThread* thread, const string& target)
: stub_(channel), thread_(thread), target_(target) {
// Hold a reference to make sure the corresponding EagerClientThread
// outlives the client.
thread_->Ref();
@ -127,7 +127,8 @@ class GrpcEagerClient : public EagerClient {
new RPCState<protobuf::Message>( \
&stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
response, std::move(done_wrapped), /*call_opts=*/nullptr, \
/*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true); \
/*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true, \
&target_); \
}
CLIENT_METHOD(CreateContext);
@ -146,7 +147,8 @@ class GrpcEagerClient : public EagerClient {
new RPCState<protobuf::Message>(
&stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
response, std::move(done_wrapped), /*call_opts=*/nullptr,
/*threadpool=*/nullptr);
/*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
&target_);
VLOG(1) << "Sending RPC to close remote eager context "
<< request->DebugString();
@ -194,6 +196,7 @@ class GrpcEagerClient : public EagerClient {
private:
::grpc::GenericStub stub_;
const GrpcEagerClientThread* thread_;
const string target_;
::grpc::CompletionQueue* cq_;
@ -236,7 +239,7 @@ class GrpcEagerClientCache : public EagerClientCache {
int assigned_index = AssignClientToThread(target);
GrpcEagerClientThread* thread = threads_[assigned_index].get();
core::RefCountPtr<EagerClient> worker(
new GrpcEagerClient(shared, thread));
new GrpcEagerClient(shared, thread, target));
it = clients_.emplace(target, std::move(worker)).first;
}

View File

@ -210,7 +210,8 @@ void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t,
get_stub(index), &completion_queue_, *get_method_ptr(index),
call->request(), call->response(),
/*done=*/[call](const Status& s) { call->Done(s); }, call->call_opts(),
nullptr /*threadpool*/, fail_fast_, timeout_in_ms_, 0 /* max_retries */);
/*threadpool=*/nullptr, fail_fast_, timeout_in_ms_, /*max_retries=*/0,
/*target=*/nullptr);
}
} // namespace tensorflow

View File

@ -45,7 +45,7 @@ class RPCState : public GrpcClientCQTag {
const ::grpc::string& method, const protobuf::Message& request,
Response* response, StatusCallback done, CallOptions* call_opts,
thread::ThreadPool* threadpool, int32 max_retries = 0,
bool fail_fast = true)
bool fail_fast = true, const string* target = nullptr)
: RPCState(
stub, cq, method, request, response, std::move(done), call_opts,
threadpool,
@ -63,7 +63,7 @@ class RPCState : public GrpcClientCQTag {
#endif // PLATFORM_GOOGLE
return x;
}(),
/*timeout_in_ms=*/0, max_retries) {
/*timeout_in_ms=*/0, max_retries, target) {
}
template <typename Request>
@ -71,7 +71,7 @@ class RPCState : public GrpcClientCQTag {
const ::grpc::string& method, const Request& request,
Response* response, StatusCallback done, CallOptions* call_opts,
thread::ThreadPool* threadpool, bool fail_fast, int64 timeout_in_ms,
int32 max_retries)
int32 max_retries, const string* target)
: call_opts_(call_opts),
threadpool_(threadpool),
done_(std::move(done)),
@ -80,7 +80,8 @@ class RPCState : public GrpcClientCQTag {
cq_(cq),
stub_(stub),
method_(method),
fail_fast_(fail_fast) {
fail_fast_(fail_fast),
target_(target) {
response_ = response;
::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf_);
if (!s.ok()) {
@ -152,10 +153,13 @@ class RPCState : public GrpcClientCQTag {
StartCall();
} else {
// Attach additional GRPC error information if any to the final status
s = Status(s.code(),
strings::StrCat(s.error_message(),
"\nAdditional GRPC error information:\n",
context_->debug_error_string()));
string error_msg = s.error_message();
strings::StrAppend(&error_msg, "\nAdditional GRPC error information");
if (target_) {
strings::StrAppend(&error_msg, " from remote target ", *target_);
}
strings::StrAppend(&error_msg, ":\n:", context_->debug_error_string());
s = Status(s.code(), error_msg);
// Always treat gRPC cancellation as a derived error. This ensures that
// other error types are preferred during status aggregation. (gRPC
// cancellation messages do not contain the original status message).
@ -196,6 +200,7 @@ class RPCState : public GrpcClientCQTag {
::grpc::GenericStub* stub_;
::grpc::string method_;
bool fail_fast_;
const string* target_;
};
// Represents state associated with one streaming RPC call.