Update "master" to "dispatch"/"dispatcher" in tf.data service terminology.
Dispatcher is more descriptive and follows the guidance in https://developers.google.com/style/word-list#master PiperOrigin-RevId: 321613785 Change-Id: Iaa576d35f0581e21278101f8b31201ba737a6865
This commit is contained in:
parent
0aa1d61fad
commit
6b9a9d98bb
|
@ -28,8 +28,8 @@ tf_proto_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_proto_library(
|
tf_proto_library(
|
||||||
name = "master_proto",
|
name = "dispatcher_proto",
|
||||||
srcs = ["master.proto"],
|
srcs = ["dispatcher.proto"],
|
||||||
has_services = 1,
|
has_services = 1,
|
||||||
cc_api_version = 2,
|
cc_api_version = 2,
|
||||||
protodeps = tf_additional_all_protos() + [
|
protodeps = tf_additional_all_protos() + [
|
||||||
|
@ -49,17 +49,17 @@ tf_proto_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "master_impl",
|
name = "dispatcher_impl",
|
||||||
srcs = ["master_impl.cc"],
|
srcs = ["dispatcher_impl.cc"],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"master_impl.h",
|
"dispatcher_impl.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":common_proto_cc",
|
":common_proto_cc",
|
||||||
":credentials_factory",
|
":credentials_factory",
|
||||||
":data_service",
|
":data_service",
|
||||||
|
":dispatcher_proto_cc",
|
||||||
":grpc_util",
|
":grpc_util",
|
||||||
":master_proto_cc",
|
|
||||||
":worker_cc_grpc_proto",
|
":worker_cc_grpc_proto",
|
||||||
":worker_proto_cc",
|
":worker_proto_cc",
|
||||||
"//tensorflow/c:c_api_internal",
|
"//tensorflow/c:c_api_internal",
|
||||||
|
@ -86,9 +86,9 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":common_proto_cc",
|
":common_proto_cc",
|
||||||
":credentials_factory",
|
":credentials_factory",
|
||||||
|
":dispatcher_cc_grpc_proto",
|
||||||
|
":dispatcher_proto_cc",
|
||||||
":grpc_util",
|
":grpc_util",
|
||||||
":master_cc_grpc_proto",
|
|
||||||
":master_proto_cc",
|
|
||||||
":worker_proto_cc",
|
":worker_proto_cc",
|
||||||
"//tensorflow/c:c_api_internal",
|
"//tensorflow/c:c_api_internal",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
|
@ -207,12 +207,12 @@ tf_cc_test(
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "grpc_master_impl",
|
name = "grpc_dispatcher_impl",
|
||||||
srcs = ["grpc_master_impl.cc"],
|
srcs = ["grpc_dispatcher_impl.cc"],
|
||||||
hdrs = ["grpc_master_impl.h"],
|
hdrs = ["grpc_dispatcher_impl.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":master_cc_grpc_proto",
|
":dispatcher_cc_grpc_proto",
|
||||||
":master_impl",
|
":dispatcher_impl",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||||
tf_grpc_cc_dependency(),
|
tf_grpc_cc_dependency(),
|
||||||
],
|
],
|
||||||
|
@ -250,7 +250,7 @@ cc_library(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":credentials_factory",
|
":credentials_factory",
|
||||||
":grpc_master_impl",
|
":grpc_dispatcher_impl",
|
||||||
":grpc_util",
|
":grpc_util",
|
||||||
":grpc_worker_impl",
|
":grpc_worker_impl",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
@ -268,9 +268,9 @@ cc_library(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":credentials_factory",
|
":credentials_factory",
|
||||||
|
":dispatcher_cc_grpc_proto",
|
||||||
|
":dispatcher_proto_cc",
|
||||||
":grpc_util",
|
":grpc_util",
|
||||||
":master_cc_grpc_proto",
|
|
||||||
":master_proto_cc",
|
|
||||||
":worker_cc_grpc_proto",
|
":worker_cc_grpc_proto",
|
||||||
":worker_proto_cc",
|
":worker_proto_cc",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
@ -287,12 +287,12 @@ tf_cc_test(
|
||||||
tags = ["no_windows"],
|
tags = ["no_windows"],
|
||||||
deps = [
|
deps = [
|
||||||
":data_service",
|
":data_service",
|
||||||
":grpc_master_impl",
|
":dispatcher_cc_grpc_proto",
|
||||||
|
":dispatcher_proto_cc",
|
||||||
|
":grpc_dispatcher_impl",
|
||||||
":grpc_util",
|
":grpc_util",
|
||||||
":grpc_worker_impl",
|
":grpc_worker_impl",
|
||||||
":local_credentials_factory",
|
":local_credentials_factory",
|
||||||
":master_cc_grpc_proto",
|
|
||||||
":master_proto_cc",
|
|
||||||
":server_lib",
|
":server_lib",
|
||||||
":test_cluster",
|
":test_cluster",
|
||||||
":test_util",
|
":test_util",
|
||||||
|
@ -309,11 +309,11 @@ tf_cc_test(
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_grpc_library(
|
cc_grpc_library(
|
||||||
name = "master_cc_grpc_proto",
|
name = "dispatcher_cc_grpc_proto",
|
||||||
srcs = [":master_proto"],
|
srcs = [":dispatcher_proto"],
|
||||||
generate_mocks = True,
|
generate_mocks = True,
|
||||||
grpc_only = True,
|
grpc_only = True,
|
||||||
deps = [":master_proto_cc"],
|
deps = [":dispatcher_proto_cc"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_grpc_library(
|
cc_grpc_library(
|
||||||
|
|
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||||
#include "grpcpp/create_channel.h"
|
#include "grpcpp/create_channel.h"
|
||||||
#include "grpcpp/security/credentials.h"
|
#include "grpcpp/security/credentials.h"
|
||||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||||
|
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||||
#include "tensorflow/core/data/service/grpc_util.h"
|
#include "tensorflow/core/data/service/grpc_util.h"
|
||||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
|
||||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ std::string ProcessingModeToString(ProcessingMode mode) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
|
Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset,
|
||||||
int64* dataset_id) {
|
int64* dataset_id) {
|
||||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||||
GetOrRegisterDatasetRequest req;
|
GetOrRegisterDatasetRequest req;
|
||||||
|
@ -69,7 +69,7 @@ Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterClient::CreateJob(int64 dataset_id,
|
Status DataServiceDispatcherClient::CreateJob(int64 dataset_id,
|
||||||
ProcessingMode processing_mode,
|
ProcessingMode processing_mode,
|
||||||
int64* job_id) {
|
int64* job_id) {
|
||||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||||
|
@ -88,11 +88,9 @@ Status DataServiceMasterClient::CreateJob(int64 dataset_id,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id,
|
Status DataServiceDispatcherClient::GetOrCreateJob(
|
||||||
ProcessingMode processing_mode,
|
int64 dataset_id, ProcessingMode processing_mode,
|
||||||
const std::string& job_name,
|
const std::string& job_name, int job_name_index, int64* job_id) {
|
||||||
int job_name_index,
|
|
||||||
int64* job_id) {
|
|
||||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||||
GetOrCreateJobRequest req;
|
GetOrCreateJobRequest req;
|
||||||
req.set_dataset_id(dataset_id);
|
req.set_dataset_id(dataset_id);
|
||||||
|
@ -112,7 +110,7 @@ Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterClient::GetTasks(int64 job_id,
|
Status DataServiceDispatcherClient::GetTasks(int64 job_id,
|
||||||
std::vector<TaskInfo>* tasks,
|
std::vector<TaskInfo>* tasks,
|
||||||
bool* job_finished) {
|
bool* job_finished) {
|
||||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||||
|
@ -132,7 +130,8 @@ Status DataServiceMasterClient::GetTasks(int64 job_id,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterClient::GetWorkers(std::vector<WorkerInfo>* workers) {
|
Status DataServiceDispatcherClient::GetWorkers(
|
||||||
|
std::vector<WorkerInfo>* workers) {
|
||||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||||
GetWorkersRequest req;
|
GetWorkersRequest req;
|
||||||
GetWorkersResponse resp;
|
GetWorkersResponse resp;
|
||||||
|
@ -148,12 +147,12 @@ Status DataServiceMasterClient::GetWorkers(std::vector<WorkerInfo>* workers) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterClient::EnsureInitialized() {
|
Status DataServiceDispatcherClient::EnsureInitialized() {
|
||||||
std::shared_ptr<grpc::ChannelCredentials> credentials;
|
std::shared_ptr<grpc::ChannelCredentials> credentials;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
|
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
|
||||||
auto channel = grpc::CreateChannel(address_, credentials);
|
auto channel = grpc::CreateChannel(address_, credentials);
|
||||||
stub_ = MasterService::NewStub(channel);
|
stub_ = DispatcherService::NewStub(channel);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,10 +186,11 @@ Status DataServiceWorkerClient::EnsureInitialized() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateDataServiceMasterClient(
|
Status CreateDataServiceDispatcherClient(
|
||||||
const std::string& address, const std::string& protocol,
|
const std::string& address, const std::string& protocol,
|
||||||
std::unique_ptr<DataServiceMasterClient>* out) {
|
std::unique_ptr<DataServiceDispatcherClient>* out) {
|
||||||
auto client = absl::make_unique<DataServiceMasterClient>(address, protocol);
|
auto client =
|
||||||
|
absl::make_unique<DataServiceDispatcherClient>(address, protocol);
|
||||||
TF_RETURN_IF_ERROR(client->Initialize());
|
TF_RETURN_IF_ERROR(client->Initialize());
|
||||||
*out = std::move(client);
|
*out = std::move(client);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
|
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
|
||||||
#define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
|
#define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
|
||||||
|
|
||||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
@ -67,10 +67,10 @@ class DataServiceClientBase {
|
||||||
const std::string protocol_;
|
const std::string protocol_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Client for communicating with the tf.data service master.
|
// Client for communicating with the tf.data service dispatcher.
|
||||||
class DataServiceMasterClient : public DataServiceClientBase {
|
class DataServiceDispatcherClient : public DataServiceClientBase {
|
||||||
public:
|
public:
|
||||||
DataServiceMasterClient(const std::string& address,
|
DataServiceDispatcherClient(const std::string& address,
|
||||||
const std::string& protocol)
|
const std::string& protocol)
|
||||||
: DataServiceClientBase(address, protocol) {}
|
: DataServiceClientBase(address, protocol) {}
|
||||||
|
|
||||||
|
@ -90,13 +90,13 @@ class DataServiceMasterClient : public DataServiceClientBase {
|
||||||
const std::string& job_name, int job_name_index,
|
const std::string& job_name, int job_name_index,
|
||||||
int64* job_id);
|
int64* job_id);
|
||||||
|
|
||||||
// Queries the master for the tasks associated with the specified job.
|
// Queries the dispatcher for the tasks associated with the specified job.
|
||||||
// The tasks will be stored in *tasks, and whether the job is finished will
|
// The tasks will be stored in *tasks, and whether the job is finished will
|
||||||
// be stored in `*job_finished`.
|
// be stored in `*job_finished`.
|
||||||
Status GetTasks(int64 job_id, std::vector<TaskInfo>* tasks,
|
Status GetTasks(int64 job_id, std::vector<TaskInfo>* tasks,
|
||||||
bool* job_finished);
|
bool* job_finished);
|
||||||
|
|
||||||
// Queries the master for its registered workers. The worker info will be
|
// Queries the dispatcher for its registered workers. The worker info will be
|
||||||
// stored in `*workers`.
|
// stored in `*workers`.
|
||||||
Status GetWorkers(std::vector<WorkerInfo>* workers);
|
Status GetWorkers(std::vector<WorkerInfo>* workers);
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ class DataServiceMasterClient : public DataServiceClientBase {
|
||||||
Status EnsureInitialized() override;
|
Status EnsureInitialized() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<MasterService::Stub> stub_;
|
std::unique_ptr<DispatcherService::Stub> stub_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Client for communicating with the tf.data service worker.
|
// Client for communicating with the tf.data service worker.
|
||||||
|
@ -127,10 +127,10 @@ class DataServiceWorkerClient : public DataServiceClientBase {
|
||||||
std::unique_ptr<WorkerService::Stub> stub_;
|
std::unique_ptr<WorkerService::Stub> stub_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Creates and initializes a new tf.data service master client.
|
// Creates and initializes a new tf.data service dispatcher client.
|
||||||
Status CreateDataServiceMasterClient(
|
Status CreateDataServiceDispatcherClient(
|
||||||
const std::string& address, const std::string& protocol,
|
const std::string& address, const std::string& protocol,
|
||||||
std::unique_ptr<DataServiceMasterClient>* out);
|
std::unique_ptr<DataServiceDispatcherClient>* out);
|
||||||
|
|
||||||
// Creates and initializes a new tf.data service worker client.
|
// Creates and initializes a new tf.data service worker client.
|
||||||
Status CreateDataServiceWorkerClient(
|
Status CreateDataServiceWorkerClient(
|
||||||
|
|
|
@ -19,9 +19,9 @@ limitations under the License.
|
||||||
#include "grpcpp/security/credentials.h"
|
#include "grpcpp/security/credentials.h"
|
||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
#include "tensorflow/core/data/compression_utils.h"
|
#include "tensorflow/core/data/compression_utils.h"
|
||||||
|
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||||
|
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||||
#include "tensorflow/core/data/service/grpc_util.h"
|
#include "tensorflow/core/data/service/grpc_util.h"
|
||||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
|
||||||
#include "tensorflow/core/data/service/master.pb.h"
|
|
||||||
#include "tensorflow/core/data/service/server_lib.h"
|
#include "tensorflow/core/data/service/server_lib.h"
|
||||||
#include "tensorflow/core/data/service/test_cluster.h"
|
#include "tensorflow/core/data/service/test_cluster.h"
|
||||||
#include "tensorflow/core/data/service/test_util.h"
|
#include "tensorflow/core/data/service/test_util.h"
|
||||||
|
@ -66,9 +66,10 @@ TEST(DataService, ProcessingModeToString) {
|
||||||
TEST(DataService, GetWorkers) {
|
TEST(DataService, GetWorkers) {
|
||||||
TestCluster cluster(1);
|
TestCluster cluster(1);
|
||||||
TF_ASSERT_OK(cluster.Initialize());
|
TF_ASSERT_OK(cluster.Initialize());
|
||||||
DataServiceMasterClient master(cluster.MasterAddress(), kProtocol);
|
DataServiceDispatcherClient dispatcher(cluster.DispatcherAddress(),
|
||||||
|
kProtocol);
|
||||||
std::vector<WorkerInfo> workers;
|
std::vector<WorkerInfo> workers;
|
||||||
TF_EXPECT_OK(master.GetWorkers(&workers));
|
TF_EXPECT_OK(dispatcher.GetWorkers(&workers));
|
||||||
EXPECT_EQ(1, workers.size());
|
EXPECT_EQ(1, workers.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -110,11 +110,11 @@ message GetWorkersResponse {
|
||||||
repeated WorkerInfo workers = 1;
|
repeated WorkerInfo workers = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
service MasterService {
|
service DispatcherService {
|
||||||
// Registers a worker with the master.
|
// Registers a worker with the dispatcher.
|
||||||
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse);
|
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse);
|
||||||
|
|
||||||
// Updates the master with information about the worker's state.
|
// Updates the dispatcher with information about the worker's state.
|
||||||
rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse);
|
rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse);
|
||||||
|
|
||||||
// Registers a dataset with the server, or returns its id if it is already
|
// Registers a dataset with the server, or returns its id if it is already
|
||||||
|
@ -134,6 +134,6 @@ service MasterService {
|
||||||
// Reports a list of all tasks for a job.
|
// Reports a list of all tasks for a job.
|
||||||
rpc GetTasks(GetTasksRequest) returns (GetTasksResponse);
|
rpc GetTasks(GetTasksRequest) returns (GetTasksResponse);
|
||||||
|
|
||||||
// Reports a list of all workers registered with the master.
|
// Reports a list of all workers registered with the dispatcher.
|
||||||
rpc GetWorkers(GetWorkersRequest) returns (GetWorkersResponse);
|
rpc GetWorkers(GetWorkersRequest) returns (GetWorkersResponse);
|
||||||
}
|
}
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/data/service/master_impl.h"
|
#include "tensorflow/core/data/service/dispatcher_impl.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
@ -26,8 +26,8 @@ limitations under the License.
|
||||||
#include "tensorflow/core/data/service/common.pb.h"
|
#include "tensorflow/core/data/service/common.pb.h"
|
||||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||||
#include "tensorflow/core/data/service/data_service.h"
|
#include "tensorflow/core/data/service/data_service.h"
|
||||||
|
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||||
#include "tensorflow/core/data/service/grpc_util.h"
|
#include "tensorflow/core/data/service/grpc_util.h"
|
||||||
#include "tensorflow/core/data/service/master.pb.h"
|
|
||||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||||
|
@ -53,10 +53,10 @@ Status CreateWorkerStub(const std::string& address,
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
DataServiceMasterImpl::DataServiceMasterImpl(const std::string protocol)
|
DataServiceDispatcherImpl::DataServiceDispatcherImpl(const std::string protocol)
|
||||||
: protocol_(protocol) {}
|
: protocol_(protocol) {}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::RegisterWorker(
|
Status DataServiceDispatcherImpl::RegisterWorker(
|
||||||
const RegisterWorkerRequest* request, RegisterWorkerResponse* response) {
|
const RegisterWorkerRequest* request, RegisterWorkerResponse* response) {
|
||||||
VLOG(3) << "Received register worker request";
|
VLOG(3) << "Received register worker request";
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
@ -86,8 +86,8 @@ Status DataServiceMasterImpl::RegisterWorker(
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request,
|
Status DataServiceDispatcherImpl::WorkerUpdate(
|
||||||
WorkerUpdateResponse* response) {
|
const WorkerUpdateRequest* request, WorkerUpdateResponse* response) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
int64 worker_id = request->worker_id();
|
int64 worker_id = request->worker_id();
|
||||||
for (auto& update : request->updates()) {
|
for (auto& update : request->updates()) {
|
||||||
|
@ -106,7 +106,7 @@ Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::GetOrRegisterDataset(
|
Status DataServiceDispatcherImpl::GetOrRegisterDataset(
|
||||||
const GetOrRegisterDatasetRequest* request,
|
const GetOrRegisterDatasetRequest* request,
|
||||||
GetOrRegisterDatasetResponse* response) {
|
GetOrRegisterDatasetResponse* response) {
|
||||||
uint64 fingerprint;
|
uint64 fingerprint;
|
||||||
|
@ -128,7 +128,7 @@ Status DataServiceMasterImpl::GetOrRegisterDataset(
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
|
int64 DataServiceDispatcherImpl::RegisterDataset(uint64 fingerprint,
|
||||||
const DatasetDef& dataset)
|
const DatasetDef& dataset)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
int64 dataset_id = next_dataset_id_++;
|
int64 dataset_id = next_dataset_id_++;
|
||||||
|
@ -142,7 +142,7 @@ int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
|
||||||
return dataset_id;
|
return dataset_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
|
Status DataServiceDispatcherImpl::CreateJob(const CreateJobRequest* request,
|
||||||
CreateJobResponse* response) {
|
CreateJobResponse* response) {
|
||||||
VLOG(3) << "Received create job request for dataset id "
|
VLOG(3) << "Received create job request for dataset id "
|
||||||
<< request->dataset_id();
|
<< request->dataset_id();
|
||||||
|
@ -157,7 +157,7 @@ Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::GetOrCreateJob(
|
Status DataServiceDispatcherImpl::GetOrCreateJob(
|
||||||
const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
|
const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
|
||||||
VLOG(3) << "Received get or create job request for dataset id "
|
VLOG(3) << "Received get or create job request for dataset id "
|
||||||
<< request->dataset_id() << " with name " << request->job_name()
|
<< request->dataset_id() << " with name " << request->job_name()
|
||||||
|
@ -193,7 +193,7 @@ Status DataServiceMasterImpl::GetOrCreateJob(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validates that the job matches the given processing_mode and dataset_id.
|
// Validates that the job matches the given processing_mode and dataset_id.
|
||||||
Status DataServiceMasterImpl::ValidateMatchingJob(
|
Status DataServiceDispatcherImpl::ValidateMatchingJob(
|
||||||
const Job& job, ProcessingMode processing_mode, int64 dataset_id) {
|
const Job& job, ProcessingMode processing_mode, int64 dataset_id) {
|
||||||
DCHECK(job.name().has_value());
|
DCHECK(job.name().has_value());
|
||||||
std::string job_name = job.name().value();
|
std::string job_name = job.name().value();
|
||||||
|
@ -214,10 +214,10 @@ Status DataServiceMasterImpl::ValidateMatchingJob(
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
|
Status DataServiceDispatcherImpl::CreateJob(
|
||||||
ProcessingMode processing_mode,
|
int64 dataset_id, ProcessingMode processing_mode,
|
||||||
absl::optional<std::string> job_name,
|
absl::optional<std::string> job_name, int64* out_job_id)
|
||||||
int64* out_job_id) LOCKS_EXCLUDED(mu_) {
|
LOCKS_EXCLUDED(mu_) {
|
||||||
switch (processing_mode) {
|
switch (processing_mode) {
|
||||||
case ProcessingMode::PARALLEL_EPOCHS:
|
case ProcessingMode::PARALLEL_EPOCHS:
|
||||||
break;
|
break;
|
||||||
|
@ -274,14 +274,16 @@ Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTask(
|
const DataServiceDispatcherImpl::Task& DataServiceDispatcherImpl::CreateTask(
|
||||||
Job* job, const std::string& worker_address) LOCKS_EXCLUDED(mu_) {
|
Job* job, const std::string& worker_address) LOCKS_EXCLUDED(mu_) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
return CreateTaskLocked(job, worker_address);
|
return CreateTaskLocked(job, worker_address);
|
||||||
}
|
}
|
||||||
|
|
||||||
const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTaskLocked(
|
const DataServiceDispatcherImpl::Task&
|
||||||
Job* job, const std::string& worker_address) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
DataServiceDispatcherImpl::CreateTaskLocked(Job* job,
|
||||||
|
const std::string& worker_address)
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
int64 task_id = next_task_id_++;
|
int64 task_id = next_task_id_++;
|
||||||
DCHECK(!tasks_.contains(task_id));
|
DCHECK(!tasks_.contains(task_id));
|
||||||
tasks_.insert({task_id, Task(task_id, job->job_id(), job->dataset_id(),
|
tasks_.insert({task_id, Task(task_id, job->job_id(), job->dataset_id(),
|
||||||
|
@ -290,7 +292,7 @@ const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTaskLocked(
|
||||||
return tasks_.at(task_id);
|
return tasks_.at(task_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) {
|
Status DataServiceDispatcherImpl::EnsureWorkerStubInitialized(Worker* worker) {
|
||||||
if (!worker->stub()) {
|
if (!worker->stub()) {
|
||||||
std::unique_ptr<WorkerService::Stub> stub;
|
std::unique_ptr<WorkerService::Stub> stub;
|
||||||
TF_RETURN_IF_ERROR(CreateWorkerStub(worker->address(), protocol_, &stub));
|
TF_RETURN_IF_ERROR(CreateWorkerStub(worker->address(), protocol_, &stub));
|
||||||
|
@ -299,7 +301,7 @@ Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task,
|
Status DataServiceDispatcherImpl::AllocateTaskToWorker(const Task& task,
|
||||||
Worker* worker)
|
Worker* worker)
|
||||||
LOCKS_EXCLUDED(mu_) {
|
LOCKS_EXCLUDED(mu_) {
|
||||||
TF_RETURN_IF_ERROR(EnsureWorkerStubInitialized(worker));
|
TF_RETURN_IF_ERROR(EnsureWorkerStubInitialized(worker));
|
||||||
|
@ -322,7 +324,7 @@ Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
|
Status DataServiceDispatcherImpl::GetTasks(const GetTasksRequest* request,
|
||||||
GetTasksResponse* response) {
|
GetTasksResponse* response) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
VLOG(3) << "Looking up tasks for job id " << request->job_id();
|
VLOG(3) << "Looking up tasks for job id " << request->job_id();
|
||||||
|
@ -346,7 +348,7 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::GetWorkers(const GetWorkersRequest* request,
|
Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request,
|
||||||
GetWorkersResponse* response) {
|
GetWorkersResponse* response) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
VLOG(3) << "Enter GetWorkers";
|
VLOG(3) << "Enter GetWorkers";
|
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_
|
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
|
||||||
#define TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_
|
#define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/core/data/service/common.pb.h"
|
#include "tensorflow/core/data/service/common.pb.h"
|
||||||
#include "tensorflow/core/data/service/data_service.h"
|
#include "tensorflow/core/data/service/data_service.h"
|
||||||
#include "tensorflow/core/data/service/master.pb.h"
|
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
@ -40,11 +40,11 @@ namespace data {
|
||||||
// ProcessingModeDef which determines what data it produces.
|
// ProcessingModeDef which determines what data it produces.
|
||||||
// * Task: A job is broken into multiple tasks, which each represent
|
// * Task: A job is broken into multiple tasks, which each represent
|
||||||
// iterating over all of or part of the dataset. Workers process tasks.
|
// iterating over all of or part of the dataset. Workers process tasks.
|
||||||
class DataServiceMasterImpl {
|
class DataServiceDispatcherImpl {
|
||||||
public:
|
public:
|
||||||
explicit DataServiceMasterImpl(const std::string protocol);
|
explicit DataServiceDispatcherImpl(const std::string protocol);
|
||||||
|
|
||||||
// See master.proto for API documentation.
|
// See dispatcher.proto for API documentation.
|
||||||
|
|
||||||
/// Worker-facing API.
|
/// Worker-facing API.
|
||||||
Status RegisterWorker(const RegisterWorkerRequest* request,
|
Status RegisterWorker(const RegisterWorkerRequest* request,
|
||||||
|
@ -191,7 +191,7 @@ class DataServiceMasterImpl {
|
||||||
// Creates a new task for a job, returning a reference to the task.
|
// Creates a new task for a job, returning a reference to the task.
|
||||||
const Task& CreateTask(Job* job, const std::string& worker_address)
|
const Task& CreateTask(Job* job, const std::string& worker_address)
|
||||||
LOCKS_EXCLUDED(mu_);
|
LOCKS_EXCLUDED(mu_);
|
||||||
// Same as `CreateTask`, but expects that the master lock is already held.
|
// Same as `CreateTask`, but expects that the dispatcher lock is already held.
|
||||||
const Task& CreateTaskLocked(Job* job, const std::string& worker_address)
|
const Task& CreateTaskLocked(Job* job, const std::string& worker_address)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
// Validates that an existing job matches the given processing_mode and
|
// Validates that an existing job matches the given processing_mode and
|
||||||
|
@ -225,10 +225,10 @@ class DataServiceMasterImpl {
|
||||||
absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_
|
absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_
|
||||||
TF_GUARDED_BY(mu_);
|
TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceMasterImpl);
|
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_
|
#endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/data/service/grpc_master_impl.h"
|
#include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
|
||||||
|
|
||||||
#include "grpcpp/server_context.h"
|
#include "grpcpp/server_context.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||||
|
@ -25,15 +25,15 @@ using ::grpc::ServerBuilder;
|
||||||
using ::grpc::ServerContext;
|
using ::grpc::ServerContext;
|
||||||
using ::grpc::Status;
|
using ::grpc::Status;
|
||||||
|
|
||||||
GrpcMasterImpl::GrpcMasterImpl(ServerBuilder* server_builder,
|
GrpcDispatcherImpl::GrpcDispatcherImpl(ServerBuilder* server_builder,
|
||||||
const std::string& protocol)
|
const std::string& protocol)
|
||||||
: impl_(protocol) {
|
: impl_(protocol) {
|
||||||
server_builder->RegisterService(this);
|
server_builder->RegisterService(this);
|
||||||
VLOG(1) << "Registered data service master";
|
VLOG(1) << "Registered data service dispatcher";
|
||||||
}
|
}
|
||||||
|
|
||||||
#define HANDLER(method) \
|
#define HANDLER(method) \
|
||||||
Status GrpcMasterImpl::method(ServerContext* context, \
|
Status GrpcDispatcherImpl::method(ServerContext* context, \
|
||||||
const method##Request* request, \
|
const method##Request* request, \
|
||||||
method##Response* response) { \
|
method##Response* response) { \
|
||||||
return ToGrpcStatus(impl_.method(request, response)); \
|
return ToGrpcStatus(impl_.method(request, response)); \
|
|
@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_
|
#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_
|
||||||
#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_
|
#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_
|
||||||
|
|
||||||
#include "grpcpp/server_builder.h"
|
#include "grpcpp/server_builder.h"
|
||||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||||
#include "tensorflow/core/data/service/master_impl.h"
|
#include "tensorflow/core/data/service/dispatcher_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
|
@ -29,14 +29,14 @@ namespace data {
|
||||||
//
|
//
|
||||||
// ::grpc::ServerBuilder builder;
|
// ::grpc::ServerBuilder builder;
|
||||||
// // configure builder
|
// // configure builder
|
||||||
// GrpcMasterImpl data_service(&builder);
|
// GrpcDispatcherImpl data_service(&builder);
|
||||||
// builder.BuildAndStart()
|
// builder.BuildAndStart()
|
||||||
//
|
//
|
||||||
class GrpcMasterImpl : public MasterService::Service {
|
class GrpcDispatcherImpl : public DispatcherService::Service {
|
||||||
public:
|
public:
|
||||||
explicit GrpcMasterImpl(grpc::ServerBuilder* server_builder,
|
explicit GrpcDispatcherImpl(grpc::ServerBuilder* server_builder,
|
||||||
const std::string& protocol);
|
const std::string& protocol);
|
||||||
~GrpcMasterImpl() override {}
|
~GrpcDispatcherImpl() override {}
|
||||||
|
|
||||||
#define HANDLER(method) \
|
#define HANDLER(method) \
|
||||||
grpc::Status method(grpc::ServerContext* context, \
|
grpc::Status method(grpc::ServerContext* context, \
|
||||||
|
@ -52,12 +52,12 @@ class GrpcMasterImpl : public MasterService::Service {
|
||||||
#undef HANDLER
|
#undef HANDLER
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DataServiceMasterImpl impl_;
|
DataServiceDispatcherImpl impl_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterImpl);
|
TF_DISALLOW_COPY_AND_ASSIGN(GrpcDispatcherImpl);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_
|
#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_
|
|
@ -26,9 +26,9 @@ using ::grpc::ServerContext;
|
||||||
using ::grpc::Status;
|
using ::grpc::Status;
|
||||||
|
|
||||||
GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder,
|
GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address,
|
||||||
const std::string& protocol)
|
const std::string& protocol)
|
||||||
: impl_(master_address, protocol) {
|
: impl_(dispatcher_address, protocol) {
|
||||||
server_builder->RegisterService(this);
|
server_builder->RegisterService(this);
|
||||||
VLOG(1) << "Registered data service worker";
|
VLOG(1) << "Registered data service worker";
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ namespace data {
|
||||||
class GrpcWorkerImpl : public WorkerService::Service {
|
class GrpcWorkerImpl : public WorkerService::Service {
|
||||||
public:
|
public:
|
||||||
explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder,
|
explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address,
|
||||||
const std::string& protocol);
|
const std::string& protocol);
|
||||||
~GrpcWorkerImpl() override {}
|
~GrpcWorkerImpl() override {}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/data/service/server_lib.h"
|
#include "tensorflow/core/data/service/server_lib.h"
|
||||||
|
|
||||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||||
#include "tensorflow/core/data/service/grpc_master_impl.h"
|
#include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
|
||||||
#include "tensorflow/core/data/service/grpc_util.h"
|
#include "tensorflow/core/data/service/grpc_util.h"
|
||||||
#include "tensorflow/core/data/service/grpc_worker_impl.h"
|
#include "tensorflow/core/data/service/grpc_worker_impl.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
@ -72,18 +72,18 @@ void GrpcDataServerBase::Join() { server_->Wait(); }
|
||||||
|
|
||||||
int GrpcDataServerBase::BoundPort() { return bound_port(); }
|
int GrpcDataServerBase::BoundPort() { return bound_port(); }
|
||||||
|
|
||||||
MasterGrpcDataServer::MasterGrpcDataServer(int port,
|
DispatchGrpcDataServer::DispatchGrpcDataServer(int port,
|
||||||
const std::string& protocol)
|
const std::string& protocol)
|
||||||
: GrpcDataServerBase(port, protocol) {}
|
: GrpcDataServerBase(port, protocol) {}
|
||||||
|
|
||||||
MasterGrpcDataServer::~MasterGrpcDataServer() { delete service_; }
|
DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
|
||||||
|
|
||||||
void MasterGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
|
void DispatchGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
|
||||||
auto service = absl::make_unique<GrpcMasterImpl>(builder, protocol_);
|
auto service = absl::make_unique<GrpcDispatcherImpl>(builder, protocol_);
|
||||||
service_ = service.release();
|
service_ = service.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MasterGrpcDataServer::NumWorkers(int* num_workers) {
|
Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
|
||||||
GetWorkersRequest req;
|
GetWorkersRequest req;
|
||||||
GetWorkersResponse resp;
|
GetWorkersResponse resp;
|
||||||
grpc::ServerContext ctx;
|
grpc::ServerContext ctx;
|
||||||
|
@ -95,19 +95,18 @@ Status MasterGrpcDataServer::NumWorkers(int* num_workers) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
WorkerGrpcDataServer::WorkerGrpcDataServer(int port,
|
WorkerGrpcDataServer::WorkerGrpcDataServer(
|
||||||
const std::string& protocol,
|
int port, const std::string& protocol,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address, const std::string& worker_address)
|
||||||
const std::string& worker_address)
|
|
||||||
: GrpcDataServerBase(port, protocol),
|
: GrpcDataServerBase(port, protocol),
|
||||||
master_address_(master_address),
|
dispatcher_address_(dispatcher_address),
|
||||||
worker_address_(worker_address) {}
|
worker_address_(worker_address) {}
|
||||||
|
|
||||||
WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
|
WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
|
||||||
|
|
||||||
void WorkerGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
|
void WorkerGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
|
||||||
auto service =
|
auto service = absl::make_unique<GrpcWorkerImpl>(builder, dispatcher_address_,
|
||||||
absl::make_unique<GrpcWorkerImpl>(builder, master_address_, protocol_);
|
protocol_);
|
||||||
service_ = service.release();
|
service_ = service.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -123,25 +122,25 @@ Status WorkerGrpcDataServer::StartServiceInternal() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status NewMasterServer(int port, const std::string& protocol,
|
Status NewDispatchServer(int port, const std::string& protocol,
|
||||||
std::unique_ptr<MasterGrpcDataServer>* out_server) {
|
std::unique_ptr<DispatchGrpcDataServer>* out_server) {
|
||||||
*out_server = absl::make_unique<MasterGrpcDataServer>(port, protocol);
|
*out_server = absl::make_unique<DispatchGrpcDataServer>(port, protocol);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status NewWorkerServer(int port, const std::string& protocol,
|
Status NewWorkerServer(int port, const std::string& protocol,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address,
|
||||||
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
|
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
|
||||||
return NewWorkerServer(port, protocol, master_address, /*worker_address=*/"",
|
return NewWorkerServer(port, protocol, dispatcher_address,
|
||||||
out_server);
|
/*worker_address=*/"", out_server);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status NewWorkerServer(int port, const std::string& protocol,
|
Status NewWorkerServer(int port, const std::string& protocol,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address,
|
||||||
const std::string& worker_address,
|
const std::string& worker_address,
|
||||||
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
|
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
|
||||||
*out_server = absl::make_unique<WorkerGrpcDataServer>(
|
*out_server = absl::make_unique<WorkerGrpcDataServer>(
|
||||||
port, protocol, master_address, worker_address);
|
port, protocol, dispatcher_address, worker_address);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ namespace data {
|
||||||
|
|
||||||
// Forward declared because transitively depending on .grpc.pb.h files causes
|
// Forward declared because transitively depending on .grpc.pb.h files causes
|
||||||
// issues in the pywrap build.
|
// issues in the pywrap build.
|
||||||
class GrpcMasterImpl;
|
class GrpcDispatcherImpl;
|
||||||
class GrpcWorkerImpl;
|
class GrpcWorkerImpl;
|
||||||
|
|
||||||
// A grpc server for the tf.data service.
|
// A grpc server for the tf.data service.
|
||||||
|
@ -35,7 +35,7 @@ class GrpcDataServerBase {
|
||||||
// server will find an available port in `Start()`. The chosen port can be
|
// server will find an available port in `Start()`. The chosen port can be
|
||||||
// found in the output of `Target()`.
|
// found in the output of `Target()`.
|
||||||
//
|
//
|
||||||
// master_address is only needed for worker data servers.
|
// dispatcher_address is only needed for worker data servers.
|
||||||
GrpcDataServerBase(int requested_port, const std::string& protocol);
|
GrpcDataServerBase(int requested_port, const std::string& protocol);
|
||||||
virtual ~GrpcDataServerBase() {}
|
virtual ~GrpcDataServerBase() {}
|
||||||
|
|
||||||
|
@ -70,12 +70,12 @@ class GrpcDataServerBase {
|
||||||
std::unique_ptr<grpc::Server> server_;
|
std::unique_ptr<grpc::Server> server_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MasterGrpcDataServer : public GrpcDataServerBase {
|
class DispatchGrpcDataServer : public GrpcDataServerBase {
|
||||||
public:
|
public:
|
||||||
MasterGrpcDataServer(int requested_port, const std::string& protocol);
|
DispatchGrpcDataServer(int requested_port, const std::string& protocol);
|
||||||
~MasterGrpcDataServer() override;
|
~DispatchGrpcDataServer() override;
|
||||||
|
|
||||||
// Returns the number of workers registerd with the master.
|
// Returns the number of workers registerd with the dispatcher.
|
||||||
Status NumWorkers(int* num_workers);
|
Status NumWorkers(int* num_workers);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -83,14 +83,14 @@ class MasterGrpcDataServer : public GrpcDataServerBase {
|
||||||
Status StartServiceInternal() override { return Status::OK(); }
|
Status StartServiceInternal() override { return Status::OK(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Owned. We use a raw pointer because GrpcMasterImpl is forward-declared.
|
// Owned. We use a raw pointer because GrpcDispatcherImpl is forward-declared.
|
||||||
GrpcMasterImpl* service_;
|
GrpcDispatcherImpl* service_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class WorkerGrpcDataServer : public GrpcDataServerBase {
|
class WorkerGrpcDataServer : public GrpcDataServerBase {
|
||||||
public:
|
public:
|
||||||
WorkerGrpcDataServer(int requested_port, const std::string& protocol,
|
WorkerGrpcDataServer(int requested_port, const std::string& protocol,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address,
|
||||||
const std::string& worker_address);
|
const std::string& worker_address);
|
||||||
~WorkerGrpcDataServer() override;
|
~WorkerGrpcDataServer() override;
|
||||||
|
|
||||||
|
@ -99,15 +99,15 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
|
||||||
Status StartServiceInternal() override;
|
Status StartServiceInternal() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const std::string master_address_;
|
const std::string dispatcher_address_;
|
||||||
const std::string worker_address_;
|
const std::string worker_address_;
|
||||||
// Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared.
|
// Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared.
|
||||||
GrpcWorkerImpl* service_;
|
GrpcWorkerImpl* service_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Creates a master tf.data server and stores it in `*out_server`.
|
// Creates a dispatch tf.data server and stores it in `*out_server`.
|
||||||
Status NewMasterServer(int port, const std::string& protocol,
|
Status NewDispatchServer(int port, const std::string& protocol,
|
||||||
std::unique_ptr<MasterGrpcDataServer>* out_server);
|
std::unique_ptr<DispatchGrpcDataServer>* out_server);
|
||||||
|
|
||||||
// Creates a worker tf.data server and stores it in `*out_server`.
|
// Creates a worker tf.data server and stores it in `*out_server`.
|
||||||
//
|
//
|
||||||
|
@ -115,18 +115,18 @@ Status NewMasterServer(int port, const std::string& protocol,
|
||||||
// will be chosen in Start(). This value can be queried with BoundPort().
|
// will be chosen in Start(). This value can be queried with BoundPort().
|
||||||
//
|
//
|
||||||
// The worker_address argument is optional. If left empty, it will default to
|
// The worker_address argument is optional. If left empty, it will default to
|
||||||
// "localhost:%port%". When the worker registers with the master, the worker
|
// "localhost:%port%". When the worker registers with the dispatcher, the worker
|
||||||
// will report the worker address, so that the master can tell clients where to
|
// will report the worker address, so that the dispatcher can tell clients where
|
||||||
// read from. The address may contain the placeholder "%port%", which will be
|
// to read from. The address may contain the placeholder "%port%", which will be
|
||||||
// replaced with the value of BoundPort().
|
// replaced with the value of BoundPort().
|
||||||
Status NewWorkerServer(int port, const std::string& protocol,
|
Status NewWorkerServer(int port, const std::string& protocol,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address,
|
||||||
const std::string& worker_address,
|
const std::string& worker_address,
|
||||||
std::unique_ptr<WorkerGrpcDataServer>* out_server);
|
std::unique_ptr<WorkerGrpcDataServer>* out_server);
|
||||||
|
|
||||||
// Creates a worker using the default worker_address.
|
// Creates a worker using the default worker_address.
|
||||||
Status NewWorkerServer(int port, const std::string& protocol,
|
Status NewWorkerServer(int port, const std::string& protocol,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address,
|
||||||
std::unique_ptr<WorkerGrpcDataServer>* out_server);
|
std::unique_ptr<WorkerGrpcDataServer>* out_server);
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
|
|
@ -45,9 +45,9 @@ Status TestCluster::Initialize() {
|
||||||
"Test cluster has already been initialized.");
|
"Test cluster has already been initialized.");
|
||||||
}
|
}
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
TF_RETURN_IF_ERROR(NewMasterServer(/*port=*/0, kProtocol, &master_));
|
TF_RETURN_IF_ERROR(NewDispatchServer(/*port=*/0, kProtocol, &dispatcher_));
|
||||||
TF_RETURN_IF_ERROR(master_->Start());
|
TF_RETURN_IF_ERROR(dispatcher_->Start());
|
||||||
master_address_ = absl::StrCat("localhost:", master_->BoundPort());
|
dispatcher_address_ = absl::StrCat("localhost:", dispatcher_->BoundPort());
|
||||||
workers_.reserve(num_workers_);
|
workers_.reserve(num_workers_);
|
||||||
worker_addresses_.reserve(num_workers_);
|
worker_addresses_.reserve(num_workers_);
|
||||||
for (int i = 0; i < num_workers_; ++i) {
|
for (int i = 0; i < num_workers_; ++i) {
|
||||||
|
@ -59,14 +59,14 @@ Status TestCluster::Initialize() {
|
||||||
Status TestCluster::AddWorker() {
|
Status TestCluster::AddWorker() {
|
||||||
std::unique_ptr<WorkerGrpcDataServer> worker;
|
std::unique_ptr<WorkerGrpcDataServer> worker;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
NewWorkerServer(/*port=*/0, kProtocol, master_address_, &worker));
|
NewWorkerServer(/*port=*/0, kProtocol, dispatcher_address_, &worker));
|
||||||
TF_RETURN_IF_ERROR(worker->Start());
|
TF_RETURN_IF_ERROR(worker->Start());
|
||||||
worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort()));
|
worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort()));
|
||||||
workers_.push_back(std::move(worker));
|
workers_.push_back(std::move(worker));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string TestCluster::MasterAddress() { return master_address_; }
|
std::string TestCluster::DispatcherAddress() { return dispatcher_address_; }
|
||||||
|
|
||||||
std::string TestCluster::WorkerAddress(int index) {
|
std::string TestCluster::WorkerAddress(int index) {
|
||||||
DCHECK_GE(index, 0);
|
DCHECK_GE(index, 0);
|
||||||
|
|
|
@ -24,7 +24,7 @@ namespace data {
|
||||||
// Helper class for unit testing a tf.data service cluster.
|
// Helper class for unit testing a tf.data service cluster.
|
||||||
class TestCluster {
|
class TestCluster {
|
||||||
public:
|
public:
|
||||||
// Creates a new test cluster with a master and `num_workers` workers.
|
// Creates a new test cluster with a dispatcher and `num_workers` workers.
|
||||||
explicit TestCluster(int num_workers);
|
explicit TestCluster(int num_workers);
|
||||||
|
|
||||||
// Initializes the test cluster. This must be called before interacting with
|
// Initializes the test cluster. This must be called before interacting with
|
||||||
|
@ -32,8 +32,8 @@ class TestCluster {
|
||||||
Status Initialize();
|
Status Initialize();
|
||||||
// Adds a new worker to the cluster.
|
// Adds a new worker to the cluster.
|
||||||
Status AddWorker();
|
Status AddWorker();
|
||||||
// Returns the master address in the form "hostname:port".
|
// Returns the dispatcher address in the form "hostname:port".
|
||||||
std::string MasterAddress();
|
std::string DispatcherAddress();
|
||||||
// Returns the address of the worker at the specified index, in the form
|
// Returns the address of the worker at the specified index, in the form
|
||||||
// "hostname:port". The index must be non-negative and less than the number of
|
// "hostname:port". The index must be non-negative and less than the number of
|
||||||
// workers in the cluster.
|
// workers in the cluster.
|
||||||
|
@ -42,8 +42,8 @@ class TestCluster {
|
||||||
private:
|
private:
|
||||||
bool initialized_ = false;
|
bool initialized_ = false;
|
||||||
int num_workers_;
|
int num_workers_;
|
||||||
std::unique_ptr<MasterGrpcDataServer> master_;
|
std::unique_ptr<DispatchGrpcDataServer> dispatcher_;
|
||||||
std::string master_address_;
|
std::string dispatcher_address_;
|
||||||
std::vector<std::unique_ptr<WorkerGrpcDataServer>> workers_;
|
std::vector<std::unique_ptr<WorkerGrpcDataServer>> workers_;
|
||||||
std::vector<std::string> worker_addresses_;
|
std::vector<std::string> worker_addresses_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -21,9 +21,9 @@ limitations under the License.
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/core/data/dataset.pb.h"
|
#include "tensorflow/core/data/dataset.pb.h"
|
||||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||||
|
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||||
|
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||||
#include "tensorflow/core/data/service/grpc_util.h"
|
#include "tensorflow/core/data/service/grpc_util.h"
|
||||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
|
||||||
#include "tensorflow/core/data/service/master.pb.h"
|
|
||||||
#include "tensorflow/core/data/standalone.h"
|
#include "tensorflow/core/data/standalone.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
@ -45,9 +45,9 @@ auto* tf_data_service_created =
|
||||||
"has been created.");
|
"has been created.");
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
DataServiceWorkerImpl::DataServiceWorkerImpl(const std::string& master_address,
|
DataServiceWorkerImpl::DataServiceWorkerImpl(
|
||||||
const std::string& protocol)
|
const std::string& dispatcher_address, const std::string& protocol)
|
||||||
: master_address_(master_address), protocol_(protocol) {
|
: dispatcher_address_(dispatcher_address), protocol_(protocol) {
|
||||||
tf_data_service_created->GetCell()->Set(true);
|
tf_data_service_created->GetCell()->Set(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,14 +67,13 @@ void DataServiceWorkerImpl::Start(const std::string& worker_address) {
|
||||||
heartbeat_thread_.reset(thread);
|
heartbeat_thread_.reset(thread);
|
||||||
Status s = Register();
|
Status s = Register();
|
||||||
while (!s.ok()) {
|
while (!s.ok()) {
|
||||||
LOG(WARNING) << "Failed to register with master at " << master_address_
|
LOG(WARNING) << "Failed to register with dispatcher at "
|
||||||
<< ": " << s;
|
<< dispatcher_address_ << ": " << s;
|
||||||
Env::Default()->SleepForMicroseconds(kHeartbeatIntervalMicros);
|
Env::Default()->SleepForMicroseconds(kHeartbeatIntervalMicros);
|
||||||
s = Register();
|
s = Register();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
|
Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
|
||||||
ProcessTaskResponse* response) {
|
ProcessTaskResponse* response) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
@ -169,29 +168,29 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceWorkerImpl::EnsureMasterStubInitialized()
|
Status DataServiceWorkerImpl::EnsureDispatcherStubInitialized()
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
if (!master_stub_) {
|
if (!dispatcher_stub_) {
|
||||||
::grpc::ChannelArguments args;
|
::grpc::ChannelArguments args;
|
||||||
std::shared_ptr<::grpc::ChannelCredentials> credentials;
|
std::shared_ptr<::grpc::ChannelCredentials> credentials;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
|
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
|
||||||
auto channel =
|
auto channel =
|
||||||
::grpc::CreateCustomChannel(master_address_, credentials, args);
|
::grpc::CreateCustomChannel(dispatcher_address_, credentials, args);
|
||||||
master_stub_ = MasterService::NewStub(channel);
|
dispatcher_stub_ = DispatcherService::NewStub(channel);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
VLOG(3) << "Registering with master at " << master_address_;
|
VLOG(3) << "Registering with dispatcher at " << dispatcher_address_;
|
||||||
TF_RETURN_IF_ERROR(EnsureMasterStubInitialized());
|
TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
|
||||||
RegisterWorkerRequest req;
|
RegisterWorkerRequest req;
|
||||||
req.set_worker_address(worker_address_);
|
req.set_worker_address(worker_address_);
|
||||||
RegisterWorkerResponse resp;
|
RegisterWorkerResponse resp;
|
||||||
|
|
||||||
grpc::ClientContext ctx;
|
grpc::ClientContext ctx;
|
||||||
grpc::Status s = master_stub_->RegisterWorker(&ctx, req, &resp);
|
grpc::Status s = dispatcher_stub_->RegisterWorker(&ctx, req, &resp);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
return grpc_util::WrapError("Failed to register worker", s);
|
return grpc_util::WrapError("Failed to register worker", s);
|
||||||
}
|
}
|
||||||
|
@ -205,8 +204,8 @@ Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
|
||||||
Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
VLOG(3) << "Sending " << pending_completed_tasks_.size()
|
VLOG(3) << "Sending " << pending_completed_tasks_.size()
|
||||||
<< " task updates to master";
|
<< " task updates to dispatcher";
|
||||||
TF_RETURN_IF_ERROR(EnsureMasterStubInitialized());
|
TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
|
||||||
WorkerUpdateRequest req;
|
WorkerUpdateRequest req;
|
||||||
req.set_worker_id(worker_id_);
|
req.set_worker_id(worker_id_);
|
||||||
for (int task_id : pending_completed_tasks_) {
|
for (int task_id : pending_completed_tasks_) {
|
||||||
|
@ -217,7 +216,7 @@ Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
|
||||||
WorkerUpdateResponse resp;
|
WorkerUpdateResponse resp;
|
||||||
grpc::ClientContext ctx;
|
grpc::ClientContext ctx;
|
||||||
grpc::Status s = master_stub_->WorkerUpdate(&ctx, req, &resp);
|
grpc::Status s = dispatcher_stub_->WorkerUpdate(&ctx, req, &resp);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
return grpc_util::WrapError("Failed to send task updates", s);
|
return grpc_util::WrapError("Failed to send task updates", s);
|
||||||
}
|
}
|
||||||
|
@ -238,7 +237,7 @@ void DataServiceWorkerImpl::HeartbeatThread() {
|
||||||
}
|
}
|
||||||
Status s = SendTaskUpdate();
|
Status s = SendTaskUpdate();
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
LOG(WARNING) << "Failed to send task updates to master: " << s;
|
LOG(WARNING) << "Failed to send task updates to dispatcher: " << s;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/core/data/service/common.pb.h"
|
#include "tensorflow/core/data/service/common.pb.h"
|
||||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||||
#include "tensorflow/core/data/service/worker.pb.h"
|
#include "tensorflow/core/data/service/worker.pb.h"
|
||||||
#include "tensorflow/core/data/standalone.h"
|
#include "tensorflow/core/data/standalone.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
@ -29,17 +29,17 @@ namespace data {
|
||||||
// A TensorFlow DataService serves dataset elements over RPC.
|
// A TensorFlow DataService serves dataset elements over RPC.
|
||||||
class DataServiceWorkerImpl {
|
class DataServiceWorkerImpl {
|
||||||
public:
|
public:
|
||||||
explicit DataServiceWorkerImpl(const std::string& master_address,
|
explicit DataServiceWorkerImpl(const std::string& dispatcher_address,
|
||||||
const std::string& protocol);
|
const std::string& protocol);
|
||||||
~DataServiceWorkerImpl();
|
~DataServiceWorkerImpl();
|
||||||
|
|
||||||
// Starts the worker. The worker needs to know its own address so that it can
|
// Starts the worker. The worker needs to know its own address so that it can
|
||||||
// register with the master.
|
// register with the dispatcher.
|
||||||
void Start(const std::string& worker_address);
|
void Start(const std::string& worker_address);
|
||||||
|
|
||||||
// See worker.proto for API documentation.
|
// See worker.proto for API documentation.
|
||||||
|
|
||||||
/// Master-facing API.
|
/// Dispatcher-facing API.
|
||||||
Status ProcessTask(const ProcessTaskRequest* request,
|
Status ProcessTask(const ProcessTaskRequest* request,
|
||||||
ProcessTaskResponse* response);
|
ProcessTaskResponse* response);
|
||||||
|
|
||||||
|
@ -48,15 +48,15 @@ class DataServiceWorkerImpl {
|
||||||
GetElementResponse* response);
|
GetElementResponse* response);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Sets master_stub_ if it isn't already set.
|
// Sets dispatcher_stub_ if it isn't already set.
|
||||||
Status EnsureMasterStubInitialized();
|
Status EnsureDispatcherStubInitialized();
|
||||||
// Registers the worker with the master.
|
// Registers the worker with the dispatcher.
|
||||||
Status Register();
|
Status Register();
|
||||||
// Sends task status to the master.
|
// Sends task status to the dispatcher.
|
||||||
Status SendTaskUpdate();
|
Status SendTaskUpdate();
|
||||||
// Creates an iterator to process a task.
|
// Creates an iterator to process a task.
|
||||||
Status ProcessTaskInternal(const TaskDef& task);
|
Status ProcessTaskInternal(const TaskDef& task);
|
||||||
// A thread for updating the master with worker status.
|
// A thread for updating the dispatcher with worker status.
|
||||||
void HeartbeatThread();
|
void HeartbeatThread();
|
||||||
|
|
||||||
typedef struct Task {
|
typedef struct Task {
|
||||||
|
@ -67,18 +67,19 @@ class DataServiceWorkerImpl {
|
||||||
std::unique_ptr<standalone::Iterator> iterator;
|
std::unique_ptr<standalone::Iterator> iterator;
|
||||||
} Task;
|
} Task;
|
||||||
|
|
||||||
const std::string master_address_;
|
const std::string dispatcher_address_;
|
||||||
// Protocol for communicating with the master.
|
// Protocol for communicating with the dispatcher.
|
||||||
const std::string protocol_;
|
const std::string protocol_;
|
||||||
// The worker's own address.
|
// The worker's own address.
|
||||||
std::string worker_address_;
|
std::string worker_address_;
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
int64 worker_id_ TF_GUARDED_BY(mu_);
|
int64 worker_id_ TF_GUARDED_BY(mu_);
|
||||||
std::unique_ptr<MasterService::Stub> master_stub_ TF_GUARDED_BY(mu_);
|
std::unique_ptr<DispatcherService::Stub> dispatcher_stub_ TF_GUARDED_BY(mu_);
|
||||||
// Information about tasks, keyed by task ids.
|
// Information about tasks, keyed by task ids.
|
||||||
absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_);
|
absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_);
|
||||||
// List of completed tasks which haven't yet been communicated to the master.
|
// List of completed tasks which haven't yet been communicated to the
|
||||||
|
// dispatcher.
|
||||||
std::vector<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
|
std::vector<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
|
||||||
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
||||||
// Condition variable for notifying the heartbeat thread.
|
// Condition variable for notifying the heartbeat thread.
|
||||||
|
|
|
@ -69,7 +69,7 @@ const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second.
|
||||||
// Dataset for reading data from the tf.data service non-deterministically.
|
// Dataset for reading data from the tf.data service non-deterministically.
|
||||||
//
|
//
|
||||||
// This dataset interleaves dataset elements produced by multiple tf.data
|
// This dataset interleaves dataset elements produced by multiple tf.data
|
||||||
// workers. We periodically query the tf.data master to determine which workers
|
// workers. We periodically query the dispatcher to determine which workers
|
||||||
// to read from (in case workers are added or removed).
|
// to read from (in case workers are added or removed).
|
||||||
class DataServiceDatasetOp::Dataset : public DatasetBase {
|
class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||||
public:
|
public:
|
||||||
|
@ -199,12 +199,13 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
VLOG(3) << "Connecting to " << dataset()->address_
|
VLOG(3) << "Connecting to " << dataset()->address_
|
||||||
<< " in data service dataset op";
|
<< " in data service dataset op";
|
||||||
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
|
DataServiceDispatcherClient dispatcher(dataset()->address_,
|
||||||
|
dataset()->protocol_);
|
||||||
if (dataset()->job_name_.empty()) {
|
if (dataset()->job_name_.empty()) {
|
||||||
TF_RETURN_IF_ERROR(master.CreateJob(
|
TF_RETURN_IF_ERROR(dispatcher.CreateJob(
|
||||||
dataset()->dataset_id_, dataset()->processing_mode_, &job_id_));
|
dataset()->dataset_id_, dataset()->processing_mode_, &job_id_));
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(master.GetOrCreateJob(
|
TF_RETURN_IF_ERROR(dispatcher.GetOrCreateJob(
|
||||||
dataset()->dataset_id_, dataset()->processing_mode_,
|
dataset()->dataset_id_, dataset()->processing_mode_,
|
||||||
dataset()->job_name_, iterator_index_, &job_id_));
|
dataset()->job_name_, iterator_index_, &job_id_));
|
||||||
}
|
}
|
||||||
|
@ -283,11 +284,12 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||||
|
|
||||||
// Periodically refresh the task list.
|
// Periodically refresh the task list.
|
||||||
// Maintain one thread fetching elements for each task.
|
// Maintain one thread fetching elements for each task.
|
||||||
// TODO(aaudibert): Instead of polling, have master send updates when
|
// TODO(aaudibert): Instead of polling, have dispatcher send updates when
|
||||||
// the list of tasks changes.
|
// the list of tasks changes.
|
||||||
void TaskThreadManager(std::unique_ptr<IteratorContext> ctx) {
|
void TaskThreadManager(std::unique_ptr<IteratorContext> ctx) {
|
||||||
VLOG(3) << "Starting task thread manager";
|
VLOG(3) << "Starting task thread manager";
|
||||||
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
|
DataServiceDispatcherClient dispatcher(dataset()->address_,
|
||||||
|
dataset()->protocol_);
|
||||||
uint64 next_check = Env::Default()->NowMicros();
|
uint64 next_check = Env::Default()->NowMicros();
|
||||||
while (true) {
|
while (true) {
|
||||||
{
|
{
|
||||||
|
@ -305,18 +307,19 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
UpdateTasks(&master);
|
UpdateTasks(&dispatcher);
|
||||||
UpdateWorkerThreads(ctx.get());
|
UpdateWorkerThreads(ctx.get());
|
||||||
next_check = Env::Default()->NowMicros() +
|
next_check = Env::Default()->NowMicros() +
|
||||||
dataset()->task_refresh_interval_ms_ * 1000;
|
dataset()->task_refresh_interval_ms_ * 1000;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTasks(DataServiceMasterClient* master) LOCKS_EXCLUDED(mu_) {
|
void UpdateTasks(DataServiceDispatcherClient* dispatcher)
|
||||||
|
LOCKS_EXCLUDED(mu_) {
|
||||||
VLOG(3) << "Updating tasks";
|
VLOG(3) << "Updating tasks";
|
||||||
std::vector<TaskInfo> tasks;
|
std::vector<TaskInfo> tasks;
|
||||||
bool job_finished;
|
bool job_finished;
|
||||||
Status s = master->GetTasks(job_id_, &tasks, &job_finished);
|
Status s = dispatcher->GetTasks(job_id_, &tasks, &job_finished);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
LOG(WARNING) << "Failed to get task info for job id " << job_id_ << ": "
|
LOG(WARNING) << "Failed to get task info for job id " << job_id_ << ": "
|
||||||
<< s;
|
<< s;
|
||||||
|
|
|
@ -53,7 +53,7 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
|
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
|
||||||
|
|
||||||
DataServiceMasterClient client(address, protocol);
|
DataServiceDispatcherClient client(address, protocol);
|
||||||
int64 dataset_id;
|
int64 dataset_id;
|
||||||
OP_REQUIRES_OK(ctx, client.RegisterDataset(graph_def, &dataset_id));
|
OP_REQUIRES_OK(ctx, client.RegisterDataset(graph_def, &dataset_id));
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ namespace data {
|
||||||
|
|
||||||
// Registers a dataset with the tf.data service.
|
// Registers a dataset with the tf.data service.
|
||||||
//
|
//
|
||||||
// The address and protocol inputs are used to connect to the tf.data master.
|
// The address and protocol inputs are used to connect to the dispatcher.
|
||||||
// The external state policy attribute determines whether to ignore, warn, or
|
// The external state policy attribute determines whether to ignore, warn, or
|
||||||
// error out when the dataset contains external state.
|
// error out when the dataset contains external state.
|
||||||
// The op produces a dataset id for identifying the registered dataset.
|
// The op produces a dataset id for identifying the registered dataset.
|
||||||
|
|
|
@ -77,7 +77,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||||
amount of memory used, since `distribute` won't use more than
|
amount of memory used, since `distribute` won't use more than
|
||||||
`element_size` * `max_outstanding_requests` of memory.
|
`element_size` * `max_outstanding_requests` of memory.
|
||||||
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
|
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
|
||||||
the master for task changes.
|
the dispatcher for task changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if job_name is None:
|
if job_name is None:
|
||||||
|
@ -173,7 +173,7 @@ def _distribute(processing_mode,
|
||||||
of memory used, since `distribute` won't use more than `element_size` *
|
of memory used, since `distribute` won't use more than `element_size` *
|
||||||
`max_outstanding_requests` of memory.
|
`max_outstanding_requests` of memory.
|
||||||
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
|
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
|
||||||
master for task changes.
|
dispatcher for task changes.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset: A `Dataset` of the elements produced by the data service.
|
Dataset: A `Dataset` of the elements produced by the data service.
|
||||||
|
|
|
@ -19,5 +19,5 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.ops.data_service_ops import distribute
|
from tensorflow.python.data.experimental.ops.data_service_ops import distribute
|
||||||
from tensorflow.python.data.experimental.service.server_lib import MasterServer
|
from tensorflow.python.data.experimental.service.server_lib import DispatchServer
|
||||||
from tensorflow.python.data.experimental.service.server_lib import WorkerServer
|
from tensorflow.python.data.experimental.service.server_lib import WorkerServer
|
||||||
|
|
|
@ -24,35 +24,35 @@ from tensorflow.python.data.experimental.service import _pywrap_server_lib
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@tf_export("data.experimental.service.MasterServer", v1=[])
|
@tf_export("data.experimental.service.DispatchServer", v1=[])
|
||||||
class MasterServer(object):
|
class DispatchServer(object):
|
||||||
"""An in-process tf.data service master server.
|
"""An in-process tf.data service dispatch server.
|
||||||
|
|
||||||
A `tf.data.experimental.service.MasterServer` coordinates a cluster of
|
A `tf.data.experimental.service.DispatchServer` coordinates a cluster of
|
||||||
`tf.data.experimental.service.WorkerServer`s. When the workers start, they
|
`tf.data.experimental.service.WorkerServer`s. When the workers start, they
|
||||||
register themselves with the master.
|
register themselves with the dispatcher.
|
||||||
|
|
||||||
>>> master = tf.data.experimental.service.MasterServer(port=0)
|
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
|
||||||
>>> master_address = master.target.split("://")[1]
|
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
||||||
>>> worker = tf.data.experimental.service.WorkerServer(
|
>>> worker = tf.data.experimental.service.WorkerServer(
|
||||||
... port=0, master_address=master_address)
|
... port=0, dispatcher_address=dispatcher_address)
|
||||||
>>> dataset = tf.data.Dataset.range(10)
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||||
... processing_mode="parallel_epochs", service=master.target))
|
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||||||
>>> print(list(dataset.as_numpy_iterator()))
|
>>> print(list(dataset.as_numpy_iterator()))
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||||
|
|
||||||
When starting a dedicated tf.data master process, use join() to block
|
When starting a dedicated tf.data dispatch process, use join() to block
|
||||||
indefinitely after starting up the server.
|
indefinitely after starting up the server.
|
||||||
|
|
||||||
```
|
```
|
||||||
master = tf.data.experimental.service.MasterServer(port=5050)
|
dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
|
||||||
master.join()
|
dispatcher.join()
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, port, protocol=None, start=True):
|
def __init__(self, port, protocol=None, start=True):
|
||||||
"""Creates a new master server.
|
"""Creates a new dispatch server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
port: Specifies the port to bind to.
|
port: Specifies the port to bind to.
|
||||||
|
@ -68,15 +68,16 @@ class MasterServer(object):
|
||||||
if protocol is None:
|
if protocol is None:
|
||||||
protocol = "grpc"
|
protocol = "grpc"
|
||||||
self._protocol = protocol
|
self._protocol = protocol
|
||||||
self._server = _pywrap_server_lib.TF_DATA_NewMasterServer(port, protocol)
|
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(port, protocol)
|
||||||
if start:
|
if start:
|
||||||
self._server.start()
|
self._server.start()
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Starts this server.
|
"""Starts this server.
|
||||||
|
|
||||||
>>> master = tf.data.experimental.service.MasterServer(port=0, start=False)
|
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0,
|
||||||
>>> master.start()
|
... start=False)
|
||||||
|
>>> dispatcher.start()
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
||||||
|
@ -87,11 +88,11 @@ class MasterServer(object):
|
||||||
def join(self):
|
def join(self):
|
||||||
"""Blocks until the server has shut down.
|
"""Blocks until the server has shut down.
|
||||||
|
|
||||||
This is useful when starting a dedicated master process.
|
This is useful when starting a dedicated dispatch process.
|
||||||
|
|
||||||
```
|
```
|
||||||
master = tf.data.experimental.service.MasterServer(port=5050)
|
dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
|
||||||
master.join()
|
dispatcher.join()
|
||||||
```
|
```
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -104,10 +105,10 @@ class MasterServer(object):
|
||||||
def target(self):
|
def target(self):
|
||||||
"""Returns a target that can be used to connect to the server.
|
"""Returns a target that can be used to connect to the server.
|
||||||
|
|
||||||
>>> master = tf.data.experimental.service.MasterServer(port=0)
|
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
|
||||||
>>> dataset = tf.data.Dataset.range(10)
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||||
... processing_mode="parallel_epochs", service=master.target))
|
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||||||
|
|
||||||
The returned string will be in the form protocol://address, e.g.
|
The returned string will be in the form protocol://address, e.g.
|
||||||
"grpc://localhost:5050".
|
"grpc://localhost:5050".
|
||||||
|
@ -136,7 +137,7 @@ class MasterServer(object):
|
||||||
return "localhost:{0}".format(self._server.bound_port())
|
return "localhost:{0}".format(self._server.bound_port())
|
||||||
|
|
||||||
def _num_workers(self):
|
def _num_workers(self):
|
||||||
"""Returns the number of workers registered with the master."""
|
"""Returns the number of workers registered with the dispatcher."""
|
||||||
return self._server.num_workers()
|
return self._server.num_workers()
|
||||||
|
|
||||||
|
|
||||||
|
@ -147,15 +148,15 @@ class WorkerServer(object):
|
||||||
A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset`
|
A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset`
|
||||||
processing for user-defined datasets, and provides the resulting elements over
|
processing for user-defined datasets, and provides the resulting elements over
|
||||||
RPC. A worker is associated with a single
|
RPC. A worker is associated with a single
|
||||||
`tf.data.experimental.service.MasterServer`.
|
`tf.data.experimental.service.DispatchServer`.
|
||||||
|
|
||||||
>>> master = tf.data.experimental.service.MasterServer(port=0)
|
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
|
||||||
>>> master_address = master.target.split("://")[1]
|
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
||||||
>>> worker = tf.data.experimental.service.WorkerServer(
|
>>> worker = tf.data.experimental.service.WorkerServer(
|
||||||
... port=0, master_address=master_address)
|
... port=0, dispatcher_address=dispatcher_address)
|
||||||
>>> dataset = tf.data.Dataset.range(10)
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||||
... processing_mode="parallel_epochs", service=master.target))
|
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||||||
>>> print(list(dataset.as_numpy_iterator()))
|
>>> print(list(dataset.as_numpy_iterator()))
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||||
|
|
||||||
|
@ -164,14 +165,14 @@ class WorkerServer(object):
|
||||||
|
|
||||||
```
|
```
|
||||||
worker = tf.data.experimental.service.WorkerServer(
|
worker = tf.data.experimental.service.WorkerServer(
|
||||||
port=5051, master_address="grpc://localhost:5050")
|
port=5051, dispatcher_address="grpc://localhost:5050")
|
||||||
worker.join()
|
worker.join()
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
port,
|
port,
|
||||||
master_address,
|
dispatcher_address,
|
||||||
worker_address=None,
|
worker_address=None,
|
||||||
protocol=None,
|
protocol=None,
|
||||||
start=True):
|
start=True):
|
||||||
|
@ -180,11 +181,12 @@ class WorkerServer(object):
|
||||||
Args:
|
Args:
|
||||||
port: Specifies the port to bind to. A value of 0 indicates that the
|
port: Specifies the port to bind to. A value of 0 indicates that the
|
||||||
worker can bind to any available port.
|
worker can bind to any available port.
|
||||||
master_address: Specifies the address of the master server.
|
dispatcher_address: Specifies the address of the dispatcher.
|
||||||
worker_address: (Optional.) Specifies the address of the worker server.
|
worker_address: (Optional.) Specifies the address of the worker server.
|
||||||
This address is passed to the master server so that the master can tell
|
This address is passed to the dispatcher so that the dispatcher can
|
||||||
clients how to connect to this worker. Defaults to `"localhost:%port%"`,
|
tell clients how to connect to this worker. Defaults to
|
||||||
where `%port%` will be replaced with the port used by the worker.
|
`"localhost:%port%"`, where `%port%` will be replaced with the port used
|
||||||
|
by the worker.
|
||||||
protocol: (Optional.) Specifies the protocol to be used by the server.
|
protocol: (Optional.) Specifies the protocol to be used by the server.
|
||||||
Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`.
|
Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`.
|
||||||
start: (Optional.) Boolean, indicating whether to start the server after
|
start: (Optional.) Boolean, indicating whether to start the server after
|
||||||
|
@ -201,7 +203,7 @@ class WorkerServer(object):
|
||||||
|
|
||||||
self._protocol = protocol
|
self._protocol = protocol
|
||||||
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
|
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
|
||||||
port, protocol, master_address, worker_address)
|
port, protocol, dispatcher_address, worker_address)
|
||||||
if start:
|
if start:
|
||||||
self._server.start()
|
self._server.start()
|
||||||
|
|
||||||
|
@ -221,7 +223,7 @@ class WorkerServer(object):
|
||||||
|
|
||||||
```
|
```
|
||||||
worker_server = tf.data.experimental.service.WorkerServer(
|
worker_server = tf.data.experimental.service.WorkerServer(
|
||||||
port=5051, master_address="grpc://localhost:5050")
|
port=5051, dispatcher_address="grpc://localhost:5050")
|
||||||
worker_server.join()
|
worker_server.join()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -25,68 +25,68 @@ from tensorflow.python.platform import test
|
||||||
|
|
||||||
class ServerLibTest(test.TestCase):
|
class ServerLibTest(test.TestCase):
|
||||||
|
|
||||||
def testStartMaster(self):
|
def testStartDispatcher(self):
|
||||||
master = server_lib.MasterServer(0, start=False)
|
dispatcher = server_lib.DispatchServer(0, start=False)
|
||||||
master.start()
|
dispatcher.start()
|
||||||
|
|
||||||
def testMultipleStartMaster(self):
|
def testMultipleStartDispatcher(self):
|
||||||
master = server_lib.MasterServer(0, start=True)
|
dispatcher = server_lib.DispatchServer(0, start=True)
|
||||||
master.start()
|
dispatcher.start()
|
||||||
|
|
||||||
def testStartWorker(self):
|
def testStartWorker(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
worker = server_lib.WorkerServer(0, master._address, start=False)
|
worker = server_lib.WorkerServer(0, dispatcher._address, start=False)
|
||||||
worker.start()
|
worker.start()
|
||||||
|
|
||||||
def testMultipleStartWorker(self):
|
def testMultipleStartWorker(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
worker = server_lib.WorkerServer(0, master._address, start=True)
|
worker = server_lib.WorkerServer(0, dispatcher._address, start=True)
|
||||||
worker.start()
|
worker.start()
|
||||||
|
|
||||||
def testStopMaster(self):
|
def testStopDispatcher(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
master._stop()
|
dispatcher._stop()
|
||||||
master._stop()
|
dispatcher._stop()
|
||||||
|
|
||||||
def testStopWorker(self):
|
def testStopWorker(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
worker = server_lib.WorkerServer(0, master._address)
|
worker = server_lib.WorkerServer(0, dispatcher._address)
|
||||||
worker._stop()
|
worker._stop()
|
||||||
worker._stop()
|
worker._stop()
|
||||||
|
|
||||||
def testStopStartMaster(self):
|
def testStopStartDispatcher(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
master._stop()
|
dispatcher._stop()
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "Server cannot be started after it has been stopped"):
|
RuntimeError, "Server cannot be started after it has been stopped"):
|
||||||
master.start()
|
dispatcher.start()
|
||||||
|
|
||||||
def testStopStartWorker(self):
|
def testStopStartWorker(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
worker = server_lib.WorkerServer(0, master._address)
|
worker = server_lib.WorkerServer(0, dispatcher._address)
|
||||||
worker._stop()
|
worker._stop()
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "Server cannot be started after it has been stopped"):
|
RuntimeError, "Server cannot be started after it has been stopped"):
|
||||||
worker.start()
|
worker.start()
|
||||||
|
|
||||||
def testJoinMaster(self):
|
def testJoinDispatcher(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
master._stop()
|
dispatcher._stop()
|
||||||
master.join()
|
dispatcher.join()
|
||||||
|
|
||||||
def testJoinWorker(self):
|
def testJoinWorker(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
worker = server_lib.WorkerServer(0, master._address)
|
worker = server_lib.WorkerServer(0, dispatcher._address)
|
||||||
worker._stop()
|
worker._stop()
|
||||||
worker.join()
|
worker.join()
|
||||||
|
|
||||||
def testMasterNumWorkers(self):
|
def testDispatcherNumWorkers(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
self.assertEqual(0, master._num_workers())
|
self.assertEqual(0, dispatcher._num_workers())
|
||||||
worker1 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable
|
worker1 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable
|
||||||
self.assertEqual(1, master._num_workers())
|
self.assertEqual(1, dispatcher._num_workers())
|
||||||
worker2 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable
|
worker2 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable
|
||||||
self.assertEqual(2, master._num_workers())
|
self.assertEqual(2, dispatcher._num_workers())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -28,13 +28,14 @@ limitations under the License.
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
PYBIND11_MODULE(_pywrap_server_lib, m) {
|
PYBIND11_MODULE(_pywrap_server_lib, m) {
|
||||||
py::class_<tensorflow::data::MasterGrpcDataServer>(m, "MasterGrpcDataServer")
|
py::class_<tensorflow::data::DispatchGrpcDataServer>(m,
|
||||||
.def("start", &tensorflow::data::MasterGrpcDataServer::Start)
|
"DispatchGrpcDataServer")
|
||||||
.def("stop", &tensorflow::data::MasterGrpcDataServer::Stop)
|
.def("start", &tensorflow::data::DispatchGrpcDataServer::Start)
|
||||||
.def("join", &tensorflow::data::MasterGrpcDataServer::Join)
|
.def("stop", &tensorflow::data::DispatchGrpcDataServer::Stop)
|
||||||
.def("bound_port", &tensorflow::data::MasterGrpcDataServer::BoundPort)
|
.def("join", &tensorflow::data::DispatchGrpcDataServer::Join)
|
||||||
|
.def("bound_port", &tensorflow::data::DispatchGrpcDataServer::BoundPort)
|
||||||
.def("num_workers",
|
.def("num_workers",
|
||||||
[](tensorflow::data::MasterGrpcDataServer* server) -> int {
|
[](tensorflow::data::DispatchGrpcDataServer* server) -> int {
|
||||||
int num_workers;
|
int num_workers;
|
||||||
tensorflow::Status status = server->NumWorkers(&num_workers);
|
tensorflow::Status status = server->NumWorkers(&num_workers);
|
||||||
tensorflow::MaybeRaiseFromStatus(status);
|
tensorflow::MaybeRaiseFromStatus(status);
|
||||||
|
@ -48,12 +49,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
|
||||||
.def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort);
|
.def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort);
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"TF_DATA_NewMasterServer",
|
"TF_DATA_NewDispatchServer",
|
||||||
[](int port, std::string protocol)
|
[](int port, std::string protocol)
|
||||||
-> std::unique_ptr<tensorflow::data::MasterGrpcDataServer> {
|
-> std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> {
|
||||||
std::unique_ptr<tensorflow::data::MasterGrpcDataServer> server;
|
std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> server;
|
||||||
tensorflow::Status status =
|
tensorflow::Status status =
|
||||||
tensorflow::data::NewMasterServer(port, protocol, &server);
|
tensorflow::data::NewDispatchServer(port, protocol, &server);
|
||||||
tensorflow::MaybeRaiseFromStatus(status);
|
tensorflow::MaybeRaiseFromStatus(status);
|
||||||
return server;
|
return server;
|
||||||
},
|
},
|
||||||
|
@ -61,12 +62,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"TF_DATA_NewWorkerServer",
|
"TF_DATA_NewWorkerServer",
|
||||||
[](int port, std::string protocol, std::string master_address,
|
[](int port, std::string protocol, std::string dispatcher_address,
|
||||||
std::string worker_address)
|
std::string worker_address)
|
||||||
-> std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> {
|
-> std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> {
|
||||||
std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server;
|
std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server;
|
||||||
tensorflow::Status status = tensorflow::data::NewWorkerServer(
|
tensorflow::Status status = tensorflow::data::NewWorkerServer(
|
||||||
port, protocol, master_address, worker_address, &server);
|
port, protocol, dispatcher_address, worker_address, &server);
|
||||||
tensorflow::MaybeRaiseFromStatus(status);
|
tensorflow::MaybeRaiseFromStatus(status);
|
||||||
return server;
|
return server;
|
||||||
},
|
},
|
||||||
|
|
|
@ -59,23 +59,25 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
num_workers: The number of workers in the cluster.
|
num_workers: The number of workers in the cluster.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The address of the master.
|
The address of the dispatcher.
|
||||||
"""
|
"""
|
||||||
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
|
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
|
||||||
self._servers = []
|
self._servers = []
|
||||||
for _ in range(num_workers):
|
for _ in range(num_workers):
|
||||||
self._servers.append(
|
self._servers.append(
|
||||||
server_lib.WorkerServer(
|
server_lib.WorkerServer(
|
||||||
port=0, master_address=self._master._address, protocol=PROTOCOL))
|
port=0,
|
||||||
|
dispatcher_address=self._dispatcher._address,
|
||||||
|
protocol=PROTOCOL))
|
||||||
|
|
||||||
return self._master._address
|
return self._dispatcher._address
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeBasic(self):
|
def testDistributeBasic(self):
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
results = [elem.numpy() for elem in ds]
|
results = [elem.numpy() for elem in ds]
|
||||||
self.assertEqual(list(range(num_elements)), results)
|
self.assertEqual(list(range(num_elements)), results)
|
||||||
|
|
||||||
|
@ -83,10 +85,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
def testDifferentShuffleOrders(self):
|
def testDifferentShuffleOrders(self):
|
||||||
random_seed.set_random_seed(None)
|
random_seed.set_random_seed(None)
|
||||||
num_elements = 100
|
num_elements = 100
|
||||||
master_address = self.create_cluster(2)
|
dispatcher_address = self.create_cluster(2)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = ds.shuffle(num_elements)
|
ds = ds.shuffle(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
output = [elem.numpy() for elem in ds]
|
output = [elem.numpy() for elem in ds]
|
||||||
|
|
||||||
# The output will be two sequences of range(num_elements)
|
# The output will be two sequences of range(num_elements)
|
||||||
|
@ -104,9 +106,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testMultipleEpochs(self):
|
def testMultipleEpochs(self):
|
||||||
num_elements = 3
|
num_elements = 3
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds])
|
self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds])
|
||||||
|
|
||||||
|
@ -114,9 +116,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
def testRepeatedDataset(self):
|
def testRepeatedDataset(self):
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
num_repetitions = 5
|
num_repetitions = 5
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
ds = ds.repeat(num_repetitions)
|
ds = ds.repeat(num_repetitions)
|
||||||
self.assertDatasetProduces(
|
self.assertDatasetProduces(
|
||||||
ds, expected_output=num_repetitions * list(range(num_elements)))
|
ds, expected_output=num_repetitions * list(range(num_elements)))
|
||||||
|
@ -125,12 +127,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
def testConcurrentEpoch(self):
|
def testConcurrentEpoch(self):
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
num_datasets = 3
|
num_datasets = 3
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
iterators = []
|
iterators = []
|
||||||
results = []
|
results = []
|
||||||
for _ in range(num_datasets):
|
for _ in range(num_datasets):
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
iterators.append(iter(ds))
|
iterators.append(iter(ds))
|
||||||
results.append([])
|
results.append([])
|
||||||
|
|
||||||
|
@ -146,9 +148,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
self.skipTest("Not yet implemented")
|
self.skipTest("Not yet implemented")
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
num_iterators = 3
|
num_iterators = 3
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
result = []
|
result = []
|
||||||
iterators = []
|
iterators = []
|
||||||
for _ in range(num_iterators):
|
for _ in range(num_iterators):
|
||||||
|
@ -170,20 +172,20 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
def testMultiWorker(self):
|
def testMultiWorker(self):
|
||||||
num_workers = 3
|
num_workers = 3
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
master_address = self.create_cluster(num_workers)
|
dispatcher_address = self.create_cluster(num_workers)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
results = [elem.numpy() for elem in ds]
|
results = [elem.numpy() for elem in ds]
|
||||||
self.assertCountEqual(num_workers * list(range(num_elements)), results)
|
self.assertCountEqual(num_workers * list(range(num_elements)), results)
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testAddWorkerMidJob(self):
|
def testAddWorkerMidJob(self):
|
||||||
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
|
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
|
||||||
self._worker = server_lib.WorkerServer(
|
self._worker = server_lib.WorkerServer(
|
||||||
port=0, master_address=self._master._address, protocol=PROTOCOL)
|
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
|
||||||
num_elements = 100
|
num_elements = 100
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, self._master._address)
|
ds = _make_distributed_dataset(ds, self._dispatcher._address)
|
||||||
iterator = iter(ds)
|
iterator = iter(ds)
|
||||||
results = []
|
results = []
|
||||||
# Read halfway through the dataset.
|
# Read halfway through the dataset.
|
||||||
|
@ -191,10 +193,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
results.append(next(iterator).numpy())
|
results.append(next(iterator).numpy())
|
||||||
|
|
||||||
self._new_worker = server_lib.WorkerServer(
|
self._new_worker = server_lib.WorkerServer(
|
||||||
port=0, master_address=self._master._address, protocol=PROTOCOL)
|
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
|
||||||
|
|
||||||
# Wait for the new worker to register with the master.
|
# Wait for the new worker to register with the dispatcher.
|
||||||
while self._master._num_workers() < 2:
|
while self._dispatcher._num_workers() < 2:
|
||||||
time.sleep(10 / 1000) # 10ms
|
time.sleep(10 / 1000) # 10ms
|
||||||
|
|
||||||
for elem in iterator:
|
for elem in iterator:
|
||||||
|
@ -206,12 +208,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
combinations.times(test_base.eager_only_combinations(),
|
combinations.times(test_base.eager_only_combinations(),
|
||||||
combinations.combine(use_same_port=[True, False])))
|
combinations.combine(use_same_port=[True, False])))
|
||||||
def testRestartWorker(self, use_same_port):
|
def testRestartWorker(self, use_same_port):
|
||||||
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
|
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
|
||||||
self._worker = server_lib.WorkerServer(
|
self._worker = server_lib.WorkerServer(
|
||||||
port=0, master_address=self._master._address, protocol=PROTOCOL)
|
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
|
||||||
num_elements = 100
|
num_elements = 100
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, self._master._address)
|
ds = _make_distributed_dataset(ds, self._dispatcher._address)
|
||||||
iterator = iter(ds)
|
iterator = iter(ds)
|
||||||
# Read halfway through the dataset.
|
# Read halfway through the dataset.
|
||||||
midpoint = num_elements // 2
|
midpoint = num_elements // 2
|
||||||
|
@ -224,7 +226,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
port = int(self._worker._address.split(":")[1])
|
port = int(self._worker._address.split(":")[1])
|
||||||
self._worker._stop()
|
self._worker._stop()
|
||||||
self._new_worker = server_lib.WorkerServer(
|
self._new_worker = server_lib.WorkerServer(
|
||||||
port=port, master_address=self._master._address, protocol=PROTOCOL)
|
port=port,
|
||||||
|
dispatcher_address=self._dispatcher._address,
|
||||||
|
protocol=PROTOCOL)
|
||||||
|
|
||||||
# There may have been some elements prefetched from the first worker
|
# There may have been some elements prefetched from the first worker
|
||||||
# before it was stopped.
|
# before it was stopped.
|
||||||
|
@ -259,12 +263,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
def testInsideFunction(self):
|
def testInsideFunction(self):
|
||||||
num_workers = 3
|
num_workers = 3
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
master_address = self.create_cluster(num_workers)
|
dispatcher_address = self.create_cluster(num_workers)
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def f():
|
def f():
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
result = tensor_array_ops.TensorArray(
|
result = tensor_array_ops.TensorArray(
|
||||||
dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
|
dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
|
||||||
i = 0
|
i = 0
|
||||||
|
@ -279,10 +283,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testSharedJobName(self):
|
def testSharedJobName(self):
|
||||||
num_elements = 100
|
num_elements = 100
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||||
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||||
iter1 = iter(ds1)
|
iter1 = iter(ds1)
|
||||||
iter2 = iter(ds2)
|
iter2 = iter(ds2)
|
||||||
results = []
|
results = []
|
||||||
|
@ -298,20 +302,22 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDifferentJobNames(self):
|
def testDifferentJobNames(self):
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name1")
|
ds1 = _make_distributed_dataset(
|
||||||
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name2")
|
ds, dispatcher_address, job_name="job_name1")
|
||||||
|
ds2 = _make_distributed_dataset(
|
||||||
|
ds, dispatcher_address, job_name="job_name2")
|
||||||
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
||||||
self.assertDatasetProduces(ds2, list(range(num_elements)))
|
self.assertDatasetProduces(ds2, list(range(num_elements)))
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testSharedJobNameMultiIteration(self):
|
def testSharedJobNameMultiIteration(self):
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||||
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||||
# iteration 1
|
# iteration 1
|
||||||
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
||||||
self.assertDatasetProduces(ds2, [])
|
self.assertDatasetProduces(ds2, [])
|
||||||
|
@ -323,11 +329,11 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
def testSharedJobNameRepeat(self):
|
def testSharedJobNameRepeat(self):
|
||||||
num_elements = 100
|
num_elements = 100
|
||||||
num_repetitions = 3
|
num_repetitions = 3
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||||
ds1 = ds1.repeat(num_repetitions)
|
ds1 = ds1.repeat(num_repetitions)
|
||||||
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||||
ds2 = ds2.repeat(num_repetitions)
|
ds2 = ds2.repeat(num_repetitions)
|
||||||
results = []
|
results = []
|
||||||
iter1 = iter(ds1)
|
iter1 = iter(ds1)
|
||||||
|
@ -345,7 +351,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testApplyDeterminismOption(self):
|
def testApplyDeterminismOption(self):
|
||||||
elements = list(range(10))
|
elements = list(range(10))
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
|
|
||||||
def dataset_fn(delay_ms):
|
def dataset_fn(delay_ms):
|
||||||
|
|
||||||
|
@ -362,7 +368,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
opts = dataset_ops.Options()
|
opts = dataset_ops.Options()
|
||||||
opts.experimental_deterministic = False
|
opts.experimental_deterministic = False
|
||||||
ds = ds.with_options(opts)
|
ds = ds.with_options(opts)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
self.checkDeterminism(
|
self.checkDeterminism(
|
||||||
|
@ -379,8 +385,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
options.experimental_external_state_policy = external_state_policy
|
options.experimental_external_state_policy = external_state_policy
|
||||||
ds = ds.with_options(options)
|
ds = ds.with_options(options)
|
||||||
|
|
||||||
master_address = self.create_cluster(3)
|
dispatcher_address = self.create_cluster(3)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
next(iter(ds))
|
next(iter(ds))
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
|
@ -400,12 +406,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeFromInterleave(self):
|
def testDistributeFromInterleave(self):
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(2)
|
ds = dataset_ops.Dataset.range(2)
|
||||||
|
|
||||||
def interleave_fn(_):
|
def interleave_fn(_):
|
||||||
ds = dataset_ops.Dataset.range(2)
|
ds = dataset_ops.Dataset.range(2)
|
||||||
_make_distributed_dataset(ds, master_address)
|
_make_distributed_dataset(ds, dispatcher_address)
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
path: "tensorflow.data.experimental.service.MasterServer"
|
path: "tensorflow.data.experimental.service.DispatchServer"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.MasterServer\'>"
|
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.DispatchServer\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member {
|
member {
|
||||||
name: "target"
|
name: "target"
|
|
@ -4,7 +4,7 @@ tf_class {
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'port\', \'master_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
|
argspec: "args=[\'self\', \'port\', \'dispatcher_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "join"
|
name: "join"
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
path: "tensorflow.data.experimental.service"
|
path: "tensorflow.data.experimental.service"
|
||||||
tf_module {
|
tf_module {
|
||||||
member {
|
member {
|
||||||
name: "MasterServer"
|
name: "DispatchServer"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
member {
|
member {
|
||||||
|
|
|
@ -99,8 +99,8 @@ tensorflow::data::GrpcDataServerBase::Join
|
||||||
tensorflow::data::GrpcDataServerBase::Start
|
tensorflow::data::GrpcDataServerBase::Start
|
||||||
tensorflow::data::GrpcDataServerBase::Stop
|
tensorflow::data::GrpcDataServerBase::Stop
|
||||||
tensorflow::data::GrpcDataServerBase::BoundPort
|
tensorflow::data::GrpcDataServerBase::BoundPort
|
||||||
tensorflow::data::MasterGrpcDataServer::NumWorkers
|
tensorflow::data::DispatchGrpcDataServer::NumWorkers
|
||||||
tensorflow::data::NewMasterServer
|
tensorflow::data::NewDispatchServer
|
||||||
tensorflow::data::NewWorkerServer
|
tensorflow::data::NewWorkerServer
|
||||||
|
|
||||||
[protos_all] # device_lib, dtypes
|
[protos_all] # device_lib, dtypes
|
||||||
|
|
Loading…
Reference in New Issue