Report remote target name for worker service RPCs.

PiperOrigin-RevId: 312095453
Change-Id: I73fc7948f994426b8d62bdefd5573cfe3b5b793d
This commit is contained in:
Haoyu Zhang 2020-05-18 09:35:47 -07:00 committed by TensorFlower Gardener
parent cfdb943405
commit dbc0fffedb
3 changed files with 15 additions and 10 deletions

View File

@ -45,7 +45,7 @@ class GrpcRemoteWorker : public WorkerInterface {
explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel, explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
::grpc::CompletionQueue* completion_queue, ::grpc::CompletionQueue* completion_queue,
thread::ThreadPool* callback_threadpool, thread::ThreadPool* callback_threadpool,
WorkerCacheLogger* logger) WorkerCacheLogger* logger, const string& target)
: channel_(std::move(channel)), : channel_(std::move(channel)),
stub_(channel_), stub_(channel_),
cq_(completion_queue), cq_(completion_queue),
@ -66,7 +66,8 @@ class GrpcRemoteWorker : public WorkerInterface {
instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)), instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)), getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)), markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
logger_(logger) {} logger_(logger),
target_(target) {}
~GrpcRemoteWorker() override {} ~GrpcRemoteWorker() override {}
@ -273,7 +274,7 @@ class GrpcRemoteWorker : public WorkerInterface {
bool fail_fast = true) { bool fail_fast = true) {
new RPCState<protobuf::Message>( new RPCState<protobuf::Message>(
&stub_, cq_, method, *request, response, std::move(done), call_opts, &stub_, cq_, method, *request, response, std::move(done), call_opts,
callback_threadpool_, /*max_retries=*/0, fail_fast); callback_threadpool_, /*max_retries=*/0, fail_fast, &target_);
} }
void IssueRequest(const protobuf::Message* request, TensorResponse* response, void IssueRequest(const protobuf::Message* request, TensorResponse* response,
@ -281,7 +282,8 @@ class GrpcRemoteWorker : public WorkerInterface {
CallOptions* call_opts = nullptr) { CallOptions* call_opts = nullptr) {
new RPCState<TensorResponse>(&stub_, cq_, method, *request, response, new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
std::move(done), call_opts, std::move(done), call_opts,
callback_threadpool_); callback_threadpool_, /*max_retries=*/0,
/*fail_fast=*/true, &target_);
} }
void IssueMarkRecvFinishedRequest(int64 request_id) { void IssueMarkRecvFinishedRequest(int64 request_id) {
@ -321,6 +323,7 @@ class GrpcRemoteWorker : public WorkerInterface {
// Support for logging. // Support for logging.
WorkerCacheLogger* logger_; WorkerCacheLogger* logger_;
const string target_;
TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker); TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
}; };
@ -328,9 +331,10 @@ class GrpcRemoteWorker : public WorkerInterface {
WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
::grpc::CompletionQueue* completion_queue, ::grpc::CompletionQueue* completion_queue,
thread::ThreadPool* callback_threadpool, thread::ThreadPool* callback_threadpool,
WorkerCacheLogger* logger) { WorkerCacheLogger* logger,
const string& target) {
return new GrpcRemoteWorker(std::move(channel), completion_queue, return new GrpcRemoteWorker(std::move(channel), completion_queue,
callback_threadpool, logger); callback_threadpool, logger, target);
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -29,7 +29,8 @@ class WorkerInterface;
WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
::grpc::CompletionQueue* completion_queue, ::grpc::CompletionQueue* completion_queue,
thread::ThreadPool* callback_threadpool, thread::ThreadPool* callback_threadpool,
WorkerCacheLogger* logger); WorkerCacheLogger* logger,
const string& target);
} // namespace tensorflow } // namespace tensorflow

View File

@ -69,9 +69,9 @@ class GrpcWorkerCache : public WorkerCachePartial {
return nullptr; return nullptr;
} }
size_t index = AssignWorkerToThread(target); size_t index = AssignWorkerToThread(target);
return NewGrpcRemoteWorker(channel, return NewGrpcRemoteWorker(
worker_env_->GetCompletionQueue(index), channel, worker_env_->GetCompletionQueue(index),
worker_env_->GetThreadPool(), &logger_); worker_env_->GetThreadPool(), &logger_, target);
} }
} }