[tf.data] Implement splitting for tensor_slices_dataset.
PiperOrigin-RevId: 332081114 Change-Id: Ief684874407b1603a9eedad28acd91c5cc04f7c6
This commit is contained in:
parent
3ef3e633e5
commit
0e4d30319a
@ -383,6 +383,14 @@ int64 GetTotalBytes(const std::vector<Tensor>& element) {
|
||||
return total_bytes;
|
||||
}
|
||||
|
||||
std::string FullName(const std::string& prefix, const std::string& name) {
|
||||
if (str_util::StrContains(name, kColon)) {
|
||||
LOG(ERROR) << name << " should not contain " << kColon;
|
||||
}
|
||||
|
||||
return strings::StrCat(kFullNameRandomHex, kPipe, prefix, kColon, name);
|
||||
}
|
||||
|
||||
Status GetDatasetFromVariantTensor(const Tensor& tensor,
|
||||
DatasetBase** out_dataset) {
|
||||
if (!(tensor.dtype() == DT_VARIANT &&
|
||||
|
@ -121,6 +121,10 @@ class IteratorStateWriter {
|
||||
virtual ~IteratorStateWriter() {}
|
||||
};
|
||||
|
||||
// Generates a full name key for iterator checkpointing. All keys generated for
|
||||
// iterator checkpoints should go through this function.
|
||||
std::string FullName(const std::string& prefix, const std::string& name);
|
||||
|
||||
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
|
||||
class GraphDefBuilderWrapper {
|
||||
public:
|
||||
@ -972,12 +976,7 @@ class DatasetBaseIterator : public IteratorBase {
|
||||
bool* end_of_sequence, int* num_skipped);
|
||||
|
||||
string full_name(const string& name) const {
|
||||
if (str_util::StrContains(name, kColon)) {
|
||||
LOG(ERROR) << name << " should not contain " << kColon;
|
||||
}
|
||||
|
||||
return strings::StrCat(kFullNameRandomHex, kPipe, params_.prefix, kColon,
|
||||
name);
|
||||
return FullName(params_.prefix, name);
|
||||
}
|
||||
|
||||
// Returns a map of key-value pairs to included in the TraceMe string.
|
||||
|
@ -27,6 +27,11 @@ TEST(DatasetTest, RegisterDatasetOp) {
|
||||
EXPECT_FALSE(data::DatasetOpRegistry::IsRegistered("InvalidDatasetOp"));
|
||||
}
|
||||
|
||||
TEST(DatasetTest, FullName) {
|
||||
EXPECT_EQ(data::FullName("prefix", "name"),
|
||||
"60d899aa0d8ce4351e7c3b419e92d25b|prefix:name");
|
||||
}
|
||||
|
||||
enum DataTypeTest {
|
||||
_tf_int_32,
|
||||
_tf_int_64,
|
||||
|
@ -26,7 +26,7 @@ cc_library(
|
||||
":map_dataset_op",
|
||||
":name_utils",
|
||||
":range_dataset_op",
|
||||
":split_providers",
|
||||
":split_utils",
|
||||
":take_dataset_op",
|
||||
":tensor_slice_dataset_op",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -738,31 +738,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "split_providers",
|
||||
srcs = ["split_providers.cc"],
|
||||
hdrs = ["split_providers.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "split_providers_test",
|
||||
size = "small",
|
||||
srcs = ["split_providers_test.cc"],
|
||||
deps = [
|
||||
":dataset_test_base",
|
||||
":dataset_utils",
|
||||
":split_providers",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/framework:tensor_testutil",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "take_dataset_op",
|
||||
srcs = ["take_dataset_op.cc"],
|
||||
@ -923,6 +898,32 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "split_utils",
|
||||
srcs = ["split_utils.cc"],
|
||||
hdrs = ["split_utils.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/kernels/data:iterator_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "split_utils_test",
|
||||
size = "small",
|
||||
srcs = ["split_utils_test.cc"],
|
||||
deps = [
|
||||
":dataset_test_base",
|
||||
":dataset_utils",
|
||||
":split_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/framework:tensor_testutil",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "tensor_dataset_op",
|
||||
srcs = ["tensor_dataset_op.cc"],
|
||||
@ -955,6 +956,7 @@ tf_kernel_library(
|
||||
deps = [
|
||||
":dataset_utils",
|
||||
":name_utils",
|
||||
":split_utils",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
|
@ -62,7 +62,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/data/map_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/kernels/data/range_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/split_providers.h"
|
||||
#include "tensorflow/core/kernels/data/split_utils.h"
|
||||
#include "tensorflow/core/kernels/data/take_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
@ -62,9 +62,7 @@ class DatasetHashUtilsTest : public ::testing::Test {
|
||||
}
|
||||
};
|
||||
|
||||
string full_name(string key) {
|
||||
return strings::StrCat(kFullNameRandomHex, kPipe, "Iterator:", key);
|
||||
}
|
||||
string full_name(string key) { return FullName("Iterator:", key); }
|
||||
|
||||
TEST(DatasetUtilsTest, MatchesAnyVersion) {
|
||||
EXPECT_TRUE(MatchesAnyVersion("BatchDataset", "BatchDataset"));
|
||||
|
@ -1,167 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/split_providers.h"
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b";
|
||||
constexpr char kPipe[] = "|";
|
||||
constexpr char kColon[] = ":";
|
||||
constexpr char kSplits[] = "splits";
|
||||
constexpr char kSplitsSize[] = "splits_size";
|
||||
|
||||
std::string full_name(const std::string& name) {
|
||||
return strings::StrCat(kFullNameRandomHex, kPipe, "test", kColon, name);
|
||||
}
|
||||
|
||||
Status SaveAndRestore(SplitProvider* split_provider) {
|
||||
VariantTensorDataWriter writer;
|
||||
TF_RETURN_IF_ERROR(split_provider->Save(full_name, &writer));
|
||||
std::vector<const VariantTensorData*> variants;
|
||||
writer.GetData(&variants);
|
||||
VariantTensorDataReader reader(variants);
|
||||
TF_RETURN_IF_ERROR(split_provider->Restore(full_name, &reader));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A split provider that provides pre-defined splits.
|
||||
class TestSplitProvider : public SplitProvider {
|
||||
public:
|
||||
explicit TestSplitProvider(std::vector<Tensor> splits) : splits_(splits) {}
|
||||
|
||||
Status GetNext(Tensor* split, bool* end_of_splits) override {
|
||||
*end_of_splits = i_ >= splits_.size();
|
||||
if (*end_of_splits) {
|
||||
return Status::OK();
|
||||
}
|
||||
*split = splits_[i_++];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Reset() override {
|
||||
i_ = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Save(std::function<std::string(std::string)> full_name,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kSplitsSize), splits_.size()));
|
||||
for (int i = 0; i < splits_.size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||
full_name(absl::StrCat(kSplits, "[", i, "]")), splits_[i]));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i_"), i_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Restore(std::function<std::string(std::string)> full_name,
|
||||
IteratorStateReader* reader) override {
|
||||
int64 splits_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kSplitsSize), &splits_size));
|
||||
splits_.clear();
|
||||
for (int i = 0; i < splits_size; ++i) {
|
||||
splits_.emplace_back();
|
||||
TF_RETURN_IF_ERROR(reader->ReadTensor(
|
||||
full_name(absl::StrCat(kSplits, "[", i, "]")), &splits_.back()));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i_"), &i_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<Tensor> splits_;
|
||||
int64 i_ = 0;
|
||||
};
|
||||
|
||||
Status CheckOutput(ShardingSplitProvider& split_provider,
|
||||
std::vector<Tensor> expected) {
|
||||
int64 next = 0;
|
||||
bool end_of_splits = false;
|
||||
while (!end_of_splits) {
|
||||
Tensor split;
|
||||
TF_RETURN_IF_ERROR(split_provider.GetNext(&split, &end_of_splits));
|
||||
if (!end_of_splits) {
|
||||
test::ExpectEqual(split, expected[next++]);
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(next, expected.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST(ShardingSplitProvider, TwoWayShardZero) {
|
||||
auto base = std::make_shared<TestSplitProvider>(
|
||||
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}}));
|
||||
ShardingSplitProvider split_provider(2, 0, base);
|
||||
TF_EXPECT_OK(CheckOutput(split_provider,
|
||||
CreateTensors<int64>(TensorShape({}), {{0}, {2}})));
|
||||
}
|
||||
|
||||
TEST(ShardingSplitProvider, TwoWayShardOne) {
|
||||
auto base = std::make_shared<TestSplitProvider>(
|
||||
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}}));
|
||||
ShardingSplitProvider split_provider(2, 1, base);
|
||||
TF_EXPECT_OK(CheckOutput(split_provider,
|
||||
CreateTensors<int64>(TensorShape({}), {{1}, {3}})));
|
||||
}
|
||||
|
||||
TEST(ShardingSplitProvider, ThreeWayShardOne) {
|
||||
auto base = std::make_shared<TestSplitProvider>(
|
||||
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}}));
|
||||
ShardingSplitProvider split_provider(3, 1, base);
|
||||
TF_EXPECT_OK(CheckOutput(split_provider,
|
||||
CreateTensors<int64>(TensorShape({}), {{1}, {4}})));
|
||||
}
|
||||
|
||||
TEST(ShardingSplitProvider, Empty) {
|
||||
auto base = std::make_shared<TestSplitProvider>(
|
||||
CreateTensors<int64>(TensorShape({}), {{0}}));
|
||||
ShardingSplitProvider split_provider(2, 1, base);
|
||||
TF_EXPECT_OK(
|
||||
CheckOutput(split_provider, CreateTensors<int64>(TensorShape({}), {})));
|
||||
}
|
||||
|
||||
TEST(ShardingSplitProvider, SaveAndRestore) {
|
||||
auto base = std::make_shared<TestSplitProvider>(
|
||||
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}}));
|
||||
std::vector<Tensor> expected =
|
||||
CreateTensors<int64>(TensorShape({}), {{1}, {4}});
|
||||
ShardingSplitProvider split_provider(3, 1, base);
|
||||
for (int i = 0; i < expected.size(); ++i) {
|
||||
TF_ASSERT_OK(SaveAndRestore(&split_provider));
|
||||
Tensor split;
|
||||
bool end_of_splits = true;
|
||||
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
|
||||
EXPECT_FALSE(end_of_splits);
|
||||
test::ExpectEqual(split, expected[i]);
|
||||
}
|
||||
TF_ASSERT_OK(SaveAndRestore(&split_provider));
|
||||
Tensor split;
|
||||
bool end_of_splits = false;
|
||||
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
|
||||
EXPECT_TRUE(end_of_splits);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -12,19 +12,51 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/split_providers.h"
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/kernels/data/split_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
namespace {
|
||||
constexpr char kNumToSkip[] = "num_to_skip";
|
||||
constexpr char kSplitProvider[] = "split_provider";
|
||||
constexpr char kSlash[] = "/";
|
||||
constexpr char kIndex[] = "index";
|
||||
} // namespace
|
||||
|
||||
IndexSplitProvider::IndexSplitProvider(int64 n) : i_(0), n_(n) {}
|
||||
|
||||
Status IndexSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
|
||||
mutex_lock l(mu_);
|
||||
if (i_ >= n_) {
|
||||
*end_of_splits = true;
|
||||
return Status::OK();
|
||||
}
|
||||
*end_of_splits = false;
|
||||
*split = Tensor(DT_INT64, TensorShape{});
|
||||
split->scalar<int64>()() = i_++;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IndexSplitProvider::Reset() {
|
||||
mutex_lock l(mu_);
|
||||
i_ = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IndexSplitProvider::Save(
|
||||
std::function<std::string(std::string)> full_name,
|
||||
IteratorStateWriter* writer) {
|
||||
mutex_lock l(mu_);
|
||||
return writer->WriteScalar(full_name(kIndex), i_);
|
||||
}
|
||||
|
||||
Status IndexSplitProvider::Restore(
|
||||
std::function<std::string(std::string)> full_name,
|
||||
IteratorStateReader* reader) {
|
||||
mutex_lock l(mu_);
|
||||
return reader->ReadScalar(full_name(kIndex), &i_);
|
||||
}
|
||||
|
||||
ShardingSplitProvider::ShardingSplitProvider(
|
||||
int64 num_shards, int64 shard_index,
|
||||
std::shared_ptr<SplitProvider> split_provider)
|
@ -12,14 +12,33 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_SPLIT_UTILS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_SPLIT_UTILS_H_
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// A class which produces splits for a dataset of size N that can be indexed
|
||||
// into.
|
||||
class IndexSplitProvider : public SplitProvider {
|
||||
public:
|
||||
explicit IndexSplitProvider(int64 n);
|
||||
Status GetNext(Tensor* split, bool* end_of_splits) override;
|
||||
Status Reset() override;
|
||||
Status Save(std::function<std::string(std::string)> full_name,
|
||||
IteratorStateWriter* writer) override;
|
||||
Status Restore(std::function<std::string(std::string)> full_name,
|
||||
IteratorStateReader* reader) override;
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
int64 i_ GUARDED_BY(mu_);
|
||||
const int64 n_;
|
||||
};
|
||||
|
||||
// A SplitProvider which wraps another split provider, but drops all splits
|
||||
// where `index != shard_index % num_shards`
|
||||
class ShardingSplitProvider : public SplitProvider {
|
||||
@ -45,4 +64,4 @@ class ShardingSplitProvider : public SplitProvider {
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_SPLIT_UTILS_H_
|
142
tensorflow/core/kernels/data/split_utils_test.cc
Normal file
142
tensorflow/core/kernels/data/split_utils_test.cc
Normal file
@ -0,0 +1,142 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/split_utils.h"
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
std::string full_name(const std::string& name) {
|
||||
return FullName("test", name);
|
||||
}
|
||||
|
||||
Status SaveAndRestore(SplitProvider* split_provider) {
|
||||
VariantTensorDataWriter writer;
|
||||
TF_RETURN_IF_ERROR(split_provider->Save(full_name, &writer));
|
||||
std::vector<const VariantTensorData*> variants;
|
||||
writer.GetData(&variants);
|
||||
VariantTensorDataReader reader(variants);
|
||||
TF_RETURN_IF_ERROR(split_provider->Restore(full_name, &reader));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckOutput(SplitProvider* split_provider,
|
||||
std::vector<Tensor> expected) {
|
||||
int64 next = 0;
|
||||
bool end_of_splits = false;
|
||||
while (!end_of_splits) {
|
||||
Tensor split;
|
||||
TF_RETURN_IF_ERROR(split_provider->GetNext(&split, &end_of_splits));
|
||||
if (!end_of_splits) {
|
||||
test::ExpectEqual(split, expected[next++]);
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(next, expected.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST(IndexSplitProviderTest, Empty) {
|
||||
IndexSplitProvider split_provider(0);
|
||||
TF_EXPECT_OK(
|
||||
CheckOutput(&split_provider, CreateTensors<int64>(TensorShape({}), {})));
|
||||
}
|
||||
|
||||
TEST(IndexSplitProviderTest, One) {
|
||||
IndexSplitProvider split_provider(1);
|
||||
TF_EXPECT_OK(CheckOutput(&split_provider,
|
||||
CreateTensors<int64>(TensorShape({}), {{0}})));
|
||||
}
|
||||
|
||||
TEST(IndexSplitProviderTest, Three) {
|
||||
IndexSplitProvider split_provider(3);
|
||||
TF_EXPECT_OK(CheckOutput(
|
||||
&split_provider, CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}})));
|
||||
}
|
||||
|
||||
TEST(IndexSplitProviderTest, SaveAndRestore) {
|
||||
IndexSplitProvider split_provider(4);
|
||||
std::vector<Tensor> expected =
|
||||
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}});
|
||||
for (int i = 0; i < expected.size(); ++i) {
|
||||
TF_ASSERT_OK(SaveAndRestore(&split_provider));
|
||||
Tensor split;
|
||||
bool end_of_splits = true;
|
||||
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
|
||||
EXPECT_FALSE(end_of_splits);
|
||||
test::ExpectEqual(split, expected[i]);
|
||||
}
|
||||
TF_ASSERT_OK(SaveAndRestore(&split_provider));
|
||||
Tensor split;
|
||||
bool end_of_splits = false;
|
||||
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
|
||||
EXPECT_TRUE(end_of_splits);
|
||||
}
|
||||
|
||||
TEST(ShardingSplitProviderTest, TwoWayShardZero) {
|
||||
auto base = std::make_shared<IndexSplitProvider>(4);
|
||||
ShardingSplitProvider split_provider(2, 0, base);
|
||||
TF_EXPECT_OK(CheckOutput(&split_provider,
|
||||
CreateTensors<int64>(TensorShape({}), {{0}, {2}})));
|
||||
}
|
||||
|
||||
TEST(ShardingSplitProviderTest, TwoWayShardOne) {
|
||||
auto base = std::make_shared<IndexSplitProvider>(4);
|
||||
ShardingSplitProvider split_provider(2, 1, base);
|
||||
TF_EXPECT_OK(CheckOutput(&split_provider,
|
||||
CreateTensors<int64>(TensorShape({}), {{1}, {3}})));
|
||||
}
|
||||
|
||||
TEST(ShardingSplitProviderTest, ThreeWayShardOne) {
|
||||
auto base = std::make_shared<IndexSplitProvider>(6);
|
||||
ShardingSplitProvider split_provider(3, 1, base);
|
||||
TF_EXPECT_OK(CheckOutput(&split_provider,
|
||||
CreateTensors<int64>(TensorShape({}), {{1}, {4}})));
|
||||
}
|
||||
|
||||
TEST(ShardingSplitProviderTest, Empty) {
|
||||
auto base = std::make_shared<IndexSplitProvider>(1);
|
||||
ShardingSplitProvider split_provider(2, 1, base);
|
||||
TF_EXPECT_OK(
|
||||
CheckOutput(&split_provider, CreateTensors<int64>(TensorShape({}), {})));
|
||||
}
|
||||
|
||||
TEST(ShardingSplitProviderTest, SaveAndRestore) {
|
||||
auto base = std::make_shared<IndexSplitProvider>(6);
|
||||
std::vector<Tensor> expected =
|
||||
CreateTensors<int64>(TensorShape({}), {{1}, {4}});
|
||||
ShardingSplitProvider split_provider(3, 1, base);
|
||||
for (int i = 0; i < expected.size(); ++i) {
|
||||
TF_ASSERT_OK(SaveAndRestore(&split_provider));
|
||||
Tensor split;
|
||||
bool end_of_splits = true;
|
||||
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
|
||||
EXPECT_FALSE(end_of_splits);
|
||||
test::ExpectEqual(split, expected[i]);
|
||||
}
|
||||
TF_ASSERT_OK(SaveAndRestore(&split_provider));
|
||||
Tensor split;
|
||||
bool end_of_splits = false;
|
||||
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
|
||||
EXPECT_TRUE(end_of_splits);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/kernels/data/split_utils.h"
|
||||
#include "tensorflow/core/util/batch_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -57,6 +58,13 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
|
||||
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
|
||||
}
|
||||
|
||||
Status MakeSplitProvider(
|
||||
std::unique_ptr<SplitProvider>* split_provider) const override {
|
||||
*split_provider =
|
||||
absl::make_unique<IndexSplitProvider>(tensors_[0].dim_size(0));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override { return dtypes_; }
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
@ -103,24 +111,26 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params),
|
||||
i_(0),
|
||||
n_(params.dataset->tensors_[0].dim_size(0)) {}
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
split_provider_ = ctx->split_provider();
|
||||
if (split_provider_ == nullptr) {
|
||||
split_provider_ = std::make_shared<IndexSplitProvider>(
|
||||
dataset()->tensors_[0].dim_size(0));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
int64 index = 0;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (i_ < n_) {
|
||||
index = i_;
|
||||
++i_;
|
||||
} else {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
Tensor split;
|
||||
TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence));
|
||||
if (*end_of_sequence) {
|
||||
return Status::OK();
|
||||
}
|
||||
int64 index = split.scalar<int64>()();
|
||||
out_tensors->clear();
|
||||
out_tensors->reserve(dataset()->tensors_.size());
|
||||
for (size_t i = 0; i < dataset()->tensors_.size(); ++i) {
|
||||
@ -142,22 +152,18 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_));
|
||||
return Status::OK();
|
||||
return split_provider_->Save(
|
||||
[this](const std::string& key) { return full_name(key); }, writer);
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &i_));
|
||||
return Status::OK();
|
||||
return split_provider_->Restore(
|
||||
[this](const std::string& key) { return full_name(key); }, reader);
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
int64 i_ TF_GUARDED_BY(mu_);
|
||||
const int64 n_;
|
||||
std::shared_ptr<SplitProvider> split_provider_;
|
||||
};
|
||||
|
||||
const std::vector<Tensor> tensors_;
|
||||
|
@ -319,6 +319,30 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
TensorSliceDatasetOpTest, ParameterizedIteratorSaveAndRestoreTest,
|
||||
::testing::ValuesIn(IteratorSaveAndRestoreTestCases()));
|
||||
|
||||
TEST_F(TensorSliceDatasetOpTest, SplitProvider) {
|
||||
auto params = TensorSliceDatasetParams(
|
||||
CreateTensors<int64>(TensorShape({7}), {{6, 2, 3, 8, 7, 0, 10}}),
|
||||
kNodeName);
|
||||
TF_ASSERT_OK(InitializeRuntime(params));
|
||||
TF_EXPECT_OK(CheckSplitProviderFullIteration(
|
||||
params, CreateTensors<int64>(TensorShape({}),
|
||||
{{6}, {2}, {3}, {8}, {7}, {0}, {10}})));
|
||||
TF_EXPECT_OK(CheckSplitProviderShardedIteration(
|
||||
params, /*num_shards=*/3, /*shard_index=*/1,
|
||||
CreateTensors<int64>(TensorShape({}), {{2}, {7}})));
|
||||
}
|
||||
|
||||
TEST_F(TensorSliceDatasetOpTest, SplitProviderEmpty) {
|
||||
auto params = TensorSliceDatasetParams(
|
||||
CreateTensors<int64>(TensorShape({0}), {{}}), kNodeName);
|
||||
TF_ASSERT_OK(InitializeRuntime(params));
|
||||
TF_EXPECT_OK(CheckSplitProviderFullIteration(
|
||||
params, CreateTensors<int64>(TensorShape({}), {})));
|
||||
TF_EXPECT_OK(CheckSplitProviderShardedIteration(
|
||||
params, /*num_shards=*/3, /*shard_index=*/1,
|
||||
CreateTensors<int64>(TensorShape({}), {})));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user