Update "master" to "dispatch"/"dispatcher" in tf.data service terminology.

Dispatcher is more descriptive and follows the guidance in https://developers.google.com/style/word-list#master

PiperOrigin-RevId: 321613785
Change-Id: Iaa576d35f0581e21278101f8b31201ba737a6865
This commit is contained in:
Andrew Audibert 2020-07-16 11:52:42 -07:00
parent 0aa1d61fad
commit 6b9a9d98bb
30 changed files with 367 additions and 353 deletions

View File

@ -28,8 +28,8 @@ tf_proto_library(
) )
tf_proto_library( tf_proto_library(
name = "master_proto", name = "dispatcher_proto",
srcs = ["master.proto"], srcs = ["dispatcher.proto"],
has_services = 1, has_services = 1,
cc_api_version = 2, cc_api_version = 2,
protodeps = tf_additional_all_protos() + [ protodeps = tf_additional_all_protos() + [
@ -49,17 +49,17 @@ tf_proto_library(
) )
cc_library( cc_library(
name = "master_impl", name = "dispatcher_impl",
srcs = ["master_impl.cc"], srcs = ["dispatcher_impl.cc"],
hdrs = [ hdrs = [
"master_impl.h", "dispatcher_impl.h",
], ],
deps = [ deps = [
":common_proto_cc", ":common_proto_cc",
":credentials_factory", ":credentials_factory",
":data_service", ":data_service",
":dispatcher_proto_cc",
":grpc_util", ":grpc_util",
":master_proto_cc",
":worker_cc_grpc_proto", ":worker_cc_grpc_proto",
":worker_proto_cc", ":worker_proto_cc",
"//tensorflow/c:c_api_internal", "//tensorflow/c:c_api_internal",
@ -86,9 +86,9 @@ cc_library(
deps = [ deps = [
":common_proto_cc", ":common_proto_cc",
":credentials_factory", ":credentials_factory",
":dispatcher_cc_grpc_proto",
":dispatcher_proto_cc",
":grpc_util", ":grpc_util",
":master_cc_grpc_proto",
":master_proto_cc",
":worker_proto_cc", ":worker_proto_cc",
"//tensorflow/c:c_api_internal", "//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_helper",
@ -207,12 +207,12 @@ tf_cc_test(
) )
cc_library( cc_library(
name = "grpc_master_impl", name = "grpc_dispatcher_impl",
srcs = ["grpc_master_impl.cc"], srcs = ["grpc_dispatcher_impl.cc"],
hdrs = ["grpc_master_impl.h"], hdrs = ["grpc_dispatcher_impl.h"],
deps = [ deps = [
":master_cc_grpc_proto", ":dispatcher_cc_grpc_proto",
":master_impl", ":dispatcher_impl",
"//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util",
tf_grpc_cc_dependency(), tf_grpc_cc_dependency(),
], ],
@ -250,7 +250,7 @@ cc_library(
], ],
deps = [ deps = [
":credentials_factory", ":credentials_factory",
":grpc_master_impl", ":grpc_dispatcher_impl",
":grpc_util", ":grpc_util",
":grpc_worker_impl", ":grpc_worker_impl",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -268,9 +268,9 @@ cc_library(
], ],
deps = [ deps = [
":credentials_factory", ":credentials_factory",
":dispatcher_cc_grpc_proto",
":dispatcher_proto_cc",
":grpc_util", ":grpc_util",
":master_cc_grpc_proto",
":master_proto_cc",
":worker_cc_grpc_proto", ":worker_cc_grpc_proto",
":worker_proto_cc", ":worker_proto_cc",
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -287,12 +287,12 @@ tf_cc_test(
tags = ["no_windows"], tags = ["no_windows"],
deps = [ deps = [
":data_service", ":data_service",
":grpc_master_impl", ":dispatcher_cc_grpc_proto",
":dispatcher_proto_cc",
":grpc_dispatcher_impl",
":grpc_util", ":grpc_util",
":grpc_worker_impl", ":grpc_worker_impl",
":local_credentials_factory", ":local_credentials_factory",
":master_cc_grpc_proto",
":master_proto_cc",
":server_lib", ":server_lib",
":test_cluster", ":test_cluster",
":test_util", ":test_util",
@ -309,11 +309,11 @@ tf_cc_test(
) )
cc_grpc_library( cc_grpc_library(
name = "master_cc_grpc_proto", name = "dispatcher_cc_grpc_proto",
srcs = [":master_proto"], srcs = [":dispatcher_proto"],
generate_mocks = True, generate_mocks = True,
grpc_only = True, grpc_only = True,
deps = [":master_proto_cc"], deps = [":dispatcher_proto_cc"],
) )
cc_grpc_library( cc_grpc_library(

View File

@ -18,8 +18,8 @@ limitations under the License.
#include "grpcpp/create_channel.h" #include "grpcpp/create_channel.h"
#include "grpcpp/security/credentials.h" #include "grpcpp/security/credentials.h"
#include "tensorflow/core/data/service/credentials_factory.h" #include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/dataset.h"
@ -54,8 +54,8 @@ std::string ProcessingModeToString(ProcessingMode mode) {
} }
} }
Status DataServiceMasterClient::RegisterDataset(GraphDef dataset, Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset,
int64* dataset_id) { int64* dataset_id) {
TF_RETURN_IF_ERROR(EnsureInitialized()); TF_RETURN_IF_ERROR(EnsureInitialized());
GetOrRegisterDatasetRequest req; GetOrRegisterDatasetRequest req;
*req.mutable_dataset()->mutable_graph() = dataset; *req.mutable_dataset()->mutable_graph() = dataset;
@ -69,9 +69,9 @@ Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
return Status::OK(); return Status::OK();
} }
Status DataServiceMasterClient::CreateJob(int64 dataset_id, Status DataServiceDispatcherClient::CreateJob(int64 dataset_id,
ProcessingMode processing_mode, ProcessingMode processing_mode,
int64* job_id) { int64* job_id) {
TF_RETURN_IF_ERROR(EnsureInitialized()); TF_RETURN_IF_ERROR(EnsureInitialized());
CreateJobRequest req; CreateJobRequest req;
req.set_dataset_id(dataset_id); req.set_dataset_id(dataset_id);
@ -88,11 +88,9 @@ Status DataServiceMasterClient::CreateJob(int64 dataset_id,
return Status::OK(); return Status::OK();
} }
Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id, Status DataServiceDispatcherClient::GetOrCreateJob(
ProcessingMode processing_mode, int64 dataset_id, ProcessingMode processing_mode,
const std::string& job_name, const std::string& job_name, int job_name_index, int64* job_id) {
int job_name_index,
int64* job_id) {
TF_RETURN_IF_ERROR(EnsureInitialized()); TF_RETURN_IF_ERROR(EnsureInitialized());
GetOrCreateJobRequest req; GetOrCreateJobRequest req;
req.set_dataset_id(dataset_id); req.set_dataset_id(dataset_id);
@ -112,9 +110,9 @@ Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id,
return Status::OK(); return Status::OK();
} }
Status DataServiceMasterClient::GetTasks(int64 job_id, Status DataServiceDispatcherClient::GetTasks(int64 job_id,
std::vector<TaskInfo>* tasks, std::vector<TaskInfo>* tasks,
bool* job_finished) { bool* job_finished) {
TF_RETURN_IF_ERROR(EnsureInitialized()); TF_RETURN_IF_ERROR(EnsureInitialized());
GetTasksRequest req; GetTasksRequest req;
req.set_job_id(job_id); req.set_job_id(job_id);
@ -132,7 +130,8 @@ Status DataServiceMasterClient::GetTasks(int64 job_id,
return Status::OK(); return Status::OK();
} }
Status DataServiceMasterClient::GetWorkers(std::vector<WorkerInfo>* workers) { Status DataServiceDispatcherClient::GetWorkers(
std::vector<WorkerInfo>* workers) {
TF_RETURN_IF_ERROR(EnsureInitialized()); TF_RETURN_IF_ERROR(EnsureInitialized());
GetWorkersRequest req; GetWorkersRequest req;
GetWorkersResponse resp; GetWorkersResponse resp;
@ -148,12 +147,12 @@ Status DataServiceMasterClient::GetWorkers(std::vector<WorkerInfo>* workers) {
return Status::OK(); return Status::OK();
} }
Status DataServiceMasterClient::EnsureInitialized() { Status DataServiceDispatcherClient::EnsureInitialized() {
std::shared_ptr<grpc::ChannelCredentials> credentials; std::shared_ptr<grpc::ChannelCredentials> credentials;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
CredentialsFactory::CreateClientCredentials(protocol_, &credentials)); CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
auto channel = grpc::CreateChannel(address_, credentials); auto channel = grpc::CreateChannel(address_, credentials);
stub_ = MasterService::NewStub(channel); stub_ = DispatcherService::NewStub(channel);
return Status::OK(); return Status::OK();
} }
@ -187,10 +186,11 @@ Status DataServiceWorkerClient::EnsureInitialized() {
return Status::OK(); return Status::OK();
} }
Status CreateDataServiceMasterClient( Status CreateDataServiceDispatcherClient(
const std::string& address, const std::string& protocol, const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceMasterClient>* out) { std::unique_ptr<DataServiceDispatcherClient>* out) {
auto client = absl::make_unique<DataServiceMasterClient>(address, protocol); auto client =
absl::make_unique<DataServiceDispatcherClient>(address, protocol);
TF_RETURN_IF_ERROR(client->Initialize()); TF_RETURN_IF_ERROR(client->Initialize());
*out = std::move(client); *out = std::move(client);
return Status::OK(); return Status::OK();

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_ #ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
#define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_ #define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
#include "tensorflow/core/data/service/master.grpc.pb.h" #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
@ -67,11 +67,11 @@ class DataServiceClientBase {
const std::string protocol_; const std::string protocol_;
}; };
// Client for communicating with the tf.data service master. // Client for communicating with the tf.data service dispatcher.
class DataServiceMasterClient : public DataServiceClientBase { class DataServiceDispatcherClient : public DataServiceClientBase {
public: public:
DataServiceMasterClient(const std::string& address, DataServiceDispatcherClient(const std::string& address,
const std::string& protocol) const std::string& protocol)
: DataServiceClientBase(address, protocol) {} : DataServiceClientBase(address, protocol) {}
// Registers a dataset with the tf.data service, and stores the generated // Registers a dataset with the tf.data service, and stores the generated
@ -90,13 +90,13 @@ class DataServiceMasterClient : public DataServiceClientBase {
const std::string& job_name, int job_name_index, const std::string& job_name, int job_name_index,
int64* job_id); int64* job_id);
// Queries the master for the tasks associated with the specified job. // Queries the dispatcher for the tasks associated with the specified job.
// The tasks will be stored in *tasks, and whether the job is finished will // The tasks will be stored in *tasks, and whether the job is finished will
// be stored in `*job_finished`. // be stored in `*job_finished`.
Status GetTasks(int64 job_id, std::vector<TaskInfo>* tasks, Status GetTasks(int64 job_id, std::vector<TaskInfo>* tasks,
bool* job_finished); bool* job_finished);
// Queries the master for its registered workers. The worker info will be // Queries the dispatcher for its registered workers. The worker info will be
// stored in `*workers`. // stored in `*workers`.
Status GetWorkers(std::vector<WorkerInfo>* workers); Status GetWorkers(std::vector<WorkerInfo>* workers);
@ -104,7 +104,7 @@ class DataServiceMasterClient : public DataServiceClientBase {
Status EnsureInitialized() override; Status EnsureInitialized() override;
private: private:
std::unique_ptr<MasterService::Stub> stub_; std::unique_ptr<DispatcherService::Stub> stub_;
}; };
// Client for communicating with the tf.data service worker. // Client for communicating with the tf.data service worker.
@ -127,10 +127,10 @@ class DataServiceWorkerClient : public DataServiceClientBase {
std::unique_ptr<WorkerService::Stub> stub_; std::unique_ptr<WorkerService::Stub> stub_;
}; };
// Creates and initializes a new tf.data service master client. // Creates and initializes a new tf.data service dispatcher client.
Status CreateDataServiceMasterClient( Status CreateDataServiceDispatcherClient(
const std::string& address, const std::string& protocol, const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceMasterClient>* out); std::unique_ptr<DataServiceDispatcherClient>* out);
// Creates and initializes a new tf.data service worker client. // Creates and initializes a new tf.data service worker client.
Status CreateDataServiceWorkerClient( Status CreateDataServiceWorkerClient(

View File

@ -19,9 +19,9 @@ limitations under the License.
#include "grpcpp/security/credentials.h" #include "grpcpp/security/credentials.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "tensorflow/core/data/compression_utils.h" #include "tensorflow/core/data/compression_utils.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/master.pb.h"
#include "tensorflow/core/data/service/server_lib.h" #include "tensorflow/core/data/service/server_lib.h"
#include "tensorflow/core/data/service/test_cluster.h" #include "tensorflow/core/data/service/test_cluster.h"
#include "tensorflow/core/data/service/test_util.h" #include "tensorflow/core/data/service/test_util.h"
@ -66,9 +66,10 @@ TEST(DataService, ProcessingModeToString) {
TEST(DataService, GetWorkers) { TEST(DataService, GetWorkers) {
TestCluster cluster(1); TestCluster cluster(1);
TF_ASSERT_OK(cluster.Initialize()); TF_ASSERT_OK(cluster.Initialize());
DataServiceMasterClient master(cluster.MasterAddress(), kProtocol); DataServiceDispatcherClient dispatcher(cluster.DispatcherAddress(),
kProtocol);
std::vector<WorkerInfo> workers; std::vector<WorkerInfo> workers;
TF_EXPECT_OK(master.GetWorkers(&workers)); TF_EXPECT_OK(dispatcher.GetWorkers(&workers));
EXPECT_EQ(1, workers.size()); EXPECT_EQ(1, workers.size());
} }

View File

@ -110,11 +110,11 @@ message GetWorkersResponse {
repeated WorkerInfo workers = 1; repeated WorkerInfo workers = 1;
} }
service MasterService { service DispatcherService {
// Registers a worker with the master. // Registers a worker with the dispatcher.
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse); rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse);
// Updates the master with information about the worker's state. // Updates the dispatcher with information about the worker's state.
rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse); rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse);
// Registers a dataset with the server, or returns its id if it is already // Registers a dataset with the server, or returns its id if it is already
@ -134,6 +134,6 @@ service MasterService {
// Reports a list of all tasks for a job. // Reports a list of all tasks for a job.
rpc GetTasks(GetTasksRequest) returns (GetTasksResponse); rpc GetTasks(GetTasksRequest) returns (GetTasksResponse);
// Reports a list of all workers registered with the master. // Reports a list of all workers registered with the dispatcher.
rpc GetWorkers(GetWorkersRequest) returns (GetWorkersResponse); rpc GetWorkers(GetWorkersRequest) returns (GetWorkersResponse);
} }

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/data/service/master_impl.h" #include "tensorflow/core/data/service/dispatcher_impl.h"
#include <memory> #include <memory>
#include <tuple> #include <tuple>
@ -26,8 +26,8 @@ limitations under the License.
#include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/credentials_factory.h" #include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/data_service.h" #include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/dataset_utils.h"
@ -53,10 +53,10 @@ Status CreateWorkerStub(const std::string& address,
} }
} // namespace } // namespace
DataServiceMasterImpl::DataServiceMasterImpl(const std::string protocol) DataServiceDispatcherImpl::DataServiceDispatcherImpl(const std::string protocol)
: protocol_(protocol) {} : protocol_(protocol) {}
Status DataServiceMasterImpl::RegisterWorker( Status DataServiceDispatcherImpl::RegisterWorker(
const RegisterWorkerRequest* request, RegisterWorkerResponse* response) { const RegisterWorkerRequest* request, RegisterWorkerResponse* response) {
VLOG(3) << "Received register worker request"; VLOG(3) << "Received register worker request";
mutex_lock l(mu_); mutex_lock l(mu_);
@ -86,8 +86,8 @@ Status DataServiceMasterImpl::RegisterWorker(
return Status::OK(); return Status::OK();
} }
Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request, Status DataServiceDispatcherImpl::WorkerUpdate(
WorkerUpdateResponse* response) { const WorkerUpdateRequest* request, WorkerUpdateResponse* response) {
mutex_lock l(mu_); mutex_lock l(mu_);
int64 worker_id = request->worker_id(); int64 worker_id = request->worker_id();
for (auto& update : request->updates()) { for (auto& update : request->updates()) {
@ -106,7 +106,7 @@ Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request,
return Status::OK(); return Status::OK();
} }
Status DataServiceMasterImpl::GetOrRegisterDataset( Status DataServiceDispatcherImpl::GetOrRegisterDataset(
const GetOrRegisterDatasetRequest* request, const GetOrRegisterDatasetRequest* request,
GetOrRegisterDatasetResponse* response) { GetOrRegisterDatasetResponse* response) {
uint64 fingerprint; uint64 fingerprint;
@ -128,8 +128,8 @@ Status DataServiceMasterImpl::GetOrRegisterDataset(
return Status::OK(); return Status::OK();
} }
int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint, int64 DataServiceDispatcherImpl::RegisterDataset(uint64 fingerprint,
const DatasetDef& dataset) const DatasetDef& dataset)
EXCLUSIVE_LOCKS_REQUIRED(mu_) { EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64 dataset_id = next_dataset_id_++; int64 dataset_id = next_dataset_id_++;
auto new_dataset = auto new_dataset =
@ -142,8 +142,8 @@ int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
return dataset_id; return dataset_id;
} }
Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request, Status DataServiceDispatcherImpl::CreateJob(const CreateJobRequest* request,
CreateJobResponse* response) { CreateJobResponse* response) {
VLOG(3) << "Received create job request for dataset id " VLOG(3) << "Received create job request for dataset id "
<< request->dataset_id(); << request->dataset_id();
ProcessingMode processing_mode = ProcessingMode(request->processing_mode()); ProcessingMode processing_mode = ProcessingMode(request->processing_mode());
@ -157,7 +157,7 @@ Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
return Status::OK(); return Status::OK();
} }
Status DataServiceMasterImpl::GetOrCreateJob( Status DataServiceDispatcherImpl::GetOrCreateJob(
const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) { const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
VLOG(3) << "Received get or create job request for dataset id " VLOG(3) << "Received get or create job request for dataset id "
<< request->dataset_id() << " with name " << request->job_name() << request->dataset_id() << " with name " << request->job_name()
@ -193,7 +193,7 @@ Status DataServiceMasterImpl::GetOrCreateJob(
} }
// Validates that the job matches the given processing_mode and dataset_id. // Validates that the job matches the given processing_mode and dataset_id.
Status DataServiceMasterImpl::ValidateMatchingJob( Status DataServiceDispatcherImpl::ValidateMatchingJob(
const Job& job, ProcessingMode processing_mode, int64 dataset_id) { const Job& job, ProcessingMode processing_mode, int64 dataset_id) {
DCHECK(job.name().has_value()); DCHECK(job.name().has_value());
std::string job_name = job.name().value(); std::string job_name = job.name().value();
@ -214,10 +214,10 @@ Status DataServiceMasterImpl::ValidateMatchingJob(
return Status::OK(); return Status::OK();
} }
Status DataServiceMasterImpl::CreateJob(int64 dataset_id, Status DataServiceDispatcherImpl::CreateJob(
ProcessingMode processing_mode, int64 dataset_id, ProcessingMode processing_mode,
absl::optional<std::string> job_name, absl::optional<std::string> job_name, int64* out_job_id)
int64* out_job_id) LOCKS_EXCLUDED(mu_) { LOCKS_EXCLUDED(mu_) {
switch (processing_mode) { switch (processing_mode) {
case ProcessingMode::PARALLEL_EPOCHS: case ProcessingMode::PARALLEL_EPOCHS:
break; break;
@ -274,14 +274,16 @@ Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
return Status::OK(); return Status::OK();
} }
const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTask( const DataServiceDispatcherImpl::Task& DataServiceDispatcherImpl::CreateTask(
Job* job, const std::string& worker_address) LOCKS_EXCLUDED(mu_) { Job* job, const std::string& worker_address) LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_); mutex_lock l(mu_);
return CreateTaskLocked(job, worker_address); return CreateTaskLocked(job, worker_address);
} }
const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTaskLocked( const DataServiceDispatcherImpl::Task&
Job* job, const std::string& worker_address) EXCLUSIVE_LOCKS_REQUIRED(mu_) { DataServiceDispatcherImpl::CreateTaskLocked(Job* job,
const std::string& worker_address)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64 task_id = next_task_id_++; int64 task_id = next_task_id_++;
DCHECK(!tasks_.contains(task_id)); DCHECK(!tasks_.contains(task_id));
tasks_.insert({task_id, Task(task_id, job->job_id(), job->dataset_id(), tasks_.insert({task_id, Task(task_id, job->job_id(), job->dataset_id(),
@ -290,7 +292,7 @@ const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTaskLocked(
return tasks_.at(task_id); return tasks_.at(task_id);
} }
Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) { Status DataServiceDispatcherImpl::EnsureWorkerStubInitialized(Worker* worker) {
if (!worker->stub()) { if (!worker->stub()) {
std::unique_ptr<WorkerService::Stub> stub; std::unique_ptr<WorkerService::Stub> stub;
TF_RETURN_IF_ERROR(CreateWorkerStub(worker->address(), protocol_, &stub)); TF_RETURN_IF_ERROR(CreateWorkerStub(worker->address(), protocol_, &stub));
@ -299,8 +301,8 @@ Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) {
return Status::OK(); return Status::OK();
} }
Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task, Status DataServiceDispatcherImpl::AllocateTaskToWorker(const Task& task,
Worker* worker) Worker* worker)
LOCKS_EXCLUDED(mu_) { LOCKS_EXCLUDED(mu_) {
TF_RETURN_IF_ERROR(EnsureWorkerStubInitialized(worker)); TF_RETURN_IF_ERROR(EnsureWorkerStubInitialized(worker));
grpc::ClientContext client_ctx; grpc::ClientContext client_ctx;
@ -322,8 +324,8 @@ Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task,
return Status::OK(); return Status::OK();
} }
Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request, Status DataServiceDispatcherImpl::GetTasks(const GetTasksRequest* request,
GetTasksResponse* response) { GetTasksResponse* response) {
mutex_lock l(mu_); mutex_lock l(mu_);
VLOG(3) << "Looking up tasks for job id " << request->job_id(); VLOG(3) << "Looking up tasks for job id " << request->job_id();
auto it = jobs_.find(request->job_id()); auto it = jobs_.find(request->job_id());
@ -346,8 +348,8 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
return Status::OK(); return Status::OK();
} }
Status DataServiceMasterImpl::GetWorkers(const GetWorkersRequest* request, Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request,
GetWorkersResponse* response) { GetWorkersResponse* response) {
mutex_lock l(mu_); mutex_lock l(mu_);
VLOG(3) << "Enter GetWorkers"; VLOG(3) << "Enter GetWorkers";
for (auto& worker : workers_) { for (auto& worker : workers_) {

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_ #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
#define TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_ #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/data_service.h" #include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/data/service/master.pb.h" #include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
@ -40,11 +40,11 @@ namespace data {
// ProcessingModeDef which determines what data it produces. // ProcessingModeDef which determines what data it produces.
// * Task: A job is broken into multiple tasks, which each represent // * Task: A job is broken into multiple tasks, which each represent
// iterating over all of or part of the dataset. Workers process tasks. // iterating over all of or part of the dataset. Workers process tasks.
class DataServiceMasterImpl { class DataServiceDispatcherImpl {
public: public:
explicit DataServiceMasterImpl(const std::string protocol); explicit DataServiceDispatcherImpl(const std::string protocol);
// See master.proto for API documentation. // See dispatcher.proto for API documentation.
/// Worker-facing API. /// Worker-facing API.
Status RegisterWorker(const RegisterWorkerRequest* request, Status RegisterWorker(const RegisterWorkerRequest* request,
@ -191,7 +191,7 @@ class DataServiceMasterImpl {
// Creates a new task for a job, returning a reference to the task. // Creates a new task for a job, returning a reference to the task.
const Task& CreateTask(Job* job, const std::string& worker_address) const Task& CreateTask(Job* job, const std::string& worker_address)
LOCKS_EXCLUDED(mu_); LOCKS_EXCLUDED(mu_);
// Same as `CreateTask`, but expects that the master lock is already held. // Same as `CreateTask`, but expects that the dispatcher lock is already held.
const Task& CreateTaskLocked(Job* job, const std::string& worker_address) const Task& CreateTaskLocked(Job* job, const std::string& worker_address)
EXCLUSIVE_LOCKS_REQUIRED(mu_); EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Validates that an existing job matches the given processing_mode and // Validates that an existing job matches the given processing_mode and
@ -225,10 +225,10 @@ class DataServiceMasterImpl {
absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_ absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_
TF_GUARDED_BY(mu_); TF_GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceMasterImpl); TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl);
}; };
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_ #endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/data/service/grpc_master_impl.h" #include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
#include "grpcpp/server_context.h" #include "grpcpp/server_context.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
@ -25,18 +25,18 @@ using ::grpc::ServerBuilder;
using ::grpc::ServerContext; using ::grpc::ServerContext;
using ::grpc::Status; using ::grpc::Status;
GrpcMasterImpl::GrpcMasterImpl(ServerBuilder* server_builder, GrpcDispatcherImpl::GrpcDispatcherImpl(ServerBuilder* server_builder,
const std::string& protocol) const std::string& protocol)
: impl_(protocol) { : impl_(protocol) {
server_builder->RegisterService(this); server_builder->RegisterService(this);
VLOG(1) << "Registered data service master"; VLOG(1) << "Registered data service dispatcher";
} }
#define HANDLER(method) \ #define HANDLER(method) \
Status GrpcMasterImpl::method(ServerContext* context, \ Status GrpcDispatcherImpl::method(ServerContext* context, \
const method##Request* request, \ const method##Request* request, \
method##Response* response) { \ method##Response* response) { \
return ToGrpcStatus(impl_.method(request, response)); \ return ToGrpcStatus(impl_.method(request, response)); \
} }
HANDLER(RegisterWorker); HANDLER(RegisterWorker);
HANDLER(WorkerUpdate); HANDLER(WorkerUpdate);

View File

@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_ #ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_
#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_ #define TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_
#include "grpcpp/server_builder.h" #include "grpcpp/server_builder.h"
#include "tensorflow/core/data/service/master.grpc.pb.h" #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/master_impl.h" #include "tensorflow/core/data/service/dispatcher_impl.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
@ -29,14 +29,14 @@ namespace data {
// //
// ::grpc::ServerBuilder builder; // ::grpc::ServerBuilder builder;
// // configure builder // // configure builder
// GrpcMasterImpl data_service(&builder); // GrpcDispatcherImpl data_service(&builder);
// builder.BuildAndStart() // builder.BuildAndStart()
// //
class GrpcMasterImpl : public MasterService::Service { class GrpcDispatcherImpl : public DispatcherService::Service {
public: public:
explicit GrpcMasterImpl(grpc::ServerBuilder* server_builder, explicit GrpcDispatcherImpl(grpc::ServerBuilder* server_builder,
const std::string& protocol); const std::string& protocol);
~GrpcMasterImpl() override {} ~GrpcDispatcherImpl() override {}
#define HANDLER(method) \ #define HANDLER(method) \
grpc::Status method(grpc::ServerContext* context, \ grpc::Status method(grpc::ServerContext* context, \
@ -52,12 +52,12 @@ class GrpcMasterImpl : public MasterService::Service {
#undef HANDLER #undef HANDLER
private: private:
DataServiceMasterImpl impl_; DataServiceDispatcherImpl impl_;
TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterImpl); TF_DISALLOW_COPY_AND_ASSIGN(GrpcDispatcherImpl);
}; };
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_ #endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_

View File

@ -26,9 +26,9 @@ using ::grpc::ServerContext;
using ::grpc::Status; using ::grpc::Status;
GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder, GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder,
const std::string& master_address, const std::string& dispatcher_address,
const std::string& protocol) const std::string& protocol)
: impl_(master_address, protocol) { : impl_(dispatcher_address, protocol) {
server_builder->RegisterService(this); server_builder->RegisterService(this);
VLOG(1) << "Registered data service worker"; VLOG(1) << "Registered data service worker";
} }

View File

@ -35,7 +35,7 @@ namespace data {
class GrpcWorkerImpl : public WorkerService::Service { class GrpcWorkerImpl : public WorkerService::Service {
public: public:
explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder, explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder,
const std::string& master_address, const std::string& dispatcher_address,
const std::string& protocol); const std::string& protocol);
~GrpcWorkerImpl() override {} ~GrpcWorkerImpl() override {}

View File

@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/core/data/service/server_lib.h" #include "tensorflow/core/data/service/server_lib.h"
#include "tensorflow/core/data/service/credentials_factory.h" #include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/grpc_master_impl.h" #include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
#include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/grpc_worker_impl.h" #include "tensorflow/core/data/service/grpc_worker_impl.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
@ -72,18 +72,18 @@ void GrpcDataServerBase::Join() { server_->Wait(); }
int GrpcDataServerBase::BoundPort() { return bound_port(); } int GrpcDataServerBase::BoundPort() { return bound_port(); }
MasterGrpcDataServer::MasterGrpcDataServer(int port, DispatchGrpcDataServer::DispatchGrpcDataServer(int port,
const std::string& protocol) const std::string& protocol)
: GrpcDataServerBase(port, protocol) {} : GrpcDataServerBase(port, protocol) {}
MasterGrpcDataServer::~MasterGrpcDataServer() { delete service_; } DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
void MasterGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) { void DispatchGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
auto service = absl::make_unique<GrpcMasterImpl>(builder, protocol_); auto service = absl::make_unique<GrpcDispatcherImpl>(builder, protocol_);
service_ = service.release(); service_ = service.release();
} }
Status MasterGrpcDataServer::NumWorkers(int* num_workers) { Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
GetWorkersRequest req; GetWorkersRequest req;
GetWorkersResponse resp; GetWorkersResponse resp;
grpc::ServerContext ctx; grpc::ServerContext ctx;
@ -95,19 +95,18 @@ Status MasterGrpcDataServer::NumWorkers(int* num_workers) {
return Status::OK(); return Status::OK();
} }
WorkerGrpcDataServer::WorkerGrpcDataServer(int port, WorkerGrpcDataServer::WorkerGrpcDataServer(
const std::string& protocol, int port, const std::string& protocol,
const std::string& master_address, const std::string& dispatcher_address, const std::string& worker_address)
const std::string& worker_address)
: GrpcDataServerBase(port, protocol), : GrpcDataServerBase(port, protocol),
master_address_(master_address), dispatcher_address_(dispatcher_address),
worker_address_(worker_address) {} worker_address_(worker_address) {}
WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; } WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
void WorkerGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) { void WorkerGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
auto service = auto service = absl::make_unique<GrpcWorkerImpl>(builder, dispatcher_address_,
absl::make_unique<GrpcWorkerImpl>(builder, master_address_, protocol_); protocol_);
service_ = service.release(); service_ = service.release();
} }
@ -123,25 +122,25 @@ Status WorkerGrpcDataServer::StartServiceInternal() {
return Status::OK(); return Status::OK();
} }
Status NewMasterServer(int port, const std::string& protocol, Status NewDispatchServer(int port, const std::string& protocol,
std::unique_ptr<MasterGrpcDataServer>* out_server) { std::unique_ptr<DispatchGrpcDataServer>* out_server) {
*out_server = absl::make_unique<MasterGrpcDataServer>(port, protocol); *out_server = absl::make_unique<DispatchGrpcDataServer>(port, protocol);
return Status::OK(); return Status::OK();
} }
Status NewWorkerServer(int port, const std::string& protocol, Status NewWorkerServer(int port, const std::string& protocol,
const std::string& master_address, const std::string& dispatcher_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server) { std::unique_ptr<WorkerGrpcDataServer>* out_server) {
return NewWorkerServer(port, protocol, master_address, /*worker_address=*/"", return NewWorkerServer(port, protocol, dispatcher_address,
out_server); /*worker_address=*/"", out_server);
} }
Status NewWorkerServer(int port, const std::string& protocol, Status NewWorkerServer(int port, const std::string& protocol,
const std::string& master_address, const std::string& dispatcher_address,
const std::string& worker_address, const std::string& worker_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server) { std::unique_ptr<WorkerGrpcDataServer>* out_server) {
*out_server = absl::make_unique<WorkerGrpcDataServer>( *out_server = absl::make_unique<WorkerGrpcDataServer>(
port, protocol, master_address, worker_address); port, protocol, dispatcher_address, worker_address);
return Status::OK(); return Status::OK();
} }

View File

@ -25,7 +25,7 @@ namespace data {
// Forward declared because transitively depending on .grpc.pb.h files causes // Forward declared because transitively depending on .grpc.pb.h files causes
// issues in the pywrap build. // issues in the pywrap build.
class GrpcMasterImpl; class GrpcDispatcherImpl;
class GrpcWorkerImpl; class GrpcWorkerImpl;
// A grpc server for the tf.data service. // A grpc server for the tf.data service.
@ -35,7 +35,7 @@ class GrpcDataServerBase {
// server will find an available port in `Start()`. The chosen port can be // server will find an available port in `Start()`. The chosen port can be
// found in the output of `Target()`. // found in the output of `Target()`.
// //
// master_address is only needed for worker data servers. // dispatcher_address is only needed for worker data servers.
GrpcDataServerBase(int requested_port, const std::string& protocol); GrpcDataServerBase(int requested_port, const std::string& protocol);
virtual ~GrpcDataServerBase() {} virtual ~GrpcDataServerBase() {}
@ -70,12 +70,12 @@ class GrpcDataServerBase {
std::unique_ptr<grpc::Server> server_; std::unique_ptr<grpc::Server> server_;
}; };
class MasterGrpcDataServer : public GrpcDataServerBase { class DispatchGrpcDataServer : public GrpcDataServerBase {
public: public:
MasterGrpcDataServer(int requested_port, const std::string& protocol); DispatchGrpcDataServer(int requested_port, const std::string& protocol);
~MasterGrpcDataServer() override; ~DispatchGrpcDataServer() override;
// Returns the number of workers registerd with the master. // Returns the number of workers registerd with the dispatcher.
Status NumWorkers(int* num_workers); Status NumWorkers(int* num_workers);
protected: protected:
@ -83,14 +83,14 @@ class MasterGrpcDataServer : public GrpcDataServerBase {
Status StartServiceInternal() override { return Status::OK(); } Status StartServiceInternal() override { return Status::OK(); }
private: private:
// Owned. We use a raw pointer because GrpcMasterImpl is forward-declared. // Owned. We use a raw pointer because GrpcDispatcherImpl is forward-declared.
GrpcMasterImpl* service_; GrpcDispatcherImpl* service_;
}; };
class WorkerGrpcDataServer : public GrpcDataServerBase { class WorkerGrpcDataServer : public GrpcDataServerBase {
public: public:
WorkerGrpcDataServer(int requested_port, const std::string& protocol, WorkerGrpcDataServer(int requested_port, const std::string& protocol,
const std::string& master_address, const std::string& dispatcher_address,
const std::string& worker_address); const std::string& worker_address);
~WorkerGrpcDataServer() override; ~WorkerGrpcDataServer() override;
@ -99,15 +99,15 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
Status StartServiceInternal() override; Status StartServiceInternal() override;
private: private:
const std::string master_address_; const std::string dispatcher_address_;
const std::string worker_address_; const std::string worker_address_;
// Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared. // Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared.
GrpcWorkerImpl* service_; GrpcWorkerImpl* service_;
}; };
// Creates a master tf.data server and stores it in `*out_server`. // Creates a dispatch tf.data server and stores it in `*out_server`.
Status NewMasterServer(int port, const std::string& protocol, Status NewDispatchServer(int port, const std::string& protocol,
std::unique_ptr<MasterGrpcDataServer>* out_server); std::unique_ptr<DispatchGrpcDataServer>* out_server);
// Creates a worker tf.data server and stores it in `*out_server`. // Creates a worker tf.data server and stores it in `*out_server`.
// //
@ -115,18 +115,18 @@ Status NewMasterServer(int port, const std::string& protocol,
// will be chosen in Start(). This value can be queried with BoundPort(). // will be chosen in Start(). This value can be queried with BoundPort().
// //
// The worker_address argument is optional. If left empty, it will default to // The worker_address argument is optional. If left empty, it will default to
// "localhost:%port%". When the worker registers with the master, the worker // "localhost:%port%". When the worker registers with the dispatcher, the worker
// will report the worker address, so that the master can tell clients where to // will report the worker address, so that the dispatcher can tell clients where
// read from. The address may contain the placeholder "%port%", which will be // to read from. The address may contain the placeholder "%port%", which will be
// replaced with the value of BoundPort(). // replaced with the value of BoundPort().
Status NewWorkerServer(int port, const std::string& protocol, Status NewWorkerServer(int port, const std::string& protocol,
const std::string& master_address, const std::string& dispatcher_address,
const std::string& worker_address, const std::string& worker_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server); std::unique_ptr<WorkerGrpcDataServer>* out_server);
// Creates a worker using the default worker_address. // Creates a worker using the default worker_address.
Status NewWorkerServer(int port, const std::string& protocol, Status NewWorkerServer(int port, const std::string& protocol,
const std::string& master_address, const std::string& dispatcher_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server); std::unique_ptr<WorkerGrpcDataServer>* out_server);
} // namespace data } // namespace data

View File

@ -45,9 +45,9 @@ Status TestCluster::Initialize() {
"Test cluster has already been initialized."); "Test cluster has already been initialized.");
} }
initialized_ = true; initialized_ = true;
TF_RETURN_IF_ERROR(NewMasterServer(/*port=*/0, kProtocol, &master_)); TF_RETURN_IF_ERROR(NewDispatchServer(/*port=*/0, kProtocol, &dispatcher_));
TF_RETURN_IF_ERROR(master_->Start()); TF_RETURN_IF_ERROR(dispatcher_->Start());
master_address_ = absl::StrCat("localhost:", master_->BoundPort()); dispatcher_address_ = absl::StrCat("localhost:", dispatcher_->BoundPort());
workers_.reserve(num_workers_); workers_.reserve(num_workers_);
worker_addresses_.reserve(num_workers_); worker_addresses_.reserve(num_workers_);
for (int i = 0; i < num_workers_; ++i) { for (int i = 0; i < num_workers_; ++i) {
@ -59,14 +59,14 @@ Status TestCluster::Initialize() {
Status TestCluster::AddWorker() { Status TestCluster::AddWorker() {
std::unique_ptr<WorkerGrpcDataServer> worker; std::unique_ptr<WorkerGrpcDataServer> worker;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
NewWorkerServer(/*port=*/0, kProtocol, master_address_, &worker)); NewWorkerServer(/*port=*/0, kProtocol, dispatcher_address_, &worker));
TF_RETURN_IF_ERROR(worker->Start()); TF_RETURN_IF_ERROR(worker->Start());
worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort())); worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort()));
workers_.push_back(std::move(worker)); workers_.push_back(std::move(worker));
return Status::OK(); return Status::OK();
} }
std::string TestCluster::MasterAddress() { return master_address_; } std::string TestCluster::DispatcherAddress() { return dispatcher_address_; }
std::string TestCluster::WorkerAddress(int index) { std::string TestCluster::WorkerAddress(int index) {
DCHECK_GE(index, 0); DCHECK_GE(index, 0);

View File

@ -24,7 +24,7 @@ namespace data {
// Helper class for unit testing a tf.data service cluster. // Helper class for unit testing a tf.data service cluster.
class TestCluster { class TestCluster {
public: public:
// Creates a new test cluster with a master and `num_workers` workers. // Creates a new test cluster with a dispatcher and `num_workers` workers.
explicit TestCluster(int num_workers); explicit TestCluster(int num_workers);
// Initializes the test cluster. This must be called before interacting with // Initializes the test cluster. This must be called before interacting with
@ -32,8 +32,8 @@ class TestCluster {
Status Initialize(); Status Initialize();
// Adds a new worker to the cluster. // Adds a new worker to the cluster.
Status AddWorker(); Status AddWorker();
// Returns the master address in the form "hostname:port". // Returns the dispatcher address in the form "hostname:port".
std::string MasterAddress(); std::string DispatcherAddress();
// Returns the address of the worker at the specified index, in the form // Returns the address of the worker at the specified index, in the form
// "hostname:port". The index must be non-negative and less than the number of // "hostname:port". The index must be non-negative and less than the number of
// workers in the cluster. // workers in the cluster.
@ -42,8 +42,8 @@ class TestCluster {
private: private:
bool initialized_ = false; bool initialized_ = false;
int num_workers_; int num_workers_;
std::unique_ptr<MasterGrpcDataServer> master_; std::unique_ptr<DispatchGrpcDataServer> dispatcher_;
std::string master_address_; std::string dispatcher_address_;
std::vector<std::unique_ptr<WorkerGrpcDataServer>> workers_; std::vector<std::unique_ptr<WorkerGrpcDataServer>> workers_;
std::vector<std::string> worker_addresses_; std::vector<std::string> worker_addresses_;
}; };

View File

@ -21,9 +21,9 @@ limitations under the License.
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/data/dataset.pb.h" #include "tensorflow/core/data/dataset.pb.h"
#include "tensorflow/core/data/service/credentials_factory.h" #include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/master.pb.h"
#include "tensorflow/core/data/standalone.h" #include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
@ -45,9 +45,9 @@ auto* tf_data_service_created =
"has been created."); "has been created.");
} // namespace } // namespace
DataServiceWorkerImpl::DataServiceWorkerImpl(const std::string& master_address, DataServiceWorkerImpl::DataServiceWorkerImpl(
const std::string& protocol) const std::string& dispatcher_address, const std::string& protocol)
: master_address_(master_address), protocol_(protocol) { : dispatcher_address_(dispatcher_address), protocol_(protocol) {
tf_data_service_created->GetCell()->Set(true); tf_data_service_created->GetCell()->Set(true);
} }
@ -67,14 +67,13 @@ void DataServiceWorkerImpl::Start(const std::string& worker_address) {
heartbeat_thread_.reset(thread); heartbeat_thread_.reset(thread);
Status s = Register(); Status s = Register();
while (!s.ok()) { while (!s.ok()) {
LOG(WARNING) << "Failed to register with master at " << master_address_ LOG(WARNING) << "Failed to register with dispatcher at "
<< ": " << s; << dispatcher_address_ << ": " << s;
Env::Default()->SleepForMicroseconds(kHeartbeatIntervalMicros); Env::Default()->SleepForMicroseconds(kHeartbeatIntervalMicros);
s = Register(); s = Register();
} }
} }
Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request, Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
ProcessTaskResponse* response) { ProcessTaskResponse* response) {
mutex_lock l(mu_); mutex_lock l(mu_);
@ -169,29 +168,29 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
return Status::OK(); return Status::OK();
} }
Status DataServiceWorkerImpl::EnsureMasterStubInitialized() Status DataServiceWorkerImpl::EnsureDispatcherStubInitialized()
EXCLUSIVE_LOCKS_REQUIRED(mu_) { EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!master_stub_) { if (!dispatcher_stub_) {
::grpc::ChannelArguments args; ::grpc::ChannelArguments args;
std::shared_ptr<::grpc::ChannelCredentials> credentials; std::shared_ptr<::grpc::ChannelCredentials> credentials;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
CredentialsFactory::CreateClientCredentials(protocol_, &credentials)); CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
auto channel = auto channel =
::grpc::CreateCustomChannel(master_address_, credentials, args); ::grpc::CreateCustomChannel(dispatcher_address_, credentials, args);
master_stub_ = MasterService::NewStub(channel); dispatcher_stub_ = DispatcherService::NewStub(channel);
} }
return Status::OK(); return Status::OK();
} }
Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) { Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
VLOG(3) << "Registering with master at " << master_address_; VLOG(3) << "Registering with dispatcher at " << dispatcher_address_;
TF_RETURN_IF_ERROR(EnsureMasterStubInitialized()); TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
RegisterWorkerRequest req; RegisterWorkerRequest req;
req.set_worker_address(worker_address_); req.set_worker_address(worker_address_);
RegisterWorkerResponse resp; RegisterWorkerResponse resp;
grpc::ClientContext ctx; grpc::ClientContext ctx;
grpc::Status s = master_stub_->RegisterWorker(&ctx, req, &resp); grpc::Status s = dispatcher_stub_->RegisterWorker(&ctx, req, &resp);
if (!s.ok()) { if (!s.ok()) {
return grpc_util::WrapError("Failed to register worker", s); return grpc_util::WrapError("Failed to register worker", s);
} }
@ -205,8 +204,8 @@ Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) { Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
VLOG(3) << "Sending " << pending_completed_tasks_.size() VLOG(3) << "Sending " << pending_completed_tasks_.size()
<< " task updates to master"; << " task updates to dispatcher";
TF_RETURN_IF_ERROR(EnsureMasterStubInitialized()); TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
WorkerUpdateRequest req; WorkerUpdateRequest req;
req.set_worker_id(worker_id_); req.set_worker_id(worker_id_);
for (int task_id : pending_completed_tasks_) { for (int task_id : pending_completed_tasks_) {
@ -217,7 +216,7 @@ Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
WorkerUpdateResponse resp; WorkerUpdateResponse resp;
grpc::ClientContext ctx; grpc::ClientContext ctx;
grpc::Status s = master_stub_->WorkerUpdate(&ctx, req, &resp); grpc::Status s = dispatcher_stub_->WorkerUpdate(&ctx, req, &resp);
if (!s.ok()) { if (!s.ok()) {
return grpc_util::WrapError("Failed to send task updates", s); return grpc_util::WrapError("Failed to send task updates", s);
} }
@ -238,7 +237,7 @@ void DataServiceWorkerImpl::HeartbeatThread() {
} }
Status s = SendTaskUpdate(); Status s = SendTaskUpdate();
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Failed to send task updates to master: " << s; LOG(WARNING) << "Failed to send task updates to dispatcher: " << s;
} }
} }
} }

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/master.grpc.pb.h" #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/worker.pb.h" #include "tensorflow/core/data/service/worker.pb.h"
#include "tensorflow/core/data/standalone.h" #include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
@ -29,17 +29,17 @@ namespace data {
// A TensorFlow DataService serves dataset elements over RPC. // A TensorFlow DataService serves dataset elements over RPC.
class DataServiceWorkerImpl { class DataServiceWorkerImpl {
public: public:
explicit DataServiceWorkerImpl(const std::string& master_address, explicit DataServiceWorkerImpl(const std::string& dispatcher_address,
const std::string& protocol); const std::string& protocol);
~DataServiceWorkerImpl(); ~DataServiceWorkerImpl();
// Starts the worker. The worker needs to know its own address so that it can // Starts the worker. The worker needs to know its own address so that it can
// register with the master. // register with the dispatcher.
void Start(const std::string& worker_address); void Start(const std::string& worker_address);
// See worker.proto for API documentation. // See worker.proto for API documentation.
/// Master-facing API. /// Dispatcher-facing API.
Status ProcessTask(const ProcessTaskRequest* request, Status ProcessTask(const ProcessTaskRequest* request,
ProcessTaskResponse* response); ProcessTaskResponse* response);
@ -48,15 +48,15 @@ class DataServiceWorkerImpl {
GetElementResponse* response); GetElementResponse* response);
private: private:
// Sets master_stub_ if it isn't already set. // Sets dispatcher_stub_ if it isn't already set.
Status EnsureMasterStubInitialized(); Status EnsureDispatcherStubInitialized();
// Registers the worker with the master. // Registers the worker with the dispatcher.
Status Register(); Status Register();
// Sends task status to the master. // Sends task status to the dispatcher.
Status SendTaskUpdate(); Status SendTaskUpdate();
// Creates an iterator to process a task. // Creates an iterator to process a task.
Status ProcessTaskInternal(const TaskDef& task); Status ProcessTaskInternal(const TaskDef& task);
// A thread for updating the master with worker status. // A thread for updating the dispatcher with worker status.
void HeartbeatThread(); void HeartbeatThread();
typedef struct Task { typedef struct Task {
@ -67,18 +67,19 @@ class DataServiceWorkerImpl {
std::unique_ptr<standalone::Iterator> iterator; std::unique_ptr<standalone::Iterator> iterator;
} Task; } Task;
const std::string master_address_; const std::string dispatcher_address_;
// Protocol for communicating with the master. // Protocol for communicating with the dispatcher.
const std::string protocol_; const std::string protocol_;
// The worker's own address. // The worker's own address.
std::string worker_address_; std::string worker_address_;
mutex mu_; mutex mu_;
int64 worker_id_ TF_GUARDED_BY(mu_); int64 worker_id_ TF_GUARDED_BY(mu_);
std::unique_ptr<MasterService::Stub> master_stub_ TF_GUARDED_BY(mu_); std::unique_ptr<DispatcherService::Stub> dispatcher_stub_ TF_GUARDED_BY(mu_);
// Information about tasks, keyed by task ids. // Information about tasks, keyed by task ids.
absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_); absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_);
// List of completed tasks which haven't yet been communicated to the master. // List of completed tasks which haven't yet been communicated to the
// dispatcher.
std::vector<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_); std::vector<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
bool cancelled_ TF_GUARDED_BY(mu_) = false; bool cancelled_ TF_GUARDED_BY(mu_) = false;
// Condition variable for notifying the heartbeat thread. // Condition variable for notifying the heartbeat thread.

View File

@ -69,7 +69,7 @@ const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second.
// Dataset for reading data from the tf.data service non-deterministically. // Dataset for reading data from the tf.data service non-deterministically.
// //
// This dataset interleaves dataset elements produced by multiple tf.data // This dataset interleaves dataset elements produced by multiple tf.data
// workers. We periodically query the tf.data master to determine which workers // workers. We periodically query the dispatcher to determine which workers
// to read from (in case workers are added or removed). // to read from (in case workers are added or removed).
class DataServiceDatasetOp::Dataset : public DatasetBase { class DataServiceDatasetOp::Dataset : public DatasetBase {
public: public:
@ -199,12 +199,13 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
VLOG(3) << "Connecting to " << dataset()->address_ VLOG(3) << "Connecting to " << dataset()->address_
<< " in data service dataset op"; << " in data service dataset op";
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_); DataServiceDispatcherClient dispatcher(dataset()->address_,
dataset()->protocol_);
if (dataset()->job_name_.empty()) { if (dataset()->job_name_.empty()) {
TF_RETURN_IF_ERROR(master.CreateJob( TF_RETURN_IF_ERROR(dispatcher.CreateJob(
dataset()->dataset_id_, dataset()->processing_mode_, &job_id_)); dataset()->dataset_id_, dataset()->processing_mode_, &job_id_));
} else { } else {
TF_RETURN_IF_ERROR(master.GetOrCreateJob( TF_RETURN_IF_ERROR(dispatcher.GetOrCreateJob(
dataset()->dataset_id_, dataset()->processing_mode_, dataset()->dataset_id_, dataset()->processing_mode_,
dataset()->job_name_, iterator_index_, &job_id_)); dataset()->job_name_, iterator_index_, &job_id_));
} }
@ -283,11 +284,12 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
// Periodically refresh the task list. // Periodically refresh the task list.
// Maintain one thread fetching elements for each task. // Maintain one thread fetching elements for each task.
// TODO(aaudibert): Instead of polling, have master send updates when // TODO(aaudibert): Instead of polling, have dispatcher send updates when
// the list of tasks changes. // the list of tasks changes.
void TaskThreadManager(std::unique_ptr<IteratorContext> ctx) { void TaskThreadManager(std::unique_ptr<IteratorContext> ctx) {
VLOG(3) << "Starting task thread manager"; VLOG(3) << "Starting task thread manager";
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_); DataServiceDispatcherClient dispatcher(dataset()->address_,
dataset()->protocol_);
uint64 next_check = Env::Default()->NowMicros(); uint64 next_check = Env::Default()->NowMicros();
while (true) { while (true) {
{ {
@ -305,18 +307,19 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
return; return;
} }
} }
UpdateTasks(&master); UpdateTasks(&dispatcher);
UpdateWorkerThreads(ctx.get()); UpdateWorkerThreads(ctx.get());
next_check = Env::Default()->NowMicros() + next_check = Env::Default()->NowMicros() +
dataset()->task_refresh_interval_ms_ * 1000; dataset()->task_refresh_interval_ms_ * 1000;
} }
} }
void UpdateTasks(DataServiceMasterClient* master) LOCKS_EXCLUDED(mu_) { void UpdateTasks(DataServiceDispatcherClient* dispatcher)
LOCKS_EXCLUDED(mu_) {
VLOG(3) << "Updating tasks"; VLOG(3) << "Updating tasks";
std::vector<TaskInfo> tasks; std::vector<TaskInfo> tasks;
bool job_finished; bool job_finished;
Status s = master->GetTasks(job_id_, &tasks, &job_finished); Status s = dispatcher->GetTasks(job_id_, &tasks, &job_finished);
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Failed to get task info for job id " << job_id_ << ": " LOG(WARNING) << "Failed to get task info for job id " << job_id_ << ": "
<< s; << s;

View File

@ -53,7 +53,7 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def)); ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
DataServiceMasterClient client(address, protocol); DataServiceDispatcherClient client(address, protocol);
int64 dataset_id; int64 dataset_id;
OP_REQUIRES_OK(ctx, client.RegisterDataset(graph_def, &dataset_id)); OP_REQUIRES_OK(ctx, client.RegisterDataset(graph_def, &dataset_id));

View File

@ -25,7 +25,7 @@ namespace data {
// Registers a dataset with the tf.data service. // Registers a dataset with the tf.data service.
// //
// The address and protocol inputs are used to connect to the tf.data master. // The address and protocol inputs are used to connect to the dispatcher.
// The external state policy attribute determines whether to ignore, warn, or // The external state policy attribute determines whether to ignore, warn, or
// error out when the dataset contains external state. // error out when the dataset contains external state.
// The op produces a dataset id for identifying the registered dataset. // The op produces a dataset id for identifying the registered dataset.

View File

@ -77,7 +77,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
amount of memory used, since `distribute` won't use more than amount of memory used, since `distribute` won't use more than
`element_size` * `max_outstanding_requests` of memory. `element_size` * `max_outstanding_requests` of memory.
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
the master for task changes. the dispatcher for task changes.
""" """
if job_name is None: if job_name is None:
@ -173,7 +173,7 @@ def _distribute(processing_mode,
of memory used, since `distribute` won't use more than `element_size` * of memory used, since `distribute` won't use more than `element_size` *
`max_outstanding_requests` of memory. `max_outstanding_requests` of memory.
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
master for task changes. dispatcher for task changes.
Returns: Returns:
Dataset: A `Dataset` of the elements produced by the data service. Dataset: A `Dataset` of the elements produced by the data service.

View File

@ -19,5 +19,5 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.data.experimental.ops.data_service_ops import distribute from tensorflow.python.data.experimental.ops.data_service_ops import distribute
from tensorflow.python.data.experimental.service.server_lib import MasterServer from tensorflow.python.data.experimental.service.server_lib import DispatchServer
from tensorflow.python.data.experimental.service.server_lib import WorkerServer from tensorflow.python.data.experimental.service.server_lib import WorkerServer

View File

@ -24,35 +24,35 @@ from tensorflow.python.data.experimental.service import _pywrap_server_lib
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.service.MasterServer", v1=[]) @tf_export("data.experimental.service.DispatchServer", v1=[])
class MasterServer(object): class DispatchServer(object):
"""An in-process tf.data service master server. """An in-process tf.data service dispatch server.
A `tf.data.experimental.service.MasterServer` coordinates a cluster of A `tf.data.experimental.service.DispatchServer` coordinates a cluster of
`tf.data.experimental.service.WorkerServer`s. When the workers start, they `tf.data.experimental.service.WorkerServer`s. When the workers start, they
register themselves with the master. register themselves with the dispatcher.
>>> master = tf.data.experimental.service.MasterServer(port=0) >>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
>>> master_address = master.target.split("://")[1] >>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer( >>> worker = tf.data.experimental.service.WorkerServer(
... port=0, master_address=master_address) ... port=0, dispatcher_address=dispatcher_address)
>>> dataset = tf.data.Dataset.range(10) >>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute( >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=master.target)) ... processing_mode="parallel_epochs", service=dispatcher.target))
>>> print(list(dataset.as_numpy_iterator())) >>> print(list(dataset.as_numpy_iterator()))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
When starting a dedicated tf.data master process, use join() to block When starting a dedicated tf.data dispatch process, use join() to block
indefinitely after starting up the server. indefinitely after starting up the server.
``` ```
master = tf.data.experimental.service.MasterServer(port=5050) dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
master.join() dispatcher.join()
``` ```
""" """
def __init__(self, port, protocol=None, start=True): def __init__(self, port, protocol=None, start=True):
"""Creates a new master server. """Creates a new dispatch server.
Args: Args:
port: Specifies the port to bind to. port: Specifies the port to bind to.
@ -68,15 +68,16 @@ class MasterServer(object):
if protocol is None: if protocol is None:
protocol = "grpc" protocol = "grpc"
self._protocol = protocol self._protocol = protocol
self._server = _pywrap_server_lib.TF_DATA_NewMasterServer(port, protocol) self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(port, protocol)
if start: if start:
self._server.start() self._server.start()
def start(self): def start(self):
"""Starts this server. """Starts this server.
>>> master = tf.data.experimental.service.MasterServer(port=0, start=False) >>> dispatcher = tf.data.experimental.service.DispatchServer(port=0,
>>> master.start() ... start=False)
>>> dispatcher.start()
Raises: Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while tf.errors.OpError: Or one of its subclasses if an error occurs while
@ -87,11 +88,11 @@ class MasterServer(object):
def join(self): def join(self):
"""Blocks until the server has shut down. """Blocks until the server has shut down.
This is useful when starting a dedicated master process. This is useful when starting a dedicated dispatch process.
``` ```
master = tf.data.experimental.service.MasterServer(port=5050) dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
master.join() dispatcher.join()
``` ```
Raises: Raises:
@ -104,10 +105,10 @@ class MasterServer(object):
def target(self): def target(self):
"""Returns a target that can be used to connect to the server. """Returns a target that can be used to connect to the server.
>>> master = tf.data.experimental.service.MasterServer(port=0) >>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
>>> dataset = tf.data.Dataset.range(10) >>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute( >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=master.target)) ... processing_mode="parallel_epochs", service=dispatcher.target))
The returned string will be in the form protocol://address, e.g. The returned string will be in the form protocol://address, e.g.
"grpc://localhost:5050". "grpc://localhost:5050".
@ -136,7 +137,7 @@ class MasterServer(object):
return "localhost:{0}".format(self._server.bound_port()) return "localhost:{0}".format(self._server.bound_port())
def _num_workers(self): def _num_workers(self):
"""Returns the number of workers registered with the master.""" """Returns the number of workers registered with the dispatcher."""
return self._server.num_workers() return self._server.num_workers()
@ -147,15 +148,15 @@ class WorkerServer(object):
A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset` A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset`
processing for user-defined datasets, and provides the resulting elements over processing for user-defined datasets, and provides the resulting elements over
RPC. A worker is associated with a single RPC. A worker is associated with a single
`tf.data.experimental.service.MasterServer`. `tf.data.experimental.service.DispatchServer`.
>>> master = tf.data.experimental.service.MasterServer(port=0) >>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
>>> master_address = master.target.split("://")[1] >>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer( >>> worker = tf.data.experimental.service.WorkerServer(
... port=0, master_address=master_address) ... port=0, dispatcher_address=dispatcher_address)
>>> dataset = tf.data.Dataset.range(10) >>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute( >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=master.target)) ... processing_mode="parallel_epochs", service=dispatcher.target))
>>> print(list(dataset.as_numpy_iterator())) >>> print(list(dataset.as_numpy_iterator()))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
@ -164,14 +165,14 @@ class WorkerServer(object):
``` ```
worker = tf.data.experimental.service.WorkerServer( worker = tf.data.experimental.service.WorkerServer(
port=5051, master_address="grpc://localhost:5050") port=5051, dispatcher_address="grpc://localhost:5050")
worker.join() worker.join()
``` ```
""" """
def __init__(self, def __init__(self,
port, port,
master_address, dispatcher_address,
worker_address=None, worker_address=None,
protocol=None, protocol=None,
start=True): start=True):
@ -180,11 +181,12 @@ class WorkerServer(object):
Args: Args:
port: Specifies the port to bind to. A value of 0 indicates that the port: Specifies the port to bind to. A value of 0 indicates that the
worker can bind to any available port. worker can bind to any available port.
master_address: Specifies the address of the master server. dispatcher_address: Specifies the address of the dispatcher.
worker_address: (Optional.) Specifies the address of the worker server. worker_address: (Optional.) Specifies the address of the worker server.
This address is passed to the master server so that the master can tell This address is passed to the dispatcher so that the dispatcher can
clients how to connect to this worker. Defaults to `"localhost:%port%"`, tell clients how to connect to this worker. Defaults to
where `%port%` will be replaced with the port used by the worker. `"localhost:%port%"`, where `%port%` will be replaced with the port used
by the worker.
protocol: (Optional.) Specifies the protocol to be used by the server. protocol: (Optional.) Specifies the protocol to be used by the server.
Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`. Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`.
start: (Optional.) Boolean, indicating whether to start the server after start: (Optional.) Boolean, indicating whether to start the server after
@ -201,7 +203,7 @@ class WorkerServer(object):
self._protocol = protocol self._protocol = protocol
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
port, protocol, master_address, worker_address) port, protocol, dispatcher_address, worker_address)
if start: if start:
self._server.start() self._server.start()
@ -221,7 +223,7 @@ class WorkerServer(object):
``` ```
worker_server = tf.data.experimental.service.WorkerServer( worker_server = tf.data.experimental.service.WorkerServer(
port=5051, master_address="grpc://localhost:5050") port=5051, dispatcher_address="grpc://localhost:5050")
worker_server.join() worker_server.join()
``` ```

View File

@ -25,68 +25,68 @@ from tensorflow.python.platform import test
class ServerLibTest(test.TestCase): class ServerLibTest(test.TestCase):
def testStartMaster(self): def testStartDispatcher(self):
master = server_lib.MasterServer(0, start=False) dispatcher = server_lib.DispatchServer(0, start=False)
master.start() dispatcher.start()
def testMultipleStartMaster(self): def testMultipleStartDispatcher(self):
master = server_lib.MasterServer(0, start=True) dispatcher = server_lib.DispatchServer(0, start=True)
master.start() dispatcher.start()
def testStartWorker(self): def testStartWorker(self):
master = server_lib.MasterServer(0) dispatcher = server_lib.DispatchServer(0)
worker = server_lib.WorkerServer(0, master._address, start=False) worker = server_lib.WorkerServer(0, dispatcher._address, start=False)
worker.start() worker.start()
def testMultipleStartWorker(self): def testMultipleStartWorker(self):
master = server_lib.MasterServer(0) dispatcher = server_lib.DispatchServer(0)
worker = server_lib.WorkerServer(0, master._address, start=True) worker = server_lib.WorkerServer(0, dispatcher._address, start=True)
worker.start() worker.start()
def testStopMaster(self): def testStopDispatcher(self):
master = server_lib.MasterServer(0) dispatcher = server_lib.DispatchServer(0)
master._stop() dispatcher._stop()
master._stop() dispatcher._stop()
def testStopWorker(self): def testStopWorker(self):
master = server_lib.MasterServer(0) dispatcher = server_lib.DispatchServer(0)
worker = server_lib.WorkerServer(0, master._address) worker = server_lib.WorkerServer(0, dispatcher._address)
worker._stop() worker._stop()
worker._stop() worker._stop()
def testStopStartMaster(self): def testStopStartDispatcher(self):
master = server_lib.MasterServer(0) dispatcher = server_lib.DispatchServer(0)
master._stop() dispatcher._stop()
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "Server cannot be started after it has been stopped"): RuntimeError, "Server cannot be started after it has been stopped"):
master.start() dispatcher.start()
def testStopStartWorker(self): def testStopStartWorker(self):
master = server_lib.MasterServer(0) dispatcher = server_lib.DispatchServer(0)
worker = server_lib.WorkerServer(0, master._address) worker = server_lib.WorkerServer(0, dispatcher._address)
worker._stop() worker._stop()
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "Server cannot be started after it has been stopped"): RuntimeError, "Server cannot be started after it has been stopped"):
worker.start() worker.start()
def testJoinMaster(self): def testJoinDispatcher(self):
master = server_lib.MasterServer(0) dispatcher = server_lib.DispatchServer(0)
master._stop() dispatcher._stop()
master.join() dispatcher.join()
def testJoinWorker(self): def testJoinWorker(self):
master = server_lib.MasterServer(0) dispatcher = server_lib.DispatchServer(0)
worker = server_lib.WorkerServer(0, master._address) worker = server_lib.WorkerServer(0, dispatcher._address)
worker._stop() worker._stop()
worker.join() worker.join()
def testMasterNumWorkers(self): def testDispatcherNumWorkers(self):
master = server_lib.MasterServer(0) dispatcher = server_lib.DispatchServer(0)
self.assertEqual(0, master._num_workers()) self.assertEqual(0, dispatcher._num_workers())
worker1 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable worker1 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable
self.assertEqual(1, master._num_workers()) self.assertEqual(1, dispatcher._num_workers())
worker2 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable worker2 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable
self.assertEqual(2, master._num_workers()) self.assertEqual(2, dispatcher._num_workers())
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -28,13 +28,14 @@ limitations under the License.
namespace py = pybind11; namespace py = pybind11;
PYBIND11_MODULE(_pywrap_server_lib, m) { PYBIND11_MODULE(_pywrap_server_lib, m) {
py::class_<tensorflow::data::MasterGrpcDataServer>(m, "MasterGrpcDataServer") py::class_<tensorflow::data::DispatchGrpcDataServer>(m,
.def("start", &tensorflow::data::MasterGrpcDataServer::Start) "DispatchGrpcDataServer")
.def("stop", &tensorflow::data::MasterGrpcDataServer::Stop) .def("start", &tensorflow::data::DispatchGrpcDataServer::Start)
.def("join", &tensorflow::data::MasterGrpcDataServer::Join) .def("stop", &tensorflow::data::DispatchGrpcDataServer::Stop)
.def("bound_port", &tensorflow::data::MasterGrpcDataServer::BoundPort) .def("join", &tensorflow::data::DispatchGrpcDataServer::Join)
.def("bound_port", &tensorflow::data::DispatchGrpcDataServer::BoundPort)
.def("num_workers", .def("num_workers",
[](tensorflow::data::MasterGrpcDataServer* server) -> int { [](tensorflow::data::DispatchGrpcDataServer* server) -> int {
int num_workers; int num_workers;
tensorflow::Status status = server->NumWorkers(&num_workers); tensorflow::Status status = server->NumWorkers(&num_workers);
tensorflow::MaybeRaiseFromStatus(status); tensorflow::MaybeRaiseFromStatus(status);
@ -48,12 +49,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
.def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort); .def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort);
m.def( m.def(
"TF_DATA_NewMasterServer", "TF_DATA_NewDispatchServer",
[](int port, std::string protocol) [](int port, std::string protocol)
-> std::unique_ptr<tensorflow::data::MasterGrpcDataServer> { -> std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> {
std::unique_ptr<tensorflow::data::MasterGrpcDataServer> server; std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> server;
tensorflow::Status status = tensorflow::Status status =
tensorflow::data::NewMasterServer(port, protocol, &server); tensorflow::data::NewDispatchServer(port, protocol, &server);
tensorflow::MaybeRaiseFromStatus(status); tensorflow::MaybeRaiseFromStatus(status);
return server; return server;
}, },
@ -61,12 +62,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
m.def( m.def(
"TF_DATA_NewWorkerServer", "TF_DATA_NewWorkerServer",
[](int port, std::string protocol, std::string master_address, [](int port, std::string protocol, std::string dispatcher_address,
std::string worker_address) std::string worker_address)
-> std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> { -> std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> {
std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server; std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server;
tensorflow::Status status = tensorflow::data::NewWorkerServer( tensorflow::Status status = tensorflow::data::NewWorkerServer(
port, protocol, master_address, worker_address, &server); port, protocol, dispatcher_address, worker_address, &server);
tensorflow::MaybeRaiseFromStatus(status); tensorflow::MaybeRaiseFromStatus(status);
return server; return server;
}, },

View File

@ -59,23 +59,25 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
num_workers: The number of workers in the cluster. num_workers: The number of workers in the cluster.
Returns: Returns:
The address of the master. The address of the dispatcher.
""" """
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
self._servers = [] self._servers = []
for _ in range(num_workers): for _ in range(num_workers):
self._servers.append( self._servers.append(
server_lib.WorkerServer( server_lib.WorkerServer(
port=0, master_address=self._master._address, protocol=PROTOCOL)) port=0,
dispatcher_address=self._dispatcher._address,
protocol=PROTOCOL))
return self._master._address return self._dispatcher._address
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testDistributeBasic(self): def testDistributeBasic(self):
num_elements = 10 num_elements = 10
master_address = self.create_cluster(1) dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address) ds = _make_distributed_dataset(ds, dispatcher_address)
results = [elem.numpy() for elem in ds] results = [elem.numpy() for elem in ds]
self.assertEqual(list(range(num_elements)), results) self.assertEqual(list(range(num_elements)), results)
@ -83,10 +85,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testDifferentShuffleOrders(self): def testDifferentShuffleOrders(self):
random_seed.set_random_seed(None) random_seed.set_random_seed(None)
num_elements = 100 num_elements = 100
master_address = self.create_cluster(2) dispatcher_address = self.create_cluster(2)
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = ds.shuffle(num_elements) ds = ds.shuffle(num_elements)
ds = _make_distributed_dataset(ds, master_address) ds = _make_distributed_dataset(ds, dispatcher_address)
output = [elem.numpy() for elem in ds] output = [elem.numpy() for elem in ds]
# The output will be two sequences of range(num_elements) # The output will be two sequences of range(num_elements)
@ -104,9 +106,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testMultipleEpochs(self): def testMultipleEpochs(self):
num_elements = 3 num_elements = 3
master_address = self.create_cluster(1) dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address) ds = _make_distributed_dataset(ds, dispatcher_address)
for _ in range(10): for _ in range(10):
self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds]) self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds])
@ -114,9 +116,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testRepeatedDataset(self): def testRepeatedDataset(self):
num_elements = 10 num_elements = 10
num_repetitions = 5 num_repetitions = 5
master_address = self.create_cluster(1) dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address) ds = _make_distributed_dataset(ds, dispatcher_address)
ds = ds.repeat(num_repetitions) ds = ds.repeat(num_repetitions)
self.assertDatasetProduces( self.assertDatasetProduces(
ds, expected_output=num_repetitions * list(range(num_elements))) ds, expected_output=num_repetitions * list(range(num_elements)))
@ -125,12 +127,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testConcurrentEpoch(self): def testConcurrentEpoch(self):
num_elements = 10 num_elements = 10
num_datasets = 3 num_datasets = 3
master_address = self.create_cluster(1) dispatcher_address = self.create_cluster(1)
iterators = [] iterators = []
results = [] results = []
for _ in range(num_datasets): for _ in range(num_datasets):
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address) ds = _make_distributed_dataset(ds, dispatcher_address)
iterators.append(iter(ds)) iterators.append(iter(ds))
results.append([]) results.append([])
@ -146,9 +148,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
self.skipTest("Not yet implemented") self.skipTest("Not yet implemented")
num_elements = 10 num_elements = 10
num_iterators = 3 num_iterators = 3
master_address = self.create_cluster(1) dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address) ds = _make_distributed_dataset(ds, dispatcher_address)
result = [] result = []
iterators = [] iterators = []
for _ in range(num_iterators): for _ in range(num_iterators):
@ -170,20 +172,20 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testMultiWorker(self): def testMultiWorker(self):
num_workers = 3 num_workers = 3
num_elements = 10 num_elements = 10
master_address = self.create_cluster(num_workers) dispatcher_address = self.create_cluster(num_workers)
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address) ds = _make_distributed_dataset(ds, dispatcher_address)
results = [elem.numpy() for elem in ds] results = [elem.numpy() for elem in ds]
self.assertCountEqual(num_workers * list(range(num_elements)), results) self.assertCountEqual(num_workers * list(range(num_elements)), results)
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testAddWorkerMidJob(self): def testAddWorkerMidJob(self):
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
self._worker = server_lib.WorkerServer( self._worker = server_lib.WorkerServer(
port=0, master_address=self._master._address, protocol=PROTOCOL) port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
num_elements = 100 num_elements = 100
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, self._master._address) ds = _make_distributed_dataset(ds, self._dispatcher._address)
iterator = iter(ds) iterator = iter(ds)
results = [] results = []
# Read halfway through the dataset. # Read halfway through the dataset.
@ -191,10 +193,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
results.append(next(iterator).numpy()) results.append(next(iterator).numpy())
self._new_worker = server_lib.WorkerServer( self._new_worker = server_lib.WorkerServer(
port=0, master_address=self._master._address, protocol=PROTOCOL) port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
# Wait for the new worker to register with the master. # Wait for the new worker to register with the dispatcher.
while self._master._num_workers() < 2: while self._dispatcher._num_workers() < 2:
time.sleep(10 / 1000) # 10ms time.sleep(10 / 1000) # 10ms
for elem in iterator: for elem in iterator:
@ -206,12 +208,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
combinations.times(test_base.eager_only_combinations(), combinations.times(test_base.eager_only_combinations(),
combinations.combine(use_same_port=[True, False]))) combinations.combine(use_same_port=[True, False])))
def testRestartWorker(self, use_same_port): def testRestartWorker(self, use_same_port):
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
self._worker = server_lib.WorkerServer( self._worker = server_lib.WorkerServer(
port=0, master_address=self._master._address, protocol=PROTOCOL) port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
num_elements = 100 num_elements = 100
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, self._master._address) ds = _make_distributed_dataset(ds, self._dispatcher._address)
iterator = iter(ds) iterator = iter(ds)
# Read halfway through the dataset. # Read halfway through the dataset.
midpoint = num_elements // 2 midpoint = num_elements // 2
@ -224,7 +226,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
port = int(self._worker._address.split(":")[1]) port = int(self._worker._address.split(":")[1])
self._worker._stop() self._worker._stop()
self._new_worker = server_lib.WorkerServer( self._new_worker = server_lib.WorkerServer(
port=port, master_address=self._master._address, protocol=PROTOCOL) port=port,
dispatcher_address=self._dispatcher._address,
protocol=PROTOCOL)
# There may have been some elements prefetched from the first worker # There may have been some elements prefetched from the first worker
# before it was stopped. # before it was stopped.
@ -259,12 +263,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testInsideFunction(self): def testInsideFunction(self):
num_workers = 3 num_workers = 3
num_elements = 10 num_elements = 10
master_address = self.create_cluster(num_workers) dispatcher_address = self.create_cluster(num_workers)
@def_function.function @def_function.function
def f(): def f():
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address) ds = _make_distributed_dataset(ds, dispatcher_address)
result = tensor_array_ops.TensorArray( result = tensor_array_ops.TensorArray(
dtypes.int64, size=num_workers * num_elements, dynamic_size=True) dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
i = 0 i = 0
@ -279,10 +283,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testSharedJobName(self): def testSharedJobName(self):
num_elements = 100 num_elements = 100
master_address = self.create_cluster(1) dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name") ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name") ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
iter1 = iter(ds1) iter1 = iter(ds1)
iter2 = iter(ds2) iter2 = iter(ds2)
results = [] results = []
@ -298,20 +302,22 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testDifferentJobNames(self): def testDifferentJobNames(self):
num_elements = 10 num_elements = 10
master_address = self.create_cluster(1) dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name1") ds1 = _make_distributed_dataset(
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name2") ds, dispatcher_address, job_name="job_name1")
ds2 = _make_distributed_dataset(
ds, dispatcher_address, job_name="job_name2")
self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds1, list(range(num_elements)))
self.assertDatasetProduces(ds2, list(range(num_elements))) self.assertDatasetProduces(ds2, list(range(num_elements)))
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testSharedJobNameMultiIteration(self): def testSharedJobNameMultiIteration(self):
num_elements = 10 num_elements = 10
master_address = self.create_cluster(1) dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name") ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name") ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
# iteration 1 # iteration 1
self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds1, list(range(num_elements)))
self.assertDatasetProduces(ds2, []) self.assertDatasetProduces(ds2, [])
@ -323,11 +329,11 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testSharedJobNameRepeat(self): def testSharedJobNameRepeat(self):
num_elements = 100 num_elements = 100
num_repetitions = 3 num_repetitions = 3
master_address = self.create_cluster(1) dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name") ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
ds1 = ds1.repeat(num_repetitions) ds1 = ds1.repeat(num_repetitions)
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name") ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
ds2 = ds2.repeat(num_repetitions) ds2 = ds2.repeat(num_repetitions)
results = [] results = []
iter1 = iter(ds1) iter1 = iter(ds1)
@ -345,7 +351,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testApplyDeterminismOption(self): def testApplyDeterminismOption(self):
elements = list(range(10)) elements = list(range(10))
master_address = self.create_cluster(1) dispatcher_address = self.create_cluster(1)
def dataset_fn(delay_ms): def dataset_fn(delay_ms):
@ -362,7 +368,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
opts = dataset_ops.Options() opts = dataset_ops.Options()
opts.experimental_deterministic = False opts.experimental_deterministic = False
ds = ds.with_options(opts) ds = ds.with_options(opts)
ds = _make_distributed_dataset(ds, master_address) ds = _make_distributed_dataset(ds, dispatcher_address)
return ds return ds
self.checkDeterminism( self.checkDeterminism(
@ -379,8 +385,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
options.experimental_external_state_policy = external_state_policy options.experimental_external_state_policy = external_state_policy
ds = ds.with_options(options) ds = ds.with_options(options)
master_address = self.create_cluster(3) dispatcher_address = self.create_cluster(3)
ds = _make_distributed_dataset(ds, master_address) ds = _make_distributed_dataset(ds, dispatcher_address)
next(iter(ds)) next(iter(ds))
@combinations.generate( @combinations.generate(
@ -400,12 +406,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testDistributeFromInterleave(self): def testDistributeFromInterleave(self):
master_address = self.create_cluster(1) dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(2) ds = dataset_ops.Dataset.range(2)
def interleave_fn(_): def interleave_fn(_):
ds = dataset_ops.Dataset.range(2) ds = dataset_ops.Dataset.range(2)
_make_distributed_dataset(ds, master_address) _make_distributed_dataset(ds, dispatcher_address)
return ds return ds
with self.assertRaisesRegex( with self.assertRaisesRegex(

View File

@ -1,6 +1,6 @@
path: "tensorflow.data.experimental.service.MasterServer" path: "tensorflow.data.experimental.service.DispatchServer"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.MasterServer\'>" is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.DispatchServer\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "target" name: "target"

View File

@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'port\', \'master_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], " argspec: "args=[\'self\', \'port\', \'dispatcher_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "join" name: "join"

View File

@ -1,7 +1,7 @@
path: "tensorflow.data.experimental.service" path: "tensorflow.data.experimental.service"
tf_module { tf_module {
member { member {
name: "MasterServer" name: "DispatchServer"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member { member {

View File

@ -99,8 +99,8 @@ tensorflow::data::GrpcDataServerBase::Join
tensorflow::data::GrpcDataServerBase::Start tensorflow::data::GrpcDataServerBase::Start
tensorflow::data::GrpcDataServerBase::Stop tensorflow::data::GrpcDataServerBase::Stop
tensorflow::data::GrpcDataServerBase::BoundPort tensorflow::data::GrpcDataServerBase::BoundPort
tensorflow::data::MasterGrpcDataServer::NumWorkers tensorflow::data::DispatchGrpcDataServer::NumWorkers
tensorflow::data::NewMasterServer tensorflow::data::NewDispatchServer
tensorflow::data::NewWorkerServer tensorflow::data::NewWorkerServer
[protos_all] # device_lib, dtypes [protos_all] # device_lib, dtypes