Update "master" to "dispatch"/"dispatcher" in tf.data service terminology.

Dispatcher is more descriptive and follows the guidance in https://developers.google.com/style/word-list#master

PiperOrigin-RevId: 321613785
Change-Id: Iaa576d35f0581e21278101f8b31201ba737a6865
This commit is contained in:
Andrew Audibert 2020-07-16 11:52:42 -07:00
parent 0aa1d61fad
commit 6b9a9d98bb
30 changed files with 367 additions and 353 deletions

View File

@ -28,8 +28,8 @@ tf_proto_library(
)
tf_proto_library(
name = "master_proto",
srcs = ["master.proto"],
name = "dispatcher_proto",
srcs = ["dispatcher.proto"],
has_services = 1,
cc_api_version = 2,
protodeps = tf_additional_all_protos() + [
@ -49,17 +49,17 @@ tf_proto_library(
)
cc_library(
name = "master_impl",
srcs = ["master_impl.cc"],
name = "dispatcher_impl",
srcs = ["dispatcher_impl.cc"],
hdrs = [
"master_impl.h",
"dispatcher_impl.h",
],
deps = [
":common_proto_cc",
":credentials_factory",
":data_service",
":dispatcher_proto_cc",
":grpc_util",
":master_proto_cc",
":worker_cc_grpc_proto",
":worker_proto_cc",
"//tensorflow/c:c_api_internal",
@ -86,9 +86,9 @@ cc_library(
deps = [
":common_proto_cc",
":credentials_factory",
":dispatcher_cc_grpc_proto",
":dispatcher_proto_cc",
":grpc_util",
":master_cc_grpc_proto",
":master_proto_cc",
":worker_proto_cc",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status_helper",
@ -207,12 +207,12 @@ tf_cc_test(
)
cc_library(
name = "grpc_master_impl",
srcs = ["grpc_master_impl.cc"],
hdrs = ["grpc_master_impl.h"],
name = "grpc_dispatcher_impl",
srcs = ["grpc_dispatcher_impl.cc"],
hdrs = ["grpc_dispatcher_impl.h"],
deps = [
":master_cc_grpc_proto",
":master_impl",
":dispatcher_cc_grpc_proto",
":dispatcher_impl",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
tf_grpc_cc_dependency(),
],
@ -250,7 +250,7 @@ cc_library(
],
deps = [
":credentials_factory",
":grpc_master_impl",
":grpc_dispatcher_impl",
":grpc_util",
":grpc_worker_impl",
"//tensorflow/core:lib",
@ -268,9 +268,9 @@ cc_library(
],
deps = [
":credentials_factory",
":dispatcher_cc_grpc_proto",
":dispatcher_proto_cc",
":grpc_util",
":master_cc_grpc_proto",
":master_proto_cc",
":worker_cc_grpc_proto",
":worker_proto_cc",
"//tensorflow/core:framework",
@ -287,12 +287,12 @@ tf_cc_test(
tags = ["no_windows"],
deps = [
":data_service",
":grpc_master_impl",
":dispatcher_cc_grpc_proto",
":dispatcher_proto_cc",
":grpc_dispatcher_impl",
":grpc_util",
":grpc_worker_impl",
":local_credentials_factory",
":master_cc_grpc_proto",
":master_proto_cc",
":server_lib",
":test_cluster",
":test_util",
@ -309,11 +309,11 @@ tf_cc_test(
)
cc_grpc_library(
name = "master_cc_grpc_proto",
srcs = [":master_proto"],
name = "dispatcher_cc_grpc_proto",
srcs = [":dispatcher_proto"],
generate_mocks = True,
grpc_only = True,
deps = [":master_proto_cc"],
deps = [":dispatcher_proto_cc"],
)
cc_grpc_library(

View File

@ -18,8 +18,8 @@ limitations under the License.
#include "grpcpp/create_channel.h"
#include "grpcpp/security/credentials.h"
#include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h"
@ -54,8 +54,8 @@ std::string ProcessingModeToString(ProcessingMode mode) {
}
}
Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
int64* dataset_id) {
Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset,
int64* dataset_id) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetOrRegisterDatasetRequest req;
*req.mutable_dataset()->mutable_graph() = dataset;
@ -69,9 +69,9 @@ Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
return Status::OK();
}
Status DataServiceMasterClient::CreateJob(int64 dataset_id,
ProcessingMode processing_mode,
int64* job_id) {
Status DataServiceDispatcherClient::CreateJob(int64 dataset_id,
ProcessingMode processing_mode,
int64* job_id) {
TF_RETURN_IF_ERROR(EnsureInitialized());
CreateJobRequest req;
req.set_dataset_id(dataset_id);
@ -88,11 +88,9 @@ Status DataServiceMasterClient::CreateJob(int64 dataset_id,
return Status::OK();
}
Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id,
ProcessingMode processing_mode,
const std::string& job_name,
int job_name_index,
int64* job_id) {
Status DataServiceDispatcherClient::GetOrCreateJob(
int64 dataset_id, ProcessingMode processing_mode,
const std::string& job_name, int job_name_index, int64* job_id) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetOrCreateJobRequest req;
req.set_dataset_id(dataset_id);
@ -112,9 +110,9 @@ Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id,
return Status::OK();
}
Status DataServiceMasterClient::GetTasks(int64 job_id,
std::vector<TaskInfo>* tasks,
bool* job_finished) {
Status DataServiceDispatcherClient::GetTasks(int64 job_id,
std::vector<TaskInfo>* tasks,
bool* job_finished) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetTasksRequest req;
req.set_job_id(job_id);
@ -132,7 +130,8 @@ Status DataServiceMasterClient::GetTasks(int64 job_id,
return Status::OK();
}
Status DataServiceMasterClient::GetWorkers(std::vector<WorkerInfo>* workers) {
Status DataServiceDispatcherClient::GetWorkers(
std::vector<WorkerInfo>* workers) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetWorkersRequest req;
GetWorkersResponse resp;
@ -148,12 +147,12 @@ Status DataServiceMasterClient::GetWorkers(std::vector<WorkerInfo>* workers) {
return Status::OK();
}
Status DataServiceMasterClient::EnsureInitialized() {
Status DataServiceDispatcherClient::EnsureInitialized() {
std::shared_ptr<grpc::ChannelCredentials> credentials;
TF_RETURN_IF_ERROR(
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
auto channel = grpc::CreateChannel(address_, credentials);
stub_ = MasterService::NewStub(channel);
stub_ = DispatcherService::NewStub(channel);
return Status::OK();
}
@ -187,10 +186,11 @@ Status DataServiceWorkerClient::EnsureInitialized() {
return Status::OK();
}
Status CreateDataServiceMasterClient(
Status CreateDataServiceDispatcherClient(
const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceMasterClient>* out) {
auto client = absl::make_unique<DataServiceMasterClient>(address, protocol);
std::unique_ptr<DataServiceDispatcherClient>* out) {
auto client =
absl::make_unique<DataServiceDispatcherClient>(address, protocol);
TF_RETURN_IF_ERROR(client->Initialize());
*out = std::move(client);
return Status::OK();

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
#define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -67,11 +67,11 @@ class DataServiceClientBase {
const std::string protocol_;
};
// Client for communicating with the tf.data service master.
class DataServiceMasterClient : public DataServiceClientBase {
// Client for communicating with the tf.data service dispatcher.
class DataServiceDispatcherClient : public DataServiceClientBase {
public:
DataServiceMasterClient(const std::string& address,
const std::string& protocol)
DataServiceDispatcherClient(const std::string& address,
const std::string& protocol)
: DataServiceClientBase(address, protocol) {}
// Registers a dataset with the tf.data service, and stores the generated
@ -90,13 +90,13 @@ class DataServiceMasterClient : public DataServiceClientBase {
const std::string& job_name, int job_name_index,
int64* job_id);
// Queries the master for the tasks associated with the specified job.
// Queries the dispatcher for the tasks associated with the specified job.
// The tasks will be stored in *tasks, and whether the job is finished will
// be stored in `*job_finished`.
Status GetTasks(int64 job_id, std::vector<TaskInfo>* tasks,
bool* job_finished);
// Queries the master for its registered workers. The worker info will be
// Queries the dispatcher for its registered workers. The worker info will be
// stored in `*workers`.
Status GetWorkers(std::vector<WorkerInfo>* workers);
@ -104,7 +104,7 @@ class DataServiceMasterClient : public DataServiceClientBase {
Status EnsureInitialized() override;
private:
std::unique_ptr<MasterService::Stub> stub_;
std::unique_ptr<DispatcherService::Stub> stub_;
};
// Client for communicating with the tf.data service worker.
@ -127,10 +127,10 @@ class DataServiceWorkerClient : public DataServiceClientBase {
std::unique_ptr<WorkerService::Stub> stub_;
};
// Creates and initializes a new tf.data service master client.
Status CreateDataServiceMasterClient(
// Creates and initializes a new tf.data service dispatcher client.
Status CreateDataServiceDispatcherClient(
const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceMasterClient>* out);
std::unique_ptr<DataServiceDispatcherClient>* out);
// Creates and initializes a new tf.data service worker client.
Status CreateDataServiceWorkerClient(

View File

@ -19,9 +19,9 @@ limitations under the License.
#include "grpcpp/security/credentials.h"
#include "absl/strings/str_split.h"
#include "tensorflow/core/data/compression_utils.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/master.pb.h"
#include "tensorflow/core/data/service/server_lib.h"
#include "tensorflow/core/data/service/test_cluster.h"
#include "tensorflow/core/data/service/test_util.h"
@ -66,9 +66,10 @@ TEST(DataService, ProcessingModeToString) {
TEST(DataService, GetWorkers) {
TestCluster cluster(1);
TF_ASSERT_OK(cluster.Initialize());
DataServiceMasterClient master(cluster.MasterAddress(), kProtocol);
DataServiceDispatcherClient dispatcher(cluster.DispatcherAddress(),
kProtocol);
std::vector<WorkerInfo> workers;
TF_EXPECT_OK(master.GetWorkers(&workers));
TF_EXPECT_OK(dispatcher.GetWorkers(&workers));
EXPECT_EQ(1, workers.size());
}

View File

@ -110,11 +110,11 @@ message GetWorkersResponse {
repeated WorkerInfo workers = 1;
}
service MasterService {
// Registers a worker with the master.
service DispatcherService {
// Registers a worker with the dispatcher.
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse);
// Updates the master with information about the worker's state.
// Updates the dispatcher with information about the worker's state.
rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse);
// Registers a dataset with the server, or returns its id if it is already
@ -134,6 +134,6 @@ service MasterService {
// Reports a list of all tasks for a job.
rpc GetTasks(GetTasksRequest) returns (GetTasksResponse);
// Reports a list of all workers registered with the master.
// Reports a list of all workers registered with the dispatcher.
rpc GetWorkers(GetWorkersRequest) returns (GetWorkersResponse);
}

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/data/service/master_impl.h"
#include "tensorflow/core/data/service/dispatcher_impl.h"
#include <memory>
#include <tuple>
@ -26,8 +26,8 @@ limitations under the License.
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
@ -53,10 +53,10 @@ Status CreateWorkerStub(const std::string& address,
}
} // namespace
DataServiceMasterImpl::DataServiceMasterImpl(const std::string protocol)
DataServiceDispatcherImpl::DataServiceDispatcherImpl(const std::string protocol)
: protocol_(protocol) {}
Status DataServiceMasterImpl::RegisterWorker(
Status DataServiceDispatcherImpl::RegisterWorker(
const RegisterWorkerRequest* request, RegisterWorkerResponse* response) {
VLOG(3) << "Received register worker request";
mutex_lock l(mu_);
@ -86,8 +86,8 @@ Status DataServiceMasterImpl::RegisterWorker(
return Status::OK();
}
Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request,
WorkerUpdateResponse* response) {
Status DataServiceDispatcherImpl::WorkerUpdate(
const WorkerUpdateRequest* request, WorkerUpdateResponse* response) {
mutex_lock l(mu_);
int64 worker_id = request->worker_id();
for (auto& update : request->updates()) {
@ -106,7 +106,7 @@ Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request,
return Status::OK();
}
Status DataServiceMasterImpl::GetOrRegisterDataset(
Status DataServiceDispatcherImpl::GetOrRegisterDataset(
const GetOrRegisterDatasetRequest* request,
GetOrRegisterDatasetResponse* response) {
uint64 fingerprint;
@ -128,8 +128,8 @@ Status DataServiceMasterImpl::GetOrRegisterDataset(
return Status::OK();
}
int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
const DatasetDef& dataset)
int64 DataServiceDispatcherImpl::RegisterDataset(uint64 fingerprint,
const DatasetDef& dataset)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64 dataset_id = next_dataset_id_++;
auto new_dataset =
@ -142,8 +142,8 @@ int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
return dataset_id;
}
Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
CreateJobResponse* response) {
Status DataServiceDispatcherImpl::CreateJob(const CreateJobRequest* request,
CreateJobResponse* response) {
VLOG(3) << "Received create job request for dataset id "
<< request->dataset_id();
ProcessingMode processing_mode = ProcessingMode(request->processing_mode());
@ -157,7 +157,7 @@ Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
return Status::OK();
}
Status DataServiceMasterImpl::GetOrCreateJob(
Status DataServiceDispatcherImpl::GetOrCreateJob(
const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
VLOG(3) << "Received get or create job request for dataset id "
<< request->dataset_id() << " with name " << request->job_name()
@ -193,7 +193,7 @@ Status DataServiceMasterImpl::GetOrCreateJob(
}
// Validates that the job matches the given processing_mode and dataset_id.
Status DataServiceMasterImpl::ValidateMatchingJob(
Status DataServiceDispatcherImpl::ValidateMatchingJob(
const Job& job, ProcessingMode processing_mode, int64 dataset_id) {
DCHECK(job.name().has_value());
std::string job_name = job.name().value();
@ -214,10 +214,10 @@ Status DataServiceMasterImpl::ValidateMatchingJob(
return Status::OK();
}
Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
ProcessingMode processing_mode,
absl::optional<std::string> job_name,
int64* out_job_id) LOCKS_EXCLUDED(mu_) {
Status DataServiceDispatcherImpl::CreateJob(
int64 dataset_id, ProcessingMode processing_mode,
absl::optional<std::string> job_name, int64* out_job_id)
LOCKS_EXCLUDED(mu_) {
switch (processing_mode) {
case ProcessingMode::PARALLEL_EPOCHS:
break;
@ -274,14 +274,16 @@ Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
return Status::OK();
}
const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTask(
const DataServiceDispatcherImpl::Task& DataServiceDispatcherImpl::CreateTask(
Job* job, const std::string& worker_address) LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
return CreateTaskLocked(job, worker_address);
}
const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTaskLocked(
Job* job, const std::string& worker_address) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
const DataServiceDispatcherImpl::Task&
DataServiceDispatcherImpl::CreateTaskLocked(Job* job,
const std::string& worker_address)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64 task_id = next_task_id_++;
DCHECK(!tasks_.contains(task_id));
tasks_.insert({task_id, Task(task_id, job->job_id(), job->dataset_id(),
@ -290,7 +292,7 @@ const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTaskLocked(
return tasks_.at(task_id);
}
Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) {
Status DataServiceDispatcherImpl::EnsureWorkerStubInitialized(Worker* worker) {
if (!worker->stub()) {
std::unique_ptr<WorkerService::Stub> stub;
TF_RETURN_IF_ERROR(CreateWorkerStub(worker->address(), protocol_, &stub));
@ -299,8 +301,8 @@ Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) {
return Status::OK();
}
Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task,
Worker* worker)
Status DataServiceDispatcherImpl::AllocateTaskToWorker(const Task& task,
Worker* worker)
LOCKS_EXCLUDED(mu_) {
TF_RETURN_IF_ERROR(EnsureWorkerStubInitialized(worker));
grpc::ClientContext client_ctx;
@ -322,8 +324,8 @@ Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task,
return Status::OK();
}
Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
GetTasksResponse* response) {
Status DataServiceDispatcherImpl::GetTasks(const GetTasksRequest* request,
GetTasksResponse* response) {
mutex_lock l(mu_);
VLOG(3) << "Looking up tasks for job id " << request->job_id();
auto it = jobs_.find(request->job_id());
@ -346,8 +348,8 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
return Status::OK();
}
Status DataServiceMasterImpl::GetWorkers(const GetWorkersRequest* request,
GetWorkersResponse* response) {
Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request,
GetWorkersResponse* response) {
mutex_lock l(mu_);
VLOG(3) << "Enter GetWorkers";
for (auto& worker : workers_) {

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_
#define TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
#define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/data/service/master.pb.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
@ -40,11 +40,11 @@ namespace data {
// ProcessingModeDef which determines what data it produces.
// * Task: A job is broken into multiple tasks, which each represent
// iterating over all of or part of the dataset. Workers process tasks.
class DataServiceMasterImpl {
class DataServiceDispatcherImpl {
public:
explicit DataServiceMasterImpl(const std::string protocol);
explicit DataServiceDispatcherImpl(const std::string protocol);
// See master.proto for API documentation.
// See dispatcher.proto for API documentation.
/// Worker-facing API.
Status RegisterWorker(const RegisterWorkerRequest* request,
@ -191,7 +191,7 @@ class DataServiceMasterImpl {
// Creates a new task for a job, returning a reference to the task.
const Task& CreateTask(Job* job, const std::string& worker_address)
LOCKS_EXCLUDED(mu_);
// Same as `CreateTask`, but expects that the master lock is already held.
// Same as `CreateTask`, but expects that the dispatcher lock is already held.
const Task& CreateTaskLocked(Job* job, const std::string& worker_address)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Validates that an existing job matches the given processing_mode and
@ -225,10 +225,10 @@ class DataServiceMasterImpl {
absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_
TF_GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceMasterImpl);
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl);
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_
#endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/data/service/grpc_master_impl.h"
#include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
#include "grpcpp/server_context.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
@ -25,18 +25,18 @@ using ::grpc::ServerBuilder;
using ::grpc::ServerContext;
using ::grpc::Status;
GrpcMasterImpl::GrpcMasterImpl(ServerBuilder* server_builder,
const std::string& protocol)
GrpcDispatcherImpl::GrpcDispatcherImpl(ServerBuilder* server_builder,
const std::string& protocol)
: impl_(protocol) {
server_builder->RegisterService(this);
VLOG(1) << "Registered data service master";
VLOG(1) << "Registered data service dispatcher";
}
#define HANDLER(method) \
Status GrpcMasterImpl::method(ServerContext* context, \
const method##Request* request, \
method##Response* response) { \
return ToGrpcStatus(impl_.method(request, response)); \
#define HANDLER(method) \
Status GrpcDispatcherImpl::method(ServerContext* context, \
const method##Request* request, \
method##Response* response) { \
return ToGrpcStatus(impl_.method(request, response)); \
}
HANDLER(RegisterWorker);
HANDLER(WorkerUpdate);

View File

@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_
#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_
#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_
#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_
#include "grpcpp/server_builder.h"
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/master_impl.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/dispatcher_impl.h"
namespace tensorflow {
namespace data {
@ -29,14 +29,14 @@ namespace data {
//
// ::grpc::ServerBuilder builder;
// // configure builder
// GrpcMasterImpl data_service(&builder);
// GrpcDispatcherImpl data_service(&builder);
// builder.BuildAndStart()
//
class GrpcMasterImpl : public MasterService::Service {
class GrpcDispatcherImpl : public DispatcherService::Service {
public:
explicit GrpcMasterImpl(grpc::ServerBuilder* server_builder,
const std::string& protocol);
~GrpcMasterImpl() override {}
explicit GrpcDispatcherImpl(grpc::ServerBuilder* server_builder,
const std::string& protocol);
~GrpcDispatcherImpl() override {}
#define HANDLER(method) \
grpc::Status method(grpc::ServerContext* context, \
@ -52,12 +52,12 @@ class GrpcMasterImpl : public MasterService::Service {
#undef HANDLER
private:
DataServiceMasterImpl impl_;
DataServiceDispatcherImpl impl_;
TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterImpl);
TF_DISALLOW_COPY_AND_ASSIGN(GrpcDispatcherImpl);
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_
#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_

View File

@ -26,9 +26,9 @@ using ::grpc::ServerContext;
using ::grpc::Status;
GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder,
const std::string& master_address,
const std::string& dispatcher_address,
const std::string& protocol)
: impl_(master_address, protocol) {
: impl_(dispatcher_address, protocol) {
server_builder->RegisterService(this);
VLOG(1) << "Registered data service worker";
}

View File

@ -35,7 +35,7 @@ namespace data {
class GrpcWorkerImpl : public WorkerService::Service {
public:
explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder,
const std::string& master_address,
const std::string& dispatcher_address,
const std::string& protocol);
~GrpcWorkerImpl() override {}

View File

@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/core/data/service/server_lib.h"
#include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/grpc_master_impl.h"
#include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/grpc_worker_impl.h"
#include "tensorflow/core/platform/errors.h"
@ -72,18 +72,18 @@ void GrpcDataServerBase::Join() { server_->Wait(); }
int GrpcDataServerBase::BoundPort() { return bound_port(); }
MasterGrpcDataServer::MasterGrpcDataServer(int port,
const std::string& protocol)
DispatchGrpcDataServer::DispatchGrpcDataServer(int port,
const std::string& protocol)
: GrpcDataServerBase(port, protocol) {}
MasterGrpcDataServer::~MasterGrpcDataServer() { delete service_; }
DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
void MasterGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
auto service = absl::make_unique<GrpcMasterImpl>(builder, protocol_);
void DispatchGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
auto service = absl::make_unique<GrpcDispatcherImpl>(builder, protocol_);
service_ = service.release();
}
Status MasterGrpcDataServer::NumWorkers(int* num_workers) {
Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
GetWorkersRequest req;
GetWorkersResponse resp;
grpc::ServerContext ctx;
@ -95,19 +95,18 @@ Status MasterGrpcDataServer::NumWorkers(int* num_workers) {
return Status::OK();
}
WorkerGrpcDataServer::WorkerGrpcDataServer(int port,
const std::string& protocol,
const std::string& master_address,
const std::string& worker_address)
WorkerGrpcDataServer::WorkerGrpcDataServer(
int port, const std::string& protocol,
const std::string& dispatcher_address, const std::string& worker_address)
: GrpcDataServerBase(port, protocol),
master_address_(master_address),
dispatcher_address_(dispatcher_address),
worker_address_(worker_address) {}
WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
void WorkerGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
auto service =
absl::make_unique<GrpcWorkerImpl>(builder, master_address_, protocol_);
auto service = absl::make_unique<GrpcWorkerImpl>(builder, dispatcher_address_,
protocol_);
service_ = service.release();
}
@ -123,25 +122,25 @@ Status WorkerGrpcDataServer::StartServiceInternal() {
return Status::OK();
}
Status NewMasterServer(int port, const std::string& protocol,
std::unique_ptr<MasterGrpcDataServer>* out_server) {
*out_server = absl::make_unique<MasterGrpcDataServer>(port, protocol);
Status NewDispatchServer(int port, const std::string& protocol,
std::unique_ptr<DispatchGrpcDataServer>* out_server) {
*out_server = absl::make_unique<DispatchGrpcDataServer>(port, protocol);
return Status::OK();
}
Status NewWorkerServer(int port, const std::string& protocol,
const std::string& master_address,
const std::string& dispatcher_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
return NewWorkerServer(port, protocol, master_address, /*worker_address=*/"",
out_server);
return NewWorkerServer(port, protocol, dispatcher_address,
/*worker_address=*/"", out_server);
}
Status NewWorkerServer(int port, const std::string& protocol,
const std::string& master_address,
const std::string& dispatcher_address,
const std::string& worker_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
*out_server = absl::make_unique<WorkerGrpcDataServer>(
port, protocol, master_address, worker_address);
port, protocol, dispatcher_address, worker_address);
return Status::OK();
}

View File

@ -25,7 +25,7 @@ namespace data {
// Forward declared because transitively depending on .grpc.pb.h files causes
// issues in the pywrap build.
class GrpcMasterImpl;
class GrpcDispatcherImpl;
class GrpcWorkerImpl;
// A grpc server for the tf.data service.
@ -35,7 +35,7 @@ class GrpcDataServerBase {
// server will find an available port in `Start()`. The chosen port can be
// found in the output of `Target()`.
//
// master_address is only needed for worker data servers.
// dispatcher_address is only needed for worker data servers.
GrpcDataServerBase(int requested_port, const std::string& protocol);
virtual ~GrpcDataServerBase() {}
@ -70,12 +70,12 @@ class GrpcDataServerBase {
std::unique_ptr<grpc::Server> server_;
};
class MasterGrpcDataServer : public GrpcDataServerBase {
class DispatchGrpcDataServer : public GrpcDataServerBase {
public:
MasterGrpcDataServer(int requested_port, const std::string& protocol);
~MasterGrpcDataServer() override;
DispatchGrpcDataServer(int requested_port, const std::string& protocol);
~DispatchGrpcDataServer() override;
// Returns the number of workers registerd with the master.
// Returns the number of workers registerd with the dispatcher.
Status NumWorkers(int* num_workers);
protected:
@ -83,14 +83,14 @@ class MasterGrpcDataServer : public GrpcDataServerBase {
Status StartServiceInternal() override { return Status::OK(); }
private:
// Owned. We use a raw pointer because GrpcMasterImpl is forward-declared.
GrpcMasterImpl* service_;
// Owned. We use a raw pointer because GrpcDispatcherImpl is forward-declared.
GrpcDispatcherImpl* service_;
};
class WorkerGrpcDataServer : public GrpcDataServerBase {
public:
WorkerGrpcDataServer(int requested_port, const std::string& protocol,
const std::string& master_address,
const std::string& dispatcher_address,
const std::string& worker_address);
~WorkerGrpcDataServer() override;
@ -99,15 +99,15 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
Status StartServiceInternal() override;
private:
const std::string master_address_;
const std::string dispatcher_address_;
const std::string worker_address_;
// Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared.
GrpcWorkerImpl* service_;
};
// Creates a master tf.data server and stores it in `*out_server`.
Status NewMasterServer(int port, const std::string& protocol,
std::unique_ptr<MasterGrpcDataServer>* out_server);
// Creates a dispatch tf.data server and stores it in `*out_server`.
Status NewDispatchServer(int port, const std::string& protocol,
std::unique_ptr<DispatchGrpcDataServer>* out_server);
// Creates a worker tf.data server and stores it in `*out_server`.
//
@ -115,18 +115,18 @@ Status NewMasterServer(int port, const std::string& protocol,
// will be chosen in Start(). This value can be queried with BoundPort().
//
// The worker_address argument is optional. If left empty, it will default to
// "localhost:%port%". When the worker registers with the master, the worker
// will report the worker address, so that the master can tell clients where to
// read from. The address may contain the placeholder "%port%", which will be
// "localhost:%port%". When the worker registers with the dispatcher, the worker
// will report the worker address, so that the dispatcher can tell clients where
// to read from. The address may contain the placeholder "%port%", which will be
// replaced with the value of BoundPort().
Status NewWorkerServer(int port, const std::string& protocol,
const std::string& master_address,
const std::string& dispatcher_address,
const std::string& worker_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server);
// Creates a worker using the default worker_address.
Status NewWorkerServer(int port, const std::string& protocol,
const std::string& master_address,
const std::string& dispatcher_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server);
} // namespace data

View File

@ -45,9 +45,9 @@ Status TestCluster::Initialize() {
"Test cluster has already been initialized.");
}
initialized_ = true;
TF_RETURN_IF_ERROR(NewMasterServer(/*port=*/0, kProtocol, &master_));
TF_RETURN_IF_ERROR(master_->Start());
master_address_ = absl::StrCat("localhost:", master_->BoundPort());
TF_RETURN_IF_ERROR(NewDispatchServer(/*port=*/0, kProtocol, &dispatcher_));
TF_RETURN_IF_ERROR(dispatcher_->Start());
dispatcher_address_ = absl::StrCat("localhost:", dispatcher_->BoundPort());
workers_.reserve(num_workers_);
worker_addresses_.reserve(num_workers_);
for (int i = 0; i < num_workers_; ++i) {
@ -59,14 +59,14 @@ Status TestCluster::Initialize() {
Status TestCluster::AddWorker() {
std::unique_ptr<WorkerGrpcDataServer> worker;
TF_RETURN_IF_ERROR(
NewWorkerServer(/*port=*/0, kProtocol, master_address_, &worker));
NewWorkerServer(/*port=*/0, kProtocol, dispatcher_address_, &worker));
TF_RETURN_IF_ERROR(worker->Start());
worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort()));
workers_.push_back(std::move(worker));
return Status::OK();
}
std::string TestCluster::MasterAddress() { return master_address_; }
std::string TestCluster::DispatcherAddress() { return dispatcher_address_; }
std::string TestCluster::WorkerAddress(int index) {
DCHECK_GE(index, 0);

View File

@ -24,7 +24,7 @@ namespace data {
// Helper class for unit testing a tf.data service cluster.
class TestCluster {
public:
// Creates a new test cluster with a master and `num_workers` workers.
// Creates a new test cluster with a dispatcher and `num_workers` workers.
explicit TestCluster(int num_workers);
// Initializes the test cluster. This must be called before interacting with
@ -32,8 +32,8 @@ class TestCluster {
Status Initialize();
// Adds a new worker to the cluster.
Status AddWorker();
// Returns the master address in the form "hostname:port".
std::string MasterAddress();
// Returns the dispatcher address in the form "hostname:port".
std::string DispatcherAddress();
// Returns the address of the worker at the specified index, in the form
// "hostname:port". The index must be non-negative and less than the number of
// workers in the cluster.
@ -42,8 +42,8 @@ class TestCluster {
private:
bool initialized_ = false;
int num_workers_;
std::unique_ptr<MasterGrpcDataServer> master_;
std::string master_address_;
std::unique_ptr<DispatchGrpcDataServer> dispatcher_;
std::string dispatcher_address_;
std::vector<std::unique_ptr<WorkerGrpcDataServer>> workers_;
std::vector<std::string> worker_addresses_;
};

View File

@ -21,9 +21,9 @@ limitations under the License.
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/data/dataset.pb.h"
#include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/master.pb.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/errors.h"
@ -45,9 +45,9 @@ auto* tf_data_service_created =
"has been created.");
} // namespace
DataServiceWorkerImpl::DataServiceWorkerImpl(const std::string& master_address,
const std::string& protocol)
: master_address_(master_address), protocol_(protocol) {
DataServiceWorkerImpl::DataServiceWorkerImpl(
const std::string& dispatcher_address, const std::string& protocol)
: dispatcher_address_(dispatcher_address), protocol_(protocol) {
tf_data_service_created->GetCell()->Set(true);
}
@ -67,14 +67,13 @@ void DataServiceWorkerImpl::Start(const std::string& worker_address) {
heartbeat_thread_.reset(thread);
Status s = Register();
while (!s.ok()) {
LOG(WARNING) << "Failed to register with master at " << master_address_
<< ": " << s;
LOG(WARNING) << "Failed to register with dispatcher at "
<< dispatcher_address_ << ": " << s;
Env::Default()->SleepForMicroseconds(kHeartbeatIntervalMicros);
s = Register();
}
}
Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
ProcessTaskResponse* response) {
mutex_lock l(mu_);
@ -169,29 +168,29 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
return Status::OK();
}
Status DataServiceWorkerImpl::EnsureMasterStubInitialized()
Status DataServiceWorkerImpl::EnsureDispatcherStubInitialized()
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!master_stub_) {
if (!dispatcher_stub_) {
::grpc::ChannelArguments args;
std::shared_ptr<::grpc::ChannelCredentials> credentials;
TF_RETURN_IF_ERROR(
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
auto channel =
::grpc::CreateCustomChannel(master_address_, credentials, args);
master_stub_ = MasterService::NewStub(channel);
::grpc::CreateCustomChannel(dispatcher_address_, credentials, args);
dispatcher_stub_ = DispatcherService::NewStub(channel);
}
return Status::OK();
}
Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
VLOG(3) << "Registering with master at " << master_address_;
TF_RETURN_IF_ERROR(EnsureMasterStubInitialized());
VLOG(3) << "Registering with dispatcher at " << dispatcher_address_;
TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
RegisterWorkerRequest req;
req.set_worker_address(worker_address_);
RegisterWorkerResponse resp;
grpc::ClientContext ctx;
grpc::Status s = master_stub_->RegisterWorker(&ctx, req, &resp);
grpc::Status s = dispatcher_stub_->RegisterWorker(&ctx, req, &resp);
if (!s.ok()) {
return grpc_util::WrapError("Failed to register worker", s);
}
@ -205,8 +204,8 @@ Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
VLOG(3) << "Sending " << pending_completed_tasks_.size()
<< " task updates to master";
TF_RETURN_IF_ERROR(EnsureMasterStubInitialized());
<< " task updates to dispatcher";
TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
WorkerUpdateRequest req;
req.set_worker_id(worker_id_);
for (int task_id : pending_completed_tasks_) {
@ -217,7 +216,7 @@ Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
WorkerUpdateResponse resp;
grpc::ClientContext ctx;
grpc::Status s = master_stub_->WorkerUpdate(&ctx, req, &resp);
grpc::Status s = dispatcher_stub_->WorkerUpdate(&ctx, req, &resp);
if (!s.ok()) {
return grpc_util::WrapError("Failed to send task updates", s);
}
@ -238,7 +237,7 @@ void DataServiceWorkerImpl::HeartbeatThread() {
}
Status s = SendTaskUpdate();
if (!s.ok()) {
LOG(WARNING) << "Failed to send task updates to master: " << s;
LOG(WARNING) << "Failed to send task updates to dispatcher: " << s;
}
}
}

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/worker.pb.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/lib/core/status.h"
@ -29,17 +29,17 @@ namespace data {
// A TensorFlow DataService serves dataset elements over RPC.
class DataServiceWorkerImpl {
public:
explicit DataServiceWorkerImpl(const std::string& master_address,
explicit DataServiceWorkerImpl(const std::string& dispatcher_address,
const std::string& protocol);
~DataServiceWorkerImpl();
// Starts the worker. The worker needs to know its own address so that it can
// register with the master.
// register with the dispatcher.
void Start(const std::string& worker_address);
// See worker.proto for API documentation.
/// Master-facing API.
/// Dispatcher-facing API.
Status ProcessTask(const ProcessTaskRequest* request,
ProcessTaskResponse* response);
@ -48,15 +48,15 @@ class DataServiceWorkerImpl {
GetElementResponse* response);
private:
// Sets master_stub_ if it isn't already set.
Status EnsureMasterStubInitialized();
// Registers the worker with the master.
// Sets dispatcher_stub_ if it isn't already set.
Status EnsureDispatcherStubInitialized();
// Registers the worker with the dispatcher.
Status Register();
// Sends task status to the master.
// Sends task status to the dispatcher.
Status SendTaskUpdate();
// Creates an iterator to process a task.
Status ProcessTaskInternal(const TaskDef& task);
// A thread for updating the master with worker status.
// A thread for updating the dispatcher with worker status.
void HeartbeatThread();
typedef struct Task {
@ -67,18 +67,19 @@ class DataServiceWorkerImpl {
std::unique_ptr<standalone::Iterator> iterator;
} Task;
const std::string master_address_;
// Protocol for communicating with the master.
const std::string dispatcher_address_;
// Protocol for communicating with the dispatcher.
const std::string protocol_;
// The worker's own address.
std::string worker_address_;
mutex mu_;
int64 worker_id_ TF_GUARDED_BY(mu_);
std::unique_ptr<MasterService::Stub> master_stub_ TF_GUARDED_BY(mu_);
std::unique_ptr<DispatcherService::Stub> dispatcher_stub_ TF_GUARDED_BY(mu_);
// Information about tasks, keyed by task ids.
absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_);
// List of completed tasks which haven't yet been communicated to the master.
// List of completed tasks which haven't yet been communicated to the
// dispatcher.
std::vector<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
bool cancelled_ TF_GUARDED_BY(mu_) = false;
// Condition variable for notifying the heartbeat thread.

View File

@ -69,7 +69,7 @@ const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second.
// Dataset for reading data from the tf.data service non-deterministically.
//
// This dataset interleaves dataset elements produced by multiple tf.data
// workers. We periodically query the tf.data master to determine which workers
// workers. We periodically query the dispatcher to determine which workers
// to read from (in case workers are added or removed).
class DataServiceDatasetOp::Dataset : public DatasetBase {
public:
@ -199,12 +199,13 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
Status Initialize(IteratorContext* ctx) override {
VLOG(3) << "Connecting to " << dataset()->address_
<< " in data service dataset op";
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
DataServiceDispatcherClient dispatcher(dataset()->address_,
dataset()->protocol_);
if (dataset()->job_name_.empty()) {
TF_RETURN_IF_ERROR(master.CreateJob(
TF_RETURN_IF_ERROR(dispatcher.CreateJob(
dataset()->dataset_id_, dataset()->processing_mode_, &job_id_));
} else {
TF_RETURN_IF_ERROR(master.GetOrCreateJob(
TF_RETURN_IF_ERROR(dispatcher.GetOrCreateJob(
dataset()->dataset_id_, dataset()->processing_mode_,
dataset()->job_name_, iterator_index_, &job_id_));
}
@ -283,11 +284,12 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
// Periodically refresh the task list.
// Maintain one thread fetching elements for each task.
// TODO(aaudibert): Instead of polling, have master send updates when
// TODO(aaudibert): Instead of polling, have dispatcher send updates when
// the list of tasks changes.
void TaskThreadManager(std::unique_ptr<IteratorContext> ctx) {
VLOG(3) << "Starting task thread manager";
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
DataServiceDispatcherClient dispatcher(dataset()->address_,
dataset()->protocol_);
uint64 next_check = Env::Default()->NowMicros();
while (true) {
{
@ -305,18 +307,19 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
return;
}
}
UpdateTasks(&master);
UpdateTasks(&dispatcher);
UpdateWorkerThreads(ctx.get());
next_check = Env::Default()->NowMicros() +
dataset()->task_refresh_interval_ms_ * 1000;
}
}
void UpdateTasks(DataServiceMasterClient* master) LOCKS_EXCLUDED(mu_) {
void UpdateTasks(DataServiceDispatcherClient* dispatcher)
LOCKS_EXCLUDED(mu_) {
VLOG(3) << "Updating tasks";
std::vector<TaskInfo> tasks;
bool job_finished;
Status s = master->GetTasks(job_id_, &tasks, &job_finished);
Status s = dispatcher->GetTasks(job_id_, &tasks, &job_finished);
if (!s.ok()) {
LOG(WARNING) << "Failed to get task info for job id " << job_id_ << ": "
<< s;

View File

@ -53,7 +53,7 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
DataServiceMasterClient client(address, protocol);
DataServiceDispatcherClient client(address, protocol);
int64 dataset_id;
OP_REQUIRES_OK(ctx, client.RegisterDataset(graph_def, &dataset_id));

View File

@ -25,7 +25,7 @@ namespace data {
// Registers a dataset with the tf.data service.
//
// The address and protocol inputs are used to connect to the tf.data master.
// The address and protocol inputs are used to connect to the dispatcher.
// The external state policy attribute determines whether to ignore, warn, or
// error out when the dataset contains external state.
// The op produces a dataset id for identifying the registered dataset.

View File

@ -77,7 +77,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
amount of memory used, since `distribute` won't use more than
`element_size` * `max_outstanding_requests` of memory.
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
the master for task changes.
the dispatcher for task changes.
"""
if job_name is None:
@ -173,7 +173,7 @@ def _distribute(processing_mode,
of memory used, since `distribute` won't use more than `element_size` *
`max_outstanding_requests` of memory.
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
master for task changes.
dispatcher for task changes.
Returns:
Dataset: A `Dataset` of the elements produced by the data service.

View File

@ -19,5 +19,5 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops.data_service_ops import distribute
from tensorflow.python.data.experimental.service.server_lib import MasterServer
from tensorflow.python.data.experimental.service.server_lib import DispatchServer
from tensorflow.python.data.experimental.service.server_lib import WorkerServer

View File

@ -24,35 +24,35 @@ from tensorflow.python.data.experimental.service import _pywrap_server_lib
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.service.MasterServer", v1=[])
class MasterServer(object):
"""An in-process tf.data service master server.
@tf_export("data.experimental.service.DispatchServer", v1=[])
class DispatchServer(object):
"""An in-process tf.data service dispatch server.
A `tf.data.experimental.service.MasterServer` coordinates a cluster of
A `tf.data.experimental.service.DispatchServer` coordinates a cluster of
`tf.data.experimental.service.WorkerServer`s. When the workers start, they
register themselves with the master.
register themselves with the dispatcher.
>>> master = tf.data.experimental.service.MasterServer(port=0)
>>> master_address = master.target.split("://")[1]
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
>>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer(
... port=0, master_address=master_address)
... port=0, dispatcher_address=dispatcher_address)
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=master.target))
... processing_mode="parallel_epochs", service=dispatcher.target))
>>> print(list(dataset.as_numpy_iterator()))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
When starting a dedicated tf.data master process, use join() to block
When starting a dedicated tf.data dispatch process, use join() to block
indefinitely after starting up the server.
```
master = tf.data.experimental.service.MasterServer(port=5050)
master.join()
dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
dispatcher.join()
```
"""
def __init__(self, port, protocol=None, start=True):
"""Creates a new master server.
"""Creates a new dispatch server.
Args:
port: Specifies the port to bind to.
@ -68,15 +68,16 @@ class MasterServer(object):
if protocol is None:
protocol = "grpc"
self._protocol = protocol
self._server = _pywrap_server_lib.TF_DATA_NewMasterServer(port, protocol)
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(port, protocol)
if start:
self._server.start()
def start(self):
"""Starts this server.
>>> master = tf.data.experimental.service.MasterServer(port=0, start=False)
>>> master.start()
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0,
... start=False)
>>> dispatcher.start()
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
@ -87,11 +88,11 @@ class MasterServer(object):
def join(self):
"""Blocks until the server has shut down.
This is useful when starting a dedicated master process.
This is useful when starting a dedicated dispatch process.
```
master = tf.data.experimental.service.MasterServer(port=5050)
master.join()
dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
dispatcher.join()
```
Raises:
@ -104,10 +105,10 @@ class MasterServer(object):
def target(self):
"""Returns a target that can be used to connect to the server.
>>> master = tf.data.experimental.service.MasterServer(port=0)
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=master.target))
... processing_mode="parallel_epochs", service=dispatcher.target))
The returned string will be in the form protocol://address, e.g.
"grpc://localhost:5050".
@ -136,7 +137,7 @@ class MasterServer(object):
return "localhost:{0}".format(self._server.bound_port())
def _num_workers(self):
"""Returns the number of workers registered with the master."""
"""Returns the number of workers registered with the dispatcher."""
return self._server.num_workers()
@ -147,15 +148,15 @@ class WorkerServer(object):
A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset`
processing for user-defined datasets, and provides the resulting elements over
RPC. A worker is associated with a single
`tf.data.experimental.service.MasterServer`.
`tf.data.experimental.service.DispatchServer`.
>>> master = tf.data.experimental.service.MasterServer(port=0)
>>> master_address = master.target.split("://")[1]
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
>>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer(
... port=0, master_address=master_address)
... port=0, dispatcher_address=dispatcher_address)
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=master.target))
... processing_mode="parallel_epochs", service=dispatcher.target))
>>> print(list(dataset.as_numpy_iterator()))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
@ -164,14 +165,14 @@ class WorkerServer(object):
```
worker = tf.data.experimental.service.WorkerServer(
port=5051, master_address="grpc://localhost:5050")
port=5051, dispatcher_address="grpc://localhost:5050")
worker.join()
```
"""
def __init__(self,
port,
master_address,
dispatcher_address,
worker_address=None,
protocol=None,
start=True):
@ -180,11 +181,12 @@ class WorkerServer(object):
Args:
port: Specifies the port to bind to. A value of 0 indicates that the
worker can bind to any available port.
master_address: Specifies the address of the master server.
dispatcher_address: Specifies the address of the dispatcher.
worker_address: (Optional.) Specifies the address of the worker server.
This address is passed to the master server so that the master can tell
clients how to connect to this worker. Defaults to `"localhost:%port%"`,
where `%port%` will be replaced with the port used by the worker.
This address is passed to the dispatcher so that the dispatcher can
tell clients how to connect to this worker. Defaults to
`"localhost:%port%"`, where `%port%` will be replaced with the port used
by the worker.
protocol: (Optional.) Specifies the protocol to be used by the server.
Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`.
start: (Optional.) Boolean, indicating whether to start the server after
@ -201,7 +203,7 @@ class WorkerServer(object):
self._protocol = protocol
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
port, protocol, master_address, worker_address)
port, protocol, dispatcher_address, worker_address)
if start:
self._server.start()
@ -221,7 +223,7 @@ class WorkerServer(object):
```
worker_server = tf.data.experimental.service.WorkerServer(
port=5051, master_address="grpc://localhost:5050")
port=5051, dispatcher_address="grpc://localhost:5050")
worker_server.join()
```

View File

@ -25,68 +25,68 @@ from tensorflow.python.platform import test
class ServerLibTest(test.TestCase):
def testStartMaster(self):
master = server_lib.MasterServer(0, start=False)
master.start()
def testStartDispatcher(self):
dispatcher = server_lib.DispatchServer(0, start=False)
dispatcher.start()
def testMultipleStartMaster(self):
master = server_lib.MasterServer(0, start=True)
master.start()
def testMultipleStartDispatcher(self):
dispatcher = server_lib.DispatchServer(0, start=True)
dispatcher.start()
def testStartWorker(self):
master = server_lib.MasterServer(0)
worker = server_lib.WorkerServer(0, master._address, start=False)
dispatcher = server_lib.DispatchServer(0)
worker = server_lib.WorkerServer(0, dispatcher._address, start=False)
worker.start()
def testMultipleStartWorker(self):
master = server_lib.MasterServer(0)
worker = server_lib.WorkerServer(0, master._address, start=True)
dispatcher = server_lib.DispatchServer(0)
worker = server_lib.WorkerServer(0, dispatcher._address, start=True)
worker.start()
def testStopMaster(self):
master = server_lib.MasterServer(0)
master._stop()
master._stop()
def testStopDispatcher(self):
dispatcher = server_lib.DispatchServer(0)
dispatcher._stop()
dispatcher._stop()
def testStopWorker(self):
master = server_lib.MasterServer(0)
worker = server_lib.WorkerServer(0, master._address)
dispatcher = server_lib.DispatchServer(0)
worker = server_lib.WorkerServer(0, dispatcher._address)
worker._stop()
worker._stop()
def testStopStartMaster(self):
master = server_lib.MasterServer(0)
master._stop()
def testStopStartDispatcher(self):
dispatcher = server_lib.DispatchServer(0)
dispatcher._stop()
with self.assertRaisesRegex(
RuntimeError, "Server cannot be started after it has been stopped"):
master.start()
dispatcher.start()
def testStopStartWorker(self):
master = server_lib.MasterServer(0)
worker = server_lib.WorkerServer(0, master._address)
dispatcher = server_lib.DispatchServer(0)
worker = server_lib.WorkerServer(0, dispatcher._address)
worker._stop()
with self.assertRaisesRegex(
RuntimeError, "Server cannot be started after it has been stopped"):
worker.start()
def testJoinMaster(self):
master = server_lib.MasterServer(0)
master._stop()
master.join()
def testJoinDispatcher(self):
dispatcher = server_lib.DispatchServer(0)
dispatcher._stop()
dispatcher.join()
def testJoinWorker(self):
master = server_lib.MasterServer(0)
worker = server_lib.WorkerServer(0, master._address)
dispatcher = server_lib.DispatchServer(0)
worker = server_lib.WorkerServer(0, dispatcher._address)
worker._stop()
worker.join()
def testMasterNumWorkers(self):
master = server_lib.MasterServer(0)
self.assertEqual(0, master._num_workers())
worker1 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable
self.assertEqual(1, master._num_workers())
worker2 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable
self.assertEqual(2, master._num_workers())
def testDispatcherNumWorkers(self):
dispatcher = server_lib.DispatchServer(0)
self.assertEqual(0, dispatcher._num_workers())
worker1 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable
self.assertEqual(1, dispatcher._num_workers())
worker2 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable
self.assertEqual(2, dispatcher._num_workers())
if __name__ == "__main__":

View File

@ -28,13 +28,14 @@ limitations under the License.
namespace py = pybind11;
PYBIND11_MODULE(_pywrap_server_lib, m) {
py::class_<tensorflow::data::MasterGrpcDataServer>(m, "MasterGrpcDataServer")
.def("start", &tensorflow::data::MasterGrpcDataServer::Start)
.def("stop", &tensorflow::data::MasterGrpcDataServer::Stop)
.def("join", &tensorflow::data::MasterGrpcDataServer::Join)
.def("bound_port", &tensorflow::data::MasterGrpcDataServer::BoundPort)
py::class_<tensorflow::data::DispatchGrpcDataServer>(m,
"DispatchGrpcDataServer")
.def("start", &tensorflow::data::DispatchGrpcDataServer::Start)
.def("stop", &tensorflow::data::DispatchGrpcDataServer::Stop)
.def("join", &tensorflow::data::DispatchGrpcDataServer::Join)
.def("bound_port", &tensorflow::data::DispatchGrpcDataServer::BoundPort)
.def("num_workers",
[](tensorflow::data::MasterGrpcDataServer* server) -> int {
[](tensorflow::data::DispatchGrpcDataServer* server) -> int {
int num_workers;
tensorflow::Status status = server->NumWorkers(&num_workers);
tensorflow::MaybeRaiseFromStatus(status);
@ -48,12 +49,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
.def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort);
m.def(
"TF_DATA_NewMasterServer",
"TF_DATA_NewDispatchServer",
[](int port, std::string protocol)
-> std::unique_ptr<tensorflow::data::MasterGrpcDataServer> {
std::unique_ptr<tensorflow::data::MasterGrpcDataServer> server;
-> std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> {
std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> server;
tensorflow::Status status =
tensorflow::data::NewMasterServer(port, protocol, &server);
tensorflow::data::NewDispatchServer(port, protocol, &server);
tensorflow::MaybeRaiseFromStatus(status);
return server;
},
@ -61,12 +62,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
m.def(
"TF_DATA_NewWorkerServer",
[](int port, std::string protocol, std::string master_address,
[](int port, std::string protocol, std::string dispatcher_address,
std::string worker_address)
-> std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> {
std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server;
tensorflow::Status status = tensorflow::data::NewWorkerServer(
port, protocol, master_address, worker_address, &server);
port, protocol, dispatcher_address, worker_address, &server);
tensorflow::MaybeRaiseFromStatus(status);
return server;
},

View File

@ -59,23 +59,25 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
num_workers: The number of workers in the cluster.
Returns:
The address of the master.
The address of the dispatcher.
"""
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
self._servers = []
for _ in range(num_workers):
self._servers.append(
server_lib.WorkerServer(
port=0, master_address=self._master._address, protocol=PROTOCOL))
port=0,
dispatcher_address=self._dispatcher._address,
protocol=PROTOCOL))
return self._master._address
return self._dispatcher._address
@combinations.generate(test_base.eager_only_combinations())
def testDistributeBasic(self):
num_elements = 10
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
results = [elem.numpy() for elem in ds]
self.assertEqual(list(range(num_elements)), results)
@ -83,10 +85,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testDifferentShuffleOrders(self):
random_seed.set_random_seed(None)
num_elements = 100
master_address = self.create_cluster(2)
dispatcher_address = self.create_cluster(2)
ds = dataset_ops.Dataset.range(num_elements)
ds = ds.shuffle(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
output = [elem.numpy() for elem in ds]
# The output will be two sequences of range(num_elements)
@ -104,9 +106,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testMultipleEpochs(self):
num_elements = 3
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
for _ in range(10):
self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds])
@ -114,9 +116,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testRepeatedDataset(self):
num_elements = 10
num_repetitions = 5
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
ds = ds.repeat(num_repetitions)
self.assertDatasetProduces(
ds, expected_output=num_repetitions * list(range(num_elements)))
@ -125,12 +127,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testConcurrentEpoch(self):
num_elements = 10
num_datasets = 3
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
iterators = []
results = []
for _ in range(num_datasets):
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
iterators.append(iter(ds))
results.append([])
@ -146,9 +148,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
self.skipTest("Not yet implemented")
num_elements = 10
num_iterators = 3
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
result = []
iterators = []
for _ in range(num_iterators):
@ -170,20 +172,20 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testMultiWorker(self):
num_workers = 3
num_elements = 10
master_address = self.create_cluster(num_workers)
dispatcher_address = self.create_cluster(num_workers)
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
results = [elem.numpy() for elem in ds]
self.assertCountEqual(num_workers * list(range(num_elements)), results)
@combinations.generate(test_base.eager_only_combinations())
def testAddWorkerMidJob(self):
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
self._worker = server_lib.WorkerServer(
port=0, master_address=self._master._address, protocol=PROTOCOL)
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
num_elements = 100
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, self._master._address)
ds = _make_distributed_dataset(ds, self._dispatcher._address)
iterator = iter(ds)
results = []
# Read halfway through the dataset.
@ -191,10 +193,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
results.append(next(iterator).numpy())
self._new_worker = server_lib.WorkerServer(
port=0, master_address=self._master._address, protocol=PROTOCOL)
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
# Wait for the new worker to register with the master.
while self._master._num_workers() < 2:
# Wait for the new worker to register with the dispatcher.
while self._dispatcher._num_workers() < 2:
time.sleep(10 / 1000) # 10ms
for elem in iterator:
@ -206,12 +208,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
combinations.times(test_base.eager_only_combinations(),
combinations.combine(use_same_port=[True, False])))
def testRestartWorker(self, use_same_port):
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
self._worker = server_lib.WorkerServer(
port=0, master_address=self._master._address, protocol=PROTOCOL)
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
num_elements = 100
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, self._master._address)
ds = _make_distributed_dataset(ds, self._dispatcher._address)
iterator = iter(ds)
# Read halfway through the dataset.
midpoint = num_elements // 2
@ -224,7 +226,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
port = int(self._worker._address.split(":")[1])
self._worker._stop()
self._new_worker = server_lib.WorkerServer(
port=port, master_address=self._master._address, protocol=PROTOCOL)
port=port,
dispatcher_address=self._dispatcher._address,
protocol=PROTOCOL)
# There may have been some elements prefetched from the first worker
# before it was stopped.
@ -259,12 +263,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testInsideFunction(self):
num_workers = 3
num_elements = 10
master_address = self.create_cluster(num_workers)
dispatcher_address = self.create_cluster(num_workers)
@def_function.function
def f():
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
result = tensor_array_ops.TensorArray(
dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
i = 0
@ -279,10 +283,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testSharedJobName(self):
num_elements = 100
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name")
ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
iter1 = iter(ds1)
iter2 = iter(ds2)
results = []
@ -298,20 +302,22 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testDifferentJobNames(self):
num_elements = 10
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name1")
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name2")
ds1 = _make_distributed_dataset(
ds, dispatcher_address, job_name="job_name1")
ds2 = _make_distributed_dataset(
ds, dispatcher_address, job_name="job_name2")
self.assertDatasetProduces(ds1, list(range(num_elements)))
self.assertDatasetProduces(ds2, list(range(num_elements)))
@combinations.generate(test_base.eager_only_combinations())
def testSharedJobNameMultiIteration(self):
num_elements = 10
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name")
ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
# iteration 1
self.assertDatasetProduces(ds1, list(range(num_elements)))
self.assertDatasetProduces(ds2, [])
@ -323,11 +329,11 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testSharedJobNameRepeat(self):
num_elements = 100
num_repetitions = 3
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
ds1 = ds1.repeat(num_repetitions)
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name")
ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
ds2 = ds2.repeat(num_repetitions)
results = []
iter1 = iter(ds1)
@ -345,7 +351,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testApplyDeterminismOption(self):
elements = list(range(10))
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
def dataset_fn(delay_ms):
@ -362,7 +368,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
opts = dataset_ops.Options()
opts.experimental_deterministic = False
ds = ds.with_options(opts)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
return ds
self.checkDeterminism(
@ -379,8 +385,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
options.experimental_external_state_policy = external_state_policy
ds = ds.with_options(options)
master_address = self.create_cluster(3)
ds = _make_distributed_dataset(ds, master_address)
dispatcher_address = self.create_cluster(3)
ds = _make_distributed_dataset(ds, dispatcher_address)
next(iter(ds))
@combinations.generate(
@ -400,12 +406,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testDistributeFromInterleave(self):
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(2)
def interleave_fn(_):
ds = dataset_ops.Dataset.range(2)
_make_distributed_dataset(ds, master_address)
_make_distributed_dataset(ds, dispatcher_address)
return ds
with self.assertRaisesRegex(

View File

@ -1,6 +1,6 @@
path: "tensorflow.data.experimental.service.MasterServer"
path: "tensorflow.data.experimental.service.DispatchServer"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.MasterServer\'>"
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.DispatchServer\'>"
is_instance: "<type \'object\'>"
member {
name: "target"

View File

@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'port\', \'master_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
argspec: "args=[\'self\', \'port\', \'dispatcher_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
}
member_method {
name: "join"

View File

@ -1,7 +1,7 @@
path: "tensorflow.data.experimental.service"
tf_module {
member {
name: "MasterServer"
name: "DispatchServer"
mtype: "<type \'type\'>"
}
member {

View File

@ -99,8 +99,8 @@ tensorflow::data::GrpcDataServerBase::Join
tensorflow::data::GrpcDataServerBase::Start
tensorflow::data::GrpcDataServerBase::Stop
tensorflow::data::GrpcDataServerBase::BoundPort
tensorflow::data::MasterGrpcDataServer::NumWorkers
tensorflow::data::NewMasterServer
tensorflow::data::DispatchGrpcDataServer::NumWorkers
tensorflow::data::NewDispatchServer
tensorflow::data::NewWorkerServer
[protos_all] # device_lib, dtypes