diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 85ef6579bc0..c851af9a5c4 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -383,6 +383,14 @@ int64 GetTotalBytes(const std::vector& element) { return total_bytes; } +std::string FullName(const std::string& prefix, const std::string& name) { + if (str_util::StrContains(name, kColon)) { + LOG(ERROR) << name << " should not contain " << kColon; + } + + return strings::StrCat(kFullNameRandomHex, kPipe, prefix, kColon, name); +} + Status GetDatasetFromVariantTensor(const Tensor& tensor, DatasetBase** out_dataset) { if (!(tensor.dtype() == DT_VARIANT && diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 3f2bb372150..110fb53702d 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -121,6 +121,10 @@ class IteratorStateWriter { virtual ~IteratorStateWriter() {} }; +// Generates a full name key for iterator checkpointing. All keys generated for +// iterator checkpoints should go through this function. +std::string FullName(const std::string& prefix, const std::string& name); + // Wrapper around GraphDefBuilder. Used to serialize Dataset graph. class GraphDefBuilderWrapper { public: @@ -972,12 +976,7 @@ class DatasetBaseIterator : public IteratorBase { bool* end_of_sequence, int* num_skipped); string full_name(const string& name) const { - if (str_util::StrContains(name, kColon)) { - LOG(ERROR) << name << " should not contain " << kColon; - } - - return strings::StrCat(kFullNameRandomHex, kPipe, params_.prefix, kColon, - name); + return FullName(params_.prefix, name); } // Returns a map of key-value pairs to included in the TraceMe string. diff --git a/tensorflow/core/framework/dataset_test.cc b/tensorflow/core/framework/dataset_test.cc index 9dbb3be7faf..e471b441ce2 100644 --- a/tensorflow/core/framework/dataset_test.cc +++ b/tensorflow/core/framework/dataset_test.cc @@ -27,6 +27,11 @@ TEST(DatasetTest, RegisterDatasetOp) { EXPECT_FALSE(data::DatasetOpRegistry::IsRegistered("InvalidDatasetOp")); } +TEST(DatasetTest, FullName) { + EXPECT_EQ(data::FullName("prefix", "name"), + "60d899aa0d8ce4351e7c3b419e92d25b|prefix:name"); +} + enum DataTypeTest { _tf_int_32, _tf_int_64, diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index f9ed541a937..c8ebea599c8 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -26,7 +26,7 @@ cc_library( ":map_dataset_op", ":name_utils", ":range_dataset_op", - ":split_providers", + ":split_utils", ":take_dataset_op", ":tensor_slice_dataset_op", "//tensorflow/core:core_cpu", @@ -738,31 +738,6 @@ tf_cc_test( ], ) -cc_library( - name = "split_providers", - srcs = ["split_providers.cc"], - hdrs = ["split_providers.h"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core/platform:errors", - ], -) - -tf_cc_test( - name = "split_providers_test", - size = "small", - srcs = ["split_providers_test.cc"], - deps = [ - ":dataset_test_base", - ":dataset_utils", - ":split_providers", - "//tensorflow/core:framework", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/framework:tensor_testutil", - ], -) - tf_kernel_library( name = "take_dataset_op", srcs = ["take_dataset_op.cc"], @@ -923,6 +898,32 @@ tf_cc_test( ], ) +cc_library( + name = "split_utils", + srcs = ["split_utils.cc"], + hdrs = ["split_utils.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/kernels/data:iterator_ops", + ], +) + +tf_cc_test( + name = "split_utils_test", + size = "small", + srcs = ["split_utils_test.cc"], + deps = [ + ":dataset_test_base", + ":dataset_utils", + ":split_utils", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/framework:tensor_testutil", + ], +) + tf_kernel_library( name = "tensor_dataset_op", srcs = ["tensor_dataset_op.cc"], @@ -955,6 +956,7 @@ tf_kernel_library( deps = [ ":dataset_utils", ":name_utils", + ":split_utils", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:graph", diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc index 5f7edb9f18b..83c7673fc0c 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.cc +++ b/tensorflow/core/kernels/data/dataset_test_base.cc @@ -62,7 +62,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/map_dataset_op.h" #include "tensorflow/core/kernels/data/name_utils.h" #include "tensorflow/core/kernels/data/range_dataset_op.h" -#include "tensorflow/core/kernels/data/split_providers.h" +#include "tensorflow/core/kernels/data/split_utils.h" #include "tensorflow/core/kernels/data/take_dataset_op.h" #include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/kernels/data/dataset_utils_test.cc b/tensorflow/core/kernels/data/dataset_utils_test.cc index 95d6a7e93cc..cc2b8dfa2cd 100644 --- a/tensorflow/core/kernels/data/dataset_utils_test.cc +++ b/tensorflow/core/kernels/data/dataset_utils_test.cc @@ -62,9 +62,7 @@ class DatasetHashUtilsTest : public ::testing::Test { } }; -string full_name(string key) { - return strings::StrCat(kFullNameRandomHex, kPipe, "Iterator:", key); -} +string full_name(string key) { return FullName("Iterator:", key); } TEST(DatasetUtilsTest, MatchesAnyVersion) { EXPECT_TRUE(MatchesAnyVersion("BatchDataset", "BatchDataset")); diff --git a/tensorflow/core/kernels/data/split_providers_test.cc b/tensorflow/core/kernels/data/split_providers_test.cc deleted file mode 100644 index 06d62754ed7..00000000000 --- a/tensorflow/core/kernels/data/split_providers_test.cc +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/kernels/data/split_providers.h" - -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/kernels/data/dataset_test_base.h" -#include "tensorflow/core/kernels/data/dataset_utils.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace data { -namespace { -constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b"; -constexpr char kPipe[] = "|"; -constexpr char kColon[] = ":"; -constexpr char kSplits[] = "splits"; -constexpr char kSplitsSize[] = "splits_size"; - -std::string full_name(const std::string& name) { - return strings::StrCat(kFullNameRandomHex, kPipe, "test", kColon, name); -} - -Status SaveAndRestore(SplitProvider* split_provider) { - VariantTensorDataWriter writer; - TF_RETURN_IF_ERROR(split_provider->Save(full_name, &writer)); - std::vector variants; - writer.GetData(&variants); - VariantTensorDataReader reader(variants); - TF_RETURN_IF_ERROR(split_provider->Restore(full_name, &reader)); - return Status::OK(); -} - -// A split provider that provides pre-defined splits. -class TestSplitProvider : public SplitProvider { - public: - explicit TestSplitProvider(std::vector splits) : splits_(splits) {} - - Status GetNext(Tensor* split, bool* end_of_splits) override { - *end_of_splits = i_ >= splits_.size(); - if (*end_of_splits) { - return Status::OK(); - } - *split = splits_[i_++]; - return Status::OK(); - } - - Status Reset() override { - i_ = 0; - return Status::OK(); - } - - Status Save(std::function full_name, - IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name(kSplitsSize), splits_.size())); - for (int i = 0; i < splits_.size(); ++i) { - TF_RETURN_IF_ERROR(writer->WriteTensor( - full_name(absl::StrCat(kSplits, "[", i, "]")), splits_[i])); - } - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i_"), i_)); - return Status::OK(); - } - - Status Restore(std::function full_name, - IteratorStateReader* reader) override { - int64 splits_size; - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name(kSplitsSize), &splits_size)); - splits_.clear(); - for (int i = 0; i < splits_size; ++i) { - splits_.emplace_back(); - TF_RETURN_IF_ERROR(reader->ReadTensor( - full_name(absl::StrCat(kSplits, "[", i, "]")), &splits_.back())); - } - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i_"), &i_)); - return Status::OK(); - } - - private: - std::vector splits_; - int64 i_ = 0; -}; - -Status CheckOutput(ShardingSplitProvider& split_provider, - std::vector expected) { - int64 next = 0; - bool end_of_splits = false; - while (!end_of_splits) { - Tensor split; - TF_RETURN_IF_ERROR(split_provider.GetNext(&split, &end_of_splits)); - if (!end_of_splits) { - test::ExpectEqual(split, expected[next++]); - } - } - EXPECT_EQ(next, expected.size()); - return Status::OK(); -} - -TEST(ShardingSplitProvider, TwoWayShardZero) { - auto base = std::make_shared( - CreateTensors(TensorShape({}), {{0}, {1}, {2}, {3}})); - ShardingSplitProvider split_provider(2, 0, base); - TF_EXPECT_OK(CheckOutput(split_provider, - CreateTensors(TensorShape({}), {{0}, {2}}))); -} - -TEST(ShardingSplitProvider, TwoWayShardOne) { - auto base = std::make_shared( - CreateTensors(TensorShape({}), {{0}, {1}, {2}, {3}})); - ShardingSplitProvider split_provider(2, 1, base); - TF_EXPECT_OK(CheckOutput(split_provider, - CreateTensors(TensorShape({}), {{1}, {3}}))); -} - -TEST(ShardingSplitProvider, ThreeWayShardOne) { - auto base = std::make_shared( - CreateTensors(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})); - ShardingSplitProvider split_provider(3, 1, base); - TF_EXPECT_OK(CheckOutput(split_provider, - CreateTensors(TensorShape({}), {{1}, {4}}))); -} - -TEST(ShardingSplitProvider, Empty) { - auto base = std::make_shared( - CreateTensors(TensorShape({}), {{0}})); - ShardingSplitProvider split_provider(2, 1, base); - TF_EXPECT_OK( - CheckOutput(split_provider, CreateTensors(TensorShape({}), {}))); -} - -TEST(ShardingSplitProvider, SaveAndRestore) { - auto base = std::make_shared( - CreateTensors(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})); - std::vector expected = - CreateTensors(TensorShape({}), {{1}, {4}}); - ShardingSplitProvider split_provider(3, 1, base); - for (int i = 0; i < expected.size(); ++i) { - TF_ASSERT_OK(SaveAndRestore(&split_provider)); - Tensor split; - bool end_of_splits = true; - TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits)); - EXPECT_FALSE(end_of_splits); - test::ExpectEqual(split, expected[i]); - } - TF_ASSERT_OK(SaveAndRestore(&split_provider)); - Tensor split; - bool end_of_splits = false; - TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits)); - EXPECT_TRUE(end_of_splits); -} - -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/split_providers.cc b/tensorflow/core/kernels/data/split_utils.cc similarity index 73% rename from tensorflow/core/kernels/data/split_providers.cc rename to tensorflow/core/kernels/data/split_utils.cc index d39bf4865f9..def079169db 100644 --- a/tensorflow/core/kernels/data/split_providers.cc +++ b/tensorflow/core/kernels/data/split_utils.cc @@ -12,19 +12,51 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/data/split_providers.h" - -#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/kernels/data/split_utils.h" namespace tensorflow { namespace data { - namespace { constexpr char kNumToSkip[] = "num_to_skip"; constexpr char kSplitProvider[] = "split_provider"; constexpr char kSlash[] = "/"; +constexpr char kIndex[] = "index"; } // namespace +IndexSplitProvider::IndexSplitProvider(int64 n) : i_(0), n_(n) {} + +Status IndexSplitProvider::GetNext(Tensor* split, bool* end_of_splits) { + mutex_lock l(mu_); + if (i_ >= n_) { + *end_of_splits = true; + return Status::OK(); + } + *end_of_splits = false; + *split = Tensor(DT_INT64, TensorShape{}); + split->scalar()() = i_++; + return Status::OK(); +} + +Status IndexSplitProvider::Reset() { + mutex_lock l(mu_); + i_ = 0; + return Status::OK(); +} + +Status IndexSplitProvider::Save( + std::function full_name, + IteratorStateWriter* writer) { + mutex_lock l(mu_); + return writer->WriteScalar(full_name(kIndex), i_); +} + +Status IndexSplitProvider::Restore( + std::function full_name, + IteratorStateReader* reader) { + mutex_lock l(mu_); + return reader->ReadScalar(full_name(kIndex), &i_); +} + ShardingSplitProvider::ShardingSplitProvider( int64 num_shards, int64 shard_index, std::shared_ptr split_provider) diff --git a/tensorflow/core/kernels/data/split_providers.h b/tensorflow/core/kernels/data/split_utils.h similarity index 68% rename from tensorflow/core/kernels/data/split_providers.h rename to tensorflow/core/kernels/data/split_utils.h index e9c02dfd9bf..82fd4e8c0a4 100644 --- a/tensorflow/core/kernels/data/split_providers.h +++ b/tensorflow/core/kernels/data/split_utils.h @@ -12,14 +12,33 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_ -#define TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_SPLIT_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_SPLIT_UTILS_H_ #include "tensorflow/core/framework/dataset.h" namespace tensorflow { namespace data { +// A class which produces splits for a dataset of size N that can be indexed +// into. +class IndexSplitProvider : public SplitProvider { + public: + explicit IndexSplitProvider(int64 n); + Status GetNext(Tensor* split, bool* end_of_splits) override; + Status Reset() override; + Status Save(std::function full_name, + IteratorStateWriter* writer) override; + Status Restore(std::function full_name, + IteratorStateReader* reader) override; + + private: + mutex mu_; + int64 i_ GUARDED_BY(mu_); + const int64 n_; +}; + // A SplitProvider which wraps another split provider, but drops all splits // where `index != shard_index % num_shards` class ShardingSplitProvider : public SplitProvider { @@ -45,4 +64,4 @@ class ShardingSplitProvider : public SplitProvider { } // namespace data } // namespace tensorflow -#endif // TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_ +#endif // TENSORFLOW_CORE_KERNELS_DATA_SPLIT_UTILS_H_ diff --git a/tensorflow/core/kernels/data/split_utils_test.cc b/tensorflow/core/kernels/data/split_utils_test.cc new file mode 100644 index 00000000000..651f6c1c894 --- /dev/null +++ b/tensorflow/core/kernels/data/split_utils_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/split_utils.h" + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/data/dataset_test_base.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { +namespace { +std::string full_name(const std::string& name) { + return FullName("test", name); +} + +Status SaveAndRestore(SplitProvider* split_provider) { + VariantTensorDataWriter writer; + TF_RETURN_IF_ERROR(split_provider->Save(full_name, &writer)); + std::vector variants; + writer.GetData(&variants); + VariantTensorDataReader reader(variants); + TF_RETURN_IF_ERROR(split_provider->Restore(full_name, &reader)); + return Status::OK(); +} + +Status CheckOutput(SplitProvider* split_provider, + std::vector expected) { + int64 next = 0; + bool end_of_splits = false; + while (!end_of_splits) { + Tensor split; + TF_RETURN_IF_ERROR(split_provider->GetNext(&split, &end_of_splits)); + if (!end_of_splits) { + test::ExpectEqual(split, expected[next++]); + } + } + EXPECT_EQ(next, expected.size()); + return Status::OK(); +} + +TEST(IndexSplitProviderTest, Empty) { + IndexSplitProvider split_provider(0); + TF_EXPECT_OK( + CheckOutput(&split_provider, CreateTensors(TensorShape({}), {}))); +} + +TEST(IndexSplitProviderTest, One) { + IndexSplitProvider split_provider(1); + TF_EXPECT_OK(CheckOutput(&split_provider, + CreateTensors(TensorShape({}), {{0}}))); +} + +TEST(IndexSplitProviderTest, Three) { + IndexSplitProvider split_provider(3); + TF_EXPECT_OK(CheckOutput( + &split_provider, CreateTensors(TensorShape({}), {{0}, {1}, {2}}))); +} + +TEST(IndexSplitProviderTest, SaveAndRestore) { + IndexSplitProvider split_provider(4); + std::vector expected = + CreateTensors(TensorShape({}), {{0}, {1}, {2}, {3}}); + for (int i = 0; i < expected.size(); ++i) { + TF_ASSERT_OK(SaveAndRestore(&split_provider)); + Tensor split; + bool end_of_splits = true; + TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits)); + EXPECT_FALSE(end_of_splits); + test::ExpectEqual(split, expected[i]); + } + TF_ASSERT_OK(SaveAndRestore(&split_provider)); + Tensor split; + bool end_of_splits = false; + TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits)); + EXPECT_TRUE(end_of_splits); +} + +TEST(ShardingSplitProviderTest, TwoWayShardZero) { + auto base = std::make_shared(4); + ShardingSplitProvider split_provider(2, 0, base); + TF_EXPECT_OK(CheckOutput(&split_provider, + CreateTensors(TensorShape({}), {{0}, {2}}))); +} + +TEST(ShardingSplitProviderTest, TwoWayShardOne) { + auto base = std::make_shared(4); + ShardingSplitProvider split_provider(2, 1, base); + TF_EXPECT_OK(CheckOutput(&split_provider, + CreateTensors(TensorShape({}), {{1}, {3}}))); +} + +TEST(ShardingSplitProviderTest, ThreeWayShardOne) { + auto base = std::make_shared(6); + ShardingSplitProvider split_provider(3, 1, base); + TF_EXPECT_OK(CheckOutput(&split_provider, + CreateTensors(TensorShape({}), {{1}, {4}}))); +} + +TEST(ShardingSplitProviderTest, Empty) { + auto base = std::make_shared(1); + ShardingSplitProvider split_provider(2, 1, base); + TF_EXPECT_OK( + CheckOutput(&split_provider, CreateTensors(TensorShape({}), {}))); +} + +TEST(ShardingSplitProviderTest, SaveAndRestore) { + auto base = std::make_shared(6); + std::vector expected = + CreateTensors(TensorShape({}), {{1}, {4}}); + ShardingSplitProvider split_provider(3, 1, base); + for (int i = 0; i < expected.size(); ++i) { + TF_ASSERT_OK(SaveAndRestore(&split_provider)); + Tensor split; + bool end_of_splits = true; + TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits)); + EXPECT_FALSE(end_of_splits); + test::ExpectEqual(split, expected[i]); + } + TF_ASSERT_OK(SaveAndRestore(&split_provider)); + Tensor split; + bool end_of_splits = false; + TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits)); + EXPECT_TRUE(end_of_splits); +} + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc index 2ab713259d1..2a8f422d90a 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/name_utils.h" +#include "tensorflow/core/kernels/data/split_utils.h" #include "tensorflow/core/util/batch_util.h" namespace tensorflow { @@ -57,6 +58,13 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase { this, name_utils::IteratorPrefix(kDatasetType, prefix)}); } + Status MakeSplitProvider( + std::unique_ptr* split_provider) const override { + *split_provider = + absl::make_unique(tensors_[0].dim_size(0)); + return Status::OK(); + } + const DataTypeVector& output_dtypes() const override { return dtypes_; } const std::vector& output_shapes() const override { @@ -103,24 +111,26 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - i_(0), - n_(params.dataset->tensors_[0].dim_size(0)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + split_provider_ = ctx->split_provider(); + if (split_provider_ == nullptr) { + split_provider_ = std::make_shared( + dataset()->tensors_[0].dim_size(0)); + } + return Status::OK(); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { - int64 index = 0; - { - mutex_lock l(mu_); - if (i_ < n_) { - index = i_; - ++i_; - } else { - *end_of_sequence = true; - return Status::OK(); - } + Tensor split; + TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence)); + if (*end_of_sequence) { + return Status::OK(); } + int64 index = split.scalar()(); out_tensors->clear(); out_tensors->reserve(dataset()->tensors_.size()); for (size_t i = 0; i < dataset()->tensors_.size(); ++i) { @@ -142,22 +152,18 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase { Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_)); - return Status::OK(); + return split_provider_->Save( + [this](const std::string& key) { return full_name(key); }, writer); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &i_)); - return Status::OK(); + return split_provider_->Restore( + [this](const std::string& key) { return full_name(key); }, reader); } private: - mutex mu_; - int64 i_ TF_GUARDED_BY(mu_); - const int64 n_; + std::shared_ptr split_provider_; }; const std::vector tensors_; diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc index a42ac083ba2..f9bccd45173 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc @@ -319,6 +319,30 @@ INSTANTIATE_TEST_SUITE_P( TensorSliceDatasetOpTest, ParameterizedIteratorSaveAndRestoreTest, ::testing::ValuesIn(IteratorSaveAndRestoreTestCases())); +TEST_F(TensorSliceDatasetOpTest, SplitProvider) { + auto params = TensorSliceDatasetParams( + CreateTensors(TensorShape({7}), {{6, 2, 3, 8, 7, 0, 10}}), + kNodeName); + TF_ASSERT_OK(InitializeRuntime(params)); + TF_EXPECT_OK(CheckSplitProviderFullIteration( + params, CreateTensors(TensorShape({}), + {{6}, {2}, {3}, {8}, {7}, {0}, {10}}))); + TF_EXPECT_OK(CheckSplitProviderShardedIteration( + params, /*num_shards=*/3, /*shard_index=*/1, + CreateTensors(TensorShape({}), {{2}, {7}}))); +} + +TEST_F(TensorSliceDatasetOpTest, SplitProviderEmpty) { + auto params = TensorSliceDatasetParams( + CreateTensors(TensorShape({0}), {{}}), kNodeName); + TF_ASSERT_OK(InitializeRuntime(params)); + TF_EXPECT_OK(CheckSplitProviderFullIteration( + params, CreateTensors(TensorShape({}), {}))); + TF_EXPECT_OK(CheckSplitProviderShardedIteration( + params, /*num_shards=*/3, /*shard_index=*/1, + CreateTensors(TensorShape({}), {}))); +} + } // namespace } // namespace data } // namespace tensorflow