Report remote target in error messages for gRPC eager service requests.
PiperOrigin-RevId: 311634462 Change-Id: Ib0550c172e419ea17dac9ffa28c18b9e1a03b3cc
This commit is contained in:
parent
90077f8c7c
commit
d5e0f468cd
|
@ -106,8 +106,8 @@ class GrpcEagerClientThread : public core::RefCounted {
|
|||
class GrpcEagerClient : public EagerClient {
|
||||
public:
|
||||
GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel,
|
||||
GrpcEagerClientThread* thread)
|
||||
: stub_(channel), thread_(thread) {
|
||||
GrpcEagerClientThread* thread, const string& target)
|
||||
: stub_(channel), thread_(thread), target_(target) {
|
||||
// Hold a reference to make sure the corresponding EagerClientThread
|
||||
// outlives the client.
|
||||
thread_->Ref();
|
||||
|
@ -127,7 +127,8 @@ class GrpcEagerClient : public EagerClient {
|
|||
new RPCState<protobuf::Message>( \
|
||||
&stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
|
||||
response, std::move(done_wrapped), /*call_opts=*/nullptr, \
|
||||
/*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true); \
|
||||
/*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true, \
|
||||
&target_); \
|
||||
}
|
||||
|
||||
CLIENT_METHOD(CreateContext);
|
||||
|
@ -146,7 +147,8 @@ class GrpcEagerClient : public EagerClient {
|
|||
new RPCState<protobuf::Message>(
|
||||
&stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
|
||||
response, std::move(done_wrapped), /*call_opts=*/nullptr,
|
||||
/*threadpool=*/nullptr);
|
||||
/*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
|
||||
&target_);
|
||||
|
||||
VLOG(1) << "Sending RPC to close remote eager context "
|
||||
<< request->DebugString();
|
||||
|
@ -194,6 +196,7 @@ class GrpcEagerClient : public EagerClient {
|
|||
private:
|
||||
::grpc::GenericStub stub_;
|
||||
const GrpcEagerClientThread* thread_;
|
||||
const string target_;
|
||||
|
||||
::grpc::CompletionQueue* cq_;
|
||||
|
||||
|
@ -236,7 +239,7 @@ class GrpcEagerClientCache : public EagerClientCache {
|
|||
int assigned_index = AssignClientToThread(target);
|
||||
GrpcEagerClientThread* thread = threads_[assigned_index].get();
|
||||
core::RefCountPtr<EagerClient> worker(
|
||||
new GrpcEagerClient(shared, thread));
|
||||
new GrpcEagerClient(shared, thread, target));
|
||||
it = clients_.emplace(target, std::move(worker)).first;
|
||||
}
|
||||
|
||||
|
|
|
@ -210,7 +210,8 @@ void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t,
|
|||
get_stub(index), &completion_queue_, *get_method_ptr(index),
|
||||
call->request(), call->response(),
|
||||
/*done=*/[call](const Status& s) { call->Done(s); }, call->call_opts(),
|
||||
nullptr /*threadpool*/, fail_fast_, timeout_in_ms_, 0 /* max_retries */);
|
||||
/*threadpool=*/nullptr, fail_fast_, timeout_in_ms_, /*max_retries=*/0,
|
||||
/*target=*/nullptr);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -45,7 +45,7 @@ class RPCState : public GrpcClientCQTag {
|
|||
const ::grpc::string& method, const protobuf::Message& request,
|
||||
Response* response, StatusCallback done, CallOptions* call_opts,
|
||||
thread::ThreadPool* threadpool, int32 max_retries = 0,
|
||||
bool fail_fast = true)
|
||||
bool fail_fast = true, const string* target = nullptr)
|
||||
: RPCState(
|
||||
stub, cq, method, request, response, std::move(done), call_opts,
|
||||
threadpool,
|
||||
|
@ -63,7 +63,7 @@ class RPCState : public GrpcClientCQTag {
|
|||
#endif // PLATFORM_GOOGLE
|
||||
return x;
|
||||
}(),
|
||||
/*timeout_in_ms=*/0, max_retries) {
|
||||
/*timeout_in_ms=*/0, max_retries, target) {
|
||||
}
|
||||
|
||||
template <typename Request>
|
||||
|
@ -71,7 +71,7 @@ class RPCState : public GrpcClientCQTag {
|
|||
const ::grpc::string& method, const Request& request,
|
||||
Response* response, StatusCallback done, CallOptions* call_opts,
|
||||
thread::ThreadPool* threadpool, bool fail_fast, int64 timeout_in_ms,
|
||||
int32 max_retries)
|
||||
int32 max_retries, const string* target)
|
||||
: call_opts_(call_opts),
|
||||
threadpool_(threadpool),
|
||||
done_(std::move(done)),
|
||||
|
@ -80,7 +80,8 @@ class RPCState : public GrpcClientCQTag {
|
|||
cq_(cq),
|
||||
stub_(stub),
|
||||
method_(method),
|
||||
fail_fast_(fail_fast) {
|
||||
fail_fast_(fail_fast),
|
||||
target_(target) {
|
||||
response_ = response;
|
||||
::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf_);
|
||||
if (!s.ok()) {
|
||||
|
@ -152,10 +153,13 @@ class RPCState : public GrpcClientCQTag {
|
|||
StartCall();
|
||||
} else {
|
||||
// Attach additional GRPC error information if any to the final status
|
||||
s = Status(s.code(),
|
||||
strings::StrCat(s.error_message(),
|
||||
"\nAdditional GRPC error information:\n",
|
||||
context_->debug_error_string()));
|
||||
string error_msg = s.error_message();
|
||||
strings::StrAppend(&error_msg, "\nAdditional GRPC error information");
|
||||
if (target_) {
|
||||
strings::StrAppend(&error_msg, " from remote target ", *target_);
|
||||
}
|
||||
strings::StrAppend(&error_msg, ":\n:", context_->debug_error_string());
|
||||
s = Status(s.code(), error_msg);
|
||||
// Always treat gRPC cancellation as a derived error. This ensures that
|
||||
// other error types are preferred during status aggregation. (gRPC
|
||||
// cancellation messages do not contain the original status message).
|
||||
|
@ -196,6 +200,7 @@ class RPCState : public GrpcClientCQTag {
|
|||
::grpc::GenericStub* stub_;
|
||||
::grpc::string method_;
|
||||
bool fail_fast_;
|
||||
const string* target_;
|
||||
};
|
||||
|
||||
// Represents state associated with one streaming RPC call.
|
||||
|
|
Loading…
Reference in New Issue