Add the tests of node_name() and type_string() for TakeDatasetOpTest

This commit is contained in:
Fei Hu 2019-03-15 16:48:35 -07:00
parent 301dfdac4c
commit 3b23e0ee4f

View File

@ -80,8 +80,7 @@ TestCase TakeLessTestCase() {
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}), DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})}, DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
/*expected_output_dtypes*/ {DT_INT64}, /*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ /*expected_output_shapes*/ {PartialTensorShape({1})},
{PartialTensorShape({1})},
/*expected_cardinality*/ 4, /*expected_cardinality*/ 4,
/*breakpoints*/ {0, 2, 5}}; /*breakpoints*/ {0, 2, 5}};
} }
@ -104,8 +103,7 @@ TestCase TakeMoreTestCase() {
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}), DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})}, DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
/*expected_output_dtypes*/ {DT_INT64}, /*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ /*expected_output_shapes*/ {PartialTensorShape({1})},
{PartialTensorShape({1})},
/*expected_cardinality*/ 10, /*expected_cardinality*/ 10,
/*breakpoints*/ {0, 2, 5, 11}}; /*breakpoints*/ {0, 2, 5, 11}};
} }
@ -128,8 +126,7 @@ TestCase TakeAllTestCase() {
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}), DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})}, DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
/*expected_output_dtypes*/ {DT_INT64}, /*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ /*expected_output_shapes*/ {PartialTensorShape({1})},
{PartialTensorShape({1})},
/*expected_cardinality*/ -1, /*expected_cardinality*/ -1,
/*breakpoints*/ {0, 2, 5, 11}}; /*breakpoints*/ {0, 2, 5, 11}};
} }
@ -140,20 +137,18 @@ TestCase TakeNothingTestCase() {
{DatasetOpsTestBase::CreateTensor<int64>( {DatasetOpsTestBase::CreateTensor<int64>(
TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
/*count*/ 0, /*count*/ 0,
/*expected_outputs*/ /*expected_outputs*/ {},
{},
/*expected_output_dtypes*/ {DT_INT64}, /*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ /*expected_output_shapes*/ {PartialTensorShape({1})},
{PartialTensorShape({1})},
/*expected_cardinality*/ 0, /*expected_cardinality*/ 0,
/*breakpoints*/ {0, 2, 5, 11}}; /*breakpoints*/ {0, 2, 5, 11}};
} }
class ParametrizedTakeDatasetOpTest class ParameterizedTakeDatasetOpTest
: public TakeDatasetOpTest, : public TakeDatasetOpTest,
public ::testing::WithParamInterface<TestCase> {}; public ::testing::WithParamInterface<TestCase> {};
TEST_P(ParametrizedTakeDatasetOpTest, GetNext) { TEST_P(ParameterizedTakeDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -204,7 +199,37 @@ TEST_P(ParametrizedTakeDatasetOpTest, GetNext) {
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
} }
TEST_F(TakeDatasetOpTest, DatasetName) { TEST_F(TakeDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = TakeLessTestCase();
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
&tensor_slice_dataset_tensor));
Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
gtl::InlinedVector<TensorValue, 4> inputs_for_take_dataset;
inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor);
inputs_for_take_dataset.emplace_back(&count);
std::unique_ptr<OpKernel> take_dataset_kernel;
TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&take_dataset_kernel));
std::unique_ptr<OpKernelContext> take_dataset_context;
TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(),
&inputs_for_take_dataset,
&take_dataset_context));
DatasetBase *take_dataset;
TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(),
take_dataset_context.get(), &take_dataset));
core::ScopedUnref scoped_unref(take_dataset);
EXPECT_EQ(take_dataset->node_name(), kNodeName);
}
TEST_F(TakeDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -234,7 +259,7 @@ TEST_F(TakeDatasetOpTest, DatasetName) {
EXPECT_EQ(take_dataset->type_string(), kOpName); EXPECT_EQ(take_dataset->type_string(), kOpName);
} }
TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputDtypes) { TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -265,7 +290,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputDtypes) {
test_case.expected_output_dtypes)); test_case.expected_output_dtypes));
} }
TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputShapes) { TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -296,7 +321,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputShapes) {
test_case.expected_output_shapes)); test_case.expected_output_shapes));
} }
TEST_P(ParametrizedTakeDatasetOpTest, Cardinality) { TEST_P(ParameterizedTakeDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -361,7 +386,7 @@ TEST_F(TakeDatasetOpTest, DatasetSave) {
TF_ASSERT_OK(writer.Flush()); TF_ASSERT_OK(writer.Flush());
} }
TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputDtypes) { TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -399,7 +424,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputDtypes) {
test_case.expected_output_dtypes)); test_case.expected_output_dtypes));
} }
TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputShapes) { TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -437,7 +462,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputShapes) {
test_case.expected_output_shapes)); test_case.expected_output_shapes));
} }
TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputPrefix) { TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputPrefix) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -478,7 +503,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputPrefix) {
} }
} }
TEST_P(ParametrizedTakeDatasetOpTest, Roundtrip) { TEST_P(ParameterizedTakeDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -550,7 +575,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, Roundtrip) {
} }
} }
INSTANTIATE_TEST_SUITE_P(TakeDatasetOpTest, ParametrizedTakeDatasetOpTest, INSTANTIATE_TEST_SUITE_P(TakeDatasetOpTest, ParameterizedTakeDatasetOpTest,
::testing::ValuesIn(std::vector<TestCase>( ::testing::ValuesIn(std::vector<TestCase>(
{TakeLessTestCase(), TakeMoreTestCase(), {TakeLessTestCase(), TakeMoreTestCase(),
TakeAllTestCase(), TakeNothingTestCase()}))); TakeAllTestCase(), TakeNothingTestCase()})));