diff --git a/tensorflow/core/kernels/data/take_dataset_op_test.cc b/tensorflow/core/kernels/data/take_dataset_op_test.cc index d8c68472ec0..afe22726552 100644 --- a/tensorflow/core/kernels/data/take_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/take_dataset_op_test.cc @@ -80,8 +80,7 @@ TestCase TakeLessTestCase() { DatasetOpsTestBase::CreateTensor(TensorShape{1}, {2}), DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3})}, /*expected_output_dtypes*/ {DT_INT64}, - /*expected_output_shapes*/ - {PartialTensorShape({1})}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, /*expected_cardinality*/ 4, /*breakpoints*/ {0, 2, 5}}; } @@ -104,8 +103,7 @@ TestCase TakeMoreTestCase() { DatasetOpsTestBase::CreateTensor(TensorShape{1}, {8}), DatasetOpsTestBase::CreateTensor(TensorShape{1}, {9})}, /*expected_output_dtypes*/ {DT_INT64}, - /*expected_output_shapes*/ - {PartialTensorShape({1})}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, /*expected_cardinality*/ 10, /*breakpoints*/ {0, 2, 5, 11}}; } @@ -128,8 +126,7 @@ TestCase TakeAllTestCase() { DatasetOpsTestBase::CreateTensor(TensorShape{1}, {8}), DatasetOpsTestBase::CreateTensor(TensorShape{1}, {9})}, /*expected_output_dtypes*/ {DT_INT64}, - /*expected_output_shapes*/ - {PartialTensorShape({1})}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, /*expected_cardinality*/ -1, /*breakpoints*/ {0, 2, 5, 11}}; } @@ -140,20 +137,18 @@ TestCase TakeNothingTestCase() { {DatasetOpsTestBase::CreateTensor( TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, /*count*/ 0, - /*expected_outputs*/ - {}, + /*expected_outputs*/ {}, /*expected_output_dtypes*/ {DT_INT64}, - /*expected_output_shapes*/ - {PartialTensorShape({1})}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, /*expected_cardinality*/ 0, /*breakpoints*/ {0, 2, 5, 11}}; } -class ParametrizedTakeDatasetOpTest +class ParameterizedTakeDatasetOpTest : public TakeDatasetOpTest, public ::testing::WithParamInterface {}; -TEST_P(ParametrizedTakeDatasetOpTest, GetNext) { +TEST_P(ParameterizedTakeDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -204,7 +199,37 @@ TEST_P(ParametrizedTakeDatasetOpTest, GetNext) { EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); } -TEST_F(TakeDatasetOpTest, DatasetName) { +TEST_F(TakeDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = TakeLessTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_take_dataset; + inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_take_dataset.emplace_back(&count); + + std::unique_ptr take_dataset_kernel; + TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &take_dataset_kernel)); + std::unique_ptr take_dataset_context; + TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(), + &inputs_for_take_dataset, + &take_dataset_context)); + DatasetBase *take_dataset; + TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(), + take_dataset_context.get(), &take_dataset)); + core::ScopedUnref scoped_unref(take_dataset); + + EXPECT_EQ(take_dataset->node_name(), kNodeName); +} + +TEST_F(TakeDatasetOpTest, DatasetTypeString) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -234,7 +259,7 @@ TEST_F(TakeDatasetOpTest, DatasetName) { EXPECT_EQ(take_dataset->type_string(), kOpName); } -TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputDtypes) { +TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputDtypes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -265,7 +290,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputDtypes) { test_case.expected_output_dtypes)); } -TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputShapes) { +TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -296,7 +321,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputShapes) { test_case.expected_output_shapes)); } -TEST_P(ParametrizedTakeDatasetOpTest, Cardinality) { +TEST_P(ParameterizedTakeDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -361,7 +386,7 @@ TEST_F(TakeDatasetOpTest, DatasetSave) { TF_ASSERT_OK(writer.Flush()); } -TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputDtypes) { +TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputDtypes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -399,7 +424,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputDtypes) { test_case.expected_output_dtypes)); } -TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputShapes) { +TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -437,7 +462,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputShapes) { test_case.expected_output_shapes)); } -TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputPrefix) { +TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputPrefix) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -478,7 +503,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputPrefix) { } } -TEST_P(ParametrizedTakeDatasetOpTest, Roundtrip) { +TEST_P(ParameterizedTakeDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -550,7 +575,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, Roundtrip) { } } -INSTANTIATE_TEST_SUITE_P(TakeDatasetOpTest, ParametrizedTakeDatasetOpTest, +INSTANTIATE_TEST_SUITE_P(TakeDatasetOpTest, ParameterizedTakeDatasetOpTest, ::testing::ValuesIn(std::vector( {TakeLessTestCase(), TakeMoreTestCase(), TakeAllTestCase(), TakeNothingTestCase()})));