gRPC service: add configuration parameters for queue depth and thread count.
Refactor GrpcServer, WorkerService to take option structs. PiperOrigin-RevId: 230836072
This commit is contained in:
parent
a827a4bee1
commit
5a901abe02
@ -17,11 +17,6 @@ filegroup(
|
|||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
load(
|
|
||||||
"//tensorflow:tensorflow.bzl",
|
|
||||||
"tf_cuda_library",
|
|
||||||
)
|
|
||||||
|
|
||||||
# For platform specific build config
|
# For platform specific build config
|
||||||
load(
|
load(
|
||||||
"//tensorflow/core:platform/default/build_config.bzl",
|
"//tensorflow/core:platform/default/build_config.bzl",
|
||||||
@ -66,7 +61,6 @@ cc_library(
|
|||||||
":gdr_memory_manager",
|
":gdr_memory_manager",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:gpu_runtime",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core/distributed_runtime:graph_mgr",
|
"//tensorflow/core/distributed_runtime:graph_mgr",
|
||||||
@ -108,15 +102,13 @@ cc_library(
|
|||||||
":gdr_memory_manager",
|
":gdr_memory_manager",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core/distributed_runtime:cancellable_call",
|
"//tensorflow/core/distributed_runtime:cancellable_call",
|
||||||
"//tensorflow/core/distributed_runtime:collective_param_resolver_distributed",
|
"//tensorflow/core/distributed_runtime:collective_param_resolver_distributed",
|
||||||
"//tensorflow/core/distributed_runtime:device_resolver_distributed",
|
"//tensorflow/core/distributed_runtime:device_resolver_distributed",
|
||||||
"//tensorflow/core/distributed_runtime:request_id",
|
"//tensorflow/core/distributed_runtime:request_id",
|
||||||
"//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr",
|
"//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr",
|
||||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||||
"//tensorflow/core/distributed_runtime:worker_env",
|
|
||||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -130,6 +122,9 @@ cc_library(
|
|||||||
":gdr_memory_manager",
|
":gdr_memory_manager",
|
||||||
":gdr_rendezvous_mgr",
|
":gdr_rendezvous_mgr",
|
||||||
":gdr_worker",
|
":gdr_worker",
|
||||||
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
"//tensorflow/core/distributed_runtime:collective_param_resolver_distributed",
|
||||||
|
"//tensorflow/core/distributed_runtime:device_resolver_distributed",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
@ -16,14 +16,14 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_
|
#define TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_
|
||||||
|
|
||||||
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
|
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
|
#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
|
||||||
#include "tensorflow/core/framework/collective.h"
|
#include "tensorflow/core/framework/collective.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
class CollectiveParamResolverDistributed;
|
|
||||||
class ConfigProto;
|
class ConfigProto;
|
||||||
class DeviceMgr;
|
class DeviceMgr;
|
||||||
class DeviceResolverDistributed;
|
|
||||||
class WorkerCacheInterface;
|
class WorkerCacheInterface;
|
||||||
class StepSequenceRequest;
|
class StepSequenceRequest;
|
||||||
class StepSequenceResponse;
|
class StepSequenceResponse;
|
||||||
|
@ -82,8 +82,11 @@ Status GdrServer::Init() {
|
|||||||
};
|
};
|
||||||
TF_RETURN_IF_ERROR(remote_memory_manager_->Init());
|
TF_RETURN_IF_ERROR(remote_memory_manager_->Init());
|
||||||
|
|
||||||
return GrpcServer::Init(nullptr, rendezvous_mgr_func, collective_mgr_func,
|
GrpcServerOptions opts;
|
||||||
worker_func);
|
opts.rendezvous_mgr_func = rendezvous_mgr_func;
|
||||||
|
opts.collective_mgr_func = collective_mgr_func;
|
||||||
|
opts.worker_func = worker_func;
|
||||||
|
return GrpcServer::Init(opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GdrServer::Start() {
|
Status GdrServer::Start() {
|
||||||
|
@ -35,8 +35,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/tracing.h"
|
#include "tensorflow/core/platform/tracing.h"
|
||||||
#include "tensorflow/core/protobuf/transport_options.pb.h"
|
|
||||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
@ -112,12 +112,7 @@ GrpcServer::~GrpcServer() {
|
|||||||
|
|
||||||
void GrpcServer::MaybeMutateBuilder(::grpc::ServerBuilder* builder) {}
|
void GrpcServer::MaybeMutateBuilder(::grpc::ServerBuilder* builder) {}
|
||||||
|
|
||||||
Status GrpcServer::Init(
|
Status GrpcServer::Init(const GrpcServerOptions& opts) {
|
||||||
ServiceInitFunction service_func,
|
|
||||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
|
|
||||||
const CollectiveMgrCreationFunction& collective_mgr_func,
|
|
||||||
const WorkerCreationFunction& worker_func,
|
|
||||||
const StatsPublisherFactory& stats_factory) {
|
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
CHECK_EQ(state_, NEW);
|
CHECK_EQ(state_, NEW);
|
||||||
master_env_.env = env_;
|
master_env_.env = env_;
|
||||||
@ -165,9 +160,9 @@ Status GrpcServer::Init(
|
|||||||
worker_env_.device_mgr = new DeviceMgr(std::move(devices));
|
worker_env_.device_mgr = new DeviceMgr(std::move(devices));
|
||||||
master_env_.local_devices = worker_env_.device_mgr->ListDevices();
|
master_env_.local_devices = worker_env_.device_mgr->ListDevices();
|
||||||
worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
|
worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
|
||||||
worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
|
worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr
|
||||||
? new RpcRendezvousMgr(&worker_env_)
|
? new RpcRendezvousMgr(&worker_env_)
|
||||||
: rendezvous_mgr_func(&worker_env_);
|
: opts.rendezvous_mgr_func(&worker_env_);
|
||||||
string unused;
|
string unused;
|
||||||
string default_worker_name;
|
string default_worker_name;
|
||||||
if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
|
if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
|
||||||
@ -200,15 +195,16 @@ Status GrpcServer::Init(
|
|||||||
MaybeMutateBuilder(&builder);
|
MaybeMutateBuilder(&builder);
|
||||||
master_impl_ = CreateMaster(&master_env_);
|
master_impl_ = CreateMaster(&master_env_);
|
||||||
master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder);
|
master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder);
|
||||||
worker_impl_ = worker_func ? worker_func(&worker_env_, config)
|
worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
|
||||||
: NewGrpcWorker(&worker_env_, config);
|
: NewGrpcWorker(&worker_env_, config);
|
||||||
worker_service_ =
|
worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
|
||||||
NewGrpcWorkerService(worker_impl_.get(), &builder).release();
|
opts.worker_service_options)
|
||||||
|
.release();
|
||||||
eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
|
eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
|
||||||
|
|
||||||
// extra service:
|
// extra service:
|
||||||
if (service_func != nullptr) {
|
if (opts.service_func != nullptr) {
|
||||||
service_func(&worker_env_, &builder);
|
opts.service_func(&worker_env_, &builder);
|
||||||
}
|
}
|
||||||
server_ = builder.BuildAndStart();
|
server_ = builder.BuildAndStart();
|
||||||
|
|
||||||
@ -222,9 +218,9 @@ Status GrpcServer::Init(
|
|||||||
WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
|
WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
|
||||||
CHECK_NE(nullptr, worker_cache);
|
CHECK_NE(nullptr, worker_cache);
|
||||||
|
|
||||||
if (collective_mgr_func) {
|
if (opts.collective_mgr_func) {
|
||||||
worker_env_.collective_executor_mgr =
|
worker_env_.collective_executor_mgr =
|
||||||
collective_mgr_func(config, &worker_env_, worker_cache);
|
opts.collective_mgr_func(config, &worker_env_, worker_cache);
|
||||||
if (!worker_env_.collective_executor_mgr) {
|
if (!worker_env_.collective_executor_mgr) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"collective_mgr_func did not return CollectiveExecutorMgr");
|
"collective_mgr_func did not return CollectiveExecutorMgr");
|
||||||
@ -256,6 +252,7 @@ Status GrpcServer::Init(
|
|||||||
master_env_.ops = OpRegistry::Global();
|
master_env_.ops = OpRegistry::Global();
|
||||||
master_env_.worker_cache = worker_cache;
|
master_env_.worker_cache = worker_cache;
|
||||||
master_env_.collective_executor_mgr = worker_env_.collective_executor_mgr;
|
master_env_.collective_executor_mgr = worker_env_.collective_executor_mgr;
|
||||||
|
StatsPublisherFactory stats_factory = opts.stats_factory;
|
||||||
master_env_.master_session_factory =
|
master_env_.master_session_factory =
|
||||||
[config, stats_factory](
|
[config, stats_factory](
|
||||||
SessionOptions options, const MasterEnv* env,
|
SessionOptions options, const MasterEnv* env,
|
||||||
@ -282,31 +279,6 @@ Status GrpcServer::Init(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GrpcServer::Init(
|
|
||||||
ServiceInitFunction service_func,
|
|
||||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
|
|
||||||
const CollectiveMgrCreationFunction& collective_mgr_func,
|
|
||||||
const WorkerCreationFunction& worker_func) {
|
|
||||||
return Init(std::move(service_func), rendezvous_mgr_func, collective_mgr_func,
|
|
||||||
worker_func, CreateNoOpStatsPublisher);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GrpcServer::Init(
|
|
||||||
ServiceInitFunction service_func,
|
|
||||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
|
|
||||||
const CollectiveMgrCreationFunction& collective_mgr_func) {
|
|
||||||
return Init(std::move(service_func), rendezvous_mgr_func, collective_mgr_func,
|
|
||||||
nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GrpcServer::Init(
|
|
||||||
ServiceInitFunction service_func,
|
|
||||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
|
|
||||||
return Init(std::move(service_func), rendezvous_mgr_func, nullptr, nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr, nullptr); }
|
|
||||||
|
|
||||||
Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
|
Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
|
||||||
GrpcChannelSpec* channel_spec) {
|
GrpcChannelSpec* channel_spec) {
|
||||||
for (const auto& job : options.cluster_def->job()) {
|
for (const auto& job : options.cluster_def->job()) {
|
||||||
@ -457,7 +429,9 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env,
|
|||||||
std::unique_ptr<GrpcServer> ret(
|
std::unique_ptr<GrpcServer> ret(
|
||||||
new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
|
new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
|
||||||
ServiceInitFunction service_func = nullptr;
|
ServiceInitFunction service_func = nullptr;
|
||||||
Status s = ret->Init(service_func, NewRpcRendezvousMgr, nullptr);
|
GrpcServerOptions options;
|
||||||
|
options.rendezvous_mgr_func = NewRpcRendezvousMgr;
|
||||||
|
Status s = ret->Init();
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
LOG(ERROR) << s;
|
LOG(ERROR) << s;
|
||||||
return s;
|
return s;
|
||||||
@ -471,8 +445,9 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env,
|
|||||||
std::unique_ptr<GrpcServer>* out_server) {
|
std::unique_ptr<GrpcServer>* out_server) {
|
||||||
std::unique_ptr<GrpcServer> ret(
|
std::unique_ptr<GrpcServer> ret(
|
||||||
new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
|
new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
|
||||||
ServiceInitFunction service_func = nullptr;
|
GrpcServerOptions options;
|
||||||
Status s = ret->Init(service_func, NewRpcRendezvousMgr, nullptr);
|
options.rendezvous_mgr_func = NewRpcRendezvousMgr;
|
||||||
|
Status s = ret->Init(options);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
LOG(ERROR) << s;
|
LOG(ERROR) << s;
|
||||||
return s;
|
return s;
|
||||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
|
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
|
||||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
|
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
|
||||||
|
|
||||||
|
// GrpcServer manages the lifecycle of an Eager, Worker and Master service.
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "grpcpp/grpcpp.h"
|
#include "grpcpp/grpcpp.h"
|
||||||
@ -26,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/distributed_runtime/master_env.h"
|
#include "tensorflow/core/distributed_runtime/master_env.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
|
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
|
||||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||||
#include "tensorflow/core/distributed_runtime/session_mgr.h"
|
#include "tensorflow/core/distributed_runtime/session_mgr.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||||
@ -57,6 +60,15 @@ typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*,
|
|||||||
const ConfigProto& config)>
|
const ConfigProto& config)>
|
||||||
WorkerCreationFunction;
|
WorkerCreationFunction;
|
||||||
|
|
||||||
|
struct GrpcServerOptions {
|
||||||
|
ServiceInitFunction service_func = nullptr;
|
||||||
|
RendezvousMgrCreationFunction rendezvous_mgr_func = nullptr;
|
||||||
|
CollectiveMgrCreationFunction collective_mgr_func = nullptr;
|
||||||
|
WorkerCreationFunction worker_func = nullptr;
|
||||||
|
StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher;
|
||||||
|
GrpcWorkerServiceOptions worker_service_options;
|
||||||
|
};
|
||||||
|
|
||||||
class GrpcServer : public ServerInterface {
|
class GrpcServer : public ServerInterface {
|
||||||
protected:
|
protected:
|
||||||
GrpcServer(const ServerDef& server_def, Env* env);
|
GrpcServer(const ServerDef& server_def, Env* env);
|
||||||
@ -86,25 +98,7 @@ class GrpcServer : public ServerInterface {
|
|||||||
std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; }
|
std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status Init(ServiceInitFunction service_func,
|
Status Init(const GrpcServerOptions& opts = GrpcServerOptions());
|
||||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
|
|
||||||
const CollectiveMgrCreationFunction& collective_mgr_func,
|
|
||||||
const WorkerCreationFunction& worker_func,
|
|
||||||
const StatsPublisherFactory& stats_factory);
|
|
||||||
|
|
||||||
Status Init(ServiceInitFunction service_func,
|
|
||||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
|
|
||||||
const CollectiveMgrCreationFunction& collective_mgr_func,
|
|
||||||
const WorkerCreationFunction& worker_func);
|
|
||||||
|
|
||||||
Status Init(ServiceInitFunction service_func,
|
|
||||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
|
|
||||||
const CollectiveMgrCreationFunction& collective_mgr_func);
|
|
||||||
|
|
||||||
Status Init(ServiceInitFunction service_func,
|
|
||||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func);
|
|
||||||
|
|
||||||
Status Init();
|
|
||||||
|
|
||||||
// A subclass can override this method to support secure credentials.
|
// A subclass can override this method to support secure credentials.
|
||||||
virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
|
virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
|
||||||
|
|
||||||
#include <deque>
|
#include <deque>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "grpcpp/alarm.h"
|
#include "grpcpp/alarm.h"
|
||||||
#include "grpcpp/server_builder.h"
|
#include "grpcpp/server_builder.h"
|
||||||
@ -41,6 +42,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/collective.h"
|
#include "tensorflow/core/framework/collective.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/tracing.h"
|
#include "tensorflow/core/platform/tracing.h"
|
||||||
#include "tensorflow/core/protobuf/transport_options.pb.h"
|
#include "tensorflow/core/protobuf/transport_options.pb.h"
|
||||||
@ -50,37 +52,6 @@ namespace tensorflow {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class GrpcWorkerService : public AsyncServiceInterface {
|
|
||||||
// TODO(ncteisen): consider adding a config var or flag for this
|
|
||||||
static constexpr const size_t kGrpcWorkerServiceThreadCount = 8;
|
|
||||||
|
|
||||||
public:
|
|
||||||
GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder)
|
|
||||||
: is_shutdown_(false) {
|
|
||||||
builder->RegisterService(&worker_service_);
|
|
||||||
for (int i = 0; i < kGrpcWorkerServiceThreadCount; i++) {
|
|
||||||
threads_.emplace_back(
|
|
||||||
new GrpcWorkerServiceThread(worker, builder, &worker_service_));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Shutdown() override {
|
|
||||||
bool did_shutdown = false;
|
|
||||||
{
|
|
||||||
mutex_lock l(service_shutdown_mu_);
|
|
||||||
if (!is_shutdown_) {
|
|
||||||
LOG(INFO) << "Shutting down GrpcWorkerService.";
|
|
||||||
is_shutdown_ = true;
|
|
||||||
did_shutdown = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (did_shutdown) {
|
|
||||||
for (auto& worker_thread : threads_) {
|
|
||||||
worker_thread->Shutdown();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This macro creates a new request for the given RPC method name
|
// This macro creates a new request for the given RPC method name
|
||||||
// (e.g., `ENQUEUE_REQUEST(GetStatus, false);`), and enqueues it on
|
// (e.g., `ENQUEUE_REQUEST(GetStatus, false);`), and enqueues it on
|
||||||
// `this->cq_`.
|
// `this->cq_`.
|
||||||
@ -105,6 +76,329 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
#define SETUP_FOR_REQUEST(method, default_depth, supports_cancel) \
|
||||||
|
for (int i = 0; \
|
||||||
|
i < gtl::FindWithDefault(queue_depth_, \
|
||||||
|
static_cast<int>(GrpcWorkerMethod::k##method), \
|
||||||
|
default_depth); \
|
||||||
|
++i) { \
|
||||||
|
ENQUEUE_REQUEST(method, supports_cancel); \
|
||||||
|
}
|
||||||
|
|
||||||
|
// GrpcWorkerService spawns one or more GrpcWorkerServiceThreads to service
|
||||||
|
// requests. Each thread operates on an independent completion queue.
|
||||||
|
class GrpcWorkerServiceThread {
|
||||||
|
public:
|
||||||
|
explicit GrpcWorkerServiceThread(
|
||||||
|
GrpcWorker* worker, ::grpc::ServerBuilder* builder,
|
||||||
|
std::unordered_map<int, int> queue_depth,
|
||||||
|
grpc::WorkerService::AsyncService* worker_service)
|
||||||
|
: worker_(worker),
|
||||||
|
queue_depth_(queue_depth),
|
||||||
|
worker_service_(worker_service),
|
||||||
|
is_shutdown_(false) {
|
||||||
|
cq_ = builder->AddCompletionQueue();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Start() {
|
||||||
|
thread_.reset(
|
||||||
|
worker_->env()->env->StartThread(ThreadOptions(), "grpc_worker_service",
|
||||||
|
[this]() { HandleRPCsLoop(); }));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Join() { thread_.reset(); } // Blocks until thread exits
|
||||||
|
|
||||||
|
void Shutdown() {
|
||||||
|
{
|
||||||
|
mutex_lock lock(shutdown_mu_);
|
||||||
|
is_shutdown_ = true;
|
||||||
|
}
|
||||||
|
cq_->Shutdown();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Add one or more completion queue entries for each worker method, then
|
||||||
|
// begin servicing requests from the completion queue.
|
||||||
|
void HandleRPCsLoop() {
|
||||||
|
// TODO(ncteisen): This may require performance engineering. We can
|
||||||
|
// change the number of threads, the number of handlers per thread,
|
||||||
|
// or even decide to specialize certain threads to certain methods.
|
||||||
|
SETUP_FOR_REQUEST(GetStatus, 1, false);
|
||||||
|
SETUP_FOR_REQUEST(CreateWorkerSession, 1, false);
|
||||||
|
SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false);
|
||||||
|
SETUP_FOR_REQUEST(CleanupAll, 1, false);
|
||||||
|
SETUP_FOR_REQUEST(RegisterGraph, 1, false);
|
||||||
|
SETUP_FOR_REQUEST(DeregisterGraph, 1, false);
|
||||||
|
SETUP_FOR_REQUEST(Logging, 1, false);
|
||||||
|
SETUP_FOR_REQUEST(Tracing, 1, false);
|
||||||
|
SETUP_FOR_REQUEST(CompleteGroup, 10, true);
|
||||||
|
SETUP_FOR_REQUEST(CompleteInstance, 10, true);
|
||||||
|
SETUP_FOR_REQUEST(GetStepSequence, 10, true);
|
||||||
|
SETUP_FOR_REQUEST(RecvBuf, 500, true);
|
||||||
|
SETUP_FOR_REQUEST(RunGraph, 100, true);
|
||||||
|
SETUP_FOR_REQUEST(CleanupGraph, 100, false);
|
||||||
|
|
||||||
|
// TODO(ncteisen): Determine a better policy for enqueuing the
|
||||||
|
// appropriate number of each request type.
|
||||||
|
for (int i = 0;
|
||||||
|
i < gtl::FindWithDefault(
|
||||||
|
queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
|
||||||
|
1000);
|
||||||
|
++i) {
|
||||||
|
EnqueueRecvTensorRequestRaw();
|
||||||
|
}
|
||||||
|
|
||||||
|
void* tag;
|
||||||
|
bool ok;
|
||||||
|
|
||||||
|
while (cq_->Next(&tag, &ok)) {
|
||||||
|
UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
|
||||||
|
static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
|
||||||
|
CHECK(callback_tag);
|
||||||
|
callback_tag->OnCompleted(this, ok);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Schedule(std::function<void()> f) {
|
||||||
|
worker_->env()->compute_pool->Schedule(std::move(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
// The following section contains one request handler method per
|
||||||
|
// RPC. The `FooHandler` method is called (indirectly) by
|
||||||
|
// `HandleRPCsLoop()` when the next Foo RPC is received. Each
|
||||||
|
// `FooHandler` call schedules a closure on `worker_->env()->compute_pool`,
|
||||||
|
// and is responsible for requesting the next Foo call by calling
|
||||||
|
// `ENQUEUE_REQUEST(Foo)`.
|
||||||
|
template <class RequestMessage, class ResponseMessage>
|
||||||
|
using WorkerCall =
|
||||||
|
Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
|
||||||
|
RequestMessage, ResponseMessage>;
|
||||||
|
|
||||||
|
void GetStatusHandler(WorkerCall<GetStatusRequest, GetStatusResponse>* call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
Status s = worker_->GetStatus(&call->request, &call->response);
|
||||||
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
|
});
|
||||||
|
ENQUEUE_REQUEST(GetStatus, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CreateWorkerSessionHandler(
|
||||||
|
WorkerCall<CreateWorkerSessionRequest, CreateWorkerSessionResponse>*
|
||||||
|
call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
Status s = worker_->CreateWorkerSession(&call->request, &call->response);
|
||||||
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
|
});
|
||||||
|
ENQUEUE_REQUEST(CreateWorkerSession, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeleteWorkerSessionHandler(
|
||||||
|
WorkerCall<DeleteWorkerSessionRequest, DeleteWorkerSessionResponse>*
|
||||||
|
call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
Status s = worker_->DeleteWorkerSession(&call->request, &call->response);
|
||||||
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
|
});
|
||||||
|
ENQUEUE_REQUEST(DeleteWorkerSession, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CleanupAllHandler(
|
||||||
|
WorkerCall<CleanupAllRequest, CleanupAllResponse>* call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
Status s = worker_->CleanupAll(&call->request, &call->response);
|
||||||
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
|
});
|
||||||
|
ENQUEUE_REQUEST(CleanupAll, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RegisterGraphHandler(
|
||||||
|
WorkerCall<RegisterGraphRequest, RegisterGraphResponse>* call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
Status s = worker_->RegisterGraph(&call->request, &call->response);
|
||||||
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
|
});
|
||||||
|
ENQUEUE_REQUEST(RegisterGraph, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeregisterGraphHandler(
|
||||||
|
WorkerCall<DeregisterGraphRequest, DeregisterGraphResponse>* call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
Status s = worker_->DeregisterGraph(&call->request, &call->response);
|
||||||
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
|
});
|
||||||
|
ENQUEUE_REQUEST(DeregisterGraph, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
CallOptions* call_opts = new CallOptions;
|
||||||
|
ProtoRunGraphRequest* wrapped_request =
|
||||||
|
new ProtoRunGraphRequest(&call->request);
|
||||||
|
NonOwnedProtoRunGraphResponse* wrapped_response =
|
||||||
|
new NonOwnedProtoRunGraphResponse(&call->response);
|
||||||
|
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
||||||
|
worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
|
||||||
|
[call, call_opts, wrapped_request,
|
||||||
|
wrapped_response](const Status& s) {
|
||||||
|
call->ClearCancelCallback();
|
||||||
|
delete call_opts;
|
||||||
|
delete wrapped_request;
|
||||||
|
delete wrapped_response;
|
||||||
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
ENQUEUE_REQUEST(RunGraph, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecvTensorHandlerRaw(
|
||||||
|
WorkerCall<RecvTensorRequest, ::grpc::ByteBuffer>* call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
CallOptions* call_opts = new CallOptions;
|
||||||
|
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
||||||
|
worker_->GrpcRecvTensorAsync(call_opts, &call->request, &call->response,
|
||||||
|
[call, call_opts](const Status& s) {
|
||||||
|
call->ClearCancelCallback();
|
||||||
|
delete call_opts;
|
||||||
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
EnqueueRecvTensorRequestRaw();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CleanupGraphHandler(
|
||||||
|
WorkerCall<CleanupGraphRequest, CleanupGraphResponse>* call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
Status s = worker_->CleanupGraph(&call->request, &call->response);
|
||||||
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
|
});
|
||||||
|
ENQUEUE_REQUEST(CleanupGraph, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void LoggingHandler(WorkerCall<LoggingRequest, LoggingResponse>* call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
Status s = worker_->Logging(&call->request, &call->response);
|
||||||
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
|
});
|
||||||
|
ENQUEUE_REQUEST(Logging, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TracingHandler(WorkerCall<TracingRequest, TracingResponse>* call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
Status s = worker_->Tracing(&call->request, &call->response);
|
||||||
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
|
});
|
||||||
|
ENQUEUE_REQUEST(Tracing, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecvBufHandler(WorkerCall<RecvBufRequest, RecvBufResponse>* call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
CallOptions* call_opts = new CallOptions;
|
||||||
|
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
||||||
|
worker_->RecvBufAsync(call_opts, &call->request, &call->response,
|
||||||
|
[call, call_opts](const Status& s) {
|
||||||
|
call->ClearCancelCallback();
|
||||||
|
delete call_opts;
|
||||||
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
ENQUEUE_REQUEST(RecvBuf, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CompleteGroupHandler(
|
||||||
|
WorkerCall<CompleteGroupRequest, CompleteGroupResponse>* call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
CallOptions* call_opts = new CallOptions;
|
||||||
|
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
||||||
|
worker_->CompleteGroupAsync(call_opts, &call->request, &call->response,
|
||||||
|
[call, call_opts](const Status& s) {
|
||||||
|
call->ClearCancelCallback();
|
||||||
|
delete call_opts;
|
||||||
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
ENQUEUE_REQUEST(CompleteGroup, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CompleteInstanceHandler(
|
||||||
|
WorkerCall<CompleteInstanceRequest, CompleteInstanceResponse>* call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
CallOptions* call_opts = new CallOptions;
|
||||||
|
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
||||||
|
worker_->CompleteInstanceAsync(call_opts, &call->request, &call->response,
|
||||||
|
[call, call_opts](const Status& s) {
|
||||||
|
call->ClearCancelCallback();
|
||||||
|
delete call_opts;
|
||||||
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
ENQUEUE_REQUEST(CompleteInstance, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetStepSequenceHandler(
|
||||||
|
WorkerCall<GetStepSequenceRequest, GetStepSequenceResponse>* call) {
|
||||||
|
Schedule([this, call]() {
|
||||||
|
worker_->GetStepSequenceAsync(
|
||||||
|
&call->request, &call->response,
|
||||||
|
[call](const Status& s) { call->SendResponse(ToGrpcStatus(s)); });
|
||||||
|
});
|
||||||
|
ENQUEUE_REQUEST(GetStepSequence, true);
|
||||||
|
}
|
||||||
|
#undef ENQUEUE_REQUEST
|
||||||
|
|
||||||
|
void EnqueueRecvTensorRequestRaw() {
|
||||||
|
mutex_lock l(shutdown_mu_);
|
||||||
|
if (!is_shutdown_) {
|
||||||
|
Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
|
||||||
|
RecvTensorRequest, ::grpc::ByteBuffer>::
|
||||||
|
EnqueueRequestForMethod(
|
||||||
|
worker_service_, cq_.get(),
|
||||||
|
static_cast<int>(GrpcWorkerMethod::kRecvTensor),
|
||||||
|
&GrpcWorkerServiceThread::RecvTensorHandlerRaw,
|
||||||
|
true /* supports cancel*/);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GrpcWorker* const worker_ = nullptr; // Not owned.
|
||||||
|
std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
|
||||||
|
std::unique_ptr<Thread> thread_;
|
||||||
|
std::unordered_map<int, int> queue_depth_;
|
||||||
|
grpc::WorkerService::AsyncService* const worker_service_;
|
||||||
|
|
||||||
|
mutex shutdown_mu_;
|
||||||
|
bool is_shutdown_ GUARDED_BY(shutdown_mu_);
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerServiceThread);
|
||||||
|
};
|
||||||
|
|
||||||
|
class GrpcWorkerService : public AsyncServiceInterface {
|
||||||
|
public:
|
||||||
|
GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder,
|
||||||
|
GrpcWorkerServiceOptions options)
|
||||||
|
: is_shutdown_(false) {
|
||||||
|
builder->RegisterService(&worker_service_);
|
||||||
|
for (int i = 0; i < options.num_worker_threads; i++) {
|
||||||
|
threads_.emplace_back(new GrpcWorkerServiceThread(
|
||||||
|
worker, builder, options.queue_depth, &worker_service_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Shutdown() override {
|
||||||
|
bool did_shutdown = false;
|
||||||
|
{
|
||||||
|
mutex_lock l(service_shutdown_mu_);
|
||||||
|
if (!is_shutdown_) {
|
||||||
|
LOG(INFO) << "Shutting down GrpcWorkerService.";
|
||||||
|
is_shutdown_ = true;
|
||||||
|
did_shutdown = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (did_shutdown) {
|
||||||
|
for (auto& worker_thread : threads_) {
|
||||||
|
worker_thread->Shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// This method blocks forever handling requests from the completion queue.
|
// This method blocks forever handling requests from the completion queue.
|
||||||
void HandleRPCsLoop() override {
|
void HandleRPCsLoop() override {
|
||||||
for (auto& worker_thread : threads_) {
|
for (auto& worker_thread : threads_) {
|
||||||
@ -116,297 +410,6 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Thread wrapping class that drives work over a single gRPC
|
|
||||||
// CompletionQueue.
|
|
||||||
class GrpcWorkerServiceThread {
|
|
||||||
public:
|
|
||||||
explicit GrpcWorkerServiceThread(
|
|
||||||
GrpcWorker* worker, ::grpc::ServerBuilder* builder,
|
|
||||||
grpc::WorkerService::AsyncService* worker_service)
|
|
||||||
: worker_(worker),
|
|
||||||
worker_service_(worker_service),
|
|
||||||
is_shutdown_(false) {
|
|
||||||
cq_ = builder->AddCompletionQueue();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Start() {
|
|
||||||
thread_.reset(worker_->env()->env->StartThread(
|
|
||||||
ThreadOptions(), "grpc_worker_service",
|
|
||||||
[this]() { HandleRPCsLoop(); }));
|
|
||||||
}
|
|
||||||
|
|
||||||
void Join() { thread_.reset(); } // Blocks until thread exits
|
|
||||||
|
|
||||||
void Shutdown() {
|
|
||||||
{
|
|
||||||
mutex_lock lock(shutdown_mu_);
|
|
||||||
is_shutdown_ = true;
|
|
||||||
}
|
|
||||||
cq_->Shutdown();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
void HandleRPCsLoop() {
|
|
||||||
// TODO(ncteisen): This may require performance engineering. We can
|
|
||||||
// change the number of threads, the number of handlers per thread,
|
|
||||||
// or even decide to specialize certain threads to certain methods.
|
|
||||||
ENQUEUE_REQUEST(GetStatus, false);
|
|
||||||
ENQUEUE_REQUEST(CreateWorkerSession, false);
|
|
||||||
ENQUEUE_REQUEST(DeleteWorkerSession, false);
|
|
||||||
ENQUEUE_REQUEST(CleanupAll, false);
|
|
||||||
ENQUEUE_REQUEST(RegisterGraph, false);
|
|
||||||
ENQUEUE_REQUEST(DeregisterGraph, false);
|
|
||||||
|
|
||||||
// TODO(ncteisen): Determine a better policy for enqueuing the
|
|
||||||
// appropriate number of each request type.
|
|
||||||
for (int i = 0; i < 1000; ++i) {
|
|
||||||
EnqueueRecvTensorRequestRaw();
|
|
||||||
}
|
|
||||||
for (int i = 0; i < 500; ++i) {
|
|
||||||
ENQUEUE_REQUEST(RecvBuf, true);
|
|
||||||
}
|
|
||||||
for (int i = 0; i < 100; ++i) {
|
|
||||||
ENQUEUE_REQUEST(RunGraph, true);
|
|
||||||
}
|
|
||||||
for (int i = 0; i < 100; ++i) {
|
|
||||||
ENQUEUE_REQUEST(CleanupGraph, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
ENQUEUE_REQUEST(Logging, false);
|
|
||||||
ENQUEUE_REQUEST(Tracing, false);
|
|
||||||
|
|
||||||
for (int i = 0; i < 10; ++i) {
|
|
||||||
ENQUEUE_REQUEST(CompleteGroup, true);
|
|
||||||
ENQUEUE_REQUEST(CompleteInstance, true);
|
|
||||||
ENQUEUE_REQUEST(GetStepSequence, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
void* tag;
|
|
||||||
bool ok;
|
|
||||||
|
|
||||||
while (cq_->Next(&tag, &ok)) {
|
|
||||||
UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
|
|
||||||
static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
|
|
||||||
CHECK(callback_tag);
|
|
||||||
callback_tag->OnCompleted(this, ok);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
void Schedule(std::function<void()> f) {
|
|
||||||
worker_->env()->compute_pool->Schedule(std::move(f));
|
|
||||||
}
|
|
||||||
|
|
||||||
// The following section contains one request handler method per
|
|
||||||
// RPC. The `FooHandler` method is called (indirectly) by
|
|
||||||
// `HandleRPCsLoop()` when the next Foo RPC is received. Each
|
|
||||||
// `FooHandler` call schedules a closure on `worker_->env()->compute_pool`,
|
|
||||||
// and is responsible for requesting the next Foo call by calling
|
|
||||||
// `ENQUEUE_REQUEST(Foo)`.
|
|
||||||
|
|
||||||
template <class RequestMessage, class ResponseMessage>
|
|
||||||
using WorkerCall =
|
|
||||||
Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
|
|
||||||
RequestMessage, ResponseMessage>;
|
|
||||||
|
|
||||||
void GetStatusHandler(
|
|
||||||
WorkerCall<GetStatusRequest, GetStatusResponse>* call) {
|
|
||||||
Schedule([this, call]() {
|
|
||||||
Status s = worker_->GetStatus(&call->request, &call->response);
|
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
|
||||||
});
|
|
||||||
ENQUEUE_REQUEST(GetStatus, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CreateWorkerSessionHandler(
|
|
||||||
WorkerCall<CreateWorkerSessionRequest, CreateWorkerSessionResponse>*
|
|
||||||
call) {
|
|
||||||
Schedule([this, call]() {
|
|
||||||
Status s =
|
|
||||||
worker_->CreateWorkerSession(&call->request, &call->response);
|
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
|
||||||
});
|
|
||||||
ENQUEUE_REQUEST(CreateWorkerSession, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
void DeleteWorkerSessionHandler(
|
|
||||||
WorkerCall<DeleteWorkerSessionRequest, DeleteWorkerSessionResponse>*
|
|
||||||
call) {
|
|
||||||
Schedule([this, call]() {
|
|
||||||
Status s =
|
|
||||||
worker_->DeleteWorkerSession(&call->request, &call->response);
|
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
|
||||||
});
|
|
||||||
ENQUEUE_REQUEST(DeleteWorkerSession, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CleanupAllHandler(
|
|
||||||
WorkerCall<CleanupAllRequest, CleanupAllResponse>* call) {
|
|
||||||
Schedule([this, call]() {
|
|
||||||
Status s = worker_->CleanupAll(&call->request, &call->response);
|
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
|
||||||
});
|
|
||||||
ENQUEUE_REQUEST(CleanupAll, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
void RegisterGraphHandler(
|
|
||||||
WorkerCall<RegisterGraphRequest, RegisterGraphResponse>* call) {
|
|
||||||
Schedule([this, call]() {
|
|
||||||
Status s = worker_->RegisterGraph(&call->request, &call->response);
|
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
|
||||||
});
|
|
||||||
ENQUEUE_REQUEST(RegisterGraph, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
void DeregisterGraphHandler(
|
|
||||||
WorkerCall<DeregisterGraphRequest, DeregisterGraphResponse>* call) {
|
|
||||||
Schedule([this, call]() {
|
|
||||||
Status s = worker_->DeregisterGraph(&call->request, &call->response);
|
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
|
||||||
});
|
|
||||||
ENQUEUE_REQUEST(DeregisterGraph, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
|
|
||||||
Schedule([this, call]() {
|
|
||||||
CallOptions* call_opts = new CallOptions;
|
|
||||||
ProtoRunGraphRequest* wrapped_request =
|
|
||||||
new ProtoRunGraphRequest(&call->request);
|
|
||||||
NonOwnedProtoRunGraphResponse* wrapped_response =
|
|
||||||
new NonOwnedProtoRunGraphResponse(&call->response);
|
|
||||||
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
|
||||||
worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
|
|
||||||
[call, call_opts, wrapped_request,
|
|
||||||
wrapped_response](const Status& s) {
|
|
||||||
call->ClearCancelCallback();
|
|
||||||
delete call_opts;
|
|
||||||
delete wrapped_request;
|
|
||||||
delete wrapped_response;
|
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
|
||||||
});
|
|
||||||
});
|
|
||||||
ENQUEUE_REQUEST(RunGraph, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
void RecvTensorHandlerRaw(
|
|
||||||
WorkerCall<RecvTensorRequest, ::grpc::ByteBuffer>* call) {
|
|
||||||
Schedule([this, call]() {
|
|
||||||
CallOptions* call_opts = new CallOptions;
|
|
||||||
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
|
||||||
worker_->GrpcRecvTensorAsync(call_opts, &call->request, &call->response,
|
|
||||||
[call, call_opts](const Status& s) {
|
|
||||||
call->ClearCancelCallback();
|
|
||||||
delete call_opts;
|
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
|
||||||
});
|
|
||||||
});
|
|
||||||
EnqueueRecvTensorRequestRaw();
|
|
||||||
}
|
|
||||||
|
|
||||||
void CleanupGraphHandler(
|
|
||||||
WorkerCall<CleanupGraphRequest, CleanupGraphResponse>* call) {
|
|
||||||
Schedule([this, call]() {
|
|
||||||
Status s = worker_->CleanupGraph(&call->request, &call->response);
|
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
|
||||||
});
|
|
||||||
ENQUEUE_REQUEST(CleanupGraph, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
void LoggingHandler(WorkerCall<LoggingRequest, LoggingResponse>* call) {
|
|
||||||
Schedule([this, call]() {
|
|
||||||
Status s = worker_->Logging(&call->request, &call->response);
|
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
|
||||||
});
|
|
||||||
ENQUEUE_REQUEST(Logging, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
void TracingHandler(WorkerCall<TracingRequest, TracingResponse>* call) {
|
|
||||||
Schedule([this, call]() {
|
|
||||||
Status s = worker_->Tracing(&call->request, &call->response);
|
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
|
||||||
});
|
|
||||||
ENQUEUE_REQUEST(Tracing, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
void RecvBufHandler(WorkerCall<RecvBufRequest, RecvBufResponse>* call) {
|
|
||||||
Schedule([this, call]() {
|
|
||||||
CallOptions* call_opts = new CallOptions;
|
|
||||||
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
|
||||||
worker_->RecvBufAsync(call_opts, &call->request, &call->response,
|
|
||||||
[call, call_opts](const Status& s) {
|
|
||||||
call->ClearCancelCallback();
|
|
||||||
delete call_opts;
|
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
|
||||||
});
|
|
||||||
});
|
|
||||||
ENQUEUE_REQUEST(RecvBuf, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CompleteGroupHandler(
|
|
||||||
WorkerCall<CompleteGroupRequest, CompleteGroupResponse>* call) {
|
|
||||||
Schedule([this, call]() {
|
|
||||||
CallOptions* call_opts = new CallOptions;
|
|
||||||
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
|
||||||
worker_->CompleteGroupAsync(call_opts, &call->request, &call->response,
|
|
||||||
[call, call_opts](const Status& s) {
|
|
||||||
call->ClearCancelCallback();
|
|
||||||
delete call_opts;
|
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
|
||||||
});
|
|
||||||
});
|
|
||||||
ENQUEUE_REQUEST(CompleteGroup, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CompleteInstanceHandler(
|
|
||||||
WorkerCall<CompleteInstanceRequest, CompleteInstanceResponse>* call) {
|
|
||||||
Schedule([this, call]() {
|
|
||||||
CallOptions* call_opts = new CallOptions;
|
|
||||||
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
|
||||||
worker_->CompleteInstanceAsync(call_opts, &call->request,
|
|
||||||
&call->response,
|
|
||||||
[call, call_opts](const Status& s) {
|
|
||||||
call->ClearCancelCallback();
|
|
||||||
delete call_opts;
|
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
|
||||||
});
|
|
||||||
});
|
|
||||||
ENQUEUE_REQUEST(CompleteInstance, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
void GetStepSequenceHandler(
|
|
||||||
WorkerCall<GetStepSequenceRequest, GetStepSequenceResponse>* call) {
|
|
||||||
Schedule([this, call]() {
|
|
||||||
worker_->GetStepSequenceAsync(
|
|
||||||
&call->request, &call->response,
|
|
||||||
[call](const Status& s) { call->SendResponse(ToGrpcStatus(s)); });
|
|
||||||
});
|
|
||||||
ENQUEUE_REQUEST(GetStepSequence, true);
|
|
||||||
}
|
|
||||||
#undef ENQUEUE_REQUEST
|
|
||||||
|
|
||||||
void EnqueueRecvTensorRequestRaw() {
|
|
||||||
mutex_lock l(shutdown_mu_);
|
|
||||||
if (!is_shutdown_) {
|
|
||||||
Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
|
|
||||||
RecvTensorRequest, ::grpc::ByteBuffer>::
|
|
||||||
EnqueueRequestForMethod(
|
|
||||||
worker_service_, cq_.get(),
|
|
||||||
static_cast<int>(GrpcWorkerMethod::kRecvTensor),
|
|
||||||
&GrpcWorkerServiceThread::RecvTensorHandlerRaw,
|
|
||||||
true /* supports cancel*/);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
GrpcWorker* const worker_ = nullptr; // Not owned.
|
|
||||||
std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
|
|
||||||
std::unique_ptr<Thread> thread_;
|
|
||||||
grpc::WorkerService::AsyncService* const worker_service_;
|
|
||||||
|
|
||||||
mutex shutdown_mu_;
|
|
||||||
bool is_shutdown_ GUARDED_BY(shutdown_mu_);
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerServiceThread);
|
|
||||||
}; // GrpcWorkerServiceThread
|
|
||||||
|
|
||||||
grpc::WorkerService::AsyncService worker_service_;
|
grpc::WorkerService::AsyncService worker_service_;
|
||||||
std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;
|
std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;
|
||||||
|
|
||||||
@ -640,9 +643,10 @@ std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
|
std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
|
||||||
GrpcWorker* worker, ::grpc::ServerBuilder* builder) {
|
GrpcWorker* worker, ::grpc::ServerBuilder* builder,
|
||||||
|
GrpcWorkerServiceOptions options) {
|
||||||
return std::unique_ptr<AsyncServiceInterface>(
|
return std::unique_ptr<AsyncServiceInterface>(
|
||||||
new GrpcWorkerService(worker, builder));
|
new GrpcWorkerService(worker, builder, options));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,7 +16,9 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
|
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
|
||||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
|
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
|
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker.h"
|
#include "tensorflow/core/distributed_runtime/worker.h"
|
||||||
|
|
||||||
namespace grpc {
|
namespace grpc {
|
||||||
@ -57,9 +59,17 @@ class GrpcWorker : public Worker {
|
|||||||
std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* worker_env,
|
std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* worker_env,
|
||||||
const ConfigProto& config);
|
const ConfigProto& config);
|
||||||
|
|
||||||
|
struct GrpcWorkerServiceOptions {
|
||||||
|
// Map from GrpcWorkerMethod id to queue depth. If set this overrides the
|
||||||
|
// default queue depth for a method.
|
||||||
|
std::unordered_map<int, int> queue_depth;
|
||||||
|
int num_worker_threads = 8;
|
||||||
|
};
|
||||||
|
|
||||||
// Returns an implementation of WorkerService rpc service.
|
// Returns an implementation of WorkerService rpc service.
|
||||||
std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
|
std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
|
||||||
GrpcWorker* worker, ::grpc::ServerBuilder* builder);
|
GrpcWorker* worker, ::grpc::ServerBuilder* builder,
|
||||||
|
GrpcWorkerServiceOptions opts = GrpcWorkerServiceOptions());
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -88,6 +88,7 @@ enum class GrpcWorkerMethod {
|
|||||||
kCompleteInstance,
|
kCompleteInstance,
|
||||||
kGetStepSequence,
|
kGetStepSequence,
|
||||||
};
|
};
|
||||||
|
|
||||||
static const int kGrpcNumWorkerMethods =
|
static const int kGrpcNumWorkerMethods =
|
||||||
static_cast<int>(GrpcWorkerMethod::kGetStepSequence) + 1;
|
static_cast<int>(GrpcWorkerMethod::kGetStepSequence) + 1;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user