Use proto to configure tf.data service worker server.
This simplifies adding new configuration properties, so that we don't need to plumb new properties through. This also gives us a single place to document all configuration options (in the .proto file). PiperOrigin-RevId: 322934622 Change-Id: I547740e3c9224c7b74ecf2853672ffeb226d61d1
This commit is contained in:
parent
918466d131
commit
fd1481780b
@ -227,7 +227,6 @@ cc_library(
|
||||
deps = [
|
||||
":worker_cc_grpc_proto",
|
||||
":worker_impl",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
|
||||
@ -26,8 +26,9 @@ using ::grpc::ServerContext;
|
||||
using ::grpc::Status;
|
||||
|
||||
GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder,
|
||||
const experimental::WorkerConfig& config)
|
||||
: impl_(config) {
|
||||
const std::string& dispatcher_address,
|
||||
const std::string& protocol)
|
||||
: impl_(dispatcher_address, protocol) {
|
||||
server_builder->RegisterService(this);
|
||||
VLOG(1) << "Registered data service worker";
|
||||
}
|
||||
|
||||
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include "grpcpp/server_builder.h"
|
||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/worker_impl.h"
|
||||
#include "tensorflow/core/protobuf/data/experimental/service_config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -36,7 +35,8 @@ namespace data {
|
||||
class GrpcWorkerImpl : public WorkerService::Service {
|
||||
public:
|
||||
explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder,
|
||||
const experimental::WorkerConfig& config);
|
||||
const std::string& dispatcher_address,
|
||||
const std::string& protocol);
|
||||
~GrpcWorkerImpl() override {}
|
||||
|
||||
void Start(const std::string& worker_address);
|
||||
|
||||
@ -79,7 +79,8 @@ DispatchGrpcDataServer::DispatchGrpcDataServer(
|
||||
DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
|
||||
|
||||
void DispatchGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
|
||||
service_ = absl::make_unique<GrpcDispatcherImpl>(builder, config_).release();
|
||||
auto service = absl::make_unique<GrpcDispatcherImpl>(builder, config_);
|
||||
service_ = service.release();
|
||||
}
|
||||
|
||||
Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
|
||||
@ -95,17 +96,22 @@ Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
|
||||
}
|
||||
|
||||
WorkerGrpcDataServer::WorkerGrpcDataServer(
|
||||
const experimental::WorkerConfig& config)
|
||||
: GrpcDataServerBase(config.port(), config.protocol()), config_(config) {}
|
||||
int port, const std::string& protocol,
|
||||
const std::string& dispatcher_address, const std::string& worker_address)
|
||||
: GrpcDataServerBase(port, protocol),
|
||||
dispatcher_address_(dispatcher_address),
|
||||
worker_address_(worker_address) {}
|
||||
|
||||
WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
|
||||
|
||||
void WorkerGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
|
||||
service_ = absl::make_unique<GrpcWorkerImpl>(builder, config_).release();
|
||||
auto service = absl::make_unique<GrpcWorkerImpl>(builder, dispatcher_address_,
|
||||
protocol_);
|
||||
service_ = service.release();
|
||||
}
|
||||
|
||||
Status WorkerGrpcDataServer::StartServiceInternal() {
|
||||
std::string worker_address = config_.worker_address();
|
||||
std::string worker_address = worker_address_;
|
||||
if (worker_address.empty()) {
|
||||
worker_address = absl::StrCat("localhost:", kPortPlaceholder);
|
||||
}
|
||||
@ -122,9 +128,19 @@ Status NewDispatchServer(const experimental::DispatcherConfig& config,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NewWorkerServer(const experimental::WorkerConfig& config,
|
||||
Status NewWorkerServer(int port, const std::string& protocol,
|
||||
const std::string& dispatcher_address,
|
||||
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
|
||||
*out_server = absl::make_unique<WorkerGrpcDataServer>(config);
|
||||
return NewWorkerServer(port, protocol, dispatcher_address,
|
||||
/*worker_address=*/"", out_server);
|
||||
}
|
||||
|
||||
Status NewWorkerServer(int port, const std::string& protocol,
|
||||
const std::string& dispatcher_address,
|
||||
const std::string& worker_address,
|
||||
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
|
||||
*out_server = absl::make_unique<WorkerGrpcDataServer>(
|
||||
port, protocol, dispatcher_address, worker_address);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
@ -91,7 +91,9 @@ class DispatchGrpcDataServer : public GrpcDataServerBase {
|
||||
|
||||
class WorkerGrpcDataServer : public GrpcDataServerBase {
|
||||
public:
|
||||
explicit WorkerGrpcDataServer(const experimental::WorkerConfig& config);
|
||||
WorkerGrpcDataServer(int requested_port, const std::string& protocol,
|
||||
const std::string& dispatcher_address,
|
||||
const std::string& worker_address);
|
||||
~WorkerGrpcDataServer() override;
|
||||
|
||||
protected:
|
||||
@ -99,7 +101,8 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
|
||||
Status StartServiceInternal() override;
|
||||
|
||||
private:
|
||||
const experimental::WorkerConfig config_;
|
||||
const std::string dispatcher_address_;
|
||||
const std::string worker_address_;
|
||||
// Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared.
|
||||
GrpcWorkerImpl* service_;
|
||||
};
|
||||
@ -109,7 +112,23 @@ Status NewDispatchServer(const experimental::DispatcherConfig& config,
|
||||
std::unique_ptr<DispatchGrpcDataServer>* out_server);
|
||||
|
||||
// Creates a worker tf.data server and stores it in `*out_server`.
|
||||
Status NewWorkerServer(const experimental::WorkerConfig& config,
|
||||
//
|
||||
// The port can be a specific port or 0. If the port is 0, an available port
|
||||
// will be chosen in Start(). This value can be queried with BoundPort().
|
||||
//
|
||||
// The worker_address argument is optional. If left empty, it will default to
|
||||
// "localhost:%port%". When the worker registers with the dispatcher, the worker
|
||||
// will report the worker address, so that the dispatcher can tell clients where
|
||||
// to read from. The address may contain the placeholder "%port%", which will be
|
||||
// replaced with the value of BoundPort().
|
||||
Status NewWorkerServer(int port, const std::string& protocol,
|
||||
const std::string& dispatcher_address,
|
||||
const std::string& worker_address,
|
||||
std::unique_ptr<WorkerGrpcDataServer>* out_server);
|
||||
|
||||
// Creates a worker using the default worker_address.
|
||||
Status NewWorkerServer(int port, const std::string& protocol,
|
||||
const std::string& dispatcher_address,
|
||||
std::unique_ptr<WorkerGrpcDataServer>* out_server);
|
||||
|
||||
} // namespace data
|
||||
|
||||
@ -62,12 +62,8 @@ Status TestCluster::Initialize() {
|
||||
|
||||
Status TestCluster::AddWorker() {
|
||||
std::unique_ptr<WorkerGrpcDataServer> worker;
|
||||
experimental::WorkerConfig config;
|
||||
config.set_port(0);
|
||||
config.set_protocol(kProtocol);
|
||||
config.set_dispatcher_address(dispatcher_address_);
|
||||
config.set_worker_address("localhost:%port%");
|
||||
TF_RETURN_IF_ERROR(NewWorkerServer(config, &worker));
|
||||
TF_RETURN_IF_ERROR(
|
||||
NewWorkerServer(/*port=*/0, kProtocol, dispatcher_address_, &worker));
|
||||
TF_RETURN_IF_ERROR(worker->Start());
|
||||
worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort()));
|
||||
workers_.push_back(std::move(worker));
|
||||
|
||||
@ -46,8 +46,8 @@ auto* tf_data_service_created =
|
||||
} // namespace
|
||||
|
||||
DataServiceWorkerImpl::DataServiceWorkerImpl(
|
||||
const experimental::WorkerConfig& config)
|
||||
: config_(config) {
|
||||
const std::string& dispatcher_address, const std::string& protocol)
|
||||
: dispatcher_address_(dispatcher_address), protocol_(protocol) {
|
||||
tf_data_service_created->GetCell()->Set(true);
|
||||
}
|
||||
|
||||
@ -68,7 +68,7 @@ void DataServiceWorkerImpl::Start(const std::string& worker_address) {
|
||||
Status s = Register();
|
||||
while (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to register with dispatcher at "
|
||||
<< config_.dispatcher_address() << ": " << s;
|
||||
<< dispatcher_address_ << ": " << s;
|
||||
Env::Default()->SleepForMicroseconds(kHeartbeatIntervalMicros);
|
||||
s = Register();
|
||||
}
|
||||
@ -173,17 +173,17 @@ Status DataServiceWorkerImpl::EnsureDispatcherStubInitialized()
|
||||
if (!dispatcher_stub_) {
|
||||
::grpc::ChannelArguments args;
|
||||
std::shared_ptr<::grpc::ChannelCredentials> credentials;
|
||||
TF_RETURN_IF_ERROR(CredentialsFactory::CreateClientCredentials(
|
||||
config_.protocol(), &credentials));
|
||||
auto channel = ::grpc::CreateCustomChannel(config_.dispatcher_address(),
|
||||
credentials, args);
|
||||
TF_RETURN_IF_ERROR(
|
||||
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
|
||||
auto channel =
|
||||
::grpc::CreateCustomChannel(dispatcher_address_, credentials, args);
|
||||
dispatcher_stub_ = DispatcherService::NewStub(channel);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
VLOG(3) << "Registering with dispatcher at " << config_.dispatcher_address();
|
||||
VLOG(3) << "Registering with dispatcher at " << dispatcher_address_;
|
||||
TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
|
||||
RegisterWorkerRequest req;
|
||||
req.set_worker_address(worker_address_);
|
||||
|
||||
@ -21,7 +21,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/data/service/worker.pb.h"
|
||||
#include "tensorflow/core/data/standalone.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/data/experimental/service_config.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -30,14 +29,12 @@ namespace data {
|
||||
// A TensorFlow DataService serves dataset elements over RPC.
|
||||
class DataServiceWorkerImpl {
|
||||
public:
|
||||
explicit DataServiceWorkerImpl(const experimental::WorkerConfig& config);
|
||||
explicit DataServiceWorkerImpl(const std::string& dispatcher_address,
|
||||
const std::string& protocol);
|
||||
~DataServiceWorkerImpl();
|
||||
|
||||
// Starts the worker. The worker needs to know its own address so that it can
|
||||
// register with the dispatcher. This is set in `Start` instead of in the
|
||||
// constructor because the worker may be binding to port `0`, in which case
|
||||
// the address isn't known until the worker has started and decided which port
|
||||
// to bind to.
|
||||
// register with the dispatcher.
|
||||
void Start(const std::string& worker_address);
|
||||
|
||||
// See worker.proto for API documentation.
|
||||
@ -70,7 +67,9 @@ class DataServiceWorkerImpl {
|
||||
std::unique_ptr<standalone::Iterator> iterator;
|
||||
} Task;
|
||||
|
||||
const experimental::WorkerConfig config_;
|
||||
const std::string dispatcher_address_;
|
||||
// Protocol for communicating with the dispatcher.
|
||||
const std::string protocol_;
|
||||
// The worker's own address.
|
||||
std::string worker_address_;
|
||||
|
||||
|
||||
@ -10,18 +10,3 @@ message DispatcherConfig {
|
||||
// The protocol for the dispatcher to use when connecting to workers.
|
||||
string protocol = 2;
|
||||
}
|
||||
|
||||
// Configuration for a tf.data service WorkerServer.
|
||||
message WorkerConfig {
|
||||
// The port for the worker to bind to. A value of 0 indicates that the
|
||||
// worker may bind to any available port.
|
||||
int64 port = 1;
|
||||
// The protocol for the worker to use when connecting to the dispatcher.
|
||||
string protocol = 2;
|
||||
// The address of the dispatcher to register with.
|
||||
string dispatcher_address = 3;
|
||||
// The address of the worker server. The substring "%port%", if specified,
|
||||
// will be replaced with the worker's bound port. This is useful when the port
|
||||
// is set to `0`.
|
||||
string worker_address = 4;
|
||||
}
|
||||
|
||||
@ -205,13 +205,8 @@ class WorkerServer(object):
|
||||
protocol = "grpc"
|
||||
|
||||
self._protocol = protocol
|
||||
config = service_config_pb2.WorkerConfig(
|
||||
port=port,
|
||||
protocol=protocol,
|
||||
dispatcher_address=dispatcher_address,
|
||||
worker_address=worker_address)
|
||||
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
|
||||
config.SerializeToString())
|
||||
port, protocol, dispatcher_address, worker_address)
|
||||
if start:
|
||||
self._server.start()
|
||||
|
||||
|
||||
@ -69,16 +69,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
|
||||
|
||||
m.def(
|
||||
"TF_DATA_NewWorkerServer",
|
||||
[](std::string serialized_worker_config)
|
||||
[](int port, std::string protocol, std::string dispatcher_address,
|
||||
std::string worker_address)
|
||||
-> std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> {
|
||||
tensorflow::data::experimental::WorkerConfig config;
|
||||
if (!config.ParseFromString(serialized_worker_config)) {
|
||||
tensorflow::MaybeRaiseFromStatus(tensorflow::errors::InvalidArgument(
|
||||
"Failed to deserialize worker config."));
|
||||
}
|
||||
std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server;
|
||||
tensorflow::Status status =
|
||||
tensorflow::data::NewWorkerServer(config, &server);
|
||||
tensorflow::Status status = tensorflow::data::NewWorkerServer(
|
||||
port, protocol, dispatcher_address, worker_address, &server);
|
||||
tensorflow::MaybeRaiseFromStatus(status);
|
||||
return server;
|
||||
},
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user