Report remote target name for worker service RPCs.
PiperOrigin-RevId: 312095453 Change-Id: I73fc7948f994426b8d62bdefd5573cfe3b5b793d
This commit is contained in:
parent
cfdb943405
commit
dbc0fffedb
@ -45,7 +45,7 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
|
||||
::grpc::CompletionQueue* completion_queue,
|
||||
thread::ThreadPool* callback_threadpool,
|
||||
WorkerCacheLogger* logger)
|
||||
WorkerCacheLogger* logger, const string& target)
|
||||
: channel_(std::move(channel)),
|
||||
stub_(channel_),
|
||||
cq_(completion_queue),
|
||||
@ -66,7 +66,8 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
|
||||
getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
|
||||
markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
|
||||
logger_(logger) {}
|
||||
logger_(logger),
|
||||
target_(target) {}
|
||||
|
||||
~GrpcRemoteWorker() override {}
|
||||
|
||||
@ -273,7 +274,7 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
bool fail_fast = true) {
|
||||
new RPCState<protobuf::Message>(
|
||||
&stub_, cq_, method, *request, response, std::move(done), call_opts,
|
||||
callback_threadpool_, /*max_retries=*/0, fail_fast);
|
||||
callback_threadpool_, /*max_retries=*/0, fail_fast, &target_);
|
||||
}
|
||||
|
||||
void IssueRequest(const protobuf::Message* request, TensorResponse* response,
|
||||
@ -281,7 +282,8 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
CallOptions* call_opts = nullptr) {
|
||||
new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
|
||||
std::move(done), call_opts,
|
||||
callback_threadpool_);
|
||||
callback_threadpool_, /*max_retries=*/0,
|
||||
/*fail_fast=*/true, &target_);
|
||||
}
|
||||
|
||||
void IssueMarkRecvFinishedRequest(int64 request_id) {
|
||||
@ -321,6 +323,7 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
|
||||
// Support for logging.
|
||||
WorkerCacheLogger* logger_;
|
||||
const string target_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
|
||||
};
|
||||
@ -328,9 +331,10 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
|
||||
::grpc::CompletionQueue* completion_queue,
|
||||
thread::ThreadPool* callback_threadpool,
|
||||
WorkerCacheLogger* logger) {
|
||||
WorkerCacheLogger* logger,
|
||||
const string& target) {
|
||||
return new GrpcRemoteWorker(std::move(channel), completion_queue,
|
||||
callback_threadpool, logger);
|
||||
callback_threadpool, logger, target);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -29,7 +29,8 @@ class WorkerInterface;
|
||||
WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
|
||||
::grpc::CompletionQueue* completion_queue,
|
||||
thread::ThreadPool* callback_threadpool,
|
||||
WorkerCacheLogger* logger);
|
||||
WorkerCacheLogger* logger,
|
||||
const string& target);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -69,9 +69,9 @@ class GrpcWorkerCache : public WorkerCachePartial {
|
||||
return nullptr;
|
||||
}
|
||||
size_t index = AssignWorkerToThread(target);
|
||||
return NewGrpcRemoteWorker(channel,
|
||||
worker_env_->GetCompletionQueue(index),
|
||||
worker_env_->GetThreadPool(), &logger_);
|
||||
return NewGrpcRemoteWorker(
|
||||
channel, worker_env_->GetCompletionQueue(index),
|
||||
worker_env_->GetThreadPool(), &logger_, target);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user