diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index 89314d162b3..bf8b66dcfa5 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -17,11 +17,6 @@ filegroup( ]), ) -load( - "//tensorflow:tensorflow.bzl", - "tf_cuda_library", -) - # For platform specific build config load( "//tensorflow/core:platform/default/build_config.bzl", @@ -66,7 +61,6 @@ cc_library( ":gdr_memory_manager", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:gpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:graph_mgr", @@ -108,15 +102,13 @@ cc_library( ":gdr_memory_manager", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:cancellable_call", "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed", "//tensorflow/core/distributed_runtime:device_resolver_distributed", "//tensorflow/core/distributed_runtime:request_id", "//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr", "//tensorflow/core/distributed_runtime:worker_cache", - "//tensorflow/core/distributed_runtime:worker_env", - "//tensorflow/core/distributed_runtime:worker_interface", ], ) @@ -130,6 +122,9 @@ cc_library( ":gdr_memory_manager", ":gdr_rendezvous_mgr", ":gdr_worker", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed", + "//tensorflow/core/distributed_runtime:device_resolver_distributed", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", ], alwayslink = 1, diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h index a713c1c6b0b..1417e51e82c 100644 --- a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h +++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h @@ -16,14 +16,14 @@ limitations under the License. #define TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ #include "tensorflow/contrib/gdr/gdr_memory_manager.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" #include "tensorflow/core/framework/collective.h" namespace tensorflow { -class CollectiveParamResolverDistributed; class ConfigProto; class DeviceMgr; -class DeviceResolverDistributed; class WorkerCacheInterface; class StepSequenceRequest; class StepSequenceResponse; diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index e4ae1eec231..c39cc0f9bce 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -82,8 +82,11 @@ Status GdrServer::Init() { }; TF_RETURN_IF_ERROR(remote_memory_manager_->Init()); - return GrpcServer::Init(nullptr, rendezvous_mgr_func, collective_mgr_func, - worker_func); + GrpcServerOptions opts; + opts.rendezvous_mgr_func = rendezvous_mgr_func; + opts.collective_mgr_func = collective_mgr_func; + opts.worker_func = worker_func; + return GrpcServer::Init(opts); } Status GdrServer::Start() { diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc index 86897e3c8ef..1204b8ca501 100644 --- a/tensorflow/contrib/gdr/gdr_worker.cc +++ b/tensorflow/contrib/gdr/gdr_worker.cc @@ -35,8 +35,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/tracing.h" -#include "tensorflow/core/protobuf/transport_options.pb.h" -#include "tensorflow/core/protobuf/worker.pb.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index ac73182190f..1405c760d54 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -112,12 +112,7 @@ GrpcServer::~GrpcServer() { void GrpcServer::MaybeMutateBuilder(::grpc::ServerBuilder* builder) {} -Status GrpcServer::Init( - ServiceInitFunction service_func, - const RendezvousMgrCreationFunction& rendezvous_mgr_func, - const CollectiveMgrCreationFunction& collective_mgr_func, - const WorkerCreationFunction& worker_func, - const StatsPublisherFactory& stats_factory) { +Status GrpcServer::Init(const GrpcServerOptions& opts) { mutex_lock l(mu_); CHECK_EQ(state_, NEW); master_env_.env = env_; @@ -165,9 +160,9 @@ Status GrpcServer::Init( worker_env_.device_mgr = new DeviceMgr(std::move(devices)); master_env_.local_devices = worker_env_.device_mgr->ListDevices(); worker_env_.local_devices = worker_env_.device_mgr->ListDevices(); - worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr + worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr ? new RpcRendezvousMgr(&worker_env_) - : rendezvous_mgr_func(&worker_env_); + : opts.rendezvous_mgr_func(&worker_env_); string unused; string default_worker_name; if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(), @@ -200,15 +195,16 @@ Status GrpcServer::Init( MaybeMutateBuilder(&builder); master_impl_ = CreateMaster(&master_env_); master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder); - worker_impl_ = worker_func ? worker_func(&worker_env_, config) - : NewGrpcWorker(&worker_env_, config); - worker_service_ = - NewGrpcWorkerService(worker_impl_.get(), &builder).release(); + worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config) + : NewGrpcWorker(&worker_env_, config); + worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder, + opts.worker_service_options) + .release(); eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder); // extra service: - if (service_func != nullptr) { - service_func(&worker_env_, &builder); + if (opts.service_func != nullptr) { + opts.service_func(&worker_env_, &builder); } server_ = builder.BuildAndStart(); @@ -222,9 +218,9 @@ Status GrpcServer::Init( WorkerCacheFactory(worker_cache_factory_options, &worker_cache)); CHECK_NE(nullptr, worker_cache); - if (collective_mgr_func) { + if (opts.collective_mgr_func) { worker_env_.collective_executor_mgr = - collective_mgr_func(config, &worker_env_, worker_cache); + opts.collective_mgr_func(config, &worker_env_, worker_cache); if (!worker_env_.collective_executor_mgr) { return errors::Internal( "collective_mgr_func did not return CollectiveExecutorMgr"); @@ -256,6 +252,7 @@ Status GrpcServer::Init( master_env_.ops = OpRegistry::Global(); master_env_.worker_cache = worker_cache; master_env_.collective_executor_mgr = worker_env_.collective_executor_mgr; + StatsPublisherFactory stats_factory = opts.stats_factory; master_env_.master_session_factory = [config, stats_factory]( SessionOptions options, const MasterEnv* env, @@ -282,31 +279,6 @@ Status GrpcServer::Init( return Status::OK(); } -Status GrpcServer::Init( - ServiceInitFunction service_func, - const RendezvousMgrCreationFunction& rendezvous_mgr_func, - const CollectiveMgrCreationFunction& collective_mgr_func, - const WorkerCreationFunction& worker_func) { - return Init(std::move(service_func), rendezvous_mgr_func, collective_mgr_func, - worker_func, CreateNoOpStatsPublisher); -} - -Status GrpcServer::Init( - ServiceInitFunction service_func, - const RendezvousMgrCreationFunction& rendezvous_mgr_func, - const CollectiveMgrCreationFunction& collective_mgr_func) { - return Init(std::move(service_func), rendezvous_mgr_func, collective_mgr_func, - nullptr); -} - -Status GrpcServer::Init( - ServiceInitFunction service_func, - const RendezvousMgrCreationFunction& rendezvous_mgr_func) { - return Init(std::move(service_func), rendezvous_mgr_func, nullptr, nullptr); -} - -Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr, nullptr); } - Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, GrpcChannelSpec* channel_spec) { for (const auto& job : options.cluster_def->job()) { @@ -457,7 +429,9 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env, std::unique_ptr ret( new GrpcServer(server_def, env == nullptr ? Env::Default() : env)); ServiceInitFunction service_func = nullptr; - Status s = ret->Init(service_func, NewRpcRendezvousMgr, nullptr); + GrpcServerOptions options; + options.rendezvous_mgr_func = NewRpcRendezvousMgr; + Status s = ret->Init(); if (!s.ok()) { LOG(ERROR) << s; return s; @@ -471,8 +445,9 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env, std::unique_ptr* out_server) { std::unique_ptr ret( new GrpcServer(server_def, env == nullptr ? Env::Default() : env)); - ServiceInitFunction service_func = nullptr; - Status s = ret->Init(service_func, NewRpcRendezvousMgr, nullptr); + GrpcServerOptions options; + options.rendezvous_mgr_func = NewRpcRendezvousMgr; + Status s = ret->Init(options); if (!s.ok()) { LOG(ERROR) << s; return s; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index c7f543e5bfc..f66d7eb82e8 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ +// GrpcServer manages the lifecycle of an Eager, Worker and Master service. + #include #include "grpcpp/grpcpp.h" @@ -26,6 +28,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/master_env.h" #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/distributed_runtime/worker_env.h" @@ -57,6 +60,15 @@ typedef std::function(WorkerEnv*, const ConfigProto& config)> WorkerCreationFunction; +struct GrpcServerOptions { + ServiceInitFunction service_func = nullptr; + RendezvousMgrCreationFunction rendezvous_mgr_func = nullptr; + CollectiveMgrCreationFunction collective_mgr_func = nullptr; + WorkerCreationFunction worker_func = nullptr; + StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher; + GrpcWorkerServiceOptions worker_service_options; +}; + class GrpcServer : public ServerInterface { protected: GrpcServer(const ServerDef& server_def, Env* env); @@ -86,25 +98,7 @@ class GrpcServer : public ServerInterface { std::shared_ptr channel_cache() { return channel_cache_; } protected: - Status Init(ServiceInitFunction service_func, - const RendezvousMgrCreationFunction& rendezvous_mgr_func, - const CollectiveMgrCreationFunction& collective_mgr_func, - const WorkerCreationFunction& worker_func, - const StatsPublisherFactory& stats_factory); - - Status Init(ServiceInitFunction service_func, - const RendezvousMgrCreationFunction& rendezvous_mgr_func, - const CollectiveMgrCreationFunction& collective_mgr_func, - const WorkerCreationFunction& worker_func); - - Status Init(ServiceInitFunction service_func, - const RendezvousMgrCreationFunction& rendezvous_mgr_func, - const CollectiveMgrCreationFunction& collective_mgr_func); - - Status Init(ServiceInitFunction service_func, - const RendezvousMgrCreationFunction& rendezvous_mgr_func); - - Status Init(); + Status Init(const GrpcServerOptions& opts = GrpcServerOptions()); // A subclass can override this method to support secure credentials. virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index de80992095d..db0ad6124cc 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" #include +#include #include "grpcpp/alarm.h" #include "grpcpp/server_builder.h" @@ -41,6 +42,7 @@ limitations under the License. #include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/protobuf/transport_options.pb.h" @@ -50,37 +52,6 @@ namespace tensorflow { namespace { -class GrpcWorkerService : public AsyncServiceInterface { - // TODO(ncteisen): consider adding a config var or flag for this - static constexpr const size_t kGrpcWorkerServiceThreadCount = 8; - - public: - GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder) - : is_shutdown_(false) { - builder->RegisterService(&worker_service_); - for (int i = 0; i < kGrpcWorkerServiceThreadCount; i++) { - threads_.emplace_back( - new GrpcWorkerServiceThread(worker, builder, &worker_service_)); - } - } - - void Shutdown() override { - bool did_shutdown = false; - { - mutex_lock l(service_shutdown_mu_); - if (!is_shutdown_) { - LOG(INFO) << "Shutting down GrpcWorkerService."; - is_shutdown_ = true; - did_shutdown = true; - } - } - if (did_shutdown) { - for (auto& worker_thread : threads_) { - worker_thread->Shutdown(); - } - } - } - // This macro creates a new request for the given RPC method name // (e.g., `ENQUEUE_REQUEST(GetStatus, false);`), and enqueues it on // `this->cq_`. @@ -105,6 +76,329 @@ class GrpcWorkerService : public AsyncServiceInterface { } \ } while (0) +#define SETUP_FOR_REQUEST(method, default_depth, supports_cancel) \ + for (int i = 0; \ + i < gtl::FindWithDefault(queue_depth_, \ + static_cast(GrpcWorkerMethod::k##method), \ + default_depth); \ + ++i) { \ + ENQUEUE_REQUEST(method, supports_cancel); \ + } + +// GrpcWorkerService spawns one or more GrpcWorkerServiceThreads to service +// requests. Each thread operates on an independent completion queue. +class GrpcWorkerServiceThread { + public: + explicit GrpcWorkerServiceThread( + GrpcWorker* worker, ::grpc::ServerBuilder* builder, + std::unordered_map queue_depth, + grpc::WorkerService::AsyncService* worker_service) + : worker_(worker), + queue_depth_(queue_depth), + worker_service_(worker_service), + is_shutdown_(false) { + cq_ = builder->AddCompletionQueue(); + } + + void Start() { + thread_.reset( + worker_->env()->env->StartThread(ThreadOptions(), "grpc_worker_service", + [this]() { HandleRPCsLoop(); })); + } + + void Join() { thread_.reset(); } // Blocks until thread exits + + void Shutdown() { + { + mutex_lock lock(shutdown_mu_); + is_shutdown_ = true; + } + cq_->Shutdown(); + } + + private: + // Add one or more completion queue entries for each worker method, then + // begin servicing requests from the completion queue. + void HandleRPCsLoop() { + // TODO(ncteisen): This may require performance engineering. We can + // change the number of threads, the number of handlers per thread, + // or even decide to specialize certain threads to certain methods. + SETUP_FOR_REQUEST(GetStatus, 1, false); + SETUP_FOR_REQUEST(CreateWorkerSession, 1, false); + SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false); + SETUP_FOR_REQUEST(CleanupAll, 1, false); + SETUP_FOR_REQUEST(RegisterGraph, 1, false); + SETUP_FOR_REQUEST(DeregisterGraph, 1, false); + SETUP_FOR_REQUEST(Logging, 1, false); + SETUP_FOR_REQUEST(Tracing, 1, false); + SETUP_FOR_REQUEST(CompleteGroup, 10, true); + SETUP_FOR_REQUEST(CompleteInstance, 10, true); + SETUP_FOR_REQUEST(GetStepSequence, 10, true); + SETUP_FOR_REQUEST(RecvBuf, 500, true); + SETUP_FOR_REQUEST(RunGraph, 100, true); + SETUP_FOR_REQUEST(CleanupGraph, 100, false); + + // TODO(ncteisen): Determine a better policy for enqueuing the + // appropriate number of each request type. + for (int i = 0; + i < gtl::FindWithDefault( + queue_depth_, static_cast(GrpcWorkerMethod::kRecvTensor), + 1000); + ++i) { + EnqueueRecvTensorRequestRaw(); + } + + void* tag; + bool ok; + + while (cq_->Next(&tag, &ok)) { + UntypedCall::Tag* callback_tag = + static_cast::Tag*>(tag); + CHECK(callback_tag); + callback_tag->OnCompleted(this, ok); + } + } + + private: + void Schedule(std::function f) { + worker_->env()->compute_pool->Schedule(std::move(f)); + } + + // The following section contains one request handler method per + // RPC. The `FooHandler` method is called (indirectly) by + // `HandleRPCsLoop()` when the next Foo RPC is received. Each + // `FooHandler` call schedules a closure on `worker_->env()->compute_pool`, + // and is responsible for requesting the next Foo call by calling + // `ENQUEUE_REQUEST(Foo)`. + template + using WorkerCall = + Call; + + void GetStatusHandler(WorkerCall* call) { + Schedule([this, call]() { + Status s = worker_->GetStatus(&call->request, &call->response); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(GetStatus, false); + } + + void CreateWorkerSessionHandler( + WorkerCall* + call) { + Schedule([this, call]() { + Status s = worker_->CreateWorkerSession(&call->request, &call->response); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(CreateWorkerSession, false); + } + + void DeleteWorkerSessionHandler( + WorkerCall* + call) { + Schedule([this, call]() { + Status s = worker_->DeleteWorkerSession(&call->request, &call->response); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(DeleteWorkerSession, false); + } + + void CleanupAllHandler( + WorkerCall* call) { + Schedule([this, call]() { + Status s = worker_->CleanupAll(&call->request, &call->response); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(CleanupAll, false); + } + + void RegisterGraphHandler( + WorkerCall* call) { + Schedule([this, call]() { + Status s = worker_->RegisterGraph(&call->request, &call->response); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(RegisterGraph, false); + } + + void DeregisterGraphHandler( + WorkerCall* call) { + Schedule([this, call]() { + Status s = worker_->DeregisterGraph(&call->request, &call->response); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(DeregisterGraph, false); + } + + void RunGraphHandler(WorkerCall* call) { + Schedule([this, call]() { + CallOptions* call_opts = new CallOptions; + ProtoRunGraphRequest* wrapped_request = + new ProtoRunGraphRequest(&call->request); + NonOwnedProtoRunGraphResponse* wrapped_response = + new NonOwnedProtoRunGraphResponse(&call->response); + call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); + worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response, + [call, call_opts, wrapped_request, + wrapped_response](const Status& s) { + call->ClearCancelCallback(); + delete call_opts; + delete wrapped_request; + delete wrapped_response; + call->SendResponse(ToGrpcStatus(s)); + }); + }); + ENQUEUE_REQUEST(RunGraph, true); + } + + void RecvTensorHandlerRaw( + WorkerCall* call) { + Schedule([this, call]() { + CallOptions* call_opts = new CallOptions; + call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); + worker_->GrpcRecvTensorAsync(call_opts, &call->request, &call->response, + [call, call_opts](const Status& s) { + call->ClearCancelCallback(); + delete call_opts; + call->SendResponse(ToGrpcStatus(s)); + }); + }); + EnqueueRecvTensorRequestRaw(); + } + + void CleanupGraphHandler( + WorkerCall* call) { + Schedule([this, call]() { + Status s = worker_->CleanupGraph(&call->request, &call->response); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(CleanupGraph, false); + } + + void LoggingHandler(WorkerCall* call) { + Schedule([this, call]() { + Status s = worker_->Logging(&call->request, &call->response); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(Logging, false); + } + + void TracingHandler(WorkerCall* call) { + Schedule([this, call]() { + Status s = worker_->Tracing(&call->request, &call->response); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(Tracing, false); + } + + void RecvBufHandler(WorkerCall* call) { + Schedule([this, call]() { + CallOptions* call_opts = new CallOptions; + call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); + worker_->RecvBufAsync(call_opts, &call->request, &call->response, + [call, call_opts](const Status& s) { + call->ClearCancelCallback(); + delete call_opts; + call->SendResponse(ToGrpcStatus(s)); + }); + }); + ENQUEUE_REQUEST(RecvBuf, true); + } + + void CompleteGroupHandler( + WorkerCall* call) { + Schedule([this, call]() { + CallOptions* call_opts = new CallOptions; + call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); + worker_->CompleteGroupAsync(call_opts, &call->request, &call->response, + [call, call_opts](const Status& s) { + call->ClearCancelCallback(); + delete call_opts; + call->SendResponse(ToGrpcStatus(s)); + }); + }); + ENQUEUE_REQUEST(CompleteGroup, true); + } + + void CompleteInstanceHandler( + WorkerCall* call) { + Schedule([this, call]() { + CallOptions* call_opts = new CallOptions; + call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); + worker_->CompleteInstanceAsync(call_opts, &call->request, &call->response, + [call, call_opts](const Status& s) { + call->ClearCancelCallback(); + delete call_opts; + call->SendResponse(ToGrpcStatus(s)); + }); + }); + ENQUEUE_REQUEST(CompleteInstance, false); + } + + void GetStepSequenceHandler( + WorkerCall* call) { + Schedule([this, call]() { + worker_->GetStepSequenceAsync( + &call->request, &call->response, + [call](const Status& s) { call->SendResponse(ToGrpcStatus(s)); }); + }); + ENQUEUE_REQUEST(GetStepSequence, true); + } +#undef ENQUEUE_REQUEST + + void EnqueueRecvTensorRequestRaw() { + mutex_lock l(shutdown_mu_); + if (!is_shutdown_) { + Call:: + EnqueueRequestForMethod( + worker_service_, cq_.get(), + static_cast(GrpcWorkerMethod::kRecvTensor), + &GrpcWorkerServiceThread::RecvTensorHandlerRaw, + true /* supports cancel*/); + } + } + + GrpcWorker* const worker_ = nullptr; // Not owned. + std::unique_ptr<::grpc::ServerCompletionQueue> cq_; + std::unique_ptr thread_; + std::unordered_map queue_depth_; + grpc::WorkerService::AsyncService* const worker_service_; + + mutex shutdown_mu_; + bool is_shutdown_ GUARDED_BY(shutdown_mu_); + TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerServiceThread); +}; + +class GrpcWorkerService : public AsyncServiceInterface { + public: + GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder, + GrpcWorkerServiceOptions options) + : is_shutdown_(false) { + builder->RegisterService(&worker_service_); + for (int i = 0; i < options.num_worker_threads; i++) { + threads_.emplace_back(new GrpcWorkerServiceThread( + worker, builder, options.queue_depth, &worker_service_)); + } + } + + void Shutdown() override { + bool did_shutdown = false; + { + mutex_lock l(service_shutdown_mu_); + if (!is_shutdown_) { + LOG(INFO) << "Shutting down GrpcWorkerService."; + is_shutdown_ = true; + did_shutdown = true; + } + } + if (did_shutdown) { + for (auto& worker_thread : threads_) { + worker_thread->Shutdown(); + } + } + } + // This method blocks forever handling requests from the completion queue. void HandleRPCsLoop() override { for (auto& worker_thread : threads_) { @@ -116,297 +410,6 @@ class GrpcWorkerService : public AsyncServiceInterface { } private: - // Thread wrapping class that drives work over a single gRPC - // CompletionQueue. - class GrpcWorkerServiceThread { - public: - explicit GrpcWorkerServiceThread( - GrpcWorker* worker, ::grpc::ServerBuilder* builder, - grpc::WorkerService::AsyncService* worker_service) - : worker_(worker), - worker_service_(worker_service), - is_shutdown_(false) { - cq_ = builder->AddCompletionQueue(); - } - - void Start() { - thread_.reset(worker_->env()->env->StartThread( - ThreadOptions(), "grpc_worker_service", - [this]() { HandleRPCsLoop(); })); - } - - void Join() { thread_.reset(); } // Blocks until thread exits - - void Shutdown() { - { - mutex_lock lock(shutdown_mu_); - is_shutdown_ = true; - } - cq_->Shutdown(); - } - - private: - void HandleRPCsLoop() { - // TODO(ncteisen): This may require performance engineering. We can - // change the number of threads, the number of handlers per thread, - // or even decide to specialize certain threads to certain methods. - ENQUEUE_REQUEST(GetStatus, false); - ENQUEUE_REQUEST(CreateWorkerSession, false); - ENQUEUE_REQUEST(DeleteWorkerSession, false); - ENQUEUE_REQUEST(CleanupAll, false); - ENQUEUE_REQUEST(RegisterGraph, false); - ENQUEUE_REQUEST(DeregisterGraph, false); - - // TODO(ncteisen): Determine a better policy for enqueuing the - // appropriate number of each request type. - for (int i = 0; i < 1000; ++i) { - EnqueueRecvTensorRequestRaw(); - } - for (int i = 0; i < 500; ++i) { - ENQUEUE_REQUEST(RecvBuf, true); - } - for (int i = 0; i < 100; ++i) { - ENQUEUE_REQUEST(RunGraph, true); - } - for (int i = 0; i < 100; ++i) { - ENQUEUE_REQUEST(CleanupGraph, false); - } - - ENQUEUE_REQUEST(Logging, false); - ENQUEUE_REQUEST(Tracing, false); - - for (int i = 0; i < 10; ++i) { - ENQUEUE_REQUEST(CompleteGroup, true); - ENQUEUE_REQUEST(CompleteInstance, true); - ENQUEUE_REQUEST(GetStepSequence, true); - } - - void* tag; - bool ok; - - while (cq_->Next(&tag, &ok)) { - UntypedCall::Tag* callback_tag = - static_cast::Tag*>(tag); - CHECK(callback_tag); - callback_tag->OnCompleted(this, ok); - } - } - - private: - void Schedule(std::function f) { - worker_->env()->compute_pool->Schedule(std::move(f)); - } - - // The following section contains one request handler method per - // RPC. The `FooHandler` method is called (indirectly) by - // `HandleRPCsLoop()` when the next Foo RPC is received. Each - // `FooHandler` call schedules a closure on `worker_->env()->compute_pool`, - // and is responsible for requesting the next Foo call by calling - // `ENQUEUE_REQUEST(Foo)`. - - template - using WorkerCall = - Call; - - void GetStatusHandler( - WorkerCall* call) { - Schedule([this, call]() { - Status s = worker_->GetStatus(&call->request, &call->response); - call->SendResponse(ToGrpcStatus(s)); - }); - ENQUEUE_REQUEST(GetStatus, false); - } - - void CreateWorkerSessionHandler( - WorkerCall* - call) { - Schedule([this, call]() { - Status s = - worker_->CreateWorkerSession(&call->request, &call->response); - call->SendResponse(ToGrpcStatus(s)); - }); - ENQUEUE_REQUEST(CreateWorkerSession, false); - } - - void DeleteWorkerSessionHandler( - WorkerCall* - call) { - Schedule([this, call]() { - Status s = - worker_->DeleteWorkerSession(&call->request, &call->response); - call->SendResponse(ToGrpcStatus(s)); - }); - ENQUEUE_REQUEST(DeleteWorkerSession, false); - } - - void CleanupAllHandler( - WorkerCall* call) { - Schedule([this, call]() { - Status s = worker_->CleanupAll(&call->request, &call->response); - call->SendResponse(ToGrpcStatus(s)); - }); - ENQUEUE_REQUEST(CleanupAll, false); - } - - void RegisterGraphHandler( - WorkerCall* call) { - Schedule([this, call]() { - Status s = worker_->RegisterGraph(&call->request, &call->response); - call->SendResponse(ToGrpcStatus(s)); - }); - ENQUEUE_REQUEST(RegisterGraph, false); - } - - void DeregisterGraphHandler( - WorkerCall* call) { - Schedule([this, call]() { - Status s = worker_->DeregisterGraph(&call->request, &call->response); - call->SendResponse(ToGrpcStatus(s)); - }); - ENQUEUE_REQUEST(DeregisterGraph, false); - } - - void RunGraphHandler(WorkerCall* call) { - Schedule([this, call]() { - CallOptions* call_opts = new CallOptions; - ProtoRunGraphRequest* wrapped_request = - new ProtoRunGraphRequest(&call->request); - NonOwnedProtoRunGraphResponse* wrapped_response = - new NonOwnedProtoRunGraphResponse(&call->response); - call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); - worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response, - [call, call_opts, wrapped_request, - wrapped_response](const Status& s) { - call->ClearCancelCallback(); - delete call_opts; - delete wrapped_request; - delete wrapped_response; - call->SendResponse(ToGrpcStatus(s)); - }); - }); - ENQUEUE_REQUEST(RunGraph, true); - } - - void RecvTensorHandlerRaw( - WorkerCall* call) { - Schedule([this, call]() { - CallOptions* call_opts = new CallOptions; - call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); - worker_->GrpcRecvTensorAsync(call_opts, &call->request, &call->response, - [call, call_opts](const Status& s) { - call->ClearCancelCallback(); - delete call_opts; - call->SendResponse(ToGrpcStatus(s)); - }); - }); - EnqueueRecvTensorRequestRaw(); - } - - void CleanupGraphHandler( - WorkerCall* call) { - Schedule([this, call]() { - Status s = worker_->CleanupGraph(&call->request, &call->response); - call->SendResponse(ToGrpcStatus(s)); - }); - ENQUEUE_REQUEST(CleanupGraph, false); - } - - void LoggingHandler(WorkerCall* call) { - Schedule([this, call]() { - Status s = worker_->Logging(&call->request, &call->response); - call->SendResponse(ToGrpcStatus(s)); - }); - ENQUEUE_REQUEST(Logging, false); - } - - void TracingHandler(WorkerCall* call) { - Schedule([this, call]() { - Status s = worker_->Tracing(&call->request, &call->response); - call->SendResponse(ToGrpcStatus(s)); - }); - ENQUEUE_REQUEST(Tracing, false); - } - - void RecvBufHandler(WorkerCall* call) { - Schedule([this, call]() { - CallOptions* call_opts = new CallOptions; - call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); - worker_->RecvBufAsync(call_opts, &call->request, &call->response, - [call, call_opts](const Status& s) { - call->ClearCancelCallback(); - delete call_opts; - call->SendResponse(ToGrpcStatus(s)); - }); - }); - ENQUEUE_REQUEST(RecvBuf, true); - } - - void CompleteGroupHandler( - WorkerCall* call) { - Schedule([this, call]() { - CallOptions* call_opts = new CallOptions; - call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); - worker_->CompleteGroupAsync(call_opts, &call->request, &call->response, - [call, call_opts](const Status& s) { - call->ClearCancelCallback(); - delete call_opts; - call->SendResponse(ToGrpcStatus(s)); - }); - }); - ENQUEUE_REQUEST(CompleteGroup, true); - } - - void CompleteInstanceHandler( - WorkerCall* call) { - Schedule([this, call]() { - CallOptions* call_opts = new CallOptions; - call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); - worker_->CompleteInstanceAsync(call_opts, &call->request, - &call->response, - [call, call_opts](const Status& s) { - call->ClearCancelCallback(); - delete call_opts; - call->SendResponse(ToGrpcStatus(s)); - }); - }); - ENQUEUE_REQUEST(CompleteInstance, false); - } - - void GetStepSequenceHandler( - WorkerCall* call) { - Schedule([this, call]() { - worker_->GetStepSequenceAsync( - &call->request, &call->response, - [call](const Status& s) { call->SendResponse(ToGrpcStatus(s)); }); - }); - ENQUEUE_REQUEST(GetStepSequence, true); - } -#undef ENQUEUE_REQUEST - - void EnqueueRecvTensorRequestRaw() { - mutex_lock l(shutdown_mu_); - if (!is_shutdown_) { - Call:: - EnqueueRequestForMethod( - worker_service_, cq_.get(), - static_cast(GrpcWorkerMethod::kRecvTensor), - &GrpcWorkerServiceThread::RecvTensorHandlerRaw, - true /* supports cancel*/); - } - } - - GrpcWorker* const worker_ = nullptr; // Not owned. - std::unique_ptr<::grpc::ServerCompletionQueue> cq_; - std::unique_ptr thread_; - grpc::WorkerService::AsyncService* const worker_service_; - - mutex shutdown_mu_; - bool is_shutdown_ GUARDED_BY(shutdown_mu_); - TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerServiceThread); - }; // GrpcWorkerServiceThread - grpc::WorkerService::AsyncService worker_service_; std::vector> threads_; @@ -640,9 +643,10 @@ std::unique_ptr NewGrpcWorker(WorkerEnv* env, } std::unique_ptr NewGrpcWorkerService( - GrpcWorker* worker, ::grpc::ServerBuilder* builder) { + GrpcWorker* worker, ::grpc::ServerBuilder* builder, + GrpcWorkerServiceOptions options) { return std::unique_ptr( - new GrpcWorkerService(worker, builder)); + new GrpcWorkerService(worker, builder, options)); } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h index 996617d385d..88beb6c2165 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ +#include #include "tensorflow/core/distributed_runtime/recent_request_ids.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h" #include "tensorflow/core/distributed_runtime/worker.h" namespace grpc { @@ -57,9 +59,17 @@ class GrpcWorker : public Worker { std::unique_ptr NewGrpcWorker(WorkerEnv* worker_env, const ConfigProto& config); +struct GrpcWorkerServiceOptions { + // Map from GrpcWorkerMethod id to queue depth. If set this overrides the + // default queue depth for a method. + std::unordered_map queue_depth; + int num_worker_threads = 8; +}; + // Returns an implementation of WorkerService rpc service. std::unique_ptr NewGrpcWorkerService( - GrpcWorker* worker, ::grpc::ServerBuilder* builder); + GrpcWorker* worker, ::grpc::ServerBuilder* builder, + GrpcWorkerServiceOptions opts = GrpcWorkerServiceOptions()); } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h index 7915c3aafd8..d2ae4eeaeec 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h @@ -88,6 +88,7 @@ enum class GrpcWorkerMethod { kCompleteInstance, kGetStepSequence, }; + static const int kGrpcNumWorkerMethods = static_cast(GrpcWorkerMethod::kGetStepSequence) + 1;