Use proto to configure tf.data service worker server.

This simplifies adding new configuration properties, so that we don't need to plumb new properties through. This also gives us a single place to document all configuration options (in the .proto file).

PiperOrigin-RevId: 322934622
Change-Id: I547740e3c9224c7b74ecf2853672ffeb226d61d1
This commit is contained in:
A. Unique TensorFlower 2020-07-23 22:02:59 -07:00 committed by TensorFlower Gardener
parent 918466d131
commit fd1481780b
11 changed files with 71 additions and 65 deletions

View File

@ -227,7 +227,6 @@ cc_library(
deps = [
":worker_cc_grpc_proto",
":worker_impl",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
tf_grpc_cc_dependency(),
],

View File

@ -26,8 +26,9 @@ using ::grpc::ServerContext;
using ::grpc::Status;
GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder,
const experimental::WorkerConfig& config)
: impl_(config) {
const std::string& dispatcher_address,
const std::string& protocol)
: impl_(dispatcher_address, protocol) {
server_builder->RegisterService(this);
VLOG(1) << "Registered data service worker";
}

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "grpcpp/server_builder.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/data/service/worker_impl.h"
#include "tensorflow/core/protobuf/data/experimental/service_config.pb.h"
namespace tensorflow {
namespace data {
@ -36,7 +35,8 @@ namespace data {
class GrpcWorkerImpl : public WorkerService::Service {
public:
explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder,
const experimental::WorkerConfig& config);
const std::string& dispatcher_address,
const std::string& protocol);
~GrpcWorkerImpl() override {}
void Start(const std::string& worker_address);

View File

@ -79,7 +79,8 @@ DispatchGrpcDataServer::DispatchGrpcDataServer(
DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
void DispatchGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
service_ = absl::make_unique<GrpcDispatcherImpl>(builder, config_).release();
auto service = absl::make_unique<GrpcDispatcherImpl>(builder, config_);
service_ = service.release();
}
Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
@ -95,17 +96,22 @@ Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
}
WorkerGrpcDataServer::WorkerGrpcDataServer(
const experimental::WorkerConfig& config)
: GrpcDataServerBase(config.port(), config.protocol()), config_(config) {}
int port, const std::string& protocol,
const std::string& dispatcher_address, const std::string& worker_address)
: GrpcDataServerBase(port, protocol),
dispatcher_address_(dispatcher_address),
worker_address_(worker_address) {}
WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
void WorkerGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
service_ = absl::make_unique<GrpcWorkerImpl>(builder, config_).release();
auto service = absl::make_unique<GrpcWorkerImpl>(builder, dispatcher_address_,
protocol_);
service_ = service.release();
}
Status WorkerGrpcDataServer::StartServiceInternal() {
std::string worker_address = config_.worker_address();
std::string worker_address = worker_address_;
if (worker_address.empty()) {
worker_address = absl::StrCat("localhost:", kPortPlaceholder);
}
@ -122,9 +128,19 @@ Status NewDispatchServer(const experimental::DispatcherConfig& config,
return Status::OK();
}
Status NewWorkerServer(const experimental::WorkerConfig& config,
Status NewWorkerServer(int port, const std::string& protocol,
const std::string& dispatcher_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
*out_server = absl::make_unique<WorkerGrpcDataServer>(config);
return NewWorkerServer(port, protocol, dispatcher_address,
/*worker_address=*/"", out_server);
}
Status NewWorkerServer(int port, const std::string& protocol,
const std::string& dispatcher_address,
const std::string& worker_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
*out_server = absl::make_unique<WorkerGrpcDataServer>(
port, protocol, dispatcher_address, worker_address);
return Status::OK();
}

View File

@ -91,7 +91,9 @@ class DispatchGrpcDataServer : public GrpcDataServerBase {
class WorkerGrpcDataServer : public GrpcDataServerBase {
public:
explicit WorkerGrpcDataServer(const experimental::WorkerConfig& config);
WorkerGrpcDataServer(int requested_port, const std::string& protocol,
const std::string& dispatcher_address,
const std::string& worker_address);
~WorkerGrpcDataServer() override;
protected:
@ -99,7 +101,8 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
Status StartServiceInternal() override;
private:
const experimental::WorkerConfig config_;
const std::string dispatcher_address_;
const std::string worker_address_;
// Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared.
GrpcWorkerImpl* service_;
};
@ -109,7 +112,23 @@ Status NewDispatchServer(const experimental::DispatcherConfig& config,
std::unique_ptr<DispatchGrpcDataServer>* out_server);
// Creates a worker tf.data server and stores it in `*out_server`.
Status NewWorkerServer(const experimental::WorkerConfig& config,
//
// The port can be a specific port or 0. If the port is 0, an available port
// will be chosen in Start(). This value can be queried with BoundPort().
//
// The worker_address argument is optional. If left empty, it will default to
// "localhost:%port%". When the worker registers with the dispatcher, the worker
// will report the worker address, so that the dispatcher can tell clients where
// to read from. The address may contain the placeholder "%port%", which will be
// replaced with the value of BoundPort().
Status NewWorkerServer(int port, const std::string& protocol,
const std::string& dispatcher_address,
const std::string& worker_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server);
// Creates a worker using the default worker_address.
Status NewWorkerServer(int port, const std::string& protocol,
const std::string& dispatcher_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server);
} // namespace data

View File

@ -62,12 +62,8 @@ Status TestCluster::Initialize() {
Status TestCluster::AddWorker() {
std::unique_ptr<WorkerGrpcDataServer> worker;
experimental::WorkerConfig config;
config.set_port(0);
config.set_protocol(kProtocol);
config.set_dispatcher_address(dispatcher_address_);
config.set_worker_address("localhost:%port%");
TF_RETURN_IF_ERROR(NewWorkerServer(config, &worker));
TF_RETURN_IF_ERROR(
NewWorkerServer(/*port=*/0, kProtocol, dispatcher_address_, &worker));
TF_RETURN_IF_ERROR(worker->Start());
worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort()));
workers_.push_back(std::move(worker));

View File

@ -46,8 +46,8 @@ auto* tf_data_service_created =
} // namespace
DataServiceWorkerImpl::DataServiceWorkerImpl(
const experimental::WorkerConfig& config)
: config_(config) {
const std::string& dispatcher_address, const std::string& protocol)
: dispatcher_address_(dispatcher_address), protocol_(protocol) {
tf_data_service_created->GetCell()->Set(true);
}
@ -68,7 +68,7 @@ void DataServiceWorkerImpl::Start(const std::string& worker_address) {
Status s = Register();
while (!s.ok()) {
LOG(WARNING) << "Failed to register with dispatcher at "
<< config_.dispatcher_address() << ": " << s;
<< dispatcher_address_ << ": " << s;
Env::Default()->SleepForMicroseconds(kHeartbeatIntervalMicros);
s = Register();
}
@ -173,17 +173,17 @@ Status DataServiceWorkerImpl::EnsureDispatcherStubInitialized()
if (!dispatcher_stub_) {
::grpc::ChannelArguments args;
std::shared_ptr<::grpc::ChannelCredentials> credentials;
TF_RETURN_IF_ERROR(CredentialsFactory::CreateClientCredentials(
config_.protocol(), &credentials));
auto channel = ::grpc::CreateCustomChannel(config_.dispatcher_address(),
credentials, args);
TF_RETURN_IF_ERROR(
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
auto channel =
::grpc::CreateCustomChannel(dispatcher_address_, credentials, args);
dispatcher_stub_ = DispatcherService::NewStub(channel);
}
return Status::OK();
}
Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
VLOG(3) << "Registering with dispatcher at " << config_.dispatcher_address();
VLOG(3) << "Registering with dispatcher at " << dispatcher_address_;
TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
RegisterWorkerRequest req;
req.set_worker_address(worker_address_);

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/core/data/service/worker.pb.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/data/experimental/service_config.pb.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
@ -30,14 +29,12 @@ namespace data {
// A TensorFlow DataService serves dataset elements over RPC.
class DataServiceWorkerImpl {
public:
explicit DataServiceWorkerImpl(const experimental::WorkerConfig& config);
explicit DataServiceWorkerImpl(const std::string& dispatcher_address,
const std::string& protocol);
~DataServiceWorkerImpl();
// Starts the worker. The worker needs to know its own address so that it can
// register with the dispatcher. This is set in `Start` instead of in the
// constructor because the worker may be binding to port `0`, in which case
// the address isn't known until the worker has started and decided which port
// to bind to.
// register with the dispatcher.
void Start(const std::string& worker_address);
// See worker.proto for API documentation.
@ -70,7 +67,9 @@ class DataServiceWorkerImpl {
std::unique_ptr<standalone::Iterator> iterator;
} Task;
const experimental::WorkerConfig config_;
const std::string dispatcher_address_;
// Protocol for communicating with the dispatcher.
const std::string protocol_;
// The worker's own address.
std::string worker_address_;

View File

@ -10,18 +10,3 @@ message DispatcherConfig {
// The protocol for the dispatcher to use when connecting to workers.
string protocol = 2;
}
// Configuration for a tf.data service WorkerServer.
message WorkerConfig {
// The port for the worker to bind to. A value of 0 indicates that the
// worker may bind to any available port.
int64 port = 1;
// The protocol for the worker to use when connecting to the dispatcher.
string protocol = 2;
// The address of the dispatcher to register with.
string dispatcher_address = 3;
// The address of the worker server. The substring "%port%", if specified,
// will be replaced with the worker's bound port. This is useful when the port
// is set to `0`.
string worker_address = 4;
}

View File

@ -205,13 +205,8 @@ class WorkerServer(object):
protocol = "grpc"
self._protocol = protocol
config = service_config_pb2.WorkerConfig(
port=port,
protocol=protocol,
dispatcher_address=dispatcher_address,
worker_address=worker_address)
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
config.SerializeToString())
port, protocol, dispatcher_address, worker_address)
if start:
self._server.start()

View File

@ -69,16 +69,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
m.def(
"TF_DATA_NewWorkerServer",
[](std::string serialized_worker_config)
[](int port, std::string protocol, std::string dispatcher_address,
std::string worker_address)
-> std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> {
tensorflow::data::experimental::WorkerConfig config;
if (!config.ParseFromString(serialized_worker_config)) {
tensorflow::MaybeRaiseFromStatus(tensorflow::errors::InvalidArgument(
"Failed to deserialize worker config."));
}
std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server;
tensorflow::Status status =
tensorflow::data::NewWorkerServer(config, &server);
tensorflow::Status status = tensorflow::data::NewWorkerServer(
port, protocol, dispatcher_address, worker_address, &server);
tensorflow::MaybeRaiseFromStatus(status);
return server;
},