[tf.data service] Improve pointer/reference usage in server_lib.

This CL updates the code to use mutable references for output parameters, since they cannot be null.

PiperOrigin-RevId: 329009092
Change-Id: Ib382300f7d11657fec4ee210a3e60e5fcc26a218
This commit is contained in:
Andrew Audibert 2020-08-28 14:24:56 -07:00 committed by TensorFlower Gardener
parent 3b87d2932a
commit 88eebe512c
8 changed files with 37 additions and 49 deletions

View File

@ -26,9 +26,9 @@ using ::grpc::ServerBuilder;
using ::grpc::ServerContext;
GrpcDispatcherImpl::GrpcDispatcherImpl(
ServerBuilder* server_builder, const experimental::DispatcherConfig& config)
const experimental::DispatcherConfig& config, ServerBuilder& server_builder)
: impl_(config) {
server_builder->RegisterService(this);
server_builder.RegisterService(this);
VLOG(1) << "Registered data service dispatcher";
}

View File

@ -25,18 +25,12 @@ namespace tensorflow {
namespace data {
// This class is a wrapper that handles communication for gRPC.
//
// Example usage:
//
// ::grpc::ServerBuilder builder;
// // configure builder
// GrpcDispatcherImpl data_service(&builder);
// builder.BuildAndStart()
//
class GrpcDispatcherImpl : public DispatcherService::Service {
public:
explicit GrpcDispatcherImpl(::grpc::ServerBuilder* server_builder,
const experimental::DispatcherConfig& config);
// Constructs a GrpcDispatcherImpl with the given config, and registers it
// with `server_builder`.
explicit GrpcDispatcherImpl(const experimental::DispatcherConfig& config,
::grpc::ServerBuilder& server_builder);
~GrpcDispatcherImpl() override {}
Status Start();

View File

@ -24,10 +24,10 @@ namespace data {
using ::grpc::ServerBuilder;
using ::grpc::ServerContext;
GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder,
const experimental::WorkerConfig& config)
GrpcWorkerImpl::GrpcWorkerImpl(const experimental::WorkerConfig& config,
ServerBuilder& server_builder)
: impl_(config) {
server_builder->RegisterService(this);
server_builder.RegisterService(this);
VLOG(1) << "Registered data service worker";
}

View File

@ -25,18 +25,12 @@ namespace tensorflow {
namespace data {
// This class is a wrapper that handles communication for gRPC.
//
// Example usage:
//
// ::grpc::ServerBuilder builder;
// // configure builder
// GrpcWorkerImpl data_service(&builder);
// builder.BuildAndStart()
//
class GrpcWorkerImpl : public WorkerService::Service {
public:
explicit GrpcWorkerImpl(::grpc::ServerBuilder* server_builder,
const experimental::WorkerConfig& config);
// Constructs a GrpcWorkerImpl with the given config, and registers it with
// `server_builder`.
explicit GrpcWorkerImpl(const experimental::WorkerConfig& config,
::grpc::ServerBuilder& server_builder);
~GrpcWorkerImpl() override {}
Status Start(const std::string& worker_address);

View File

@ -51,8 +51,8 @@ Status GrpcDataServerBase::Start() {
credentials, &bound_port_);
builder.SetMaxReceiveMessageSize(-1);
AddDataServiceToBuilder(&builder);
AddProfilerServiceToBuilder(&builder);
AddDataServiceToBuilder(builder);
AddProfilerServiceToBuilder(builder);
server_ = builder.BuildAndStart();
if (!server_) {
return errors::Internal("Could not start gRPC server");
@ -81,9 +81,9 @@ void GrpcDataServerBase::Join() { server_->Wait(); }
int GrpcDataServerBase::BoundPort() { return bound_port(); }
void GrpcDataServerBase::AddProfilerServiceToBuilder(
::grpc::ServerBuilder* builder) {
::grpc::ServerBuilder& builder) {
profiler_service_ = CreateProfilerService();
builder->RegisterService(profiler_service_.get());
builder.RegisterService(profiler_service_.get());
}
DispatchGrpcDataServer::DispatchGrpcDataServer(
@ -94,8 +94,8 @@ DispatchGrpcDataServer::DispatchGrpcDataServer(
DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
void DispatchGrpcDataServer::AddDataServiceToBuilder(
::grpc::ServerBuilder* builder) {
service_ = absl::make_unique<GrpcDispatcherImpl>(builder, config_).release();
::grpc::ServerBuilder& builder) {
service_ = absl::make_unique<GrpcDispatcherImpl>(config_, builder).release();
}
Status DispatchGrpcDataServer::StartServiceInternal() {
@ -122,8 +122,8 @@ WorkerGrpcDataServer::WorkerGrpcDataServer(
WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
void WorkerGrpcDataServer::AddDataServiceToBuilder(
::grpc::ServerBuilder* builder) {
service_ = absl::make_unique<GrpcWorkerImpl>(builder, config_).release();
::grpc::ServerBuilder& builder) {
service_ = absl::make_unique<GrpcWorkerImpl>(config_, builder).release();
}
Status WorkerGrpcDataServer::StartServiceInternal() {
@ -139,14 +139,14 @@ Status WorkerGrpcDataServer::StartServiceInternal() {
}
Status NewDispatchServer(const experimental::DispatcherConfig& config,
std::unique_ptr<DispatchGrpcDataServer>* out_server) {
*out_server = absl::make_unique<DispatchGrpcDataServer>(config);
std::unique_ptr<DispatchGrpcDataServer>& out_server) {
out_server = absl::make_unique<DispatchGrpcDataServer>(config);
return Status::OK();
}
Status NewWorkerServer(const experimental::WorkerConfig& config,
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
*out_server = absl::make_unique<WorkerGrpcDataServer>(config);
std::unique_ptr<WorkerGrpcDataServer>& out_server) {
out_server = absl::make_unique<WorkerGrpcDataServer>(config);
return Status::OK();
}

View File

@ -53,8 +53,8 @@ class GrpcDataServerBase {
int BoundPort();
protected:
virtual void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) = 0;
void AddProfilerServiceToBuilder(::grpc::ServerBuilder* builder);
virtual void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) = 0;
void AddProfilerServiceToBuilder(::grpc::ServerBuilder& builder);
// Starts the service. This will be called after building the service, so
// bound_port() will return the actual bound port.
virtual Status StartServiceInternal() = 0;
@ -84,7 +84,7 @@ class DispatchGrpcDataServer : public GrpcDataServerBase {
Status NumWorkers(int* num_workers);
protected:
void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override;
void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override;
Status StartServiceInternal() override;
private:
@ -99,7 +99,7 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
~WorkerGrpcDataServer() override;
protected:
void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override;
void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override;
Status StartServiceInternal() override;
private:
@ -108,13 +108,13 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
GrpcWorkerImpl* service_;
};
// Creates a dispatch tf.data server and stores it in `*out_server`.
// Creates a dispatch tf.data server and stores it in `out_server`.
Status NewDispatchServer(const experimental::DispatcherConfig& config,
std::unique_ptr<DispatchGrpcDataServer>* out_server);
std::unique_ptr<DispatchGrpcDataServer>& out_server);
// Creates a worker tf.data server and stores it in `*out_server`.
// Creates a worker tf.data server and stores it in `out_server`.
Status NewWorkerServer(const experimental::WorkerConfig& config,
std::unique_ptr<WorkerGrpcDataServer>* out_server);
std::unique_ptr<WorkerGrpcDataServer>& out_server);
} // namespace data
} // namespace tensorflow

View File

@ -49,7 +49,7 @@ Status TestCluster::Initialize() {
experimental::DispatcherConfig config;
config.set_port(0);
config.set_protocol(kProtocol);
TF_RETURN_IF_ERROR(NewDispatchServer(config, &dispatcher_));
TF_RETURN_IF_ERROR(NewDispatchServer(config, dispatcher_));
TF_RETURN_IF_ERROR(dispatcher_->Start());
dispatcher_address_ = absl::StrCat("localhost:", dispatcher_->BoundPort());
workers_.reserve(num_workers_);
@ -67,7 +67,7 @@ Status TestCluster::AddWorker() {
config.set_protocol(kProtocol);
config.set_dispatcher_address(dispatcher_address_);
config.set_worker_address("localhost:%port%");
TF_RETURN_IF_ERROR(NewWorkerServer(config, &worker));
TF_RETURN_IF_ERROR(NewWorkerServer(config, worker));
TF_RETURN_IF_ERROR(worker->Start());
worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort()));
workers_.push_back(std::move(worker));

View File

@ -63,7 +63,7 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
}
std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> server;
tensorflow::Status status =
tensorflow::data::NewDispatchServer(config, &server);
tensorflow::data::NewDispatchServer(config, server);
tensorflow::MaybeRaiseFromStatus(status);
return server;
},
@ -80,7 +80,7 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
}
std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server;
tensorflow::Status status =
tensorflow::data::NewWorkerServer(config, &server);
tensorflow::data::NewWorkerServer(config, server);
tensorflow::MaybeRaiseFromStatus(status);
return server;
},