diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.cc b/tensorflow/core/data/service/grpc_dispatcher_impl.cc index a7a30798a93..fbfc5d20665 100644 --- a/tensorflow/core/data/service/grpc_dispatcher_impl.cc +++ b/tensorflow/core/data/service/grpc_dispatcher_impl.cc @@ -26,9 +26,9 @@ using ::grpc::ServerBuilder; using ::grpc::ServerContext; GrpcDispatcherImpl::GrpcDispatcherImpl( - ServerBuilder* server_builder, const experimental::DispatcherConfig& config) + const experimental::DispatcherConfig& config, ServerBuilder& server_builder) : impl_(config) { - server_builder->RegisterService(this); + server_builder.RegisterService(this); VLOG(1) << "Registered data service dispatcher"; } diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.h b/tensorflow/core/data/service/grpc_dispatcher_impl.h index 81f1cbf6f02..171deed4792 100644 --- a/tensorflow/core/data/service/grpc_dispatcher_impl.h +++ b/tensorflow/core/data/service/grpc_dispatcher_impl.h @@ -25,18 +25,12 @@ namespace tensorflow { namespace data { // This class is a wrapper that handles communication for gRPC. -// -// Example usage: -// -// ::grpc::ServerBuilder builder; -// // configure builder -// GrpcDispatcherImpl data_service(&builder); -// builder.BuildAndStart() -// class GrpcDispatcherImpl : public DispatcherService::Service { public: - explicit GrpcDispatcherImpl(::grpc::ServerBuilder* server_builder, - const experimental::DispatcherConfig& config); + // Constructs a GrpcDispatcherImpl with the given config, and registers it + // with `server_builder`. + explicit GrpcDispatcherImpl(const experimental::DispatcherConfig& config, + ::grpc::ServerBuilder& server_builder); ~GrpcDispatcherImpl() override {} Status Start(); diff --git a/tensorflow/core/data/service/grpc_worker_impl.cc b/tensorflow/core/data/service/grpc_worker_impl.cc index b3a37fe0eec..ef386be4640 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.cc +++ b/tensorflow/core/data/service/grpc_worker_impl.cc @@ -24,10 +24,10 @@ namespace data { using ::grpc::ServerBuilder; using ::grpc::ServerContext; -GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder, - const experimental::WorkerConfig& config) +GrpcWorkerImpl::GrpcWorkerImpl(const experimental::WorkerConfig& config, + ServerBuilder& server_builder) : impl_(config) { - server_builder->RegisterService(this); + server_builder.RegisterService(this); VLOG(1) << "Registered data service worker"; } diff --git a/tensorflow/core/data/service/grpc_worker_impl.h b/tensorflow/core/data/service/grpc_worker_impl.h index c42e5639385..3d30af9a806 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.h +++ b/tensorflow/core/data/service/grpc_worker_impl.h @@ -25,18 +25,12 @@ namespace tensorflow { namespace data { // This class is a wrapper that handles communication for gRPC. -// -// Example usage: -// -// ::grpc::ServerBuilder builder; -// // configure builder -// GrpcWorkerImpl data_service(&builder); -// builder.BuildAndStart() -// class GrpcWorkerImpl : public WorkerService::Service { public: - explicit GrpcWorkerImpl(::grpc::ServerBuilder* server_builder, - const experimental::WorkerConfig& config); + // Constructs a GrpcWorkerImpl with the given config, and registers it with + // `server_builder`. + explicit GrpcWorkerImpl(const experimental::WorkerConfig& config, + ::grpc::ServerBuilder& server_builder); ~GrpcWorkerImpl() override {} Status Start(const std::string& worker_address); diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc index 4ee186cd9ec..4c9e4c11503 100644 --- a/tensorflow/core/data/service/server_lib.cc +++ b/tensorflow/core/data/service/server_lib.cc @@ -51,8 +51,8 @@ Status GrpcDataServerBase::Start() { credentials, &bound_port_); builder.SetMaxReceiveMessageSize(-1); - AddDataServiceToBuilder(&builder); - AddProfilerServiceToBuilder(&builder); + AddDataServiceToBuilder(builder); + AddProfilerServiceToBuilder(builder); server_ = builder.BuildAndStart(); if (!server_) { return errors::Internal("Could not start gRPC server"); @@ -81,9 +81,9 @@ void GrpcDataServerBase::Join() { server_->Wait(); } int GrpcDataServerBase::BoundPort() { return bound_port(); } void GrpcDataServerBase::AddProfilerServiceToBuilder( - ::grpc::ServerBuilder* builder) { + ::grpc::ServerBuilder& builder) { profiler_service_ = CreateProfilerService(); - builder->RegisterService(profiler_service_.get()); + builder.RegisterService(profiler_service_.get()); } DispatchGrpcDataServer::DispatchGrpcDataServer( @@ -94,8 +94,8 @@ DispatchGrpcDataServer::DispatchGrpcDataServer( DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; } void DispatchGrpcDataServer::AddDataServiceToBuilder( - ::grpc::ServerBuilder* builder) { - service_ = absl::make_unique(builder, config_).release(); + ::grpc::ServerBuilder& builder) { + service_ = absl::make_unique(config_, builder).release(); } Status DispatchGrpcDataServer::StartServiceInternal() { @@ -122,8 +122,8 @@ WorkerGrpcDataServer::WorkerGrpcDataServer( WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; } void WorkerGrpcDataServer::AddDataServiceToBuilder( - ::grpc::ServerBuilder* builder) { - service_ = absl::make_unique(builder, config_).release(); + ::grpc::ServerBuilder& builder) { + service_ = absl::make_unique(config_, builder).release(); } Status WorkerGrpcDataServer::StartServiceInternal() { @@ -139,14 +139,14 @@ Status WorkerGrpcDataServer::StartServiceInternal() { } Status NewDispatchServer(const experimental::DispatcherConfig& config, - std::unique_ptr* out_server) { - *out_server = absl::make_unique(config); + std::unique_ptr& out_server) { + out_server = absl::make_unique(config); return Status::OK(); } Status NewWorkerServer(const experimental::WorkerConfig& config, - std::unique_ptr* out_server) { - *out_server = absl::make_unique(config); + std::unique_ptr& out_server) { + out_server = absl::make_unique(config); return Status::OK(); } diff --git a/tensorflow/core/data/service/server_lib.h b/tensorflow/core/data/service/server_lib.h index 0ddc80676c3..c45ec144652 100644 --- a/tensorflow/core/data/service/server_lib.h +++ b/tensorflow/core/data/service/server_lib.h @@ -53,8 +53,8 @@ class GrpcDataServerBase { int BoundPort(); protected: - virtual void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) = 0; - void AddProfilerServiceToBuilder(::grpc::ServerBuilder* builder); + virtual void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) = 0; + void AddProfilerServiceToBuilder(::grpc::ServerBuilder& builder); // Starts the service. This will be called after building the service, so // bound_port() will return the actual bound port. virtual Status StartServiceInternal() = 0; @@ -84,7 +84,7 @@ class DispatchGrpcDataServer : public GrpcDataServerBase { Status NumWorkers(int* num_workers); protected: - void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override; + void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override; Status StartServiceInternal() override; private: @@ -99,7 +99,7 @@ class WorkerGrpcDataServer : public GrpcDataServerBase { ~WorkerGrpcDataServer() override; protected: - void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override; + void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override; Status StartServiceInternal() override; private: @@ -108,13 +108,13 @@ class WorkerGrpcDataServer : public GrpcDataServerBase { GrpcWorkerImpl* service_; }; -// Creates a dispatch tf.data server and stores it in `*out_server`. +// Creates a dispatch tf.data server and stores it in `out_server`. Status NewDispatchServer(const experimental::DispatcherConfig& config, - std::unique_ptr* out_server); + std::unique_ptr& out_server); -// Creates a worker tf.data server and stores it in `*out_server`. +// Creates a worker tf.data server and stores it in `out_server`. Status NewWorkerServer(const experimental::WorkerConfig& config, - std::unique_ptr* out_server); + std::unique_ptr& out_server); } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/test_cluster.cc b/tensorflow/core/data/service/test_cluster.cc index 8ae3f191407..49f7eaef30d 100644 --- a/tensorflow/core/data/service/test_cluster.cc +++ b/tensorflow/core/data/service/test_cluster.cc @@ -49,7 +49,7 @@ Status TestCluster::Initialize() { experimental::DispatcherConfig config; config.set_port(0); config.set_protocol(kProtocol); - TF_RETURN_IF_ERROR(NewDispatchServer(config, &dispatcher_)); + TF_RETURN_IF_ERROR(NewDispatchServer(config, dispatcher_)); TF_RETURN_IF_ERROR(dispatcher_->Start()); dispatcher_address_ = absl::StrCat("localhost:", dispatcher_->BoundPort()); workers_.reserve(num_workers_); @@ -67,7 +67,7 @@ Status TestCluster::AddWorker() { config.set_protocol(kProtocol); config.set_dispatcher_address(dispatcher_address_); config.set_worker_address("localhost:%port%"); - TF_RETURN_IF_ERROR(NewWorkerServer(config, &worker)); + TF_RETURN_IF_ERROR(NewWorkerServer(config, worker)); TF_RETURN_IF_ERROR(worker->Start()); worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort())); workers_.push_back(std::move(worker)); diff --git a/tensorflow/python/data/experimental/service/server_lib_wrapper.cc b/tensorflow/python/data/experimental/service/server_lib_wrapper.cc index b268ba2403a..8ce904eecba 100644 --- a/tensorflow/python/data/experimental/service/server_lib_wrapper.cc +++ b/tensorflow/python/data/experimental/service/server_lib_wrapper.cc @@ -63,7 +63,7 @@ PYBIND11_MODULE(_pywrap_server_lib, m) { } std::unique_ptr server; tensorflow::Status status = - tensorflow::data::NewDispatchServer(config, &server); + tensorflow::data::NewDispatchServer(config, server); tensorflow::MaybeRaiseFromStatus(status); return server; }, @@ -80,7 +80,7 @@ PYBIND11_MODULE(_pywrap_server_lib, m) { } std::unique_ptr server; tensorflow::Status status = - tensorflow::data::NewWorkerServer(config, &server); + tensorflow::data::NewWorkerServer(config, server); tensorflow::MaybeRaiseFromStatus(status); return server; },