From fd1481780bcc57426bf3158a8f94eab95e529384 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 23 Jul 2020 22:02:59 -0700 Subject: [PATCH] Use proto to configure tf.data service worker server. This simplifies adding new configuration properties, so that we don't need to plumb new properties through. This also gives us a single place to document all configuration options (in the .proto file). PiperOrigin-RevId: 322934622 Change-Id: I547740e3c9224c7b74ecf2853672ffeb226d61d1 --- tensorflow/core/data/service/BUILD | 1 - .../core/data/service/grpc_worker_impl.cc | 5 ++-- .../core/data/service/grpc_worker_impl.h | 4 +-- tensorflow/core/data/service/server_lib.cc | 30 ++++++++++++++----- tensorflow/core/data/service/server_lib.h | 25 ++++++++++++++-- tensorflow/core/data/service/test_cluster.cc | 8 ++--- tensorflow/core/data/service/worker_impl.cc | 16 +++++----- tensorflow/core/data/service/worker_impl.h | 13 ++++---- .../data/experimental/service_config.proto | 15 ---------- .../data/experimental/service/server_lib.py | 7 +---- .../service/server_lib_wrapper.cc | 12 +++----- 11 files changed, 71 insertions(+), 65 deletions(-) diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 913cbf26cf0..d7cc7a3e528 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -227,7 +227,6 @@ cc_library( deps = [ ":worker_cc_grpc_proto", ":worker_impl", - "//tensorflow/core:protos_all_cc", "//tensorflow/core/distributed_runtime/rpc:grpc_util", tf_grpc_cc_dependency(), ], diff --git a/tensorflow/core/data/service/grpc_worker_impl.cc b/tensorflow/core/data/service/grpc_worker_impl.cc index c76e1062753..0cddfce4e0b 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.cc +++ b/tensorflow/core/data/service/grpc_worker_impl.cc @@ -26,8 +26,9 @@ using ::grpc::ServerContext; using ::grpc::Status; GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder, - const experimental::WorkerConfig& config) - : impl_(config) { + const std::string& dispatcher_address, + const std::string& protocol) + : impl_(dispatcher_address, protocol) { server_builder->RegisterService(this); VLOG(1) << "Registered data service worker"; } diff --git a/tensorflow/core/data/service/grpc_worker_impl.h b/tensorflow/core/data/service/grpc_worker_impl.h index b0881143a57..169ae29ea37 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.h +++ b/tensorflow/core/data/service/grpc_worker_impl.h @@ -19,7 +19,6 @@ limitations under the License. #include "grpcpp/server_builder.h" #include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/data/service/worker_impl.h" -#include "tensorflow/core/protobuf/data/experimental/service_config.pb.h" namespace tensorflow { namespace data { @@ -36,7 +35,8 @@ namespace data { class GrpcWorkerImpl : public WorkerService::Service { public: explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder, - const experimental::WorkerConfig& config); + const std::string& dispatcher_address, + const std::string& protocol); ~GrpcWorkerImpl() override {} void Start(const std::string& worker_address); diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc index 648a189717e..6d912b1c802 100644 --- a/tensorflow/core/data/service/server_lib.cc +++ b/tensorflow/core/data/service/server_lib.cc @@ -79,7 +79,8 @@ DispatchGrpcDataServer::DispatchGrpcDataServer( DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; } void DispatchGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) { - service_ = absl::make_unique(builder, config_).release(); + auto service = absl::make_unique(builder, config_); + service_ = service.release(); } Status DispatchGrpcDataServer::NumWorkers(int* num_workers) { @@ -95,17 +96,22 @@ Status DispatchGrpcDataServer::NumWorkers(int* num_workers) { } WorkerGrpcDataServer::WorkerGrpcDataServer( - const experimental::WorkerConfig& config) - : GrpcDataServerBase(config.port(), config.protocol()), config_(config) {} + int port, const std::string& protocol, + const std::string& dispatcher_address, const std::string& worker_address) + : GrpcDataServerBase(port, protocol), + dispatcher_address_(dispatcher_address), + worker_address_(worker_address) {} WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; } void WorkerGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) { - service_ = absl::make_unique(builder, config_).release(); + auto service = absl::make_unique(builder, dispatcher_address_, + protocol_); + service_ = service.release(); } Status WorkerGrpcDataServer::StartServiceInternal() { - std::string worker_address = config_.worker_address(); + std::string worker_address = worker_address_; if (worker_address.empty()) { worker_address = absl::StrCat("localhost:", kPortPlaceholder); } @@ -122,9 +128,19 @@ Status NewDispatchServer(const experimental::DispatcherConfig& config, return Status::OK(); } -Status NewWorkerServer(const experimental::WorkerConfig& config, +Status NewWorkerServer(int port, const std::string& protocol, + const std::string& dispatcher_address, std::unique_ptr* out_server) { - *out_server = absl::make_unique(config); + return NewWorkerServer(port, protocol, dispatcher_address, + /*worker_address=*/"", out_server); +} + +Status NewWorkerServer(int port, const std::string& protocol, + const std::string& dispatcher_address, + const std::string& worker_address, + std::unique_ptr* out_server) { + *out_server = absl::make_unique( + port, protocol, dispatcher_address, worker_address); return Status::OK(); } diff --git a/tensorflow/core/data/service/server_lib.h b/tensorflow/core/data/service/server_lib.h index 365241753fb..d147f47c5e4 100644 --- a/tensorflow/core/data/service/server_lib.h +++ b/tensorflow/core/data/service/server_lib.h @@ -91,7 +91,9 @@ class DispatchGrpcDataServer : public GrpcDataServerBase { class WorkerGrpcDataServer : public GrpcDataServerBase { public: - explicit WorkerGrpcDataServer(const experimental::WorkerConfig& config); + WorkerGrpcDataServer(int requested_port, const std::string& protocol, + const std::string& dispatcher_address, + const std::string& worker_address); ~WorkerGrpcDataServer() override; protected: @@ -99,7 +101,8 @@ class WorkerGrpcDataServer : public GrpcDataServerBase { Status StartServiceInternal() override; private: - const experimental::WorkerConfig config_; + const std::string dispatcher_address_; + const std::string worker_address_; // Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared. GrpcWorkerImpl* service_; }; @@ -109,7 +112,23 @@ Status NewDispatchServer(const experimental::DispatcherConfig& config, std::unique_ptr* out_server); // Creates a worker tf.data server and stores it in `*out_server`. -Status NewWorkerServer(const experimental::WorkerConfig& config, +// +// The port can be a specific port or 0. If the port is 0, an available port +// will be chosen in Start(). This value can be queried with BoundPort(). +// +// The worker_address argument is optional. If left empty, it will default to +// "localhost:%port%". When the worker registers with the dispatcher, the worker +// will report the worker address, so that the dispatcher can tell clients where +// to read from. The address may contain the placeholder "%port%", which will be +// replaced with the value of BoundPort(). +Status NewWorkerServer(int port, const std::string& protocol, + const std::string& dispatcher_address, + const std::string& worker_address, + std::unique_ptr* out_server); + +// Creates a worker using the default worker_address. +Status NewWorkerServer(int port, const std::string& protocol, + const std::string& dispatcher_address, std::unique_ptr* out_server); } // namespace data diff --git a/tensorflow/core/data/service/test_cluster.cc b/tensorflow/core/data/service/test_cluster.cc index 8ae3f191407..ad0d2be87d8 100644 --- a/tensorflow/core/data/service/test_cluster.cc +++ b/tensorflow/core/data/service/test_cluster.cc @@ -62,12 +62,8 @@ Status TestCluster::Initialize() { Status TestCluster::AddWorker() { std::unique_ptr worker; - experimental::WorkerConfig config; - config.set_port(0); - config.set_protocol(kProtocol); - config.set_dispatcher_address(dispatcher_address_); - config.set_worker_address("localhost:%port%"); - TF_RETURN_IF_ERROR(NewWorkerServer(config, &worker)); + TF_RETURN_IF_ERROR( + NewWorkerServer(/*port=*/0, kProtocol, dispatcher_address_, &worker)); TF_RETURN_IF_ERROR(worker->Start()); worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort())); workers_.push_back(std::move(worker)); diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index 39508b1eab0..00659e1d048 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -46,8 +46,8 @@ auto* tf_data_service_created = } // namespace DataServiceWorkerImpl::DataServiceWorkerImpl( - const experimental::WorkerConfig& config) - : config_(config) { + const std::string& dispatcher_address, const std::string& protocol) + : dispatcher_address_(dispatcher_address), protocol_(protocol) { tf_data_service_created->GetCell()->Set(true); } @@ -68,7 +68,7 @@ void DataServiceWorkerImpl::Start(const std::string& worker_address) { Status s = Register(); while (!s.ok()) { LOG(WARNING) << "Failed to register with dispatcher at " - << config_.dispatcher_address() << ": " << s; + << dispatcher_address_ << ": " << s; Env::Default()->SleepForMicroseconds(kHeartbeatIntervalMicros); s = Register(); } @@ -173,17 +173,17 @@ Status DataServiceWorkerImpl::EnsureDispatcherStubInitialized() if (!dispatcher_stub_) { ::grpc::ChannelArguments args; std::shared_ptr<::grpc::ChannelCredentials> credentials; - TF_RETURN_IF_ERROR(CredentialsFactory::CreateClientCredentials( - config_.protocol(), &credentials)); - auto channel = ::grpc::CreateCustomChannel(config_.dispatcher_address(), - credentials, args); + TF_RETURN_IF_ERROR( + CredentialsFactory::CreateClientCredentials(protocol_, &credentials)); + auto channel = + ::grpc::CreateCustomChannel(dispatcher_address_, credentials, args); dispatcher_stub_ = DispatcherService::NewStub(channel); } return Status::OK(); } Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - VLOG(3) << "Registering with dispatcher at " << config_.dispatcher_address(); + VLOG(3) << "Registering with dispatcher at " << dispatcher_address_; TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized()); RegisterWorkerRequest req; req.set_worker_address(worker_address_); diff --git a/tensorflow/core/data/service/worker_impl.h b/tensorflow/core/data/service/worker_impl.h index 6961312ee34..adb3e97bbea 100644 --- a/tensorflow/core/data/service/worker_impl.h +++ b/tensorflow/core/data/service/worker_impl.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/core/data/service/worker.pb.h" #include "tensorflow/core/data/standalone.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/protobuf/data/experimental/service_config.pb.h" #include "tensorflow/core/public/session.h" namespace tensorflow { @@ -30,14 +29,12 @@ namespace data { // A TensorFlow DataService serves dataset elements over RPC. class DataServiceWorkerImpl { public: - explicit DataServiceWorkerImpl(const experimental::WorkerConfig& config); + explicit DataServiceWorkerImpl(const std::string& dispatcher_address, + const std::string& protocol); ~DataServiceWorkerImpl(); // Starts the worker. The worker needs to know its own address so that it can - // register with the dispatcher. This is set in `Start` instead of in the - // constructor because the worker may be binding to port `0`, in which case - // the address isn't known until the worker has started and decided which port - // to bind to. + // register with the dispatcher. void Start(const std::string& worker_address); // See worker.proto for API documentation. @@ -70,7 +67,9 @@ class DataServiceWorkerImpl { std::unique_ptr iterator; } Task; - const experimental::WorkerConfig config_; + const std::string dispatcher_address_; + // Protocol for communicating with the dispatcher. + const std::string protocol_; // The worker's own address. std::string worker_address_; diff --git a/tensorflow/core/protobuf/data/experimental/service_config.proto b/tensorflow/core/protobuf/data/experimental/service_config.proto index 8708b923720..5dcc3c69083 100644 --- a/tensorflow/core/protobuf/data/experimental/service_config.proto +++ b/tensorflow/core/protobuf/data/experimental/service_config.proto @@ -10,18 +10,3 @@ message DispatcherConfig { // The protocol for the dispatcher to use when connecting to workers. string protocol = 2; } - -// Configuration for a tf.data service WorkerServer. -message WorkerConfig { - // The port for the worker to bind to. A value of 0 indicates that the - // worker may bind to any available port. - int64 port = 1; - // The protocol for the worker to use when connecting to the dispatcher. - string protocol = 2; - // The address of the dispatcher to register with. - string dispatcher_address = 3; - // The address of the worker server. The substring "%port%", if specified, - // will be replaced with the worker's bound port. This is useful when the port - // is set to `0`. - string worker_address = 4; -} diff --git a/tensorflow/python/data/experimental/service/server_lib.py b/tensorflow/python/data/experimental/service/server_lib.py index 99dc9297901..3e355565308 100644 --- a/tensorflow/python/data/experimental/service/server_lib.py +++ b/tensorflow/python/data/experimental/service/server_lib.py @@ -205,13 +205,8 @@ class WorkerServer(object): protocol = "grpc" self._protocol = protocol - config = service_config_pb2.WorkerConfig( - port=port, - protocol=protocol, - dispatcher_address=dispatcher_address, - worker_address=worker_address) self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( - config.SerializeToString()) + port, protocol, dispatcher_address, worker_address) if start: self._server.start() diff --git a/tensorflow/python/data/experimental/service/server_lib_wrapper.cc b/tensorflow/python/data/experimental/service/server_lib_wrapper.cc index f59c1fb90bf..b8250aaeda6 100644 --- a/tensorflow/python/data/experimental/service/server_lib_wrapper.cc +++ b/tensorflow/python/data/experimental/service/server_lib_wrapper.cc @@ -69,16 +69,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) { m.def( "TF_DATA_NewWorkerServer", - [](std::string serialized_worker_config) + [](int port, std::string protocol, std::string dispatcher_address, + std::string worker_address) -> std::unique_ptr { - tensorflow::data::experimental::WorkerConfig config; - if (!config.ParseFromString(serialized_worker_config)) { - tensorflow::MaybeRaiseFromStatus(tensorflow::errors::InvalidArgument( - "Failed to deserialize worker config.")); - } std::unique_ptr server; - tensorflow::Status status = - tensorflow::data::NewWorkerServer(config, &server); + tensorflow::Status status = tensorflow::data::NewWorkerServer( + port, protocol, dispatcher_address, worker_address, &server); tensorflow::MaybeRaiseFromStatus(status); return server; },