Report remote target name for worker service RPCs.

PiperOrigin-RevId: 312095453
Change-Id: I73fc7948f994426b8d62bdefd5573cfe3b5b793d
This commit is contained in:
Haoyu Zhang 2020-05-18 09:35:47 -07:00 committed by TensorFlower Gardener
parent cfdb943405
commit dbc0fffedb
3 changed files with 15 additions and 10 deletions

View File

@ -45,7 +45,7 @@ class GrpcRemoteWorker : public WorkerInterface {
explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
::grpc::CompletionQueue* completion_queue,
thread::ThreadPool* callback_threadpool,
WorkerCacheLogger* logger)
WorkerCacheLogger* logger, const string& target)
: channel_(std::move(channel)),
stub_(channel_),
cq_(completion_queue),
@ -66,7 +66,8 @@ class GrpcRemoteWorker : public WorkerInterface {
instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
logger_(logger) {}
logger_(logger),
target_(target) {}
~GrpcRemoteWorker() override {}
@ -273,7 +274,7 @@ class GrpcRemoteWorker : public WorkerInterface {
bool fail_fast = true) {
new RPCState<protobuf::Message>(
&stub_, cq_, method, *request, response, std::move(done), call_opts,
callback_threadpool_, /*max_retries=*/0, fail_fast);
callback_threadpool_, /*max_retries=*/0, fail_fast, &target_);
}
void IssueRequest(const protobuf::Message* request, TensorResponse* response,
@ -281,7 +282,8 @@ class GrpcRemoteWorker : public WorkerInterface {
CallOptions* call_opts = nullptr) {
new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
std::move(done), call_opts,
callback_threadpool_);
callback_threadpool_, /*max_retries=*/0,
/*fail_fast=*/true, &target_);
}
void IssueMarkRecvFinishedRequest(int64 request_id) {
@ -321,6 +323,7 @@ class GrpcRemoteWorker : public WorkerInterface {
// Support for logging.
WorkerCacheLogger* logger_;
const string target_;
TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
};
@ -328,9 +331,10 @@ class GrpcRemoteWorker : public WorkerInterface {
WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
::grpc::CompletionQueue* completion_queue,
thread::ThreadPool* callback_threadpool,
WorkerCacheLogger* logger) {
WorkerCacheLogger* logger,
const string& target) {
return new GrpcRemoteWorker(std::move(channel), completion_queue,
callback_threadpool, logger);
callback_threadpool, logger, target);
}
} // namespace tensorflow

View File

@ -29,7 +29,8 @@ class WorkerInterface;
WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
::grpc::CompletionQueue* completion_queue,
thread::ThreadPool* callback_threadpool,
WorkerCacheLogger* logger);
WorkerCacheLogger* logger,
const string& target);
} // namespace tensorflow

View File

@ -69,9 +69,9 @@ class GrpcWorkerCache : public WorkerCachePartial {
return nullptr;
}
size_t index = AssignWorkerToThread(target);
return NewGrpcRemoteWorker(channel,
worker_env_->GetCompletionQueue(index),
worker_env_->GetThreadPool(), &logger_);
return NewGrpcRemoteWorker(
channel, worker_env_->GetCompletionQueue(index),
worker_env_->GetThreadPool(), &logger_, target);
}
}