Add the tests of node_name() and type_string() for TakeDatasetOpTest

This commit is contained in:
Fei Hu 2019-03-15 16:48:35 -07:00
parent 301dfdac4c
commit 3b23e0ee4f

View File

@ -80,8 +80,7 @@ TestCase TakeLessTestCase() {
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/
{PartialTensorShape({1})},
/*expected_output_shapes*/ {PartialTensorShape({1})},
/*expected_cardinality*/ 4,
/*breakpoints*/ {0, 2, 5}};
}
@ -104,8 +103,7 @@ TestCase TakeMoreTestCase() {
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/
{PartialTensorShape({1})},
/*expected_output_shapes*/ {PartialTensorShape({1})},
/*expected_cardinality*/ 10,
/*breakpoints*/ {0, 2, 5, 11}};
}
@ -128,8 +126,7 @@ TestCase TakeAllTestCase() {
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/
{PartialTensorShape({1})},
/*expected_output_shapes*/ {PartialTensorShape({1})},
/*expected_cardinality*/ -1,
/*breakpoints*/ {0, 2, 5, 11}};
}
@ -140,20 +137,18 @@ TestCase TakeNothingTestCase() {
{DatasetOpsTestBase::CreateTensor<int64>(
TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
/*count*/ 0,
/*expected_outputs*/
{},
/*expected_outputs*/ {},
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/
{PartialTensorShape({1})},
/*expected_output_shapes*/ {PartialTensorShape({1})},
/*expected_cardinality*/ 0,
/*breakpoints*/ {0, 2, 5, 11}};
}
class ParametrizedTakeDatasetOpTest
class ParameterizedTakeDatasetOpTest
: public TakeDatasetOpTest,
public ::testing::WithParamInterface<TestCase> {};
TEST_P(ParametrizedTakeDatasetOpTest, GetNext) {
TEST_P(ParameterizedTakeDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -204,7 +199,37 @@ TEST_P(ParametrizedTakeDatasetOpTest, GetNext) {
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
}
TEST_F(TakeDatasetOpTest, DatasetName) {
TEST_F(TakeDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = TakeLessTestCase();
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
&tensor_slice_dataset_tensor));
Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
gtl::InlinedVector<TensorValue, 4> inputs_for_take_dataset;
inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor);
inputs_for_take_dataset.emplace_back(&count);
std::unique_ptr<OpKernel> take_dataset_kernel;
TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&take_dataset_kernel));
std::unique_ptr<OpKernelContext> take_dataset_context;
TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(),
&inputs_for_take_dataset,
&take_dataset_context));
DatasetBase *take_dataset;
TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(),
take_dataset_context.get(), &take_dataset));
core::ScopedUnref scoped_unref(take_dataset);
EXPECT_EQ(take_dataset->node_name(), kNodeName);
}
TEST_F(TakeDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -234,7 +259,7 @@ TEST_F(TakeDatasetOpTest, DatasetName) {
EXPECT_EQ(take_dataset->type_string(), kOpName);
}
TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputDtypes) {
TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -265,7 +290,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputDtypes) {
test_case.expected_output_dtypes));
}
TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputShapes) {
TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -296,7 +321,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputShapes) {
test_case.expected_output_shapes));
}
TEST_P(ParametrizedTakeDatasetOpTest, Cardinality) {
TEST_P(ParameterizedTakeDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -361,7 +386,7 @@ TEST_F(TakeDatasetOpTest, DatasetSave) {
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputDtypes) {
TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -399,7 +424,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputDtypes) {
test_case.expected_output_dtypes));
}
TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputShapes) {
TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -437,7 +462,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputShapes) {
test_case.expected_output_shapes));
}
TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputPrefix) {
TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputPrefix) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -478,7 +503,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputPrefix) {
}
}
TEST_P(ParametrizedTakeDatasetOpTest, Roundtrip) {
TEST_P(ParameterizedTakeDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -550,7 +575,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, Roundtrip) {
}
}
INSTANTIATE_TEST_SUITE_P(TakeDatasetOpTest, ParametrizedTakeDatasetOpTest,
INSTANTIATE_TEST_SUITE_P(TakeDatasetOpTest, ParameterizedTakeDatasetOpTest,
::testing::ValuesIn(std::vector<TestCase>(
{TakeLessTestCase(), TakeMoreTestCase(),
TakeAllTestCase(), TakeNothingTestCase()})));