diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index 85431acdf0c..6e706179863 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -45,7 +45,7 @@ class GrpcRemoteWorker : public WorkerInterface { explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel, ::grpc::CompletionQueue* completion_queue, thread::ThreadPool* callback_threadpool, - WorkerCacheLogger* logger) + WorkerCacheLogger* logger, const string& target) : channel_(std::move(channel)), stub_(channel_), cq_(completion_queue), @@ -66,7 +66,8 @@ class GrpcRemoteWorker : public WorkerInterface { instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)), getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)), markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)), - logger_(logger) {} + logger_(logger), + target_(target) {} ~GrpcRemoteWorker() override {} @@ -273,7 +274,7 @@ class GrpcRemoteWorker : public WorkerInterface { bool fail_fast = true) { new RPCState( &stub_, cq_, method, *request, response, std::move(done), call_opts, - callback_threadpool_, /*max_retries=*/0, fail_fast); + callback_threadpool_, /*max_retries=*/0, fail_fast, &target_); } void IssueRequest(const protobuf::Message* request, TensorResponse* response, @@ -281,7 +282,8 @@ class GrpcRemoteWorker : public WorkerInterface { CallOptions* call_opts = nullptr) { new RPCState(&stub_, cq_, method, *request, response, std::move(done), call_opts, - callback_threadpool_); + callback_threadpool_, /*max_retries=*/0, + /*fail_fast=*/true, &target_); } void IssueMarkRecvFinishedRequest(int64 request_id) { @@ -321,6 +323,7 @@ class GrpcRemoteWorker : public WorkerInterface { // Support for logging. WorkerCacheLogger* logger_; + const string target_; TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker); }; @@ -328,9 +331,10 @@ class GrpcRemoteWorker : public WorkerInterface { WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, ::grpc::CompletionQueue* completion_queue, thread::ThreadPool* callback_threadpool, - WorkerCacheLogger* logger) { + WorkerCacheLogger* logger, + const string& target) { return new GrpcRemoteWorker(std::move(channel), completion_queue, - callback_threadpool, logger); + callback_threadpool, logger, target); } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h index c0a49ecfc38..97e590e0ad1 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h @@ -29,7 +29,8 @@ class WorkerInterface; WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, ::grpc::CompletionQueue* completion_queue, thread::ThreadPool* callback_threadpool, - WorkerCacheLogger* logger); + WorkerCacheLogger* logger, + const string& target); } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc index f6b6e15a2ba..1d75728ddd2 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc @@ -69,9 +69,9 @@ class GrpcWorkerCache : public WorkerCachePartial { return nullptr; } size_t index = AssignWorkerToThread(target); - return NewGrpcRemoteWorker(channel, - worker_env_->GetCompletionQueue(index), - worker_env_->GetThreadPool(), &logger_); + return NewGrpcRemoteWorker( + channel, worker_env_->GetCompletionQueue(index), + worker_env_->GetThreadPool(), &logger_, target); } }