[tf.data service] Improve pointer/reference usage in server_lib.
This CL updates the code to use mutable references for output parameters, since they cannot be null. PiperOrigin-RevId: 329009092 Change-Id: Ib382300f7d11657fec4ee210a3e60e5fcc26a218
This commit is contained in:
parent
3b87d2932a
commit
88eebe512c
@ -26,9 +26,9 @@ using ::grpc::ServerBuilder;
|
||||
using ::grpc::ServerContext;
|
||||
|
||||
GrpcDispatcherImpl::GrpcDispatcherImpl(
|
||||
ServerBuilder* server_builder, const experimental::DispatcherConfig& config)
|
||||
const experimental::DispatcherConfig& config, ServerBuilder& server_builder)
|
||||
: impl_(config) {
|
||||
server_builder->RegisterService(this);
|
||||
server_builder.RegisterService(this);
|
||||
VLOG(1) << "Registered data service dispatcher";
|
||||
}
|
||||
|
||||
|
@ -25,18 +25,12 @@ namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// This class is a wrapper that handles communication for gRPC.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// ::grpc::ServerBuilder builder;
|
||||
// // configure builder
|
||||
// GrpcDispatcherImpl data_service(&builder);
|
||||
// builder.BuildAndStart()
|
||||
//
|
||||
class GrpcDispatcherImpl : public DispatcherService::Service {
|
||||
public:
|
||||
explicit GrpcDispatcherImpl(::grpc::ServerBuilder* server_builder,
|
||||
const experimental::DispatcherConfig& config);
|
||||
// Constructs a GrpcDispatcherImpl with the given config, and registers it
|
||||
// with `server_builder`.
|
||||
explicit GrpcDispatcherImpl(const experimental::DispatcherConfig& config,
|
||||
::grpc::ServerBuilder& server_builder);
|
||||
~GrpcDispatcherImpl() override {}
|
||||
|
||||
Status Start();
|
||||
|
@ -24,10 +24,10 @@ namespace data {
|
||||
using ::grpc::ServerBuilder;
|
||||
using ::grpc::ServerContext;
|
||||
|
||||
GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder,
|
||||
const experimental::WorkerConfig& config)
|
||||
GrpcWorkerImpl::GrpcWorkerImpl(const experimental::WorkerConfig& config,
|
||||
ServerBuilder& server_builder)
|
||||
: impl_(config) {
|
||||
server_builder->RegisterService(this);
|
||||
server_builder.RegisterService(this);
|
||||
VLOG(1) << "Registered data service worker";
|
||||
}
|
||||
|
||||
|
@ -25,18 +25,12 @@ namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// This class is a wrapper that handles communication for gRPC.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// ::grpc::ServerBuilder builder;
|
||||
// // configure builder
|
||||
// GrpcWorkerImpl data_service(&builder);
|
||||
// builder.BuildAndStart()
|
||||
//
|
||||
class GrpcWorkerImpl : public WorkerService::Service {
|
||||
public:
|
||||
explicit GrpcWorkerImpl(::grpc::ServerBuilder* server_builder,
|
||||
const experimental::WorkerConfig& config);
|
||||
// Constructs a GrpcWorkerImpl with the given config, and registers it with
|
||||
// `server_builder`.
|
||||
explicit GrpcWorkerImpl(const experimental::WorkerConfig& config,
|
||||
::grpc::ServerBuilder& server_builder);
|
||||
~GrpcWorkerImpl() override {}
|
||||
|
||||
Status Start(const std::string& worker_address);
|
||||
|
@ -51,8 +51,8 @@ Status GrpcDataServerBase::Start() {
|
||||
credentials, &bound_port_);
|
||||
builder.SetMaxReceiveMessageSize(-1);
|
||||
|
||||
AddDataServiceToBuilder(&builder);
|
||||
AddProfilerServiceToBuilder(&builder);
|
||||
AddDataServiceToBuilder(builder);
|
||||
AddProfilerServiceToBuilder(builder);
|
||||
server_ = builder.BuildAndStart();
|
||||
if (!server_) {
|
||||
return errors::Internal("Could not start gRPC server");
|
||||
@ -81,9 +81,9 @@ void GrpcDataServerBase::Join() { server_->Wait(); }
|
||||
int GrpcDataServerBase::BoundPort() { return bound_port(); }
|
||||
|
||||
void GrpcDataServerBase::AddProfilerServiceToBuilder(
|
||||
::grpc::ServerBuilder* builder) {
|
||||
::grpc::ServerBuilder& builder) {
|
||||
profiler_service_ = CreateProfilerService();
|
||||
builder->RegisterService(profiler_service_.get());
|
||||
builder.RegisterService(profiler_service_.get());
|
||||
}
|
||||
|
||||
DispatchGrpcDataServer::DispatchGrpcDataServer(
|
||||
@ -94,8 +94,8 @@ DispatchGrpcDataServer::DispatchGrpcDataServer(
|
||||
DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
|
||||
|
||||
void DispatchGrpcDataServer::AddDataServiceToBuilder(
|
||||
::grpc::ServerBuilder* builder) {
|
||||
service_ = absl::make_unique<GrpcDispatcherImpl>(builder, config_).release();
|
||||
::grpc::ServerBuilder& builder) {
|
||||
service_ = absl::make_unique<GrpcDispatcherImpl>(config_, builder).release();
|
||||
}
|
||||
|
||||
Status DispatchGrpcDataServer::StartServiceInternal() {
|
||||
@ -122,8 +122,8 @@ WorkerGrpcDataServer::WorkerGrpcDataServer(
|
||||
WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
|
||||
|
||||
void WorkerGrpcDataServer::AddDataServiceToBuilder(
|
||||
::grpc::ServerBuilder* builder) {
|
||||
service_ = absl::make_unique<GrpcWorkerImpl>(builder, config_).release();
|
||||
::grpc::ServerBuilder& builder) {
|
||||
service_ = absl::make_unique<GrpcWorkerImpl>(config_, builder).release();
|
||||
}
|
||||
|
||||
Status WorkerGrpcDataServer::StartServiceInternal() {
|
||||
@ -139,14 +139,14 @@ Status WorkerGrpcDataServer::StartServiceInternal() {
|
||||
}
|
||||
|
||||
Status NewDispatchServer(const experimental::DispatcherConfig& config,
|
||||
std::unique_ptr<DispatchGrpcDataServer>* out_server) {
|
||||
*out_server = absl::make_unique<DispatchGrpcDataServer>(config);
|
||||
std::unique_ptr<DispatchGrpcDataServer>& out_server) {
|
||||
out_server = absl::make_unique<DispatchGrpcDataServer>(config);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NewWorkerServer(const experimental::WorkerConfig& config,
|
||||
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
|
||||
*out_server = absl::make_unique<WorkerGrpcDataServer>(config);
|
||||
std::unique_ptr<WorkerGrpcDataServer>& out_server) {
|
||||
out_server = absl::make_unique<WorkerGrpcDataServer>(config);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -53,8 +53,8 @@ class GrpcDataServerBase {
|
||||
int BoundPort();
|
||||
|
||||
protected:
|
||||
virtual void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) = 0;
|
||||
void AddProfilerServiceToBuilder(::grpc::ServerBuilder* builder);
|
||||
virtual void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) = 0;
|
||||
void AddProfilerServiceToBuilder(::grpc::ServerBuilder& builder);
|
||||
// Starts the service. This will be called after building the service, so
|
||||
// bound_port() will return the actual bound port.
|
||||
virtual Status StartServiceInternal() = 0;
|
||||
@ -84,7 +84,7 @@ class DispatchGrpcDataServer : public GrpcDataServerBase {
|
||||
Status NumWorkers(int* num_workers);
|
||||
|
||||
protected:
|
||||
void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override;
|
||||
void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override;
|
||||
Status StartServiceInternal() override;
|
||||
|
||||
private:
|
||||
@ -99,7 +99,7 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
|
||||
~WorkerGrpcDataServer() override;
|
||||
|
||||
protected:
|
||||
void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override;
|
||||
void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override;
|
||||
Status StartServiceInternal() override;
|
||||
|
||||
private:
|
||||
@ -108,13 +108,13 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
|
||||
GrpcWorkerImpl* service_;
|
||||
};
|
||||
|
||||
// Creates a dispatch tf.data server and stores it in `*out_server`.
|
||||
// Creates a dispatch tf.data server and stores it in `out_server`.
|
||||
Status NewDispatchServer(const experimental::DispatcherConfig& config,
|
||||
std::unique_ptr<DispatchGrpcDataServer>* out_server);
|
||||
std::unique_ptr<DispatchGrpcDataServer>& out_server);
|
||||
|
||||
// Creates a worker tf.data server and stores it in `*out_server`.
|
||||
// Creates a worker tf.data server and stores it in `out_server`.
|
||||
Status NewWorkerServer(const experimental::WorkerConfig& config,
|
||||
std::unique_ptr<WorkerGrpcDataServer>* out_server);
|
||||
std::unique_ptr<WorkerGrpcDataServer>& out_server);
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -49,7 +49,7 @@ Status TestCluster::Initialize() {
|
||||
experimental::DispatcherConfig config;
|
||||
config.set_port(0);
|
||||
config.set_protocol(kProtocol);
|
||||
TF_RETURN_IF_ERROR(NewDispatchServer(config, &dispatcher_));
|
||||
TF_RETURN_IF_ERROR(NewDispatchServer(config, dispatcher_));
|
||||
TF_RETURN_IF_ERROR(dispatcher_->Start());
|
||||
dispatcher_address_ = absl::StrCat("localhost:", dispatcher_->BoundPort());
|
||||
workers_.reserve(num_workers_);
|
||||
@ -67,7 +67,7 @@ Status TestCluster::AddWorker() {
|
||||
config.set_protocol(kProtocol);
|
||||
config.set_dispatcher_address(dispatcher_address_);
|
||||
config.set_worker_address("localhost:%port%");
|
||||
TF_RETURN_IF_ERROR(NewWorkerServer(config, &worker));
|
||||
TF_RETURN_IF_ERROR(NewWorkerServer(config, worker));
|
||||
TF_RETURN_IF_ERROR(worker->Start());
|
||||
worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort()));
|
||||
workers_.push_back(std::move(worker));
|
||||
|
@ -63,7 +63,7 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
|
||||
}
|
||||
std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> server;
|
||||
tensorflow::Status status =
|
||||
tensorflow::data::NewDispatchServer(config, &server);
|
||||
tensorflow::data::NewDispatchServer(config, server);
|
||||
tensorflow::MaybeRaiseFromStatus(status);
|
||||
return server;
|
||||
},
|
||||
@ -80,7 +80,7 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
|
||||
}
|
||||
std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server;
|
||||
tensorflow::Status status =
|
||||
tensorflow::data::NewWorkerServer(config, &server);
|
||||
tensorflow::data::NewWorkerServer(config, server);
|
||||
tensorflow::MaybeRaiseFromStatus(status);
|
||||
return server;
|
||||
},
|
||||
|
Loading…
Reference in New Issue
Block a user