From 8dc85de492592bca6eb08aa18e07cd1920385dad Mon Sep 17 00:00:00 2001 From: Fei Hu Date: Fri, 15 Mar 2019 16:46:03 -0700 Subject: [PATCH] Refactor RangeDatasetOpTest --- .../kernels/data/range_dataset_op_test.cc | 579 +++++++++++------- 1 file changed, 342 insertions(+), 237 deletions(-) diff --git a/tensorflow/core/kernels/data/range_dataset_op_test.cc b/tensorflow/core/kernels/data/range_dataset_op_test.cc index bfe091fd524..dd589265a74 100644 --- a/tensorflow/core/kernels/data/range_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/range_dataset_op_test.cc @@ -13,237 +13,324 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/function_testlib.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/variant.h" -#include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/kernels/data/dataset_test_base.h" -#include "tensorflow/core/kernels/data/dataset_utils.h" -#include "tensorflow/core/kernels/data/iterator_ops.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { namespace { +constexpr char kNodeName[] = "range_dataset"; constexpr char kOpName[] = "RangeDataset"; class RangeDatasetOpTest : public DatasetOpsTestBase { protected: // Creates a new RangeDataset op kernel context. Status CreateRangeDatasetContext( - int64 start, int64 end, int64 step, OpKernel* const range_kernel, + OpKernel* const range_kernel, + gtl::InlinedVector* const inputs, std::unique_ptr* range_context) { - inputs_.clear(); - TF_RETURN_IF_ERROR(AddDatasetInputFromArray( - &inputs_, range_kernel->input_types(), TensorShape({}), {start})); - TF_RETURN_IF_ERROR(AddDatasetInputFromArray( - &inputs_, range_kernel->input_types(), TensorShape({}), {end})); - TF_RETURN_IF_ERROR(AddDatasetInputFromArray( - &inputs_, range_kernel->input_types(), TensorShape({}), {step})); - + TF_RETURN_IF_ERROR(CheckOpKernelInput(*range_kernel, *inputs)); TF_RETURN_IF_ERROR( - CreateOpKernelContext(range_kernel, &inputs_, range_context)); - TF_RETURN_IF_ERROR(CheckOpKernelInput(*range_kernel, inputs_)); + CreateOpKernelContext(range_kernel, inputs, range_context)); return Status::OK(); } - - private: - gtl::InlinedVector inputs_; }; -struct GetNextTestParams { - explicit GetNextTestParams(int64 input_start, int64 input_end, - int64 input_step) - : start(input_start), end(input_end), step(input_step) {} - +struct TestCase { int64 start; int64 end; int64 step; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; }; -struct DatasetGetNextTest : RangeDatasetOpTest, - ::testing::WithParamInterface {}; +TestCase PositiveStepTestCase() { + return {/*start*/ 0, + /*end*/ 10, + /*step*/ 3, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {6}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {9})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 4}}; +} -TEST_P(DatasetGetNextTest, GetNext) { +TestCase NegativeStepTestCase() { + return {/*start*/ 10, + /*end*/ 0, + /*step*/ -3, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {10}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {7}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 4}}; +} + +TestCase ZeroStepTestCase() { + return {/*start*/ 0, + /*end*/ 10, + /*step*/ 0, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {}, + /*expected_output_shapes*/ {}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {}}; +} + +class ParameterizedRangeDatasetOpTest + : public RangeDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedRangeDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; - GetNextTestParams params = GetParam(); - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step, - range_kernel.get(), &range_context)); - DatasetBase* range_dataset; + TestCase test_case = GetParam(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(range_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); bool end_of_sequence = false; + auto expected_outputs_it = test_case.expected_outputs.begin(); std::vector out_tensors; while (!end_of_sequence) { TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, &end_of_sequence)); + if (!end_of_sequence) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it)); + expected_outputs_it++; + } } - std::vector expected_values; - for (int i = params.start; (params.end - i) * params.step > 0; - i = i + params.step) { - expected_values.reserve(1); - expected_values.emplace_back(i); - } - EXPECT_EQ(out_tensors.size(), expected_values.size()); - for (size_t i = 0; i < out_tensors.size(); ++i) { - int64 actual_value = out_tensors[i].flat()(0); - int64 expect_value = expected_values[i]; - EXPECT_EQ(actual_value, expect_value); - } + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); } -INSTANTIATE_TEST_CASE_P(RangeDatasetOpTest, DatasetGetNextTest, - ::testing::Values(GetNextTestParams(0, 10, 1), - GetNextTestParams(0, 10, 3), - GetNextTestParams(10, 0, -1), - GetNextTestParams(10, 0, -3))); - -TEST_F(RangeDatasetOpTest, DatasetName) { - int64 start = 0, end = 10, step = 1; +TEST_F(RangeDatasetOpTest, ZeroStep) { int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = ZeroStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + EXPECT_EQ(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset) + .code(), + tensorflow::error::INVALID_ARGUMENT); +} + +TEST_F(RangeDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; + TF_ASSERT_OK( + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); + core::ScopedUnref scoped_unref(range_dataset); + + EXPECT_EQ(range_dataset->node_name(), kNodeName); +} + +TEST_F(RangeDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; + TF_ASSERT_OK( + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); EXPECT_EQ(range_dataset->type_string(), kOpName); } TEST_F(RangeDatasetOpTest, DatasetOutputDtypes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); - DataTypeVector expected_dtypes({DT_INT64}); - EXPECT_EQ(range_dataset->output_dtypes(), expected_dtypes); + TF_EXPECT_OK(VerifyTypesMatch(range_dataset->output_dtypes(), + test_case.expected_output_dtypes)); } TEST_F(RangeDatasetOpTest, DatasetOutputShapes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); - std::vector expected_shapes({PartialTensorShape({})}); - EXPECT_EQ(range_dataset->output_shapes().size(), expected_shapes.size()); - for (int i = 0; i < range_dataset->output_shapes().size(); ++i) { - EXPECT_TRUE( - range_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); - } + TF_EXPECT_OK(VerifyShapesCompatible(range_dataset->output_shapes(), + test_case.expected_output_shapes)); } -struct CardinalityTestParams { - explicit CardinalityTestParams(int64 input_start, int64 input_end, - int64 input_step, - int input_expected_cardinality) - : start(input_start), - end(input_end), - step(input_step), - expected_cardinality(input_expected_cardinality) {} - - int64 start; - int64 end; - int64 step; - int expected_cardinality; -}; - -struct DatasetCardinalityTest - : RangeDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(DatasetCardinalityTest, Cardinality) { +TEST_P(ParameterizedRangeDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; - CardinalityTestParams params = GetParam(); - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step, - range_kernel.get(), &range_context)); - DatasetBase* range_dataset; + TestCase test_case = GetParam(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); - EXPECT_EQ(range_dataset->Cardinality(), params.expected_cardinality); + EXPECT_EQ(range_dataset->Cardinality(), test_case.expected_cardinality); } -INSTANTIATE_TEST_CASE_P(RangeDatasetOpTest, DatasetCardinalityTest, - ::testing::Values(CardinalityTestParams(0, 10, 1, 10), - CardinalityTestParams(0, 10, 3, 4), - CardinalityTestParams(10, 0, -3, 4))); - TEST_F(RangeDatasetOpTest, DatasetSave) { int64 thread_num = 2, cpu_num = 2; - int start = 0, end = 10, step = 1; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr serialization_context; @@ -256,81 +343,105 @@ TEST_F(RangeDatasetOpTest, DatasetSave) { } TEST_F(RangeDatasetOpTest, IteratorOutputDtypes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(range_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - DataTypeVector expected_dtypes({DT_INT64}); - EXPECT_EQ(iterator->output_dtypes(), expected_dtypes); + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); } TEST_F(RangeDatasetOpTest, IteratorOutputShapes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(range_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - std::vector expected_shapes({PartialTensorShape({})}); - EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size()); - for (int i = 0; i < range_dataset->output_shapes().size(); ++i) { - EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); - } + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); } TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(range_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); @@ -338,83 +449,77 @@ TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) { EXPECT_EQ(iterator->prefix(), "Iterator::Range"); } -struct RoundtripTestParams { - explicit RoundtripTestParams(int64 input_start, int64 input_end, - int64 input_step, int input_breakpoint) - : start(input_start), - end(input_end), - step(input_step), - breakpoint(input_breakpoint) {} - - int64 start; - int64 end; - int64 step; - int breakpoint; -}; - -struct IteratorRoundtripTest - : RangeDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(IteratorRoundtripTest, Roundtrip) { +TEST_P(ParameterizedRangeDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; - RoundtripTestParams params = GetParam(); - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step, - range_kernel.get(), &range_context)); - DatasetBase* range_dataset; + TestCase test_case = GetParam(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(range_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - std::vector out_tensors; + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); bool end_of_sequence = false; - int64 cur_val = params.start - params.step; - for (int i = 0; i < params.breakpoint; i++) { - if (!end_of_sequence) { + std::vector out_tensors; + int cur_iteration = 0; + auto expected_outputs_it = test_case.expected_outputs.begin(); + const std::vector& breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(iterator->Restore(iterator_context.get(), &reader)); + + while (cur_iteration <= breakpoint) { TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, &end_of_sequence)); - cur_val = ((params.end - cur_val - params.step) * params.step > 0) - ? cur_val + params.step - : cur_val; + if (!end_of_sequence) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it)); + expected_outputs_it++; + } + cur_iteration++; + } + + if (breakpoint >= test_case.expected_cardinality) { + EXPECT_TRUE(end_of_sequence); + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); + } else { + EXPECT_FALSE(end_of_sequence); } } - - std::unique_ptr serialization_context; - TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); - VariantTensorData data; - VariantTensorDataWriter writer(&data); - TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer)); - TF_ASSERT_OK(writer.Flush()); - VariantTensorDataReader reader(&data); - TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader)); - TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, - &end_of_sequence)); - int64 expect_next = ((params.end - cur_val - params.step) * params.step > 0) - ? cur_val + params.step - : cur_val; - EXPECT_EQ(out_tensors.back().flat()(0), expect_next); } -INSTANTIATE_TEST_CASE_P( - RangeDatasetOpTest, IteratorRoundtripTest, - ::testing::Values( - RoundtripTestParams(0, 10, 2, 0), // unused_iterator - RoundtripTestParams(0, 10, 2, 4), // fully_used_iterator_increase - RoundtripTestParams(10, 0, -2, 4), // fully_used_iterator_decrease - RoundtripTestParams(0, 10, 2, 6))); // exhausted_iterator +INSTANTIATE_TEST_SUITE_P( + RangeDatasetOpTest, ParameterizedRangeDatasetOpTest, + ::testing::ValuesIn(std::vector({PositiveStepTestCase(), + NegativeStepTestCase()}))); } // namespace } // namespace data