diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index bebd179401e..d2a887a82f8 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -28,8 +28,8 @@ tf_proto_library( ) tf_proto_library( - name = "master_proto", - srcs = ["master.proto"], + name = "dispatcher_proto", + srcs = ["dispatcher.proto"], has_services = 1, cc_api_version = 2, protodeps = tf_additional_all_protos() + [ @@ -49,17 +49,17 @@ tf_proto_library( ) cc_library( - name = "master_impl", - srcs = ["master_impl.cc"], + name = "dispatcher_impl", + srcs = ["dispatcher_impl.cc"], hdrs = [ - "master_impl.h", + "dispatcher_impl.h", ], deps = [ ":common_proto_cc", ":credentials_factory", ":data_service", + ":dispatcher_proto_cc", ":grpc_util", - ":master_proto_cc", ":worker_cc_grpc_proto", ":worker_proto_cc", "//tensorflow/c:c_api_internal", @@ -86,9 +86,9 @@ cc_library( deps = [ ":common_proto_cc", ":credentials_factory", + ":dispatcher_cc_grpc_proto", + ":dispatcher_proto_cc", ":grpc_util", - ":master_cc_grpc_proto", - ":master_proto_cc", ":worker_proto_cc", "//tensorflow/c:c_api_internal", "//tensorflow/c:tf_status_helper", @@ -207,12 +207,12 @@ tf_cc_test( ) cc_library( - name = "grpc_master_impl", - srcs = ["grpc_master_impl.cc"], - hdrs = ["grpc_master_impl.h"], + name = "grpc_dispatcher_impl", + srcs = ["grpc_dispatcher_impl.cc"], + hdrs = ["grpc_dispatcher_impl.h"], deps = [ - ":master_cc_grpc_proto", - ":master_impl", + ":dispatcher_cc_grpc_proto", + ":dispatcher_impl", "//tensorflow/core/distributed_runtime/rpc:grpc_util", tf_grpc_cc_dependency(), ], @@ -250,7 +250,7 @@ cc_library( ], deps = [ ":credentials_factory", - ":grpc_master_impl", + ":grpc_dispatcher_impl", ":grpc_util", ":grpc_worker_impl", "//tensorflow/core:lib", @@ -268,9 +268,9 @@ cc_library( ], deps = [ ":credentials_factory", + ":dispatcher_cc_grpc_proto", + ":dispatcher_proto_cc", ":grpc_util", - ":master_cc_grpc_proto", - ":master_proto_cc", ":worker_cc_grpc_proto", ":worker_proto_cc", "//tensorflow/core:framework", @@ -287,12 +287,12 @@ tf_cc_test( tags = ["no_windows"], deps = [ ":data_service", - ":grpc_master_impl", + ":dispatcher_cc_grpc_proto", + ":dispatcher_proto_cc", + ":grpc_dispatcher_impl", ":grpc_util", ":grpc_worker_impl", ":local_credentials_factory", - ":master_cc_grpc_proto", - ":master_proto_cc", ":server_lib", ":test_cluster", ":test_util", @@ -309,11 +309,11 @@ tf_cc_test( ) cc_grpc_library( - name = "master_cc_grpc_proto", - srcs = [":master_proto"], + name = "dispatcher_cc_grpc_proto", + srcs = [":dispatcher_proto"], generate_mocks = True, grpc_only = True, - deps = [":master_proto_cc"], + deps = [":dispatcher_proto_cc"], ) cc_grpc_library( diff --git a/tensorflow/core/data/service/data_service.cc b/tensorflow/core/data/service/data_service.cc index d4e08c77f35..be09b10c1fc 100644 --- a/tensorflow/core/data/service/data_service.cc +++ b/tensorflow/core/data/service/data_service.cc @@ -18,8 +18,8 @@ limitations under the License. #include "grpcpp/create_channel.h" #include "grpcpp/security/credentials.h" #include "tensorflow/core/data/service/credentials_factory.h" +#include "tensorflow/core/data/service/dispatcher.grpc.pb.h" #include "tensorflow/core/data/service/grpc_util.h" -#include "tensorflow/core/data/service/master.grpc.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/framework/dataset.h" @@ -54,8 +54,8 @@ std::string ProcessingModeToString(ProcessingMode mode) { } } -Status DataServiceMasterClient::RegisterDataset(GraphDef dataset, - int64* dataset_id) { +Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset, + int64* dataset_id) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetOrRegisterDatasetRequest req; *req.mutable_dataset()->mutable_graph() = dataset; @@ -69,9 +69,9 @@ Status DataServiceMasterClient::RegisterDataset(GraphDef dataset, return Status::OK(); } -Status DataServiceMasterClient::CreateJob(int64 dataset_id, - ProcessingMode processing_mode, - int64* job_id) { +Status DataServiceDispatcherClient::CreateJob(int64 dataset_id, + ProcessingMode processing_mode, + int64* job_id) { TF_RETURN_IF_ERROR(EnsureInitialized()); CreateJobRequest req; req.set_dataset_id(dataset_id); @@ -88,11 +88,9 @@ Status DataServiceMasterClient::CreateJob(int64 dataset_id, return Status::OK(); } -Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id, - ProcessingMode processing_mode, - const std::string& job_name, - int job_name_index, - int64* job_id) { +Status DataServiceDispatcherClient::GetOrCreateJob( + int64 dataset_id, ProcessingMode processing_mode, + const std::string& job_name, int job_name_index, int64* job_id) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetOrCreateJobRequest req; req.set_dataset_id(dataset_id); @@ -112,9 +110,9 @@ Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id, return Status::OK(); } -Status DataServiceMasterClient::GetTasks(int64 job_id, - std::vector* tasks, - bool* job_finished) { +Status DataServiceDispatcherClient::GetTasks(int64 job_id, + std::vector* tasks, + bool* job_finished) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetTasksRequest req; req.set_job_id(job_id); @@ -132,7 +130,8 @@ Status DataServiceMasterClient::GetTasks(int64 job_id, return Status::OK(); } -Status DataServiceMasterClient::GetWorkers(std::vector* workers) { +Status DataServiceDispatcherClient::GetWorkers( + std::vector* workers) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetWorkersRequest req; GetWorkersResponse resp; @@ -148,12 +147,12 @@ Status DataServiceMasterClient::GetWorkers(std::vector* workers) { return Status::OK(); } -Status DataServiceMasterClient::EnsureInitialized() { +Status DataServiceDispatcherClient::EnsureInitialized() { std::shared_ptr credentials; TF_RETURN_IF_ERROR( CredentialsFactory::CreateClientCredentials(protocol_, &credentials)); auto channel = grpc::CreateChannel(address_, credentials); - stub_ = MasterService::NewStub(channel); + stub_ = DispatcherService::NewStub(channel); return Status::OK(); } @@ -187,10 +186,11 @@ Status DataServiceWorkerClient::EnsureInitialized() { return Status::OK(); } -Status CreateDataServiceMasterClient( +Status CreateDataServiceDispatcherClient( const std::string& address, const std::string& protocol, - std::unique_ptr* out) { - auto client = absl::make_unique(address, protocol); + std::unique_ptr* out) { + auto client = + absl::make_unique(address, protocol); TF_RETURN_IF_ERROR(client->Initialize()); *out = std::move(client); return Status::OK(); diff --git a/tensorflow/core/data/service/data_service.h b/tensorflow/core/data/service/data_service.h index bb5a8a470f0..d0e46c82ff5 100644 --- a/tensorflow/core/data/service/data_service.h +++ b/tensorflow/core/data/service/data_service.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_ #define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_ -#include "tensorflow/core/data/service/master.grpc.pb.h" +#include "tensorflow/core/data/service/dispatcher.grpc.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op_kernel.h" @@ -67,11 +67,11 @@ class DataServiceClientBase { const std::string protocol_; }; -// Client for communicating with the tf.data service master. -class DataServiceMasterClient : public DataServiceClientBase { +// Client for communicating with the tf.data service dispatcher. +class DataServiceDispatcherClient : public DataServiceClientBase { public: - DataServiceMasterClient(const std::string& address, - const std::string& protocol) + DataServiceDispatcherClient(const std::string& address, + const std::string& protocol) : DataServiceClientBase(address, protocol) {} // Registers a dataset with the tf.data service, and stores the generated @@ -90,13 +90,13 @@ class DataServiceMasterClient : public DataServiceClientBase { const std::string& job_name, int job_name_index, int64* job_id); - // Queries the master for the tasks associated with the specified job. + // Queries the dispatcher for the tasks associated with the specified job. // The tasks will be stored in *tasks, and whether the job is finished will // be stored in `*job_finished`. Status GetTasks(int64 job_id, std::vector* tasks, bool* job_finished); - // Queries the master for its registered workers. The worker info will be + // Queries the dispatcher for its registered workers. The worker info will be // stored in `*workers`. Status GetWorkers(std::vector* workers); @@ -104,7 +104,7 @@ class DataServiceMasterClient : public DataServiceClientBase { Status EnsureInitialized() override; private: - std::unique_ptr stub_; + std::unique_ptr stub_; }; // Client for communicating with the tf.data service worker. @@ -127,10 +127,10 @@ class DataServiceWorkerClient : public DataServiceClientBase { std::unique_ptr stub_; }; -// Creates and initializes a new tf.data service master client. -Status CreateDataServiceMasterClient( +// Creates and initializes a new tf.data service dispatcher client. +Status CreateDataServiceDispatcherClient( const std::string& address, const std::string& protocol, - std::unique_ptr* out); + std::unique_ptr* out); // Creates and initializes a new tf.data service worker client. Status CreateDataServiceWorkerClient( diff --git a/tensorflow/core/data/service/data_service_test.cc b/tensorflow/core/data/service/data_service_test.cc index 19392393eeb..607570054b4 100644 --- a/tensorflow/core/data/service/data_service_test.cc +++ b/tensorflow/core/data/service/data_service_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "grpcpp/security/credentials.h" #include "absl/strings/str_split.h" #include "tensorflow/core/data/compression_utils.h" +#include "tensorflow/core/data/service/dispatcher.grpc.pb.h" +#include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/grpc_util.h" -#include "tensorflow/core/data/service/master.grpc.pb.h" -#include "tensorflow/core/data/service/master.pb.h" #include "tensorflow/core/data/service/server_lib.h" #include "tensorflow/core/data/service/test_cluster.h" #include "tensorflow/core/data/service/test_util.h" @@ -66,9 +66,10 @@ TEST(DataService, ProcessingModeToString) { TEST(DataService, GetWorkers) { TestCluster cluster(1); TF_ASSERT_OK(cluster.Initialize()); - DataServiceMasterClient master(cluster.MasterAddress(), kProtocol); + DataServiceDispatcherClient dispatcher(cluster.DispatcherAddress(), + kProtocol); std::vector workers; - TF_EXPECT_OK(master.GetWorkers(&workers)); + TF_EXPECT_OK(dispatcher.GetWorkers(&workers)); EXPECT_EQ(1, workers.size()); } diff --git a/tensorflow/core/data/service/master.proto b/tensorflow/core/data/service/dispatcher.proto similarity index 94% rename from tensorflow/core/data/service/master.proto rename to tensorflow/core/data/service/dispatcher.proto index 661264cc41b..119fe675f2a 100644 --- a/tensorflow/core/data/service/master.proto +++ b/tensorflow/core/data/service/dispatcher.proto @@ -110,11 +110,11 @@ message GetWorkersResponse { repeated WorkerInfo workers = 1; } -service MasterService { - // Registers a worker with the master. +service DispatcherService { + // Registers a worker with the dispatcher. rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse); - // Updates the master with information about the worker's state. + // Updates the dispatcher with information about the worker's state. rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse); // Registers a dataset with the server, or returns its id if it is already @@ -134,6 +134,6 @@ service MasterService { // Reports a list of all tasks for a job. rpc GetTasks(GetTasksRequest) returns (GetTasksResponse); - // Reports a list of all workers registered with the master. + // Reports a list of all workers registered with the dispatcher. rpc GetWorkers(GetWorkersRequest) returns (GetWorkersResponse); } diff --git a/tensorflow/core/data/service/master_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc similarity index 87% rename from tensorflow/core/data/service/master_impl.cc rename to tensorflow/core/data/service/dispatcher_impl.cc index 5c7917b4154..22a86570b46 100644 --- a/tensorflow/core/data/service/master_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/data/service/master_impl.h" +#include "tensorflow/core/data/service/dispatcher_impl.h" #include #include @@ -26,8 +26,8 @@ limitations under the License. #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/credentials_factory.h" #include "tensorflow/core/data/service/data_service.h" +#include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/grpc_util.h" -#include "tensorflow/core/data/service/master.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/kernels/data/dataset_utils.h" @@ -53,10 +53,10 @@ Status CreateWorkerStub(const std::string& address, } } // namespace -DataServiceMasterImpl::DataServiceMasterImpl(const std::string protocol) +DataServiceDispatcherImpl::DataServiceDispatcherImpl(const std::string protocol) : protocol_(protocol) {} -Status DataServiceMasterImpl::RegisterWorker( +Status DataServiceDispatcherImpl::RegisterWorker( const RegisterWorkerRequest* request, RegisterWorkerResponse* response) { VLOG(3) << "Received register worker request"; mutex_lock l(mu_); @@ -86,8 +86,8 @@ Status DataServiceMasterImpl::RegisterWorker( return Status::OK(); } -Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request, - WorkerUpdateResponse* response) { +Status DataServiceDispatcherImpl::WorkerUpdate( + const WorkerUpdateRequest* request, WorkerUpdateResponse* response) { mutex_lock l(mu_); int64 worker_id = request->worker_id(); for (auto& update : request->updates()) { @@ -106,7 +106,7 @@ Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request, return Status::OK(); } -Status DataServiceMasterImpl::GetOrRegisterDataset( +Status DataServiceDispatcherImpl::GetOrRegisterDataset( const GetOrRegisterDatasetRequest* request, GetOrRegisterDatasetResponse* response) { uint64 fingerprint; @@ -128,8 +128,8 @@ Status DataServiceMasterImpl::GetOrRegisterDataset( return Status::OK(); } -int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint, - const DatasetDef& dataset) +int64 DataServiceDispatcherImpl::RegisterDataset(uint64 fingerprint, + const DatasetDef& dataset) EXCLUSIVE_LOCKS_REQUIRED(mu_) { int64 dataset_id = next_dataset_id_++; auto new_dataset = @@ -142,8 +142,8 @@ int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint, return dataset_id; } -Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request, - CreateJobResponse* response) { +Status DataServiceDispatcherImpl::CreateJob(const CreateJobRequest* request, + CreateJobResponse* response) { VLOG(3) << "Received create job request for dataset id " << request->dataset_id(); ProcessingMode processing_mode = ProcessingMode(request->processing_mode()); @@ -157,7 +157,7 @@ Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request, return Status::OK(); } -Status DataServiceMasterImpl::GetOrCreateJob( +Status DataServiceDispatcherImpl::GetOrCreateJob( const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) { VLOG(3) << "Received get or create job request for dataset id " << request->dataset_id() << " with name " << request->job_name() @@ -193,7 +193,7 @@ Status DataServiceMasterImpl::GetOrCreateJob( } // Validates that the job matches the given processing_mode and dataset_id. -Status DataServiceMasterImpl::ValidateMatchingJob( +Status DataServiceDispatcherImpl::ValidateMatchingJob( const Job& job, ProcessingMode processing_mode, int64 dataset_id) { DCHECK(job.name().has_value()); std::string job_name = job.name().value(); @@ -214,10 +214,10 @@ Status DataServiceMasterImpl::ValidateMatchingJob( return Status::OK(); } -Status DataServiceMasterImpl::CreateJob(int64 dataset_id, - ProcessingMode processing_mode, - absl::optional job_name, - int64* out_job_id) LOCKS_EXCLUDED(mu_) { +Status DataServiceDispatcherImpl::CreateJob( + int64 dataset_id, ProcessingMode processing_mode, + absl::optional job_name, int64* out_job_id) + LOCKS_EXCLUDED(mu_) { switch (processing_mode) { case ProcessingMode::PARALLEL_EPOCHS: break; @@ -274,14 +274,16 @@ Status DataServiceMasterImpl::CreateJob(int64 dataset_id, return Status::OK(); } -const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTask( +const DataServiceDispatcherImpl::Task& DataServiceDispatcherImpl::CreateTask( Job* job, const std::string& worker_address) LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); return CreateTaskLocked(job, worker_address); } -const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTaskLocked( - Job* job, const std::string& worker_address) EXCLUSIVE_LOCKS_REQUIRED(mu_) { +const DataServiceDispatcherImpl::Task& +DataServiceDispatcherImpl::CreateTaskLocked(Job* job, + const std::string& worker_address) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { int64 task_id = next_task_id_++; DCHECK(!tasks_.contains(task_id)); tasks_.insert({task_id, Task(task_id, job->job_id(), job->dataset_id(), @@ -290,7 +292,7 @@ const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTaskLocked( return tasks_.at(task_id); } -Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) { +Status DataServiceDispatcherImpl::EnsureWorkerStubInitialized(Worker* worker) { if (!worker->stub()) { std::unique_ptr stub; TF_RETURN_IF_ERROR(CreateWorkerStub(worker->address(), protocol_, &stub)); @@ -299,8 +301,8 @@ Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) { return Status::OK(); } -Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task, - Worker* worker) +Status DataServiceDispatcherImpl::AllocateTaskToWorker(const Task& task, + Worker* worker) LOCKS_EXCLUDED(mu_) { TF_RETURN_IF_ERROR(EnsureWorkerStubInitialized(worker)); grpc::ClientContext client_ctx; @@ -322,8 +324,8 @@ Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task, return Status::OK(); } -Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request, - GetTasksResponse* response) { +Status DataServiceDispatcherImpl::GetTasks(const GetTasksRequest* request, + GetTasksResponse* response) { mutex_lock l(mu_); VLOG(3) << "Looking up tasks for job id " << request->job_id(); auto it = jobs_.find(request->job_id()); @@ -346,8 +348,8 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request, return Status::OK(); } -Status DataServiceMasterImpl::GetWorkers(const GetWorkersRequest* request, - GetWorkersResponse* response) { +Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request, + GetWorkersResponse* response) { mutex_lock l(mu_); VLOG(3) << "Enter GetWorkers"; for (auto& worker : workers_) { diff --git a/tensorflow/core/data/service/master_impl.h b/tensorflow/core/data/service/dispatcher_impl.h similarity index 94% rename from tensorflow/core/data/service/master_impl.h rename to tensorflow/core/data/service/dispatcher_impl.h index 67df2613118..84770f7056f 100644 --- a/tensorflow/core/data/service/master_impl.h +++ b/tensorflow/core/data/service/dispatcher_impl.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_ -#define TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_ #include "absl/container/flat_hash_map.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/data_service.h" -#include "tensorflow/core/data/service/master.pb.h" +#include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/mutex.h" @@ -40,11 +40,11 @@ namespace data { // ProcessingModeDef which determines what data it produces. // * Task: A job is broken into multiple tasks, which each represent // iterating over all of or part of the dataset. Workers process tasks. -class DataServiceMasterImpl { +class DataServiceDispatcherImpl { public: - explicit DataServiceMasterImpl(const std::string protocol); + explicit DataServiceDispatcherImpl(const std::string protocol); - // See master.proto for API documentation. + // See dispatcher.proto for API documentation. /// Worker-facing API. Status RegisterWorker(const RegisterWorkerRequest* request, @@ -191,7 +191,7 @@ class DataServiceMasterImpl { // Creates a new task for a job, returning a reference to the task. const Task& CreateTask(Job* job, const std::string& worker_address) LOCKS_EXCLUDED(mu_); - // Same as `CreateTask`, but expects that the master lock is already held. + // Same as `CreateTask`, but expects that the dispatcher lock is already held. const Task& CreateTaskLocked(Job* job, const std::string& worker_address) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Validates that an existing job matches the given processing_mode and @@ -225,10 +225,10 @@ class DataServiceMasterImpl { absl::flat_hash_map> named_jobs_ TF_GUARDED_BY(mu_); - TF_DISALLOW_COPY_AND_ASSIGN(DataServiceMasterImpl); + TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl); }; } // namespace data } // namespace tensorflow -#endif // TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_ +#endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_ diff --git a/tensorflow/core/data/service/grpc_master_impl.cc b/tensorflow/core/data/service/grpc_dispatcher_impl.cc similarity index 73% rename from tensorflow/core/data/service/grpc_master_impl.cc rename to tensorflow/core/data/service/grpc_dispatcher_impl.cc index 20ad58a0115..38ecc7057be 100644 --- a/tensorflow/core/data/service/grpc_master_impl.cc +++ b/tensorflow/core/data/service/grpc_dispatcher_impl.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/data/service/grpc_master_impl.h" +#include "tensorflow/core/data/service/grpc_dispatcher_impl.h" #include "grpcpp/server_context.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" @@ -25,18 +25,18 @@ using ::grpc::ServerBuilder; using ::grpc::ServerContext; using ::grpc::Status; -GrpcMasterImpl::GrpcMasterImpl(ServerBuilder* server_builder, - const std::string& protocol) +GrpcDispatcherImpl::GrpcDispatcherImpl(ServerBuilder* server_builder, + const std::string& protocol) : impl_(protocol) { server_builder->RegisterService(this); - VLOG(1) << "Registered data service master"; + VLOG(1) << "Registered data service dispatcher"; } -#define HANDLER(method) \ - Status GrpcMasterImpl::method(ServerContext* context, \ - const method##Request* request, \ - method##Response* response) { \ - return ToGrpcStatus(impl_.method(request, response)); \ +#define HANDLER(method) \ + Status GrpcDispatcherImpl::method(ServerContext* context, \ + const method##Request* request, \ + method##Response* response) { \ + return ToGrpcStatus(impl_.method(request, response)); \ } HANDLER(RegisterWorker); HANDLER(WorkerUpdate); diff --git a/tensorflow/core/data/service/grpc_master_impl.h b/tensorflow/core/data/service/grpc_dispatcher_impl.h similarity index 68% rename from tensorflow/core/data/service/grpc_master_impl.h rename to tensorflow/core/data/service/grpc_dispatcher_impl.h index d29bb6759f0..f407bd64127 100644 --- a/tensorflow/core/data/service/grpc_master_impl.h +++ b/tensorflow/core/data/service/grpc_dispatcher_impl.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_ -#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_ #include "grpcpp/server_builder.h" -#include "tensorflow/core/data/service/master.grpc.pb.h" -#include "tensorflow/core/data/service/master_impl.h" +#include "tensorflow/core/data/service/dispatcher.grpc.pb.h" +#include "tensorflow/core/data/service/dispatcher_impl.h" namespace tensorflow { namespace data { @@ -29,14 +29,14 @@ namespace data { // // ::grpc::ServerBuilder builder; // // configure builder -// GrpcMasterImpl data_service(&builder); +// GrpcDispatcherImpl data_service(&builder); // builder.BuildAndStart() // -class GrpcMasterImpl : public MasterService::Service { +class GrpcDispatcherImpl : public DispatcherService::Service { public: - explicit GrpcMasterImpl(grpc::ServerBuilder* server_builder, - const std::string& protocol); - ~GrpcMasterImpl() override {} + explicit GrpcDispatcherImpl(grpc::ServerBuilder* server_builder, + const std::string& protocol); + ~GrpcDispatcherImpl() override {} #define HANDLER(method) \ grpc::Status method(grpc::ServerContext* context, \ @@ -52,12 +52,12 @@ class GrpcMasterImpl : public MasterService::Service { #undef HANDLER private: - DataServiceMasterImpl impl_; + DataServiceDispatcherImpl impl_; - TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterImpl); + TF_DISALLOW_COPY_AND_ASSIGN(GrpcDispatcherImpl); }; } // namespace data } // namespace tensorflow -#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_ +#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_ diff --git a/tensorflow/core/data/service/grpc_worker_impl.cc b/tensorflow/core/data/service/grpc_worker_impl.cc index 7884fa063ba..0cddfce4e0b 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.cc +++ b/tensorflow/core/data/service/grpc_worker_impl.cc @@ -26,9 +26,9 @@ using ::grpc::ServerContext; using ::grpc::Status; GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder, - const std::string& master_address, + const std::string& dispatcher_address, const std::string& protocol) - : impl_(master_address, protocol) { + : impl_(dispatcher_address, protocol) { server_builder->RegisterService(this); VLOG(1) << "Registered data service worker"; } diff --git a/tensorflow/core/data/service/grpc_worker_impl.h b/tensorflow/core/data/service/grpc_worker_impl.h index b7ece2a7738..169ae29ea37 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.h +++ b/tensorflow/core/data/service/grpc_worker_impl.h @@ -35,7 +35,7 @@ namespace data { class GrpcWorkerImpl : public WorkerService::Service { public: explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder, - const std::string& master_address, + const std::string& dispatcher_address, const std::string& protocol); ~GrpcWorkerImpl() override {} diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc index 33c2232f4dc..4f34bf9d0c7 100644 --- a/tensorflow/core/data/service/server_lib.cc +++ b/tensorflow/core/data/service/server_lib.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/core/data/service/server_lib.h" #include "tensorflow/core/data/service/credentials_factory.h" -#include "tensorflow/core/data/service/grpc_master_impl.h" +#include "tensorflow/core/data/service/grpc_dispatcher_impl.h" #include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/grpc_worker_impl.h" #include "tensorflow/core/platform/errors.h" @@ -72,18 +72,18 @@ void GrpcDataServerBase::Join() { server_->Wait(); } int GrpcDataServerBase::BoundPort() { return bound_port(); } -MasterGrpcDataServer::MasterGrpcDataServer(int port, - const std::string& protocol) +DispatchGrpcDataServer::DispatchGrpcDataServer(int port, + const std::string& protocol) : GrpcDataServerBase(port, protocol) {} -MasterGrpcDataServer::~MasterGrpcDataServer() { delete service_; } +DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; } -void MasterGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) { - auto service = absl::make_unique(builder, protocol_); +void DispatchGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) { + auto service = absl::make_unique(builder, protocol_); service_ = service.release(); } -Status MasterGrpcDataServer::NumWorkers(int* num_workers) { +Status DispatchGrpcDataServer::NumWorkers(int* num_workers) { GetWorkersRequest req; GetWorkersResponse resp; grpc::ServerContext ctx; @@ -95,19 +95,18 @@ Status MasterGrpcDataServer::NumWorkers(int* num_workers) { return Status::OK(); } -WorkerGrpcDataServer::WorkerGrpcDataServer(int port, - const std::string& protocol, - const std::string& master_address, - const std::string& worker_address) +WorkerGrpcDataServer::WorkerGrpcDataServer( + int port, const std::string& protocol, + const std::string& dispatcher_address, const std::string& worker_address) : GrpcDataServerBase(port, protocol), - master_address_(master_address), + dispatcher_address_(dispatcher_address), worker_address_(worker_address) {} WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; } void WorkerGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) { - auto service = - absl::make_unique(builder, master_address_, protocol_); + auto service = absl::make_unique(builder, dispatcher_address_, + protocol_); service_ = service.release(); } @@ -123,25 +122,25 @@ Status WorkerGrpcDataServer::StartServiceInternal() { return Status::OK(); } -Status NewMasterServer(int port, const std::string& protocol, - std::unique_ptr* out_server) { - *out_server = absl::make_unique(port, protocol); +Status NewDispatchServer(int port, const std::string& protocol, + std::unique_ptr* out_server) { + *out_server = absl::make_unique(port, protocol); return Status::OK(); } Status NewWorkerServer(int port, const std::string& protocol, - const std::string& master_address, + const std::string& dispatcher_address, std::unique_ptr* out_server) { - return NewWorkerServer(port, protocol, master_address, /*worker_address=*/"", - out_server); + return NewWorkerServer(port, protocol, dispatcher_address, + /*worker_address=*/"", out_server); } Status NewWorkerServer(int port, const std::string& protocol, - const std::string& master_address, + const std::string& dispatcher_address, const std::string& worker_address, std::unique_ptr* out_server) { *out_server = absl::make_unique( - port, protocol, master_address, worker_address); + port, protocol, dispatcher_address, worker_address); return Status::OK(); } diff --git a/tensorflow/core/data/service/server_lib.h b/tensorflow/core/data/service/server_lib.h index 72bec665c8e..2190c7a56fe 100644 --- a/tensorflow/core/data/service/server_lib.h +++ b/tensorflow/core/data/service/server_lib.h @@ -25,7 +25,7 @@ namespace data { // Forward declared because transitively depending on .grpc.pb.h files causes // issues in the pywrap build. -class GrpcMasterImpl; +class GrpcDispatcherImpl; class GrpcWorkerImpl; // A grpc server for the tf.data service. @@ -35,7 +35,7 @@ class GrpcDataServerBase { // server will find an available port in `Start()`. The chosen port can be // found in the output of `Target()`. // - // master_address is only needed for worker data servers. + // dispatcher_address is only needed for worker data servers. GrpcDataServerBase(int requested_port, const std::string& protocol); virtual ~GrpcDataServerBase() {} @@ -70,12 +70,12 @@ class GrpcDataServerBase { std::unique_ptr server_; }; -class MasterGrpcDataServer : public GrpcDataServerBase { +class DispatchGrpcDataServer : public GrpcDataServerBase { public: - MasterGrpcDataServer(int requested_port, const std::string& protocol); - ~MasterGrpcDataServer() override; + DispatchGrpcDataServer(int requested_port, const std::string& protocol); + ~DispatchGrpcDataServer() override; - // Returns the number of workers registerd with the master. + // Returns the number of workers registerd with the dispatcher. Status NumWorkers(int* num_workers); protected: @@ -83,14 +83,14 @@ class MasterGrpcDataServer : public GrpcDataServerBase { Status StartServiceInternal() override { return Status::OK(); } private: - // Owned. We use a raw pointer because GrpcMasterImpl is forward-declared. - GrpcMasterImpl* service_; + // Owned. We use a raw pointer because GrpcDispatcherImpl is forward-declared. + GrpcDispatcherImpl* service_; }; class WorkerGrpcDataServer : public GrpcDataServerBase { public: WorkerGrpcDataServer(int requested_port, const std::string& protocol, - const std::string& master_address, + const std::string& dispatcher_address, const std::string& worker_address); ~WorkerGrpcDataServer() override; @@ -99,15 +99,15 @@ class WorkerGrpcDataServer : public GrpcDataServerBase { Status StartServiceInternal() override; private: - const std::string master_address_; + const std::string dispatcher_address_; const std::string worker_address_; // Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared. GrpcWorkerImpl* service_; }; -// Creates a master tf.data server and stores it in `*out_server`. -Status NewMasterServer(int port, const std::string& protocol, - std::unique_ptr* out_server); +// Creates a dispatch tf.data server and stores it in `*out_server`. +Status NewDispatchServer(int port, const std::string& protocol, + std::unique_ptr* out_server); // Creates a worker tf.data server and stores it in `*out_server`. // @@ -115,18 +115,18 @@ Status NewMasterServer(int port, const std::string& protocol, // will be chosen in Start(). This value can be queried with BoundPort(). // // The worker_address argument is optional. If left empty, it will default to -// "localhost:%port%". When the worker registers with the master, the worker -// will report the worker address, so that the master can tell clients where to -// read from. The address may contain the placeholder "%port%", which will be +// "localhost:%port%". When the worker registers with the dispatcher, the worker +// will report the worker address, so that the dispatcher can tell clients where +// to read from. The address may contain the placeholder "%port%", which will be // replaced with the value of BoundPort(). Status NewWorkerServer(int port, const std::string& protocol, - const std::string& master_address, + const std::string& dispatcher_address, const std::string& worker_address, std::unique_ptr* out_server); // Creates a worker using the default worker_address. Status NewWorkerServer(int port, const std::string& protocol, - const std::string& master_address, + const std::string& dispatcher_address, std::unique_ptr* out_server); } // namespace data diff --git a/tensorflow/core/data/service/test_cluster.cc b/tensorflow/core/data/service/test_cluster.cc index ded3ebb91b5..4066a75a374 100644 --- a/tensorflow/core/data/service/test_cluster.cc +++ b/tensorflow/core/data/service/test_cluster.cc @@ -45,9 +45,9 @@ Status TestCluster::Initialize() { "Test cluster has already been initialized."); } initialized_ = true; - TF_RETURN_IF_ERROR(NewMasterServer(/*port=*/0, kProtocol, &master_)); - TF_RETURN_IF_ERROR(master_->Start()); - master_address_ = absl::StrCat("localhost:", master_->BoundPort()); + TF_RETURN_IF_ERROR(NewDispatchServer(/*port=*/0, kProtocol, &dispatcher_)); + TF_RETURN_IF_ERROR(dispatcher_->Start()); + dispatcher_address_ = absl::StrCat("localhost:", dispatcher_->BoundPort()); workers_.reserve(num_workers_); worker_addresses_.reserve(num_workers_); for (int i = 0; i < num_workers_; ++i) { @@ -59,14 +59,14 @@ Status TestCluster::Initialize() { Status TestCluster::AddWorker() { std::unique_ptr worker; TF_RETURN_IF_ERROR( - NewWorkerServer(/*port=*/0, kProtocol, master_address_, &worker)); + NewWorkerServer(/*port=*/0, kProtocol, dispatcher_address_, &worker)); TF_RETURN_IF_ERROR(worker->Start()); worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort())); workers_.push_back(std::move(worker)); return Status::OK(); } -std::string TestCluster::MasterAddress() { return master_address_; } +std::string TestCluster::DispatcherAddress() { return dispatcher_address_; } std::string TestCluster::WorkerAddress(int index) { DCHECK_GE(index, 0); diff --git a/tensorflow/core/data/service/test_cluster.h b/tensorflow/core/data/service/test_cluster.h index c4b05ad0543..c5ca3db4c74 100644 --- a/tensorflow/core/data/service/test_cluster.h +++ b/tensorflow/core/data/service/test_cluster.h @@ -24,7 +24,7 @@ namespace data { // Helper class for unit testing a tf.data service cluster. class TestCluster { public: - // Creates a new test cluster with a master and `num_workers` workers. + // Creates a new test cluster with a dispatcher and `num_workers` workers. explicit TestCluster(int num_workers); // Initializes the test cluster. This must be called before interacting with @@ -32,8 +32,8 @@ class TestCluster { Status Initialize(); // Adds a new worker to the cluster. Status AddWorker(); - // Returns the master address in the form "hostname:port". - std::string MasterAddress(); + // Returns the dispatcher address in the form "hostname:port". + std::string DispatcherAddress(); // Returns the address of the worker at the specified index, in the form // "hostname:port". The index must be non-negative and less than the number of // workers in the cluster. @@ -42,8 +42,8 @@ class TestCluster { private: bool initialized_ = false; int num_workers_; - std::unique_ptr master_; - std::string master_address_; + std::unique_ptr dispatcher_; + std::string dispatcher_address_; std::vector> workers_; std::vector worker_addresses_; }; diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index 151410bb219..00659e1d048 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -21,9 +21,9 @@ limitations under the License. #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/data/dataset.pb.h" #include "tensorflow/core/data/service/credentials_factory.h" +#include "tensorflow/core/data/service/dispatcher.grpc.pb.h" +#include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/grpc_util.h" -#include "tensorflow/core/data/service/master.grpc.pb.h" -#include "tensorflow/core/data/service/master.pb.h" #include "tensorflow/core/data/standalone.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/errors.h" @@ -45,9 +45,9 @@ auto* tf_data_service_created = "has been created."); } // namespace -DataServiceWorkerImpl::DataServiceWorkerImpl(const std::string& master_address, - const std::string& protocol) - : master_address_(master_address), protocol_(protocol) { +DataServiceWorkerImpl::DataServiceWorkerImpl( + const std::string& dispatcher_address, const std::string& protocol) + : dispatcher_address_(dispatcher_address), protocol_(protocol) { tf_data_service_created->GetCell()->Set(true); } @@ -67,14 +67,13 @@ void DataServiceWorkerImpl::Start(const std::string& worker_address) { heartbeat_thread_.reset(thread); Status s = Register(); while (!s.ok()) { - LOG(WARNING) << "Failed to register with master at " << master_address_ - << ": " << s; + LOG(WARNING) << "Failed to register with dispatcher at " + << dispatcher_address_ << ": " << s; Env::Default()->SleepForMicroseconds(kHeartbeatIntervalMicros); s = Register(); } } - Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request, ProcessTaskResponse* response) { mutex_lock l(mu_); @@ -169,29 +168,29 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, return Status::OK(); } -Status DataServiceWorkerImpl::EnsureMasterStubInitialized() +Status DataServiceWorkerImpl::EnsureDispatcherStubInitialized() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (!master_stub_) { + if (!dispatcher_stub_) { ::grpc::ChannelArguments args; std::shared_ptr<::grpc::ChannelCredentials> credentials; TF_RETURN_IF_ERROR( CredentialsFactory::CreateClientCredentials(protocol_, &credentials)); auto channel = - ::grpc::CreateCustomChannel(master_address_, credentials, args); - master_stub_ = MasterService::NewStub(channel); + ::grpc::CreateCustomChannel(dispatcher_address_, credentials, args); + dispatcher_stub_ = DispatcherService::NewStub(channel); } return Status::OK(); } Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - VLOG(3) << "Registering with master at " << master_address_; - TF_RETURN_IF_ERROR(EnsureMasterStubInitialized()); + VLOG(3) << "Registering with dispatcher at " << dispatcher_address_; + TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized()); RegisterWorkerRequest req; req.set_worker_address(worker_address_); RegisterWorkerResponse resp; grpc::ClientContext ctx; - grpc::Status s = master_stub_->RegisterWorker(&ctx, req, &resp); + grpc::Status s = dispatcher_stub_->RegisterWorker(&ctx, req, &resp); if (!s.ok()) { return grpc_util::WrapError("Failed to register worker", s); } @@ -205,8 +204,8 @@ Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) { Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) { VLOG(3) << "Sending " << pending_completed_tasks_.size() - << " task updates to master"; - TF_RETURN_IF_ERROR(EnsureMasterStubInitialized()); + << " task updates to dispatcher"; + TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized()); WorkerUpdateRequest req; req.set_worker_id(worker_id_); for (int task_id : pending_completed_tasks_) { @@ -217,7 +216,7 @@ Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) { WorkerUpdateResponse resp; grpc::ClientContext ctx; - grpc::Status s = master_stub_->WorkerUpdate(&ctx, req, &resp); + grpc::Status s = dispatcher_stub_->WorkerUpdate(&ctx, req, &resp); if (!s.ok()) { return grpc_util::WrapError("Failed to send task updates", s); } @@ -238,7 +237,7 @@ void DataServiceWorkerImpl::HeartbeatThread() { } Status s = SendTaskUpdate(); if (!s.ok()) { - LOG(WARNING) << "Failed to send task updates to master: " << s; + LOG(WARNING) << "Failed to send task updates to dispatcher: " << s; } } } diff --git a/tensorflow/core/data/service/worker_impl.h b/tensorflow/core/data/service/worker_impl.h index 8c5fc2ea51c..adb3e97bbea 100644 --- a/tensorflow/core/data/service/worker_impl.h +++ b/tensorflow/core/data/service/worker_impl.h @@ -17,7 +17,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "tensorflow/core/data/service/common.pb.h" -#include "tensorflow/core/data/service/master.grpc.pb.h" +#include "tensorflow/core/data/service/dispatcher.grpc.pb.h" #include "tensorflow/core/data/service/worker.pb.h" #include "tensorflow/core/data/standalone.h" #include "tensorflow/core/lib/core/status.h" @@ -29,17 +29,17 @@ namespace data { // A TensorFlow DataService serves dataset elements over RPC. class DataServiceWorkerImpl { public: - explicit DataServiceWorkerImpl(const std::string& master_address, + explicit DataServiceWorkerImpl(const std::string& dispatcher_address, const std::string& protocol); ~DataServiceWorkerImpl(); // Starts the worker. The worker needs to know its own address so that it can - // register with the master. + // register with the dispatcher. void Start(const std::string& worker_address); // See worker.proto for API documentation. - /// Master-facing API. + /// Dispatcher-facing API. Status ProcessTask(const ProcessTaskRequest* request, ProcessTaskResponse* response); @@ -48,15 +48,15 @@ class DataServiceWorkerImpl { GetElementResponse* response); private: - // Sets master_stub_ if it isn't already set. - Status EnsureMasterStubInitialized(); - // Registers the worker with the master. + // Sets dispatcher_stub_ if it isn't already set. + Status EnsureDispatcherStubInitialized(); + // Registers the worker with the dispatcher. Status Register(); - // Sends task status to the master. + // Sends task status to the dispatcher. Status SendTaskUpdate(); // Creates an iterator to process a task. Status ProcessTaskInternal(const TaskDef& task); - // A thread for updating the master with worker status. + // A thread for updating the dispatcher with worker status. void HeartbeatThread(); typedef struct Task { @@ -67,18 +67,19 @@ class DataServiceWorkerImpl { std::unique_ptr iterator; } Task; - const std::string master_address_; - // Protocol for communicating with the master. + const std::string dispatcher_address_; + // Protocol for communicating with the dispatcher. const std::string protocol_; // The worker's own address. std::string worker_address_; mutex mu_; int64 worker_id_ TF_GUARDED_BY(mu_); - std::unique_ptr master_stub_ TF_GUARDED_BY(mu_); + std::unique_ptr dispatcher_stub_ TF_GUARDED_BY(mu_); // Information about tasks, keyed by task ids. absl::flat_hash_map tasks_ TF_GUARDED_BY(mu_); - // List of completed tasks which haven't yet been communicated to the master. + // List of completed tasks which haven't yet been communicated to the + // dispatcher. std::vector pending_completed_tasks_ TF_GUARDED_BY(mu_); bool cancelled_ TF_GUARDED_BY(mu_) = false; // Condition variable for notifying the heartbeat thread. diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index ee8f72bc663..0c2c5254590 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -69,7 +69,7 @@ const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second. // Dataset for reading data from the tf.data service non-deterministically. // // This dataset interleaves dataset elements produced by multiple tf.data -// workers. We periodically query the tf.data master to determine which workers +// workers. We periodically query the dispatcher to determine which workers // to read from (in case workers are added or removed). class DataServiceDatasetOp::Dataset : public DatasetBase { public: @@ -199,12 +199,13 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { Status Initialize(IteratorContext* ctx) override { VLOG(3) << "Connecting to " << dataset()->address_ << " in data service dataset op"; - DataServiceMasterClient master(dataset()->address_, dataset()->protocol_); + DataServiceDispatcherClient dispatcher(dataset()->address_, + dataset()->protocol_); if (dataset()->job_name_.empty()) { - TF_RETURN_IF_ERROR(master.CreateJob( + TF_RETURN_IF_ERROR(dispatcher.CreateJob( dataset()->dataset_id_, dataset()->processing_mode_, &job_id_)); } else { - TF_RETURN_IF_ERROR(master.GetOrCreateJob( + TF_RETURN_IF_ERROR(dispatcher.GetOrCreateJob( dataset()->dataset_id_, dataset()->processing_mode_, dataset()->job_name_, iterator_index_, &job_id_)); } @@ -283,11 +284,12 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { // Periodically refresh the task list. // Maintain one thread fetching elements for each task. - // TODO(aaudibert): Instead of polling, have master send updates when + // TODO(aaudibert): Instead of polling, have dispatcher send updates when // the list of tasks changes. void TaskThreadManager(std::unique_ptr ctx) { VLOG(3) << "Starting task thread manager"; - DataServiceMasterClient master(dataset()->address_, dataset()->protocol_); + DataServiceDispatcherClient dispatcher(dataset()->address_, + dataset()->protocol_); uint64 next_check = Env::Default()->NowMicros(); while (true) { { @@ -305,18 +307,19 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { return; } } - UpdateTasks(&master); + UpdateTasks(&dispatcher); UpdateWorkerThreads(ctx.get()); next_check = Env::Default()->NowMicros() + dataset()->task_refresh_interval_ms_ * 1000; } } - void UpdateTasks(DataServiceMasterClient* master) LOCKS_EXCLUDED(mu_) { + void UpdateTasks(DataServiceDispatcherClient* dispatcher) + LOCKS_EXCLUDED(mu_) { VLOG(3) << "Updating tasks"; std::vector tasks; bool job_finished; - Status s = master->GetTasks(job_id_, &tasks, &job_finished); + Status s = dispatcher->GetTasks(job_id_, &tasks, &job_finished); if (!s.ok()) { LOG(WARNING) << "Failed to get task info for job id " << job_id_ << ": " << s; diff --git a/tensorflow/core/kernels/data/experimental/data_service_ops.cc b/tensorflow/core/kernels/data/experimental/data_service_ops.cc index c6a54baad64..d9ef42d4afa 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_ops.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_ops.cc @@ -53,7 +53,7 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK( ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def)); - DataServiceMasterClient client(address, protocol); + DataServiceDispatcherClient client(address, protocol); int64 dataset_id; OP_REQUIRES_OK(ctx, client.RegisterDataset(graph_def, &dataset_id)); diff --git a/tensorflow/core/kernels/data/experimental/data_service_ops.h b/tensorflow/core/kernels/data/experimental/data_service_ops.h index b7d66938ae6..b3d6233aa52 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_ops.h +++ b/tensorflow/core/kernels/data/experimental/data_service_ops.h @@ -25,7 +25,7 @@ namespace data { // Registers a dataset with the tf.data service. // -// The address and protocol inputs are used to connect to the tf.data master. +// The address and protocol inputs are used to connect to the dispatcher. // The external state policy attribute determines whether to ignore, warn, or // error out when the dataset contains external state. // The op produces a dataset id for identifying the registered dataset. diff --git a/tensorflow/python/data/experimental/ops/data_service_ops.py b/tensorflow/python/data/experimental/ops/data_service_ops.py index 01ec155a89c..49f923eba47 100644 --- a/tensorflow/python/data/experimental/ops/data_service_ops.py +++ b/tensorflow/python/data/experimental/ops/data_service_ops.py @@ -77,7 +77,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource): amount of memory used, since `distribute` won't use more than `element_size` * `max_outstanding_requests` of memory. task_refresh_interval_hint_ms: (Optional.) A hint for how often to query - the master for task changes. + the dispatcher for task changes. """ if job_name is None: @@ -173,7 +173,7 @@ def _distribute(processing_mode, of memory used, since `distribute` won't use more than `element_size` * `max_outstanding_requests` of memory. task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the - master for task changes. + dispatcher for task changes. Returns: Dataset: A `Dataset` of the elements produced by the data service. diff --git a/tensorflow/python/data/experimental/service/__init__.py b/tensorflow/python/data/experimental/service/__init__.py index aecc07965bb..7887e53600a 100644 --- a/tensorflow/python/data/experimental/service/__init__.py +++ b/tensorflow/python/data/experimental/service/__init__.py @@ -19,5 +19,5 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.experimental.ops.data_service_ops import distribute -from tensorflow.python.data.experimental.service.server_lib import MasterServer +from tensorflow.python.data.experimental.service.server_lib import DispatchServer from tensorflow.python.data.experimental.service.server_lib import WorkerServer diff --git a/tensorflow/python/data/experimental/service/server_lib.py b/tensorflow/python/data/experimental/service/server_lib.py index f249af671a6..5a7ce73b4c7 100644 --- a/tensorflow/python/data/experimental/service/server_lib.py +++ b/tensorflow/python/data/experimental/service/server_lib.py @@ -24,35 +24,35 @@ from tensorflow.python.data.experimental.service import _pywrap_server_lib from tensorflow.python.util.tf_export import tf_export -@tf_export("data.experimental.service.MasterServer", v1=[]) -class MasterServer(object): - """An in-process tf.data service master server. +@tf_export("data.experimental.service.DispatchServer", v1=[]) +class DispatchServer(object): + """An in-process tf.data service dispatch server. - A `tf.data.experimental.service.MasterServer` coordinates a cluster of + A `tf.data.experimental.service.DispatchServer` coordinates a cluster of `tf.data.experimental.service.WorkerServer`s. When the workers start, they - register themselves with the master. + register themselves with the dispatcher. - >>> master = tf.data.experimental.service.MasterServer(port=0) - >>> master_address = master.target.split("://")[1] + >>> dispatcher = tf.data.experimental.service.DispatchServer(port=0) + >>> dispatcher_address = dispatcher.target.split("://")[1] >>> worker = tf.data.experimental.service.WorkerServer( - ... port=0, master_address=master_address) + ... port=0, dispatcher_address=dispatcher_address) >>> dataset = tf.data.Dataset.range(10) >>> dataset = dataset.apply(tf.data.experimental.service.distribute( - ... processing_mode="parallel_epochs", service=master.target)) + ... processing_mode="parallel_epochs", service=dispatcher.target)) >>> print(list(dataset.as_numpy_iterator())) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - When starting a dedicated tf.data master process, use join() to block + When starting a dedicated tf.data dispatch process, use join() to block indefinitely after starting up the server. ``` - master = tf.data.experimental.service.MasterServer(port=5050) - master.join() + dispatcher = tf.data.experimental.service.DispatchServer(port=5050) + dispatcher.join() ``` """ def __init__(self, port, protocol=None, start=True): - """Creates a new master server. + """Creates a new dispatch server. Args: port: Specifies the port to bind to. @@ -68,15 +68,16 @@ class MasterServer(object): if protocol is None: protocol = "grpc" self._protocol = protocol - self._server = _pywrap_server_lib.TF_DATA_NewMasterServer(port, protocol) + self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(port, protocol) if start: self._server.start() def start(self): """Starts this server. - >>> master = tf.data.experimental.service.MasterServer(port=0, start=False) - >>> master.start() + >>> dispatcher = tf.data.experimental.service.DispatchServer(port=0, + ... start=False) + >>> dispatcher.start() Raises: tf.errors.OpError: Or one of its subclasses if an error occurs while @@ -87,11 +88,11 @@ class MasterServer(object): def join(self): """Blocks until the server has shut down. - This is useful when starting a dedicated master process. + This is useful when starting a dedicated dispatch process. ``` - master = tf.data.experimental.service.MasterServer(port=5050) - master.join() + dispatcher = tf.data.experimental.service.DispatchServer(port=5050) + dispatcher.join() ``` Raises: @@ -104,10 +105,10 @@ class MasterServer(object): def target(self): """Returns a target that can be used to connect to the server. - >>> master = tf.data.experimental.service.MasterServer(port=0) + >>> dispatcher = tf.data.experimental.service.DispatchServer(port=0) >>> dataset = tf.data.Dataset.range(10) >>> dataset = dataset.apply(tf.data.experimental.service.distribute( - ... processing_mode="parallel_epochs", service=master.target)) + ... processing_mode="parallel_epochs", service=dispatcher.target)) The returned string will be in the form protocol://address, e.g. "grpc://localhost:5050". @@ -136,7 +137,7 @@ class MasterServer(object): return "localhost:{0}".format(self._server.bound_port()) def _num_workers(self): - """Returns the number of workers registered with the master.""" + """Returns the number of workers registered with the dispatcher.""" return self._server.num_workers() @@ -147,15 +148,15 @@ class WorkerServer(object): A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset` processing for user-defined datasets, and provides the resulting elements over RPC. A worker is associated with a single - `tf.data.experimental.service.MasterServer`. + `tf.data.experimental.service.DispatchServer`. - >>> master = tf.data.experimental.service.MasterServer(port=0) - >>> master_address = master.target.split("://")[1] + >>> dispatcher = tf.data.experimental.service.DispatchServer(port=0) + >>> dispatcher_address = dispatcher.target.split("://")[1] >>> worker = tf.data.experimental.service.WorkerServer( - ... port=0, master_address=master_address) + ... port=0, dispatcher_address=dispatcher_address) >>> dataset = tf.data.Dataset.range(10) >>> dataset = dataset.apply(tf.data.experimental.service.distribute( - ... processing_mode="parallel_epochs", service=master.target)) + ... processing_mode="parallel_epochs", service=dispatcher.target)) >>> print(list(dataset.as_numpy_iterator())) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] @@ -164,14 +165,14 @@ class WorkerServer(object): ``` worker = tf.data.experimental.service.WorkerServer( - port=5051, master_address="grpc://localhost:5050") + port=5051, dispatcher_address="grpc://localhost:5050") worker.join() ``` """ def __init__(self, port, - master_address, + dispatcher_address, worker_address=None, protocol=None, start=True): @@ -180,11 +181,12 @@ class WorkerServer(object): Args: port: Specifies the port to bind to. A value of 0 indicates that the worker can bind to any available port. - master_address: Specifies the address of the master server. + dispatcher_address: Specifies the address of the dispatcher. worker_address: (Optional.) Specifies the address of the worker server. - This address is passed to the master server so that the master can tell - clients how to connect to this worker. Defaults to `"localhost:%port%"`, - where `%port%` will be replaced with the port used by the worker. + This address is passed to the dispatcher so that the dispatcher can + tell clients how to connect to this worker. Defaults to + `"localhost:%port%"`, where `%port%` will be replaced with the port used + by the worker. protocol: (Optional.) Specifies the protocol to be used by the server. Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`. start: (Optional.) Boolean, indicating whether to start the server after @@ -201,7 +203,7 @@ class WorkerServer(object): self._protocol = protocol self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( - port, protocol, master_address, worker_address) + port, protocol, dispatcher_address, worker_address) if start: self._server.start() @@ -221,7 +223,7 @@ class WorkerServer(object): ``` worker_server = tf.data.experimental.service.WorkerServer( - port=5051, master_address="grpc://localhost:5050") + port=5051, dispatcher_address="grpc://localhost:5050") worker_server.join() ``` diff --git a/tensorflow/python/data/experimental/service/server_lib_test.py b/tensorflow/python/data/experimental/service/server_lib_test.py index 74eb11dc59c..f7354e64a3a 100644 --- a/tensorflow/python/data/experimental/service/server_lib_test.py +++ b/tensorflow/python/data/experimental/service/server_lib_test.py @@ -25,68 +25,68 @@ from tensorflow.python.platform import test class ServerLibTest(test.TestCase): - def testStartMaster(self): - master = server_lib.MasterServer(0, start=False) - master.start() + def testStartDispatcher(self): + dispatcher = server_lib.DispatchServer(0, start=False) + dispatcher.start() - def testMultipleStartMaster(self): - master = server_lib.MasterServer(0, start=True) - master.start() + def testMultipleStartDispatcher(self): + dispatcher = server_lib.DispatchServer(0, start=True) + dispatcher.start() def testStartWorker(self): - master = server_lib.MasterServer(0) - worker = server_lib.WorkerServer(0, master._address, start=False) + dispatcher = server_lib.DispatchServer(0) + worker = server_lib.WorkerServer(0, dispatcher._address, start=False) worker.start() def testMultipleStartWorker(self): - master = server_lib.MasterServer(0) - worker = server_lib.WorkerServer(0, master._address, start=True) + dispatcher = server_lib.DispatchServer(0) + worker = server_lib.WorkerServer(0, dispatcher._address, start=True) worker.start() - def testStopMaster(self): - master = server_lib.MasterServer(0) - master._stop() - master._stop() + def testStopDispatcher(self): + dispatcher = server_lib.DispatchServer(0) + dispatcher._stop() + dispatcher._stop() def testStopWorker(self): - master = server_lib.MasterServer(0) - worker = server_lib.WorkerServer(0, master._address) + dispatcher = server_lib.DispatchServer(0) + worker = server_lib.WorkerServer(0, dispatcher._address) worker._stop() worker._stop() - def testStopStartMaster(self): - master = server_lib.MasterServer(0) - master._stop() + def testStopStartDispatcher(self): + dispatcher = server_lib.DispatchServer(0) + dispatcher._stop() with self.assertRaisesRegex( RuntimeError, "Server cannot be started after it has been stopped"): - master.start() + dispatcher.start() def testStopStartWorker(self): - master = server_lib.MasterServer(0) - worker = server_lib.WorkerServer(0, master._address) + dispatcher = server_lib.DispatchServer(0) + worker = server_lib.WorkerServer(0, dispatcher._address) worker._stop() with self.assertRaisesRegex( RuntimeError, "Server cannot be started after it has been stopped"): worker.start() - def testJoinMaster(self): - master = server_lib.MasterServer(0) - master._stop() - master.join() + def testJoinDispatcher(self): + dispatcher = server_lib.DispatchServer(0) + dispatcher._stop() + dispatcher.join() def testJoinWorker(self): - master = server_lib.MasterServer(0) - worker = server_lib.WorkerServer(0, master._address) + dispatcher = server_lib.DispatchServer(0) + worker = server_lib.WorkerServer(0, dispatcher._address) worker._stop() worker.join() - def testMasterNumWorkers(self): - master = server_lib.MasterServer(0) - self.assertEqual(0, master._num_workers()) - worker1 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable - self.assertEqual(1, master._num_workers()) - worker2 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable - self.assertEqual(2, master._num_workers()) + def testDispatcherNumWorkers(self): + dispatcher = server_lib.DispatchServer(0) + self.assertEqual(0, dispatcher._num_workers()) + worker1 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable + self.assertEqual(1, dispatcher._num_workers()) + worker2 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable + self.assertEqual(2, dispatcher._num_workers()) if __name__ == "__main__": diff --git a/tensorflow/python/data/experimental/service/server_lib_wrapper.cc b/tensorflow/python/data/experimental/service/server_lib_wrapper.cc index 03453a56c7f..e288179dd36 100644 --- a/tensorflow/python/data/experimental/service/server_lib_wrapper.cc +++ b/tensorflow/python/data/experimental/service/server_lib_wrapper.cc @@ -28,13 +28,14 @@ limitations under the License. namespace py = pybind11; PYBIND11_MODULE(_pywrap_server_lib, m) { - py::class_(m, "MasterGrpcDataServer") - .def("start", &tensorflow::data::MasterGrpcDataServer::Start) - .def("stop", &tensorflow::data::MasterGrpcDataServer::Stop) - .def("join", &tensorflow::data::MasterGrpcDataServer::Join) - .def("bound_port", &tensorflow::data::MasterGrpcDataServer::BoundPort) + py::class_(m, + "DispatchGrpcDataServer") + .def("start", &tensorflow::data::DispatchGrpcDataServer::Start) + .def("stop", &tensorflow::data::DispatchGrpcDataServer::Stop) + .def("join", &tensorflow::data::DispatchGrpcDataServer::Join) + .def("bound_port", &tensorflow::data::DispatchGrpcDataServer::BoundPort) .def("num_workers", - [](tensorflow::data::MasterGrpcDataServer* server) -> int { + [](tensorflow::data::DispatchGrpcDataServer* server) -> int { int num_workers; tensorflow::Status status = server->NumWorkers(&num_workers); tensorflow::MaybeRaiseFromStatus(status); @@ -48,12 +49,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) { .def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort); m.def( - "TF_DATA_NewMasterServer", + "TF_DATA_NewDispatchServer", [](int port, std::string protocol) - -> std::unique_ptr { - std::unique_ptr server; + -> std::unique_ptr { + std::unique_ptr server; tensorflow::Status status = - tensorflow::data::NewMasterServer(port, protocol, &server); + tensorflow::data::NewDispatchServer(port, protocol, &server); tensorflow::MaybeRaiseFromStatus(status); return server; }, @@ -61,12 +62,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) { m.def( "TF_DATA_NewWorkerServer", - [](int port, std::string protocol, std::string master_address, + [](int port, std::string protocol, std::string dispatcher_address, std::string worker_address) -> std::unique_ptr { std::unique_ptr server; tensorflow::Status status = tensorflow::data::NewWorkerServer( - port, protocol, master_address, worker_address, &server); + port, protocol, dispatcher_address, worker_address, &server); tensorflow::MaybeRaiseFromStatus(status); return server; }, diff --git a/tensorflow/python/data/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/kernel_tests/data_service_ops_test.py index 488bf97f184..98db4fb0d4b 100644 --- a/tensorflow/python/data/kernel_tests/data_service_ops_test.py +++ b/tensorflow/python/data/kernel_tests/data_service_ops_test.py @@ -59,23 +59,25 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): num_workers: The number of workers in the cluster. Returns: - The address of the master. + The address of the dispatcher. """ - self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) + self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL) self._servers = [] for _ in range(num_workers): self._servers.append( server_lib.WorkerServer( - port=0, master_address=self._master._address, protocol=PROTOCOL)) + port=0, + dispatcher_address=self._dispatcher._address, + protocol=PROTOCOL)) - return self._master._address + return self._dispatcher._address @combinations.generate(test_base.eager_only_combinations()) def testDistributeBasic(self): num_elements = 10 - master_address = self.create_cluster(1) + dispatcher_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, master_address) + ds = _make_distributed_dataset(ds, dispatcher_address) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) @@ -83,10 +85,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testDifferentShuffleOrders(self): random_seed.set_random_seed(None) num_elements = 100 - master_address = self.create_cluster(2) + dispatcher_address = self.create_cluster(2) ds = dataset_ops.Dataset.range(num_elements) ds = ds.shuffle(num_elements) - ds = _make_distributed_dataset(ds, master_address) + ds = _make_distributed_dataset(ds, dispatcher_address) output = [elem.numpy() for elem in ds] # The output will be two sequences of range(num_elements) @@ -104,9 +106,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testMultipleEpochs(self): num_elements = 3 - master_address = self.create_cluster(1) + dispatcher_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, master_address) + ds = _make_distributed_dataset(ds, dispatcher_address) for _ in range(10): self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds]) @@ -114,9 +116,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testRepeatedDataset(self): num_elements = 10 num_repetitions = 5 - master_address = self.create_cluster(1) + dispatcher_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, master_address) + ds = _make_distributed_dataset(ds, dispatcher_address) ds = ds.repeat(num_repetitions) self.assertDatasetProduces( ds, expected_output=num_repetitions * list(range(num_elements))) @@ -125,12 +127,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testConcurrentEpoch(self): num_elements = 10 num_datasets = 3 - master_address = self.create_cluster(1) + dispatcher_address = self.create_cluster(1) iterators = [] results = [] for _ in range(num_datasets): ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, master_address) + ds = _make_distributed_dataset(ds, dispatcher_address) iterators.append(iter(ds)) results.append([]) @@ -146,9 +148,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): self.skipTest("Not yet implemented") num_elements = 10 num_iterators = 3 - master_address = self.create_cluster(1) + dispatcher_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, master_address) + ds = _make_distributed_dataset(ds, dispatcher_address) result = [] iterators = [] for _ in range(num_iterators): @@ -170,20 +172,20 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testMultiWorker(self): num_workers = 3 num_elements = 10 - master_address = self.create_cluster(num_workers) + dispatcher_address = self.create_cluster(num_workers) ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, master_address) + ds = _make_distributed_dataset(ds, dispatcher_address) results = [elem.numpy() for elem in ds] self.assertCountEqual(num_workers * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testAddWorkerMidJob(self): - self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) + self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( - port=0, master_address=self._master._address, protocol=PROTOCOL) + port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, self._master._address) + ds = _make_distributed_dataset(ds, self._dispatcher._address) iterator = iter(ds) results = [] # Read halfway through the dataset. @@ -191,10 +193,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): results.append(next(iterator).numpy()) self._new_worker = server_lib.WorkerServer( - port=0, master_address=self._master._address, protocol=PROTOCOL) + port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL) - # Wait for the new worker to register with the master. - while self._master._num_workers() < 2: + # Wait for the new worker to register with the dispatcher. + while self._dispatcher._num_workers() < 2: time.sleep(10 / 1000) # 10ms for elem in iterator: @@ -206,12 +208,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): combinations.times(test_base.eager_only_combinations(), combinations.combine(use_same_port=[True, False]))) def testRestartWorker(self, use_same_port): - self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) + self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( - port=0, master_address=self._master._address, protocol=PROTOCOL) + port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, self._master._address) + ds = _make_distributed_dataset(ds, self._dispatcher._address) iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 @@ -224,7 +226,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): port = int(self._worker._address.split(":")[1]) self._worker._stop() self._new_worker = server_lib.WorkerServer( - port=port, master_address=self._master._address, protocol=PROTOCOL) + port=port, + dispatcher_address=self._dispatcher._address, + protocol=PROTOCOL) # There may have been some elements prefetched from the first worker # before it was stopped. @@ -259,12 +263,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testInsideFunction(self): num_workers = 3 num_elements = 10 - master_address = self.create_cluster(num_workers) + dispatcher_address = self.create_cluster(num_workers) @def_function.function def f(): ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, master_address) + ds = _make_distributed_dataset(ds, dispatcher_address) result = tensor_array_ops.TensorArray( dtypes.int64, size=num_workers * num_elements, dynamic_size=True) i = 0 @@ -279,10 +283,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testSharedJobName(self): num_elements = 100 - master_address = self.create_cluster(1) + dispatcher_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name") - ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name") + ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name") + ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name") iter1 = iter(ds1) iter2 = iter(ds2) results = [] @@ -298,20 +302,22 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testDifferentJobNames(self): num_elements = 10 - master_address = self.create_cluster(1) + dispatcher_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name1") - ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name2") + ds1 = _make_distributed_dataset( + ds, dispatcher_address, job_name="job_name1") + ds2 = _make_distributed_dataset( + ds, dispatcher_address, job_name="job_name2") self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameMultiIteration(self): num_elements = 10 - master_address = self.create_cluster(1) + dispatcher_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name") - ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name") + ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name") + ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name") # iteration 1 self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, []) @@ -323,11 +329,11 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testSharedJobNameRepeat(self): num_elements = 100 num_repetitions = 3 - master_address = self.create_cluster(1) + dispatcher_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name") + ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name") ds1 = ds1.repeat(num_repetitions) - ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name") + ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name") ds2 = ds2.repeat(num_repetitions) results = [] iter1 = iter(ds1) @@ -345,7 +351,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testApplyDeterminismOption(self): elements = list(range(10)) - master_address = self.create_cluster(1) + dispatcher_address = self.create_cluster(1) def dataset_fn(delay_ms): @@ -362,7 +368,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): opts = dataset_ops.Options() opts.experimental_deterministic = False ds = ds.with_options(opts) - ds = _make_distributed_dataset(ds, master_address) + ds = _make_distributed_dataset(ds, dispatcher_address) return ds self.checkDeterminism( @@ -379,8 +385,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): options.experimental_external_state_policy = external_state_policy ds = ds.with_options(options) - master_address = self.create_cluster(3) - ds = _make_distributed_dataset(ds, master_address) + dispatcher_address = self.create_cluster(3) + ds = _make_distributed_dataset(ds, dispatcher_address) next(iter(ds)) @combinations.generate( @@ -400,12 +406,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testDistributeFromInterleave(self): - master_address = self.create_cluster(1) + dispatcher_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(2) def interleave_fn(_): ds = dataset_ops.Dataset.range(2) - _make_distributed_dataset(ds, master_address) + _make_distributed_dataset(ds, dispatcher_address) return ds with self.assertRaisesRegex( diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-master-server.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-dispatch-server.pbtxt similarity index 85% rename from tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-master-server.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-dispatch-server.pbtxt index daac7716ca8..86efaf268e0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-master-server.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-dispatch-server.pbtxt @@ -1,6 +1,6 @@ -path: "tensorflow.data.experimental.service.MasterServer" +path: "tensorflow.data.experimental.service.DispatchServer" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member { name: "target" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-server.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-server.pbtxt index d0121b7edf2..8d8b1fd8584 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-server.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-server.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'port\', \'master_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], " + argspec: "args=[\'self\', \'port\', \'dispatcher_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], " } member_method { name: "join" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.pbtxt index 347dd3c74b1..00f0035e082 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.data.experimental.service" tf_module { member { - name: "MasterServer" + name: "DispatchServer" mtype: "" } member { diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 07f5906aa08..b256015582a 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -99,8 +99,8 @@ tensorflow::data::GrpcDataServerBase::Join tensorflow::data::GrpcDataServerBase::Start tensorflow::data::GrpcDataServerBase::Stop tensorflow::data::GrpcDataServerBase::BoundPort -tensorflow::data::MasterGrpcDataServer::NumWorkers -tensorflow::data::NewMasterServer +tensorflow::data::DispatchGrpcDataServer::NumWorkers +tensorflow::data::NewDispatchServer tensorflow::data::NewWorkerServer [protos_all] # device_lib, dtypes