diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 2c4e615b193..85ef6579bc0 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/resource.h" @@ -427,6 +428,31 @@ Status DatasetBase::MakeIterator( return s; } +Status DatasetBase::MakeSplitProvider( + std::unique_ptr* split_provider) const { + std::vector inputs; + Status s = InputDatasets(&inputs); + if (errors::IsUnimplemented(s)) { + return errors::Unimplemented( + "Cannot create a split provider for dataset of type ", type_string(), + ", because the dataset implements neither `InputDatasets` nor " + "`MakeSplitProvider`."); + } + if (inputs.size() != 1) { + return errors::Unimplemented( + "Cannot create a split provider for dataset of type ", type_string(), + ", because the dataset is not unary (having arity ", inputs.size(), + "), and no custom implementation of `MakeSplitProvider` is defined."); + } + return inputs[0]->MakeSplitProvider(split_provider); +} + +Status DatasetBase::InputDatasets( + std::vector* inputs) const { + return errors::Unimplemented("InputDatasets not implemented for ", + type_string()); +} + Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset( SerializationContext* ctx, const DatasetBase* dataset, Node** output) { Status status = dataset->AsGraphDefInternal(ctx, this, output); diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 8c35b1909ca..3f2bb372150 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -295,6 +295,24 @@ class Runner { static Runner* get(); }; +// A class which provides a sequence of splits. Iterators created with a split +// provider will iterate over only the splits provided by the split provider. +class SplitProvider { + public: + virtual ~SplitProvider() {} + // Stores the next split in `*split`, setting `*end_of_splits` to indicate + // whether there were any splits left. + virtual Status GetNext(Tensor* split, bool* end_of_splits) = 0; + // Resets the split provider to its beginning. + virtual Status Reset() = 0; + // Saves the state of this split provider. + virtual Status Save(std::function full_name, + IteratorStateWriter* writer) = 0; + // Saves the state of this split provider. + virtual Status Restore(std::function full_name, + IteratorStateReader* reader) = 0; +}; + // A cut-down version of `OpKernelContext` for running computations in // iterators. Note that we cannot simply use `OpKernelContext` here because we // might run computation in an iterator whose lifetime is not nested within the @@ -319,6 +337,7 @@ class IteratorContext { model(ctx->model()), runner(*(ctx->runner())), runner_threadpool_size(ctx->runner_threadpool_size()), + split_provider(ctx->split_provider()), stats_aggregator(ctx->stats_aggregator()), thread_factory(ctx->thread_factory()), thread_pool(ctx->thread_pool()) {} @@ -386,6 +405,9 @@ class IteratorContext { // Number of threads used for executing user-defined functions. int32 runner_threadpool_size = 0; + // An optional split provider indicating which splits to process. + std::shared_ptr split_provider = nullptr; + // The `StatsAggregator` object to record statistics about the iterator. std::shared_ptr stats_aggregator = nullptr; @@ -432,6 +454,10 @@ class IteratorContext { int32 runner_threadpool_size() { return params_.runner_threadpool_size; } + std::shared_ptr split_provider() { + return params_.split_provider; + } + std::shared_ptr stats_aggregator() { return params_.stats_aggregator; } @@ -802,6 +828,12 @@ class DatasetBase : public core::RefCounted { return MakeIteratorFromCheckpoint(&ctx, output_prefix, reader, iterator); } + // Returns a split provider which partitions the dataset's data into splits + // and provides them in a sequence. The split provider is stored in + // `*split_provider`. + virtual Status MakeSplitProvider( + std::unique_ptr* split_provider) const; + // Returns a vector of DataType values, representing the respective // element types of each tuple component in the outputs of this // dataset. @@ -824,6 +856,14 @@ class DatasetBase : public core::RefCounted { // A human-readable debug string for this dataset. virtual string DebugString() const = 0; + // Stores the dataset's input datasets in `*inputs`. The pointers stored in + // `*inputs` are borrowed. The only valid non-ok return status is + // UNIMPLEMENTED in case `InputDatasets` is not implemented by a dataset + // subclass. Implementing `InputDatasets` enables `DatasetBase` to provide a + // default implementation of `MakeSplitProvider` when there is a single input + // dataset. + virtual Status InputDatasets(std::vector* inputs) const; + // Indicates whether the dataset depends on any external state which would // prevent it from being serializable. If so, the method returns // `errors::FailedPrecondition` with a message that identifies the external diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 1365f8a1d31..f9ed541a937 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -26,6 +26,7 @@ cc_library( ":map_dataset_op", ":name_utils", ":range_dataset_op", + ":split_providers", ":take_dataset_op", ":tensor_slice_dataset_op", "//tensorflow/core:core_cpu", @@ -737,6 +738,31 @@ tf_cc_test( ], ) +cc_library( + name = "split_providers", + srcs = ["split_providers.cc"], + hdrs = ["split_providers.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core/platform:errors", + ], +) + +tf_cc_test( + name = "split_providers_test", + size = "small", + srcs = ["split_providers_test.cc"], + deps = [ + ":dataset_test_base", + ":dataset_utils", + ":split_providers", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/framework:tensor_testutil", + ], +) + tf_kernel_library( name = "take_dataset_op", srcs = ["take_dataset_op.cc"], @@ -813,6 +839,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc index 14af07fe494..5f7edb9f18b 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.cc +++ b/tensorflow/core/kernels/data/dataset_test_base.cc @@ -62,6 +62,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/map_dataset_op.h" #include "tensorflow/core/kernels/data/name_utils.h" #include "tensorflow/core/kernels/data/range_dataset_op.h" +#include "tensorflow/core/kernels/data/split_providers.h" #include "tensorflow/core/kernels/data/take_dataset_op.h" #include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -606,6 +607,47 @@ Status DatasetOpsTestBase::CheckIteratorGetNext( return Status::OK(); } +Status DatasetOpsTestBase::CheckSplitProviderFullIteration( + const DatasetParams& params, const std::vector& expected_outputs) { + std::unique_ptr dataset; + TF_RETURN_IF_ERROR(MakeDataset(params, &dataset)); + std::unique_ptr split_provider; + TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProvider(&split_provider)); + std::unique_ptr iterator; + TF_RETURN_IF_ERROR( + MakeIterator(params, *dataset, std::move(split_provider), &iterator)); + TF_RETURN_IF_ERROR(CheckIteratorGetNext(iterator.get(), expected_outputs, + /*compare_order=*/true)); + return Status::OK(); +} + +Status DatasetOpsTestBase::CheckSplitProviderShardedIteration( + const DatasetParams& params, int64 num_shards, int64 shard_index, + const std::vector& expected_outputs) { + std::unique_ptr dataset; + TF_RETURN_IF_ERROR(MakeDataset(params, &dataset)); + std::unique_ptr split_provider; + TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProvider(&split_provider)); + split_provider = absl::make_unique( + num_shards, shard_index, std::move(split_provider)); + std::unique_ptr iterator_ctx; + TF_RETURN_IF_ERROR( + CreateIteratorContext(dataset->op_kernel_context(), &iterator_ctx)); + IteratorContext::Params iterator_params(iterator_ctx.get()); + iterator_params.split_provider = std::move(split_provider); + iterator_ctx = absl::make_unique(iterator_params); + int mid_breakpoint = expected_outputs.size() / 2; + int near_end_breakpoint = expected_outputs.size() - 1; + int end_breakpoint = expected_outputs.size(); + TF_RETURN_IF_ERROR(CheckIteratorSaveAndRestore( + dataset->dataset(), iterator_ctx.get(), params.iterator_prefix(), + expected_outputs, + /*breakpoints=*/ + {0, mid_breakpoint, near_end_breakpoint, end_breakpoint}, + /*compare_order=*/true)); + return Status::OK(); +} + Status DatasetOpsTestBase::CheckDatasetNodeName( const string& expected_dataset_node_name) { EXPECT_EQ(dataset_->node_name(), expected_dataset_node_name); @@ -658,11 +700,13 @@ Status DatasetOpsTestBase::CheckIteratorPrefix( } Status DatasetOpsTestBase::CheckIteratorSaveAndRestore( - const string& iterator_prefix, const std::vector& expected_outputs, + DatasetBase* dataset, IteratorContext* iterator_ctx, + const std::string& iterator_prefix, + const std::vector& expected_outputs, const std::vector& breakpoints, bool compare_order) { std::unique_ptr iterator; - TF_RETURN_IF_ERROR(dataset_->MakeIterator( - iterator_ctx_.get(), /*parent=*/nullptr, iterator_prefix, &iterator)); + TF_RETURN_IF_ERROR(dataset->MakeIterator(iterator_ctx, /*parent=*/nullptr, + iterator_prefix, &iterator)); std::unique_ptr serialization_ctx; TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx)); bool end_of_sequence = false; @@ -674,33 +718,31 @@ Status DatasetOpsTestBase::CheckIteratorSaveAndRestore( std::vector data; writer.GetData(&data); VariantTensorDataReader reader(data); - TF_EXPECT_OK(RestoreIterator(iterator_ctx_.get(), &reader, iterator_prefix, - *dataset_, &iterator)); + TF_EXPECT_OK(RestoreIterator(iterator_ctx, &reader, iterator_prefix, + *dataset, &iterator)); while (cur_iteration <= breakpoint) { std::vector next; TF_RETURN_IF_ERROR( - iterator->GetNext(iterator_ctx_.get(), &next, &end_of_sequence)); + iterator->GetNext(iterator_ctx, &next, &end_of_sequence)); out_tensors.insert(out_tensors.end(), next.begin(), next.end()); cur_iteration++; } - - if (dataset_->Cardinality() == kUnknownCardinality) { - continue; - } - - if (dataset_->Cardinality() == kInfiniteCardinality || - breakpoint < dataset_->Cardinality()) { - EXPECT_FALSE(end_of_sequence); - } else { - EXPECT_TRUE(end_of_sequence); - } } TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs, /*compare_order=*/compare_order)); return Status::OK(); } +Status DatasetOpsTestBase::CheckIteratorSaveAndRestore( + const std::string& iterator_prefix, + const std::vector& expected_outputs, + const std::vector& breakpoints, bool compare_order) { + return CheckIteratorSaveAndRestore(dataset_, iterator_ctx_.get(), + iterator_prefix, expected_outputs, + breakpoints, compare_order); +} + Status DatasetOpsTestBase::Initialize(const DatasetParams& dataset_params) { if (initialized_) { return errors::Internal( @@ -793,10 +835,14 @@ Status DatasetOpsTestBase::MakeDataset( Status DatasetOpsTestBase::MakeIterator( const DatasetParams& dataset_params, const TestDataset& dataset, + std::unique_ptr split_provider, std::unique_ptr* iterator) { std::unique_ptr iterator_ctx; TF_RETURN_IF_ERROR( CreateIteratorContext(dataset.op_kernel_context(), &iterator_ctx)); + IteratorContext::Params iterator_params(iterator_ctx.get()); + iterator_params.split_provider = std::move(split_provider); + iterator_ctx = absl::make_unique(iterator_params); std::unique_ptr iterator_base; TF_RETURN_IF_ERROR(dataset.dataset()->MakeIterator( iterator_ctx.get(), /*parent=*/nullptr, dataset_params.iterator_prefix(), @@ -806,6 +852,13 @@ Status DatasetOpsTestBase::MakeIterator( return Status::OK(); } +Status DatasetOpsTestBase::MakeIterator( + const DatasetParams& dataset_params, const TestDataset& dataset, + std::unique_ptr* iterator) { + return MakeIterator(dataset_params, dataset, /*split_provider=*/nullptr, + iterator); +} + Status DatasetOpsTestBase::RunDatasetOp(const DatasetParams& dataset_params, std::vector* outputs) { TF_RETURN_IF_ERROR(RunDatasetOp(dataset_params, &dataset_kernel_, ¶ms_, diff --git a/tensorflow/core/kernels/data/dataset_test_base.h b/tensorflow/core/kernels/data/dataset_test_base.h index 0d07a93d4f2..c8680fa98b5 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.h +++ b/tensorflow/core/kernels/data/dataset_test_base.h @@ -517,6 +517,12 @@ class DatasetOpsTestBase : public ::testing::Test { Status MakeDataset(const DatasetParams& dataset_params, std::unique_ptr* dataset); + // Creates an iterator for the given dataset, using the specified split + // provider. + Status MakeIterator(const DatasetParams& dataset_params, + const TestDataset& dataset, + std::unique_ptr split_provider, + std::unique_ptr* iterator); // Creates an iterator for the given dataset. Status MakeIterator(const DatasetParams& dataset_params, const TestDataset& dataset, @@ -556,6 +562,18 @@ class DatasetOpsTestBase : public ::testing::Test { const std::vector& expected_outputs, bool compare_order); + // Checks that iterating through the dataset using a split provider produces + // the expected outputs. + Status CheckSplitProviderFullIteration( + const DatasetParams& params, const std::vector& expected_outputs); + + // Checks that iterating through the dataset using a sharded split provider + // with the given `num_shards` and `shard_index` produces the expected + // outputs. + Status CheckSplitProviderShardedIteration( + const DatasetParams& params, int64 num_shards, int64 shard_index, + const std::vector& expected_outputs); + // Checks `DatasetBase::node_name()`. Status CheckDatasetNodeName(const string& expected_dataset_node_name); @@ -583,9 +601,14 @@ class DatasetOpsTestBase : public ::testing::Test { // Checks `IteratorBase::prefix()`. Status CheckIteratorPrefix(const string& expected_iterator_prefix); - // Checks `IteratorBase::GetNext()`. Status CheckIteratorSaveAndRestore( - const string& iterator_prefix, + DatasetBase* dataset, IteratorContext* iterator_ctx, + const std::string& iterator_prefix, + const std::vector& expected_outputs, + const std::vector& breakpoints, bool compare_order); + + Status CheckIteratorSaveAndRestore( + const std::string& iterator_prefix, const std::vector& expected_outputs, const std::vector& breakpoints, bool compare_order); @@ -660,6 +683,7 @@ class DatasetOpsTestBase : public ::testing::Test { OpKernelContext* const op_context, std::unique_ptr* iterator_context); + // Creates a new iterator context for iterating the dataset. // Creates a new serialization context for serializing the dataset and // iterator. Status CreateSerializationContext( diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index 37b7594564c..d325a3dcf66 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -99,6 +99,12 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { int64 Cardinality() const override { return input_->Cardinality(); } + Status InputDatasets( + std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + Status CheckExternalState() const override { return input_->CheckExternalState(); } diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index f6993ab2797..e0a80f1f0ee 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/range_dataset_op.h" +#include "absl/memory/memory.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" @@ -32,7 +33,97 @@ namespace data { /* static */ constexpr const char* const RangeDatasetOp::kOutputTypes; /* static */ constexpr const char* const RangeDatasetOp::kOutputShapes; +namespace { constexpr char kNext[] = "next"; +constexpr char kHasSplitProvider[] = "has_split_provider"; +constexpr char kSlash[] = "/"; +constexpr char kSplitProvider[] = "split_provider"; + +// Class which produces the elements of `range(start, stop, step)`. Threadsafe. +class RangeCounter { + public: + RangeCounter(int64 start, int64 stop, int64 step) + : start_(start), stop_(stop), step_(step), next_(start) {} + + // Returns the next value for the counter. Sets `*end_of_counter` to indicate + // whether the end of the counter was reached. + int64 GetNext(bool* end_of_counter) { + mutex_lock l(mu_); + if ((step_ > 0 && next_ >= stop_) || (step_ < 0 && next_ <= stop_)) { + *end_of_counter = true; + return -1; + } + *end_of_counter = false; + int result = next_; + next_ += step_; + return result; + } + + int64 Peek() const { + mutex_lock l(mu_); + return next_; + } + + void Reset() { + mutex_lock l(mu_); + next_ = start_; + } + + void SetNext(int64 value) { + mutex_lock l(mu_); + next_ = value; + } + + private: + const int64 start_; + const int64 stop_; + const int64 step_; + mutable mutex mu_; + int64 next_ TF_GUARDED_BY(mu_); +}; +} // namespace + +// Split provider where splits are individual outputs from RangeDataset. +// For example, the "splits" of range(0, 10, 2) will be {0, 2, 4, 6, 8}. +// The split tensors are scalars of type DT_INT64. +class RangeDatasetOp::RangeSplitProvider : public SplitProvider { + public: + RangeSplitProvider(int64 start, int64 stop, int64 step) + : counter_(start, stop, step) {} + + Status GetNext(Tensor* split, bool* end_of_splits) override { + int64 next = counter_.GetNext(end_of_splits); + if (*end_of_splits) { + return Status::OK(); + } + *split = Tensor(DT_INT64, TensorShape{}); + split->scalar()() = next; + return Status::OK(); + } + + Status Reset() override { + counter_.Reset(); + return Status::OK(); + } + + Status Save(std::function key_name_fn, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR( + writer->WriteScalar(key_name_fn(kNext), counter_.Peek())); + return Status::OK(); + } + + Status Restore(std::function key_name_fn, + IteratorStateReader* reader) override { + int64 next; + TF_RETURN_IF_ERROR(reader->ReadScalar(key_name_fn(kNext), &next)); + counter_.SetNext(next); + return Status::OK(); + } + + private: + RangeCounter counter_; +}; class RangeDatasetOp::Dataset : public DatasetBase { public: @@ -74,6 +165,18 @@ class RangeDatasetOp::Dataset : public DatasetBase { } } + Status MakeSplitProvider( + std::unique_ptr* split_provider) const override { + *split_provider = + absl::make_unique(start_, stop_, step_); + return Status::OK(); + } + + Status InputDatasets(std::vector* inputs) const override { + inputs->clear(); + return Status::OK(); + } + Status CheckExternalState() const override { return Status::OK(); } protected: @@ -93,24 +196,40 @@ class RangeDatasetOp::Dataset : public DatasetBase { private: class Iterator : public DatasetIterator { public: - explicit Iterator(const Params& params) : DatasetIterator(params) { - next_ = params.dataset->start_; + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + split_provider_ = ctx->split_provider(); + if (!split_provider_) { + counter_ = absl::make_unique( + dataset()->start_, dataset()->stop_, dataset()->step_); + } + return Status::OK(); } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { - mutex_lock l(mu_); - if ((dataset()->step_ > 0 && next_ >= dataset()->stop_) || - (dataset()->step_ < 0 && next_ <= dataset()->stop_)) { - *end_of_sequence = true; - return Status::OK(); + int64 value; + if (split_provider_ != nullptr) { + Tensor split; + TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence)); + if (*end_of_sequence) { + return Status::OK(); + } + value = split.scalar()(); + } else { + value = counter_->GetNext(end_of_sequence); + if (*end_of_sequence) { + return Status::OK(); + } } out_tensors->reserve(1); switch (dataset()->output_dtypes()[0]) { #define HANDLE_TYPE(type) \ case DataTypeToEnum::value: { \ - out_tensors->emplace_back(static_cast(next_)); \ + out_tensors->emplace_back(static_cast(value)); \ break; \ } TF_CALL_NUMBER_TYPES(HANDLE_TYPE); @@ -120,9 +239,6 @@ class RangeDatasetOp::Dataset : public DatasetBase { "Unsupported data type: ", DataTypeString(dataset()->output_dtypes()[0])); } - *end_of_sequence = false; - next_ += dataset()->step_; - return Status::OK(); } @@ -134,21 +250,44 @@ class RangeDatasetOp::Dataset : public DatasetBase { Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNext), next_)); + if (split_provider_) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name(kHasSplitProvider), true)); + TF_RETURN_IF_ERROR(split_provider_->Save( + [this](const std::string& key) { + return SplitProviderKeyNameFn(key); + }, + writer)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name(kNext), counter_->Peek())); + } return Status::OK(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNext), &next_)); + if (reader->Contains(full_name(kHasSplitProvider))) { + TF_RETURN_IF_ERROR(split_provider_->Restore( + [this](const std::string& key) { + return SplitProviderKeyNameFn(key); + }, + reader)); + } else { + int64 next; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNext), &next)); + counter_->SetNext(next); + } return Status::OK(); } + std::string SplitProviderKeyNameFn(const std::string& key) { + return full_name(absl::StrCat(kSplitProvider, kSlash, key)); + } + private: - mutex mu_; - int64 next_ TF_GUARDED_BY(mu_); + std::unique_ptr counter_; + std::shared_ptr split_provider_; }; const int64 start_; diff --git a/tensorflow/core/kernels/data/range_dataset_op.h b/tensorflow/core/kernels/data/range_dataset_op.h index 077987b72e1..8e9891c5671 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.h +++ b/tensorflow/core/kernels/data/range_dataset_op.h @@ -36,6 +36,7 @@ class RangeDatasetOp : public DatasetOpKernel { private: class Dataset; + class RangeSplitProvider; DataTypeVector output_types_; }; diff --git a/tensorflow/core/kernels/data/range_dataset_op_test.cc b/tensorflow/core/kernels/data/range_dataset_op_test.cc index 13a027ecdc6..f8f12c36343 100644 --- a/tensorflow/core/kernels/data/range_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/range_dataset_op_test.cc @@ -141,6 +141,39 @@ TEST_F(RangeDatasetOpTest, ZeroStep) { tensorflow::error::INVALID_ARGUMENT); } +TEST_F(RangeDatasetOpTest, SplitProviderPositiveStep) { + auto params = RangeDatasetParams(/*start=*/0, /*stop=*/10, /*step=*/3, + /*output_dtypes=*/{DT_INT64}); + TF_ASSERT_OK(InitializeRuntime(params)); + TF_EXPECT_OK(CheckSplitProviderFullIteration( + params, CreateTensors(TensorShape({}), {{0}, {3}, {6}, {9}}))); + TF_EXPECT_OK(CheckSplitProviderShardedIteration( + params, /*num_shards=*/2, /*shard_index=*/1, + CreateTensors(TensorShape({}), {{3}, {9}}))); +} + +TEST_F(RangeDatasetOpTest, SplitProviderNegativeStep) { + auto params = RangeDatasetParams(/*start=*/10, /*stop=*/0, /*step=*/-3, + /*output_dtypes=*/{DT_INT64}); + TF_ASSERT_OK(InitializeRuntime(params)); + TF_EXPECT_OK(CheckSplitProviderFullIteration( + params, CreateTensors(TensorShape({}), {{10}, {7}, {4}, {1}}))); + TF_EXPECT_OK(CheckSplitProviderShardedIteration( + params, /*num_shards=*/2, /*shard_index=*/0, + CreateTensors(TensorShape({}), {{10}, {4}}))); +} + +TEST_F(RangeDatasetOpTest, SplitProviderEmpty) { + auto params = RangeDatasetParams(/*start=*/0, /*stop=*/0, /*step=*/1, + /*output_dtypes=*/{DT_INT64}); + TF_ASSERT_OK(InitializeRuntime(params)); + TF_EXPECT_OK(CheckSplitProviderFullIteration( + params, CreateTensors(TensorShape({}), {}))); + TF_EXPECT_OK(CheckSplitProviderShardedIteration( + params, /*num_shards=*/3, /*shard_index=*/2, + CreateTensors(TensorShape({}), {}))); +} + } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/split_providers.cc b/tensorflow/core/kernels/data/split_providers.cc new file mode 100644 index 00000000000..d39bf4865f9 --- /dev/null +++ b/tensorflow/core/kernels/data/split_providers.cc @@ -0,0 +1,83 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/split_providers.h" + +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace data { + +namespace { +constexpr char kNumToSkip[] = "num_to_skip"; +constexpr char kSplitProvider[] = "split_provider"; +constexpr char kSlash[] = "/"; +} // namespace + +ShardingSplitProvider::ShardingSplitProvider( + int64 num_shards, int64 shard_index, + std::shared_ptr split_provider) + : num_shards_(num_shards), + shard_index_(shard_index), + split_provider_(split_provider), + num_to_skip_(shard_index_) {} + +Status ShardingSplitProvider::GetNext(Tensor* split, bool* end_of_splits) { + mutex_lock l(mu_); + while (num_to_skip_ > 0) { + TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits)); + if (*end_of_splits) { + return Status::OK(); + } + num_to_skip_--; + } + num_to_skip_ = num_shards_ - 1; + TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits)); + return Status::OK(); +} + +Status ShardingSplitProvider::Reset() { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(split_provider_->Reset()); + num_to_skip_ = shard_index_; + return Status::OK(); +} + +Status ShardingSplitProvider::Save( + std::function full_name, + IteratorStateWriter* writer) { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(split_provider_->Save( + [&](const std::string& key) { + return full_name(absl::StrCat(kSplitProvider, kSlash, key)); + }, + writer)); + return writer->WriteScalar(full_name(kNumToSkip), num_to_skip_); +} + +Status ShardingSplitProvider::Restore( + std::function full_name, + IteratorStateReader* reader) { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(split_provider_->Restore( + [&](const std::string& key) { + return full_name(absl::StrCat(kSplitProvider, kSlash, key)); + }, + reader)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumToSkip), &num_to_skip_)); + return Status::OK(); +} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/split_providers.h b/tensorflow/core/kernels/data/split_providers.h new file mode 100644 index 00000000000..e9c02dfd9bf --- /dev/null +++ b/tensorflow/core/kernels/data/split_providers.h @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +// A SplitProvider which wraps another split provider, but drops all splits +// where `index != shard_index % num_shards` +class ShardingSplitProvider : public SplitProvider { + public: + ShardingSplitProvider(int64 num_shards, int64 shard_index, + std::shared_ptr split_provider); + + Status GetNext(Tensor* split, bool* end_of_splits) override; + Status Reset() override; + Status Save(std::function full_name, + IteratorStateWriter* writer) override; + Status Restore(std::function full_name, + IteratorStateReader* reader) override; + + private: + const int64 num_shards_; + const int64 shard_index_; + mutex mu_; + std::shared_ptr split_provider_ TF_GUARDED_BY(mu_); + int64 num_to_skip_ TF_GUARDED_BY(mu_); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_ diff --git a/tensorflow/core/kernels/data/split_providers_test.cc b/tensorflow/core/kernels/data/split_providers_test.cc new file mode 100644 index 00000000000..06d62754ed7 --- /dev/null +++ b/tensorflow/core/kernels/data/split_providers_test.cc @@ -0,0 +1,167 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/split_providers.h" + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/data/dataset_test_base.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { +namespace { +constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b"; +constexpr char kPipe[] = "|"; +constexpr char kColon[] = ":"; +constexpr char kSplits[] = "splits"; +constexpr char kSplitsSize[] = "splits_size"; + +std::string full_name(const std::string& name) { + return strings::StrCat(kFullNameRandomHex, kPipe, "test", kColon, name); +} + +Status SaveAndRestore(SplitProvider* split_provider) { + VariantTensorDataWriter writer; + TF_RETURN_IF_ERROR(split_provider->Save(full_name, &writer)); + std::vector variants; + writer.GetData(&variants); + VariantTensorDataReader reader(variants); + TF_RETURN_IF_ERROR(split_provider->Restore(full_name, &reader)); + return Status::OK(); +} + +// A split provider that provides pre-defined splits. +class TestSplitProvider : public SplitProvider { + public: + explicit TestSplitProvider(std::vector splits) : splits_(splits) {} + + Status GetNext(Tensor* split, bool* end_of_splits) override { + *end_of_splits = i_ >= splits_.size(); + if (*end_of_splits) { + return Status::OK(); + } + *split = splits_[i_++]; + return Status::OK(); + } + + Status Reset() override { + i_ = 0; + return Status::OK(); + } + + Status Save(std::function full_name, + IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name(kSplitsSize), splits_.size())); + for (int i = 0; i < splits_.size(); ++i) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(absl::StrCat(kSplits, "[", i, "]")), splits_[i])); + } + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i_"), i_)); + return Status::OK(); + } + + Status Restore(std::function full_name, + IteratorStateReader* reader) override { + int64 splits_size; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name(kSplitsSize), &splits_size)); + splits_.clear(); + for (int i = 0; i < splits_size; ++i) { + splits_.emplace_back(); + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(absl::StrCat(kSplits, "[", i, "]")), &splits_.back())); + } + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i_"), &i_)); + return Status::OK(); + } + + private: + std::vector splits_; + int64 i_ = 0; +}; + +Status CheckOutput(ShardingSplitProvider& split_provider, + std::vector expected) { + int64 next = 0; + bool end_of_splits = false; + while (!end_of_splits) { + Tensor split; + TF_RETURN_IF_ERROR(split_provider.GetNext(&split, &end_of_splits)); + if (!end_of_splits) { + test::ExpectEqual(split, expected[next++]); + } + } + EXPECT_EQ(next, expected.size()); + return Status::OK(); +} + +TEST(ShardingSplitProvider, TwoWayShardZero) { + auto base = std::make_shared( + CreateTensors(TensorShape({}), {{0}, {1}, {2}, {3}})); + ShardingSplitProvider split_provider(2, 0, base); + TF_EXPECT_OK(CheckOutput(split_provider, + CreateTensors(TensorShape({}), {{0}, {2}}))); +} + +TEST(ShardingSplitProvider, TwoWayShardOne) { + auto base = std::make_shared( + CreateTensors(TensorShape({}), {{0}, {1}, {2}, {3}})); + ShardingSplitProvider split_provider(2, 1, base); + TF_EXPECT_OK(CheckOutput(split_provider, + CreateTensors(TensorShape({}), {{1}, {3}}))); +} + +TEST(ShardingSplitProvider, ThreeWayShardOne) { + auto base = std::make_shared( + CreateTensors(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})); + ShardingSplitProvider split_provider(3, 1, base); + TF_EXPECT_OK(CheckOutput(split_provider, + CreateTensors(TensorShape({}), {{1}, {4}}))); +} + +TEST(ShardingSplitProvider, Empty) { + auto base = std::make_shared( + CreateTensors(TensorShape({}), {{0}})); + ShardingSplitProvider split_provider(2, 1, base); + TF_EXPECT_OK( + CheckOutput(split_provider, CreateTensors(TensorShape({}), {}))); +} + +TEST(ShardingSplitProvider, SaveAndRestore) { + auto base = std::make_shared( + CreateTensors(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})); + std::vector expected = + CreateTensors(TensorShape({}), {{1}, {4}}); + ShardingSplitProvider split_provider(3, 1, base); + for (int i = 0; i < expected.size(); ++i) { + TF_ASSERT_OK(SaveAndRestore(&split_provider)); + Tensor split; + bool end_of_splits = true; + TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits)); + EXPECT_FALSE(end_of_splits); + test::ExpectEqual(split, expected[i]); + } + TF_ASSERT_OK(SaveAndRestore(&split_provider)); + Tensor split; + bool end_of_splits = false; + TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits)); + EXPECT_TRUE(end_of_splits); +} + +} // namespace +} // namespace data +} // namespace tensorflow