[tf.data] Implement splitting for tensor_slices_dataset.

PiperOrigin-RevId: 332081114
Change-Id: Ief684874407b1603a9eedad28acd91c5cc04f7c6
This commit is contained in:
Andrew Audibert 2020-09-16 13:52:59 -07:00 committed by TensorFlower Gardener
parent 3ef3e633e5
commit 0e4d30319a
12 changed files with 300 additions and 232 deletions

View File

@ -383,6 +383,14 @@ int64 GetTotalBytes(const std::vector<Tensor>& element) {
return total_bytes;
}
std::string FullName(const std::string& prefix, const std::string& name) {
if (str_util::StrContains(name, kColon)) {
LOG(ERROR) << name << " should not contain " << kColon;
}
return strings::StrCat(kFullNameRandomHex, kPipe, prefix, kColon, name);
}
Status GetDatasetFromVariantTensor(const Tensor& tensor,
DatasetBase** out_dataset) {
if (!(tensor.dtype() == DT_VARIANT &&

View File

@ -121,6 +121,10 @@ class IteratorStateWriter {
virtual ~IteratorStateWriter() {}
};
// Generates a full name key for iterator checkpointing. All keys generated for
// iterator checkpoints should go through this function.
std::string FullName(const std::string& prefix, const std::string& name);
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
class GraphDefBuilderWrapper {
public:
@ -972,12 +976,7 @@ class DatasetBaseIterator : public IteratorBase {
bool* end_of_sequence, int* num_skipped);
string full_name(const string& name) const {
if (str_util::StrContains(name, kColon)) {
LOG(ERROR) << name << " should not contain " << kColon;
}
return strings::StrCat(kFullNameRandomHex, kPipe, params_.prefix, kColon,
name);
return FullName(params_.prefix, name);
}
// Returns a map of key-value pairs to included in the TraceMe string.

View File

@ -27,6 +27,11 @@ TEST(DatasetTest, RegisterDatasetOp) {
EXPECT_FALSE(data::DatasetOpRegistry::IsRegistered("InvalidDatasetOp"));
}
TEST(DatasetTest, FullName) {
EXPECT_EQ(data::FullName("prefix", "name"),
"60d899aa0d8ce4351e7c3b419e92d25b|prefix:name");
}
enum DataTypeTest {
_tf_int_32,
_tf_int_64,

View File

@ -26,7 +26,7 @@ cc_library(
":map_dataset_op",
":name_utils",
":range_dataset_op",
":split_providers",
":split_utils",
":take_dataset_op",
":tensor_slice_dataset_op",
"//tensorflow/core:core_cpu",
@ -738,31 +738,6 @@ tf_cc_test(
],
)
cc_library(
name = "split_providers",
srcs = ["split_providers.cc"],
hdrs = ["split_providers.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core/platform:errors",
],
)
tf_cc_test(
name = "split_providers_test",
size = "small",
srcs = ["split_providers_test.cc"],
deps = [
":dataset_test_base",
":dataset_utils",
":split_providers",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/framework:tensor_testutil",
],
)
tf_kernel_library(
name = "take_dataset_op",
srcs = ["take_dataset_op.cc"],
@ -923,6 +898,32 @@ tf_cc_test(
],
)
cc_library(
name = "split_utils",
srcs = ["split_utils.cc"],
hdrs = ["split_utils.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels/data:iterator_ops",
],
)
tf_cc_test(
name = "split_utils_test",
size = "small",
srcs = ["split_utils_test.cc"],
deps = [
":dataset_test_base",
":dataset_utils",
":split_utils",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/framework:tensor_testutil",
],
)
tf_kernel_library(
name = "tensor_dataset_op",
srcs = ["tensor_dataset_op.cc"],
@ -955,6 +956,7 @@ tf_kernel_library(
deps = [
":dataset_utils",
":name_utils",
":split_utils",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:graph",

View File

@ -62,7 +62,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/map_dataset_op.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/kernels/data/range_dataset_op.h"
#include "tensorflow/core/kernels/data/split_providers.h"
#include "tensorflow/core/kernels/data/split_utils.h"
#include "tensorflow/core/kernels/data/take_dataset_op.h"
#include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h"
#include "tensorflow/core/lib/core/status_test_util.h"

View File

@ -62,9 +62,7 @@ class DatasetHashUtilsTest : public ::testing::Test {
}
};
string full_name(string key) {
return strings::StrCat(kFullNameRandomHex, kPipe, "Iterator:", key);
}
string full_name(string key) { return FullName("Iterator:", key); }
TEST(DatasetUtilsTest, MatchesAnyVersion) {
EXPECT_TRUE(MatchesAnyVersion("BatchDataset", "BatchDataset"));

View File

@ -1,167 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/split_providers.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b";
constexpr char kPipe[] = "|";
constexpr char kColon[] = ":";
constexpr char kSplits[] = "splits";
constexpr char kSplitsSize[] = "splits_size";
std::string full_name(const std::string& name) {
return strings::StrCat(kFullNameRandomHex, kPipe, "test", kColon, name);
}
Status SaveAndRestore(SplitProvider* split_provider) {
VariantTensorDataWriter writer;
TF_RETURN_IF_ERROR(split_provider->Save(full_name, &writer));
std::vector<const VariantTensorData*> variants;
writer.GetData(&variants);
VariantTensorDataReader reader(variants);
TF_RETURN_IF_ERROR(split_provider->Restore(full_name, &reader));
return Status::OK();
}
// A split provider that provides pre-defined splits.
class TestSplitProvider : public SplitProvider {
public:
explicit TestSplitProvider(std::vector<Tensor> splits) : splits_(splits) {}
Status GetNext(Tensor* split, bool* end_of_splits) override {
*end_of_splits = i_ >= splits_.size();
if (*end_of_splits) {
return Status::OK();
}
*split = splits_[i_++];
return Status::OK();
}
Status Reset() override {
i_ = 0;
return Status::OK();
}
Status Save(std::function<std::string(std::string)> full_name,
IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(kSplitsSize), splits_.size()));
for (int i = 0; i < splits_.size(); ++i) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(absl::StrCat(kSplits, "[", i, "]")), splits_[i]));
}
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i_"), i_));
return Status::OK();
}
Status Restore(std::function<std::string(std::string)> full_name,
IteratorStateReader* reader) override {
int64 splits_size;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name(kSplitsSize), &splits_size));
splits_.clear();
for (int i = 0; i < splits_size; ++i) {
splits_.emplace_back();
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(absl::StrCat(kSplits, "[", i, "]")), &splits_.back()));
}
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i_"), &i_));
return Status::OK();
}
private:
std::vector<Tensor> splits_;
int64 i_ = 0;
};
Status CheckOutput(ShardingSplitProvider& split_provider,
std::vector<Tensor> expected) {
int64 next = 0;
bool end_of_splits = false;
while (!end_of_splits) {
Tensor split;
TF_RETURN_IF_ERROR(split_provider.GetNext(&split, &end_of_splits));
if (!end_of_splits) {
test::ExpectEqual(split, expected[next++]);
}
}
EXPECT_EQ(next, expected.size());
return Status::OK();
}
TEST(ShardingSplitProvider, TwoWayShardZero) {
auto base = std::make_shared<TestSplitProvider>(
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}}));
ShardingSplitProvider split_provider(2, 0, base);
TF_EXPECT_OK(CheckOutput(split_provider,
CreateTensors<int64>(TensorShape({}), {{0}, {2}})));
}
TEST(ShardingSplitProvider, TwoWayShardOne) {
auto base = std::make_shared<TestSplitProvider>(
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}}));
ShardingSplitProvider split_provider(2, 1, base);
TF_EXPECT_OK(CheckOutput(split_provider,
CreateTensors<int64>(TensorShape({}), {{1}, {3}})));
}
TEST(ShardingSplitProvider, ThreeWayShardOne) {
auto base = std::make_shared<TestSplitProvider>(
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}}));
ShardingSplitProvider split_provider(3, 1, base);
TF_EXPECT_OK(CheckOutput(split_provider,
CreateTensors<int64>(TensorShape({}), {{1}, {4}})));
}
TEST(ShardingSplitProvider, Empty) {
auto base = std::make_shared<TestSplitProvider>(
CreateTensors<int64>(TensorShape({}), {{0}}));
ShardingSplitProvider split_provider(2, 1, base);
TF_EXPECT_OK(
CheckOutput(split_provider, CreateTensors<int64>(TensorShape({}), {})));
}
TEST(ShardingSplitProvider, SaveAndRestore) {
auto base = std::make_shared<TestSplitProvider>(
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}}));
std::vector<Tensor> expected =
CreateTensors<int64>(TensorShape({}), {{1}, {4}});
ShardingSplitProvider split_provider(3, 1, base);
for (int i = 0; i < expected.size(); ++i) {
TF_ASSERT_OK(SaveAndRestore(&split_provider));
Tensor split;
bool end_of_splits = true;
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
EXPECT_FALSE(end_of_splits);
test::ExpectEqual(split, expected[i]);
}
TF_ASSERT_OK(SaveAndRestore(&split_provider));
Tensor split;
bool end_of_splits = false;
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
EXPECT_TRUE(end_of_splits);
}
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -12,19 +12,51 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/split_providers.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/kernels/data/split_utils.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kNumToSkip[] = "num_to_skip";
constexpr char kSplitProvider[] = "split_provider";
constexpr char kSlash[] = "/";
constexpr char kIndex[] = "index";
} // namespace
IndexSplitProvider::IndexSplitProvider(int64 n) : i_(0), n_(n) {}
Status IndexSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
mutex_lock l(mu_);
if (i_ >= n_) {
*end_of_splits = true;
return Status::OK();
}
*end_of_splits = false;
*split = Tensor(DT_INT64, TensorShape{});
split->scalar<int64>()() = i_++;
return Status::OK();
}
Status IndexSplitProvider::Reset() {
mutex_lock l(mu_);
i_ = 0;
return Status::OK();
}
Status IndexSplitProvider::Save(
std::function<std::string(std::string)> full_name,
IteratorStateWriter* writer) {
mutex_lock l(mu_);
return writer->WriteScalar(full_name(kIndex), i_);
}
Status IndexSplitProvider::Restore(
std::function<std::string(std::string)> full_name,
IteratorStateReader* reader) {
mutex_lock l(mu_);
return reader->ReadScalar(full_name(kIndex), &i_);
}
ShardingSplitProvider::ShardingSplitProvider(
int64 num_shards, int64 shard_index,
std::shared_ptr<SplitProvider> split_provider)

View File

@ -12,14 +12,33 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_
#define TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_
#ifndef TENSORFLOW_CORE_KERNELS_DATA_SPLIT_UTILS_H_
#define TENSORFLOW_CORE_KERNELS_DATA_SPLIT_UTILS_H_
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
namespace data {
// A class which produces splits for a dataset of size N that can be indexed
// into.
class IndexSplitProvider : public SplitProvider {
public:
explicit IndexSplitProvider(int64 n);
Status GetNext(Tensor* split, bool* end_of_splits) override;
Status Reset() override;
Status Save(std::function<std::string(std::string)> full_name,
IteratorStateWriter* writer) override;
Status Restore(std::function<std::string(std::string)> full_name,
IteratorStateReader* reader) override;
private:
mutex mu_;
int64 i_ GUARDED_BY(mu_);
const int64 n_;
};
// A SplitProvider which wraps another split provider, but drops all splits
// where `index != shard_index % num_shards`
class ShardingSplitProvider : public SplitProvider {
@ -45,4 +64,4 @@ class ShardingSplitProvider : public SplitProvider {
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_
#endif // TENSORFLOW_CORE_KERNELS_DATA_SPLIT_UTILS_H_

View File

@ -0,0 +1,142 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/split_utils.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace data {
namespace {
std::string full_name(const std::string& name) {
return FullName("test", name);
}
Status SaveAndRestore(SplitProvider* split_provider) {
VariantTensorDataWriter writer;
TF_RETURN_IF_ERROR(split_provider->Save(full_name, &writer));
std::vector<const VariantTensorData*> variants;
writer.GetData(&variants);
VariantTensorDataReader reader(variants);
TF_RETURN_IF_ERROR(split_provider->Restore(full_name, &reader));
return Status::OK();
}
Status CheckOutput(SplitProvider* split_provider,
std::vector<Tensor> expected) {
int64 next = 0;
bool end_of_splits = false;
while (!end_of_splits) {
Tensor split;
TF_RETURN_IF_ERROR(split_provider->GetNext(&split, &end_of_splits));
if (!end_of_splits) {
test::ExpectEqual(split, expected[next++]);
}
}
EXPECT_EQ(next, expected.size());
return Status::OK();
}
TEST(IndexSplitProviderTest, Empty) {
IndexSplitProvider split_provider(0);
TF_EXPECT_OK(
CheckOutput(&split_provider, CreateTensors<int64>(TensorShape({}), {})));
}
TEST(IndexSplitProviderTest, One) {
IndexSplitProvider split_provider(1);
TF_EXPECT_OK(CheckOutput(&split_provider,
CreateTensors<int64>(TensorShape({}), {{0}})));
}
TEST(IndexSplitProviderTest, Three) {
IndexSplitProvider split_provider(3);
TF_EXPECT_OK(CheckOutput(
&split_provider, CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}})));
}
TEST(IndexSplitProviderTest, SaveAndRestore) {
IndexSplitProvider split_provider(4);
std::vector<Tensor> expected =
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}});
for (int i = 0; i < expected.size(); ++i) {
TF_ASSERT_OK(SaveAndRestore(&split_provider));
Tensor split;
bool end_of_splits = true;
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
EXPECT_FALSE(end_of_splits);
test::ExpectEqual(split, expected[i]);
}
TF_ASSERT_OK(SaveAndRestore(&split_provider));
Tensor split;
bool end_of_splits = false;
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
EXPECT_TRUE(end_of_splits);
}
TEST(ShardingSplitProviderTest, TwoWayShardZero) {
auto base = std::make_shared<IndexSplitProvider>(4);
ShardingSplitProvider split_provider(2, 0, base);
TF_EXPECT_OK(CheckOutput(&split_provider,
CreateTensors<int64>(TensorShape({}), {{0}, {2}})));
}
TEST(ShardingSplitProviderTest, TwoWayShardOne) {
auto base = std::make_shared<IndexSplitProvider>(4);
ShardingSplitProvider split_provider(2, 1, base);
TF_EXPECT_OK(CheckOutput(&split_provider,
CreateTensors<int64>(TensorShape({}), {{1}, {3}})));
}
TEST(ShardingSplitProviderTest, ThreeWayShardOne) {
auto base = std::make_shared<IndexSplitProvider>(6);
ShardingSplitProvider split_provider(3, 1, base);
TF_EXPECT_OK(CheckOutput(&split_provider,
CreateTensors<int64>(TensorShape({}), {{1}, {4}})));
}
TEST(ShardingSplitProviderTest, Empty) {
auto base = std::make_shared<IndexSplitProvider>(1);
ShardingSplitProvider split_provider(2, 1, base);
TF_EXPECT_OK(
CheckOutput(&split_provider, CreateTensors<int64>(TensorShape({}), {})));
}
TEST(ShardingSplitProviderTest, SaveAndRestore) {
auto base = std::make_shared<IndexSplitProvider>(6);
std::vector<Tensor> expected =
CreateTensors<int64>(TensorShape({}), {{1}, {4}});
ShardingSplitProvider split_provider(3, 1, base);
for (int i = 0; i < expected.size(); ++i) {
TF_ASSERT_OK(SaveAndRestore(&split_provider));
Tensor split;
bool end_of_splits = true;
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
EXPECT_FALSE(end_of_splits);
test::ExpectEqual(split, expected[i]);
}
TF_ASSERT_OK(SaveAndRestore(&split_provider));
Tensor split;
bool end_of_splits = false;
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
EXPECT_TRUE(end_of_splits);
}
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/kernels/data/split_utils.h"
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
@ -57,6 +58,13 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
}
Status MakeSplitProvider(
std::unique_ptr<SplitProvider>* split_provider) const override {
*split_provider =
absl::make_unique<IndexSplitProvider>(tensors_[0].dim_size(0));
return Status::OK();
}
const DataTypeVector& output_dtypes() const override { return dtypes_; }
const std::vector<PartialTensorShape>& output_shapes() const override {
@ -103,24 +111,26 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
i_(0),
n_(params.dataset->tensors_[0].dim_size(0)) {}
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
split_provider_ = ctx->split_provider();
if (split_provider_ == nullptr) {
split_provider_ = std::make_shared<IndexSplitProvider>(
dataset()->tensors_[0].dim_size(0));
}
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
int64 index = 0;
{
mutex_lock l(mu_);
if (i_ < n_) {
index = i_;
++i_;
} else {
*end_of_sequence = true;
return Status::OK();
}
Tensor split;
TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence));
if (*end_of_sequence) {
return Status::OK();
}
int64 index = split.scalar<int64>()();
out_tensors->clear();
out_tensors->reserve(dataset()->tensors_.size());
for (size_t i = 0; i < dataset()->tensors_.size(); ++i) {
@ -142,22 +152,18 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_));
return Status::OK();
return split_provider_->Save(
[this](const std::string& key) { return full_name(key); }, writer);
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &i_));
return Status::OK();
return split_provider_->Restore(
[this](const std::string& key) { return full_name(key); }, reader);
}
private:
mutex mu_;
int64 i_ TF_GUARDED_BY(mu_);
const int64 n_;
std::shared_ptr<SplitProvider> split_provider_;
};
const std::vector<Tensor> tensors_;

View File

@ -319,6 +319,30 @@ INSTANTIATE_TEST_SUITE_P(
TensorSliceDatasetOpTest, ParameterizedIteratorSaveAndRestoreTest,
::testing::ValuesIn(IteratorSaveAndRestoreTestCases()));
TEST_F(TensorSliceDatasetOpTest, SplitProvider) {
auto params = TensorSliceDatasetParams(
CreateTensors<int64>(TensorShape({7}), {{6, 2, 3, 8, 7, 0, 10}}),
kNodeName);
TF_ASSERT_OK(InitializeRuntime(params));
TF_EXPECT_OK(CheckSplitProviderFullIteration(
params, CreateTensors<int64>(TensorShape({}),
{{6}, {2}, {3}, {8}, {7}, {0}, {10}})));
TF_EXPECT_OK(CheckSplitProviderShardedIteration(
params, /*num_shards=*/3, /*shard_index=*/1,
CreateTensors<int64>(TensorShape({}), {{2}, {7}})));
}
TEST_F(TensorSliceDatasetOpTest, SplitProviderEmpty) {
auto params = TensorSliceDatasetParams(
CreateTensors<int64>(TensorShape({0}), {{}}), kNodeName);
TF_ASSERT_OK(InitializeRuntime(params));
TF_EXPECT_OK(CheckSplitProviderFullIteration(
params, CreateTensors<int64>(TensorShape({}), {})));
TF_EXPECT_OK(CheckSplitProviderShardedIteration(
params, /*num_shards=*/3, /*shard_index=*/1,
CreateTensors<int64>(TensorShape({}), {})));
}
} // namespace
} // namespace data
} // namespace tensorflow