From 655d6d3cca3e301fc48b80f605e481ac9b2b6bc5 Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Thu, 30 Apr 2020 12:48:16 -0700 Subject: [PATCH] [tf.data service] Support __iter__ for tf.data service datasets. Now we can iterate through tf.data service datasets with `for elem in dataset`, and `distributed_dataset.repeat()` will work correctly. This CL removes the previous method of iteration via CreateJob/CreateDataServiceIterator. It wasn't yet made public, so it is OK to remove the old ops. PiperOrigin-RevId: 309281077 Change-Id: I9531f7d2834ce6669f15896d8c830d23d8277b13 --- .../api_def/base_api/api_def_CreateJob.pbtxt | 5 - .../api_def_MakeDataServiceIterator.pbtxt | 5 - tensorflow/core/data/service/data_service.cc | 28 +++ tensorflow/core/data/service/data_service.h | 7 + .../core/data/service/data_service_test.cc | 24 +++ tensorflow/core/data/service/worker_impl.cc | 3 + tensorflow/core/framework/dataset.h | 36 ---- tensorflow/core/framework/dataset_test.cc | 23 --- .../experimental/data_service_dataset_op.cc | 59 ++++-- .../experimental/data_service_dataset_op.h | 2 + .../data/experimental/data_service_ops.cc | 66 ------ .../data/experimental/data_service_ops.h | 35 ---- tensorflow/core/kernels/data/iterator_ops.cc | 12 +- tensorflow/core/kernels/data/iterator_ops.h | 7 +- .../ops/compat/ops_history_v1/CreateJob.pbtxt | 23 --- .../MakeDataServiceIterator.pbtxt | 16 -- .../ops/compat/ops_history_v2/CreateJob.pbtxt | 23 --- .../MakeDataServiceIterator.pbtxt | 16 -- .../core/ops/experimental_dataset_ops.cc | 16 +- .../data/experimental/ops/data_service_ops.py | 191 +++++------------- .../kernel_tests/data_service_ops_test.py | 101 +++++---- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 10 +- .../api/golden/v2/tensorflow.raw_ops.pbtxt | 10 +- 23 files changed, 211 insertions(+), 507 deletions(-) delete mode 100644 tensorflow/core/api_def/base_api/api_def_CreateJob.pbtxt delete mode 100644 tensorflow/core/api_def/base_api/api_def_MakeDataServiceIterator.pbtxt delete mode 100644 tensorflow/core/ops/compat/ops_history_v1/CreateJob.pbtxt delete mode 100644 tensorflow/core/ops/compat/ops_history_v1/MakeDataServiceIterator.pbtxt delete mode 100644 tensorflow/core/ops/compat/ops_history_v2/CreateJob.pbtxt delete mode 100644 tensorflow/core/ops/compat/ops_history_v2/MakeDataServiceIterator.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_CreateJob.pbtxt b/tensorflow/core/api_def/base_api/api_def_CreateJob.pbtxt deleted file mode 100644 index f6e41e58897..00000000000 --- a/tensorflow/core/api_def/base_api/api_def_CreateJob.pbtxt +++ /dev/null @@ -1,5 +0,0 @@ -op { - graph_op_name: "CreateJob" - visibility: HIDDEN - summary: "Creates a tf.data service job." -} diff --git a/tensorflow/core/api_def/base_api/api_def_MakeDataServiceIterator.pbtxt b/tensorflow/core/api_def/base_api/api_def_MakeDataServiceIterator.pbtxt deleted file mode 100644 index 0d516687ebc..00000000000 --- a/tensorflow/core/api_def/base_api/api_def_MakeDataServiceIterator.pbtxt +++ /dev/null @@ -1,5 +0,0 @@ -op { - graph_op_name: "MakeDataServiceIterator" - visibility: HIDDEN - summary: "Creates an iterator for reading from the tf.data service." -} diff --git a/tensorflow/core/data/service/data_service.cc b/tensorflow/core/data/service/data_service.cc index f961683a775..688e3214a47 100644 --- a/tensorflow/core/data/service/data_service.cc +++ b/tensorflow/core/data/service/data_service.cc @@ -26,6 +26,34 @@ limitations under the License. namespace tensorflow { namespace data { +namespace { +constexpr const char kParallelEpochs[] = "parallel_epochs"; +constexpr const char kOneEpoch[] = "one_epoch"; +} // namespace + +Status ParseProcessingMode(absl::string_view s, ProcessingMode* mode) { + if (s == kParallelEpochs) { + *mode = ProcessingMode::PARALLEL_EPOCHS; + } else if (s == kOneEpoch) { + *mode = ProcessingMode::ONE_EPOCH; + } else { + return errors::InvalidArgument("Unrecognized processing mode: ", s); + } + return Status::OK(); +} + +std::string ProcessingModeToString(ProcessingMode mode) { + switch (mode) { + case ProcessingMode::PARALLEL_EPOCHS: + return kParallelEpochs; + case ProcessingMode::ONE_EPOCH: + return kOneEpoch; + default: + DCHECK(false); + return "Unknown"; + } +} + Status DataServiceMasterClient::CreateJob(int64 dataset_id, ProcessingMode processing_mode, int64* job_id) { diff --git a/tensorflow/core/data/service/data_service.h b/tensorflow/core/data/service/data_service.h index c54c0c33390..009e6d25f60 100644 --- a/tensorflow/core/data/service/data_service.h +++ b/tensorflow/core/data/service/data_service.h @@ -33,6 +33,13 @@ enum class ProcessingMode : int64 { ONE_EPOCH = 1, }; +// Parses a string representing a processing mode and stores the result in +// *mode. Returns an InvalidArgument status if the string is not recognized. +Status ParseProcessingMode(absl::string_view s, ProcessingMode* mode); + +// Converts a processing mode to its corresponding string. +std::string ProcessingModeToString(ProcessingMode mode); + // Base class for data service clients. Data service clients are // thread-compatible, requiring external synchronization when used from multiple // threads. diff --git a/tensorflow/core/data/service/data_service_test.cc b/tensorflow/core/data/service/data_service_test.cc index f4c3c0e13e7..73a46bad3d0 100644 --- a/tensorflow/core/data/service/data_service_test.cc +++ b/tensorflow/core/data/service/data_service_test.cc @@ -38,6 +38,30 @@ namespace data { namespace { constexpr const char kProtocol[] = "grpc+local"; +TEST(DataService, ParseParallelEpochsProcessingMode) { + ProcessingMode mode; + TF_ASSERT_OK(ParseProcessingMode("parallel_epochs", &mode)); + EXPECT_EQ(mode, ProcessingMode::PARALLEL_EPOCHS); +} + +TEST(DataService, ParseOneEpochProcessingMode) { + ProcessingMode mode; + TF_ASSERT_OK(ParseProcessingMode("one_epoch", &mode)); + EXPECT_EQ(mode, ProcessingMode::ONE_EPOCH); +} + +TEST(DataService, ParseInvalidProcessingMode) { + ProcessingMode mode; + Status s = ParseProcessingMode("invalid", &mode); + EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); +} + +TEST(DataService, ProcessingModeToString) { + EXPECT_EQ("parallel_epochs", + ProcessingModeToString(ProcessingMode::PARALLEL_EPOCHS)); + EXPECT_EQ("one_epoch", ProcessingModeToString(ProcessingMode::ONE_EPOCH)); +} + Status CheckWorkerOutput(const std::string& worker_address, int64 task_id, std::vector> expected_output) { DataServiceWorkerClient worker(worker_address, kProtocol); diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index b0d3275313a..7395244a569 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -117,11 +117,13 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, } std::unique_ptr& iter = it->second.iterator; if (iter == nullptr) { + VLOG(3) << "Task " << request->task_id() << " is already finished"; response->set_end_of_sequence(true); return Status::OK(); } TF_RETURN_IF_ERROR(iter->GetNext(&outputs, &end_of_sequence)); if (end_of_sequence) { + VLOG(3) << "Reached end_of_sequence for task " << request->task_id(); // Release iterator memory and leave a null entry as a tombstone. iter.reset(); pending_completed_tasks_.push_back(request->task_id()); @@ -130,6 +132,7 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, } if (!end_of_sequence) { + VLOG(3) << "Producing an element for task " << request->task_id(); TF_RETURN_IF_ERROR(service_util::Compress( outputs, response->mutable_compressed_element())); } diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 1d9559d1878..3635cf7c4ba 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -292,36 +292,6 @@ class Runner { static Runner* get(); }; -// A token for reading from a tf.data service job. -class JobToken { - public: - JobToken() : is_empty_(true) {} - - explicit JobToken(int64 job_id) : job_id_(job_id), is_empty_(false) {} - - bool is_empty() const { return is_empty_; } - int64 job_id() const { return job_id_; } - string TypeName() const { return "tensorflow::JobToken"; } - void Encode(VariantTensorData* data) const { - Tensor job_id = Tensor(DT_INT64, TensorShape({})); - job_id.scalar()() = job_id_; - *(data->add_tensors()) = job_id; - - Tensor is_empty = Tensor(DT_BOOL, TensorShape({})); - is_empty.scalar()() = is_empty_; - *(data->add_tensors()) = is_empty; - } - bool Decode(const VariantTensorData& data) { - job_id_ = data.tensors(0).scalar()(); - is_empty_ = data.tensors(1).scalar()(); - return true; - } - - private: - int64 job_id_; - bool is_empty_; -}; - // A cut-down version of `OpKernelContext` for running computations in // iterators. Note that we cannot simply use `OpKernelContext` here because we // might run computation in an iterator whose lifetime is not nested within the @@ -342,7 +312,6 @@ class IteratorContext { env(ctx->env()), flr(ctx->flr()), function_handle_cache(ctx->function_handle_cache()), - job_token(ctx->job_token()), resource_mgr(ctx->resource_mgr()), model(ctx->model()), runner(*(ctx->runner())), @@ -401,9 +370,6 @@ class IteratorContext { // A FunctionHandleCache that owns all the function handles. Not owned. FunctionHandleCache* function_handle_cache = nullptr; - // A token for reading data from a tf.data service job. - JobToken job_token; - // A resource manager for storing dataset-related state, e.g. random // seeds or cached tensors. Not owned. ResourceMgr* resource_mgr = nullptr; @@ -453,8 +419,6 @@ class IteratorContext { return params_.function_handle_cache; } - const JobToken& job_token() { return params_.job_token; } - ResourceMgr* resource_mgr() { return params_.resource_mgr; } const std::shared_ptr& model() { return params_.model; } diff --git a/tensorflow/core/framework/dataset_test.cc b/tensorflow/core/framework/dataset_test.cc index 49a4763e8cb..9dbb3be7faf 100644 --- a/tensorflow/core/framework/dataset_test.cc +++ b/tensorflow/core/framework/dataset_test.cc @@ -91,27 +91,4 @@ INSTANTIATE_TEST_SUITE_P( {_tf_string_, tensor_strs, static_cast(sizeof(str) + str.size()) /*bytes*/}})); -TEST(DatasetTest, JobServiceTokenIsEmpty) { - data::JobToken token; - EXPECT_TRUE(token.is_empty()); -} - -TEST(DatasetTest, JobTokenHoldsJobId) { - int64 job_id = 5; - data::JobToken token(job_id); - EXPECT_EQ(job_id, token.job_id()); - EXPECT_FALSE(token.is_empty()); -} - -TEST(DatasetTest, JobTokenEncodeDecode) { - int64 job_id = 5; - data::JobToken token(job_id); - VariantTensorData data; - token.Encode(&data); - data::JobToken decoded; - decoded.Decode(data); - EXPECT_FALSE(token.is_empty()); - EXPECT_EQ(job_id, token.job_id()); -} - } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index 815468d98a3..80425215121 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -43,6 +43,8 @@ namespace tensorflow { namespace data { /* static */ constexpr const char* const DataServiceDatasetOp::kDatasetType; +/* static */ constexpr const char* const DataServiceDatasetOp::kDatasetId; +/* static */ constexpr const char* const DataServiceDatasetOp::kProcessingMode; /* static */ constexpr const char* const DataServiceDatasetOp::kAddress; /* static */ constexpr const char* const DataServiceDatasetOp::kProtocol; /* static */ constexpr const char* const @@ -50,6 +52,7 @@ namespace data { /* static */ constexpr const char* const DataServiceDatasetOp::kOutputTypes; /* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes; +namespace { // Once we've spent `kRetryTimeoutMicros` in `GetNextInternal`, we will wait for // the current attempt to complete and perform no more retries. const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes. @@ -57,6 +60,8 @@ const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes. // Default interval between task list refreshes. const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second. +} // namespace + // Dataset for reading data from the tf.data service non-deterministically. // // This dataset interleaves dataset elements produced by multiple tf.data @@ -64,11 +69,14 @@ const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second. // to read from (in case workers are added or removed). class DataServiceDatasetOp::Dataset : public DatasetBase { public: - Dataset(OpKernelContext* ctx, const std::string& address, + Dataset(OpKernelContext* ctx, int64 dataset_id, + ProcessingMode processing_mode, const std::string& address, const std::string& protocol, int64 max_outstanding_requests, int64 task_refresh_interval_ms, const DataTypeVector& output_types, const std::vector& output_shapes) : DatasetBase(DatasetContext(ctx)), + dataset_id_(dataset_id), + processing_mode_(processing_mode), address_(address), protocol_(protocol), max_outstanding_requests_(max_outstanding_requests), @@ -102,6 +110,13 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { + Node* dataset_id; + TF_RETURN_IF_ERROR(b->AddScalar(dataset_id_, &dataset_id)); + + Node* processing_mode; + tstring processing_mode_str = ProcessingModeToString(processing_mode_); + TF_RETURN_IF_ERROR(b->AddScalar(processing_mode_str, &processing_mode)); + Node* address; TF_RETURN_IF_ERROR(b->AddScalar(address_, &address)); @@ -117,7 +132,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { &task_refresh_interval_hint_ms); TF_RETURN_IF_ERROR( - b->AddDataset(this, {address, protocol, max_outstanding_requests}, + b->AddDataset(this, + {dataset_id, processing_mode, address, protocol, + max_outstanding_requests}, {std::make_pair(kTaskRefreshIntervalHintMs, task_refresh_interval_hint_ms)}, output)); @@ -132,6 +149,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { ~Iterator() override { mutex_lock l(mu_); + VLOG(1) << "Destroying data service dataset iterator for job id " + << job_id_; cancelled_ = true; cv_.notify_all(); // Thread destructors will block until the threads finish, no need to wait @@ -141,14 +160,10 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { Status Initialize(IteratorContext* ctx) override { VLOG(3) << "Connecting to " << dataset()->address_ << " in data service dataset op"; - if (ctx->job_token().is_empty()) { - return errors::FailedPrecondition( - "Expected a job token, but none found. To iterate over a dataset " - "containing a `distribute` transformation, call `create_job`, " - "which will return a job token that you should then use to iterate " - "over the dataset via `create_iterator(dataset, job_token).`"); - } - job_id_ = ctx->job_token().job_id(); + DataServiceMasterClient master(dataset()->address_, dataset()->protocol_); + TF_RETURN_IF_ERROR(master.CreateJob( + dataset()->dataset_id_, dataset()->processing_mode_, &job_id_)); + VLOG(1) << "Created data service job with id " << job_id_; return Status::OK(); } @@ -175,6 +190,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { return Status::OK(); } DCHECK(!results_.empty()); + *end_of_sequence = false; out_tensors->swap(results_.front()); results_.pop(); cv_.notify_all(); @@ -287,6 +303,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { continue; } if (!task_ids.contains(task_thread->task_id)) { + VLOG(3) << "Marking removed task thread " << task_thread->task_id + << " as finished"; task_thread->end_of_sequence = true; } ++it; @@ -315,6 +333,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { { mutex_lock l(mu_); if (task_handler->end_of_sequence) { + VLOG(3) << "Task thread " << task_handler->task_id + << " reached end_of_sequence"; return; } outstanding_requests_--; @@ -427,6 +447,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { std::unique_ptr task_thread_manager_ GUARDED_BY(mu_); }; + const int64 dataset_id_; + const ProcessingMode processing_mode_; const tstring address_; const tstring protocol_; const int64 max_outstanding_requests_; @@ -448,6 +470,16 @@ DataServiceDatasetOp::DataServiceDatasetOp(OpKernelConstruction* ctx) void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) { + int64 dataset_id; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kDatasetId, &dataset_id)); + + tstring processing_mode_str; + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, kProcessingMode, &processing_mode_str)); + ProcessingMode processing_mode; + OP_REQUIRES_OK(ctx, + ParseProcessingMode(processing_mode_str, &processing_mode)); + tstring address; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address)); OP_REQUIRES(ctx, !address.empty(), @@ -468,9 +500,10 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx, errors::InvalidArgument(kMaxOutstandingRequests, " must be positive or ", model::kAutotune)); - *output = new Dataset(ctx, address, protocol, max_outstanding_requests, - task_refresh_interval_hint_ms_, output_types_, - output_shapes_); + *output = + new Dataset(ctx, dataset_id, processing_mode, address, protocol, + max_outstanding_requests, task_refresh_interval_hint_ms_, + output_types_, output_shapes_); } REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h index d51cb8c861c..d64ca92bc64 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h @@ -24,6 +24,8 @@ namespace data { class DataServiceDatasetOp : public DatasetOpKernel { public: static constexpr const char* const kDatasetType = "DataService"; + static constexpr const char* const kDatasetId = "dataset_id"; + static constexpr const char* const kProcessingMode = "processing_mode"; static constexpr const char* const kAddress = "address"; static constexpr const char* const kProtocol = "protocol"; static constexpr const char* const kMaxOutstandingRequests = diff --git a/tensorflow/core/kernels/data/experimental/data_service_ops.cc b/tensorflow/core/kernels/data/experimental/data_service_ops.cc index fa3a1a51c1e..c6a54baad64 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_ops.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_ops.cc @@ -22,18 +22,6 @@ limitations under the License. namespace tensorflow { namespace data { -namespace { -Status ParseProcessingMode(const tstring& s, ProcessingMode* mode) { - if (s == "parallel_epochs") { - *mode = ProcessingMode::PARALLEL_EPOCHS; - } else if (s == "one_epoch") { - *mode = ProcessingMode::ONE_EPOCH; - } else { - return errors::InvalidArgument("Unrecognized processing mode: ", s); - } - return Status::OK(); -} -} // namespace RegisterDatasetOp::RegisterDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -75,62 +63,8 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) { output_dataset_id() = dataset_id; } -CreateJobOp::CreateJobOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - -void CreateJobOp::Compute(OpKernelContext* ctx) { - int64 dataset_id; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kDatasetId, &dataset_id)); - - tstring address; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address)); - OP_REQUIRES(ctx, !address.empty(), - errors::InvalidArgument(kAddress, " must be non-empty.")); - - tstring protocol; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kProtocol, &protocol)); - OP_REQUIRES(ctx, !protocol.empty(), - errors::InvalidArgument(kProtocol, " must be non-empty.")); - - tstring processing_mode_str; - OP_REQUIRES_OK( - ctx, ParseScalarArgument(ctx, kProcessingMode, &processing_mode_str)); - ProcessingMode processing_mode; - OP_REQUIRES_OK(ctx, - ParseProcessingMode(processing_mode_str, &processing_mode)); - - DataServiceMasterClient client(address, protocol); - int64 job_id; - OP_REQUIRES_OK(ctx, client.CreateJob(dataset_id, processing_mode, &job_id)); - - JobToken token(job_id); - Tensor* output; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &output)); - auto output_token = output->tensor(); - output_token() = token; -} - -Status MakeDataServiceIteratorOp::DoCompute(OpKernelContext* ctx) { - DatasetBase* dataset; - TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset)); - - const Tensor* token_tensor; - TF_RETURN_IF_ERROR(ctx->input(kJobToken, &token_tensor)); - JobToken token = *token_tensor->scalar()().get(); - - IteratorResource* iterator_resource; - TF_RETURN_IF_ERROR( - LookupResource(ctx, HandleFromInput(ctx, 2), &iterator_resource)); - - core::ScopedUnref unref_iterator(iterator_resource); - - return iterator_resource->SetIteratorFromDataset(ctx, dataset, token); -} - REGISTER_KERNEL_BUILDER(Name("RegisterDataset").Device(DEVICE_CPU), RegisterDatasetOp); -REGISTER_KERNEL_BUILDER(Name("CreateJob").Device(DEVICE_CPU), CreateJobOp); -REGISTER_KERNEL_BUILDER(Name("MakeDataServiceIterator").Device(DEVICE_CPU), - MakeDataServiceIteratorOp); } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/data_service_ops.h b/tensorflow/core/kernels/data/experimental/data_service_ops.h index ebbcb39d0a3..b7d66938ae6 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_ops.h +++ b/tensorflow/core/kernels/data/experimental/data_service_ops.h @@ -44,41 +44,6 @@ class RegisterDatasetOp : public OpKernel { SerializationContext::ExternalStatePolicy external_state_policy_; }; -// Creates a token for reading from the tf.data service. -// -// The dataset_id input identifies which dataset to create a token for. -// The address and protocol inputs are used to connect to the tf.data service -// master. -// The processing_mode defines how the tf.data service should produce data for -// the token. -class CreateJobOp : public OpKernel { - public: - static constexpr const char* const kDatasetId = "dataset_id"; - static constexpr const char* const kAddress = "address"; - static constexpr const char* const kProtocol = "protocol"; - static constexpr const char* const kProcessingMode = "processing_mode"; - - explicit CreateJobOp(OpKernelConstruction* ctx); - - void Compute(OpKernelContext* ctx) override; -}; - -// Creates a new iterator for iterating over a tf.data service dataset. -// -// The epoch_id input identifies which epoch to read from. Multiple iterators -// may read from the same epoch, causing the elements of the epoch to be split -// across all iterators. -class MakeDataServiceIteratorOp : public MakeIteratorOp { - public: - static constexpr const char* const kJobToken = "job_token"; - - explicit MakeDataServiceIteratorOp(OpKernelConstruction* ctx) - : MakeIteratorOp(ctx) {} - - protected: - Status DoCompute(OpKernelContext* ctx) override; -}; - } // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_OPS_H_ diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 21fa5bf6ac2..fde4adf26d7 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -166,8 +166,7 @@ Status IteratorResource::Restore(OpKernelContext* ctx, } Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx, - DatasetBase* dataset, - JobToken job_token) { + DatasetBase* dataset) { std::shared_ptr new_state; { tf_shared_lock l(mu_); @@ -180,7 +179,6 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx, IteratorContext::Params params(ctx); params.flr = new_state->flr; params.function_handle_cache = new_state->function_handle_cache.get(); - params.job_token = job_token; params.resource_mgr = &new_state->resource_mgr; params.thread_factory = unbounded_thread_pool_.get_thread_factory(); params.thread_pool = &unbounded_thread_pool_; @@ -532,9 +530,7 @@ Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) { TF_RETURN_IF_ERROR( LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource)); core::ScopedUnref unref_iterator(iterator_resource); - JobToken empty_token; - return iterator_resource->SetIteratorFromDataset(ctx, dataset, - /*job_token=*/empty_token); + return iterator_resource->SetIteratorFromDataset(ctx, dataset); } void DeleteIteratorOp::Compute(OpKernelContext* ctx) { @@ -855,9 +851,7 @@ class OneShotIteratorOp : public AsyncOpKernel { // factory function. DatasetBase* dataset; TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset)); - JobToken empty_token; - TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset( - ctx, dataset, /*job_token=*/empty_token)); + TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset)); (*iterator)->Ref(); return Status::OK(); } diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h index df99a38b516..fcee8ca20c0 100644 --- a/tensorflow/core/kernels/data/iterator_ops.h +++ b/tensorflow/core/kernels/data/iterator_ops.h @@ -69,14 +69,9 @@ class IteratorResource : public ResourceBase { // Creates an iterator for `dataset`, and associates the iterator with this // iterator resource. // - // The `job_token` will be passed through the IteratorContext when - // creating the iterator. This token is used to read from a tf.data service - // job. - // // `SetIteratorFromDataset` should be called before calling `GetNext`, `Save`, // or `Restore`. - Status SetIteratorFromDataset(OpKernelContext* ctx, DatasetBase* dataset, - JobToken job_token); + Status SetIteratorFromDataset(OpKernelContext* ctx, DatasetBase* dataset); string DebugString() const override { return "Iterator resource"; } diff --git a/tensorflow/core/ops/compat/ops_history_v1/CreateJob.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/CreateJob.pbtxt deleted file mode 100644 index 56c68a8c6df..00000000000 --- a/tensorflow/core/ops/compat/ops_history_v1/CreateJob.pbtxt +++ /dev/null @@ -1,23 +0,0 @@ -op { - name: "CreateJob" - input_arg { - name: "dataset_id" - type: DT_INT64 - } - input_arg { - name: "address" - type: DT_STRING - } - input_arg { - name: "protocol" - type: DT_STRING - } - input_arg { - name: "processing_mode" - type: DT_STRING - } - output_arg { - name: "job_token" - type: DT_VARIANT - } -} diff --git a/tensorflow/core/ops/compat/ops_history_v1/MakeDataServiceIterator.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/MakeDataServiceIterator.pbtxt deleted file mode 100644 index e2061ad3a57..00000000000 --- a/tensorflow/core/ops/compat/ops_history_v1/MakeDataServiceIterator.pbtxt +++ /dev/null @@ -1,16 +0,0 @@ -op { - name: "MakeDataServiceIterator" - input_arg { - name: "dataset" - type: DT_VARIANT - } - input_arg { - name: "job_token" - type: DT_VARIANT - } - input_arg { - name: "iterator" - type: DT_RESOURCE - } - is_stateful: true -} diff --git a/tensorflow/core/ops/compat/ops_history_v2/CreateJob.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/CreateJob.pbtxt deleted file mode 100644 index 56c68a8c6df..00000000000 --- a/tensorflow/core/ops/compat/ops_history_v2/CreateJob.pbtxt +++ /dev/null @@ -1,23 +0,0 @@ -op { - name: "CreateJob" - input_arg { - name: "dataset_id" - type: DT_INT64 - } - input_arg { - name: "address" - type: DT_STRING - } - input_arg { - name: "protocol" - type: DT_STRING - } - input_arg { - name: "processing_mode" - type: DT_STRING - } - output_arg { - name: "job_token" - type: DT_VARIANT - } -} diff --git a/tensorflow/core/ops/compat/ops_history_v2/MakeDataServiceIterator.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/MakeDataServiceIterator.pbtxt deleted file mode 100644 index e2061ad3a57..00000000000 --- a/tensorflow/core/ops/compat/ops_history_v2/MakeDataServiceIterator.pbtxt +++ /dev/null @@ -1,16 +0,0 @@ -op { - name: "MakeDataServiceIterator" - input_arg { - name: "dataset" - type: DT_VARIANT - } - input_arg { - name: "job_token" - type: DT_VARIANT - } - input_arg { - name: "iterator" - type: DT_RESOURCE - } - is_stateful: true -} diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index 480626a2465..cd9585ab3d5 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -1043,6 +1043,8 @@ REGISTER_OP("ExperimentalUniqueDataset") .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("DataServiceDataset") + .Input("dataset_id: int64") + .Input("processing_mode: string") .Input("address: string") .Input("protocol: string") .Input("max_outstanding_requests: int64") @@ -1061,18 +1063,4 @@ REGISTER_OP("RegisterDataset") .Attr("external_state_policy: int") .SetShapeFn(shape_inference::ScalarShape); -REGISTER_OP("CreateJob") - .Input("dataset_id: int64") - .Input("address: string") - .Input("protocol: string") - .Input("processing_mode: string") - .Output("job_token: variant") - .SetShapeFn(shape_inference::ScalarShape); - -REGISTER_OP("MakeDataServiceIterator") - .Input("dataset: variant") - .Input("job_token: variant") - .Input("iterator: resource") - .SetShapeFn(shape_inference::NoOutputs); - } // namespace tensorflow diff --git a/tensorflow/python/data/experimental/ops/data_service_ops.py b/tensorflow/python/data/experimental/ops/data_service_ops.py index f8e9ac15723..b5b54e3e94a 100644 --- a/tensorflow/python/data/experimental/ops/data_service_ops.py +++ b/tensorflow/python/data/experimental/ops/data_service_ops.py @@ -24,8 +24,6 @@ import six from tensorflow.python import tf2 from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_experimental_dataset_ops @@ -36,10 +34,10 @@ class ProcessingMode(object): @staticmethod def validate(mode): - """Raises a TypeError if the given object is not a valid processing mode.""" + """Raises a ValueError if the given object is not a valid processing mode.""" valid_modes = [ProcessingMode.PARALLEL_EPOCHS] if mode not in valid_modes: - raise TypeError( + raise ValueError( "{0} is not a valid processing mode. Valid modes: {1}".format( mode, valid_modes)) @@ -50,6 +48,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource): def __init__(self, input_dataset, dataset_id, + processing_mode, address, protocol, max_outstanding_requests=None, @@ -60,6 +59,9 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource): input_dataset: The input dataset, which should be registered with the tf.data service under `dataset_id`. dataset_id: The dataset id for the dataset to read from. + processing_mode: A string specifying the policy for how data should be + processed by tf.data workers. Currently, the only supported value is + "parallel_epochs". address: The tf.data service address, e.g. "localhost:5000". protocol: The protocol to use for communicating with the tf.data service, e.g. "grpc". @@ -77,13 +79,10 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource): task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE self._element_spec = input_dataset.element_spec - self._dataset_id = dataset_id - self._address = address - self._protocol = protocol - self._max_outstanding_requests = max_outstanding_requests - self._task_refresh_interval_hint_ms = task_refresh_interval_hint_ms variant_tensor = gen_experimental_dataset_ops.data_service_dataset( + dataset_id=dataset_id, + processing_mode=processing_mode, address=address, protocol=protocol, max_outstanding_requests=max_outstanding_requests, @@ -91,18 +90,6 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource): **self._flat_structure) super(_DataServiceDatasetV2, self).__init__(variant_tensor) - @property - def dataset_id(self): - return self._dataset_id - - @property - def address(self): - return self._address - - @property - def protocol(self): - return self._protocol - @property def element_spec(self): return self._element_spec @@ -112,30 +99,20 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter): """A `Dataset` that executes its input through the tf.data service.""" @functools.wraps(_DataServiceDatasetV2.__init__) - def __init__(self, input_dataset, dataset_id, address, protocol, - max_outstanding_requests, task_refresh_interval_hint_ms): + def __init__(self, input_dataset, dataset_id, processing_mode, address, + protocol, max_outstanding_requests, + task_refresh_interval_hint_ms): self._wrapped = _DataServiceDatasetV2( input_dataset=input_dataset, dataset_id=dataset_id, + processing_mode=processing_mode, address=address, protocol=protocol, max_outstanding_requests=max_outstanding_requests, task_refresh_interval_hint_ms=task_refresh_interval_hint_ms) super(_DataServiceDatasetV1, self).__init__(self._wrapped) - @property - def dataset_id(self): - return self._wrapped.dataset_id - - @property - def address(self): - return self._wrapped.address - - @property - def protocol(self): - return self._wrapped.protocol - if tf2.enabled(): _DataServiceDataset = _DataServiceDatasetV2 @@ -143,7 +120,8 @@ else: _DataServiceDataset = _DataServiceDatasetV1 -def _distribute(service, +def _distribute(processing_mode, + service, max_outstanding_requests=None, task_refresh_interval_hint_ms=None): """A transformation that moves dataset processing to the tf.data service. @@ -152,6 +130,9 @@ def _distribute(service, parameters which we do not yet want to add to the public Python API. Args: + processing_mode: A string specifying the policy for how data should be + processed by tf.data workers. Currently, the only supported value is + "parallel_epochs". service: A string indicating how to connect to the tf.data service. The string should be in the format ://
, e.g. grpc://localhost:5000. @@ -165,6 +146,7 @@ def _distribute(service, Returns: Dataset: A `Dataset` of the elements produced by the data service. """ + ProcessingMode.validate(processing_mode) if not isinstance(service, six.string_types): raise ValueError( "service must be a string, but service was of type {0}. service={1}" @@ -197,6 +179,7 @@ def _distribute(service, return _DataServiceDataset( input_dataset=dataset, dataset_id=dataset_id, + processing_mode=processing_mode, address=address, protocol=protocol, max_outstanding_requests=max_outstanding_requests, @@ -205,52 +188,9 @@ def _distribute(service, return _apply_fn -def distribute(service, max_outstanding_requests=None): +def distribute(processing_mode, service, max_outstanding_requests=None): """A transformation that moves dataset processing to the tf.data service. - ``` - dataset = tf.data.Dataset.range(10) - dataset = dataset.map(lambda x: x*x) - dataset = dataset.apply( - tf.data.experimental.service.distribute("grpc://dataservice:5000")) - dataset = dataset.map(lambda x: x+10) - - job_token = tf.data.experimental.service.create_job(dataset) - it = tf.data.experimental.service.create_iterator(dataset, job_token) - for element in it: - # process element - ``` - - In the above example, the first two lines (before the call to `distribute`) - will be executed on tf.data workers, and the elements provided over - RPC. The remaining transformations (after the call to `distribute`) will be - executed locally. - - The token returned from `create_job` may be used to create multiple - coordinated iterators which consume data from the same job. - - Args: - service: A string indicating how to connect to the tf.data service. The - string should be in the format ://
, e.g. - grpc://localhost:5000. - max_outstanding_requests: (Optional.) A limit on how many elements may be - requested at the same time. You can use this option to control the amount - of memory used, since `distribute` won't use more than `element_size` * - `max_outstanding_requests` of memory. - - Returns: - Dataset: A `Dataset` of the elements produced by the data service. - """ - return _distribute(service, max_outstanding_requests) - - -def create_job(dataset, processing_mode): - """Creates a job for reading a dataset through the tf.data service. - - The returned token can be used to create iterators for consuming data from - the job. `processing_mode` controls what data will be produced. Iterators - created from the same token will consume from the same job. - The `processing_mode` argument controls how data is processed by the tf.data service. Currently, the only supported mode is "parallel_epochs". @@ -263,77 +203,40 @@ def create_job(dataset, processing_mode): to randomly shuffle your dataset, so that different tf.data workers will iterate through the dataset in different orders. - In the future, we plan to add additional epoch modes. For example, we will add + In the future, there will be additional epoch modes. For example, a "one_epoch" mode which partitions the dataset across the tf.data workers, so that the consumers see each element of the dataset only once. + ``` + dataset = tf.data.Dataset.range(10) + dataset = dataset.map(lambda x: x*x) + dataset = dataset.apply( + tf.data.experimental.service.distribute("parallel_epochs", + "grpc://dataservice:5000")) + dataset = dataset.map(lambda x: x+10) + + for element in dataset: + # process element + ``` + + In the above example, the first two lines (before the call to `distribute`) + will be executed on tf.data workers, and the elements provided over + RPC. The remaining transformations (after the call to `distribute`) will be + executed locally. + Args: - dataset: A `tf.data.Dataset` to create a job for. The dataset must contain a - single `distribute` transformation. processing_mode: A string specifying the policy for how data should be processed by tf.data workers. Currently, the only supported value is "parallel_epochs". + service: A string indicating how to connect to the tf.data service. The + string should be in the format ://
, e.g. + grpc://localhost:5000. + max_outstanding_requests: (Optional.) A limit on how many elements may be + requested at the same time. You can use this option to control the amount + of memory used, since `distribute` won't use more than `element_size` * + `max_outstanding_requests` of memory. Returns: - A token for reading from the created tf.data service job. To read using the - token, call `create_iterator(dataset, token)` - - Raises: - ValueError: If the dataset contains no calls to `distribute` or more than 1 - call to `distribute`. + Dataset: A `Dataset` of the elements produced by the data service. """ - datasets = _find_data_service_datasets(dataset) - if len(datasets) > 1: - raise ValueError( - "Datasets containing multiple calls to .distribute(...) are " + - "not supported") - if not datasets: - raise ValueError( - "Dataset does not contain any distribute() transformations") - ProcessingMode.validate(processing_mode) - data_service_dataset = datasets[0] - return gen_experimental_dataset_ops.create_job( - data_service_dataset.dataset_id, data_service_dataset.address, - data_service_dataset.protocol, processing_mode) - - -def create_iterator(dataset, job_token): - """Creates an iterator for reading from the tf.data service. - - Args: - dataset: A `tf.data.Dataset` object. - job_token: A token generated by `create_job`. - - Returns: - A dataset iterator. - - Raises: - RuntimeError: If called outside of a function in graph mode. - """ - if context.executing_eagerly() or ops.inside_function(): - return iterator_ops.OwnedIterator(dataset, job_token=job_token) - else: - raise RuntimeError("create_iterator() is only supported inside of " - "tf.function or when eager execution is enabled.") - - -def _find_data_service_datasets(dataset): - """Produces a list of all data service datasets in the given dataset. - - Args: - dataset: A `tf.data.Dataset`. - - Returns: - A list of all data service datasets. - """ - result = [] - to_check = [dataset] - while to_check: - d = to_check.pop() - if isinstance(d, dataset_ops.DatasetV1Adapter): - d = d._dataset # pylint: disable=protected-access - if isinstance(d, _DataServiceDatasetV1) or isinstance( - d, _DataServiceDatasetV2): - result.append(d) - to_check.extend(d._inputs()) # pylint: disable=protected-access - return result + return _distribute(processing_mode, service, max_outstanding_requests) diff --git a/tensorflow/python/data/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/kernel_tests/data_service_ops_test.py index 55fad6f7b7e..b6e963959e4 100644 --- a/tensorflow/python/data/kernel_tests/data_service_ops_test.py +++ b/tensorflow/python/data/kernel_tests/data_service_ops_test.py @@ -40,7 +40,8 @@ PROTOCOL = "grpc" def _make_distributed_dataset(dataset, service): """Creates a distributed dataset with a short task refresh interval.""" return dataset.apply( - data_service_ops._distribute(service, task_refresh_interval_hint_ms=20)) + data_service_ops._distribute( + "parallel_epochs", service, task_refresh_interval_hint_ms=20)) class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): @@ -65,27 +66,35 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): return self._master.target - @combinations.generate(test_base.eager_only_combinations()) - def testMultipleEpochs(self): - service = self.create_cluster(1) - ds = dataset_ops.Dataset.range(3) - ds = _make_distributed_dataset(ds, service) - for _ in range(10): - token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") - it = data_service_ops.create_iterator(ds, token) - self.assertEqual(list(range(3)), [t.numpy() for t in it]) - @combinations.generate(test_base.eager_only_combinations()) def testDistributeBasic(self): num_elements = 10 service = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, service) - token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") - it = data_service_ops.create_iterator(ds, token) - results = [t.numpy() for t in it] + results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) + @combinations.generate(test_base.eager_only_combinations()) + def testMultipleEpochs(self): + num_elements = 3 + service = self.create_cluster(1) + ds = dataset_ops.Dataset.range(num_elements) + ds = _make_distributed_dataset(ds, service) + for _ in range(10): + self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds]) + + @combinations.generate(test_base.eager_only_combinations()) + def testRepeatedDataset(self): + num_elements = 10 + num_repetitions = 5 + service = self.create_cluster(1) + ds = dataset_ops.Dataset.range(num_elements) + ds = _make_distributed_dataset(ds, service) + ds = ds.repeat(num_repetitions) + self.assertDatasetProduces( + ds, expected_output=num_repetitions * list(range(num_elements))) + @combinations.generate(test_base.eager_only_combinations()) def testConcurrentEpoch(self): num_elements = 10 @@ -96,9 +105,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): for _ in range(num_datasets): ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, service) - token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") - it = data_service_ops.create_iterator(ds, token) - iterators.append(it) + iterators.append(iter(ds)) results.append([]) for _ in range(num_elements): @@ -110,6 +117,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testSharedEpoch(self): + self.skipTest("Not yet implemented") num_elements = 10 num_iterators = 3 service = self.create_cluster(1) @@ -117,9 +125,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): ds = _make_distributed_dataset(ds, service) result = [] iterators = [] - token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") for _ in range(num_iterators): - iterators.append(data_service_ops.create_iterator(ds, token)) + iterators.append(iter(ds)) # Alternate reading between the iterators. for _ in range(2): @@ -140,9 +147,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): service = self.create_cluster(num_workers) ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, service) - token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") - iterator = data_service_ops.create_iterator(ds, token) - results = [elem.numpy() for elem in iterator] + results = [elem.numpy() for elem in ds] self.assertCountEqual(num_workers * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) @@ -154,8 +159,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, self._master.target) - token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") - iterator = data_service_ops.create_iterator(ds, token) + iterator = iter(ds) results = [] # Read halfway through the dataset. for _ in range(num_elements // 2): @@ -184,8 +188,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, self._master.target) - token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") - iterator = data_service_ops.create_iterator(ds, token) + iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 for i in range(midpoint): @@ -219,12 +222,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def f(): ds = dataset_ops.Dataset.range(num_elements) ds = _make_distributed_dataset(ds, service) - token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") - it = data_service_ops.create_iterator(ds, token) result = tensor_array_ops.TensorArray( dtypes.int64, size=num_workers * num_elements, dynamic_size=True) i = 0 - for elem in it: + for elem in ds: result = result.write(i, elem) i += 1 return result.stack() @@ -243,9 +244,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): service = self.create_cluster(3) ds = _make_distributed_dataset(ds, service) - token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") - iterator = data_service_ops.create_iterator(ds, token) - next(iterator) + next(iter(ds)) @combinations.generate( combinations.times( @@ -262,27 +261,6 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): with self.assertRaises(errors.FailedPreconditionError): self.run_stateful(distribute_options.ExternalStatePolicy.FAIL) - @combinations.generate(test_base.eager_only_combinations()) - def testNoDistributeCalls(self): - ds = dataset_ops.Dataset.range(1) - with self.assertRaisesWithLiteralMatch( - ValueError, - "Dataset does not contain any distribute() transformations"): - data_service_ops.create_job(ds, processing_mode="parallel_epochs") - - @combinations.generate(test_base.eager_only_combinations()) - def testMultipleDistributeCalls(self): - service = self.create_cluster(1) - ds1 = dataset_ops.Dataset.range(1) - ds1 = _make_distributed_dataset(ds1, service) - ds2 = dataset_ops.Dataset.range(1) - ds2 = _make_distributed_dataset(ds2, service) - ds = dataset_ops.Dataset.zip((ds1, ds2)) - with self.assertRaisesWithLiteralMatch( - ValueError, "Datasets containing multiple calls to .distribute(...) " - "are not supported"): - data_service_ops.create_job(ds, processing_mode="parallel_epochs") - @combinations.generate(test_base.eager_only_combinations()) def testDistributeFromInterleave(self): service = self.create_cluster(1) @@ -302,14 +280,27 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testDistributeNonStringAddresses(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesRegex(ValueError, "service must be a string"): - ds = ds.apply(data_service_ops.distribute(service=1)) + ds = ds.apply( + data_service_ops.distribute( + processing_mode="parallel_epochs", service=1)) @combinations.generate(test_base.eager_only_combinations()) def testDistributeEmptyAddress(self): ds = dataset_ops.Dataset.range(10) with self.assertRaisesWithLiteralMatch(ValueError, "service must not be empty"): - ds = ds.apply(data_service_ops.distribute(service="")) + ds = ds.apply( + data_service_ops.distribute( + processing_mode="parallel_epochs", service="")) + + @combinations.generate(test_base.eager_only_combinations()) + def testDistributeInvalidProcessingMode(self): + ds = dataset_ops.Dataset.range(10) + with self.assertRaisesRegex(ValueError, + "invalid is not a valid processing mode"): + ds = ds.apply( + data_service_ops.distribute( + processing_mode="invalid", service="grpc://localhost:5000")) if __name__ == "__main__": diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 68de2d3b478..b9eb914ac9f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -840,10 +840,6 @@ tf_module { name: "CountUpTo" argspec: "args=[\'ref\', \'limit\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "CreateJob" - argspec: "args=[\'dataset_id\', \'address\', \'protocol\', \'processing_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "CreateSummaryDbWriter" argspec: "args=[\'writer\', \'db_uri\', \'experiment_name\', \'run_name\', \'user_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -938,7 +934,7 @@ tf_module { } member_method { name: "DataServiceDataset" - argspec: "args=[\'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " } member_method { name: "DatasetCardinality" @@ -2200,10 +2196,6 @@ tf_module { name: "Lu" argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } - member_method { - name: "MakeDataServiceIterator" - argspec: "args=[\'dataset\', \'job_token\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "MakeIterator" argspec: "args=[\'dataset\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 68de2d3b478..b9eb914ac9f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -840,10 +840,6 @@ tf_module { name: "CountUpTo" argspec: "args=[\'ref\', \'limit\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "CreateJob" - argspec: "args=[\'dataset_id\', \'address\', \'protocol\', \'processing_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "CreateSummaryDbWriter" argspec: "args=[\'writer\', \'db_uri\', \'experiment_name\', \'run_name\', \'user_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -938,7 +934,7 @@ tf_module { } member_method { name: "DataServiceDataset" - argspec: "args=[\'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " } member_method { name: "DatasetCardinality" @@ -2200,10 +2196,6 @@ tf_module { name: "Lu" argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } - member_method { - name: "MakeDataServiceIterator" - argspec: "args=[\'dataset\', \'job_token\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "MakeIterator" argspec: "args=[\'dataset\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "