diff --git a/tensorflow/core/kernels/data/map_dataset_op_test.cc b/tensorflow/core/kernels/data/map_dataset_op_test.cc index b0d17ab2865..457743c220b 100644 --- a/tensorflow/core/kernels/data/map_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/map_dataset_op_test.cc @@ -49,7 +49,7 @@ class MapDatasetOpTest : public DatasetOpsTestBase { FunctionDefHelper::AttrValueWrapper func = FunctionDefHelper::FunctionRef(func_name, {{"T", DT_INT64}}); - map_node_def_ = test::function::NDef( + NodeDef map_dataset_node_def = test::function::NDef( kNodeName, kOpName, {input_dataset}, {{"f", func}, {"Targuments", {}}, @@ -58,136 +58,198 @@ class MapDatasetOpTest : public DatasetOpsTestBase { gtl::ArraySlice{tensorflow::DataTypeToEnum::value}}, {"use_inter_op_parallelism", true}, {"preserve_cardinality", false}}); - TF_CHECK_OK(CreateOpKernel(map_node_def_, map_kernel)); + TF_RETURN_IF_ERROR(CreateOpKernel(map_dataset_node_def, map_kernel)); return Status::OK(); } // Creates a new MapDataset op kernel context. Status CreateMapDatasetContext( - DatasetBase* const input_dataset, OpKernel* const map_kernel, + OpKernel* const map_kernel, gtl::InlinedVector* inputs, std::unique_ptr* map_context) { - map_inputs_.clear(); - // Save the input dataset into a variant tensor as the input of MapDataset. - Tensor dataset_tensor(DT_VARIANT, TensorShape({})); - TF_RETURN_IF_ERROR( - StoreDatasetInVariantTensor(input_dataset, &dataset_tensor)); - Variant variant = dataset_tensor.scalar()(); - TF_RETURN_IF_ERROR(AddDatasetInputFromArray( - &map_inputs_, map_kernel->input_types(), TensorShape({}), {variant})); - input_dataset->Ref(); - TF_RETURN_IF_ERROR( - CreateOpKernelContext(map_kernel, &map_inputs_, map_context)); - TF_RETURN_IF_ERROR(CheckOpKernelInput(*map_kernel, map_inputs_)); + TF_RETURN_IF_ERROR(CheckOpKernelInput(*map_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(map_kernel, inputs, map_context)); return Status::OK(); } - - private: - NodeDef map_node_def_; - gtl::InlinedVector map_inputs_; }; -struct GetNextTestParams { - explicit GetNextTestParams(int64 input_start, int64 input_end, - int64 input_step, string input_func_name, - std::vector input_expected_values, - std::vector input_func_lib) - : start(input_start), - end(input_end), - step(input_step), - func_name(std::move(input_func_name)), - expected_values(std::move(input_expected_values)), - func_lib(std::move(input_func_lib)) {} - +struct TestCase { int64 start; int64 end; int64 step; string func_name; - std::vector expected_values; std::vector func_lib; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; }; -struct DatasetGetNextTest : MapDatasetOpTest, - ::testing::WithParamInterface {}; +TestCase TestCase1() { + return {/*start*/ 0, + /*end*/ 10, + /*step*/ 3, + /*func_name*/ "XTimesTwo", + /*func_lib*/ {test::function::XTimesTwo()}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {6}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {12}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {18})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} -TEST_P(DatasetGetNextTest, GetNext) { +TestCase TestCase2() { + return {/*start*/ 10, + /*end*/ 0, + /*step*/ -3, + /*func_name*/ "XAddX", + /*func_lib*/ {test::function::XAddX()}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {20}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {14}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {8}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} + +// In this test case, the function `XTimesFour()` will call `XTimesTwo()`, so +// both of them are added to the function library. +TestCase TestCase3() { + return { + /*start*/ 0, + /*end*/ 10, + /*step*/ 3, + /*func_name*/ "XTimesFour", + /*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {12}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {24}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {36})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} + +class ParameterizedMapDatasetOpTest + : public MapDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedMapDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; - GetNextTestParams test_params = GetParam(); - + TestCase test_case = GetParam(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime(test_params.func_lib, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; - TF_ASSERT_OK(CreateRangeDataset(test_params.start, test_params.end, - test_params.step, "range", - &range_dataset)); - core::ScopedUnref scored_unref_range_dataset(range_dataset); + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), test_params.func_name, &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(map_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); + bool end_of_sequence = false; + auto expected_outputs_it = test_case.expected_outputs.begin(); std::vector out_tensors; while (!end_of_sequence) { TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, &end_of_sequence)); + if (!end_of_sequence) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it)); + expected_outputs_it++; + } } - - EXPECT_EQ(out_tensors.size(), test_params.expected_values.size()); - for (size_t i = 0; i < out_tensors.size(); ++i) { - int64 actual_value = out_tensors[i].flat()(0); - int64 expect_value = test_params.expected_values[i]; - EXPECT_EQ(actual_value, expect_value); - } + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); } -INSTANTIATE_TEST_CASE_P( - MapDatasetOpTest, DatasetGetNextTest, - ::testing::Values( - GetNextTestParams( - 0, 10, 3, "XTimesTwo", std::vector{0, 6, 12, 18}, - std::vector{test::function::XTimesTwo()}), - GetNextTestParams(0, 10, 3, "XAddX", std::vector{0, 6, 12, 18}, - std::vector{test::function::XAddX()}), - GetNextTestParams( - 10, 0, -3, "XTimesFour", std::vector{40, 28, 16, 4}, - std::vector{test::function::XTimesTwo(), - test::function::XTimesFour()}))); - -TEST_F(MapDatasetOpTest, DatasetName) { +TEST_F(MapDatasetOpTest, DatasetNodeName) { int thread_num = 2, cpu_num = 2; - int64 start = 0, end = 10, step = 1; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); + core::ScopedUnref scoped_unref_map_dataset(map_dataset); + + EXPECT_EQ(map_dataset->node_name(), kNodeName); +} + +TEST_F(MapDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); + + std::unique_ptr map_dataset_kernel; + TF_ASSERT_OK(CreateMapDatasetOpKernel( + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); + DatasetBase* map_dataset; + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); EXPECT_EQ(map_dataset->type_string(), kOpName); @@ -195,138 +257,125 @@ TEST_F(MapDatasetOpTest, DatasetName) { TEST_F(MapDatasetOpTest, DatasetOutputDtypes) { int thread_num = 2, cpu_num = 2; - int64 start = 0, end = 10, step = 1; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); - DataTypeVector expected_dtypes({DT_INT64}); - EXPECT_EQ(map_dataset->output_dtypes(), expected_dtypes); + TF_EXPECT_OK(VerifyTypesMatch(map_dataset->output_dtypes(), + test_case.expected_output_dtypes)); } TEST_F(MapDatasetOpTest, DatasetOutputShapes) { int thread_num = 2, cpu_num = 2; - int64 start = 0, end = 10, step = 1; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); - std::vector expected_shapes({PartialTensorShape({})}); - EXPECT_EQ(map_dataset->output_shapes().size(), expected_shapes.size()); - for (int i = 0; i < map_dataset->output_shapes().size(); ++i) { - EXPECT_TRUE( - map_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); - } + TF_EXPECT_OK(VerifyShapesCompatible(map_dataset->output_shapes(), + test_case.expected_output_shapes)); } -struct CardinalityTestParams { - explicit CardinalityTestParams(int64 input_start, int64 input_end, - int64 input_step, - int input_expected_cardinality) - : start(input_start), - end(input_end), - step(input_step), - expected_cardinality(input_expected_cardinality) {} - - int64 start; - int64 end; - int64 step; - int expected_cardinality; -}; - -struct DatasetCardinalityTest - : MapDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(DatasetCardinalityTest, Cardinality) { +TEST_P(ParameterizedMapDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; - CardinalityTestParams test_params = GetParam(); - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = GetParam(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; - TF_ASSERT_OK(CreateRangeDataset(test_params.start, test_params.end, - test_params.step, "range", - &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); - EXPECT_EQ(map_dataset->Cardinality(), test_params.expected_cardinality); + EXPECT_EQ(map_dataset->Cardinality(), test_case.expected_cardinality); } -INSTANTIATE_TEST_CASE_P(MapDatasetOpTest, DatasetCardinalityTest, - ::testing::Values(CardinalityTestParams(0, 10, 1, 10), - CardinalityTestParams(0, 10, 3, 4), - CardinalityTestParams(10, 0, -3, 4))); - -TEST_F(MapDatasetOpTest, DatasetSave) { +TEST_P(ParameterizedMapDatasetOpTest, DatasetSave) { int thread_num = 2, cpu_num = 2; - int64 start = 0, end = 10, step = 1; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = GetParam(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr serialization_context; @@ -338,101 +387,114 @@ TEST_F(MapDatasetOpTest, DatasetSave) { } TEST_F(MapDatasetOpTest, IteratorOutputDtypes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(map_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - DataTypeVector expected_dtypes({DT_INT64}); - EXPECT_EQ(iterator->output_dtypes(), expected_dtypes); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); } TEST_F(MapDatasetOpTest, IteratorOutputShapes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(map_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - std::vector expected_shapes({PartialTensorShape({})}); - EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size()); - for (int i = 0; i < map_dataset->output_shapes().size(); ++i) { - EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); - } + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); } TEST_F(MapDatasetOpTest, IteratorOutputPrefix) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(map_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); @@ -440,95 +502,79 @@ TEST_F(MapDatasetOpTest, IteratorOutputPrefix) { EXPECT_EQ(iterator->prefix(), "Iterator::Map"); } -struct RoundtripTestParams { - explicit RoundtripTestParams(int64 input_start, int64 input_end, - int64 input_step, int input_breakpoint, - int64 input_expected_value, - string input_func_name, - std::vector input_func_lib) - : start(input_start), - end(input_end), - step(input_step), - breakpoint(input_breakpoint), - expected_value(input_expected_value), - func_name(std::move(input_func_name)), - func_lib(std::move(input_func_lib)) {} - - int64 start; - int64 end; - int64 step; - int breakpoint; - int64 expected_value; - string func_name; - std::vector func_lib; -}; - -struct IteratorRoundtripTest - : MapDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(IteratorRoundtripTest, Roundtrip) { +TEST_P(ParameterizedMapDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; - RoundtripTestParams test_params = GetParam(); - + TestCase test_case = GetParam(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime(test_params.func_lib, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; - TF_ASSERT_OK(CreateRangeDataset(test_params.start, test_params.end, - test_params.step, "range", - &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), test_params.func_name, &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(map_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - std::vector out_tensors; + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); bool end_of_sequence = false; - for (int i = 0; i < test_params.breakpoint; i++) { - TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, - &end_of_sequence)); - } + std::vector out_tensors; + int cur_iteration = 0; + auto expected_outputs_it = test_case.expected_outputs.begin(); + const std::vector& breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(iterator->Restore(iterator_context.get(), &reader)); - std::unique_ptr serialization_context; - TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); - VariantTensorData data; - VariantTensorDataWriter writer(&data); - TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer)); - TF_ASSERT_OK(writer.Flush()); - VariantTensorDataReader reader(&data); - TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader)); - TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, - &end_of_sequence)); - EXPECT_EQ(out_tensors.back().flat()(0), test_params.expected_value); + while (cur_iteration <= breakpoint) { + TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, + &end_of_sequence)); + if (!end_of_sequence) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it)); + expected_outputs_it++; + } + cur_iteration++; + } + + if (breakpoint >= test_case.expected_cardinality) { + EXPECT_TRUE(end_of_sequence); + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); + } else { + EXPECT_FALSE(end_of_sequence); + } + } } -INSTANTIATE_TEST_CASE_P( - MapDatasetOpTest, IteratorRoundtripTest, - ::testing::Values(RoundtripTestParams(0, 10, 2, 0, 0, "XTimesTwo", - std::vector{ - test::function::XTimesTwo()}), - RoundtripTestParams(0, 10, 2, 4, 16, "XAddX", - std::vector{ - test::function::XAddX()}), - RoundtripTestParams(0, 10, 2, 6, 32, "XTimesFour", - std::vector{ - test::function::XTimesTwo(), - test::function::XTimesFour()}))); +INSTANTIATE_TEST_SUITE_P(MapDatasetOpTest, ParameterizedMapDatasetOpTest, + ::testing::ValuesIn(std::vector( + {TestCase1(), TestCase2(), TestCase3()}))); } // namespace } // namespace data diff --git a/tensorflow/core/kernels/data/range_dataset_op_test.cc b/tensorflow/core/kernels/data/range_dataset_op_test.cc index bfe091fd524..dd589265a74 100644 --- a/tensorflow/core/kernels/data/range_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/range_dataset_op_test.cc @@ -13,237 +13,324 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/function_testlib.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/variant.h" -#include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/kernels/data/dataset_test_base.h" -#include "tensorflow/core/kernels/data/dataset_utils.h" -#include "tensorflow/core/kernels/data/iterator_ops.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { namespace { +constexpr char kNodeName[] = "range_dataset"; constexpr char kOpName[] = "RangeDataset"; class RangeDatasetOpTest : public DatasetOpsTestBase { protected: // Creates a new RangeDataset op kernel context. Status CreateRangeDatasetContext( - int64 start, int64 end, int64 step, OpKernel* const range_kernel, + OpKernel* const range_kernel, + gtl::InlinedVector* const inputs, std::unique_ptr* range_context) { - inputs_.clear(); - TF_RETURN_IF_ERROR(AddDatasetInputFromArray( - &inputs_, range_kernel->input_types(), TensorShape({}), {start})); - TF_RETURN_IF_ERROR(AddDatasetInputFromArray( - &inputs_, range_kernel->input_types(), TensorShape({}), {end})); - TF_RETURN_IF_ERROR(AddDatasetInputFromArray( - &inputs_, range_kernel->input_types(), TensorShape({}), {step})); - + TF_RETURN_IF_ERROR(CheckOpKernelInput(*range_kernel, *inputs)); TF_RETURN_IF_ERROR( - CreateOpKernelContext(range_kernel, &inputs_, range_context)); - TF_RETURN_IF_ERROR(CheckOpKernelInput(*range_kernel, inputs_)); + CreateOpKernelContext(range_kernel, inputs, range_context)); return Status::OK(); } - - private: - gtl::InlinedVector inputs_; }; -struct GetNextTestParams { - explicit GetNextTestParams(int64 input_start, int64 input_end, - int64 input_step) - : start(input_start), end(input_end), step(input_step) {} - +struct TestCase { int64 start; int64 end; int64 step; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; }; -struct DatasetGetNextTest : RangeDatasetOpTest, - ::testing::WithParamInterface {}; +TestCase PositiveStepTestCase() { + return {/*start*/ 0, + /*end*/ 10, + /*step*/ 3, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {6}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {9})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 4}}; +} -TEST_P(DatasetGetNextTest, GetNext) { +TestCase NegativeStepTestCase() { + return {/*start*/ 10, + /*end*/ 0, + /*step*/ -3, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {10}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {7}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 4}}; +} + +TestCase ZeroStepTestCase() { + return {/*start*/ 0, + /*end*/ 10, + /*step*/ 0, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {}, + /*expected_output_shapes*/ {}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {}}; +} + +class ParameterizedRangeDatasetOpTest + : public RangeDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedRangeDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; - GetNextTestParams params = GetParam(); - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step, - range_kernel.get(), &range_context)); - DatasetBase* range_dataset; + TestCase test_case = GetParam(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(range_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); bool end_of_sequence = false; + auto expected_outputs_it = test_case.expected_outputs.begin(); std::vector out_tensors; while (!end_of_sequence) { TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, &end_of_sequence)); + if (!end_of_sequence) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it)); + expected_outputs_it++; + } } - std::vector expected_values; - for (int i = params.start; (params.end - i) * params.step > 0; - i = i + params.step) { - expected_values.reserve(1); - expected_values.emplace_back(i); - } - EXPECT_EQ(out_tensors.size(), expected_values.size()); - for (size_t i = 0; i < out_tensors.size(); ++i) { - int64 actual_value = out_tensors[i].flat()(0); - int64 expect_value = expected_values[i]; - EXPECT_EQ(actual_value, expect_value); - } + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); } -INSTANTIATE_TEST_CASE_P(RangeDatasetOpTest, DatasetGetNextTest, - ::testing::Values(GetNextTestParams(0, 10, 1), - GetNextTestParams(0, 10, 3), - GetNextTestParams(10, 0, -1), - GetNextTestParams(10, 0, -3))); - -TEST_F(RangeDatasetOpTest, DatasetName) { - int64 start = 0, end = 10, step = 1; +TEST_F(RangeDatasetOpTest, ZeroStep) { int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = ZeroStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + EXPECT_EQ(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset) + .code(), + tensorflow::error::INVALID_ARGUMENT); +} + +TEST_F(RangeDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; + TF_ASSERT_OK( + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); + core::ScopedUnref scoped_unref(range_dataset); + + EXPECT_EQ(range_dataset->node_name(), kNodeName); +} + +TEST_F(RangeDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; + TF_ASSERT_OK( + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); EXPECT_EQ(range_dataset->type_string(), kOpName); } TEST_F(RangeDatasetOpTest, DatasetOutputDtypes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); - DataTypeVector expected_dtypes({DT_INT64}); - EXPECT_EQ(range_dataset->output_dtypes(), expected_dtypes); + TF_EXPECT_OK(VerifyTypesMatch(range_dataset->output_dtypes(), + test_case.expected_output_dtypes)); } TEST_F(RangeDatasetOpTest, DatasetOutputShapes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); - std::vector expected_shapes({PartialTensorShape({})}); - EXPECT_EQ(range_dataset->output_shapes().size(), expected_shapes.size()); - for (int i = 0; i < range_dataset->output_shapes().size(); ++i) { - EXPECT_TRUE( - range_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); - } + TF_EXPECT_OK(VerifyShapesCompatible(range_dataset->output_shapes(), + test_case.expected_output_shapes)); } -struct CardinalityTestParams { - explicit CardinalityTestParams(int64 input_start, int64 input_end, - int64 input_step, - int input_expected_cardinality) - : start(input_start), - end(input_end), - step(input_step), - expected_cardinality(input_expected_cardinality) {} - - int64 start; - int64 end; - int64 step; - int expected_cardinality; -}; - -struct DatasetCardinalityTest - : RangeDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(DatasetCardinalityTest, Cardinality) { +TEST_P(ParameterizedRangeDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; - CardinalityTestParams params = GetParam(); - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step, - range_kernel.get(), &range_context)); - DatasetBase* range_dataset; + TestCase test_case = GetParam(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); - EXPECT_EQ(range_dataset->Cardinality(), params.expected_cardinality); + EXPECT_EQ(range_dataset->Cardinality(), test_case.expected_cardinality); } -INSTANTIATE_TEST_CASE_P(RangeDatasetOpTest, DatasetCardinalityTest, - ::testing::Values(CardinalityTestParams(0, 10, 1, 10), - CardinalityTestParams(0, 10, 3, 4), - CardinalityTestParams(10, 0, -3, 4))); - TEST_F(RangeDatasetOpTest, DatasetSave) { int64 thread_num = 2, cpu_num = 2; - int start = 0, end = 10, step = 1; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr serialization_context; @@ -256,81 +343,105 @@ TEST_F(RangeDatasetOpTest, DatasetSave) { } TEST_F(RangeDatasetOpTest, IteratorOutputDtypes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(range_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - DataTypeVector expected_dtypes({DT_INT64}); - EXPECT_EQ(iterator->output_dtypes(), expected_dtypes); + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); } TEST_F(RangeDatasetOpTest, IteratorOutputShapes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(range_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - std::vector expected_shapes({PartialTensorShape({})}); - EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size()); - for (int i = 0; i < range_dataset->output_shapes().size(); ++i) { - EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); - } + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); } TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(range_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); @@ -338,83 +449,77 @@ TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) { EXPECT_EQ(iterator->prefix(), "Iterator::Range"); } -struct RoundtripTestParams { - explicit RoundtripTestParams(int64 input_start, int64 input_end, - int64 input_step, int input_breakpoint) - : start(input_start), - end(input_end), - step(input_step), - breakpoint(input_breakpoint) {} - - int64 start; - int64 end; - int64 step; - int breakpoint; -}; - -struct IteratorRoundtripTest - : RangeDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(IteratorRoundtripTest, Roundtrip) { +TEST_P(ParameterizedRangeDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; - RoundtripTestParams params = GetParam(); - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step, - range_kernel.get(), &range_context)); - DatasetBase* range_dataset; + TestCase test_case = GetParam(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(range_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - std::vector out_tensors; + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); bool end_of_sequence = false; - int64 cur_val = params.start - params.step; - for (int i = 0; i < params.breakpoint; i++) { - if (!end_of_sequence) { + std::vector out_tensors; + int cur_iteration = 0; + auto expected_outputs_it = test_case.expected_outputs.begin(); + const std::vector& breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(iterator->Restore(iterator_context.get(), &reader)); + + while (cur_iteration <= breakpoint) { TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, &end_of_sequence)); - cur_val = ((params.end - cur_val - params.step) * params.step > 0) - ? cur_val + params.step - : cur_val; + if (!end_of_sequence) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it)); + expected_outputs_it++; + } + cur_iteration++; + } + + if (breakpoint >= test_case.expected_cardinality) { + EXPECT_TRUE(end_of_sequence); + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); + } else { + EXPECT_FALSE(end_of_sequence); } } - - std::unique_ptr serialization_context; - TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); - VariantTensorData data; - VariantTensorDataWriter writer(&data); - TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer)); - TF_ASSERT_OK(writer.Flush()); - VariantTensorDataReader reader(&data); - TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader)); - TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, - &end_of_sequence)); - int64 expect_next = ((params.end - cur_val - params.step) * params.step > 0) - ? cur_val + params.step - : cur_val; - EXPECT_EQ(out_tensors.back().flat()(0), expect_next); } -INSTANTIATE_TEST_CASE_P( - RangeDatasetOpTest, IteratorRoundtripTest, - ::testing::Values( - RoundtripTestParams(0, 10, 2, 0), // unused_iterator - RoundtripTestParams(0, 10, 2, 4), // fully_used_iterator_increase - RoundtripTestParams(10, 0, -2, 4), // fully_used_iterator_decrease - RoundtripTestParams(0, 10, 2, 6))); // exhausted_iterator +INSTANTIATE_TEST_SUITE_P( + RangeDatasetOpTest, ParameterizedRangeDatasetOpTest, + ::testing::ValuesIn(std::vector({PositiveStepTestCase(), + NegativeStepTestCase()}))); } // namespace } // namespace data diff --git a/tensorflow/core/kernels/data/take_dataset_op_test.cc b/tensorflow/core/kernels/data/take_dataset_op_test.cc index d8c68472ec0..afe22726552 100644 --- a/tensorflow/core/kernels/data/take_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/take_dataset_op_test.cc @@ -80,8 +80,7 @@ TestCase TakeLessTestCase() { DatasetOpsTestBase::CreateTensor(TensorShape{1}, {2}), DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3})}, /*expected_output_dtypes*/ {DT_INT64}, - /*expected_output_shapes*/ - {PartialTensorShape({1})}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, /*expected_cardinality*/ 4, /*breakpoints*/ {0, 2, 5}}; } @@ -104,8 +103,7 @@ TestCase TakeMoreTestCase() { DatasetOpsTestBase::CreateTensor(TensorShape{1}, {8}), DatasetOpsTestBase::CreateTensor(TensorShape{1}, {9})}, /*expected_output_dtypes*/ {DT_INT64}, - /*expected_output_shapes*/ - {PartialTensorShape({1})}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, /*expected_cardinality*/ 10, /*breakpoints*/ {0, 2, 5, 11}}; } @@ -128,8 +126,7 @@ TestCase TakeAllTestCase() { DatasetOpsTestBase::CreateTensor(TensorShape{1}, {8}), DatasetOpsTestBase::CreateTensor(TensorShape{1}, {9})}, /*expected_output_dtypes*/ {DT_INT64}, - /*expected_output_shapes*/ - {PartialTensorShape({1})}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, /*expected_cardinality*/ -1, /*breakpoints*/ {0, 2, 5, 11}}; } @@ -140,20 +137,18 @@ TestCase TakeNothingTestCase() { {DatasetOpsTestBase::CreateTensor( TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, /*count*/ 0, - /*expected_outputs*/ - {}, + /*expected_outputs*/ {}, /*expected_output_dtypes*/ {DT_INT64}, - /*expected_output_shapes*/ - {PartialTensorShape({1})}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, /*expected_cardinality*/ 0, /*breakpoints*/ {0, 2, 5, 11}}; } -class ParametrizedTakeDatasetOpTest +class ParameterizedTakeDatasetOpTest : public TakeDatasetOpTest, public ::testing::WithParamInterface {}; -TEST_P(ParametrizedTakeDatasetOpTest, GetNext) { +TEST_P(ParameterizedTakeDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -204,7 +199,37 @@ TEST_P(ParametrizedTakeDatasetOpTest, GetNext) { EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); } -TEST_F(TakeDatasetOpTest, DatasetName) { +TEST_F(TakeDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = TakeLessTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_take_dataset; + inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_take_dataset.emplace_back(&count); + + std::unique_ptr take_dataset_kernel; + TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &take_dataset_kernel)); + std::unique_ptr take_dataset_context; + TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(), + &inputs_for_take_dataset, + &take_dataset_context)); + DatasetBase *take_dataset; + TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(), + take_dataset_context.get(), &take_dataset)); + core::ScopedUnref scoped_unref(take_dataset); + + EXPECT_EQ(take_dataset->node_name(), kNodeName); +} + +TEST_F(TakeDatasetOpTest, DatasetTypeString) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -234,7 +259,7 @@ TEST_F(TakeDatasetOpTest, DatasetName) { EXPECT_EQ(take_dataset->type_string(), kOpName); } -TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputDtypes) { +TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputDtypes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -265,7 +290,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputDtypes) { test_case.expected_output_dtypes)); } -TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputShapes) { +TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -296,7 +321,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputShapes) { test_case.expected_output_shapes)); } -TEST_P(ParametrizedTakeDatasetOpTest, Cardinality) { +TEST_P(ParameterizedTakeDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -361,7 +386,7 @@ TEST_F(TakeDatasetOpTest, DatasetSave) { TF_ASSERT_OK(writer.Flush()); } -TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputDtypes) { +TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputDtypes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -399,7 +424,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputDtypes) { test_case.expected_output_dtypes)); } -TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputShapes) { +TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -437,7 +462,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputShapes) { test_case.expected_output_shapes)); } -TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputPrefix) { +TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputPrefix) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -478,7 +503,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputPrefix) { } } -TEST_P(ParametrizedTakeDatasetOpTest, Roundtrip) { +TEST_P(ParameterizedTakeDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -550,7 +575,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, Roundtrip) { } } -INSTANTIATE_TEST_SUITE_P(TakeDatasetOpTest, ParametrizedTakeDatasetOpTest, +INSTANTIATE_TEST_SUITE_P(TakeDatasetOpTest, ParameterizedTakeDatasetOpTest, ::testing::ValuesIn(std::vector( {TakeLessTestCase(), TakeMoreTestCase(), TakeAllTestCase(), TakeNothingTestCase()})));