Update "master" to "dispatch"/"dispatcher" in tf.data service terminology.
Dispatcher is more descriptive and follows the guidance in https://developers.google.com/style/word-list#master PiperOrigin-RevId: 321613785 Change-Id: Iaa576d35f0581e21278101f8b31201ba737a6865
This commit is contained in:
parent
0aa1d61fad
commit
6b9a9d98bb
|
@ -28,8 +28,8 @@ tf_proto_library(
|
|||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "master_proto",
|
||||
srcs = ["master.proto"],
|
||||
name = "dispatcher_proto",
|
||||
srcs = ["dispatcher.proto"],
|
||||
has_services = 1,
|
||||
cc_api_version = 2,
|
||||
protodeps = tf_additional_all_protos() + [
|
||||
|
@ -49,17 +49,17 @@ tf_proto_library(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "master_impl",
|
||||
srcs = ["master_impl.cc"],
|
||||
name = "dispatcher_impl",
|
||||
srcs = ["dispatcher_impl.cc"],
|
||||
hdrs = [
|
||||
"master_impl.h",
|
||||
"dispatcher_impl.h",
|
||||
],
|
||||
deps = [
|
||||
":common_proto_cc",
|
||||
":credentials_factory",
|
||||
":data_service",
|
||||
":dispatcher_proto_cc",
|
||||
":grpc_util",
|
||||
":master_proto_cc",
|
||||
":worker_cc_grpc_proto",
|
||||
":worker_proto_cc",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
|
@ -86,9 +86,9 @@ cc_library(
|
|||
deps = [
|
||||
":common_proto_cc",
|
||||
":credentials_factory",
|
||||
":dispatcher_cc_grpc_proto",
|
||||
":dispatcher_proto_cc",
|
||||
":grpc_util",
|
||||
":master_cc_grpc_proto",
|
||||
":master_proto_cc",
|
||||
":worker_proto_cc",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
|
@ -207,12 +207,12 @@ tf_cc_test(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_master_impl",
|
||||
srcs = ["grpc_master_impl.cc"],
|
||||
hdrs = ["grpc_master_impl.h"],
|
||||
name = "grpc_dispatcher_impl",
|
||||
srcs = ["grpc_dispatcher_impl.cc"],
|
||||
hdrs = ["grpc_dispatcher_impl.h"],
|
||||
deps = [
|
||||
":master_cc_grpc_proto",
|
||||
":master_impl",
|
||||
":dispatcher_cc_grpc_proto",
|
||||
":dispatcher_impl",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
|
@ -250,7 +250,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":credentials_factory",
|
||||
":grpc_master_impl",
|
||||
":grpc_dispatcher_impl",
|
||||
":grpc_util",
|
||||
":grpc_worker_impl",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -268,9 +268,9 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":credentials_factory",
|
||||
":dispatcher_cc_grpc_proto",
|
||||
":dispatcher_proto_cc",
|
||||
":grpc_util",
|
||||
":master_cc_grpc_proto",
|
||||
":master_proto_cc",
|
||||
":worker_cc_grpc_proto",
|
||||
":worker_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -287,12 +287,12 @@ tf_cc_test(
|
|||
tags = ["no_windows"],
|
||||
deps = [
|
||||
":data_service",
|
||||
":grpc_master_impl",
|
||||
":dispatcher_cc_grpc_proto",
|
||||
":dispatcher_proto_cc",
|
||||
":grpc_dispatcher_impl",
|
||||
":grpc_util",
|
||||
":grpc_worker_impl",
|
||||
":local_credentials_factory",
|
||||
":master_cc_grpc_proto",
|
||||
":master_proto_cc",
|
||||
":server_lib",
|
||||
":test_cluster",
|
||||
":test_util",
|
||||
|
@ -309,11 +309,11 @@ tf_cc_test(
|
|||
)
|
||||
|
||||
cc_grpc_library(
|
||||
name = "master_cc_grpc_proto",
|
||||
srcs = [":master_proto"],
|
||||
name = "dispatcher_cc_grpc_proto",
|
||||
srcs = [":dispatcher_proto"],
|
||||
generate_mocks = True,
|
||||
grpc_only = True,
|
||||
deps = [":master_proto_cc"],
|
||||
deps = [":dispatcher_proto_cc"],
|
||||
)
|
||||
|
||||
cc_grpc_library(
|
||||
|
|
|
@ -18,8 +18,8 @@ limitations under the License.
|
|||
#include "grpcpp/create_channel.h"
|
||||
#include "grpcpp/security/credentials.h"
|
||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
|
||||
|
@ -54,8 +54,8 @@ std::string ProcessingModeToString(ProcessingMode mode) {
|
|||
}
|
||||
}
|
||||
|
||||
Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
|
||||
int64* dataset_id) {
|
||||
Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset,
|
||||
int64* dataset_id) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
GetOrRegisterDatasetRequest req;
|
||||
*req.mutable_dataset()->mutable_graph() = dataset;
|
||||
|
@ -69,9 +69,9 @@ Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterClient::CreateJob(int64 dataset_id,
|
||||
ProcessingMode processing_mode,
|
||||
int64* job_id) {
|
||||
Status DataServiceDispatcherClient::CreateJob(int64 dataset_id,
|
||||
ProcessingMode processing_mode,
|
||||
int64* job_id) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
CreateJobRequest req;
|
||||
req.set_dataset_id(dataset_id);
|
||||
|
@ -88,11 +88,9 @@ Status DataServiceMasterClient::CreateJob(int64 dataset_id,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id,
|
||||
ProcessingMode processing_mode,
|
||||
const std::string& job_name,
|
||||
int job_name_index,
|
||||
int64* job_id) {
|
||||
Status DataServiceDispatcherClient::GetOrCreateJob(
|
||||
int64 dataset_id, ProcessingMode processing_mode,
|
||||
const std::string& job_name, int job_name_index, int64* job_id) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
GetOrCreateJobRequest req;
|
||||
req.set_dataset_id(dataset_id);
|
||||
|
@ -112,9 +110,9 @@ Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterClient::GetTasks(int64 job_id,
|
||||
std::vector<TaskInfo>* tasks,
|
||||
bool* job_finished) {
|
||||
Status DataServiceDispatcherClient::GetTasks(int64 job_id,
|
||||
std::vector<TaskInfo>* tasks,
|
||||
bool* job_finished) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
GetTasksRequest req;
|
||||
req.set_job_id(job_id);
|
||||
|
@ -132,7 +130,8 @@ Status DataServiceMasterClient::GetTasks(int64 job_id,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterClient::GetWorkers(std::vector<WorkerInfo>* workers) {
|
||||
Status DataServiceDispatcherClient::GetWorkers(
|
||||
std::vector<WorkerInfo>* workers) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
GetWorkersRequest req;
|
||||
GetWorkersResponse resp;
|
||||
|
@ -148,12 +147,12 @@ Status DataServiceMasterClient::GetWorkers(std::vector<WorkerInfo>* workers) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterClient::EnsureInitialized() {
|
||||
Status DataServiceDispatcherClient::EnsureInitialized() {
|
||||
std::shared_ptr<grpc::ChannelCredentials> credentials;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
|
||||
auto channel = grpc::CreateChannel(address_, credentials);
|
||||
stub_ = MasterService::NewStub(channel);
|
||||
stub_ = DispatcherService::NewStub(channel);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -187,10 +186,11 @@ Status DataServiceWorkerClient::EnsureInitialized() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CreateDataServiceMasterClient(
|
||||
Status CreateDataServiceDispatcherClient(
|
||||
const std::string& address, const std::string& protocol,
|
||||
std::unique_ptr<DataServiceMasterClient>* out) {
|
||||
auto client = absl::make_unique<DataServiceMasterClient>(address, protocol);
|
||||
std::unique_ptr<DataServiceDispatcherClient>* out) {
|
||||
auto client =
|
||||
absl::make_unique<DataServiceDispatcherClient>(address, protocol);
|
||||
TF_RETURN_IF_ERROR(client->Initialize());
|
||||
*out = std::move(client);
|
||||
return Status::OK();
|
||||
|
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
|
||||
#define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
|
||||
|
||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
@ -67,11 +67,11 @@ class DataServiceClientBase {
|
|||
const std::string protocol_;
|
||||
};
|
||||
|
||||
// Client for communicating with the tf.data service master.
|
||||
class DataServiceMasterClient : public DataServiceClientBase {
|
||||
// Client for communicating with the tf.data service dispatcher.
|
||||
class DataServiceDispatcherClient : public DataServiceClientBase {
|
||||
public:
|
||||
DataServiceMasterClient(const std::string& address,
|
||||
const std::string& protocol)
|
||||
DataServiceDispatcherClient(const std::string& address,
|
||||
const std::string& protocol)
|
||||
: DataServiceClientBase(address, protocol) {}
|
||||
|
||||
// Registers a dataset with the tf.data service, and stores the generated
|
||||
|
@ -90,13 +90,13 @@ class DataServiceMasterClient : public DataServiceClientBase {
|
|||
const std::string& job_name, int job_name_index,
|
||||
int64* job_id);
|
||||
|
||||
// Queries the master for the tasks associated with the specified job.
|
||||
// Queries the dispatcher for the tasks associated with the specified job.
|
||||
// The tasks will be stored in *tasks, and whether the job is finished will
|
||||
// be stored in `*job_finished`.
|
||||
Status GetTasks(int64 job_id, std::vector<TaskInfo>* tasks,
|
||||
bool* job_finished);
|
||||
|
||||
// Queries the master for its registered workers. The worker info will be
|
||||
// Queries the dispatcher for its registered workers. The worker info will be
|
||||
// stored in `*workers`.
|
||||
Status GetWorkers(std::vector<WorkerInfo>* workers);
|
||||
|
||||
|
@ -104,7 +104,7 @@ class DataServiceMasterClient : public DataServiceClientBase {
|
|||
Status EnsureInitialized() override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<MasterService::Stub> stub_;
|
||||
std::unique_ptr<DispatcherService::Stub> stub_;
|
||||
};
|
||||
|
||||
// Client for communicating with the tf.data service worker.
|
||||
|
@ -127,10 +127,10 @@ class DataServiceWorkerClient : public DataServiceClientBase {
|
|||
std::unique_ptr<WorkerService::Stub> stub_;
|
||||
};
|
||||
|
||||
// Creates and initializes a new tf.data service master client.
|
||||
Status CreateDataServiceMasterClient(
|
||||
// Creates and initializes a new tf.data service dispatcher client.
|
||||
Status CreateDataServiceDispatcherClient(
|
||||
const std::string& address, const std::string& protocol,
|
||||
std::unique_ptr<DataServiceMasterClient>* out);
|
||||
std::unique_ptr<DataServiceDispatcherClient>* out);
|
||||
|
||||
// Creates and initializes a new tf.data service worker client.
|
||||
Status CreateDataServiceWorkerClient(
|
||||
|
|
|
@ -19,9 +19,9 @@ limitations under the License.
|
|||
#include "grpcpp/security/credentials.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/core/data/compression_utils.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/master.pb.h"
|
||||
#include "tensorflow/core/data/service/server_lib.h"
|
||||
#include "tensorflow/core/data/service/test_cluster.h"
|
||||
#include "tensorflow/core/data/service/test_util.h"
|
||||
|
@ -66,9 +66,10 @@ TEST(DataService, ProcessingModeToString) {
|
|||
TEST(DataService, GetWorkers) {
|
||||
TestCluster cluster(1);
|
||||
TF_ASSERT_OK(cluster.Initialize());
|
||||
DataServiceMasterClient master(cluster.MasterAddress(), kProtocol);
|
||||
DataServiceDispatcherClient dispatcher(cluster.DispatcherAddress(),
|
||||
kProtocol);
|
||||
std::vector<WorkerInfo> workers;
|
||||
TF_EXPECT_OK(master.GetWorkers(&workers));
|
||||
TF_EXPECT_OK(dispatcher.GetWorkers(&workers));
|
||||
EXPECT_EQ(1, workers.size());
|
||||
}
|
||||
|
||||
|
|
|
@ -110,11 +110,11 @@ message GetWorkersResponse {
|
|||
repeated WorkerInfo workers = 1;
|
||||
}
|
||||
|
||||
service MasterService {
|
||||
// Registers a worker with the master.
|
||||
service DispatcherService {
|
||||
// Registers a worker with the dispatcher.
|
||||
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse);
|
||||
|
||||
// Updates the master with information about the worker's state.
|
||||
// Updates the dispatcher with information about the worker's state.
|
||||
rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse);
|
||||
|
||||
// Registers a dataset with the server, or returns its id if it is already
|
||||
|
@ -134,6 +134,6 @@ service MasterService {
|
|||
// Reports a list of all tasks for a job.
|
||||
rpc GetTasks(GetTasksRequest) returns (GetTasksResponse);
|
||||
|
||||
// Reports a list of all workers registered with the master.
|
||||
// Reports a list of all workers registered with the dispatcher.
|
||||
rpc GetWorkers(GetWorkersRequest) returns (GetWorkersResponse);
|
||||
}
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/data/service/master_impl.h"
|
||||
#include "tensorflow/core/data/service/dispatcher_impl.h"
|
||||
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
|
@ -26,8 +26,8 @@ limitations under the License.
|
|||
#include "tensorflow/core/data/service/common.pb.h"
|
||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||
#include "tensorflow/core/data/service/data_service.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/master.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
|
@ -53,10 +53,10 @@ Status CreateWorkerStub(const std::string& address,
|
|||
}
|
||||
} // namespace
|
||||
|
||||
DataServiceMasterImpl::DataServiceMasterImpl(const std::string protocol)
|
||||
DataServiceDispatcherImpl::DataServiceDispatcherImpl(const std::string protocol)
|
||||
: protocol_(protocol) {}
|
||||
|
||||
Status DataServiceMasterImpl::RegisterWorker(
|
||||
Status DataServiceDispatcherImpl::RegisterWorker(
|
||||
const RegisterWorkerRequest* request, RegisterWorkerResponse* response) {
|
||||
VLOG(3) << "Received register worker request";
|
||||
mutex_lock l(mu_);
|
||||
|
@ -86,8 +86,8 @@ Status DataServiceMasterImpl::RegisterWorker(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request,
|
||||
WorkerUpdateResponse* response) {
|
||||
Status DataServiceDispatcherImpl::WorkerUpdate(
|
||||
const WorkerUpdateRequest* request, WorkerUpdateResponse* response) {
|
||||
mutex_lock l(mu_);
|
||||
int64 worker_id = request->worker_id();
|
||||
for (auto& update : request->updates()) {
|
||||
|
@ -106,7 +106,7 @@ Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::GetOrRegisterDataset(
|
||||
Status DataServiceDispatcherImpl::GetOrRegisterDataset(
|
||||
const GetOrRegisterDatasetRequest* request,
|
||||
GetOrRegisterDatasetResponse* response) {
|
||||
uint64 fingerprint;
|
||||
|
@ -128,8 +128,8 @@ Status DataServiceMasterImpl::GetOrRegisterDataset(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
|
||||
const DatasetDef& dataset)
|
||||
int64 DataServiceDispatcherImpl::RegisterDataset(uint64 fingerprint,
|
||||
const DatasetDef& dataset)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
int64 dataset_id = next_dataset_id_++;
|
||||
auto new_dataset =
|
||||
|
@ -142,8 +142,8 @@ int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
|
|||
return dataset_id;
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
|
||||
CreateJobResponse* response) {
|
||||
Status DataServiceDispatcherImpl::CreateJob(const CreateJobRequest* request,
|
||||
CreateJobResponse* response) {
|
||||
VLOG(3) << "Received create job request for dataset id "
|
||||
<< request->dataset_id();
|
||||
ProcessingMode processing_mode = ProcessingMode(request->processing_mode());
|
||||
|
@ -157,7 +157,7 @@ Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::GetOrCreateJob(
|
||||
Status DataServiceDispatcherImpl::GetOrCreateJob(
|
||||
const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
|
||||
VLOG(3) << "Received get or create job request for dataset id "
|
||||
<< request->dataset_id() << " with name " << request->job_name()
|
||||
|
@ -193,7 +193,7 @@ Status DataServiceMasterImpl::GetOrCreateJob(
|
|||
}
|
||||
|
||||
// Validates that the job matches the given processing_mode and dataset_id.
|
||||
Status DataServiceMasterImpl::ValidateMatchingJob(
|
||||
Status DataServiceDispatcherImpl::ValidateMatchingJob(
|
||||
const Job& job, ProcessingMode processing_mode, int64 dataset_id) {
|
||||
DCHECK(job.name().has_value());
|
||||
std::string job_name = job.name().value();
|
||||
|
@ -214,10 +214,10 @@ Status DataServiceMasterImpl::ValidateMatchingJob(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
|
||||
ProcessingMode processing_mode,
|
||||
absl::optional<std::string> job_name,
|
||||
int64* out_job_id) LOCKS_EXCLUDED(mu_) {
|
||||
Status DataServiceDispatcherImpl::CreateJob(
|
||||
int64 dataset_id, ProcessingMode processing_mode,
|
||||
absl::optional<std::string> job_name, int64* out_job_id)
|
||||
LOCKS_EXCLUDED(mu_) {
|
||||
switch (processing_mode) {
|
||||
case ProcessingMode::PARALLEL_EPOCHS:
|
||||
break;
|
||||
|
@ -274,14 +274,16 @@ Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTask(
|
||||
const DataServiceDispatcherImpl::Task& DataServiceDispatcherImpl::CreateTask(
|
||||
Job* job, const std::string& worker_address) LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
return CreateTaskLocked(job, worker_address);
|
||||
}
|
||||
|
||||
const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTaskLocked(
|
||||
Job* job, const std::string& worker_address) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
const DataServiceDispatcherImpl::Task&
|
||||
DataServiceDispatcherImpl::CreateTaskLocked(Job* job,
|
||||
const std::string& worker_address)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
int64 task_id = next_task_id_++;
|
||||
DCHECK(!tasks_.contains(task_id));
|
||||
tasks_.insert({task_id, Task(task_id, job->job_id(), job->dataset_id(),
|
||||
|
@ -290,7 +292,7 @@ const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTaskLocked(
|
|||
return tasks_.at(task_id);
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) {
|
||||
Status DataServiceDispatcherImpl::EnsureWorkerStubInitialized(Worker* worker) {
|
||||
if (!worker->stub()) {
|
||||
std::unique_ptr<WorkerService::Stub> stub;
|
||||
TF_RETURN_IF_ERROR(CreateWorkerStub(worker->address(), protocol_, &stub));
|
||||
|
@ -299,8 +301,8 @@ Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task,
|
||||
Worker* worker)
|
||||
Status DataServiceDispatcherImpl::AllocateTaskToWorker(const Task& task,
|
||||
Worker* worker)
|
||||
LOCKS_EXCLUDED(mu_) {
|
||||
TF_RETURN_IF_ERROR(EnsureWorkerStubInitialized(worker));
|
||||
grpc::ClientContext client_ctx;
|
||||
|
@ -322,8 +324,8 @@ Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
|
||||
GetTasksResponse* response) {
|
||||
Status DataServiceDispatcherImpl::GetTasks(const GetTasksRequest* request,
|
||||
GetTasksResponse* response) {
|
||||
mutex_lock l(mu_);
|
||||
VLOG(3) << "Looking up tasks for job id " << request->job_id();
|
||||
auto it = jobs_.find(request->job_id());
|
||||
|
@ -346,8 +348,8 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterImpl::GetWorkers(const GetWorkersRequest* request,
|
||||
GetWorkersResponse* response) {
|
||||
Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request,
|
||||
GetWorkersResponse* response) {
|
||||
mutex_lock l(mu_);
|
||||
VLOG(3) << "Enter GetWorkers";
|
||||
for (auto& worker : workers_) {
|
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_
|
||||
#define TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_
|
||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
|
||||
#define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/data/service/common.pb.h"
|
||||
#include "tensorflow/core/data/service/data_service.h"
|
||||
#include "tensorflow/core/data/service/master.pb.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
@ -40,11 +40,11 @@ namespace data {
|
|||
// ProcessingModeDef which determines what data it produces.
|
||||
// * Task: A job is broken into multiple tasks, which each represent
|
||||
// iterating over all of or part of the dataset. Workers process tasks.
|
||||
class DataServiceMasterImpl {
|
||||
class DataServiceDispatcherImpl {
|
||||
public:
|
||||
explicit DataServiceMasterImpl(const std::string protocol);
|
||||
explicit DataServiceDispatcherImpl(const std::string protocol);
|
||||
|
||||
// See master.proto for API documentation.
|
||||
// See dispatcher.proto for API documentation.
|
||||
|
||||
/// Worker-facing API.
|
||||
Status RegisterWorker(const RegisterWorkerRequest* request,
|
||||
|
@ -191,7 +191,7 @@ class DataServiceMasterImpl {
|
|||
// Creates a new task for a job, returning a reference to the task.
|
||||
const Task& CreateTask(Job* job, const std::string& worker_address)
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
// Same as `CreateTask`, but expects that the master lock is already held.
|
||||
// Same as `CreateTask`, but expects that the dispatcher lock is already held.
|
||||
const Task& CreateTaskLocked(Job* job, const std::string& worker_address)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
// Validates that an existing job matches the given processing_mode and
|
||||
|
@ -225,10 +225,10 @@ class DataServiceMasterImpl {
|
|||
absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_
|
||||
TF_GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceMasterImpl);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl);
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_
|
||||
#endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/data/service/grpc_master_impl.h"
|
||||
#include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
|
||||
|
||||
#include "grpcpp/server_context.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
|
@ -25,18 +25,18 @@ using ::grpc::ServerBuilder;
|
|||
using ::grpc::ServerContext;
|
||||
using ::grpc::Status;
|
||||
|
||||
GrpcMasterImpl::GrpcMasterImpl(ServerBuilder* server_builder,
|
||||
const std::string& protocol)
|
||||
GrpcDispatcherImpl::GrpcDispatcherImpl(ServerBuilder* server_builder,
|
||||
const std::string& protocol)
|
||||
: impl_(protocol) {
|
||||
server_builder->RegisterService(this);
|
||||
VLOG(1) << "Registered data service master";
|
||||
VLOG(1) << "Registered data service dispatcher";
|
||||
}
|
||||
|
||||
#define HANDLER(method) \
|
||||
Status GrpcMasterImpl::method(ServerContext* context, \
|
||||
const method##Request* request, \
|
||||
method##Response* response) { \
|
||||
return ToGrpcStatus(impl_.method(request, response)); \
|
||||
#define HANDLER(method) \
|
||||
Status GrpcDispatcherImpl::method(ServerContext* context, \
|
||||
const method##Request* request, \
|
||||
method##Response* response) { \
|
||||
return ToGrpcStatus(impl_.method(request, response)); \
|
||||
}
|
||||
HANDLER(RegisterWorker);
|
||||
HANDLER(WorkerUpdate);
|
|
@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_
|
||||
#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_
|
||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_
|
||||
#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_
|
||||
|
||||
#include "grpcpp/server_builder.h"
|
||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/master_impl.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/dispatcher_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
@ -29,14 +29,14 @@ namespace data {
|
|||
//
|
||||
// ::grpc::ServerBuilder builder;
|
||||
// // configure builder
|
||||
// GrpcMasterImpl data_service(&builder);
|
||||
// GrpcDispatcherImpl data_service(&builder);
|
||||
// builder.BuildAndStart()
|
||||
//
|
||||
class GrpcMasterImpl : public MasterService::Service {
|
||||
class GrpcDispatcherImpl : public DispatcherService::Service {
|
||||
public:
|
||||
explicit GrpcMasterImpl(grpc::ServerBuilder* server_builder,
|
||||
const std::string& protocol);
|
||||
~GrpcMasterImpl() override {}
|
||||
explicit GrpcDispatcherImpl(grpc::ServerBuilder* server_builder,
|
||||
const std::string& protocol);
|
||||
~GrpcDispatcherImpl() override {}
|
||||
|
||||
#define HANDLER(method) \
|
||||
grpc::Status method(grpc::ServerContext* context, \
|
||||
|
@ -52,12 +52,12 @@ class GrpcMasterImpl : public MasterService::Service {
|
|||
#undef HANDLER
|
||||
|
||||
private:
|
||||
DataServiceMasterImpl impl_;
|
||||
DataServiceDispatcherImpl impl_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterImpl);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcDispatcherImpl);
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_
|
||||
#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_
|
|
@ -26,9 +26,9 @@ using ::grpc::ServerContext;
|
|||
using ::grpc::Status;
|
||||
|
||||
GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder,
|
||||
const std::string& master_address,
|
||||
const std::string& dispatcher_address,
|
||||
const std::string& protocol)
|
||||
: impl_(master_address, protocol) {
|
||||
: impl_(dispatcher_address, protocol) {
|
||||
server_builder->RegisterService(this);
|
||||
VLOG(1) << "Registered data service worker";
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ namespace data {
|
|||
class GrpcWorkerImpl : public WorkerService::Service {
|
||||
public:
|
||||
explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder,
|
||||
const std::string& master_address,
|
||||
const std::string& dispatcher_address,
|
||||
const std::string& protocol);
|
||||
~GrpcWorkerImpl() override {}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/data/service/server_lib.h"
|
||||
|
||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||
#include "tensorflow/core/data/service/grpc_master_impl.h"
|
||||
#include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/grpc_worker_impl.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
@ -72,18 +72,18 @@ void GrpcDataServerBase::Join() { server_->Wait(); }
|
|||
|
||||
int GrpcDataServerBase::BoundPort() { return bound_port(); }
|
||||
|
||||
MasterGrpcDataServer::MasterGrpcDataServer(int port,
|
||||
const std::string& protocol)
|
||||
DispatchGrpcDataServer::DispatchGrpcDataServer(int port,
|
||||
const std::string& protocol)
|
||||
: GrpcDataServerBase(port, protocol) {}
|
||||
|
||||
MasterGrpcDataServer::~MasterGrpcDataServer() { delete service_; }
|
||||
DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
|
||||
|
||||
void MasterGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
|
||||
auto service = absl::make_unique<GrpcMasterImpl>(builder, protocol_);
|
||||
void DispatchGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
|
||||
auto service = absl::make_unique<GrpcDispatcherImpl>(builder, protocol_);
|
||||
service_ = service.release();
|
||||
}
|
||||
|
||||
Status MasterGrpcDataServer::NumWorkers(int* num_workers) {
|
||||
Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
|
||||
GetWorkersRequest req;
|
||||
GetWorkersResponse resp;
|
||||
grpc::ServerContext ctx;
|
||||
|
@ -95,19 +95,18 @@ Status MasterGrpcDataServer::NumWorkers(int* num_workers) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
WorkerGrpcDataServer::WorkerGrpcDataServer(int port,
|
||||
const std::string& protocol,
|
||||
const std::string& master_address,
|
||||
const std::string& worker_address)
|
||||
WorkerGrpcDataServer::WorkerGrpcDataServer(
|
||||
int port, const std::string& protocol,
|
||||
const std::string& dispatcher_address, const std::string& worker_address)
|
||||
: GrpcDataServerBase(port, protocol),
|
||||
master_address_(master_address),
|
||||
dispatcher_address_(dispatcher_address),
|
||||
worker_address_(worker_address) {}
|
||||
|
||||
WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
|
||||
|
||||
void WorkerGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
|
||||
auto service =
|
||||
absl::make_unique<GrpcWorkerImpl>(builder, master_address_, protocol_);
|
||||
auto service = absl::make_unique<GrpcWorkerImpl>(builder, dispatcher_address_,
|
||||
protocol_);
|
||||
service_ = service.release();
|
||||
}
|
||||
|
||||
|
@ -123,25 +122,25 @@ Status WorkerGrpcDataServer::StartServiceInternal() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NewMasterServer(int port, const std::string& protocol,
|
||||
std::unique_ptr<MasterGrpcDataServer>* out_server) {
|
||||
*out_server = absl::make_unique<MasterGrpcDataServer>(port, protocol);
|
||||
Status NewDispatchServer(int port, const std::string& protocol,
|
||||
std::unique_ptr<DispatchGrpcDataServer>* out_server) {
|
||||
*out_server = absl::make_unique<DispatchGrpcDataServer>(port, protocol);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NewWorkerServer(int port, const std::string& protocol,
|
||||
const std::string& master_address,
|
||||
const std::string& dispatcher_address,
|
||||
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
|
||||
return NewWorkerServer(port, protocol, master_address, /*worker_address=*/"",
|
||||
out_server);
|
||||
return NewWorkerServer(port, protocol, dispatcher_address,
|
||||
/*worker_address=*/"", out_server);
|
||||
}
|
||||
|
||||
Status NewWorkerServer(int port, const std::string& protocol,
|
||||
const std::string& master_address,
|
||||
const std::string& dispatcher_address,
|
||||
const std::string& worker_address,
|
||||
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
|
||||
*out_server = absl::make_unique<WorkerGrpcDataServer>(
|
||||
port, protocol, master_address, worker_address);
|
||||
port, protocol, dispatcher_address, worker_address);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ namespace data {
|
|||
|
||||
// Forward declared because transitively depending on .grpc.pb.h files causes
|
||||
// issues in the pywrap build.
|
||||
class GrpcMasterImpl;
|
||||
class GrpcDispatcherImpl;
|
||||
class GrpcWorkerImpl;
|
||||
|
||||
// A grpc server for the tf.data service.
|
||||
|
@ -35,7 +35,7 @@ class GrpcDataServerBase {
|
|||
// server will find an available port in `Start()`. The chosen port can be
|
||||
// found in the output of `Target()`.
|
||||
//
|
||||
// master_address is only needed for worker data servers.
|
||||
// dispatcher_address is only needed for worker data servers.
|
||||
GrpcDataServerBase(int requested_port, const std::string& protocol);
|
||||
virtual ~GrpcDataServerBase() {}
|
||||
|
||||
|
@ -70,12 +70,12 @@ class GrpcDataServerBase {
|
|||
std::unique_ptr<grpc::Server> server_;
|
||||
};
|
||||
|
||||
class MasterGrpcDataServer : public GrpcDataServerBase {
|
||||
class DispatchGrpcDataServer : public GrpcDataServerBase {
|
||||
public:
|
||||
MasterGrpcDataServer(int requested_port, const std::string& protocol);
|
||||
~MasterGrpcDataServer() override;
|
||||
DispatchGrpcDataServer(int requested_port, const std::string& protocol);
|
||||
~DispatchGrpcDataServer() override;
|
||||
|
||||
// Returns the number of workers registerd with the master.
|
||||
// Returns the number of workers registerd with the dispatcher.
|
||||
Status NumWorkers(int* num_workers);
|
||||
|
||||
protected:
|
||||
|
@ -83,14 +83,14 @@ class MasterGrpcDataServer : public GrpcDataServerBase {
|
|||
Status StartServiceInternal() override { return Status::OK(); }
|
||||
|
||||
private:
|
||||
// Owned. We use a raw pointer because GrpcMasterImpl is forward-declared.
|
||||
GrpcMasterImpl* service_;
|
||||
// Owned. We use a raw pointer because GrpcDispatcherImpl is forward-declared.
|
||||
GrpcDispatcherImpl* service_;
|
||||
};
|
||||
|
||||
class WorkerGrpcDataServer : public GrpcDataServerBase {
|
||||
public:
|
||||
WorkerGrpcDataServer(int requested_port, const std::string& protocol,
|
||||
const std::string& master_address,
|
||||
const std::string& dispatcher_address,
|
||||
const std::string& worker_address);
|
||||
~WorkerGrpcDataServer() override;
|
||||
|
||||
|
@ -99,15 +99,15 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
|
|||
Status StartServiceInternal() override;
|
||||
|
||||
private:
|
||||
const std::string master_address_;
|
||||
const std::string dispatcher_address_;
|
||||
const std::string worker_address_;
|
||||
// Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared.
|
||||
GrpcWorkerImpl* service_;
|
||||
};
|
||||
|
||||
// Creates a master tf.data server and stores it in `*out_server`.
|
||||
Status NewMasterServer(int port, const std::string& protocol,
|
||||
std::unique_ptr<MasterGrpcDataServer>* out_server);
|
||||
// Creates a dispatch tf.data server and stores it in `*out_server`.
|
||||
Status NewDispatchServer(int port, const std::string& protocol,
|
||||
std::unique_ptr<DispatchGrpcDataServer>* out_server);
|
||||
|
||||
// Creates a worker tf.data server and stores it in `*out_server`.
|
||||
//
|
||||
|
@ -115,18 +115,18 @@ Status NewMasterServer(int port, const std::string& protocol,
|
|||
// will be chosen in Start(). This value can be queried with BoundPort().
|
||||
//
|
||||
// The worker_address argument is optional. If left empty, it will default to
|
||||
// "localhost:%port%". When the worker registers with the master, the worker
|
||||
// will report the worker address, so that the master can tell clients where to
|
||||
// read from. The address may contain the placeholder "%port%", which will be
|
||||
// "localhost:%port%". When the worker registers with the dispatcher, the worker
|
||||
// will report the worker address, so that the dispatcher can tell clients where
|
||||
// to read from. The address may contain the placeholder "%port%", which will be
|
||||
// replaced with the value of BoundPort().
|
||||
Status NewWorkerServer(int port, const std::string& protocol,
|
||||
const std::string& master_address,
|
||||
const std::string& dispatcher_address,
|
||||
const std::string& worker_address,
|
||||
std::unique_ptr<WorkerGrpcDataServer>* out_server);
|
||||
|
||||
// Creates a worker using the default worker_address.
|
||||
Status NewWorkerServer(int port, const std::string& protocol,
|
||||
const std::string& master_address,
|
||||
const std::string& dispatcher_address,
|
||||
std::unique_ptr<WorkerGrpcDataServer>* out_server);
|
||||
|
||||
} // namespace data
|
||||
|
|
|
@ -45,9 +45,9 @@ Status TestCluster::Initialize() {
|
|||
"Test cluster has already been initialized.");
|
||||
}
|
||||
initialized_ = true;
|
||||
TF_RETURN_IF_ERROR(NewMasterServer(/*port=*/0, kProtocol, &master_));
|
||||
TF_RETURN_IF_ERROR(master_->Start());
|
||||
master_address_ = absl::StrCat("localhost:", master_->BoundPort());
|
||||
TF_RETURN_IF_ERROR(NewDispatchServer(/*port=*/0, kProtocol, &dispatcher_));
|
||||
TF_RETURN_IF_ERROR(dispatcher_->Start());
|
||||
dispatcher_address_ = absl::StrCat("localhost:", dispatcher_->BoundPort());
|
||||
workers_.reserve(num_workers_);
|
||||
worker_addresses_.reserve(num_workers_);
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
|
@ -59,14 +59,14 @@ Status TestCluster::Initialize() {
|
|||
Status TestCluster::AddWorker() {
|
||||
std::unique_ptr<WorkerGrpcDataServer> worker;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NewWorkerServer(/*port=*/0, kProtocol, master_address_, &worker));
|
||||
NewWorkerServer(/*port=*/0, kProtocol, dispatcher_address_, &worker));
|
||||
TF_RETURN_IF_ERROR(worker->Start());
|
||||
worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort()));
|
||||
workers_.push_back(std::move(worker));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::string TestCluster::MasterAddress() { return master_address_; }
|
||||
std::string TestCluster::DispatcherAddress() { return dispatcher_address_; }
|
||||
|
||||
std::string TestCluster::WorkerAddress(int index) {
|
||||
DCHECK_GE(index, 0);
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace data {
|
|||
// Helper class for unit testing a tf.data service cluster.
|
||||
class TestCluster {
|
||||
public:
|
||||
// Creates a new test cluster with a master and `num_workers` workers.
|
||||
// Creates a new test cluster with a dispatcher and `num_workers` workers.
|
||||
explicit TestCluster(int num_workers);
|
||||
|
||||
// Initializes the test cluster. This must be called before interacting with
|
||||
|
@ -32,8 +32,8 @@ class TestCluster {
|
|||
Status Initialize();
|
||||
// Adds a new worker to the cluster.
|
||||
Status AddWorker();
|
||||
// Returns the master address in the form "hostname:port".
|
||||
std::string MasterAddress();
|
||||
// Returns the dispatcher address in the form "hostname:port".
|
||||
std::string DispatcherAddress();
|
||||
// Returns the address of the worker at the specified index, in the form
|
||||
// "hostname:port". The index must be non-negative and less than the number of
|
||||
// workers in the cluster.
|
||||
|
@ -42,8 +42,8 @@ class TestCluster {
|
|||
private:
|
||||
bool initialized_ = false;
|
||||
int num_workers_;
|
||||
std::unique_ptr<MasterGrpcDataServer> master_;
|
||||
std::string master_address_;
|
||||
std::unique_ptr<DispatchGrpcDataServer> dispatcher_;
|
||||
std::string dispatcher_address_;
|
||||
std::vector<std::unique_ptr<WorkerGrpcDataServer>> workers_;
|
||||
std::vector<std::string> worker_addresses_;
|
||||
};
|
||||
|
|
|
@ -21,9 +21,9 @@ limitations under the License.
|
|||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/data/dataset.pb.h"
|
||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/master.pb.h"
|
||||
#include "tensorflow/core/data/standalone.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
@ -45,9 +45,9 @@ auto* tf_data_service_created =
|
|||
"has been created.");
|
||||
} // namespace
|
||||
|
||||
DataServiceWorkerImpl::DataServiceWorkerImpl(const std::string& master_address,
|
||||
const std::string& protocol)
|
||||
: master_address_(master_address), protocol_(protocol) {
|
||||
DataServiceWorkerImpl::DataServiceWorkerImpl(
|
||||
const std::string& dispatcher_address, const std::string& protocol)
|
||||
: dispatcher_address_(dispatcher_address), protocol_(protocol) {
|
||||
tf_data_service_created->GetCell()->Set(true);
|
||||
}
|
||||
|
||||
|
@ -67,14 +67,13 @@ void DataServiceWorkerImpl::Start(const std::string& worker_address) {
|
|||
heartbeat_thread_.reset(thread);
|
||||
Status s = Register();
|
||||
while (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to register with master at " << master_address_
|
||||
<< ": " << s;
|
||||
LOG(WARNING) << "Failed to register with dispatcher at "
|
||||
<< dispatcher_address_ << ": " << s;
|
||||
Env::Default()->SleepForMicroseconds(kHeartbeatIntervalMicros);
|
||||
s = Register();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
|
||||
ProcessTaskResponse* response) {
|
||||
mutex_lock l(mu_);
|
||||
|
@ -169,29 +168,29 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceWorkerImpl::EnsureMasterStubInitialized()
|
||||
Status DataServiceWorkerImpl::EnsureDispatcherStubInitialized()
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (!master_stub_) {
|
||||
if (!dispatcher_stub_) {
|
||||
::grpc::ChannelArguments args;
|
||||
std::shared_ptr<::grpc::ChannelCredentials> credentials;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
|
||||
auto channel =
|
||||
::grpc::CreateCustomChannel(master_address_, credentials, args);
|
||||
master_stub_ = MasterService::NewStub(channel);
|
||||
::grpc::CreateCustomChannel(dispatcher_address_, credentials, args);
|
||||
dispatcher_stub_ = DispatcherService::NewStub(channel);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
VLOG(3) << "Registering with master at " << master_address_;
|
||||
TF_RETURN_IF_ERROR(EnsureMasterStubInitialized());
|
||||
VLOG(3) << "Registering with dispatcher at " << dispatcher_address_;
|
||||
TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
|
||||
RegisterWorkerRequest req;
|
||||
req.set_worker_address(worker_address_);
|
||||
RegisterWorkerResponse resp;
|
||||
|
||||
grpc::ClientContext ctx;
|
||||
grpc::Status s = master_stub_->RegisterWorker(&ctx, req, &resp);
|
||||
grpc::Status s = dispatcher_stub_->RegisterWorker(&ctx, req, &resp);
|
||||
if (!s.ok()) {
|
||||
return grpc_util::WrapError("Failed to register worker", s);
|
||||
}
|
||||
|
@ -205,8 +204,8 @@ Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
|||
|
||||
Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
VLOG(3) << "Sending " << pending_completed_tasks_.size()
|
||||
<< " task updates to master";
|
||||
TF_RETURN_IF_ERROR(EnsureMasterStubInitialized());
|
||||
<< " task updates to dispatcher";
|
||||
TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
|
||||
WorkerUpdateRequest req;
|
||||
req.set_worker_id(worker_id_);
|
||||
for (int task_id : pending_completed_tasks_) {
|
||||
|
@ -217,7 +216,7 @@ Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
|||
|
||||
WorkerUpdateResponse resp;
|
||||
grpc::ClientContext ctx;
|
||||
grpc::Status s = master_stub_->WorkerUpdate(&ctx, req, &resp);
|
||||
grpc::Status s = dispatcher_stub_->WorkerUpdate(&ctx, req, &resp);
|
||||
if (!s.ok()) {
|
||||
return grpc_util::WrapError("Failed to send task updates", s);
|
||||
}
|
||||
|
@ -238,7 +237,7 @@ void DataServiceWorkerImpl::HeartbeatThread() {
|
|||
}
|
||||
Status s = SendTaskUpdate();
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to send task updates to master: " << s;
|
||||
LOG(WARNING) << "Failed to send task updates to dispatcher: " << s;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/data/service/common.pb.h"
|
||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.pb.h"
|
||||
#include "tensorflow/core/data/standalone.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
@ -29,17 +29,17 @@ namespace data {
|
|||
// A TensorFlow DataService serves dataset elements over RPC.
|
||||
class DataServiceWorkerImpl {
|
||||
public:
|
||||
explicit DataServiceWorkerImpl(const std::string& master_address,
|
||||
explicit DataServiceWorkerImpl(const std::string& dispatcher_address,
|
||||
const std::string& protocol);
|
||||
~DataServiceWorkerImpl();
|
||||
|
||||
// Starts the worker. The worker needs to know its own address so that it can
|
||||
// register with the master.
|
||||
// register with the dispatcher.
|
||||
void Start(const std::string& worker_address);
|
||||
|
||||
// See worker.proto for API documentation.
|
||||
|
||||
/// Master-facing API.
|
||||
/// Dispatcher-facing API.
|
||||
Status ProcessTask(const ProcessTaskRequest* request,
|
||||
ProcessTaskResponse* response);
|
||||
|
||||
|
@ -48,15 +48,15 @@ class DataServiceWorkerImpl {
|
|||
GetElementResponse* response);
|
||||
|
||||
private:
|
||||
// Sets master_stub_ if it isn't already set.
|
||||
Status EnsureMasterStubInitialized();
|
||||
// Registers the worker with the master.
|
||||
// Sets dispatcher_stub_ if it isn't already set.
|
||||
Status EnsureDispatcherStubInitialized();
|
||||
// Registers the worker with the dispatcher.
|
||||
Status Register();
|
||||
// Sends task status to the master.
|
||||
// Sends task status to the dispatcher.
|
||||
Status SendTaskUpdate();
|
||||
// Creates an iterator to process a task.
|
||||
Status ProcessTaskInternal(const TaskDef& task);
|
||||
// A thread for updating the master with worker status.
|
||||
// A thread for updating the dispatcher with worker status.
|
||||
void HeartbeatThread();
|
||||
|
||||
typedef struct Task {
|
||||
|
@ -67,18 +67,19 @@ class DataServiceWorkerImpl {
|
|||
std::unique_ptr<standalone::Iterator> iterator;
|
||||
} Task;
|
||||
|
||||
const std::string master_address_;
|
||||
// Protocol for communicating with the master.
|
||||
const std::string dispatcher_address_;
|
||||
// Protocol for communicating with the dispatcher.
|
||||
const std::string protocol_;
|
||||
// The worker's own address.
|
||||
std::string worker_address_;
|
||||
|
||||
mutex mu_;
|
||||
int64 worker_id_ TF_GUARDED_BY(mu_);
|
||||
std::unique_ptr<MasterService::Stub> master_stub_ TF_GUARDED_BY(mu_);
|
||||
std::unique_ptr<DispatcherService::Stub> dispatcher_stub_ TF_GUARDED_BY(mu_);
|
||||
// Information about tasks, keyed by task ids.
|
||||
absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_);
|
||||
// List of completed tasks which haven't yet been communicated to the master.
|
||||
// List of completed tasks which haven't yet been communicated to the
|
||||
// dispatcher.
|
||||
std::vector<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
|
||||
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
||||
// Condition variable for notifying the heartbeat thread.
|
||||
|
|
|
@ -69,7 +69,7 @@ const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second.
|
|||
// Dataset for reading data from the tf.data service non-deterministically.
|
||||
//
|
||||
// This dataset interleaves dataset elements produced by multiple tf.data
|
||||
// workers. We periodically query the tf.data master to determine which workers
|
||||
// workers. We periodically query the dispatcher to determine which workers
|
||||
// to read from (in case workers are added or removed).
|
||||
class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
|
@ -199,12 +199,13 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||
Status Initialize(IteratorContext* ctx) override {
|
||||
VLOG(3) << "Connecting to " << dataset()->address_
|
||||
<< " in data service dataset op";
|
||||
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
|
||||
DataServiceDispatcherClient dispatcher(dataset()->address_,
|
||||
dataset()->protocol_);
|
||||
if (dataset()->job_name_.empty()) {
|
||||
TF_RETURN_IF_ERROR(master.CreateJob(
|
||||
TF_RETURN_IF_ERROR(dispatcher.CreateJob(
|
||||
dataset()->dataset_id_, dataset()->processing_mode_, &job_id_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(master.GetOrCreateJob(
|
||||
TF_RETURN_IF_ERROR(dispatcher.GetOrCreateJob(
|
||||
dataset()->dataset_id_, dataset()->processing_mode_,
|
||||
dataset()->job_name_, iterator_index_, &job_id_));
|
||||
}
|
||||
|
@ -283,11 +284,12 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||
|
||||
// Periodically refresh the task list.
|
||||
// Maintain one thread fetching elements for each task.
|
||||
// TODO(aaudibert): Instead of polling, have master send updates when
|
||||
// TODO(aaudibert): Instead of polling, have dispatcher send updates when
|
||||
// the list of tasks changes.
|
||||
void TaskThreadManager(std::unique_ptr<IteratorContext> ctx) {
|
||||
VLOG(3) << "Starting task thread manager";
|
||||
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
|
||||
DataServiceDispatcherClient dispatcher(dataset()->address_,
|
||||
dataset()->protocol_);
|
||||
uint64 next_check = Env::Default()->NowMicros();
|
||||
while (true) {
|
||||
{
|
||||
|
@ -305,18 +307,19 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||
return;
|
||||
}
|
||||
}
|
||||
UpdateTasks(&master);
|
||||
UpdateTasks(&dispatcher);
|
||||
UpdateWorkerThreads(ctx.get());
|
||||
next_check = Env::Default()->NowMicros() +
|
||||
dataset()->task_refresh_interval_ms_ * 1000;
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateTasks(DataServiceMasterClient* master) LOCKS_EXCLUDED(mu_) {
|
||||
void UpdateTasks(DataServiceDispatcherClient* dispatcher)
|
||||
LOCKS_EXCLUDED(mu_) {
|
||||
VLOG(3) << "Updating tasks";
|
||||
std::vector<TaskInfo> tasks;
|
||||
bool job_finished;
|
||||
Status s = master->GetTasks(job_id_, &tasks, &job_finished);
|
||||
Status s = dispatcher->GetTasks(job_id_, &tasks, &job_finished);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to get task info for job id " << job_id_ << ": "
|
||||
<< s;
|
||||
|
|
|
@ -53,7 +53,7 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
|
|||
OP_REQUIRES_OK(
|
||||
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
|
||||
|
||||
DataServiceMasterClient client(address, protocol);
|
||||
DataServiceDispatcherClient client(address, protocol);
|
||||
int64 dataset_id;
|
||||
OP_REQUIRES_OK(ctx, client.RegisterDataset(graph_def, &dataset_id));
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ namespace data {
|
|||
|
||||
// Registers a dataset with the tf.data service.
|
||||
//
|
||||
// The address and protocol inputs are used to connect to the tf.data master.
|
||||
// The address and protocol inputs are used to connect to the dispatcher.
|
||||
// The external state policy attribute determines whether to ignore, warn, or
|
||||
// error out when the dataset contains external state.
|
||||
// The op produces a dataset id for identifying the registered dataset.
|
||||
|
|
|
@ -77,7 +77,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
|||
amount of memory used, since `distribute` won't use more than
|
||||
`element_size` * `max_outstanding_requests` of memory.
|
||||
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
|
||||
the master for task changes.
|
||||
the dispatcher for task changes.
|
||||
"""
|
||||
|
||||
if job_name is None:
|
||||
|
@ -173,7 +173,7 @@ def _distribute(processing_mode,
|
|||
of memory used, since `distribute` won't use more than `element_size` *
|
||||
`max_outstanding_requests` of memory.
|
||||
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
|
||||
master for task changes.
|
||||
dispatcher for task changes.
|
||||
|
||||
Returns:
|
||||
Dataset: A `Dataset` of the elements produced by the data service.
|
||||
|
|
|
@ -19,5 +19,5 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.ops.data_service_ops import distribute
|
||||
from tensorflow.python.data.experimental.service.server_lib import MasterServer
|
||||
from tensorflow.python.data.experimental.service.server_lib import DispatchServer
|
||||
from tensorflow.python.data.experimental.service.server_lib import WorkerServer
|
||||
|
|
|
@ -24,35 +24,35 @@ from tensorflow.python.data.experimental.service import _pywrap_server_lib
|
|||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("data.experimental.service.MasterServer", v1=[])
|
||||
class MasterServer(object):
|
||||
"""An in-process tf.data service master server.
|
||||
@tf_export("data.experimental.service.DispatchServer", v1=[])
|
||||
class DispatchServer(object):
|
||||
"""An in-process tf.data service dispatch server.
|
||||
|
||||
A `tf.data.experimental.service.MasterServer` coordinates a cluster of
|
||||
A `tf.data.experimental.service.DispatchServer` coordinates a cluster of
|
||||
`tf.data.experimental.service.WorkerServer`s. When the workers start, they
|
||||
register themselves with the master.
|
||||
register themselves with the dispatcher.
|
||||
|
||||
>>> master = tf.data.experimental.service.MasterServer(port=0)
|
||||
>>> master_address = master.target.split("://")[1]
|
||||
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
|
||||
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
||||
>>> worker = tf.data.experimental.service.WorkerServer(
|
||||
... port=0, master_address=master_address)
|
||||
... port=0, dispatcher_address=dispatcher_address)
|
||||
>>> dataset = tf.data.Dataset.range(10)
|
||||
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||
... processing_mode="parallel_epochs", service=master.target))
|
||||
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||||
>>> print(list(dataset.as_numpy_iterator()))
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
|
||||
When starting a dedicated tf.data master process, use join() to block
|
||||
When starting a dedicated tf.data dispatch process, use join() to block
|
||||
indefinitely after starting up the server.
|
||||
|
||||
```
|
||||
master = tf.data.experimental.service.MasterServer(port=5050)
|
||||
master.join()
|
||||
dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
|
||||
dispatcher.join()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, port, protocol=None, start=True):
|
||||
"""Creates a new master server.
|
||||
"""Creates a new dispatch server.
|
||||
|
||||
Args:
|
||||
port: Specifies the port to bind to.
|
||||
|
@ -68,15 +68,16 @@ class MasterServer(object):
|
|||
if protocol is None:
|
||||
protocol = "grpc"
|
||||
self._protocol = protocol
|
||||
self._server = _pywrap_server_lib.TF_DATA_NewMasterServer(port, protocol)
|
||||
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(port, protocol)
|
||||
if start:
|
||||
self._server.start()
|
||||
|
||||
def start(self):
|
||||
"""Starts this server.
|
||||
|
||||
>>> master = tf.data.experimental.service.MasterServer(port=0, start=False)
|
||||
>>> master.start()
|
||||
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0,
|
||||
... start=False)
|
||||
>>> dispatcher.start()
|
||||
|
||||
Raises:
|
||||
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
||||
|
@ -87,11 +88,11 @@ class MasterServer(object):
|
|||
def join(self):
|
||||
"""Blocks until the server has shut down.
|
||||
|
||||
This is useful when starting a dedicated master process.
|
||||
This is useful when starting a dedicated dispatch process.
|
||||
|
||||
```
|
||||
master = tf.data.experimental.service.MasterServer(port=5050)
|
||||
master.join()
|
||||
dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
|
||||
dispatcher.join()
|
||||
```
|
||||
|
||||
Raises:
|
||||
|
@ -104,10 +105,10 @@ class MasterServer(object):
|
|||
def target(self):
|
||||
"""Returns a target that can be used to connect to the server.
|
||||
|
||||
>>> master = tf.data.experimental.service.MasterServer(port=0)
|
||||
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
|
||||
>>> dataset = tf.data.Dataset.range(10)
|
||||
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||
... processing_mode="parallel_epochs", service=master.target))
|
||||
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||||
|
||||
The returned string will be in the form protocol://address, e.g.
|
||||
"grpc://localhost:5050".
|
||||
|
@ -136,7 +137,7 @@ class MasterServer(object):
|
|||
return "localhost:{0}".format(self._server.bound_port())
|
||||
|
||||
def _num_workers(self):
|
||||
"""Returns the number of workers registered with the master."""
|
||||
"""Returns the number of workers registered with the dispatcher."""
|
||||
return self._server.num_workers()
|
||||
|
||||
|
||||
|
@ -147,15 +148,15 @@ class WorkerServer(object):
|
|||
A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset`
|
||||
processing for user-defined datasets, and provides the resulting elements over
|
||||
RPC. A worker is associated with a single
|
||||
`tf.data.experimental.service.MasterServer`.
|
||||
`tf.data.experimental.service.DispatchServer`.
|
||||
|
||||
>>> master = tf.data.experimental.service.MasterServer(port=0)
|
||||
>>> master_address = master.target.split("://")[1]
|
||||
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
|
||||
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
||||
>>> worker = tf.data.experimental.service.WorkerServer(
|
||||
... port=0, master_address=master_address)
|
||||
... port=0, dispatcher_address=dispatcher_address)
|
||||
>>> dataset = tf.data.Dataset.range(10)
|
||||
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||
... processing_mode="parallel_epochs", service=master.target))
|
||||
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||||
>>> print(list(dataset.as_numpy_iterator()))
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
|
||||
|
@ -164,14 +165,14 @@ class WorkerServer(object):
|
|||
|
||||
```
|
||||
worker = tf.data.experimental.service.WorkerServer(
|
||||
port=5051, master_address="grpc://localhost:5050")
|
||||
port=5051, dispatcher_address="grpc://localhost:5050")
|
||||
worker.join()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
port,
|
||||
master_address,
|
||||
dispatcher_address,
|
||||
worker_address=None,
|
||||
protocol=None,
|
||||
start=True):
|
||||
|
@ -180,11 +181,12 @@ class WorkerServer(object):
|
|||
Args:
|
||||
port: Specifies the port to bind to. A value of 0 indicates that the
|
||||
worker can bind to any available port.
|
||||
master_address: Specifies the address of the master server.
|
||||
dispatcher_address: Specifies the address of the dispatcher.
|
||||
worker_address: (Optional.) Specifies the address of the worker server.
|
||||
This address is passed to the master server so that the master can tell
|
||||
clients how to connect to this worker. Defaults to `"localhost:%port%"`,
|
||||
where `%port%` will be replaced with the port used by the worker.
|
||||
This address is passed to the dispatcher so that the dispatcher can
|
||||
tell clients how to connect to this worker. Defaults to
|
||||
`"localhost:%port%"`, where `%port%` will be replaced with the port used
|
||||
by the worker.
|
||||
protocol: (Optional.) Specifies the protocol to be used by the server.
|
||||
Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`.
|
||||
start: (Optional.) Boolean, indicating whether to start the server after
|
||||
|
@ -201,7 +203,7 @@ class WorkerServer(object):
|
|||
|
||||
self._protocol = protocol
|
||||
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
|
||||
port, protocol, master_address, worker_address)
|
||||
port, protocol, dispatcher_address, worker_address)
|
||||
if start:
|
||||
self._server.start()
|
||||
|
||||
|
@ -221,7 +223,7 @@ class WorkerServer(object):
|
|||
|
||||
```
|
||||
worker_server = tf.data.experimental.service.WorkerServer(
|
||||
port=5051, master_address="grpc://localhost:5050")
|
||||
port=5051, dispatcher_address="grpc://localhost:5050")
|
||||
worker_server.join()
|
||||
```
|
||||
|
||||
|
|
|
@ -25,68 +25,68 @@ from tensorflow.python.platform import test
|
|||
|
||||
class ServerLibTest(test.TestCase):
|
||||
|
||||
def testStartMaster(self):
|
||||
master = server_lib.MasterServer(0, start=False)
|
||||
master.start()
|
||||
def testStartDispatcher(self):
|
||||
dispatcher = server_lib.DispatchServer(0, start=False)
|
||||
dispatcher.start()
|
||||
|
||||
def testMultipleStartMaster(self):
|
||||
master = server_lib.MasterServer(0, start=True)
|
||||
master.start()
|
||||
def testMultipleStartDispatcher(self):
|
||||
dispatcher = server_lib.DispatchServer(0, start=True)
|
||||
dispatcher.start()
|
||||
|
||||
def testStartWorker(self):
|
||||
master = server_lib.MasterServer(0)
|
||||
worker = server_lib.WorkerServer(0, master._address, start=False)
|
||||
dispatcher = server_lib.DispatchServer(0)
|
||||
worker = server_lib.WorkerServer(0, dispatcher._address, start=False)
|
||||
worker.start()
|
||||
|
||||
def testMultipleStartWorker(self):
|
||||
master = server_lib.MasterServer(0)
|
||||
worker = server_lib.WorkerServer(0, master._address, start=True)
|
||||
dispatcher = server_lib.DispatchServer(0)
|
||||
worker = server_lib.WorkerServer(0, dispatcher._address, start=True)
|
||||
worker.start()
|
||||
|
||||
def testStopMaster(self):
|
||||
master = server_lib.MasterServer(0)
|
||||
master._stop()
|
||||
master._stop()
|
||||
def testStopDispatcher(self):
|
||||
dispatcher = server_lib.DispatchServer(0)
|
||||
dispatcher._stop()
|
||||
dispatcher._stop()
|
||||
|
||||
def testStopWorker(self):
|
||||
master = server_lib.MasterServer(0)
|
||||
worker = server_lib.WorkerServer(0, master._address)
|
||||
dispatcher = server_lib.DispatchServer(0)
|
||||
worker = server_lib.WorkerServer(0, dispatcher._address)
|
||||
worker._stop()
|
||||
worker._stop()
|
||||
|
||||
def testStopStartMaster(self):
|
||||
master = server_lib.MasterServer(0)
|
||||
master._stop()
|
||||
def testStopStartDispatcher(self):
|
||||
dispatcher = server_lib.DispatchServer(0)
|
||||
dispatcher._stop()
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Server cannot be started after it has been stopped"):
|
||||
master.start()
|
||||
dispatcher.start()
|
||||
|
||||
def testStopStartWorker(self):
|
||||
master = server_lib.MasterServer(0)
|
||||
worker = server_lib.WorkerServer(0, master._address)
|
||||
dispatcher = server_lib.DispatchServer(0)
|
||||
worker = server_lib.WorkerServer(0, dispatcher._address)
|
||||
worker._stop()
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Server cannot be started after it has been stopped"):
|
||||
worker.start()
|
||||
|
||||
def testJoinMaster(self):
|
||||
master = server_lib.MasterServer(0)
|
||||
master._stop()
|
||||
master.join()
|
||||
def testJoinDispatcher(self):
|
||||
dispatcher = server_lib.DispatchServer(0)
|
||||
dispatcher._stop()
|
||||
dispatcher.join()
|
||||
|
||||
def testJoinWorker(self):
|
||||
master = server_lib.MasterServer(0)
|
||||
worker = server_lib.WorkerServer(0, master._address)
|
||||
dispatcher = server_lib.DispatchServer(0)
|
||||
worker = server_lib.WorkerServer(0, dispatcher._address)
|
||||
worker._stop()
|
||||
worker.join()
|
||||
|
||||
def testMasterNumWorkers(self):
|
||||
master = server_lib.MasterServer(0)
|
||||
self.assertEqual(0, master._num_workers())
|
||||
worker1 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable
|
||||
self.assertEqual(1, master._num_workers())
|
||||
worker2 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable
|
||||
self.assertEqual(2, master._num_workers())
|
||||
def testDispatcherNumWorkers(self):
|
||||
dispatcher = server_lib.DispatchServer(0)
|
||||
self.assertEqual(0, dispatcher._num_workers())
|
||||
worker1 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable
|
||||
self.assertEqual(1, dispatcher._num_workers())
|
||||
worker2 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable
|
||||
self.assertEqual(2, dispatcher._num_workers())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -28,13 +28,14 @@ limitations under the License.
|
|||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(_pywrap_server_lib, m) {
|
||||
py::class_<tensorflow::data::MasterGrpcDataServer>(m, "MasterGrpcDataServer")
|
||||
.def("start", &tensorflow::data::MasterGrpcDataServer::Start)
|
||||
.def("stop", &tensorflow::data::MasterGrpcDataServer::Stop)
|
||||
.def("join", &tensorflow::data::MasterGrpcDataServer::Join)
|
||||
.def("bound_port", &tensorflow::data::MasterGrpcDataServer::BoundPort)
|
||||
py::class_<tensorflow::data::DispatchGrpcDataServer>(m,
|
||||
"DispatchGrpcDataServer")
|
||||
.def("start", &tensorflow::data::DispatchGrpcDataServer::Start)
|
||||
.def("stop", &tensorflow::data::DispatchGrpcDataServer::Stop)
|
||||
.def("join", &tensorflow::data::DispatchGrpcDataServer::Join)
|
||||
.def("bound_port", &tensorflow::data::DispatchGrpcDataServer::BoundPort)
|
||||
.def("num_workers",
|
||||
[](tensorflow::data::MasterGrpcDataServer* server) -> int {
|
||||
[](tensorflow::data::DispatchGrpcDataServer* server) -> int {
|
||||
int num_workers;
|
||||
tensorflow::Status status = server->NumWorkers(&num_workers);
|
||||
tensorflow::MaybeRaiseFromStatus(status);
|
||||
|
@ -48,12 +49,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
|
|||
.def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort);
|
||||
|
||||
m.def(
|
||||
"TF_DATA_NewMasterServer",
|
||||
"TF_DATA_NewDispatchServer",
|
||||
[](int port, std::string protocol)
|
||||
-> std::unique_ptr<tensorflow::data::MasterGrpcDataServer> {
|
||||
std::unique_ptr<tensorflow::data::MasterGrpcDataServer> server;
|
||||
-> std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> {
|
||||
std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> server;
|
||||
tensorflow::Status status =
|
||||
tensorflow::data::NewMasterServer(port, protocol, &server);
|
||||
tensorflow::data::NewDispatchServer(port, protocol, &server);
|
||||
tensorflow::MaybeRaiseFromStatus(status);
|
||||
return server;
|
||||
},
|
||||
|
@ -61,12 +62,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
|
|||
|
||||
m.def(
|
||||
"TF_DATA_NewWorkerServer",
|
||||
[](int port, std::string protocol, std::string master_address,
|
||||
[](int port, std::string protocol, std::string dispatcher_address,
|
||||
std::string worker_address)
|
||||
-> std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> {
|
||||
std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server;
|
||||
tensorflow::Status status = tensorflow::data::NewWorkerServer(
|
||||
port, protocol, master_address, worker_address, &server);
|
||||
port, protocol, dispatcher_address, worker_address, &server);
|
||||
tensorflow::MaybeRaiseFromStatus(status);
|
||||
return server;
|
||||
},
|
||||
|
|
|
@ -59,23 +59,25 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
num_workers: The number of workers in the cluster.
|
||||
|
||||
Returns:
|
||||
The address of the master.
|
||||
The address of the dispatcher.
|
||||
"""
|
||||
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
|
||||
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
|
||||
self._servers = []
|
||||
for _ in range(num_workers):
|
||||
self._servers.append(
|
||||
server_lib.WorkerServer(
|
||||
port=0, master_address=self._master._address, protocol=PROTOCOL))
|
||||
port=0,
|
||||
dispatcher_address=self._dispatcher._address,
|
||||
protocol=PROTOCOL))
|
||||
|
||||
return self._master._address
|
||||
return self._dispatcher._address
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeBasic(self):
|
||||
num_elements = 10
|
||||
master_address = self.create_cluster(1)
|
||||
dispatcher_address = self.create_cluster(1)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = _make_distributed_dataset(ds, master_address)
|
||||
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||
results = [elem.numpy() for elem in ds]
|
||||
self.assertEqual(list(range(num_elements)), results)
|
||||
|
||||
|
@ -83,10 +85,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
def testDifferentShuffleOrders(self):
|
||||
random_seed.set_random_seed(None)
|
||||
num_elements = 100
|
||||
master_address = self.create_cluster(2)
|
||||
dispatcher_address = self.create_cluster(2)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = ds.shuffle(num_elements)
|
||||
ds = _make_distributed_dataset(ds, master_address)
|
||||
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||
output = [elem.numpy() for elem in ds]
|
||||
|
||||
# The output will be two sequences of range(num_elements)
|
||||
|
@ -104,9 +106,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testMultipleEpochs(self):
|
||||
num_elements = 3
|
||||
master_address = self.create_cluster(1)
|
||||
dispatcher_address = self.create_cluster(1)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = _make_distributed_dataset(ds, master_address)
|
||||
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||
for _ in range(10):
|
||||
self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds])
|
||||
|
||||
|
@ -114,9 +116,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
def testRepeatedDataset(self):
|
||||
num_elements = 10
|
||||
num_repetitions = 5
|
||||
master_address = self.create_cluster(1)
|
||||
dispatcher_address = self.create_cluster(1)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = _make_distributed_dataset(ds, master_address)
|
||||
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||
ds = ds.repeat(num_repetitions)
|
||||
self.assertDatasetProduces(
|
||||
ds, expected_output=num_repetitions * list(range(num_elements)))
|
||||
|
@ -125,12 +127,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
def testConcurrentEpoch(self):
|
||||
num_elements = 10
|
||||
num_datasets = 3
|
||||
master_address = self.create_cluster(1)
|
||||
dispatcher_address = self.create_cluster(1)
|
||||
iterators = []
|
||||
results = []
|
||||
for _ in range(num_datasets):
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = _make_distributed_dataset(ds, master_address)
|
||||
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||
iterators.append(iter(ds))
|
||||
results.append([])
|
||||
|
||||
|
@ -146,9 +148,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
self.skipTest("Not yet implemented")
|
||||
num_elements = 10
|
||||
num_iterators = 3
|
||||
master_address = self.create_cluster(1)
|
||||
dispatcher_address = self.create_cluster(1)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = _make_distributed_dataset(ds, master_address)
|
||||
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||
result = []
|
||||
iterators = []
|
||||
for _ in range(num_iterators):
|
||||
|
@ -170,20 +172,20 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
def testMultiWorker(self):
|
||||
num_workers = 3
|
||||
num_elements = 10
|
||||
master_address = self.create_cluster(num_workers)
|
||||
dispatcher_address = self.create_cluster(num_workers)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = _make_distributed_dataset(ds, master_address)
|
||||
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||
results = [elem.numpy() for elem in ds]
|
||||
self.assertCountEqual(num_workers * list(range(num_elements)), results)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testAddWorkerMidJob(self):
|
||||
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
|
||||
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
|
||||
self._worker = server_lib.WorkerServer(
|
||||
port=0, master_address=self._master._address, protocol=PROTOCOL)
|
||||
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
|
||||
num_elements = 100
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = _make_distributed_dataset(ds, self._master._address)
|
||||
ds = _make_distributed_dataset(ds, self._dispatcher._address)
|
||||
iterator = iter(ds)
|
||||
results = []
|
||||
# Read halfway through the dataset.
|
||||
|
@ -191,10 +193,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
results.append(next(iterator).numpy())
|
||||
|
||||
self._new_worker = server_lib.WorkerServer(
|
||||
port=0, master_address=self._master._address, protocol=PROTOCOL)
|
||||
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
|
||||
|
||||
# Wait for the new worker to register with the master.
|
||||
while self._master._num_workers() < 2:
|
||||
# Wait for the new worker to register with the dispatcher.
|
||||
while self._dispatcher._num_workers() < 2:
|
||||
time.sleep(10 / 1000) # 10ms
|
||||
|
||||
for elem in iterator:
|
||||
|
@ -206,12 +208,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
combinations.times(test_base.eager_only_combinations(),
|
||||
combinations.combine(use_same_port=[True, False])))
|
||||
def testRestartWorker(self, use_same_port):
|
||||
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
|
||||
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
|
||||
self._worker = server_lib.WorkerServer(
|
||||
port=0, master_address=self._master._address, protocol=PROTOCOL)
|
||||
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
|
||||
num_elements = 100
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = _make_distributed_dataset(ds, self._master._address)
|
||||
ds = _make_distributed_dataset(ds, self._dispatcher._address)
|
||||
iterator = iter(ds)
|
||||
# Read halfway through the dataset.
|
||||
midpoint = num_elements // 2
|
||||
|
@ -224,7 +226,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
port = int(self._worker._address.split(":")[1])
|
||||
self._worker._stop()
|
||||
self._new_worker = server_lib.WorkerServer(
|
||||
port=port, master_address=self._master._address, protocol=PROTOCOL)
|
||||
port=port,
|
||||
dispatcher_address=self._dispatcher._address,
|
||||
protocol=PROTOCOL)
|
||||
|
||||
# There may have been some elements prefetched from the first worker
|
||||
# before it was stopped.
|
||||
|
@ -259,12 +263,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
def testInsideFunction(self):
|
||||
num_workers = 3
|
||||
num_elements = 10
|
||||
master_address = self.create_cluster(num_workers)
|
||||
dispatcher_address = self.create_cluster(num_workers)
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = _make_distributed_dataset(ds, master_address)
|
||||
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||
result = tensor_array_ops.TensorArray(
|
||||
dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
|
||||
i = 0
|
||||
|
@ -279,10 +283,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testSharedJobName(self):
|
||||
num_elements = 100
|
||||
master_address = self.create_cluster(1)
|
||||
dispatcher_address = self.create_cluster(1)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
||||
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
||||
ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||
ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||
iter1 = iter(ds1)
|
||||
iter2 = iter(ds2)
|
||||
results = []
|
||||
|
@ -298,20 +302,22 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDifferentJobNames(self):
|
||||
num_elements = 10
|
||||
master_address = self.create_cluster(1)
|
||||
dispatcher_address = self.create_cluster(1)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name1")
|
||||
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name2")
|
||||
ds1 = _make_distributed_dataset(
|
||||
ds, dispatcher_address, job_name="job_name1")
|
||||
ds2 = _make_distributed_dataset(
|
||||
ds, dispatcher_address, job_name="job_name2")
|
||||
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
||||
self.assertDatasetProduces(ds2, list(range(num_elements)))
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testSharedJobNameMultiIteration(self):
|
||||
num_elements = 10
|
||||
master_address = self.create_cluster(1)
|
||||
dispatcher_address = self.create_cluster(1)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
||||
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
||||
ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||
ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||
# iteration 1
|
||||
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
||||
self.assertDatasetProduces(ds2, [])
|
||||
|
@ -323,11 +329,11 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
def testSharedJobNameRepeat(self):
|
||||
num_elements = 100
|
||||
num_repetitions = 3
|
||||
master_address = self.create_cluster(1)
|
||||
dispatcher_address = self.create_cluster(1)
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
||||
ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||
ds1 = ds1.repeat(num_repetitions)
|
||||
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
||||
ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||
ds2 = ds2.repeat(num_repetitions)
|
||||
results = []
|
||||
iter1 = iter(ds1)
|
||||
|
@ -345,7 +351,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testApplyDeterminismOption(self):
|
||||
elements = list(range(10))
|
||||
master_address = self.create_cluster(1)
|
||||
dispatcher_address = self.create_cluster(1)
|
||||
|
||||
def dataset_fn(delay_ms):
|
||||
|
||||
|
@ -362,7 +368,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
opts = dataset_ops.Options()
|
||||
opts.experimental_deterministic = False
|
||||
ds = ds.with_options(opts)
|
||||
ds = _make_distributed_dataset(ds, master_address)
|
||||
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||
return ds
|
||||
|
||||
self.checkDeterminism(
|
||||
|
@ -379,8 +385,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
options.experimental_external_state_policy = external_state_policy
|
||||
ds = ds.with_options(options)
|
||||
|
||||
master_address = self.create_cluster(3)
|
||||
ds = _make_distributed_dataset(ds, master_address)
|
||||
dispatcher_address = self.create_cluster(3)
|
||||
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||
next(iter(ds))
|
||||
|
||||
@combinations.generate(
|
||||
|
@ -400,12 +406,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeFromInterleave(self):
|
||||
master_address = self.create_cluster(1)
|
||||
dispatcher_address = self.create_cluster(1)
|
||||
ds = dataset_ops.Dataset.range(2)
|
||||
|
||||
def interleave_fn(_):
|
||||
ds = dataset_ops.Dataset.range(2)
|
||||
_make_distributed_dataset(ds, master_address)
|
||||
_make_distributed_dataset(ds, dispatcher_address)
|
||||
return ds
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
path: "tensorflow.data.experimental.service.MasterServer"
|
||||
path: "tensorflow.data.experimental.service.DispatchServer"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.MasterServer\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.DispatchServer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "target"
|
|
@ -4,7 +4,7 @@ tf_class {
|
|||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'port\', \'master_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
|
||||
argspec: "args=[\'self\', \'port\', \'dispatcher_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "join"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
path: "tensorflow.data.experimental.service"
|
||||
tf_module {
|
||||
member {
|
||||
name: "MasterServer"
|
||||
name: "DispatchServer"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
|
|
|
@ -99,8 +99,8 @@ tensorflow::data::GrpcDataServerBase::Join
|
|||
tensorflow::data::GrpcDataServerBase::Start
|
||||
tensorflow::data::GrpcDataServerBase::Stop
|
||||
tensorflow::data::GrpcDataServerBase::BoundPort
|
||||
tensorflow::data::MasterGrpcDataServer::NumWorkers
|
||||
tensorflow::data::NewMasterServer
|
||||
tensorflow::data::DispatchGrpcDataServer::NumWorkers
|
||||
tensorflow::data::NewDispatchServer
|
||||
tensorflow::data::NewWorkerServer
|
||||
|
||||
[protos_all] # device_lib, dtypes
|
||||
|
|
Loading…
Reference in New Issue