[tf.data service] Improve pointer/reference usage in server_lib.
This CL updates the code to use mutable references for output parameters, since they cannot be null. PiperOrigin-RevId: 329009092 Change-Id: Ib382300f7d11657fec4ee210a3e60e5fcc26a218
This commit is contained in:
parent
3b87d2932a
commit
88eebe512c
@ -26,9 +26,9 @@ using ::grpc::ServerBuilder;
|
|||||||
using ::grpc::ServerContext;
|
using ::grpc::ServerContext;
|
||||||
|
|
||||||
GrpcDispatcherImpl::GrpcDispatcherImpl(
|
GrpcDispatcherImpl::GrpcDispatcherImpl(
|
||||||
ServerBuilder* server_builder, const experimental::DispatcherConfig& config)
|
const experimental::DispatcherConfig& config, ServerBuilder& server_builder)
|
||||||
: impl_(config) {
|
: impl_(config) {
|
||||||
server_builder->RegisterService(this);
|
server_builder.RegisterService(this);
|
||||||
VLOG(1) << "Registered data service dispatcher";
|
VLOG(1) << "Registered data service dispatcher";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,18 +25,12 @@ namespace tensorflow {
|
|||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
// This class is a wrapper that handles communication for gRPC.
|
// This class is a wrapper that handles communication for gRPC.
|
||||||
//
|
|
||||||
// Example usage:
|
|
||||||
//
|
|
||||||
// ::grpc::ServerBuilder builder;
|
|
||||||
// // configure builder
|
|
||||||
// GrpcDispatcherImpl data_service(&builder);
|
|
||||||
// builder.BuildAndStart()
|
|
||||||
//
|
|
||||||
class GrpcDispatcherImpl : public DispatcherService::Service {
|
class GrpcDispatcherImpl : public DispatcherService::Service {
|
||||||
public:
|
public:
|
||||||
explicit GrpcDispatcherImpl(::grpc::ServerBuilder* server_builder,
|
// Constructs a GrpcDispatcherImpl with the given config, and registers it
|
||||||
const experimental::DispatcherConfig& config);
|
// with `server_builder`.
|
||||||
|
explicit GrpcDispatcherImpl(const experimental::DispatcherConfig& config,
|
||||||
|
::grpc::ServerBuilder& server_builder);
|
||||||
~GrpcDispatcherImpl() override {}
|
~GrpcDispatcherImpl() override {}
|
||||||
|
|
||||||
Status Start();
|
Status Start();
|
||||||
|
@ -24,10 +24,10 @@ namespace data {
|
|||||||
using ::grpc::ServerBuilder;
|
using ::grpc::ServerBuilder;
|
||||||
using ::grpc::ServerContext;
|
using ::grpc::ServerContext;
|
||||||
|
|
||||||
GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder,
|
GrpcWorkerImpl::GrpcWorkerImpl(const experimental::WorkerConfig& config,
|
||||||
const experimental::WorkerConfig& config)
|
ServerBuilder& server_builder)
|
||||||
: impl_(config) {
|
: impl_(config) {
|
||||||
server_builder->RegisterService(this);
|
server_builder.RegisterService(this);
|
||||||
VLOG(1) << "Registered data service worker";
|
VLOG(1) << "Registered data service worker";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,18 +25,12 @@ namespace tensorflow {
|
|||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
// This class is a wrapper that handles communication for gRPC.
|
// This class is a wrapper that handles communication for gRPC.
|
||||||
//
|
|
||||||
// Example usage:
|
|
||||||
//
|
|
||||||
// ::grpc::ServerBuilder builder;
|
|
||||||
// // configure builder
|
|
||||||
// GrpcWorkerImpl data_service(&builder);
|
|
||||||
// builder.BuildAndStart()
|
|
||||||
//
|
|
||||||
class GrpcWorkerImpl : public WorkerService::Service {
|
class GrpcWorkerImpl : public WorkerService::Service {
|
||||||
public:
|
public:
|
||||||
explicit GrpcWorkerImpl(::grpc::ServerBuilder* server_builder,
|
// Constructs a GrpcWorkerImpl with the given config, and registers it with
|
||||||
const experimental::WorkerConfig& config);
|
// `server_builder`.
|
||||||
|
explicit GrpcWorkerImpl(const experimental::WorkerConfig& config,
|
||||||
|
::grpc::ServerBuilder& server_builder);
|
||||||
~GrpcWorkerImpl() override {}
|
~GrpcWorkerImpl() override {}
|
||||||
|
|
||||||
Status Start(const std::string& worker_address);
|
Status Start(const std::string& worker_address);
|
||||||
|
@ -51,8 +51,8 @@ Status GrpcDataServerBase::Start() {
|
|||||||
credentials, &bound_port_);
|
credentials, &bound_port_);
|
||||||
builder.SetMaxReceiveMessageSize(-1);
|
builder.SetMaxReceiveMessageSize(-1);
|
||||||
|
|
||||||
AddDataServiceToBuilder(&builder);
|
AddDataServiceToBuilder(builder);
|
||||||
AddProfilerServiceToBuilder(&builder);
|
AddProfilerServiceToBuilder(builder);
|
||||||
server_ = builder.BuildAndStart();
|
server_ = builder.BuildAndStart();
|
||||||
if (!server_) {
|
if (!server_) {
|
||||||
return errors::Internal("Could not start gRPC server");
|
return errors::Internal("Could not start gRPC server");
|
||||||
@ -81,9 +81,9 @@ void GrpcDataServerBase::Join() { server_->Wait(); }
|
|||||||
int GrpcDataServerBase::BoundPort() { return bound_port(); }
|
int GrpcDataServerBase::BoundPort() { return bound_port(); }
|
||||||
|
|
||||||
void GrpcDataServerBase::AddProfilerServiceToBuilder(
|
void GrpcDataServerBase::AddProfilerServiceToBuilder(
|
||||||
::grpc::ServerBuilder* builder) {
|
::grpc::ServerBuilder& builder) {
|
||||||
profiler_service_ = CreateProfilerService();
|
profiler_service_ = CreateProfilerService();
|
||||||
builder->RegisterService(profiler_service_.get());
|
builder.RegisterService(profiler_service_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
DispatchGrpcDataServer::DispatchGrpcDataServer(
|
DispatchGrpcDataServer::DispatchGrpcDataServer(
|
||||||
@ -94,8 +94,8 @@ DispatchGrpcDataServer::DispatchGrpcDataServer(
|
|||||||
DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
|
DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
|
||||||
|
|
||||||
void DispatchGrpcDataServer::AddDataServiceToBuilder(
|
void DispatchGrpcDataServer::AddDataServiceToBuilder(
|
||||||
::grpc::ServerBuilder* builder) {
|
::grpc::ServerBuilder& builder) {
|
||||||
service_ = absl::make_unique<GrpcDispatcherImpl>(builder, config_).release();
|
service_ = absl::make_unique<GrpcDispatcherImpl>(config_, builder).release();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DispatchGrpcDataServer::StartServiceInternal() {
|
Status DispatchGrpcDataServer::StartServiceInternal() {
|
||||||
@ -122,8 +122,8 @@ WorkerGrpcDataServer::WorkerGrpcDataServer(
|
|||||||
WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
|
WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
|
||||||
|
|
||||||
void WorkerGrpcDataServer::AddDataServiceToBuilder(
|
void WorkerGrpcDataServer::AddDataServiceToBuilder(
|
||||||
::grpc::ServerBuilder* builder) {
|
::grpc::ServerBuilder& builder) {
|
||||||
service_ = absl::make_unique<GrpcWorkerImpl>(builder, config_).release();
|
service_ = absl::make_unique<GrpcWorkerImpl>(config_, builder).release();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status WorkerGrpcDataServer::StartServiceInternal() {
|
Status WorkerGrpcDataServer::StartServiceInternal() {
|
||||||
@ -139,14 +139,14 @@ Status WorkerGrpcDataServer::StartServiceInternal() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status NewDispatchServer(const experimental::DispatcherConfig& config,
|
Status NewDispatchServer(const experimental::DispatcherConfig& config,
|
||||||
std::unique_ptr<DispatchGrpcDataServer>* out_server) {
|
std::unique_ptr<DispatchGrpcDataServer>& out_server) {
|
||||||
*out_server = absl::make_unique<DispatchGrpcDataServer>(config);
|
out_server = absl::make_unique<DispatchGrpcDataServer>(config);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status NewWorkerServer(const experimental::WorkerConfig& config,
|
Status NewWorkerServer(const experimental::WorkerConfig& config,
|
||||||
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
|
std::unique_ptr<WorkerGrpcDataServer>& out_server) {
|
||||||
*out_server = absl::make_unique<WorkerGrpcDataServer>(config);
|
out_server = absl::make_unique<WorkerGrpcDataServer>(config);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,8 +53,8 @@ class GrpcDataServerBase {
|
|||||||
int BoundPort();
|
int BoundPort();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) = 0;
|
virtual void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) = 0;
|
||||||
void AddProfilerServiceToBuilder(::grpc::ServerBuilder* builder);
|
void AddProfilerServiceToBuilder(::grpc::ServerBuilder& builder);
|
||||||
// Starts the service. This will be called after building the service, so
|
// Starts the service. This will be called after building the service, so
|
||||||
// bound_port() will return the actual bound port.
|
// bound_port() will return the actual bound port.
|
||||||
virtual Status StartServiceInternal() = 0;
|
virtual Status StartServiceInternal() = 0;
|
||||||
@ -84,7 +84,7 @@ class DispatchGrpcDataServer : public GrpcDataServerBase {
|
|||||||
Status NumWorkers(int* num_workers);
|
Status NumWorkers(int* num_workers);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override;
|
void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override;
|
||||||
Status StartServiceInternal() override;
|
Status StartServiceInternal() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -99,7 +99,7 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
|
|||||||
~WorkerGrpcDataServer() override;
|
~WorkerGrpcDataServer() override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override;
|
void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override;
|
||||||
Status StartServiceInternal() override;
|
Status StartServiceInternal() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -108,13 +108,13 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
|
|||||||
GrpcWorkerImpl* service_;
|
GrpcWorkerImpl* service_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Creates a dispatch tf.data server and stores it in `*out_server`.
|
// Creates a dispatch tf.data server and stores it in `out_server`.
|
||||||
Status NewDispatchServer(const experimental::DispatcherConfig& config,
|
Status NewDispatchServer(const experimental::DispatcherConfig& config,
|
||||||
std::unique_ptr<DispatchGrpcDataServer>* out_server);
|
std::unique_ptr<DispatchGrpcDataServer>& out_server);
|
||||||
|
|
||||||
// Creates a worker tf.data server and stores it in `*out_server`.
|
// Creates a worker tf.data server and stores it in `out_server`.
|
||||||
Status NewWorkerServer(const experimental::WorkerConfig& config,
|
Status NewWorkerServer(const experimental::WorkerConfig& config,
|
||||||
std::unique_ptr<WorkerGrpcDataServer>* out_server);
|
std::unique_ptr<WorkerGrpcDataServer>& out_server);
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -49,7 +49,7 @@ Status TestCluster::Initialize() {
|
|||||||
experimental::DispatcherConfig config;
|
experimental::DispatcherConfig config;
|
||||||
config.set_port(0);
|
config.set_port(0);
|
||||||
config.set_protocol(kProtocol);
|
config.set_protocol(kProtocol);
|
||||||
TF_RETURN_IF_ERROR(NewDispatchServer(config, &dispatcher_));
|
TF_RETURN_IF_ERROR(NewDispatchServer(config, dispatcher_));
|
||||||
TF_RETURN_IF_ERROR(dispatcher_->Start());
|
TF_RETURN_IF_ERROR(dispatcher_->Start());
|
||||||
dispatcher_address_ = absl::StrCat("localhost:", dispatcher_->BoundPort());
|
dispatcher_address_ = absl::StrCat("localhost:", dispatcher_->BoundPort());
|
||||||
workers_.reserve(num_workers_);
|
workers_.reserve(num_workers_);
|
||||||
@ -67,7 +67,7 @@ Status TestCluster::AddWorker() {
|
|||||||
config.set_protocol(kProtocol);
|
config.set_protocol(kProtocol);
|
||||||
config.set_dispatcher_address(dispatcher_address_);
|
config.set_dispatcher_address(dispatcher_address_);
|
||||||
config.set_worker_address("localhost:%port%");
|
config.set_worker_address("localhost:%port%");
|
||||||
TF_RETURN_IF_ERROR(NewWorkerServer(config, &worker));
|
TF_RETURN_IF_ERROR(NewWorkerServer(config, worker));
|
||||||
TF_RETURN_IF_ERROR(worker->Start());
|
TF_RETURN_IF_ERROR(worker->Start());
|
||||||
worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort()));
|
worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort()));
|
||||||
workers_.push_back(std::move(worker));
|
workers_.push_back(std::move(worker));
|
||||||
|
@ -63,7 +63,7 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
|
|||||||
}
|
}
|
||||||
std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> server;
|
std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> server;
|
||||||
tensorflow::Status status =
|
tensorflow::Status status =
|
||||||
tensorflow::data::NewDispatchServer(config, &server);
|
tensorflow::data::NewDispatchServer(config, server);
|
||||||
tensorflow::MaybeRaiseFromStatus(status);
|
tensorflow::MaybeRaiseFromStatus(status);
|
||||||
return server;
|
return server;
|
||||||
},
|
},
|
||||||
@ -80,7 +80,7 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
|
|||||||
}
|
}
|
||||||
std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server;
|
std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server;
|
||||||
tensorflow::Status status =
|
tensorflow::Status status =
|
||||||
tensorflow::data::NewWorkerServer(config, &server);
|
tensorflow::data::NewWorkerServer(config, server);
|
||||||
tensorflow::MaybeRaiseFromStatus(status);
|
tensorflow::MaybeRaiseFromStatus(status);
|
||||||
return server;
|
return server;
|
||||||
},
|
},
|
||||||
|
Loading…
Reference in New Issue
Block a user