From a58bfeef159f4465064f0a02e4a958e075a0a44c Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Wed, 16 Sep 2020 12:32:23 -0700 Subject: [PATCH] [tf.data] Implement DatasetBase::InputDatasets() for core datasets. PiperOrigin-RevId: 332063322 Change-Id: Ib776466a1f2844b98e8257b28eadb26af9abb9ed --- .../core/kernels/data/batch_dataset_op.cc | 5 +++++ .../core/kernels/data/cache_dataset_ops.cc | 10 ++++++++++ .../kernels/data/concatenate_dataset_op.cc | 6 ++++++ .../assert_cardinality_dataset_op.cc | 5 +++++ .../experimental/assert_next_dataset_op.cc | 5 +++++ .../choose_fastest_branch_dataset_op.cc | 10 ++++++++++ .../experimental/choose_fastest_dataset_op.cc | 8 ++++++++ .../dense_to_sparse_batch_dataset_op.cc | 6 ++++++ .../group_by_window_dataset_op.cc | 6 ++++++ .../experimental/ignore_errors_dataset_op.cc | 6 ++++++ .../data/experimental/lmdb_dataset_op.cc | 4 ++++ .../experimental/map_and_batch_dataset_op.cc | 5 +++++ .../experimental/matching_files_dataset_op.cc | 5 +++++ .../non_serializable_dataset_op.cc | 6 ++++++ .../parallel_interleave_dataset_op.cc | 5 +++++ .../experimental/parse_example_dataset_op.cc | 6 ++++++ .../data/experimental/random_dataset_op.cc | 4 ++++ .../data/experimental/sampling_dataset_op.cc | 5 +++++ .../data/experimental/scan_dataset_op.cc | 6 ++++++ .../set_stats_aggregator_dataset_op.cc | 6 ++++++ .../data/experimental/sleep_dataset_op.cc | 6 ++++++ .../experimental/sliding_window_dataset_op.cc | 6 ++++++ .../data/experimental/snapshot_dataset_op.cc | 14 +++++++++++++ .../data/experimental/snapshot_util.cc | 9 +++++++++ .../data/experimental/sql_dataset_op.cc | 5 +++++ .../data/experimental/stats_dataset_ops.cc | 6 ++++++ .../experimental/take_while_dataset_op.cc | 6 ++++++ .../experimental/threadpool_dataset_op.cc | 20 +++++++++++++++++++ .../data/experimental/unbatch_dataset_op.cc | 6 ++++++ .../data/experimental/unique_dataset_op.cc | 5 +++++ .../core/kernels/data/filter_dataset_op.cc | 5 +++++ .../data/fixed_length_record_dataset_op.cc | 4 ++++ .../core/kernels/data/flat_map_dataset_op.cc | 5 +++++ .../core/kernels/data/generator_dataset_op.cc | 4 ++++ .../kernels/data/interleave_dataset_op.cc | 5 +++++ .../core/kernels/data/map_dataset_op.cc | 5 +++++ .../kernels/data/padded_batch_dataset_op.cc | 5 +++++ .../data/parallel_interleave_dataset_op.cc | 5 +++++ .../kernels/data/parallel_map_dataset_op.cc | 5 +++++ .../core/kernels/data/prefetch_dataset_op.cc | 5 +++++ .../core/kernels/data/repeat_dataset_op.cc | 8 ++++++++ .../core/kernels/data/shard_dataset_op.cc | 5 +++++ .../core/kernels/data/shuffle_dataset_op.cc | 11 +++++++++- .../core/kernels/data/skip_dataset_op.cc | 5 +++++ .../data/sparse_tensor_slice_dataset_op.cc | 4 ++++ .../core/kernels/data/take_dataset_op.cc | 6 ++++++ .../core/kernels/data/take_dataset_op.h | 2 ++ .../core/kernels/data/tensor_dataset_op.cc | 4 ++++ .../kernels/data/tensor_slice_dataset_op.cc | 4 ++++ .../core/kernels/data/text_line_dataset_op.cc | 4 ++++ .../core/kernels/data/tf_record_dataset_op.cc | 4 ++++ .../core/kernels/data/window_dataset.cc | 4 ++++ .../core/kernels/data/window_dataset_op.cc | 5 +++++ .../core/kernels/data/zip_dataset_op.cc | 7 +++++++ 54 files changed, 322 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index 96c7e036e03..7ea39dfe709 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -117,6 +117,11 @@ class BatchDatasetOp::Dataset : public DatasetBase { return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index f60001b0055..c9883f9c938 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -102,6 +102,11 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { int64 Cardinality() const override { return input_->Cardinality(); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } @@ -680,6 +685,11 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { int64 Cardinality() const override { return input_->Cardinality(); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc index 34faafeb178..ffe15248c0e 100644 --- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc @@ -85,6 +85,12 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { return n1 + n2; } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + inputs->push_back(to_concatenate_); + return Status::OK(); + } + Status CheckExternalState() const override { TF_RETURN_IF_ERROR(input_->CheckExternalState()); return to_concatenate_->CheckExternalState(); diff --git a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc index 1dd38dcaa04..30d0f9405f7 100644 --- a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc @@ -67,6 +67,11 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { int64 Cardinality() const override { return cardinality_; } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc index 7348b342c6a..3898bb4e705 100644 --- a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc @@ -64,6 +64,11 @@ class AssertNextDatasetOp::Dataset : public DatasetBase { int64 Cardinality() const override { return input_->Cardinality(); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc index 8772f21ef8f..1cb2564d3a0 100644 --- a/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc @@ -55,6 +55,10 @@ class WrapperDataset : public DatasetBase { string DebugString() const override { return "WrapperDataset"; } + Status InputDatasets(std::vector* inputs) const override { + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } protected: @@ -245,6 +249,12 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { return static_cast(n) * ratio_numerator_ / ratio_denominator_; } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { for (const auto& captured_func : captured_funcs_) { TF_RETURN_IF_ERROR(captured_func->CheckExternalState()); diff --git a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc index 6ab72d85a99..3fff7bc6f16 100644 --- a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc @@ -158,6 +158,14 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { int64 Cardinality() const override { return cardinality_; } + Status InputDatasets( + std::vector* inputs) const override { + for (const auto& input : inputs_) { + inputs->push_back(input); + } + return Status::OK(); + } + Status CheckExternalState() const override { for (const auto& input : inputs_) { TF_RETURN_IF_ERROR(input->CheckExternalState()); diff --git a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc index d09922988b9..c0070dca9f7 100644 --- a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc @@ -120,6 +120,12 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { return n / batch_size_ + (n % batch_size_ == 0 ? 0 : 1); } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc index 0a6df24d40a..a629292e2ed 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc @@ -107,6 +107,12 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { return "GroupByWindowDatasetOp::Dataset"; } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_key_func_->CheckExternalState()); TF_RETURN_IF_ERROR(captured_reduce_func_->CheckExternalState()); diff --git a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc index 27d3b7cd45b..8b2745ee526 100644 --- a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc @@ -66,6 +66,12 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { int64 Cardinality() const override { return input_->Cardinality(); } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc index 7cfa74e6516..31763c89544 100644 --- a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc @@ -54,6 +54,10 @@ class LMDBDatasetOp::Dataset : public DatasetBase { string DebugString() const override { return "LMDBDatasetOp::Dataset"; } + Status InputDatasets(std::vector* inputs) const override { + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } protected: diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc index fdc63bdb913..5cc72ba853e 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc @@ -137,6 +137,11 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); diff --git a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc index 90a61d72597..a606e76008c 100644 --- a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc @@ -82,6 +82,11 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { return "MatchingFilesDatasetOp::Dataset"; } + Status InputDatasets( + std::vector* inputs) const override { + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } protected: diff --git a/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc b/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc index 1e752931157..cb0be0e0bbf 100644 --- a/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc @@ -69,6 +69,12 @@ class NonSerializableDatasetOp : public UnaryDatasetOpKernel { return "NonSerializableDatasetOp::Dataset"; } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc index 9c344e01c6a..0e15015efee 100644 --- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc @@ -147,6 +147,11 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType, params); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); diff --git a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc index 3002987b621..16cf7fe6416 100644 --- a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc @@ -270,6 +270,12 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { int64 Cardinality() const override { return input_->Cardinality(); } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/experimental/random_dataset_op.cc b/tensorflow/core/kernels/data/experimental/random_dataset_op.cc index 460c18ce7a3..ee90d8cb603 100644 --- a/tensorflow/core/kernels/data/experimental/random_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/random_dataset_op.cc @@ -64,6 +64,10 @@ class RandomDatasetOp::Dataset : public DatasetBase { int64 Cardinality() const override { return kInfiniteCardinality; } + Status InputDatasets(std::vector* inputs) const override { + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } protected: diff --git a/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc index 00869eea85c..6f271e4912b 100644 --- a/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc @@ -70,6 +70,11 @@ class SamplingDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc index 723f32311d0..eee635ffa7b 100644 --- a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc @@ -114,6 +114,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { } } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); diff --git a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc index e96de29d759..ab4f58e3c5c 100644 --- a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc @@ -138,6 +138,12 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { int64 Cardinality() const override { return input_->Cardinality(); } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc index f2195804cfd..d21039c4078 100644 --- a/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc @@ -68,6 +68,12 @@ class SleepDatasetOp : public UnaryDatasetOpKernel { int64 Cardinality() const override { return input_->Cardinality(); } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc index 04ebd5bfd34..523259dde73 100644 --- a/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc @@ -108,6 +108,12 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel { return n / window_shift_; } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index 83f21fee10e..ce21903f2f0 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -122,6 +122,8 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { int64 Cardinality() const override; + Status InputDatasets(std::vector* inputs) const override; + Status CheckExternalState() const override; protected: @@ -327,6 +329,12 @@ int64 SnapshotDatasetV2Op::Dataset::Cardinality() const { return input_->Cardinality(); } +Status SnapshotDatasetV2Op::Dataset::InputDatasets( + std::vector* inputs) const { + inputs->push_back(input_); + return Status::OK(); +} + Status SnapshotDatasetV2Op::Dataset::CheckExternalState() const { return input_->CheckExternalState(); } @@ -1034,6 +1042,12 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { int64 Cardinality() const override { return input_->Cardinality(); } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.cc b/tensorflow/core/kernels/data/experimental/snapshot_util.cc index 9e936974c83..33ce9956cbc 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_util.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_util.cc @@ -363,6 +363,10 @@ class Reader::Dataset : public DatasetBase { return "snapshot_util::Reader::Dataset"; } + Status InputDatasets(std::vector* inputs) const override { + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } protected: @@ -483,6 +487,11 @@ class Reader::NestedDataset : public DatasetBase { return "snapshot_util::Reader::NestedDataset"; } + Status InputDatasets(std::vector* inputs) const override { + inputs->clear(); + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } protected: diff --git a/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc index f6720aa1c88..4b1b99a120f 100644 --- a/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc @@ -103,6 +103,11 @@ class SqlDatasetOp : public DatasetOpKernel { string DebugString() const override { return "SqlDatasetOp::Dataset"; } + Status InputDatasets( + std::vector* inputs) const override { + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } protected: diff --git a/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc b/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc index 08d208fc340..1aa179acdd3 100644 --- a/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc @@ -78,6 +78,12 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { int64 Cardinality() const override { return input_->Cardinality(); } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc b/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc index fd4b4fccb7e..fd7eedc4cf0 100644 --- a/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc @@ -82,6 +82,12 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel { int64 Cardinality() const override { return kUnknownCardinality; } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc index a9c682a426b..111d7b2fec2 100644 --- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc @@ -173,6 +173,12 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { int64 Cardinality() const override { return input_->Cardinality(); } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } @@ -301,6 +307,13 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel { int64 Cardinality() const override { return input_->Cardinality(); } + Status InputDatasets( + std::vector* inputs) const override { + inputs->clear(); + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } @@ -423,6 +436,13 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel { int64 Cardinality() const override { return input_->Cardinality(); } + Status InputDatasets( + std::vector* inputs) const override { + inputs->clear(); + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc index e813de70931..7d5810f7818 100644 --- a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc @@ -84,6 +84,12 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { return kUnknownCardinality; } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc index a4319234082..eeb1655970f 100644 --- a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc @@ -54,6 +54,11 @@ class UniqueDatasetOp::Dataset : public DatasetBase { return strings::StrCat("UniqueDatasetOp::Dataset"); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index 1301aed3cb4..b93f19e58e3 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -75,6 +75,11 @@ class FilterDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); diff --git a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc index 468a22261d5..2b75483a7a5 100644 --- a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc @@ -93,6 +93,10 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType, params); } + Status InputDatasets(std::vector* inputs) const override { + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } protected: diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index eba5097a1bb..ab0eb18abda 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -75,6 +75,11 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index fcdbe4ab9a5..8d841cf9f60 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -74,6 +74,10 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType); } + Status InputDatasets(std::vector* inputs) const override { + return Status::OK(); + } + Status CheckExternalState() const override { TF_RETURN_IF_ERROR(init_func_->CheckExternalState()); TF_RETURN_IF_ERROR(next_func_->CheckExternalState()); diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 0a795c1cf82..cbe1caeb0b0 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -87,6 +87,11 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index d34e4f2b041..3626c0bbf89 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -80,6 +80,11 @@ class MapDatasetOp::Dataset : public DatasetBase { } } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index fd0a1855206..805954a5179 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -119,6 +119,11 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 90dd5337c1d..583a3cc509c 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -219,6 +219,11 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { ParallelInterleaveDatasetOp::kDatasetType, params); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index b0c4a6589cc..87ea4531d5d 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -120,6 +120,11 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { } } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index 20b78ba14ad..4a55514ffd1 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -88,6 +88,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { int64 Cardinality() const override { return input_->Cardinality(); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index dd6a0e9d03e..76dbff1744d 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -89,6 +89,11 @@ class RepeatDatasetOp::Dataset : public DatasetBase { return count_ * n; } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } @@ -158,6 +163,9 @@ class RepeatDatasetOp::Dataset : public DatasetBase { return Status::OK(); } ++i_; + if (ctx->split_provider()) { + TF_RETURN_IF_ERROR(ctx->split_provider()->Reset()); + } TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_)); } diff --git a/tensorflow/core/kernels/data/shard_dataset_op.cc b/tensorflow/core/kernels/data/shard_dataset_op.cc index d54ea63099b..43c4b79db06 100644 --- a/tensorflow/core/kernels/data/shard_dataset_op.cc +++ b/tensorflow/core/kernels/data/shard_dataset_op.cc @@ -84,6 +84,11 @@ class ShardDatasetOp::Dataset : public DatasetBase { return n / num_shards_ + (index_ < n % num_shards_ ? 1 : 0); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 7b696371049..4df9dcefcf0 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -117,6 +117,11 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { } } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } @@ -182,7 +187,8 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { data_produced_ = true; break; } - if (!data_produced_ && this->dataset()->count_ == -1) { + if (ctx->split_provider() == nullptr && !data_produced_ && + this->dataset()->count_ == -1) { // If we encounter the end of sequence without producing data, we // terminate the iteration immediately. (Otherwise, this iterator // would loop infinitely and never produce a value.) @@ -192,6 +198,9 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { epoch_++; int64 n = slices_.back()->end; slices_.push_back(absl::make_unique(n, n)); + if (ctx->split_provider()) { + TF_RETURN_IF_ERROR(ctx->split_provider()->Reset()); + } TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator( ctx, this, this->prefix(), &input_impl_)); } diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc index 9f1e99cd915..897c7b6b7e4 100644 --- a/tensorflow/core/kernels/data/skip_dataset_op.cc +++ b/tensorflow/core/kernels/data/skip_dataset_op.cc @@ -75,6 +75,11 @@ class SkipDatasetOp::Dataset : public DatasetBase { return count_ < 0 ? 0 : std::max(int64{0}, n - count_); } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc index 1e3ed53d6c6..9efc9fddf58 100644 --- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc @@ -56,6 +56,10 @@ class Dataset : public DatasetBase { int64 Cardinality() const override { return sparse_tensor_.shape()[0]; } + Status InputDatasets(std::vector* inputs) const override { + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } protected: diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc index 627467f291b..bfafcaa7aa1 100644 --- a/tensorflow/core/kernels/data/take_dataset_op.cc +++ b/tensorflow/core/kernels/data/take_dataset_op.cc @@ -74,6 +74,12 @@ int64 TakeDataset::Cardinality() const { return std::min(n, count_); } +Status TakeDataset::InputDatasets( + std::vector* inputs) const { + inputs->push_back(input_); + return Status::OK(); +} + Status TakeDataset::CheckExternalState() const { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/take_dataset_op.h b/tensorflow/core/kernels/data/take_dataset_op.h index 03f8ff662a7..2b85e74e7f1 100644 --- a/tensorflow/core/kernels/data/take_dataset_op.h +++ b/tensorflow/core/kernels/data/take_dataset_op.h @@ -40,6 +40,8 @@ class TakeDataset : public DatasetBase { int64 Cardinality() const override; + Status InputDatasets(std::vector* inputs) const override; + Status CheckExternalState() const override; protected: diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc index 78cc06a54c5..84b8e0bd435 100644 --- a/tensorflow/core/kernels/data/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc @@ -64,6 +64,10 @@ class TensorDatasetOp::Dataset : public DatasetBase { int64 Cardinality() const override { return 1LL; } + Status InputDatasets(std::vector* inputs) const override { + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } protected: diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc index e4f27f55327..2ab713259d1 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc @@ -69,6 +69,10 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase { int64 Cardinality() const override { return tensors_[0].dim_size(0); } + Status InputDatasets(std::vector* inputs) const override { + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } protected: diff --git a/tensorflow/core/kernels/data/text_line_dataset_op.cc b/tensorflow/core/kernels/data/text_line_dataset_op.cc index 550a859093d..8851ec7995e 100644 --- a/tensorflow/core/kernels/data/text_line_dataset_op.cc +++ b/tensorflow/core/kernels/data/text_line_dataset_op.cc @@ -70,6 +70,10 @@ class TextLineDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType); } + Status InputDatasets(std::vector* inputs) const override { + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } protected: diff --git a/tensorflow/core/kernels/data/tf_record_dataset_op.cc b/tensorflow/core/kernels/data/tf_record_dataset_op.cc index c6387a49f46..0de7f9100b1 100644 --- a/tensorflow/core/kernels/data/tf_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/tf_record_dataset_op.cc @@ -85,6 +85,10 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType); } + Status InputDatasets(std::vector* inputs) const override { + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } protected: diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc index 0c156baec89..42c2fc7656c 100644 --- a/tensorflow/core/kernels/data/window_dataset.cc +++ b/tensorflow/core/kernels/data/window_dataset.cc @@ -67,6 +67,10 @@ class WindowDataset : public DatasetBase { string DebugString() const override { return kWindowDataset; } + Status InputDatasets(std::vector* inputs) const override { + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } protected: diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc index 35437a9231c..4e239d0895c 100644 --- a/tensorflow/core/kernels/data/window_dataset_op.cc +++ b/tensorflow/core/kernels/data/window_dataset_op.cc @@ -105,6 +105,11 @@ class WindowDatasetOp::Dataset : public DatasetBase { return cardinality; } + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc index b59dc2c3a22..0ac9f17839b 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op.cc +++ b/tensorflow/core/kernels/data/zip_dataset_op.cc @@ -87,6 +87,13 @@ class ZipDatasetOp::Dataset : public DatasetBase { return result; } + Status InputDatasets(std::vector* inputs) const override { + for (const auto& input : inputs_) { + inputs->push_back(input); + } + return Status::OK(); + } + Status CheckExternalState() const override { for (const auto& input : inputs_) { TF_RETURN_IF_ERROR(input->CheckExternalState());