diff --git a/tensorflow/core/kernels/data/map_dataset_op_test.cc b/tensorflow/core/kernels/data/map_dataset_op_test.cc index b0d17ab2865..457743c220b 100644 --- a/tensorflow/core/kernels/data/map_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/map_dataset_op_test.cc @@ -49,7 +49,7 @@ class MapDatasetOpTest : public DatasetOpsTestBase { FunctionDefHelper::AttrValueWrapper func = FunctionDefHelper::FunctionRef(func_name, {{"T", DT_INT64}}); - map_node_def_ = test::function::NDef( + NodeDef map_dataset_node_def = test::function::NDef( kNodeName, kOpName, {input_dataset}, {{"f", func}, {"Targuments", {}}, @@ -58,136 +58,198 @@ class MapDatasetOpTest : public DatasetOpsTestBase { gtl::ArraySlice{tensorflow::DataTypeToEnum::value}}, {"use_inter_op_parallelism", true}, {"preserve_cardinality", false}}); - TF_CHECK_OK(CreateOpKernel(map_node_def_, map_kernel)); + TF_RETURN_IF_ERROR(CreateOpKernel(map_dataset_node_def, map_kernel)); return Status::OK(); } // Creates a new MapDataset op kernel context. Status CreateMapDatasetContext( - DatasetBase* const input_dataset, OpKernel* const map_kernel, + OpKernel* const map_kernel, gtl::InlinedVector* inputs, std::unique_ptr* map_context) { - map_inputs_.clear(); - // Save the input dataset into a variant tensor as the input of MapDataset. - Tensor dataset_tensor(DT_VARIANT, TensorShape({})); - TF_RETURN_IF_ERROR( - StoreDatasetInVariantTensor(input_dataset, &dataset_tensor)); - Variant variant = dataset_tensor.scalar()(); - TF_RETURN_IF_ERROR(AddDatasetInputFromArray( - &map_inputs_, map_kernel->input_types(), TensorShape({}), {variant})); - input_dataset->Ref(); - TF_RETURN_IF_ERROR( - CreateOpKernelContext(map_kernel, &map_inputs_, map_context)); - TF_RETURN_IF_ERROR(CheckOpKernelInput(*map_kernel, map_inputs_)); + TF_RETURN_IF_ERROR(CheckOpKernelInput(*map_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(map_kernel, inputs, map_context)); return Status::OK(); } - - private: - NodeDef map_node_def_; - gtl::InlinedVector map_inputs_; }; -struct GetNextTestParams { - explicit GetNextTestParams(int64 input_start, int64 input_end, - int64 input_step, string input_func_name, - std::vector input_expected_values, - std::vector input_func_lib) - : start(input_start), - end(input_end), - step(input_step), - func_name(std::move(input_func_name)), - expected_values(std::move(input_expected_values)), - func_lib(std::move(input_func_lib)) {} - +struct TestCase { int64 start; int64 end; int64 step; string func_name; - std::vector expected_values; std::vector func_lib; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; }; -struct DatasetGetNextTest : MapDatasetOpTest, - ::testing::WithParamInterface {}; +TestCase TestCase1() { + return {/*start*/ 0, + /*end*/ 10, + /*step*/ 3, + /*func_name*/ "XTimesTwo", + /*func_lib*/ {test::function::XTimesTwo()}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {6}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {12}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {18})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} -TEST_P(DatasetGetNextTest, GetNext) { +TestCase TestCase2() { + return {/*start*/ 10, + /*end*/ 0, + /*step*/ -3, + /*func_name*/ "XAddX", + /*func_lib*/ {test::function::XAddX()}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {20}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {14}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {8}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} + +// In this test case, the function `XTimesFour()` will call `XTimesTwo()`, so +// both of them are added to the function library. +TestCase TestCase3() { + return { + /*start*/ 0, + /*end*/ 10, + /*step*/ 3, + /*func_name*/ "XTimesFour", + /*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {12}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {24}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {36})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} + +class ParameterizedMapDatasetOpTest + : public MapDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedMapDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; - GetNextTestParams test_params = GetParam(); - + TestCase test_case = GetParam(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime(test_params.func_lib, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; - TF_ASSERT_OK(CreateRangeDataset(test_params.start, test_params.end, - test_params.step, "range", - &range_dataset)); - core::ScopedUnref scored_unref_range_dataset(range_dataset); + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), test_params.func_name, &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(map_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); + bool end_of_sequence = false; + auto expected_outputs_it = test_case.expected_outputs.begin(); std::vector out_tensors; while (!end_of_sequence) { TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, &end_of_sequence)); + if (!end_of_sequence) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it)); + expected_outputs_it++; + } } - - EXPECT_EQ(out_tensors.size(), test_params.expected_values.size()); - for (size_t i = 0; i < out_tensors.size(); ++i) { - int64 actual_value = out_tensors[i].flat()(0); - int64 expect_value = test_params.expected_values[i]; - EXPECT_EQ(actual_value, expect_value); - } + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); } -INSTANTIATE_TEST_CASE_P( - MapDatasetOpTest, DatasetGetNextTest, - ::testing::Values( - GetNextTestParams( - 0, 10, 3, "XTimesTwo", std::vector{0, 6, 12, 18}, - std::vector{test::function::XTimesTwo()}), - GetNextTestParams(0, 10, 3, "XAddX", std::vector{0, 6, 12, 18}, - std::vector{test::function::XAddX()}), - GetNextTestParams( - 10, 0, -3, "XTimesFour", std::vector{40, 28, 16, 4}, - std::vector{test::function::XTimesTwo(), - test::function::XTimesFour()}))); - -TEST_F(MapDatasetOpTest, DatasetName) { +TEST_F(MapDatasetOpTest, DatasetNodeName) { int thread_num = 2, cpu_num = 2; - int64 start = 0, end = 10, step = 1; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); + core::ScopedUnref scoped_unref_map_dataset(map_dataset); + + EXPECT_EQ(map_dataset->node_name(), kNodeName); +} + +TEST_F(MapDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); + + std::unique_ptr map_dataset_kernel; + TF_ASSERT_OK(CreateMapDatasetOpKernel( + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); + DatasetBase* map_dataset; + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); EXPECT_EQ(map_dataset->type_string(), kOpName); @@ -195,138 +257,125 @@ TEST_F(MapDatasetOpTest, DatasetName) { TEST_F(MapDatasetOpTest, DatasetOutputDtypes) { int thread_num = 2, cpu_num = 2; - int64 start = 0, end = 10, step = 1; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); - DataTypeVector expected_dtypes({DT_INT64}); - EXPECT_EQ(map_dataset->output_dtypes(), expected_dtypes); + TF_EXPECT_OK(VerifyTypesMatch(map_dataset->output_dtypes(), + test_case.expected_output_dtypes)); } TEST_F(MapDatasetOpTest, DatasetOutputShapes) { int thread_num = 2, cpu_num = 2; - int64 start = 0, end = 10, step = 1; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); - std::vector expected_shapes({PartialTensorShape({})}); - EXPECT_EQ(map_dataset->output_shapes().size(), expected_shapes.size()); - for (int i = 0; i < map_dataset->output_shapes().size(); ++i) { - EXPECT_TRUE( - map_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); - } + TF_EXPECT_OK(VerifyShapesCompatible(map_dataset->output_shapes(), + test_case.expected_output_shapes)); } -struct CardinalityTestParams { - explicit CardinalityTestParams(int64 input_start, int64 input_end, - int64 input_step, - int input_expected_cardinality) - : start(input_start), - end(input_end), - step(input_step), - expected_cardinality(input_expected_cardinality) {} - - int64 start; - int64 end; - int64 step; - int expected_cardinality; -}; - -struct DatasetCardinalityTest - : MapDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(DatasetCardinalityTest, Cardinality) { +TEST_P(ParameterizedMapDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; - CardinalityTestParams test_params = GetParam(); - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = GetParam(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; - TF_ASSERT_OK(CreateRangeDataset(test_params.start, test_params.end, - test_params.step, "range", - &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); - EXPECT_EQ(map_dataset->Cardinality(), test_params.expected_cardinality); + EXPECT_EQ(map_dataset->Cardinality(), test_case.expected_cardinality); } -INSTANTIATE_TEST_CASE_P(MapDatasetOpTest, DatasetCardinalityTest, - ::testing::Values(CardinalityTestParams(0, 10, 1, 10), - CardinalityTestParams(0, 10, 3, 4), - CardinalityTestParams(10, 0, -3, 4))); - -TEST_F(MapDatasetOpTest, DatasetSave) { +TEST_P(ParameterizedMapDatasetOpTest, DatasetSave) { int thread_num = 2, cpu_num = 2; - int64 start = 0, end = 10, step = 1; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = GetParam(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr serialization_context; @@ -338,101 +387,114 @@ TEST_F(MapDatasetOpTest, DatasetSave) { } TEST_F(MapDatasetOpTest, IteratorOutputDtypes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(map_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - DataTypeVector expected_dtypes({DT_INT64}); - EXPECT_EQ(iterator->output_dtypes(), expected_dtypes); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); } TEST_F(MapDatasetOpTest, IteratorOutputShapes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(map_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - std::vector expected_shapes({PartialTensorShape({})}); - EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size()); - for (int i = 0; i < map_dataset->output_shapes().size(); ++i) { - EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); - } + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); } TEST_F(MapDatasetOpTest, IteratorOutputPrefix) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(map_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); @@ -440,95 +502,79 @@ TEST_F(MapDatasetOpTest, IteratorOutputPrefix) { EXPECT_EQ(iterator->prefix(), "Iterator::Map"); } -struct RoundtripTestParams { - explicit RoundtripTestParams(int64 input_start, int64 input_end, - int64 input_step, int input_breakpoint, - int64 input_expected_value, - string input_func_name, - std::vector input_func_lib) - : start(input_start), - end(input_end), - step(input_step), - breakpoint(input_breakpoint), - expected_value(input_expected_value), - func_name(std::move(input_func_name)), - func_lib(std::move(input_func_lib)) {} - - int64 start; - int64 end; - int64 step; - int breakpoint; - int64 expected_value; - string func_name; - std::vector func_lib; -}; - -struct IteratorRoundtripTest - : MapDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(IteratorRoundtripTest, Roundtrip) { +TEST_P(ParameterizedMapDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; - RoundtripTestParams test_params = GetParam(); - + TestCase test_case = GetParam(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime(test_params.func_lib, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; - TF_ASSERT_OK(CreateRangeDataset(test_params.start, test_params.end, - test_params.step, "range", - &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transfered to DatasetVariantWrapper, + // which will handle the release of memory. + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), test_params.func_name, &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(map_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - std::vector out_tensors; + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); bool end_of_sequence = false; - for (int i = 0; i < test_params.breakpoint; i++) { - TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, - &end_of_sequence)); - } + std::vector out_tensors; + int cur_iteration = 0; + auto expected_outputs_it = test_case.expected_outputs.begin(); + const std::vector& breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(iterator->Restore(iterator_context.get(), &reader)); - std::unique_ptr serialization_context; - TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); - VariantTensorData data; - VariantTensorDataWriter writer(&data); - TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer)); - TF_ASSERT_OK(writer.Flush()); - VariantTensorDataReader reader(&data); - TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader)); - TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, - &end_of_sequence)); - EXPECT_EQ(out_tensors.back().flat()(0), test_params.expected_value); + while (cur_iteration <= breakpoint) { + TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, + &end_of_sequence)); + if (!end_of_sequence) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it)); + expected_outputs_it++; + } + cur_iteration++; + } + + if (breakpoint >= test_case.expected_cardinality) { + EXPECT_TRUE(end_of_sequence); + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); + } else { + EXPECT_FALSE(end_of_sequence); + } + } } -INSTANTIATE_TEST_CASE_P( - MapDatasetOpTest, IteratorRoundtripTest, - ::testing::Values(RoundtripTestParams(0, 10, 2, 0, 0, "XTimesTwo", - std::vector{ - test::function::XTimesTwo()}), - RoundtripTestParams(0, 10, 2, 4, 16, "XAddX", - std::vector{ - test::function::XAddX()}), - RoundtripTestParams(0, 10, 2, 6, 32, "XTimesFour", - std::vector{ - test::function::XTimesTwo(), - test::function::XTimesFour()}))); +INSTANTIATE_TEST_SUITE_P(MapDatasetOpTest, ParameterizedMapDatasetOpTest, + ::testing::ValuesIn(std::vector( + {TestCase1(), TestCase2(), TestCase3()}))); } // namespace } // namespace data