Merge pull request #27969 from feihugis:Test_ShuffleDataset

PiperOrigin-RevId: 245486029
This commit is contained in:
TensorFlower Gardener 2019-04-26 15:50:41 -07:00
commit a2b7b5d5a0
7 changed files with 990 additions and 14 deletions

View File

@ -667,6 +667,24 @@ tf_kernel_library(
], ],
) )
tf_cc_test(
name = "shuffle_dataset_op_test",
size = "small",
srcs = ["shuffle_dataset_op_test.cc"],
deps = [
"shuffle_dataset_op",
":dataset_test_base",
":dataset_utils",
":iterator_ops",
":range_dataset_op",
"//tensorflow/core:framework",
"//tensorflow/core:ptr_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_kernel_library( tf_kernel_library(
name = "sparse_tensor_slice_dataset_op", name = "sparse_tensor_slice_dataset_op",
srcs = ["sparse_tensor_slice_dataset_op.cc"], srcs = ["sparse_tensor_slice_dataset_op.cc"],

View File

@ -18,12 +18,39 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
template <typename T>
Status IsEqual(const Tensor& t1, const Tensor& t2) {
if (t1.dtype() != t2.dtype()) {
return tensorflow::errors::Internal(
"Two tensors have different dtypes: ", DataTypeString(t1.dtype()),
" vs. ", DataTypeString(t2.dtype()));
}
if (!t1.IsSameSize(t2)) {
return tensorflow::errors::Internal(
"Two tensors have different shapes: ", t1.shape().DebugString(),
" vs. ", t2.shape().DebugString());
}
auto flat_t1 = t1.flat<T>();
auto flat_t2 = t2.flat<T>();
auto length = flat_t1.size();
for (int i = 0; i < length; ++i) {
if (flat_t1(i) != flat_t2(i)) {
return tensorflow::errors::Internal(
"Two tensors have different values "
"at [",
i, "]: ", flat_t1(i), " vs. ", flat_t2(i));
}
}
return Status::OK();
}
Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) { Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
EXPECT_EQ(a.dtype(), b.dtype());
switch (a.dtype()) { switch (a.dtype()) {
#define CASE(type) \ #define CASE(DT) \
case DataTypeToEnum<type>::value: \ case DataTypeToEnum<DT>::value: \
test::ExpectTensorEqual<type>(a, b); \ TF_RETURN_IF_ERROR(IsEqual<DT>(a, b)); \
break; break;
TF_CALL_NUMBER_TYPES(CASE); TF_CALL_NUMBER_TYPES(CASE);
TF_CALL_string(CASE); TF_CALL_string(CASE);
@ -36,7 +63,7 @@ Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
} }
template <typename T> template <typename T>
bool compare(Tensor t1, Tensor t2) { bool compare(const Tensor& t1, const Tensor& t2) {
auto flat_t1 = t1.flat<T>(); auto flat_t1 = t1.flat<T>();
auto flat_t2 = t2.flat<T>(); auto flat_t2 = t2.flat<T>();
auto length = std::min(flat_t1.size(), flat_t2.size()); auto length = std::min(flat_t1.size(), flat_t2.size());
@ -49,7 +76,7 @@ bool compare(Tensor t1, Tensor t2) {
Status DatasetOpsTestBase::ExpectEqual(std::vector<Tensor> produced_tensors, Status DatasetOpsTestBase::ExpectEqual(std::vector<Tensor> produced_tensors,
std::vector<Tensor> expected_tensors, std::vector<Tensor> expected_tensors,
bool expect_items_equal) { bool compare_order) {
if (produced_tensors.size() != expected_tensors.size()) { if (produced_tensors.size() != expected_tensors.size()) {
return Status(tensorflow::errors::Internal( return Status(tensorflow::errors::Internal(
"The two tensor vectors have different size (", produced_tensors.size(), "The two tensor vectors have different size (", produced_tensors.size(),
@ -64,7 +91,7 @@ Status DatasetOpsTestBase::ExpectEqual(std::vector<Tensor> produced_tensors,
")")); ")"));
} }
if (expect_items_equal) { if (!compare_order) {
const DataType& dtype = produced_tensors[0].dtype(); const DataType& dtype = produced_tensors[0].dtype();
switch (dtype) { switch (dtype) {
#define CASE(DT) \ #define CASE(DT) \
@ -190,6 +217,7 @@ Status DatasetOpsTestBase::CreateIteratorContext(
OpKernelContext* const op_context, OpKernelContext* const op_context,
std::unique_ptr<IteratorContext>* iterator_context) { std::unique_ptr<IteratorContext>* iterator_context) {
IteratorContext::Params params(op_context); IteratorContext::Params params(op_context);
params.resource_mgr = op_context->resource_manager();
function_handle_cache_ = absl::make_unique<FunctionHandleCache>(flr_); function_handle_cache_ = absl::make_unique<FunctionHandleCache>(flr_);
params.function_handle_cache = function_handle_cache_.get(); params.function_handle_cache = function_handle_cache_.get();
*iterator_context = absl::make_unique<IteratorContext>(params); *iterator_context = absl::make_unique<IteratorContext>(params);
@ -228,6 +256,7 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices( TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices)); options, "/job:localhost/replica:0/task:0", &devices));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices)); device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
resource_mgr_ = absl::make_unique<ResourceMgr>("default_container");
FunctionDefLibrary proto; FunctionDefLibrary proto;
for (const auto& fdef : flib) *(proto.add_function()) = fdef; for (const auto& fdef : flib) *(proto.add_function()) = fdef;
@ -269,6 +298,7 @@ Status DatasetOpsTestBase::CreateOpKernelContext(
step_container_ = step_container_ =
absl::make_unique<ScopedStepContainer>(0, [](const string&) {}); absl::make_unique<ScopedStepContainer>(0, [](const string&) {});
params_->step_container = step_container_.get(); params_->step_container = step_container_.get();
params_->resource_manager = resource_mgr_.get();
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper; checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
slice_reader_cache_ = slice_reader_cache_ =
absl::make_unique<checkpoint::TensorSliceReaderCacheWrapper>(); absl::make_unique<checkpoint::TensorSliceReaderCacheWrapper>();

View File

@ -52,11 +52,11 @@ class DatasetOpsTestBase : public ::testing::Test {
static Status ExpectEqual(const Tensor& a, const Tensor& b); static Status ExpectEqual(const Tensor& a, const Tensor& b);
// The method validates whether the two tensor vectors have the same tensors. // The method validates whether the two tensor vectors have the same tensors.
// If `expect_items_equal` is true, the method will only evaluate the two // If `compare_order` is false, the method will only evaluate whether the two
// vectors have the same elements regardless of order. // vectors have the same elements regardless of order.
static Status ExpectEqual(std::vector<Tensor> produced_tensors, static Status ExpectEqual(std::vector<Tensor> produced_tensors,
std::vector<Tensor> expected_tensors, std::vector<Tensor> expected_tensors,
bool expect_items_equal); bool compare_order);
// Creates a tensor with the specified dtype, shape, and value. // Creates a tensor with the specified dtype, shape, and value.
template <typename T> template <typename T>
@ -206,6 +206,7 @@ class DatasetOpsTestBase : public ::testing::Test {
std::function<void(std::function<void()>)> runner_; std::function<void(std::function<void()>)> runner_;
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_; std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ResourceMgr> resource_mgr_;
std::unique_ptr<OpKernelContext::Params> params_; std::unique_ptr<OpKernelContext::Params> params_;
std::unique_ptr<checkpoint::TensorSliceReaderCacheWrapper> std::unique_ptr<checkpoint::TensorSliceReaderCacheWrapper>
slice_reader_cache_; slice_reader_cache_;

View File

@ -494,7 +494,7 @@ TEST_P(ParameterizedParallelInterleaveDatasetOpTest, GetNext) {
} }
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
/*expect_items_equal*/ test_case.sloppy)); /*compare_order*/ !test_case.sloppy));
} }
TEST_F(ParallelInterleaveDatasetOpTest, InvalidArguments) { TEST_F(ParallelInterleaveDatasetOpTest, InvalidArguments) {
@ -949,7 +949,7 @@ TEST_P(ParameterizedParallelInterleaveDatasetOpTest, Roundtrip) {
} }
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
/*expect_items_equal*/ test_case.sloppy)); /*compare_order*/ !test_case.sloppy));
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(

View File

@ -334,7 +334,7 @@ TEST_P(ParameterizedParallelMapDatasetOpTest, GetNext) {
} }
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
/*expect_items_equal*/ test_case.sloppy)); /*compare_order*/ !test_case.sloppy));
} }
TEST_F(ParallelMapDatasetOpTest, DatasetNodeName) { TEST_F(ParallelMapDatasetOpTest, DatasetNodeName) {
@ -769,7 +769,7 @@ TEST_P(ParameterizedParallelMapDatasetOpTest, Roundtrip) {
} }
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
/*expect_items_equal*/ test_case.sloppy)); /*compare_order*/ !test_case.sloppy));
} }
TEST_F(ParallelMapDatasetOpTest, InvalidNumParallelCalls) { TEST_F(ParallelMapDatasetOpTest, InvalidNumParallelCalls) {

View File

@ -63,7 +63,15 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
return input_->output_shapes(); return input_->output_shapes();
} }
int64 Cardinality() const override { return input_->Cardinality(); } int64 Cardinality() const override {
if (count_ == -1 || input_->Cardinality() == kInfiniteCardinality) {
return kInfiniteCardinality;
} else if (input_->Cardinality() == kUnknownCardinality) {
return kUnknownCardinality;
} else {
return input_->Cardinality() * count_;
}
}
protected: protected:
template <class T> template <class T>
@ -645,6 +653,10 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
int64 count; int64 count;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count)); OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
OP_REQUIRES(ctx, count > 0 || count == -1,
errors::InvalidArgument(
"count must be greater than zero or equal to -1."));
// By TensorFlow convention, if both seeds are 0, then shuffling should be // By TensorFlow convention, if both seeds are 0, then shuffling should be
// seeded non-deterministically. // seeded non-deterministically.
if (seed == 0 && seed2 == 0) { if (seed == 0 && seed2 == 0) {

View File

@ -0,0 +1,915 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/dataset_test_base.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kShuffleNodeName[] = "shuffle_dataset";
constexpr char kShuffleOpName[] = "ShuffleDataset";
constexpr char kShuffleAndRepeatNodeName[] = "shuffle_and_repeat_dataset";
constexpr char kShuffleAndRepeatOpName[] = "ShuffleAndRepeatDataset";
class ShuffleDatasetOpTest : public DatasetOpsTestBase {
protected:
// Creates a new `ShuffleDataset`/`ShuffleAndRepeatDataset` op kernel
Status CreateDatasetOpKernel(
int64 count, bool reshuffle_each_iteration,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
std::unique_ptr<OpKernel>* shuffle_dataset_kernel) {
NodeDef node_def;
if (count == 1) {
node_def = test::function::NDef(
kShuffleNodeName, kShuffleOpName,
{"input_dataset", "buffer_size", "seed", "seed2"},
{{"reshuffle_each_iteration", reshuffle_each_iteration},
{"output_types", output_types},
{"output_shapes", output_shapes}});
} else {
node_def = test::function::NDef(
kShuffleAndRepeatNodeName, kShuffleAndRepeatOpName,
{"input_dataset", "buffer_size", "seed", "seed2", "count"},
{{"output_types", output_types}, {"output_shapes", output_shapes}});
}
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, shuffle_dataset_kernel));
return Status::OK();
}
// Creates a new `ShuffleDataset`/`ShuffleAndRepeatDataset` op kernel context.
Status CreateDatasetContext(OpKernel* const op_kernel,
gtl::InlinedVector<TensorValue, 4>* const inputs,
std::unique_ptr<OpKernelContext>* context) {
TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
return Status::OK();
}
};
struct RangeDatasetParam {
int64 start;
int64 end;
int64 step;
};
struct TestCase {
RangeDatasetParam range_data_param;
Tensor buffer_size;
Tensor seed;
Tensor seed2;
Tensor count;
bool reshuffle_each_iteration;
std::vector<Tensor> expected_shuffle_outputs;
std::vector<Tensor> expected_reshuffle_outputs;
DataTypeVector expected_output_dtypes;
std::vector<PartialTensorShape> expected_output_shapes;
int64 expected_cardinality;
std::vector<int> breakpoints;
};
template <typename T>
std::vector<Tensor> ConvertToTensorVec(std::vector<T> values) {
std::vector<Tensor> tensors;
tensors.reserve(values.size());
for (auto& value : values) {
tensors.emplace_back(
DatasetOpsTestBase::CreateTensor<T>(TensorShape({}), {value}));
}
return tensors;
}
// Test case 1: test shuffle_dataset with reshuffle_each_iteration = false.
TestCase TestCase1() {
return {
/*range_data_param*/ {0, 10, 1},
/*buffer_size*/
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
/*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
/*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*reshuffle_each_iteration*/ false,
/*expected_shuffle_outputs*/
ConvertToTensorVec<int64>({2, 3, 0, 5, 6, 4, 7, 8, 9, 1}),
/*expected_reshuffle_outputs*/
ConvertToTensorVec<int64>({2, 3, 0, 5, 6, 4, 7, 8, 9, 1}),
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 10,
/*breakpoints*/ {0, 1, 9}};
}
// Test case 2: test shuffle_dataset with reshuffle_each_iteration = true.
TestCase TestCase2() {
return {
/*range_data_param*/ {0, 10, 1},
/*buffer_size*/
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
/*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
/*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*reshuffle_each_iteration*/ true,
/*expected_shuffle_outputs*/
ConvertToTensorVec<int64>({2, 6, 1, 3, 9, 5, 0, 8, 7, 4}),
/*expected_reshuffle_outputs*/
ConvertToTensorVec<int64>({1, 6, 0, 5, 2, 7, 4, 3, 9, 8}),
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 10,
/*breakpoints*/ {0, 1, 9}};
}
// Test case 3: similar with the test case 2 but a smaller buffer size than
// the input dataset.
TestCase TestCase3() {
return {
/*range_data_param*/ {0, 10, 1},
/*buffer_size*/
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
/*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
/*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*reshuffle_each_iteration*/ true,
/*expected_shuffle_outputs*/
ConvertToTensorVec<int64>({0, 2, 1, 3, 5, 6, 4, 7, 8, 9}),
/*expected_reshuffle_outputs*/
ConvertToTensorVec<int64>({1, 0, 2, 3, 4, 5, 6, 7, 9, 8}),
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 10,
/*breakpoints*/ {0, 1, 9}};
}
// Test case 4: similar with the test case 2 but has different seeds.
TestCase TestCase4() {
return {
/*range_data_param*/ {0, 10, 1},
/*buffer_size*/
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
/*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
/*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
/*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*reshuffle_each_iteration*/ true,
/*expected_shuffle_outputs*/
ConvertToTensorVec<int64>({3, 0, 8, 1, 5, 4, 7, 2, 6, 9}),
/*expected_reshuffle_outputs*/
ConvertToTensorVec<int64>({4, 6, 9, 0, 1, 8, 2, 7, 3, 5}),
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 10,
/*breakpoints*/ {0, 1, 9}};
}
// Test case 5: test shuffle_dataset with buffer_size = 1 &
// reshuffle_each_iteration = true.
TestCase TestCase5() {
return {
/*range_data_param*/ {0, 10, 1},
/*buffer_size*/
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
/*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*reshuffle_each_iteration*/ true,
/*expected_shuffle_outputs*/
ConvertToTensorVec<int64>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}),
/*expected_reshuffle_outputs*/
ConvertToTensorVec<int64>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}),
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 10,
/*breakpoints*/ {0, 1, 9}};
}
// Test case 6: test shuffle_dataset with an empty input dataset.
TestCase TestCase6() {
return {
/*range_data_param*/ {0, 0, 1},
/*buffer_size*/
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
/*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
/*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*reshuffle_each_iteration*/ true,
/*expected_shuffle_outputs*/
ConvertToTensorVec<int64>({}),
/*expected_reshuffle_outputs*/
ConvertToTensorVec<int64>({}),
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 0,
/*breakpoints*/ {0, 1, 9}};
}
// Test case 7: test shuffle_and_repeat_dataset with buffer_size = 10 &
// count = 2.
TestCase TestCase7() {
return {
/*range_data_param*/ {0, 10, 1},
/*buffer_size*/
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
/*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
/*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
/*reshuffle_each_iteration*/ false,
/*expected_shuffle_outputs*/
ConvertToTensorVec<int64>(
{9, 0, 8, 6, 1, 3, 7, 2, 4, 5, 4, 3, 0, 5, 8, 2, 6, 9, 7, 1}),
/*expected_reshuffle_outputs*/
ConvertToTensorVec<int64>(
{9, 0, 8, 6, 1, 3, 7, 2, 4, 5, 4, 3, 0, 5, 8, 2, 6, 9, 7, 1}),
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 20,
/*breakpoints*/ {0, 5, 22}};
}
// Test case 8: test shuffle_and_repeat_dataset with buffer_size = 10 &
// count = -1
TestCase TestCase8() {
return {
/*range_data_param*/ {0, 3, 1},
/*buffer_size*/
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
/*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
/*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {-1}),
/*reshuffle_each_iteration*/ false,
/*expected_shuffle_outputs*/
ConvertToTensorVec<int64>(
{2, 0, 1, 2, 0, 1, 1, 2, 0, 1, 0, 2, 2, 0, 1, 1, 0, 2, 2, 1, 0}),
/*expected_reshuffle_outputs*/
ConvertToTensorVec<int64>(
{2, 0, 1, 2, 0, 1, 1, 2, 0, 1, 0, 2, 2, 0, 1, 1, 0, 2, 2, 1, 0}),
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ kInfiniteCardinality,
/*breakpoints*/ {0, 5, 20}};
}
TestCase InvalidBufferSizeTestCaseForShuffleDataset() {
return {
/*range_data_param*/ {0, 10, 1},
/*buffer_size*/
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {-1}),
/*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
/*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*reshuffle_each_iteration*/ true,
/*expected_shuffle_outputs*/ ConvertToTensorVec<int64>({}),
/*expected_reshuffle_outputs*/ ConvertToTensorVec<int64>({}),
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 0,
/*breakpoints*/ {0, 1, 9}};
}
TestCase InvalidBufferSizeTestCaseForShuffleAndRepeatDataset() {
return {
/*range_data_param*/ {0, 10, 1},
/*buffer_size*/
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {-1}),
/*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
/*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
/*reshuffle_each_iteration*/ true,
/*expected_shuffle_outputs*/ ConvertToTensorVec<int64>({}),
/*expected_reshuffle_outputs*/ ConvertToTensorVec<int64>({}),
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 0,
/*breakpoints*/ {0, 1, 9}};
}
TestCase InvalidCountTestCaseForShuffleAndRepeatDataset() {
return {
/*range_data_param*/ {0, 3, 1},
/*buffer_size*/
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
/*seed*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
/*seed2*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
/*count*/ DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
/*reshuffle_each_iteration*/ false,
/*expected_shuffle_outputs*/
ConvertToTensorVec<int64>({}),
/*expected_reshuffle_outputs*/
ConvertToTensorVec<int64>({}),
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 0,
/*breakpoints*/ {0, 5, 20}};
}
class ParameterizedShuffleDatasetOpTest
: public ShuffleDatasetOpTest,
public ::testing::WithParamInterface<TestCase> {};
TEST_P(ParameterizedShuffleDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor count = test_case.count;
int64 count_value = count.flat<int64>()(0);
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(
CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration,
test_case.expected_output_dtypes,
test_case.expected_output_shapes, &dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor buffer_size = test_case.buffer_size;
Tensor seed = test_case.seed;
Tensor seed2 = test_case.seed2;
gtl::InlinedVector<TensorValue, 4> inputs(
{&range_dataset_tensor, &buffer_size, &seed, &seed2});
if (count_value != 1) inputs.push_back(&count);
std::unique_ptr<OpKernelContext> dataset_context;
TF_ASSERT_OK(
CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context));
DatasetBase* dataset;
TF_ASSERT_OK(
CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset));
core::ScopedUnref scoped_unref_dataset(dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
bool end_of_sequence = false;
std::vector<Tensor> shuffled_out_tensors;
while (!end_of_sequence) {
std::vector<Tensor> next;
TF_EXPECT_OK(
iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence));
shuffled_out_tensors.insert(shuffled_out_tensors.end(), next.begin(),
next.end());
// For the forever-repeat case, we test only a finite number of steps of
// the infinite sequence.
if (count_value == -1 && shuffled_out_tensors.size() ==
test_case.expected_shuffle_outputs.size()) {
break;
}
}
// Reshuffle the dataset.
end_of_sequence = false;
TF_ASSERT_OK(
dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
std::vector<Tensor> reshuffled_out_tensors;
while (!end_of_sequence) {
std::vector<Tensor> next;
TF_EXPECT_OK(
iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence));
reshuffled_out_tensors.insert(reshuffled_out_tensors.end(), next.begin(),
next.end());
// For the forever-repeat case, we test only a finite number of steps of
// the infinite sequence.
if (count_value == -1 && reshuffled_out_tensors.size() ==
test_case.expected_shuffle_outputs.size()) {
break;
}
}
TF_EXPECT_OK(ExpectEqual(shuffled_out_tensors,
test_case.expected_shuffle_outputs,
/*compare_order*/ true));
TF_EXPECT_OK(ExpectEqual(reshuffled_out_tensors,
test_case.expected_reshuffle_outputs,
/*compare_order*/ true));
}
TEST_P(ParameterizedShuffleDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor count = test_case.count;
int64 count_value = count.flat<int64>()(0);
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(
CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration,
test_case.expected_output_dtypes,
test_case.expected_output_shapes, &dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor buffer_size = test_case.buffer_size;
Tensor seed = test_case.seed;
Tensor seed2 = test_case.seed2;
gtl::InlinedVector<TensorValue, 4> inputs(
{&range_dataset_tensor, &buffer_size, &seed, &seed2});
if (count_value != 1) inputs.push_back(&count);
std::unique_ptr<OpKernelContext> dataset_context;
TF_ASSERT_OK(
CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context));
DatasetBase* dataset;
TF_ASSERT_OK(
CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset));
core::ScopedUnref scoped_unref_dataset(dataset);
if (count_value == 1) {
EXPECT_EQ(dataset->node_name(), kShuffleNodeName);
} else {
EXPECT_EQ(dataset->node_name(), kShuffleAndRepeatNodeName);
}
}
TEST_P(ParameterizedShuffleDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor count = test_case.count;
int64 count_value = count.flat<int64>()(0);
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(
CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration,
test_case.expected_output_dtypes,
test_case.expected_output_shapes, &dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor buffer_size = test_case.buffer_size;
Tensor seed = test_case.seed;
Tensor seed2 = test_case.seed2;
gtl::InlinedVector<TensorValue, 4> inputs(
{&range_dataset_tensor, &buffer_size, &seed, &seed2});
if (count_value != 1) inputs.push_back(&count);
std::unique_ptr<OpKernelContext> dataset_context;
TF_ASSERT_OK(
CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context));
DatasetBase* dataset;
TF_ASSERT_OK(
CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset));
core::ScopedUnref scoped_unref_dataset(dataset);
if (count_value == 1) {
EXPECT_EQ(dataset->type_string(), kShuffleOpName);
} else {
EXPECT_EQ(dataset->type_string(), kShuffleAndRepeatOpName);
}
}
TEST_P(ParameterizedShuffleDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor count = test_case.count;
int64 count_value = count.flat<int64>()(0);
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(
CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration,
test_case.expected_output_dtypes,
test_case.expected_output_shapes, &dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor buffer_size = test_case.buffer_size;
Tensor seed = test_case.seed;
Tensor seed2 = test_case.seed2;
gtl::InlinedVector<TensorValue, 4> inputs(
{&range_dataset_tensor, &buffer_size, &seed, &seed2});
if (count_value != 1) inputs.push_back(&count);
std::unique_ptr<OpKernelContext> dataset_context;
TF_ASSERT_OK(
CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context));
DatasetBase* dataset;
TF_ASSERT_OK(
CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset));
core::ScopedUnref scoped_unref_dataset(dataset);
TF_EXPECT_OK(VerifyTypesMatch(dataset->output_dtypes(),
test_case.expected_output_dtypes));
}
TEST_P(ParameterizedShuffleDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor count = test_case.count;
int64 count_value = count.flat<int64>()(0);
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(
CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration,
test_case.expected_output_dtypes,
test_case.expected_output_shapes, &dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor buffer_size = test_case.buffer_size;
Tensor seed = test_case.seed;
Tensor seed2 = test_case.seed2;
gtl::InlinedVector<TensorValue, 4> inputs(
{&range_dataset_tensor, &buffer_size, &seed, &seed2});
if (count_value != 1) inputs.push_back(&count);
std::unique_ptr<OpKernelContext> dataset_context;
TF_ASSERT_OK(
CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context));
DatasetBase* dataset;
TF_ASSERT_OK(
CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset));
core::ScopedUnref scoped_unref_dataset(dataset);
TF_EXPECT_OK(VerifyShapesCompatible(dataset->output_shapes(),
test_case.expected_output_shapes));
}
TEST_P(ParameterizedShuffleDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor count = test_case.count;
int64 count_value = count.flat<int64>()(0);
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(
CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration,
test_case.expected_output_dtypes,
test_case.expected_output_shapes, &dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor buffer_size = test_case.buffer_size;
Tensor seed = test_case.seed;
Tensor seed2 = test_case.seed2;
gtl::InlinedVector<TensorValue, 4> inputs(
{&range_dataset_tensor, &buffer_size, &seed, &seed2});
if (count_value != 1) inputs.push_back(&count);
std::unique_ptr<OpKernelContext> dataset_context;
TF_ASSERT_OK(
CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context));
DatasetBase* dataset;
TF_ASSERT_OK(
CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset));
core::ScopedUnref scoped_unref_dataset(dataset);
EXPECT_EQ(dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParameterizedShuffleDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor count = test_case.count;
int64 count_value = count.flat<int64>()(0);
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(
CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration,
test_case.expected_output_dtypes,
test_case.expected_output_shapes, &dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor buffer_size = test_case.buffer_size;
Tensor seed = test_case.seed;
Tensor seed2 = test_case.seed2;
gtl::InlinedVector<TensorValue, 4> inputs(
{&range_dataset_tensor, &buffer_size, &seed, &seed2});
if (count_value != 1) inputs.push_back(&count);
std::unique_ptr<OpKernelContext> dataset_context;
TF_ASSERT_OK(
CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context));
DatasetBase* dataset;
TF_ASSERT_OK(
CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset));
core::ScopedUnref scoped_unref_dataset(dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedShuffleDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor count = test_case.count;
int64 count_value = count.flat<int64>()(0);
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(
CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration,
test_case.expected_output_dtypes,
test_case.expected_output_shapes, &dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor buffer_size = test_case.buffer_size;
Tensor seed = test_case.seed;
Tensor seed2 = test_case.seed2;
gtl::InlinedVector<TensorValue, 4> inputs(
{&range_dataset_tensor, &buffer_size, &seed, &seed2});
if (count_value != 1) inputs.push_back(&count);
std::unique_ptr<OpKernelContext> dataset_context;
TF_ASSERT_OK(
CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context));
DatasetBase* dataset;
TF_ASSERT_OK(
CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset));
core::ScopedUnref scoped_unref_dataset(dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
test_case.expected_output_dtypes));
}
TEST_P(ParameterizedShuffleDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor count = test_case.count;
int64 count_value = count.flat<int64>()(0);
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(
CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration,
test_case.expected_output_dtypes,
test_case.expected_output_shapes, &dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor buffer_size = test_case.buffer_size;
Tensor seed = test_case.seed;
Tensor seed2 = test_case.seed2;
gtl::InlinedVector<TensorValue, 4> inputs(
{&range_dataset_tensor, &buffer_size, &seed, &seed2});
if (count_value != 1) inputs.push_back(&count);
std::unique_ptr<OpKernelContext> dataset_context;
TF_ASSERT_OK(
CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context));
DatasetBase* dataset;
TF_ASSERT_OK(
CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset));
core::ScopedUnref scoped_unref_dataset(dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
test_case.expected_output_shapes));
}
TEST_P(ParameterizedShuffleDatasetOpTest, IteratorOutputPrefix) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor count = test_case.count;
int64 count_value = count.flat<int64>()(0);
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(
CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration,
test_case.expected_output_dtypes,
test_case.expected_output_shapes, &dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor buffer_size = test_case.buffer_size;
Tensor seed = test_case.seed;
Tensor seed2 = test_case.seed2;
gtl::InlinedVector<TensorValue, 4> inputs(
{&range_dataset_tensor, &buffer_size, &seed, &seed2});
if (count_value != 1) inputs.push_back(&count);
std::unique_ptr<OpKernelContext> dataset_context;
TF_ASSERT_OK(
CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context));
DatasetBase* dataset;
TF_ASSERT_OK(
CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset));
core::ScopedUnref scoped_unref_dataset(dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
if (count_value == 1) {
EXPECT_EQ(iterator->prefix(), "Iterator::Shuffle");
} else {
EXPECT_EQ(iterator->prefix(), "Iterator::ShuffleAndRepeat");
}
}
TEST_P(ParameterizedShuffleDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor count = test_case.count;
int64 count_value = count.flat<int64>()(0);
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(
CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration,
test_case.expected_output_dtypes,
test_case.expected_output_shapes, &dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor buffer_size = test_case.buffer_size;
Tensor seed = test_case.seed;
Tensor seed2 = test_case.seed2;
gtl::InlinedVector<TensorValue, 4> inputs(
{&range_dataset_tensor, &buffer_size, &seed, &seed2});
if (count_value != 1) inputs.push_back(&count);
std::unique_ptr<OpKernelContext> dataset_context;
TF_ASSERT_OK(
CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context));
DatasetBase* dataset;
TF_ASSERT_OK(
CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset));
core::ScopedUnref scoped_unref_dataset(dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
int cur_iteration = 0;
const std::vector<int>& breakpoints = test_case.breakpoints;
for (int breakpoint : breakpoints) {
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator",
*dataset, &iterator));
while (cur_iteration <= breakpoint) {
std::vector<Tensor> next;
TF_EXPECT_OK(
iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence));
out_tensors.insert(out_tensors.end(), next.begin(), next.end());
cur_iteration++;
}
}
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_shuffle_outputs,
/*compare_order*/ true));
}
INSTANTIATE_TEST_SUITE_P(ShuffleDatasetOpTest,
ParameterizedShuffleDatasetOpTest,
::testing::ValuesIn(std::vector<TestCase>(
{TestCase1(), TestCase2(), TestCase3(),
TestCase4(), TestCase5(), TestCase6(),
TestCase7(), TestCase8()})));
TEST_F(ShuffleDatasetOpTest, InvalidArguments) {
int thread_num = 2, cpu_num = 2;
std::vector<TestCase> test_cases = {
InvalidBufferSizeTestCaseForShuffleDataset(),
InvalidBufferSizeTestCaseForShuffleAndRepeatDataset(),
InvalidCountTestCaseForShuffleAndRepeatDataset()};
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
for (const auto& test_case : test_cases) {
Tensor count = test_case.count;
int64 count_value = count.flat<int64>()(0);
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateDatasetOpKernel(
count_value, test_case.reshuffle_each_iteration,
test_case.expected_output_dtypes, test_case.expected_output_shapes,
&dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor buffer_size = test_case.buffer_size;
Tensor seed = test_case.seed;
Tensor seed2 = test_case.seed2;
gtl::InlinedVector<TensorValue, 4> inputs(
{&range_dataset_tensor, &buffer_size, &seed, &seed2});
if (count_value != 1) inputs.push_back(&count);
std::unique_ptr<OpKernelContext> dataset_context;
TF_ASSERT_OK(
CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context));
DatasetBase* shuffle_dataset;
EXPECT_EQ(CreateDataset(dataset_kernel.get(), dataset_context.get(),
&shuffle_dataset)
.code(),
tensorflow::error::INVALID_ARGUMENT);
}
}
} // namespace
} // namespace data
} // namespace tensorflow