[tf.data service] Add a DataTransferClient/Server abstraction to support transfer methods other than gRPC

PiperOrigin-RevId: 352868792
Change-Id: I47cfd8473a917ddbd60490fe7cd6a9ba7e1ba34f
This commit is contained in:
A. Unique TensorFlower 2021-01-20 13:54:40 -08:00 committed by TensorFlower Gardener
parent 2263b49a6e
commit 3e3364c6f1
29 changed files with 543 additions and 107 deletions

View File

@ -30,6 +30,8 @@
for synchronous training workloads where example sizes vary. With strict
round robin reads, users can guarantee that consumers get similar-sized
examples in the same step.
* tf.data service supports custom data transfer protocols (other than
gRPC).
## Bug Fixes and Other Changes

View File

@ -20,6 +20,8 @@ load(
"tf_cc_test",
)
package_group(name = "data_transfer_visibility")
package(
default_visibility = [
"//tensorflow:internal",
@ -53,6 +55,9 @@ tf_proto_library(
":common_proto",
"//tensorflow/core/data:dataset_proto",
],
visibility = [
":data_transfer_visibility",
],
)
cc_library(
@ -86,6 +91,7 @@ cc_library(
],
deps = [
":credentials_factory",
":data_transfer",
":dispatcher_cc_grpc_proto",
":grpc_util",
":worker_cc_grpc_proto",
@ -124,6 +130,37 @@ tf_cc_test(
] + tf_protos_profiler_service(),
)
cc_library(
name = "data_transfer",
srcs = ["data_transfer.cc"],
hdrs = ["data_transfer.h"],
visibility = [
":data_transfer_visibility",
],
deps = [
":worker_proto_cc",
"//tensorflow/core/data:dataset_proto_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:mutex",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
tf_cc_test(
name = "data_transfer_test",
srcs = ["data_transfer_test.cc"],
deps = [
":data_transfer",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "dataset_store",
srcs = ["dataset_store.cc"],
@ -335,6 +372,7 @@ cc_library(
],
deps = [
":credentials_factory",
":data_transfer",
":grpc_dispatcher_impl",
":grpc_util",
":grpc_worker_impl",

View File

@ -30,6 +30,8 @@ message TaskDef {
message TaskInfo {
// The address of the worker processing the task.
string worker_address = 1;
// The transfer address of the worker processing the task.
string transfer_address = 4;
// The task id.
int64 task_id = 2;
// The id of the job that the task is part of.

View File

@ -31,6 +31,7 @@ namespace data {
namespace {
constexpr const char kParallelEpochs[] = "parallel_epochs";
constexpr const char kDistributedEpoch[] = "distributed_epoch";
} // namespace
Status ParseProcessingMode(const std::string& s, ProcessingMode& mode) {
@ -57,11 +58,13 @@ std::string ProcessingModeToString(ProcessingMode mode) {
}
Status DataServiceDispatcherClient::WorkerHeartbeat(
const std::string& worker_address, const std::vector<int64>& current_tasks,
std::vector<TaskDef>& new_tasks, std::vector<int64>& tasks_to_delete) {
const std::string& worker_address, const std::string& transfer_address,
const std::vector<int64>& current_tasks, std::vector<TaskDef>& new_tasks,
std::vector<int64>& tasks_to_delete) {
TF_RETURN_IF_ERROR(EnsureInitialized());
WorkerHeartbeatRequest req;
req.set_worker_address(worker_address);
req.set_transfer_address(transfer_address);
for (int64 task : current_tasks) {
req.add_current_tasks(task);
}
@ -244,69 +247,112 @@ Status DataServiceDispatcherClient::EnsureInitialized() {
return Status::OK();
}
class GrpcDataTransferClient : public DataTransferClient {
public:
GrpcDataTransferClient(std::shared_ptr<grpc::ChannelCredentials> credentials,
std::string address) {
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(-1);
auto channel = grpc::CreateCustomChannel(address, credentials, args);
stub_ = WorkerService::NewStub(channel);
}
Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
absl::optional<int64> round_index,
CompressedElement& element,
bool& end_of_sequence) override {
{
mutex_lock l(mu_);
if (cancelled_) {
return errors::Cancelled("Client was cancelled.");
}
}
GetElementRequest req;
req.set_task_id(task_id);
if (consumer_index.has_value()) {
req.set_consumer_index(consumer_index.value());
}
if (round_index.has_value()) {
req.set_round_index(round_index.value());
}
GetElementResponse resp;
grpc::ClientContext ctx;
{
mutex_lock l(mu_);
active_contexts_.insert(&ctx);
}
grpc::Status s = stub_->GetElement(&ctx, req, &resp);
{
mutex_lock l(mu_);
active_contexts_.erase(&ctx);
}
if (!s.ok()) {
return grpc_util::WrapError("Failed to get element", s);
}
end_of_sequence = resp.end_of_sequence();
if (!end_of_sequence) {
element = std::move(*resp.mutable_compressed_element());
}
return Status::OK();
}
void TryCancel() override {
mutex_lock l(mu_);
cancelled_ = true;
for (const auto& ctx : active_contexts_) {
ctx->TryCancel();
}
}
private:
mutex mu_;
std::unique_ptr<WorkerService::Stub> stub_;
// Set of all currently active clients contexts. Used to support
// cancellation.
absl::flat_hash_set<::grpc::ClientContext*> active_contexts_ GUARDED_BY(mu_);
// Indicates that the client has been cancelled, so no further requests should
// be accepted.
bool cancelled_ GUARDED_BY(mu_) = false;
};
class GrpcTransferClientRegistrar {
public:
GrpcTransferClientRegistrar() {
DataTransferClient::Register(
"grpc", [](DataTransferClient::Config config,
std::unique_ptr<DataTransferClient>* out) {
std::shared_ptr<grpc::ChannelCredentials> credentials;
TF_RETURN_IF_ERROR(CredentialsFactory::CreateClientCredentials(
config.protocol, &credentials));
*out = std::make_unique<GrpcDataTransferClient>(credentials,
config.address);
return Status::OK();
});
}
};
static GrpcTransferClientRegistrar registrar;
Status DataServiceWorkerClient::GetElement(int64 task_id,
absl::optional<int64> consumer_index,
absl::optional<int64> round_index,
CompressedElement& element,
bool& end_of_sequence) {
TF_RETURN_IF_ERROR(EnsureInitialized());
{
mutex_lock l(mu_);
if (cancelled_) {
return errors::Cancelled("Client was cancelled.");
}
}
GetElementRequest req;
req.set_task_id(task_id);
if (consumer_index.has_value()) {
req.set_consumer_index(consumer_index.value());
}
if (round_index.has_value()) {
req.set_round_index(round_index.value());
}
GetElementResponse resp;
grpc::ClientContext ctx;
{
mutex_lock l(mu_);
active_contexts_.insert(&ctx);
}
grpc::Status s = stub_->GetElement(&ctx, req, &resp);
{
mutex_lock l(mu_);
active_contexts_.erase(&ctx);
}
if (!s.ok()) {
return grpc_util::WrapError("Failed to get element", s);
}
end_of_sequence = resp.end_of_sequence();
if (!end_of_sequence) {
element = std::move(*resp.mutable_compressed_element());
}
return Status::OK();
return client_->GetElement(task_id, consumer_index, round_index, element,
end_of_sequence);
}
Status DataServiceWorkerClient::EnsureInitialized() {
mutex_lock l(mu_);
if (stub_) {
if (client_) {
return Status::OK();
}
std::shared_ptr<grpc::ChannelCredentials> credentials;
TF_RETURN_IF_ERROR(
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(-1);
auto channel = grpc::CreateCustomChannel(address_, credentials, args);
stub_ = WorkerService::NewStub(channel);
TF_RETURN_IF_ERROR(DataTransferClient::Build(
transfer_protocol_, {protocol_, address_}, &client_));
return Status::OK();
}
void DataServiceWorkerClient::TryCancel() {
mutex_lock l(mu_);
cancelled_ = true;
for (const auto& ctx : active_contexts_) {
ctx->TryCancel();
}
}
void DataServiceWorkerClient::TryCancel() { client_->TryCancel(); }
Status CreateDataServiceDispatcherClient(
const std::string& address, const std::string& protocol,
@ -320,8 +366,10 @@ Status CreateDataServiceDispatcherClient(
Status CreateDataServiceWorkerClient(
const std::string& address, const std::string& protocol,
const std::string& transfer_protocol,
std::unique_ptr<DataServiceWorkerClient>& out) {
auto client = absl::make_unique<DataServiceWorkerClient>(address, protocol);
auto client = absl::make_unique<DataServiceWorkerClient>(address, protocol,
transfer_protocol);
TF_RETURN_IF_ERROR(client->Initialize());
out = std::move(client);
return Status::OK();

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "grpcpp/impl/codegen/client_context.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/data/service/data_transfer.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h"
@ -82,6 +83,7 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
// tasks it should delete. This is stored into `new_tasks` and
// `tasks_to_delete`.
Status WorkerHeartbeat(const std::string& worker_address,
const std::string& transfer_address,
const std::vector<int64>& current_tasks,
std::vector<TaskDef>& new_tasks,
std::vector<int64>& tasks_to_delete);
@ -138,8 +140,10 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
class DataServiceWorkerClient : public DataServiceClientBase {
public:
DataServiceWorkerClient(const std::string& address,
const std::string& protocol)
: DataServiceClientBase(address, protocol) {}
const std::string& protocol,
const std::string& transfer_protocol)
: DataServiceClientBase(address, protocol),
transfer_protocol_(transfer_protocol) {}
// Fetches the next element for the specified task_id. The optional
// `consumer_index` and `round_index` must be specified for tasks which use
@ -158,16 +162,11 @@ class DataServiceWorkerClient : public DataServiceClientBase {
Status EnsureInitialized() override;
private:
const std::string transfer_protocol_;
mutex mu_;
// Initialization is guarded by `mu_`, but using the stub does not require
// holding `mu_`
std::unique_ptr<WorkerService::Stub> stub_;
// Set of all currently active clients contexts. Used to support
// cancellation.
absl::flat_hash_set<::grpc::ClientContext*> active_contexts_ GUARDED_BY(mu_);
// Indicates that the client has been cancelled, so no further requests should
// be accepted.
bool cancelled_ GUARDED_BY(mu_) = false;
std::unique_ptr<DataTransferClient> client_;
};
// Creates and initializes a new tf.data service dispatcher client.
@ -178,6 +177,7 @@ Status CreateDataServiceDispatcherClient(
// Creates and initializes a new tf.data service worker client.
Status CreateDataServiceWorkerClient(
const std::string& address, const std::string& protocol,
const std::string& transfer_protocol,
std::unique_ptr<DataServiceWorkerClient>& out);
} // namespace data

View File

@ -0,0 +1,110 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/data/service/data_transfer.h"
#include <functional>
#include "absl/strings/str_join.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace data {
namespace {
mutex* get_lock() {
static mutex lock(LINKER_INITIALIZED);
return &lock;
}
using DataTransferServerFactories =
std::unordered_map<std::string,
std::function<std::shared_ptr<DataTransferServer>(
DataTransferServer::GetElementT)>>;
DataTransferServerFactories& transfer_server_factories() {
static auto& factories = *new DataTransferServerFactories();
return factories;
}
using DataTransferClientFactories =
std::unordered_map<std::string, DataTransferClient::FactoryT>;
DataTransferClientFactories& transfer_client_factories() {
static auto& factories = *new DataTransferClientFactories();
return factories;
}
} // namespace
void DataTransferServer::Register(
std::string name,
std::function<std::shared_ptr<DataTransferServer>(GetElementT)> factory) {
mutex_lock l(*get_lock());
if (!transfer_server_factories().insert({name, factory}).second) {
LOG(ERROR)
<< "Two data transfer server factories are being registered with name "
<< name << ". Which one gets used is undefined.";
}
}
Status DataTransferServer::Build(std::string name, GetElementT get_element,
std::shared_ptr<DataTransferServer>* out) {
mutex_lock l(*get_lock());
auto it = transfer_server_factories().find(name);
if (it != transfer_server_factories().end()) {
*out = it->second(get_element);
return Status::OK();
}
std::vector<string> available_names;
for (const auto& factory : transfer_server_factories()) {
available_names.push_back(factory.first);
}
return errors::NotFound(
"No data transfer server factory has been registered for name ", name,
". The available names are: [ ", absl::StrJoin(available_names, ", "),
" ]");
}
void DataTransferClient::Register(std::string name, FactoryT factory) {
mutex_lock l(*get_lock());
if (!transfer_client_factories().insert({name, factory}).second) {
LOG(ERROR)
<< "Two data transfer client factories are being registered with name "
<< name << ". Which one gets used is undefined.";
}
}
Status DataTransferClient::Build(std::string name, Config config,
std::unique_ptr<DataTransferClient>* out) {
mutex_lock l(*get_lock());
auto it = transfer_client_factories().find(name);
if (it != transfer_client_factories().end()) {
return it->second(config, out);
}
std::vector<string> available_names;
for (const auto& factory : transfer_client_factories()) {
available_names.push_back(factory.first);
}
return errors::NotFound(
"No data transfer client factory has been registered for name ", name,
". The available names are: [ ", absl::StrJoin(available_names, ", "),
" ]");
}
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,87 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_
#define TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_
#include <functional>
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/core/data/dataset.pb.h"
#include "tensorflow/core/data/service/worker.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace data {
// Client for communicating with the tf.data service transfer server.
class DataTransferClient {
public:
struct Config {
absl::string_view protocol;
std::string address;
};
using FactoryT =
std::function<Status(Config, std::unique_ptr<DataTransferClient>*)>;
virtual ~DataTransferClient() = default;
// Fetches the next element for the specified task_id. The element's
// compressed tensors will be stored in `element`. If no element is available,
// `end_of_sequence` will be `true`, and `element` will be left unchanged.
virtual Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
absl::optional<int64> round_index,
tensorflow::data::CompressedElement& element,
bool& end_of_sequence) = 0;
// Makes a best effort to cancel all outstanding calls in progress for the
// client, and causes further calls to return Cancelled status.
virtual void TryCancel() = 0;
// Registers a DataTransferClient factory under `name`.
static void Register(std::string name, FactoryT factory);
// Builds a DataTransferClient from the factory registered under `name`.
static Status Build(std::string name, Config config,
std::unique_ptr<DataTransferClient>* out);
};
// Server for communicating with the tf.data service transfer client.
class DataTransferServer {
public:
using GetElementT =
std::function<Status(const GetElementRequest*, GetElementResponse*)>;
virtual ~DataTransferServer() = default;
// Starts DataTransferServer, it should be available for requests afterwards.
virtual Status Start() = 0;
// Return the port that this server is listening on.
virtual int get_port() = 0;
// Register a DataTransferServer factory under `name`.
static void Register(
std::string name,
std::function<std::shared_ptr<DataTransferServer>(GetElementT)> factory);
// Builds a DataTransferServer from the factory registered with `name`.
static Status Build(std::string name, GetElementT get_element,
std::shared_ptr<DataTransferServer>* out);
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_

View File

@ -0,0 +1,54 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/data/service/data_transfer.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace data {
namespace {
class TestDataTransferServer : public DataTransferServer {
public:
explicit TestDataTransferServer(bool* called) : called_(called) {}
Status Start() override {
*called_ = true;
return Status::OK();
}
int get_port() override { return 0; }
private:
bool* called_;
};
TEST(DataTransferTest, RegisterDataTransferServerBuilder) {
bool called = false;
DataTransferServer::Register("test", [&called](auto _) {
return std::make_shared<TestDataTransferServer>(&called);
});
std::shared_ptr<DataTransferServer> server;
TF_ASSERT_OK(DataTransferServer::Build("test", {}, &server));
EXPECT_FALSE(called);
TF_ASSERT_OK(server->Start());
EXPECT_TRUE(called);
}
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -14,6 +14,7 @@ message TaskProgress {
message WorkerHeartbeatRequest {
string worker_address = 1;
string transfer_address = 3;
repeated int64 current_tasks = 2;
}

View File

@ -217,6 +217,8 @@ Status DataServiceDispatcherImpl::WorkerHeartbeat(
}
Update update;
update.mutable_register_worker()->set_worker_address(worker_address);
update.mutable_register_worker()->set_transfer_address(
request->transfer_address());
TF_RETURN_IF_ERROR(Apply(update));
TF_RETURN_IF_ERROR(CreateTasksForWorker(worker_address));
TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, correct_tasks));
@ -584,6 +586,9 @@ Status DataServiceDispatcherImpl::CreateTask(std::shared_ptr<const Job> job,
create_task->set_task_id(task_id);
create_task->set_job_id(job->job_id);
create_task->set_worker_address(worker_address);
std::shared_ptr<const Worker> worker;
TF_RETURN_IF_ERROR(state_.WorkerFromAddress(worker_address, worker));
create_task->set_transfer_address(worker->transfer_address);
TF_RETURN_IF_ERROR(Apply(update));
TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
return Status::OK();
@ -686,6 +691,7 @@ Status DataServiceDispatcherImpl::GetTasks(const GetTasksRequest* request,
for (const auto& task : tasks) {
TaskInfo* task_info = response->mutable_task_info()->Add();
task_info->set_worker_address(task->worker_address);
task_info->set_transfer_address(task->transfer_address);
task_info->set_task_id(task->task_id);
task_info->set_job_id(job->job_id);
}

View File

@ -74,7 +74,8 @@ void DispatcherState::RegisterWorker(
const RegisterWorkerUpdate& register_worker) {
std::string address = register_worker.worker_address();
DCHECK(!workers_.contains(address));
workers_[address] = std::make_shared<Worker>(address);
workers_[address] =
std::make_shared<Worker>(address, register_worker.transfer_address());
tasks_by_worker_[address] =
absl::flat_hash_map<int64, std::shared_ptr<Task>>();
}
@ -146,7 +147,8 @@ void DispatcherState::CreateTask(const CreateTaskUpdate& create_task) {
DCHECK_EQ(task, nullptr);
auto& job = jobs_[create_task.job_id()];
DCHECK_NE(job, nullptr);
task = std::make_shared<Task>(task_id, job, create_task.worker_address());
task = std::make_shared<Task>(task_id, job, create_task.worker_address(),
create_task.transfer_address());
tasks_by_job_[create_task.job_id()].push_back(task);
tasks_by_worker_[create_task.worker_address()][task->task_id] = task;
next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);

View File

@ -69,9 +69,12 @@ class DispatcherState {
// A worker registered with the dispatcher.
struct Worker {
explicit Worker(const std::string& address) : address(address) {}
explicit Worker(const std::string& address,
const std::string& transfer_address)
: address(address), transfer_address(transfer_address) {}
const std::string address;
const std::string transfer_address;
};
// A key for identifying a named job. The key contains a user-specified name,
@ -128,12 +131,17 @@ class DispatcherState {
struct Task {
explicit Task(int64 task_id, const std::shared_ptr<Job>& job,
const std::string& worker_address)
: task_id(task_id), job(job), worker_address(worker_address) {}
const std::string& worker_address,
const std::string& transfer_address)
: task_id(task_id),
job(job),
worker_address(worker_address),
transfer_address(transfer_address) {}
const int64 task_id;
const std::shared_ptr<Job> job;
const std::string worker_address;
const std::string transfer_address;
bool finished = false;
};

View File

@ -31,8 +31,9 @@ GrpcWorkerImpl::GrpcWorkerImpl(const experimental::WorkerConfig& config,
VLOG(1) << "Registered data service worker";
}
Status GrpcWorkerImpl::Start(const std::string& worker_address) {
return impl_.Start(worker_address);
Status GrpcWorkerImpl::Start(const std::string& worker_address,
const std::string& transfer_address) {
return impl_.Start(worker_address, transfer_address);
}
#define HANDLER(method) \

View File

@ -33,7 +33,16 @@ class GrpcWorkerImpl : public WorkerService::Service {
::grpc::ServerBuilder& server_builder);
~GrpcWorkerImpl() override {}
Status Start(const std::string& worker_address);
Status Start(const std::string& worker_address,
const std::string& transfer_address);
std::function<Status(const GetElementRequest*, GetElementResponse*)>
get_element_getter() {
return
[this](const GetElementRequest* request, GetElementResponse* response) {
return impl_.GetElement(request, response);
};
}
#define HANDLER(method) \
::grpc::Status method(::grpc::ServerContext* context, \

View File

@ -27,6 +27,7 @@ message RegisterDatasetUpdate {
message RegisterWorkerUpdate {
string worker_address = 1;
string transfer_address = 2;
}
message NamedJobKeyDef {
@ -71,6 +72,7 @@ message CreateTaskUpdate {
int64 task_id = 1;
int64 job_id = 2;
string worker_address = 4;
string transfer_address = 6;
}
message FinishTaskUpdate {

View File

@ -127,14 +127,27 @@ void WorkerGrpcDataServer::AddDataServiceToBuilder(
}
Status WorkerGrpcDataServer::StartServiceInternal() {
std::string worker_address = config_.worker_address();
if (worker_address.empty()) {
worker_address = absl::StrCat("localhost:", kPortPlaceholder);
std::string base_address = config_.worker_address();
if (base_address.empty()) {
base_address = absl::StrCat("localhost:", kPortPlaceholder);
}
std::string resolved_address = str_util::StringReplace(
worker_address, kPortPlaceholder, absl::StrCat(bound_port()),
std::string worker_address = str_util::StringReplace(
base_address, kPortPlaceholder, absl::StrCat(bound_port()),
/*replace_all=*/false);
TF_RETURN_IF_ERROR(service_->Start(resolved_address));
std::string transfer_address = worker_address;
std::string transfer_protocol = config_.data_transfer_protocol();
if (!transfer_protocol.empty()) {
TF_RETURN_IF_ERROR(DataTransferServer::Build(
transfer_protocol, service_->get_element_getter(), &transfer_server_));
TF_RETURN_IF_ERROR(transfer_server_->Start());
LOG(INFO) << "Data transfer server started at 0.0.0.0:"
<< transfer_server_->get_port();
transfer_address =
str_util::StringReplace(base_address, kPortPlaceholder,
absl::StrCat(transfer_server_->get_port()),
/*replace_all=*/false);
}
TF_RETURN_IF_ERROR(service_->Start(worker_address, transfer_address));
return Status::OK();
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "grpcpp/server.h"
#include "grpcpp/server_builder.h"
#include "tensorflow/core/data/service/data_transfer.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
#include "tensorflow/core/protobuf/service_config.pb.h"
@ -109,6 +110,7 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
const experimental::WorkerConfig config_;
// Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared.
GrpcWorkerImpl* service_;
std::shared_ptr<DataTransferServer> transfer_server_;
};
// Creates a dispatch tf.data server and stores it in `out_server`.

View File

@ -64,9 +64,11 @@ DataServiceWorkerImpl::~DataServiceWorkerImpl() {
heartbeat_cv_.notify_one();
}
Status DataServiceWorkerImpl::Start(const std::string& worker_address) {
Status DataServiceWorkerImpl::Start(const std::string& worker_address,
const std::string& transfer_address) {
VLOG(3) << "Starting tf.data service worker at address " << worker_address;
worker_address_ = worker_address;
transfer_address_ = transfer_address;
dispatcher_ = absl::make_unique<DataServiceDispatcherClient>(
config_.dispatcher_address(), config_.protocol());
@ -356,8 +358,9 @@ Status DataServiceWorkerImpl::Heartbeat() TF_LOCKS_EXCLUDED(mu_) {
}
std::vector<TaskDef> new_tasks;
std::vector<int64> tasks_to_delete;
TF_RETURN_IF_ERROR(dispatcher_->WorkerHeartbeat(
worker_address_, current_tasks, new_tasks, tasks_to_delete));
TF_RETURN_IF_ERROR(
dispatcher_->WorkerHeartbeat(worker_address_, transfer_address_,
current_tasks, new_tasks, tasks_to_delete));
mutex_lock l(mu_);
for (const auto& task : new_tasks) {
Status s = ProcessTaskInternal(task);

View File

@ -41,7 +41,8 @@ class DataServiceWorkerImpl {
// constructor because the worker may be binding to port `0`, in which case
// the address isn't known until the worker has started and decided which port
// to bind to.
Status Start(const std::string& worker_address);
Status Start(const std::string& worker_address,
const std::string& transfer_address);
// See worker.proto for API documentation.
@ -81,6 +82,7 @@ class DataServiceWorkerImpl {
const experimental::WorkerConfig config_;
// The worker's own address.
std::string worker_address_;
std::string transfer_address_;
std::unique_ptr<DataServiceDispatcherClient> dispatcher_;
mutex mu_;

View File

@ -51,6 +51,8 @@ namespace data {
/* static */ constexpr const char* const DataServiceDatasetOp::kProcessingMode;
/* static */ constexpr const char* const DataServiceDatasetOp::kAddress;
/* static */ constexpr const char* const DataServiceDatasetOp::kProtocol;
/* static */ constexpr const char* const
DataServiceDatasetOp::kDataTransferProtocol;
/* static */ constexpr const char* const DataServiceDatasetOp::kJobName;
/* static */ constexpr const char* const DataServiceDatasetOp::kConsumerIndex;
/* static */ constexpr const char* const DataServiceDatasetOp::kNumConsumers;
@ -80,8 +82,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int op_version, int64 dataset_id,
ProcessingMode processing_mode, const std::string& address,
const std::string& protocol, const std::string& job_name,
absl::optional<int64> consumer_index,
const std::string& protocol,
const std::string& data_transfer_protocol,
const std::string& job_name, absl::optional<int64> consumer_index,
absl::optional<int64> num_consumers, int64 max_outstanding_requests,
int64 task_refresh_interval_ms, IterationCounter* iteration_counter,
bool owns_resource, ResourceHandle iteration_counter_handle,
@ -93,6 +96,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
processing_mode_(processing_mode),
address_(address),
protocol_(protocol),
data_transfer_protocol_(data_transfer_protocol),
job_name_(job_name),
consumer_index_(consumer_index),
num_consumers_(num_consumers),
@ -164,6 +168,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(b->AddScalar(protocol_, &protocol));
inputs.push_back(protocol);
AttrValue data_transfer_protocol;
b->BuildAttrValue(data_transfer_protocol_, &data_transfer_protocol);
Node* job_name;
TF_RETURN_IF_ERROR(b->AddScalar(job_name_, &job_name));
inputs.push_back(job_name);
@ -195,11 +202,12 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
b->BuildAttrValue(task_refresh_interval_ms_,
&task_refresh_interval_hint_ms);
TF_RETURN_IF_ERROR(
b->AddDataset(this, inputs,
{std::make_pair(kTaskRefreshIntervalHintMs,
task_refresh_interval_hint_ms)},
output));
TF_RETURN_IF_ERROR(b->AddDataset(
this, inputs,
{std::make_pair(kTaskRefreshIntervalHintMs,
task_refresh_interval_hint_ms),
std::make_pair(kDataTransferProtocol, data_transfer_protocol)},
output));
return Status::OK();
}
@ -459,8 +467,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
}
TaskInfo& task_info = it->second;
std::unique_ptr<DataServiceWorkerClient> worker;
Status s = CreateDataServiceWorkerClient(task_info.worker_address(),
dataset()->protocol_, worker);
Status s = CreateDataServiceWorkerClient(
task_info.transfer_address(), dataset()->protocol_,
dataset()->data_transfer_protocol_, worker);
if (!s.ok()) {
status_ = s;
get_next_cv_.notify_all();
@ -743,6 +752,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
const ProcessingMode processing_mode_;
const tstring address_;
const tstring protocol_;
const tstring data_transfer_protocol_;
const tstring job_name_;
const absl::optional<int64> consumer_index_;
const absl::optional<int64> num_consumers_;
@ -765,6 +775,11 @@ DataServiceDatasetOp::DataServiceDatasetOp(OpKernelConstruction* ctx)
}
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
if (ctx->HasAttr(kDataTransferProtocol)) {
OP_REQUIRES_OK(
ctx, ctx->GetAttr(kDataTransferProtocol, &data_transfer_protocol_));
}
if (data_transfer_protocol_.empty()) data_transfer_protocol_ = "grpc";
auto& op_name = ctx->def().op();
if (op_name == kDataServiceDatasetV1) {
op_version_ = 1;
@ -859,11 +874,12 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
errors::InvalidArgument(kMaxOutstandingRequests, " must be positive or ",
model::kAutotune));
*output = new Dataset(
ctx, op_version_, dataset_id, processing_mode, address, protocol,
job_name, consumer_index, num_consumers, max_outstanding_requests,
task_refresh_interval_hint_ms_, iteration_counter, owns_resource,
iteration_counter_handle, output_types_, output_shapes_);
*output = new Dataset(ctx, op_version_, dataset_id, processing_mode, address,
protocol, data_transfer_protocol_, job_name,
consumer_index, num_consumers, max_outstanding_requests,
task_refresh_interval_hint_ms_, iteration_counter,
owns_resource, iteration_counter_handle, output_types_,
output_shapes_);
}
REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU),

View File

@ -51,6 +51,8 @@ class DataServiceDatasetOp : public DatasetOpKernel {
static constexpr const char* const kProcessingMode = "processing_mode";
static constexpr const char* const kAddress = "address";
static constexpr const char* const kProtocol = "protocol";
static constexpr const char* const kDataTransferProtocol =
"data_transfer_protocol";
static constexpr const char* const kJobName = "job_name";
static constexpr const char* const kConsumerIndex = "consumer_index";
static constexpr const char* const kNumConsumers = "num_consumers";
@ -76,6 +78,7 @@ class DataServiceDatasetOp : public DatasetOpKernel {
int64 task_refresh_interval_hint_ms_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
std::string data_transfer_protocol_;
};
} // namespace data

View File

@ -1184,6 +1184,7 @@ REGISTER_OP("DataServiceDataset")
.Attr("task_refresh_interval_hint_ms: int = -1")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("data_transfer_protocol: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
@ -1203,6 +1204,7 @@ REGISTER_OP("DataServiceDatasetV2")
.Attr("task_refresh_interval_hint_ms: int = -1")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("data_transfer_protocol: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);

View File

@ -40,4 +40,6 @@ message WorkerConfig {
// How long to retry requests to the dispatcher before giving up and reporting
// an error.
int64 dispatcher_timeout_ms = 6;
// The protocol for the worker to use when transferring data to clients.
string data_transfer_protocol = 7;
}

View File

@ -59,6 +59,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
processing_mode,
address,
protocol,
data_transfer_protocol,
job_name=None,
consumer_index=None,
num_consumers=None,
@ -76,6 +77,8 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
address: The tf.data service address, e.g. "localhost:5000".
protocol: The protocol to use for communicating with the tf.data service,
e.g. "grpc".
data_transfer_protocol: The protocol to use for transferring data with the
tf.data service, e.g. "grpc".
job_name: (Optional.) The name of the job. This argument makes it possible
for multiple datasets to share the same job. The default behavior is
that the dataset creates anonymous, exclusively owned jobs.
@ -138,6 +141,10 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
# represented by scalar DT_VARIANTs.
self._element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
compat_kwargs = {}
if data_transfer_protocol is not None:
compat_kwargs["data_transfer_protocol"] = data_transfer_protocol
if num_consumers is None:
variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
dataset_id=self._dataset_id,
@ -149,6 +156,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
iteration_counter=gen_experimental_dataset_ops
.dummy_iteration_counter(),
**compat_kwargs,
**self._flat_structure)
else:
variant_tensor = gen_experimental_dataset_ops.data_service_dataset_v2(
@ -163,6 +171,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
iteration_counter=gen_experimental_dataset_ops
.dummy_iteration_counter(),
**compat_kwargs,
**self._flat_structure)
super(_DataServiceDatasetV2, self).__init__(variant_tensor)
@ -175,15 +184,16 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
"""A `Dataset` that executes its input through the tf.data service."""
@functools.wraps(_DataServiceDatasetV2.__init__)
def __init__(self, dataset_id, processing_mode, address, protocol, job_name,
consumer_index, num_consumers, max_outstanding_requests,
task_refresh_interval_hint_ms):
def __init__(self, dataset_id, processing_mode, address, protocol,
data_transfer_protocol, job_name, consumer_index, num_consumers,
max_outstanding_requests, task_refresh_interval_hint_ms):
self._wrapped = _DataServiceDatasetV2(
dataset_id=dataset_id,
processing_mode=processing_mode,
address=address,
protocol=protocol,
data_transfer_protocol=data_transfer_protocol,
job_name=job_name,
consumer_index=consumer_index,
num_consumers=num_consumers,
@ -233,7 +243,8 @@ def _from_dataset_id(processing_mode,
consumer_index=None,
num_consumers=None,
max_outstanding_requests=None,
task_refresh_interval_hint_ms=None):
task_refresh_interval_hint_ms=None,
data_transfer_protocol=None):
"""Creates a dataset which reads data from the tf.data service.
This transformation is similar to `from_dataset_id`, but supports additional
@ -274,6 +285,8 @@ def _from_dataset_id(processing_mode,
`max_outstanding_requests` of memory.
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
dispatcher for task changes.
data_transfer_protocol: (Optional.) The protocol to use for transferring
data with the tf.data service.
Returns:
A `tf.data.Dataset` which reads from the tf.data service.
@ -294,6 +307,7 @@ def _from_dataset_id(processing_mode,
processing_mode=processing_mode,
address=address,
protocol=protocol,
data_transfer_protocol=data_transfer_protocol,
job_name=job_name,
consumer_index=consumer_index,
num_consumers=num_consumers,
@ -317,7 +331,8 @@ def _distribute(processing_mode,
consumer_index=None,
num_consumers=None,
max_outstanding_requests=None,
task_refresh_interval_hint_ms=None):
task_refresh_interval_hint_ms=None,
data_transfer_protocol=None):
"""A transformation that moves dataset processing to the tf.data service.
This transformation is similar to `distribute`, but supports additional
@ -352,6 +367,8 @@ def _distribute(processing_mode,
`max_outstanding_requests` of memory.
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
dispatcher for task changes.
data_transfer_protocol: (Optional.) The protocol to use for transferring
data with the tf.data service.
Returns:
Dataset: A `Dataset` of the elements produced by the data service.
@ -369,7 +386,8 @@ def _distribute(processing_mode,
consumer_index=consumer_index,
num_consumers=num_consumers,
max_outstanding_requests=max_outstanding_requests,
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
data_transfer_protocol=data_transfer_protocol)
return _apply_fn
@ -380,7 +398,8 @@ def distribute(processing_mode,
job_name=None,
consumer_index=None,
num_consumers=None,
max_outstanding_requests=None):
max_outstanding_requests=None,
data_transfer_protocol=None):
"""A transformation that moves dataset processing to the tf.data service.
When you iterate over a dataset containing the `distribute` transformation,
@ -580,6 +599,8 @@ def distribute(processing_mode,
requested at the same time. You can use this option to control the amount
of memory used, since `distribute` won't use more than `element_size` *
`max_outstanding_requests` of memory.
data_transfer_protocol: (Optional.) The protocol to use for transferring
data with the tf.data service, e.g. "grpc".
Returns:
Dataset: A `Dataset` of the elements produced by the data service.
@ -590,7 +611,8 @@ def distribute(processing_mode,
job_name=job_name,
consumer_index=consumer_index,
num_consumers=num_consumers,
max_outstanding_requests=max_outstanding_requests)
max_outstanding_requests=max_outstanding_requests,
data_transfer_protocol=data_transfer_protocol)
@tf_export("data.experimental.service.register_dataset")

View File

@ -307,7 +307,8 @@ class WorkerServer(object):
port=config.port,
protocol=config.protocol,
heartbeat_interval_ms=config.heartbeat_interval_ms,
dispatcher_timeout_ms=config.dispatcher_timeout_ms)
dispatcher_timeout_ms=config.dispatcher_timeout_ms,
data_transfer_protocol=None)
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
config_proto.SerializeToString())
if start:

View File

@ -10,7 +10,7 @@ tf_module {
}
member_method {
name: "distribute"
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'data_transfer_protocol\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_dataset_id"

View File

@ -1006,11 +1006,11 @@ tf_module {
}
member_method {
name: "DataServiceDataset"
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'data_transfer_protocol\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'None\'], "
}
member_method {
name: "DataServiceDatasetV2"
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'data_transfer_protocol\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'None\'], "
}
member_method {
name: "DatasetCardinality"

View File

@ -18,7 +18,7 @@ tf_module {
}
member_method {
name: "distribute"
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'data_transfer_protocol\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_dataset_id"

View File

@ -1006,11 +1006,11 @@ tf_module {
}
member_method {
name: "DataServiceDataset"
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'data_transfer_protocol\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'None\'], "
}
member_method {
name: "DataServiceDatasetV2"
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'data_transfer_protocol\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'None\'], "
}
member_method {
name: "DatasetCardinality"