[tf.data service] Add a DataTransferClient/Server abstraction to support transfer methods other than gRPC
PiperOrigin-RevId: 352868792 Change-Id: I47cfd8473a917ddbd60490fe7cd6a9ba7e1ba34f
This commit is contained in:
parent
2263b49a6e
commit
3e3364c6f1
@ -30,6 +30,8 @@
|
||||
for synchronous training workloads where example sizes vary. With strict
|
||||
round robin reads, users can guarantee that consumers get similar-sized
|
||||
examples in the same step.
|
||||
* tf.data service supports custom data transfer protocols (other than
|
||||
gRPC).
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
|
@ -20,6 +20,8 @@ load(
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package_group(name = "data_transfer_visibility")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow:internal",
|
||||
@ -53,6 +55,9 @@ tf_proto_library(
|
||||
":common_proto",
|
||||
"//tensorflow/core/data:dataset_proto",
|
||||
],
|
||||
visibility = [
|
||||
":data_transfer_visibility",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -86,6 +91,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":credentials_factory",
|
||||
":data_transfer",
|
||||
":dispatcher_cc_grpc_proto",
|
||||
":grpc_util",
|
||||
":worker_cc_grpc_proto",
|
||||
@ -124,6 +130,37 @@ tf_cc_test(
|
||||
] + tf_protos_profiler_service(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "data_transfer",
|
||||
srcs = ["data_transfer.cc"],
|
||||
hdrs = ["data_transfer.h"],
|
||||
visibility = [
|
||||
":data_transfer_visibility",
|
||||
],
|
||||
deps = [
|
||||
":worker_proto_cc",
|
||||
"//tensorflow/core/data:dataset_proto_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:mutex",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "data_transfer_test",
|
||||
srcs = ["data_transfer_test.cc"],
|
||||
deps = [
|
||||
":data_transfer",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dataset_store",
|
||||
srcs = ["dataset_store.cc"],
|
||||
@ -335,6 +372,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":credentials_factory",
|
||||
":data_transfer",
|
||||
":grpc_dispatcher_impl",
|
||||
":grpc_util",
|
||||
":grpc_worker_impl",
|
||||
|
@ -30,6 +30,8 @@ message TaskDef {
|
||||
message TaskInfo {
|
||||
// The address of the worker processing the task.
|
||||
string worker_address = 1;
|
||||
// The transfer address of the worker processing the task.
|
||||
string transfer_address = 4;
|
||||
// The task id.
|
||||
int64 task_id = 2;
|
||||
// The id of the job that the task is part of.
|
||||
|
@ -31,6 +31,7 @@ namespace data {
|
||||
namespace {
|
||||
constexpr const char kParallelEpochs[] = "parallel_epochs";
|
||||
constexpr const char kDistributedEpoch[] = "distributed_epoch";
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ParseProcessingMode(const std::string& s, ProcessingMode& mode) {
|
||||
@ -57,11 +58,13 @@ std::string ProcessingModeToString(ProcessingMode mode) {
|
||||
}
|
||||
|
||||
Status DataServiceDispatcherClient::WorkerHeartbeat(
|
||||
const std::string& worker_address, const std::vector<int64>& current_tasks,
|
||||
std::vector<TaskDef>& new_tasks, std::vector<int64>& tasks_to_delete) {
|
||||
const std::string& worker_address, const std::string& transfer_address,
|
||||
const std::vector<int64>& current_tasks, std::vector<TaskDef>& new_tasks,
|
||||
std::vector<int64>& tasks_to_delete) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
WorkerHeartbeatRequest req;
|
||||
req.set_worker_address(worker_address);
|
||||
req.set_transfer_address(transfer_address);
|
||||
for (int64 task : current_tasks) {
|
||||
req.add_current_tasks(task);
|
||||
}
|
||||
@ -244,69 +247,112 @@ Status DataServiceDispatcherClient::EnsureInitialized() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class GrpcDataTransferClient : public DataTransferClient {
|
||||
public:
|
||||
GrpcDataTransferClient(std::shared_ptr<grpc::ChannelCredentials> credentials,
|
||||
std::string address) {
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(-1);
|
||||
auto channel = grpc::CreateCustomChannel(address, credentials, args);
|
||||
stub_ = WorkerService::NewStub(channel);
|
||||
}
|
||||
|
||||
Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
|
||||
absl::optional<int64> round_index,
|
||||
CompressedElement& element,
|
||||
bool& end_of_sequence) override {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (cancelled_) {
|
||||
return errors::Cancelled("Client was cancelled.");
|
||||
}
|
||||
}
|
||||
GetElementRequest req;
|
||||
req.set_task_id(task_id);
|
||||
if (consumer_index.has_value()) {
|
||||
req.set_consumer_index(consumer_index.value());
|
||||
}
|
||||
if (round_index.has_value()) {
|
||||
req.set_round_index(round_index.value());
|
||||
}
|
||||
GetElementResponse resp;
|
||||
grpc::ClientContext ctx;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
active_contexts_.insert(&ctx);
|
||||
}
|
||||
grpc::Status s = stub_->GetElement(&ctx, req, &resp);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
active_contexts_.erase(&ctx);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
return grpc_util::WrapError("Failed to get element", s);
|
||||
}
|
||||
end_of_sequence = resp.end_of_sequence();
|
||||
if (!end_of_sequence) {
|
||||
element = std::move(*resp.mutable_compressed_element());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void TryCancel() override {
|
||||
mutex_lock l(mu_);
|
||||
cancelled_ = true;
|
||||
for (const auto& ctx : active_contexts_) {
|
||||
ctx->TryCancel();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
std::unique_ptr<WorkerService::Stub> stub_;
|
||||
// Set of all currently active clients contexts. Used to support
|
||||
// cancellation.
|
||||
absl::flat_hash_set<::grpc::ClientContext*> active_contexts_ GUARDED_BY(mu_);
|
||||
// Indicates that the client has been cancelled, so no further requests should
|
||||
// be accepted.
|
||||
bool cancelled_ GUARDED_BY(mu_) = false;
|
||||
};
|
||||
|
||||
class GrpcTransferClientRegistrar {
|
||||
public:
|
||||
GrpcTransferClientRegistrar() {
|
||||
DataTransferClient::Register(
|
||||
"grpc", [](DataTransferClient::Config config,
|
||||
std::unique_ptr<DataTransferClient>* out) {
|
||||
std::shared_ptr<grpc::ChannelCredentials> credentials;
|
||||
TF_RETURN_IF_ERROR(CredentialsFactory::CreateClientCredentials(
|
||||
config.protocol, &credentials));
|
||||
*out = std::make_unique<GrpcDataTransferClient>(credentials,
|
||||
config.address);
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
};
|
||||
static GrpcTransferClientRegistrar registrar;
|
||||
|
||||
Status DataServiceWorkerClient::GetElement(int64 task_id,
|
||||
absl::optional<int64> consumer_index,
|
||||
absl::optional<int64> round_index,
|
||||
CompressedElement& element,
|
||||
bool& end_of_sequence) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (cancelled_) {
|
||||
return errors::Cancelled("Client was cancelled.");
|
||||
}
|
||||
}
|
||||
GetElementRequest req;
|
||||
req.set_task_id(task_id);
|
||||
if (consumer_index.has_value()) {
|
||||
req.set_consumer_index(consumer_index.value());
|
||||
}
|
||||
if (round_index.has_value()) {
|
||||
req.set_round_index(round_index.value());
|
||||
}
|
||||
GetElementResponse resp;
|
||||
grpc::ClientContext ctx;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
active_contexts_.insert(&ctx);
|
||||
}
|
||||
grpc::Status s = stub_->GetElement(&ctx, req, &resp);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
active_contexts_.erase(&ctx);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
return grpc_util::WrapError("Failed to get element", s);
|
||||
}
|
||||
end_of_sequence = resp.end_of_sequence();
|
||||
if (!end_of_sequence) {
|
||||
element = std::move(*resp.mutable_compressed_element());
|
||||
}
|
||||
return Status::OK();
|
||||
return client_->GetElement(task_id, consumer_index, round_index, element,
|
||||
end_of_sequence);
|
||||
}
|
||||
|
||||
Status DataServiceWorkerClient::EnsureInitialized() {
|
||||
mutex_lock l(mu_);
|
||||
if (stub_) {
|
||||
if (client_) {
|
||||
return Status::OK();
|
||||
}
|
||||
std::shared_ptr<grpc::ChannelCredentials> credentials;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(-1);
|
||||
auto channel = grpc::CreateCustomChannel(address_, credentials, args);
|
||||
stub_ = WorkerService::NewStub(channel);
|
||||
TF_RETURN_IF_ERROR(DataTransferClient::Build(
|
||||
transfer_protocol_, {protocol_, address_}, &client_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DataServiceWorkerClient::TryCancel() {
|
||||
mutex_lock l(mu_);
|
||||
cancelled_ = true;
|
||||
for (const auto& ctx : active_contexts_) {
|
||||
ctx->TryCancel();
|
||||
}
|
||||
}
|
||||
void DataServiceWorkerClient::TryCancel() { client_->TryCancel(); }
|
||||
|
||||
Status CreateDataServiceDispatcherClient(
|
||||
const std::string& address, const std::string& protocol,
|
||||
@ -320,8 +366,10 @@ Status CreateDataServiceDispatcherClient(
|
||||
|
||||
Status CreateDataServiceWorkerClient(
|
||||
const std::string& address, const std::string& protocol,
|
||||
const std::string& transfer_protocol,
|
||||
std::unique_ptr<DataServiceWorkerClient>& out) {
|
||||
auto client = absl::make_unique<DataServiceWorkerClient>(address, protocol);
|
||||
auto client = absl::make_unique<DataServiceWorkerClient>(address, protocol,
|
||||
transfer_protocol);
|
||||
TF_RETURN_IF_ERROR(client->Initialize());
|
||||
out = std::move(client);
|
||||
return Status::OK();
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "grpcpp/impl/codegen/client_context.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/core/data/service/data_transfer.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
@ -82,6 +83,7 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
|
||||
// tasks it should delete. This is stored into `new_tasks` and
|
||||
// `tasks_to_delete`.
|
||||
Status WorkerHeartbeat(const std::string& worker_address,
|
||||
const std::string& transfer_address,
|
||||
const std::vector<int64>& current_tasks,
|
||||
std::vector<TaskDef>& new_tasks,
|
||||
std::vector<int64>& tasks_to_delete);
|
||||
@ -138,8 +140,10 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
|
||||
class DataServiceWorkerClient : public DataServiceClientBase {
|
||||
public:
|
||||
DataServiceWorkerClient(const std::string& address,
|
||||
const std::string& protocol)
|
||||
: DataServiceClientBase(address, protocol) {}
|
||||
const std::string& protocol,
|
||||
const std::string& transfer_protocol)
|
||||
: DataServiceClientBase(address, protocol),
|
||||
transfer_protocol_(transfer_protocol) {}
|
||||
|
||||
// Fetches the next element for the specified task_id. The optional
|
||||
// `consumer_index` and `round_index` must be specified for tasks which use
|
||||
@ -158,16 +162,11 @@ class DataServiceWorkerClient : public DataServiceClientBase {
|
||||
Status EnsureInitialized() override;
|
||||
|
||||
private:
|
||||
const std::string transfer_protocol_;
|
||||
mutex mu_;
|
||||
// Initialization is guarded by `mu_`, but using the stub does not require
|
||||
// holding `mu_`
|
||||
std::unique_ptr<WorkerService::Stub> stub_;
|
||||
// Set of all currently active clients contexts. Used to support
|
||||
// cancellation.
|
||||
absl::flat_hash_set<::grpc::ClientContext*> active_contexts_ GUARDED_BY(mu_);
|
||||
// Indicates that the client has been cancelled, so no further requests should
|
||||
// be accepted.
|
||||
bool cancelled_ GUARDED_BY(mu_) = false;
|
||||
std::unique_ptr<DataTransferClient> client_;
|
||||
};
|
||||
|
||||
// Creates and initializes a new tf.data service dispatcher client.
|
||||
@ -178,6 +177,7 @@ Status CreateDataServiceDispatcherClient(
|
||||
// Creates and initializes a new tf.data service worker client.
|
||||
Status CreateDataServiceWorkerClient(
|
||||
const std::string& address, const std::string& protocol,
|
||||
const std::string& transfer_protocol,
|
||||
std::unique_ptr<DataServiceWorkerClient>& out);
|
||||
|
||||
} // namespace data
|
||||
|
110
tensorflow/core/data/service/data_transfer.cc
Normal file
110
tensorflow/core/data/service/data_transfer.cc
Normal file
@ -0,0 +1,110 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/data/service/data_transfer.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
namespace {
|
||||
mutex* get_lock() {
|
||||
static mutex lock(LINKER_INITIALIZED);
|
||||
return &lock;
|
||||
}
|
||||
|
||||
using DataTransferServerFactories =
|
||||
std::unordered_map<std::string,
|
||||
std::function<std::shared_ptr<DataTransferServer>(
|
||||
DataTransferServer::GetElementT)>>;
|
||||
DataTransferServerFactories& transfer_server_factories() {
|
||||
static auto& factories = *new DataTransferServerFactories();
|
||||
return factories;
|
||||
}
|
||||
|
||||
using DataTransferClientFactories =
|
||||
std::unordered_map<std::string, DataTransferClient::FactoryT>;
|
||||
DataTransferClientFactories& transfer_client_factories() {
|
||||
static auto& factories = *new DataTransferClientFactories();
|
||||
return factories;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void DataTransferServer::Register(
|
||||
std::string name,
|
||||
std::function<std::shared_ptr<DataTransferServer>(GetElementT)> factory) {
|
||||
mutex_lock l(*get_lock());
|
||||
if (!transfer_server_factories().insert({name, factory}).second) {
|
||||
LOG(ERROR)
|
||||
<< "Two data transfer server factories are being registered with name "
|
||||
<< name << ". Which one gets used is undefined.";
|
||||
}
|
||||
}
|
||||
|
||||
Status DataTransferServer::Build(std::string name, GetElementT get_element,
|
||||
std::shared_ptr<DataTransferServer>* out) {
|
||||
mutex_lock l(*get_lock());
|
||||
auto it = transfer_server_factories().find(name);
|
||||
if (it != transfer_server_factories().end()) {
|
||||
*out = it->second(get_element);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<string> available_names;
|
||||
for (const auto& factory : transfer_server_factories()) {
|
||||
available_names.push_back(factory.first);
|
||||
}
|
||||
|
||||
return errors::NotFound(
|
||||
"No data transfer server factory has been registered for name ", name,
|
||||
". The available names are: [ ", absl::StrJoin(available_names, ", "),
|
||||
" ]");
|
||||
}
|
||||
|
||||
void DataTransferClient::Register(std::string name, FactoryT factory) {
|
||||
mutex_lock l(*get_lock());
|
||||
if (!transfer_client_factories().insert({name, factory}).second) {
|
||||
LOG(ERROR)
|
||||
<< "Two data transfer client factories are being registered with name "
|
||||
<< name << ". Which one gets used is undefined.";
|
||||
}
|
||||
}
|
||||
|
||||
Status DataTransferClient::Build(std::string name, Config config,
|
||||
std::unique_ptr<DataTransferClient>* out) {
|
||||
mutex_lock l(*get_lock());
|
||||
auto it = transfer_client_factories().find(name);
|
||||
if (it != transfer_client_factories().end()) {
|
||||
return it->second(config, out);
|
||||
}
|
||||
|
||||
std::vector<string> available_names;
|
||||
for (const auto& factory : transfer_client_factories()) {
|
||||
available_names.push_back(factory.first);
|
||||
}
|
||||
|
||||
return errors::NotFound(
|
||||
"No data transfer client factory has been registered for name ", name,
|
||||
". The available names are: [ ", absl::StrJoin(available_names, ", "),
|
||||
" ]");
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
87
tensorflow/core/data/service/data_transfer.h
Normal file
87
tensorflow/core/data/service/data_transfer.h
Normal file
@ -0,0 +1,87 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_
|
||||
#define TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/data/dataset.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// Client for communicating with the tf.data service transfer server.
|
||||
class DataTransferClient {
|
||||
public:
|
||||
struct Config {
|
||||
absl::string_view protocol;
|
||||
std::string address;
|
||||
};
|
||||
using FactoryT =
|
||||
std::function<Status(Config, std::unique_ptr<DataTransferClient>*)>;
|
||||
virtual ~DataTransferClient() = default;
|
||||
|
||||
// Fetches the next element for the specified task_id. The element's
|
||||
// compressed tensors will be stored in `element`. If no element is available,
|
||||
// `end_of_sequence` will be `true`, and `element` will be left unchanged.
|
||||
virtual Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
|
||||
absl::optional<int64> round_index,
|
||||
tensorflow::data::CompressedElement& element,
|
||||
bool& end_of_sequence) = 0;
|
||||
|
||||
// Makes a best effort to cancel all outstanding calls in progress for the
|
||||
// client, and causes further calls to return Cancelled status.
|
||||
virtual void TryCancel() = 0;
|
||||
|
||||
// Registers a DataTransferClient factory under `name`.
|
||||
static void Register(std::string name, FactoryT factory);
|
||||
|
||||
// Builds a DataTransferClient from the factory registered under `name`.
|
||||
static Status Build(std::string name, Config config,
|
||||
std::unique_ptr<DataTransferClient>* out);
|
||||
};
|
||||
|
||||
// Server for communicating with the tf.data service transfer client.
|
||||
class DataTransferServer {
|
||||
public:
|
||||
using GetElementT =
|
||||
std::function<Status(const GetElementRequest*, GetElementResponse*)>;
|
||||
virtual ~DataTransferServer() = default;
|
||||
|
||||
// Starts DataTransferServer, it should be available for requests afterwards.
|
||||
virtual Status Start() = 0;
|
||||
|
||||
// Return the port that this server is listening on.
|
||||
virtual int get_port() = 0;
|
||||
|
||||
// Register a DataTransferServer factory under `name`.
|
||||
static void Register(
|
||||
std::string name,
|
||||
std::function<std::shared_ptr<DataTransferServer>(GetElementT)> factory);
|
||||
|
||||
// Builds a DataTransferServer from the factory registered with `name`.
|
||||
static Status Build(std::string name, GetElementT get_element,
|
||||
std::shared_ptr<DataTransferServer>* out);
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_
|
54
tensorflow/core/data/service/data_transfer_test.cc
Normal file
54
tensorflow/core/data/service/data_transfer_test.cc
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/data/service/data_transfer.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
class TestDataTransferServer : public DataTransferServer {
|
||||
public:
|
||||
explicit TestDataTransferServer(bool* called) : called_(called) {}
|
||||
Status Start() override {
|
||||
*called_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
int get_port() override { return 0; }
|
||||
|
||||
private:
|
||||
bool* called_;
|
||||
};
|
||||
|
||||
TEST(DataTransferTest, RegisterDataTransferServerBuilder) {
|
||||
bool called = false;
|
||||
DataTransferServer::Register("test", [&called](auto _) {
|
||||
return std::make_shared<TestDataTransferServer>(&called);
|
||||
});
|
||||
|
||||
std::shared_ptr<DataTransferServer> server;
|
||||
TF_ASSERT_OK(DataTransferServer::Build("test", {}, &server));
|
||||
EXPECT_FALSE(called);
|
||||
|
||||
TF_ASSERT_OK(server->Start());
|
||||
EXPECT_TRUE(called);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -14,6 +14,7 @@ message TaskProgress {
|
||||
|
||||
message WorkerHeartbeatRequest {
|
||||
string worker_address = 1;
|
||||
string transfer_address = 3;
|
||||
repeated int64 current_tasks = 2;
|
||||
}
|
||||
|
||||
|
@ -217,6 +217,8 @@ Status DataServiceDispatcherImpl::WorkerHeartbeat(
|
||||
}
|
||||
Update update;
|
||||
update.mutable_register_worker()->set_worker_address(worker_address);
|
||||
update.mutable_register_worker()->set_transfer_address(
|
||||
request->transfer_address());
|
||||
TF_RETURN_IF_ERROR(Apply(update));
|
||||
TF_RETURN_IF_ERROR(CreateTasksForWorker(worker_address));
|
||||
TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, correct_tasks));
|
||||
@ -584,6 +586,9 @@ Status DataServiceDispatcherImpl::CreateTask(std::shared_ptr<const Job> job,
|
||||
create_task->set_task_id(task_id);
|
||||
create_task->set_job_id(job->job_id);
|
||||
create_task->set_worker_address(worker_address);
|
||||
std::shared_ptr<const Worker> worker;
|
||||
TF_RETURN_IF_ERROR(state_.WorkerFromAddress(worker_address, worker));
|
||||
create_task->set_transfer_address(worker->transfer_address);
|
||||
TF_RETURN_IF_ERROR(Apply(update));
|
||||
TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
|
||||
return Status::OK();
|
||||
@ -686,6 +691,7 @@ Status DataServiceDispatcherImpl::GetTasks(const GetTasksRequest* request,
|
||||
for (const auto& task : tasks) {
|
||||
TaskInfo* task_info = response->mutable_task_info()->Add();
|
||||
task_info->set_worker_address(task->worker_address);
|
||||
task_info->set_transfer_address(task->transfer_address);
|
||||
task_info->set_task_id(task->task_id);
|
||||
task_info->set_job_id(job->job_id);
|
||||
}
|
||||
|
@ -74,7 +74,8 @@ void DispatcherState::RegisterWorker(
|
||||
const RegisterWorkerUpdate& register_worker) {
|
||||
std::string address = register_worker.worker_address();
|
||||
DCHECK(!workers_.contains(address));
|
||||
workers_[address] = std::make_shared<Worker>(address);
|
||||
workers_[address] =
|
||||
std::make_shared<Worker>(address, register_worker.transfer_address());
|
||||
tasks_by_worker_[address] =
|
||||
absl::flat_hash_map<int64, std::shared_ptr<Task>>();
|
||||
}
|
||||
@ -146,7 +147,8 @@ void DispatcherState::CreateTask(const CreateTaskUpdate& create_task) {
|
||||
DCHECK_EQ(task, nullptr);
|
||||
auto& job = jobs_[create_task.job_id()];
|
||||
DCHECK_NE(job, nullptr);
|
||||
task = std::make_shared<Task>(task_id, job, create_task.worker_address());
|
||||
task = std::make_shared<Task>(task_id, job, create_task.worker_address(),
|
||||
create_task.transfer_address());
|
||||
tasks_by_job_[create_task.job_id()].push_back(task);
|
||||
tasks_by_worker_[create_task.worker_address()][task->task_id] = task;
|
||||
next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
|
||||
|
@ -69,9 +69,12 @@ class DispatcherState {
|
||||
|
||||
// A worker registered with the dispatcher.
|
||||
struct Worker {
|
||||
explicit Worker(const std::string& address) : address(address) {}
|
||||
explicit Worker(const std::string& address,
|
||||
const std::string& transfer_address)
|
||||
: address(address), transfer_address(transfer_address) {}
|
||||
|
||||
const std::string address;
|
||||
const std::string transfer_address;
|
||||
};
|
||||
|
||||
// A key for identifying a named job. The key contains a user-specified name,
|
||||
@ -128,12 +131,17 @@ class DispatcherState {
|
||||
|
||||
struct Task {
|
||||
explicit Task(int64 task_id, const std::shared_ptr<Job>& job,
|
||||
const std::string& worker_address)
|
||||
: task_id(task_id), job(job), worker_address(worker_address) {}
|
||||
const std::string& worker_address,
|
||||
const std::string& transfer_address)
|
||||
: task_id(task_id),
|
||||
job(job),
|
||||
worker_address(worker_address),
|
||||
transfer_address(transfer_address) {}
|
||||
|
||||
const int64 task_id;
|
||||
const std::shared_ptr<Job> job;
|
||||
const std::string worker_address;
|
||||
const std::string transfer_address;
|
||||
bool finished = false;
|
||||
};
|
||||
|
||||
|
@ -31,8 +31,9 @@ GrpcWorkerImpl::GrpcWorkerImpl(const experimental::WorkerConfig& config,
|
||||
VLOG(1) << "Registered data service worker";
|
||||
}
|
||||
|
||||
Status GrpcWorkerImpl::Start(const std::string& worker_address) {
|
||||
return impl_.Start(worker_address);
|
||||
Status GrpcWorkerImpl::Start(const std::string& worker_address,
|
||||
const std::string& transfer_address) {
|
||||
return impl_.Start(worker_address, transfer_address);
|
||||
}
|
||||
|
||||
#define HANDLER(method) \
|
||||
|
@ -33,7 +33,16 @@ class GrpcWorkerImpl : public WorkerService::Service {
|
||||
::grpc::ServerBuilder& server_builder);
|
||||
~GrpcWorkerImpl() override {}
|
||||
|
||||
Status Start(const std::string& worker_address);
|
||||
Status Start(const std::string& worker_address,
|
||||
const std::string& transfer_address);
|
||||
|
||||
std::function<Status(const GetElementRequest*, GetElementResponse*)>
|
||||
get_element_getter() {
|
||||
return
|
||||
[this](const GetElementRequest* request, GetElementResponse* response) {
|
||||
return impl_.GetElement(request, response);
|
||||
};
|
||||
}
|
||||
|
||||
#define HANDLER(method) \
|
||||
::grpc::Status method(::grpc::ServerContext* context, \
|
||||
|
@ -27,6 +27,7 @@ message RegisterDatasetUpdate {
|
||||
|
||||
message RegisterWorkerUpdate {
|
||||
string worker_address = 1;
|
||||
string transfer_address = 2;
|
||||
}
|
||||
|
||||
message NamedJobKeyDef {
|
||||
@ -71,6 +72,7 @@ message CreateTaskUpdate {
|
||||
int64 task_id = 1;
|
||||
int64 job_id = 2;
|
||||
string worker_address = 4;
|
||||
string transfer_address = 6;
|
||||
}
|
||||
|
||||
message FinishTaskUpdate {
|
||||
|
@ -127,14 +127,27 @@ void WorkerGrpcDataServer::AddDataServiceToBuilder(
|
||||
}
|
||||
|
||||
Status WorkerGrpcDataServer::StartServiceInternal() {
|
||||
std::string worker_address = config_.worker_address();
|
||||
if (worker_address.empty()) {
|
||||
worker_address = absl::StrCat("localhost:", kPortPlaceholder);
|
||||
std::string base_address = config_.worker_address();
|
||||
if (base_address.empty()) {
|
||||
base_address = absl::StrCat("localhost:", kPortPlaceholder);
|
||||
}
|
||||
std::string resolved_address = str_util::StringReplace(
|
||||
worker_address, kPortPlaceholder, absl::StrCat(bound_port()),
|
||||
std::string worker_address = str_util::StringReplace(
|
||||
base_address, kPortPlaceholder, absl::StrCat(bound_port()),
|
||||
/*replace_all=*/false);
|
||||
TF_RETURN_IF_ERROR(service_->Start(resolved_address));
|
||||
std::string transfer_address = worker_address;
|
||||
std::string transfer_protocol = config_.data_transfer_protocol();
|
||||
if (!transfer_protocol.empty()) {
|
||||
TF_RETURN_IF_ERROR(DataTransferServer::Build(
|
||||
transfer_protocol, service_->get_element_getter(), &transfer_server_));
|
||||
TF_RETURN_IF_ERROR(transfer_server_->Start());
|
||||
LOG(INFO) << "Data transfer server started at 0.0.0.0:"
|
||||
<< transfer_server_->get_port();
|
||||
transfer_address =
|
||||
str_util::StringReplace(base_address, kPortPlaceholder,
|
||||
absl::StrCat(transfer_server_->get_port()),
|
||||
/*replace_all=*/false);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(service_->Start(worker_address, transfer_address));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "grpcpp/server.h"
|
||||
#include "grpcpp/server_builder.h"
|
||||
#include "tensorflow/core/data/service/data_transfer.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
|
||||
#include "tensorflow/core/protobuf/service_config.pb.h"
|
||||
@ -109,6 +110,7 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
|
||||
const experimental::WorkerConfig config_;
|
||||
// Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared.
|
||||
GrpcWorkerImpl* service_;
|
||||
std::shared_ptr<DataTransferServer> transfer_server_;
|
||||
};
|
||||
|
||||
// Creates a dispatch tf.data server and stores it in `out_server`.
|
||||
|
@ -64,9 +64,11 @@ DataServiceWorkerImpl::~DataServiceWorkerImpl() {
|
||||
heartbeat_cv_.notify_one();
|
||||
}
|
||||
|
||||
Status DataServiceWorkerImpl::Start(const std::string& worker_address) {
|
||||
Status DataServiceWorkerImpl::Start(const std::string& worker_address,
|
||||
const std::string& transfer_address) {
|
||||
VLOG(3) << "Starting tf.data service worker at address " << worker_address;
|
||||
worker_address_ = worker_address;
|
||||
transfer_address_ = transfer_address;
|
||||
|
||||
dispatcher_ = absl::make_unique<DataServiceDispatcherClient>(
|
||||
config_.dispatcher_address(), config_.protocol());
|
||||
@ -356,8 +358,9 @@ Status DataServiceWorkerImpl::Heartbeat() TF_LOCKS_EXCLUDED(mu_) {
|
||||
}
|
||||
std::vector<TaskDef> new_tasks;
|
||||
std::vector<int64> tasks_to_delete;
|
||||
TF_RETURN_IF_ERROR(dispatcher_->WorkerHeartbeat(
|
||||
worker_address_, current_tasks, new_tasks, tasks_to_delete));
|
||||
TF_RETURN_IF_ERROR(
|
||||
dispatcher_->WorkerHeartbeat(worker_address_, transfer_address_,
|
||||
current_tasks, new_tasks, tasks_to_delete));
|
||||
mutex_lock l(mu_);
|
||||
for (const auto& task : new_tasks) {
|
||||
Status s = ProcessTaskInternal(task);
|
||||
|
@ -41,7 +41,8 @@ class DataServiceWorkerImpl {
|
||||
// constructor because the worker may be binding to port `0`, in which case
|
||||
// the address isn't known until the worker has started and decided which port
|
||||
// to bind to.
|
||||
Status Start(const std::string& worker_address);
|
||||
Status Start(const std::string& worker_address,
|
||||
const std::string& transfer_address);
|
||||
|
||||
// See worker.proto for API documentation.
|
||||
|
||||
@ -81,6 +82,7 @@ class DataServiceWorkerImpl {
|
||||
const experimental::WorkerConfig config_;
|
||||
// The worker's own address.
|
||||
std::string worker_address_;
|
||||
std::string transfer_address_;
|
||||
std::unique_ptr<DataServiceDispatcherClient> dispatcher_;
|
||||
|
||||
mutex mu_;
|
||||
|
@ -51,6 +51,8 @@ namespace data {
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kProcessingMode;
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kAddress;
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kProtocol;
|
||||
/* static */ constexpr const char* const
|
||||
DataServiceDatasetOp::kDataTransferProtocol;
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kJobName;
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kConsumerIndex;
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kNumConsumers;
|
||||
@ -80,8 +82,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, int op_version, int64 dataset_id,
|
||||
ProcessingMode processing_mode, const std::string& address,
|
||||
const std::string& protocol, const std::string& job_name,
|
||||
absl::optional<int64> consumer_index,
|
||||
const std::string& protocol,
|
||||
const std::string& data_transfer_protocol,
|
||||
const std::string& job_name, absl::optional<int64> consumer_index,
|
||||
absl::optional<int64> num_consumers, int64 max_outstanding_requests,
|
||||
int64 task_refresh_interval_ms, IterationCounter* iteration_counter,
|
||||
bool owns_resource, ResourceHandle iteration_counter_handle,
|
||||
@ -93,6 +96,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
processing_mode_(processing_mode),
|
||||
address_(address),
|
||||
protocol_(protocol),
|
||||
data_transfer_protocol_(data_transfer_protocol),
|
||||
job_name_(job_name),
|
||||
consumer_index_(consumer_index),
|
||||
num_consumers_(num_consumers),
|
||||
@ -164,6 +168,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(protocol_, &protocol));
|
||||
inputs.push_back(protocol);
|
||||
|
||||
AttrValue data_transfer_protocol;
|
||||
b->BuildAttrValue(data_transfer_protocol_, &data_transfer_protocol);
|
||||
|
||||
Node* job_name;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(job_name_, &job_name));
|
||||
inputs.push_back(job_name);
|
||||
@ -195,11 +202,12 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
b->BuildAttrValue(task_refresh_interval_ms_,
|
||||
&task_refresh_interval_hint_ms);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this, inputs,
|
||||
{std::make_pair(kTaskRefreshIntervalHintMs,
|
||||
task_refresh_interval_hint_ms)},
|
||||
output));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, inputs,
|
||||
{std::make_pair(kTaskRefreshIntervalHintMs,
|
||||
task_refresh_interval_hint_ms),
|
||||
std::make_pair(kDataTransferProtocol, data_transfer_protocol)},
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -459,8 +467,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
TaskInfo& task_info = it->second;
|
||||
std::unique_ptr<DataServiceWorkerClient> worker;
|
||||
Status s = CreateDataServiceWorkerClient(task_info.worker_address(),
|
||||
dataset()->protocol_, worker);
|
||||
Status s = CreateDataServiceWorkerClient(
|
||||
task_info.transfer_address(), dataset()->protocol_,
|
||||
dataset()->data_transfer_protocol_, worker);
|
||||
if (!s.ok()) {
|
||||
status_ = s;
|
||||
get_next_cv_.notify_all();
|
||||
@ -743,6 +752,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
const ProcessingMode processing_mode_;
|
||||
const tstring address_;
|
||||
const tstring protocol_;
|
||||
const tstring data_transfer_protocol_;
|
||||
const tstring job_name_;
|
||||
const absl::optional<int64> consumer_index_;
|
||||
const absl::optional<int64> num_consumers_;
|
||||
@ -765,6 +775,11 @@ DataServiceDatasetOp::DataServiceDatasetOp(OpKernelConstruction* ctx)
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
if (ctx->HasAttr(kDataTransferProtocol)) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr(kDataTransferProtocol, &data_transfer_protocol_));
|
||||
}
|
||||
if (data_transfer_protocol_.empty()) data_transfer_protocol_ = "grpc";
|
||||
auto& op_name = ctx->def().op();
|
||||
if (op_name == kDataServiceDatasetV1) {
|
||||
op_version_ = 1;
|
||||
@ -859,11 +874,12 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
errors::InvalidArgument(kMaxOutstandingRequests, " must be positive or ",
|
||||
model::kAutotune));
|
||||
|
||||
*output = new Dataset(
|
||||
ctx, op_version_, dataset_id, processing_mode, address, protocol,
|
||||
job_name, consumer_index, num_consumers, max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms_, iteration_counter, owns_resource,
|
||||
iteration_counter_handle, output_types_, output_shapes_);
|
||||
*output = new Dataset(ctx, op_version_, dataset_id, processing_mode, address,
|
||||
protocol, data_transfer_protocol_, job_name,
|
||||
consumer_index, num_consumers, max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms_, iteration_counter,
|
||||
owns_resource, iteration_counter_handle, output_types_,
|
||||
output_shapes_);
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU),
|
||||
|
@ -51,6 +51,8 @@ class DataServiceDatasetOp : public DatasetOpKernel {
|
||||
static constexpr const char* const kProcessingMode = "processing_mode";
|
||||
static constexpr const char* const kAddress = "address";
|
||||
static constexpr const char* const kProtocol = "protocol";
|
||||
static constexpr const char* const kDataTransferProtocol =
|
||||
"data_transfer_protocol";
|
||||
static constexpr const char* const kJobName = "job_name";
|
||||
static constexpr const char* const kConsumerIndex = "consumer_index";
|
||||
static constexpr const char* const kNumConsumers = "num_consumers";
|
||||
@ -76,6 +78,7 @@ class DataServiceDatasetOp : public DatasetOpKernel {
|
||||
int64 task_refresh_interval_hint_ms_;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
std::string data_transfer_protocol_;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
|
@ -1184,6 +1184,7 @@ REGISTER_OP("DataServiceDataset")
|
||||
.Attr("task_refresh_interval_hint_ms: int = -1")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("data_transfer_protocol: string = ''")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
@ -1203,6 +1204,7 @@ REGISTER_OP("DataServiceDatasetV2")
|
||||
.Attr("task_refresh_interval_hint_ms: int = -1")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("data_transfer_protocol: string = ''")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
|
@ -40,4 +40,6 @@ message WorkerConfig {
|
||||
// How long to retry requests to the dispatcher before giving up and reporting
|
||||
// an error.
|
||||
int64 dispatcher_timeout_ms = 6;
|
||||
// The protocol for the worker to use when transferring data to clients.
|
||||
string data_transfer_protocol = 7;
|
||||
}
|
||||
|
@ -59,6 +59,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||
processing_mode,
|
||||
address,
|
||||
protocol,
|
||||
data_transfer_protocol,
|
||||
job_name=None,
|
||||
consumer_index=None,
|
||||
num_consumers=None,
|
||||
@ -76,6 +77,8 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||
address: The tf.data service address, e.g. "localhost:5000".
|
||||
protocol: The protocol to use for communicating with the tf.data service,
|
||||
e.g. "grpc".
|
||||
data_transfer_protocol: The protocol to use for transferring data with the
|
||||
tf.data service, e.g. "grpc".
|
||||
job_name: (Optional.) The name of the job. This argument makes it possible
|
||||
for multiple datasets to share the same job. The default behavior is
|
||||
that the dataset creates anonymous, exclusively owned jobs.
|
||||
@ -138,6 +141,10 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||
# represented by scalar DT_VARIANTs.
|
||||
self._element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
|
||||
|
||||
compat_kwargs = {}
|
||||
if data_transfer_protocol is not None:
|
||||
compat_kwargs["data_transfer_protocol"] = data_transfer_protocol
|
||||
|
||||
if num_consumers is None:
|
||||
variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
|
||||
dataset_id=self._dataset_id,
|
||||
@ -149,6 +156,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
|
||||
iteration_counter=gen_experimental_dataset_ops
|
||||
.dummy_iteration_counter(),
|
||||
**compat_kwargs,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_experimental_dataset_ops.data_service_dataset_v2(
|
||||
@ -163,6 +171,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
|
||||
iteration_counter=gen_experimental_dataset_ops
|
||||
.dummy_iteration_counter(),
|
||||
**compat_kwargs,
|
||||
**self._flat_structure)
|
||||
super(_DataServiceDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@ -175,15 +184,16 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
|
||||
"""A `Dataset` that executes its input through the tf.data service."""
|
||||
|
||||
@functools.wraps(_DataServiceDatasetV2.__init__)
|
||||
def __init__(self, dataset_id, processing_mode, address, protocol, job_name,
|
||||
consumer_index, num_consumers, max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms):
|
||||
def __init__(self, dataset_id, processing_mode, address, protocol,
|
||||
data_transfer_protocol, job_name, consumer_index, num_consumers,
|
||||
max_outstanding_requests, task_refresh_interval_hint_ms):
|
||||
|
||||
self._wrapped = _DataServiceDatasetV2(
|
||||
dataset_id=dataset_id,
|
||||
processing_mode=processing_mode,
|
||||
address=address,
|
||||
protocol=protocol,
|
||||
data_transfer_protocol=data_transfer_protocol,
|
||||
job_name=job_name,
|
||||
consumer_index=consumer_index,
|
||||
num_consumers=num_consumers,
|
||||
@ -233,7 +243,8 @@ def _from_dataset_id(processing_mode,
|
||||
consumer_index=None,
|
||||
num_consumers=None,
|
||||
max_outstanding_requests=None,
|
||||
task_refresh_interval_hint_ms=None):
|
||||
task_refresh_interval_hint_ms=None,
|
||||
data_transfer_protocol=None):
|
||||
"""Creates a dataset which reads data from the tf.data service.
|
||||
|
||||
This transformation is similar to `from_dataset_id`, but supports additional
|
||||
@ -274,6 +285,8 @@ def _from_dataset_id(processing_mode,
|
||||
`max_outstanding_requests` of memory.
|
||||
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
|
||||
dispatcher for task changes.
|
||||
data_transfer_protocol: (Optional.) The protocol to use for transferring
|
||||
data with the tf.data service.
|
||||
|
||||
Returns:
|
||||
A `tf.data.Dataset` which reads from the tf.data service.
|
||||
@ -294,6 +307,7 @@ def _from_dataset_id(processing_mode,
|
||||
processing_mode=processing_mode,
|
||||
address=address,
|
||||
protocol=protocol,
|
||||
data_transfer_protocol=data_transfer_protocol,
|
||||
job_name=job_name,
|
||||
consumer_index=consumer_index,
|
||||
num_consumers=num_consumers,
|
||||
@ -317,7 +331,8 @@ def _distribute(processing_mode,
|
||||
consumer_index=None,
|
||||
num_consumers=None,
|
||||
max_outstanding_requests=None,
|
||||
task_refresh_interval_hint_ms=None):
|
||||
task_refresh_interval_hint_ms=None,
|
||||
data_transfer_protocol=None):
|
||||
"""A transformation that moves dataset processing to the tf.data service.
|
||||
|
||||
This transformation is similar to `distribute`, but supports additional
|
||||
@ -352,6 +367,8 @@ def _distribute(processing_mode,
|
||||
`max_outstanding_requests` of memory.
|
||||
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
|
||||
dispatcher for task changes.
|
||||
data_transfer_protocol: (Optional.) The protocol to use for transferring
|
||||
data with the tf.data service.
|
||||
|
||||
Returns:
|
||||
Dataset: A `Dataset` of the elements produced by the data service.
|
||||
@ -369,7 +386,8 @@ def _distribute(processing_mode,
|
||||
consumer_index=consumer_index,
|
||||
num_consumers=num_consumers,
|
||||
max_outstanding_requests=max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
|
||||
data_transfer_protocol=data_transfer_protocol)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
@ -380,7 +398,8 @@ def distribute(processing_mode,
|
||||
job_name=None,
|
||||
consumer_index=None,
|
||||
num_consumers=None,
|
||||
max_outstanding_requests=None):
|
||||
max_outstanding_requests=None,
|
||||
data_transfer_protocol=None):
|
||||
"""A transformation that moves dataset processing to the tf.data service.
|
||||
|
||||
When you iterate over a dataset containing the `distribute` transformation,
|
||||
@ -580,6 +599,8 @@ def distribute(processing_mode,
|
||||
requested at the same time. You can use this option to control the amount
|
||||
of memory used, since `distribute` won't use more than `element_size` *
|
||||
`max_outstanding_requests` of memory.
|
||||
data_transfer_protocol: (Optional.) The protocol to use for transferring
|
||||
data with the tf.data service, e.g. "grpc".
|
||||
|
||||
Returns:
|
||||
Dataset: A `Dataset` of the elements produced by the data service.
|
||||
@ -590,7 +611,8 @@ def distribute(processing_mode,
|
||||
job_name=job_name,
|
||||
consumer_index=consumer_index,
|
||||
num_consumers=num_consumers,
|
||||
max_outstanding_requests=max_outstanding_requests)
|
||||
max_outstanding_requests=max_outstanding_requests,
|
||||
data_transfer_protocol=data_transfer_protocol)
|
||||
|
||||
|
||||
@tf_export("data.experimental.service.register_dataset")
|
||||
|
@ -307,7 +307,8 @@ class WorkerServer(object):
|
||||
port=config.port,
|
||||
protocol=config.protocol,
|
||||
heartbeat_interval_ms=config.heartbeat_interval_ms,
|
||||
dispatcher_timeout_ms=config.dispatcher_timeout_ms)
|
||||
dispatcher_timeout_ms=config.dispatcher_timeout_ms,
|
||||
data_transfer_protocol=None)
|
||||
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
|
||||
config_proto.SerializeToString())
|
||||
if start:
|
||||
|
@ -10,7 +10,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "distribute"
|
||||
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'data_transfer_protocol\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_dataset_id"
|
||||
|
@ -1006,11 +1006,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "DataServiceDataset"
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'data_transfer_protocol\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DataServiceDatasetV2"
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'data_transfer_protocol\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DatasetCardinality"
|
||||
|
@ -18,7 +18,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "distribute"
|
||||
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'data_transfer_protocol\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_dataset_id"
|
||||
|
@ -1006,11 +1006,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "DataServiceDataset"
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'data_transfer_protocol\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DataServiceDatasetV2"
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'data_transfer_protocol\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DatasetCardinality"
|
||||
|
Loading…
Reference in New Issue
Block a user