Report remote target name for worker service RPCs.
PiperOrigin-RevId: 312095453 Change-Id: I73fc7948f994426b8d62bdefd5573cfe3b5b793d
This commit is contained in:
parent
cfdb943405
commit
dbc0fffedb
@ -45,7 +45,7 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||||||
explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
|
explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
|
||||||
::grpc::CompletionQueue* completion_queue,
|
::grpc::CompletionQueue* completion_queue,
|
||||||
thread::ThreadPool* callback_threadpool,
|
thread::ThreadPool* callback_threadpool,
|
||||||
WorkerCacheLogger* logger)
|
WorkerCacheLogger* logger, const string& target)
|
||||||
: channel_(std::move(channel)),
|
: channel_(std::move(channel)),
|
||||||
stub_(channel_),
|
stub_(channel_),
|
||||||
cq_(completion_queue),
|
cq_(completion_queue),
|
||||||
@ -66,7 +66,8 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||||||
instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
|
instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
|
||||||
getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
|
getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
|
||||||
markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
|
markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
|
||||||
logger_(logger) {}
|
logger_(logger),
|
||||||
|
target_(target) {}
|
||||||
|
|
||||||
~GrpcRemoteWorker() override {}
|
~GrpcRemoteWorker() override {}
|
||||||
|
|
||||||
@ -273,7 +274,7 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||||||
bool fail_fast = true) {
|
bool fail_fast = true) {
|
||||||
new RPCState<protobuf::Message>(
|
new RPCState<protobuf::Message>(
|
||||||
&stub_, cq_, method, *request, response, std::move(done), call_opts,
|
&stub_, cq_, method, *request, response, std::move(done), call_opts,
|
||||||
callback_threadpool_, /*max_retries=*/0, fail_fast);
|
callback_threadpool_, /*max_retries=*/0, fail_fast, &target_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void IssueRequest(const protobuf::Message* request, TensorResponse* response,
|
void IssueRequest(const protobuf::Message* request, TensorResponse* response,
|
||||||
@ -281,7 +282,8 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||||||
CallOptions* call_opts = nullptr) {
|
CallOptions* call_opts = nullptr) {
|
||||||
new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
|
new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
|
||||||
std::move(done), call_opts,
|
std::move(done), call_opts,
|
||||||
callback_threadpool_);
|
callback_threadpool_, /*max_retries=*/0,
|
||||||
|
/*fail_fast=*/true, &target_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void IssueMarkRecvFinishedRequest(int64 request_id) {
|
void IssueMarkRecvFinishedRequest(int64 request_id) {
|
||||||
@ -321,6 +323,7 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||||||
|
|
||||||
// Support for logging.
|
// Support for logging.
|
||||||
WorkerCacheLogger* logger_;
|
WorkerCacheLogger* logger_;
|
||||||
|
const string target_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
|
TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
|
||||||
};
|
};
|
||||||
@ -328,9 +331,10 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||||||
WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
|
WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
|
||||||
::grpc::CompletionQueue* completion_queue,
|
::grpc::CompletionQueue* completion_queue,
|
||||||
thread::ThreadPool* callback_threadpool,
|
thread::ThreadPool* callback_threadpool,
|
||||||
WorkerCacheLogger* logger) {
|
WorkerCacheLogger* logger,
|
||||||
|
const string& target) {
|
||||||
return new GrpcRemoteWorker(std::move(channel), completion_queue,
|
return new GrpcRemoteWorker(std::move(channel), completion_queue,
|
||||||
callback_threadpool, logger);
|
callback_threadpool, logger, target);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -29,7 +29,8 @@ class WorkerInterface;
|
|||||||
WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
|
WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
|
||||||
::grpc::CompletionQueue* completion_queue,
|
::grpc::CompletionQueue* completion_queue,
|
||||||
thread::ThreadPool* callback_threadpool,
|
thread::ThreadPool* callback_threadpool,
|
||||||
WorkerCacheLogger* logger);
|
WorkerCacheLogger* logger,
|
||||||
|
const string& target);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -69,9 +69,9 @@ class GrpcWorkerCache : public WorkerCachePartial {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
size_t index = AssignWorkerToThread(target);
|
size_t index = AssignWorkerToThread(target);
|
||||||
return NewGrpcRemoteWorker(channel,
|
return NewGrpcRemoteWorker(
|
||||||
worker_env_->GetCompletionQueue(index),
|
channel, worker_env_->GetCompletionQueue(index),
|
||||||
worker_env_->GetThreadPool(), &logger_);
|
worker_env_->GetThreadPool(), &logger_, target);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user