From 3e3364c6f16486001e34b7b422890f3182803564 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 20 Jan 2021 13:54:40 -0800 Subject: [PATCH] [tf.data service] Add a DataTransferClient/Server abstraction to support transfer methods other than gRPC PiperOrigin-RevId: 352868792 Change-Id: I47cfd8473a917ddbd60490fe7cd6a9ba7e1ba34f --- RELEASE.md | 2 + tensorflow/core/data/service/BUILD | 38 +++++ tensorflow/core/data/service/common.proto | 2 + tensorflow/core/data/service/data_service.cc | 150 ++++++++++++------ tensorflow/core/data/service/data_service.h | 18 +-- tensorflow/core/data/service/data_transfer.cc | 110 +++++++++++++ tensorflow/core/data/service/data_transfer.h | 87 ++++++++++ .../core/data/service/data_transfer_test.cc | 54 +++++++ tensorflow/core/data/service/dispatcher.proto | 1 + .../core/data/service/dispatcher_impl.cc | 6 + .../core/data/service/dispatcher_state.cc | 6 +- .../core/data/service/dispatcher_state.h | 14 +- .../core/data/service/grpc_worker_impl.cc | 5 +- .../core/data/service/grpc_worker_impl.h | 11 +- tensorflow/core/data/service/journal.proto | 2 + tensorflow/core/data/service/server_lib.cc | 25 ++- tensorflow/core/data/service/server_lib.h | 2 + tensorflow/core/data/service/worker_impl.cc | 9 +- tensorflow/core/data/service/worker_impl.h | 4 +- .../experimental/data_service_dataset_op.cc | 44 +++-- .../experimental/data_service_dataset_op.h | 3 + .../core/ops/experimental_dataset_ops.cc | 2 + tensorflow/core/protobuf/service_config.proto | 2 + .../data/experimental/ops/data_service_ops.py | 38 ++++- .../data/experimental/service/server_lib.py | 3 +- ...tensorflow.data.experimental.service.pbtxt | 2 +- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 4 +- ...tensorflow.data.experimental.service.pbtxt | 2 +- .../api/golden/v2/tensorflow.raw_ops.pbtxt | 4 +- 29 files changed, 543 insertions(+), 107 deletions(-) create mode 100644 tensorflow/core/data/service/data_transfer.cc create mode 100644 tensorflow/core/data/service/data_transfer.h create mode 100644 tensorflow/core/data/service/data_transfer_test.cc diff --git a/RELEASE.md b/RELEASE.md index 437cab7215c..a285c0e1942 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -30,6 +30,8 @@ for synchronous training workloads where example sizes vary. With strict round robin reads, users can guarantee that consumers get similar-sized examples in the same step. + * tf.data service supports custom data transfer protocols (other than + gRPC). ## Bug Fixes and Other Changes diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 02282894033..44ac9e57b60 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -20,6 +20,8 @@ load( "tf_cc_test", ) +package_group(name = "data_transfer_visibility") + package( default_visibility = [ "//tensorflow:internal", @@ -53,6 +55,9 @@ tf_proto_library( ":common_proto", "//tensorflow/core/data:dataset_proto", ], + visibility = [ + ":data_transfer_visibility", + ], ) cc_library( @@ -86,6 +91,7 @@ cc_library( ], deps = [ ":credentials_factory", + ":data_transfer", ":dispatcher_cc_grpc_proto", ":grpc_util", ":worker_cc_grpc_proto", @@ -124,6 +130,37 @@ tf_cc_test( ] + tf_protos_profiler_service(), ) +cc_library( + name = "data_transfer", + srcs = ["data_transfer.cc"], + hdrs = ["data_transfer.h"], + visibility = [ + ":data_transfer_visibility", + ], + deps = [ + ":worker_proto_cc", + "//tensorflow/core/data:dataset_proto_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:mutex", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "data_transfer_test", + srcs = ["data_transfer_test.cc"], + deps = [ + ":data_transfer", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "dataset_store", srcs = ["dataset_store.cc"], @@ -335,6 +372,7 @@ cc_library( ], deps = [ ":credentials_factory", + ":data_transfer", ":grpc_dispatcher_impl", ":grpc_util", ":grpc_worker_impl", diff --git a/tensorflow/core/data/service/common.proto b/tensorflow/core/data/service/common.proto index c48b6e3263b..8c81c6ff3e7 100644 --- a/tensorflow/core/data/service/common.proto +++ b/tensorflow/core/data/service/common.proto @@ -30,6 +30,8 @@ message TaskDef { message TaskInfo { // The address of the worker processing the task. string worker_address = 1; + // The transfer address of the worker processing the task. + string transfer_address = 4; // The task id. int64 task_id = 2; // The id of the job that the task is part of. diff --git a/tensorflow/core/data/service/data_service.cc b/tensorflow/core/data/service/data_service.cc index eb1f3359815..a6cee593b47 100644 --- a/tensorflow/core/data/service/data_service.cc +++ b/tensorflow/core/data/service/data_service.cc @@ -31,6 +31,7 @@ namespace data { namespace { constexpr const char kParallelEpochs[] = "parallel_epochs"; constexpr const char kDistributedEpoch[] = "distributed_epoch"; + } // namespace Status ParseProcessingMode(const std::string& s, ProcessingMode& mode) { @@ -57,11 +58,13 @@ std::string ProcessingModeToString(ProcessingMode mode) { } Status DataServiceDispatcherClient::WorkerHeartbeat( - const std::string& worker_address, const std::vector& current_tasks, - std::vector& new_tasks, std::vector& tasks_to_delete) { + const std::string& worker_address, const std::string& transfer_address, + const std::vector& current_tasks, std::vector& new_tasks, + std::vector& tasks_to_delete) { TF_RETURN_IF_ERROR(EnsureInitialized()); WorkerHeartbeatRequest req; req.set_worker_address(worker_address); + req.set_transfer_address(transfer_address); for (int64 task : current_tasks) { req.add_current_tasks(task); } @@ -244,69 +247,112 @@ Status DataServiceDispatcherClient::EnsureInitialized() { return Status::OK(); } +class GrpcDataTransferClient : public DataTransferClient { + public: + GrpcDataTransferClient(std::shared_ptr credentials, + std::string address) { + grpc::ChannelArguments args; + args.SetMaxReceiveMessageSize(-1); + auto channel = grpc::CreateCustomChannel(address, credentials, args); + stub_ = WorkerService::NewStub(channel); + } + + Status GetElement(int64 task_id, absl::optional consumer_index, + absl::optional round_index, + CompressedElement& element, + bool& end_of_sequence) override { + { + mutex_lock l(mu_); + if (cancelled_) { + return errors::Cancelled("Client was cancelled."); + } + } + GetElementRequest req; + req.set_task_id(task_id); + if (consumer_index.has_value()) { + req.set_consumer_index(consumer_index.value()); + } + if (round_index.has_value()) { + req.set_round_index(round_index.value()); + } + GetElementResponse resp; + grpc::ClientContext ctx; + { + mutex_lock l(mu_); + active_contexts_.insert(&ctx); + } + grpc::Status s = stub_->GetElement(&ctx, req, &resp); + { + mutex_lock l(mu_); + active_contexts_.erase(&ctx); + } + if (!s.ok()) { + return grpc_util::WrapError("Failed to get element", s); + } + end_of_sequence = resp.end_of_sequence(); + if (!end_of_sequence) { + element = std::move(*resp.mutable_compressed_element()); + } + return Status::OK(); + } + + void TryCancel() override { + mutex_lock l(mu_); + cancelled_ = true; + for (const auto& ctx : active_contexts_) { + ctx->TryCancel(); + } + } + + private: + mutex mu_; + std::unique_ptr stub_; + // Set of all currently active clients contexts. Used to support + // cancellation. + absl::flat_hash_set<::grpc::ClientContext*> active_contexts_ GUARDED_BY(mu_); + // Indicates that the client has been cancelled, so no further requests should + // be accepted. + bool cancelled_ GUARDED_BY(mu_) = false; +}; + +class GrpcTransferClientRegistrar { + public: + GrpcTransferClientRegistrar() { + DataTransferClient::Register( + "grpc", [](DataTransferClient::Config config, + std::unique_ptr* out) { + std::shared_ptr credentials; + TF_RETURN_IF_ERROR(CredentialsFactory::CreateClientCredentials( + config.protocol, &credentials)); + *out = std::make_unique(credentials, + config.address); + return Status::OK(); + }); + } +}; +static GrpcTransferClientRegistrar registrar; + Status DataServiceWorkerClient::GetElement(int64 task_id, absl::optional consumer_index, absl::optional round_index, CompressedElement& element, bool& end_of_sequence) { TF_RETURN_IF_ERROR(EnsureInitialized()); - { - mutex_lock l(mu_); - if (cancelled_) { - return errors::Cancelled("Client was cancelled."); - } - } - GetElementRequest req; - req.set_task_id(task_id); - if (consumer_index.has_value()) { - req.set_consumer_index(consumer_index.value()); - } - if (round_index.has_value()) { - req.set_round_index(round_index.value()); - } - GetElementResponse resp; - grpc::ClientContext ctx; - { - mutex_lock l(mu_); - active_contexts_.insert(&ctx); - } - grpc::Status s = stub_->GetElement(&ctx, req, &resp); - { - mutex_lock l(mu_); - active_contexts_.erase(&ctx); - } - if (!s.ok()) { - return grpc_util::WrapError("Failed to get element", s); - } - end_of_sequence = resp.end_of_sequence(); - if (!end_of_sequence) { - element = std::move(*resp.mutable_compressed_element()); - } - return Status::OK(); + return client_->GetElement(task_id, consumer_index, round_index, element, + end_of_sequence); } Status DataServiceWorkerClient::EnsureInitialized() { mutex_lock l(mu_); - if (stub_) { + if (client_) { return Status::OK(); } - std::shared_ptr credentials; - TF_RETURN_IF_ERROR( - CredentialsFactory::CreateClientCredentials(protocol_, &credentials)); - grpc::ChannelArguments args; - args.SetMaxReceiveMessageSize(-1); - auto channel = grpc::CreateCustomChannel(address_, credentials, args); - stub_ = WorkerService::NewStub(channel); + TF_RETURN_IF_ERROR(DataTransferClient::Build( + transfer_protocol_, {protocol_, address_}, &client_)); return Status::OK(); } -void DataServiceWorkerClient::TryCancel() { - mutex_lock l(mu_); - cancelled_ = true; - for (const auto& ctx : active_contexts_) { - ctx->TryCancel(); - } -} +void DataServiceWorkerClient::TryCancel() { client_->TryCancel(); } Status CreateDataServiceDispatcherClient( const std::string& address, const std::string& protocol, @@ -320,8 +366,10 @@ Status CreateDataServiceDispatcherClient( Status CreateDataServiceWorkerClient( const std::string& address, const std::string& protocol, + const std::string& transfer_protocol, std::unique_ptr& out) { - auto client = absl::make_unique(address, protocol); + auto client = absl::make_unique(address, protocol, + transfer_protocol); TF_RETURN_IF_ERROR(client->Initialize()); out = std::move(client); return Status::OK(); diff --git a/tensorflow/core/data/service/data_service.h b/tensorflow/core/data/service/data_service.h index 1ef926c28be..77b0f45906a 100644 --- a/tensorflow/core/data/service/data_service.h +++ b/tensorflow/core/data/service/data_service.h @@ -18,6 +18,7 @@ limitations under the License. #include "grpcpp/impl/codegen/client_context.h" #include "absl/container/flat_hash_set.h" +#include "tensorflow/core/data/service/data_transfer.h" #include "tensorflow/core/data/service/dispatcher.grpc.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/framework/dataset.h" @@ -82,6 +83,7 @@ class DataServiceDispatcherClient : public DataServiceClientBase { // tasks it should delete. This is stored into `new_tasks` and // `tasks_to_delete`. Status WorkerHeartbeat(const std::string& worker_address, + const std::string& transfer_address, const std::vector& current_tasks, std::vector& new_tasks, std::vector& tasks_to_delete); @@ -138,8 +140,10 @@ class DataServiceDispatcherClient : public DataServiceClientBase { class DataServiceWorkerClient : public DataServiceClientBase { public: DataServiceWorkerClient(const std::string& address, - const std::string& protocol) - : DataServiceClientBase(address, protocol) {} + const std::string& protocol, + const std::string& transfer_protocol) + : DataServiceClientBase(address, protocol), + transfer_protocol_(transfer_protocol) {} // Fetches the next element for the specified task_id. The optional // `consumer_index` and `round_index` must be specified for tasks which use @@ -158,16 +162,11 @@ class DataServiceWorkerClient : public DataServiceClientBase { Status EnsureInitialized() override; private: + const std::string transfer_protocol_; mutex mu_; // Initialization is guarded by `mu_`, but using the stub does not require // holding `mu_` - std::unique_ptr stub_; - // Set of all currently active clients contexts. Used to support - // cancellation. - absl::flat_hash_set<::grpc::ClientContext*> active_contexts_ GUARDED_BY(mu_); - // Indicates that the client has been cancelled, so no further requests should - // be accepted. - bool cancelled_ GUARDED_BY(mu_) = false; + std::unique_ptr client_; }; // Creates and initializes a new tf.data service dispatcher client. @@ -178,6 +177,7 @@ Status CreateDataServiceDispatcherClient( // Creates and initializes a new tf.data service worker client. Status CreateDataServiceWorkerClient( const std::string& address, const std::string& protocol, + const std::string& transfer_protocol, std::unique_ptr& out); } // namespace data diff --git a/tensorflow/core/data/service/data_transfer.cc b/tensorflow/core/data/service/data_transfer.cc new file mode 100644 index 00000000000..9685bfc3154 --- /dev/null +++ b/tensorflow/core/data/service/data_transfer.cc @@ -0,0 +1,110 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/data/service/data_transfer.h" + +#include + +#include "absl/strings/str_join.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace data { + +namespace { +mutex* get_lock() { + static mutex lock(LINKER_INITIALIZED); + return &lock; +} + +using DataTransferServerFactories = + std::unordered_map( + DataTransferServer::GetElementT)>>; +DataTransferServerFactories& transfer_server_factories() { + static auto& factories = *new DataTransferServerFactories(); + return factories; +} + +using DataTransferClientFactories = + std::unordered_map; +DataTransferClientFactories& transfer_client_factories() { + static auto& factories = *new DataTransferClientFactories(); + return factories; +} +} // namespace + +void DataTransferServer::Register( + std::string name, + std::function(GetElementT)> factory) { + mutex_lock l(*get_lock()); + if (!transfer_server_factories().insert({name, factory}).second) { + LOG(ERROR) + << "Two data transfer server factories are being registered with name " + << name << ". Which one gets used is undefined."; + } +} + +Status DataTransferServer::Build(std::string name, GetElementT get_element, + std::shared_ptr* out) { + mutex_lock l(*get_lock()); + auto it = transfer_server_factories().find(name); + if (it != transfer_server_factories().end()) { + *out = it->second(get_element); + return Status::OK(); + } + + std::vector available_names; + for (const auto& factory : transfer_server_factories()) { + available_names.push_back(factory.first); + } + + return errors::NotFound( + "No data transfer server factory has been registered for name ", name, + ". The available names are: [ ", absl::StrJoin(available_names, ", "), + " ]"); +} + +void DataTransferClient::Register(std::string name, FactoryT factory) { + mutex_lock l(*get_lock()); + if (!transfer_client_factories().insert({name, factory}).second) { + LOG(ERROR) + << "Two data transfer client factories are being registered with name " + << name << ". Which one gets used is undefined."; + } +} + +Status DataTransferClient::Build(std::string name, Config config, + std::unique_ptr* out) { + mutex_lock l(*get_lock()); + auto it = transfer_client_factories().find(name); + if (it != transfer_client_factories().end()) { + return it->second(config, out); + } + + std::vector available_names; + for (const auto& factory : transfer_client_factories()) { + available_names.push_back(factory.first); + } + + return errors::NotFound( + "No data transfer client factory has been registered for name ", name, + ". The available names are: [ ", absl::StrJoin(available_names, ", "), + " ]"); +} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/data_transfer.h b/tensorflow/core/data/service/data_transfer.h new file mode 100644 index 00000000000..7fc03a5df88 --- /dev/null +++ b/tensorflow/core/data/service/data_transfer.h @@ -0,0 +1,87 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "tensorflow/core/data/dataset.pb.h" +#include "tensorflow/core/data/service/worker.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace data { + +// Client for communicating with the tf.data service transfer server. +class DataTransferClient { + public: + struct Config { + absl::string_view protocol; + std::string address; + }; + using FactoryT = + std::function*)>; + virtual ~DataTransferClient() = default; + + // Fetches the next element for the specified task_id. The element's + // compressed tensors will be stored in `element`. If no element is available, + // `end_of_sequence` will be `true`, and `element` will be left unchanged. + virtual Status GetElement(int64 task_id, absl::optional consumer_index, + absl::optional round_index, + tensorflow::data::CompressedElement& element, + bool& end_of_sequence) = 0; + + // Makes a best effort to cancel all outstanding calls in progress for the + // client, and causes further calls to return Cancelled status. + virtual void TryCancel() = 0; + + // Registers a DataTransferClient factory under `name`. + static void Register(std::string name, FactoryT factory); + + // Builds a DataTransferClient from the factory registered under `name`. + static Status Build(std::string name, Config config, + std::unique_ptr* out); +}; + +// Server for communicating with the tf.data service transfer client. +class DataTransferServer { + public: + using GetElementT = + std::function; + virtual ~DataTransferServer() = default; + + // Starts DataTransferServer, it should be available for requests afterwards. + virtual Status Start() = 0; + + // Return the port that this server is listening on. + virtual int get_port() = 0; + + // Register a DataTransferServer factory under `name`. + static void Register( + std::string name, + std::function(GetElementT)> factory); + + // Builds a DataTransferServer from the factory registered with `name`. + static Status Build(std::string name, GetElementT get_element, + std::shared_ptr* out); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_ diff --git a/tensorflow/core/data/service/data_transfer_test.cc b/tensorflow/core/data/service/data_transfer_test.cc new file mode 100644 index 00000000000..9054451e885 --- /dev/null +++ b/tensorflow/core/data/service/data_transfer_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/data/service/data_transfer.h" + +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { +namespace { + +class TestDataTransferServer : public DataTransferServer { + public: + explicit TestDataTransferServer(bool* called) : called_(called) {} + Status Start() override { + *called_ = true; + return Status::OK(); + } + int get_port() override { return 0; } + + private: + bool* called_; +}; + +TEST(DataTransferTest, RegisterDataTransferServerBuilder) { + bool called = false; + DataTransferServer::Register("test", [&called](auto _) { + return std::make_shared(&called); + }); + + std::shared_ptr server; + TF_ASSERT_OK(DataTransferServer::Build("test", {}, &server)); + EXPECT_FALSE(called); + + TF_ASSERT_OK(server->Start()); + EXPECT_TRUE(called); +} + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/dispatcher.proto b/tensorflow/core/data/service/dispatcher.proto index 609605b31fa..5a7bc490d7f 100644 --- a/tensorflow/core/data/service/dispatcher.proto +++ b/tensorflow/core/data/service/dispatcher.proto @@ -14,6 +14,7 @@ message TaskProgress { message WorkerHeartbeatRequest { string worker_address = 1; + string transfer_address = 3; repeated int64 current_tasks = 2; } diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index 6cd68fe5c72..23bd16a7f8c 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -217,6 +217,8 @@ Status DataServiceDispatcherImpl::WorkerHeartbeat( } Update update; update.mutable_register_worker()->set_worker_address(worker_address); + update.mutable_register_worker()->set_transfer_address( + request->transfer_address()); TF_RETURN_IF_ERROR(Apply(update)); TF_RETURN_IF_ERROR(CreateTasksForWorker(worker_address)); TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, correct_tasks)); @@ -584,6 +586,9 @@ Status DataServiceDispatcherImpl::CreateTask(std::shared_ptr job, create_task->set_task_id(task_id); create_task->set_job_id(job->job_id); create_task->set_worker_address(worker_address); + std::shared_ptr worker; + TF_RETURN_IF_ERROR(state_.WorkerFromAddress(worker_address, worker)); + create_task->set_transfer_address(worker->transfer_address); TF_RETURN_IF_ERROR(Apply(update)); TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task)); return Status::OK(); @@ -686,6 +691,7 @@ Status DataServiceDispatcherImpl::GetTasks(const GetTasksRequest* request, for (const auto& task : tasks) { TaskInfo* task_info = response->mutable_task_info()->Add(); task_info->set_worker_address(task->worker_address); + task_info->set_transfer_address(task->transfer_address); task_info->set_task_id(task->task_id); task_info->set_job_id(job->job_id); } diff --git a/tensorflow/core/data/service/dispatcher_state.cc b/tensorflow/core/data/service/dispatcher_state.cc index ff3270e29b2..8940b24edcc 100644 --- a/tensorflow/core/data/service/dispatcher_state.cc +++ b/tensorflow/core/data/service/dispatcher_state.cc @@ -74,7 +74,8 @@ void DispatcherState::RegisterWorker( const RegisterWorkerUpdate& register_worker) { std::string address = register_worker.worker_address(); DCHECK(!workers_.contains(address)); - workers_[address] = std::make_shared(address); + workers_[address] = + std::make_shared(address, register_worker.transfer_address()); tasks_by_worker_[address] = absl::flat_hash_map>(); } @@ -146,7 +147,8 @@ void DispatcherState::CreateTask(const CreateTaskUpdate& create_task) { DCHECK_EQ(task, nullptr); auto& job = jobs_[create_task.job_id()]; DCHECK_NE(job, nullptr); - task = std::make_shared(task_id, job, create_task.worker_address()); + task = std::make_shared(task_id, job, create_task.worker_address(), + create_task.transfer_address()); tasks_by_job_[create_task.job_id()].push_back(task); tasks_by_worker_[create_task.worker_address()][task->task_id] = task; next_available_task_id_ = std::max(next_available_task_id_, task_id + 1); diff --git a/tensorflow/core/data/service/dispatcher_state.h b/tensorflow/core/data/service/dispatcher_state.h index 488c95526ca..f341fcf013b 100644 --- a/tensorflow/core/data/service/dispatcher_state.h +++ b/tensorflow/core/data/service/dispatcher_state.h @@ -69,9 +69,12 @@ class DispatcherState { // A worker registered with the dispatcher. struct Worker { - explicit Worker(const std::string& address) : address(address) {} + explicit Worker(const std::string& address, + const std::string& transfer_address) + : address(address), transfer_address(transfer_address) {} const std::string address; + const std::string transfer_address; }; // A key for identifying a named job. The key contains a user-specified name, @@ -128,12 +131,17 @@ class DispatcherState { struct Task { explicit Task(int64 task_id, const std::shared_ptr& job, - const std::string& worker_address) - : task_id(task_id), job(job), worker_address(worker_address) {} + const std::string& worker_address, + const std::string& transfer_address) + : task_id(task_id), + job(job), + worker_address(worker_address), + transfer_address(transfer_address) {} const int64 task_id; const std::shared_ptr job; const std::string worker_address; + const std::string transfer_address; bool finished = false; }; diff --git a/tensorflow/core/data/service/grpc_worker_impl.cc b/tensorflow/core/data/service/grpc_worker_impl.cc index 3c3a81d0daf..a1a0a6f0a13 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.cc +++ b/tensorflow/core/data/service/grpc_worker_impl.cc @@ -31,8 +31,9 @@ GrpcWorkerImpl::GrpcWorkerImpl(const experimental::WorkerConfig& config, VLOG(1) << "Registered data service worker"; } -Status GrpcWorkerImpl::Start(const std::string& worker_address) { - return impl_.Start(worker_address); +Status GrpcWorkerImpl::Start(const std::string& worker_address, + const std::string& transfer_address) { + return impl_.Start(worker_address, transfer_address); } #define HANDLER(method) \ diff --git a/tensorflow/core/data/service/grpc_worker_impl.h b/tensorflow/core/data/service/grpc_worker_impl.h index c094d12d59d..23e1c40bb3d 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.h +++ b/tensorflow/core/data/service/grpc_worker_impl.h @@ -33,7 +33,16 @@ class GrpcWorkerImpl : public WorkerService::Service { ::grpc::ServerBuilder& server_builder); ~GrpcWorkerImpl() override {} - Status Start(const std::string& worker_address); + Status Start(const std::string& worker_address, + const std::string& transfer_address); + + std::function + get_element_getter() { + return + [this](const GetElementRequest* request, GetElementResponse* response) { + return impl_.GetElement(request, response); + }; + } #define HANDLER(method) \ ::grpc::Status method(::grpc::ServerContext* context, \ diff --git a/tensorflow/core/data/service/journal.proto b/tensorflow/core/data/service/journal.proto index 0fc07423822..0b94acd14e8 100644 --- a/tensorflow/core/data/service/journal.proto +++ b/tensorflow/core/data/service/journal.proto @@ -27,6 +27,7 @@ message RegisterDatasetUpdate { message RegisterWorkerUpdate { string worker_address = 1; + string transfer_address = 2; } message NamedJobKeyDef { @@ -71,6 +72,7 @@ message CreateTaskUpdate { int64 task_id = 1; int64 job_id = 2; string worker_address = 4; + string transfer_address = 6; } message FinishTaskUpdate { diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc index af940fe54a3..b942560e737 100644 --- a/tensorflow/core/data/service/server_lib.cc +++ b/tensorflow/core/data/service/server_lib.cc @@ -127,14 +127,27 @@ void WorkerGrpcDataServer::AddDataServiceToBuilder( } Status WorkerGrpcDataServer::StartServiceInternal() { - std::string worker_address = config_.worker_address(); - if (worker_address.empty()) { - worker_address = absl::StrCat("localhost:", kPortPlaceholder); + std::string base_address = config_.worker_address(); + if (base_address.empty()) { + base_address = absl::StrCat("localhost:", kPortPlaceholder); } - std::string resolved_address = str_util::StringReplace( - worker_address, kPortPlaceholder, absl::StrCat(bound_port()), + std::string worker_address = str_util::StringReplace( + base_address, kPortPlaceholder, absl::StrCat(bound_port()), /*replace_all=*/false); - TF_RETURN_IF_ERROR(service_->Start(resolved_address)); + std::string transfer_address = worker_address; + std::string transfer_protocol = config_.data_transfer_protocol(); + if (!transfer_protocol.empty()) { + TF_RETURN_IF_ERROR(DataTransferServer::Build( + transfer_protocol, service_->get_element_getter(), &transfer_server_)); + TF_RETURN_IF_ERROR(transfer_server_->Start()); + LOG(INFO) << "Data transfer server started at 0.0.0.0:" + << transfer_server_->get_port(); + transfer_address = + str_util::StringReplace(base_address, kPortPlaceholder, + absl::StrCat(transfer_server_->get_port()), + /*replace_all=*/false); + } + TF_RETURN_IF_ERROR(service_->Start(worker_address, transfer_address)); return Status::OK(); } diff --git a/tensorflow/core/data/service/server_lib.h b/tensorflow/core/data/service/server_lib.h index ed92097a45c..a173cbb6e49 100644 --- a/tensorflow/core/data/service/server_lib.h +++ b/tensorflow/core/data/service/server_lib.h @@ -18,6 +18,7 @@ limitations under the License. #include "grpcpp/server.h" #include "grpcpp/server_builder.h" +#include "tensorflow/core/data/service/data_transfer.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/profiler/rpc/profiler_service_impl.h" #include "tensorflow/core/protobuf/service_config.pb.h" @@ -109,6 +110,7 @@ class WorkerGrpcDataServer : public GrpcDataServerBase { const experimental::WorkerConfig config_; // Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared. GrpcWorkerImpl* service_; + std::shared_ptr transfer_server_; }; // Creates a dispatch tf.data server and stores it in `out_server`. diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index 5f1aa2930b2..73b24be3f14 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -64,9 +64,11 @@ DataServiceWorkerImpl::~DataServiceWorkerImpl() { heartbeat_cv_.notify_one(); } -Status DataServiceWorkerImpl::Start(const std::string& worker_address) { +Status DataServiceWorkerImpl::Start(const std::string& worker_address, + const std::string& transfer_address) { VLOG(3) << "Starting tf.data service worker at address " << worker_address; worker_address_ = worker_address; + transfer_address_ = transfer_address; dispatcher_ = absl::make_unique( config_.dispatcher_address(), config_.protocol()); @@ -356,8 +358,9 @@ Status DataServiceWorkerImpl::Heartbeat() TF_LOCKS_EXCLUDED(mu_) { } std::vector new_tasks; std::vector tasks_to_delete; - TF_RETURN_IF_ERROR(dispatcher_->WorkerHeartbeat( - worker_address_, current_tasks, new_tasks, tasks_to_delete)); + TF_RETURN_IF_ERROR( + dispatcher_->WorkerHeartbeat(worker_address_, transfer_address_, + current_tasks, new_tasks, tasks_to_delete)); mutex_lock l(mu_); for (const auto& task : new_tasks) { Status s = ProcessTaskInternal(task); diff --git a/tensorflow/core/data/service/worker_impl.h b/tensorflow/core/data/service/worker_impl.h index 80eb5b756a4..c468f0ee1cd 100644 --- a/tensorflow/core/data/service/worker_impl.h +++ b/tensorflow/core/data/service/worker_impl.h @@ -41,7 +41,8 @@ class DataServiceWorkerImpl { // constructor because the worker may be binding to port `0`, in which case // the address isn't known until the worker has started and decided which port // to bind to. - Status Start(const std::string& worker_address); + Status Start(const std::string& worker_address, + const std::string& transfer_address); // See worker.proto for API documentation. @@ -81,6 +82,7 @@ class DataServiceWorkerImpl { const experimental::WorkerConfig config_; // The worker's own address. std::string worker_address_; + std::string transfer_address_; std::unique_ptr dispatcher_; mutex mu_; diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index 4574ffa2219..1520071915f 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -51,6 +51,8 @@ namespace data { /* static */ constexpr const char* const DataServiceDatasetOp::kProcessingMode; /* static */ constexpr const char* const DataServiceDatasetOp::kAddress; /* static */ constexpr const char* const DataServiceDatasetOp::kProtocol; +/* static */ constexpr const char* const + DataServiceDatasetOp::kDataTransferProtocol; /* static */ constexpr const char* const DataServiceDatasetOp::kJobName; /* static */ constexpr const char* const DataServiceDatasetOp::kConsumerIndex; /* static */ constexpr const char* const DataServiceDatasetOp::kNumConsumers; @@ -80,8 +82,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, int op_version, int64 dataset_id, ProcessingMode processing_mode, const std::string& address, - const std::string& protocol, const std::string& job_name, - absl::optional consumer_index, + const std::string& protocol, + const std::string& data_transfer_protocol, + const std::string& job_name, absl::optional consumer_index, absl::optional num_consumers, int64 max_outstanding_requests, int64 task_refresh_interval_ms, IterationCounter* iteration_counter, bool owns_resource, ResourceHandle iteration_counter_handle, @@ -93,6 +96,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { processing_mode_(processing_mode), address_(address), protocol_(protocol), + data_transfer_protocol_(data_transfer_protocol), job_name_(job_name), consumer_index_(consumer_index), num_consumers_(num_consumers), @@ -164,6 +168,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(b->AddScalar(protocol_, &protocol)); inputs.push_back(protocol); + AttrValue data_transfer_protocol; + b->BuildAttrValue(data_transfer_protocol_, &data_transfer_protocol); + Node* job_name; TF_RETURN_IF_ERROR(b->AddScalar(job_name_, &job_name)); inputs.push_back(job_name); @@ -195,11 +202,12 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { b->BuildAttrValue(task_refresh_interval_ms_, &task_refresh_interval_hint_ms); - TF_RETURN_IF_ERROR( - b->AddDataset(this, inputs, - {std::make_pair(kTaskRefreshIntervalHintMs, - task_refresh_interval_hint_ms)}, - output)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, inputs, + {std::make_pair(kTaskRefreshIntervalHintMs, + task_refresh_interval_hint_ms), + std::make_pair(kDataTransferProtocol, data_transfer_protocol)}, + output)); return Status::OK(); } @@ -459,8 +467,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { } TaskInfo& task_info = it->second; std::unique_ptr worker; - Status s = CreateDataServiceWorkerClient(task_info.worker_address(), - dataset()->protocol_, worker); + Status s = CreateDataServiceWorkerClient( + task_info.transfer_address(), dataset()->protocol_, + dataset()->data_transfer_protocol_, worker); if (!s.ok()) { status_ = s; get_next_cv_.notify_all(); @@ -743,6 +752,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { const ProcessingMode processing_mode_; const tstring address_; const tstring protocol_; + const tstring data_transfer_protocol_; const tstring job_name_; const absl::optional consumer_index_; const absl::optional num_consumers_; @@ -765,6 +775,11 @@ DataServiceDatasetOp::DataServiceDatasetOp(OpKernelConstruction* ctx) } OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); + if (ctx->HasAttr(kDataTransferProtocol)) { + OP_REQUIRES_OK( + ctx, ctx->GetAttr(kDataTransferProtocol, &data_transfer_protocol_)); + } + if (data_transfer_protocol_.empty()) data_transfer_protocol_ = "grpc"; auto& op_name = ctx->def().op(); if (op_name == kDataServiceDatasetV1) { op_version_ = 1; @@ -859,11 +874,12 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx, errors::InvalidArgument(kMaxOutstandingRequests, " must be positive or ", model::kAutotune)); - *output = new Dataset( - ctx, op_version_, dataset_id, processing_mode, address, protocol, - job_name, consumer_index, num_consumers, max_outstanding_requests, - task_refresh_interval_hint_ms_, iteration_counter, owns_resource, - iteration_counter_handle, output_types_, output_shapes_); + *output = new Dataset(ctx, op_version_, dataset_id, processing_mode, address, + protocol, data_transfer_protocol_, job_name, + consumer_index, num_consumers, max_outstanding_requests, + task_refresh_interval_hint_ms_, iteration_counter, + owns_resource, iteration_counter_handle, output_types_, + output_shapes_); } REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h index efb013bccdc..ada54eae2b6 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h @@ -51,6 +51,8 @@ class DataServiceDatasetOp : public DatasetOpKernel { static constexpr const char* const kProcessingMode = "processing_mode"; static constexpr const char* const kAddress = "address"; static constexpr const char* const kProtocol = "protocol"; + static constexpr const char* const kDataTransferProtocol = + "data_transfer_protocol"; static constexpr const char* const kJobName = "job_name"; static constexpr const char* const kConsumerIndex = "consumer_index"; static constexpr const char* const kNumConsumers = "num_consumers"; @@ -76,6 +78,7 @@ class DataServiceDatasetOp : public DatasetOpKernel { int64 task_refresh_interval_hint_ms_; DataTypeVector output_types_; std::vector output_shapes_; + std::string data_transfer_protocol_; }; } // namespace data diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index 8a41f809b69..84ab94688f4 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -1184,6 +1184,7 @@ REGISTER_OP("DataServiceDataset") .Attr("task_refresh_interval_hint_ms: int = -1") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .Attr("data_transfer_protocol: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); @@ -1203,6 +1204,7 @@ REGISTER_OP("DataServiceDatasetV2") .Attr("task_refresh_interval_hint_ms: int = -1") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .Attr("data_transfer_protocol: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); diff --git a/tensorflow/core/protobuf/service_config.proto b/tensorflow/core/protobuf/service_config.proto index 3dcd2cd48d0..c45f8526c85 100644 --- a/tensorflow/core/protobuf/service_config.proto +++ b/tensorflow/core/protobuf/service_config.proto @@ -40,4 +40,6 @@ message WorkerConfig { // How long to retry requests to the dispatcher before giving up and reporting // an error. int64 dispatcher_timeout_ms = 6; + // The protocol for the worker to use when transferring data to clients. + string data_transfer_protocol = 7; } diff --git a/tensorflow/python/data/experimental/ops/data_service_ops.py b/tensorflow/python/data/experimental/ops/data_service_ops.py index 620f0ef415e..73423fbb66f 100644 --- a/tensorflow/python/data/experimental/ops/data_service_ops.py +++ b/tensorflow/python/data/experimental/ops/data_service_ops.py @@ -59,6 +59,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource): processing_mode, address, protocol, + data_transfer_protocol, job_name=None, consumer_index=None, num_consumers=None, @@ -76,6 +77,8 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource): address: The tf.data service address, e.g. "localhost:5000". protocol: The protocol to use for communicating with the tf.data service, e.g. "grpc". + data_transfer_protocol: The protocol to use for transferring data with the + tf.data service, e.g. "grpc". job_name: (Optional.) The name of the job. This argument makes it possible for multiple datasets to share the same job. The default behavior is that the dataset creates anonymous, exclusively owned jobs. @@ -138,6 +141,10 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource): # represented by scalar DT_VARIANTs. self._element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant) + compat_kwargs = {} + if data_transfer_protocol is not None: + compat_kwargs["data_transfer_protocol"] = data_transfer_protocol + if num_consumers is None: variant_tensor = gen_experimental_dataset_ops.data_service_dataset( dataset_id=self._dataset_id, @@ -149,6 +156,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource): task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, iteration_counter=gen_experimental_dataset_ops .dummy_iteration_counter(), + **compat_kwargs, **self._flat_structure) else: variant_tensor = gen_experimental_dataset_ops.data_service_dataset_v2( @@ -163,6 +171,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource): task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, iteration_counter=gen_experimental_dataset_ops .dummy_iteration_counter(), + **compat_kwargs, **self._flat_structure) super(_DataServiceDatasetV2, self).__init__(variant_tensor) @@ -175,15 +184,16 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter): """A `Dataset` that executes its input through the tf.data service.""" @functools.wraps(_DataServiceDatasetV2.__init__) - def __init__(self, dataset_id, processing_mode, address, protocol, job_name, - consumer_index, num_consumers, max_outstanding_requests, - task_refresh_interval_hint_ms): + def __init__(self, dataset_id, processing_mode, address, protocol, + data_transfer_protocol, job_name, consumer_index, num_consumers, + max_outstanding_requests, task_refresh_interval_hint_ms): self._wrapped = _DataServiceDatasetV2( dataset_id=dataset_id, processing_mode=processing_mode, address=address, protocol=protocol, + data_transfer_protocol=data_transfer_protocol, job_name=job_name, consumer_index=consumer_index, num_consumers=num_consumers, @@ -233,7 +243,8 @@ def _from_dataset_id(processing_mode, consumer_index=None, num_consumers=None, max_outstanding_requests=None, - task_refresh_interval_hint_ms=None): + task_refresh_interval_hint_ms=None, + data_transfer_protocol=None): """Creates a dataset which reads data from the tf.data service. This transformation is similar to `from_dataset_id`, but supports additional @@ -274,6 +285,8 @@ def _from_dataset_id(processing_mode, `max_outstanding_requests` of memory. task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the dispatcher for task changes. + data_transfer_protocol: (Optional.) The protocol to use for transferring + data with the tf.data service. Returns: A `tf.data.Dataset` which reads from the tf.data service. @@ -294,6 +307,7 @@ def _from_dataset_id(processing_mode, processing_mode=processing_mode, address=address, protocol=protocol, + data_transfer_protocol=data_transfer_protocol, job_name=job_name, consumer_index=consumer_index, num_consumers=num_consumers, @@ -317,7 +331,8 @@ def _distribute(processing_mode, consumer_index=None, num_consumers=None, max_outstanding_requests=None, - task_refresh_interval_hint_ms=None): + task_refresh_interval_hint_ms=None, + data_transfer_protocol=None): """A transformation that moves dataset processing to the tf.data service. This transformation is similar to `distribute`, but supports additional @@ -352,6 +367,8 @@ def _distribute(processing_mode, `max_outstanding_requests` of memory. task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the dispatcher for task changes. + data_transfer_protocol: (Optional.) The protocol to use for transferring + data with the tf.data service. Returns: Dataset: A `Dataset` of the elements produced by the data service. @@ -369,7 +386,8 @@ def _distribute(processing_mode, consumer_index=consumer_index, num_consumers=num_consumers, max_outstanding_requests=max_outstanding_requests, - task_refresh_interval_hint_ms=task_refresh_interval_hint_ms) + task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, + data_transfer_protocol=data_transfer_protocol) return _apply_fn @@ -380,7 +398,8 @@ def distribute(processing_mode, job_name=None, consumer_index=None, num_consumers=None, - max_outstanding_requests=None): + max_outstanding_requests=None, + data_transfer_protocol=None): """A transformation that moves dataset processing to the tf.data service. When you iterate over a dataset containing the `distribute` transformation, @@ -580,6 +599,8 @@ def distribute(processing_mode, requested at the same time. You can use this option to control the amount of memory used, since `distribute` won't use more than `element_size` * `max_outstanding_requests` of memory. + data_transfer_protocol: (Optional.) The protocol to use for transferring + data with the tf.data service, e.g. "grpc". Returns: Dataset: A `Dataset` of the elements produced by the data service. @@ -590,7 +611,8 @@ def distribute(processing_mode, job_name=job_name, consumer_index=consumer_index, num_consumers=num_consumers, - max_outstanding_requests=max_outstanding_requests) + max_outstanding_requests=max_outstanding_requests, + data_transfer_protocol=data_transfer_protocol) @tf_export("data.experimental.service.register_dataset") diff --git a/tensorflow/python/data/experimental/service/server_lib.py b/tensorflow/python/data/experimental/service/server_lib.py index 8e8e3f15f9c..9218bb34d21 100644 --- a/tensorflow/python/data/experimental/service/server_lib.py +++ b/tensorflow/python/data/experimental/service/server_lib.py @@ -307,7 +307,8 @@ class WorkerServer(object): port=config.port, protocol=config.protocol, heartbeat_interval_ms=config.heartbeat_interval_ms, - dispatcher_timeout_ms=config.dispatcher_timeout_ms) + dispatcher_timeout_ms=config.dispatcher_timeout_ms, + data_transfer_protocol=None) self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( config_proto.SerializeToString()) if start: diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.pbtxt index 35dc9254038..4dbd1e8f2d3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.pbtxt @@ -10,7 +10,7 @@ tf_module { } member_method { name: "distribute" - argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'data_transfer_protocol\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_dataset_id" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 2f1ca248a9b..0019be4397b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1006,11 +1006,11 @@ tf_module { } member_method { name: "DataServiceDataset" - argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'data_transfer_protocol\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'None\'], " } member_method { name: "DataServiceDatasetV2" - argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'data_transfer_protocol\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'None\'], " } member_method { name: "DatasetCardinality" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.pbtxt index d29349bb4fe..90a1447bc6c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.pbtxt @@ -18,7 +18,7 @@ tf_module { } member_method { name: "distribute" - argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'data_transfer_protocol\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "from_dataset_id" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 2f1ca248a9b..0019be4397b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1006,11 +1006,11 @@ tf_module { } member_method { name: "DataServiceDataset" - argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'data_transfer_protocol\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'None\'], " } member_method { name: "DataServiceDatasetV2" - argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'data_transfer_protocol\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\', \'None\'], " } member_method { name: "DatasetCardinality"