Add the tests of node_name() and type_string() for TakeDatasetOpTest
This commit is contained in:
parent
301dfdac4c
commit
3b23e0ee4f
@ -80,8 +80,7 @@ TestCase TakeLessTestCase() {
|
|||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}),
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}),
|
||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
|
||||||
/*expected_output_dtypes*/ {DT_INT64},
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
/*expected_output_shapes*/
|
/*expected_output_shapes*/ {PartialTensorShape({1})},
|
||||||
{PartialTensorShape({1})},
|
|
||||||
/*expected_cardinality*/ 4,
|
/*expected_cardinality*/ 4,
|
||||||
/*breakpoints*/ {0, 2, 5}};
|
/*breakpoints*/ {0, 2, 5}};
|
||||||
}
|
}
|
||||||
@ -104,8 +103,7 @@ TestCase TakeMoreTestCase() {
|
|||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
|
||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
|
||||||
/*expected_output_dtypes*/ {DT_INT64},
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
/*expected_output_shapes*/
|
/*expected_output_shapes*/ {PartialTensorShape({1})},
|
||||||
{PartialTensorShape({1})},
|
|
||||||
/*expected_cardinality*/ 10,
|
/*expected_cardinality*/ 10,
|
||||||
/*breakpoints*/ {0, 2, 5, 11}};
|
/*breakpoints*/ {0, 2, 5, 11}};
|
||||||
}
|
}
|
||||||
@ -128,8 +126,7 @@ TestCase TakeAllTestCase() {
|
|||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
|
||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
|
||||||
/*expected_output_dtypes*/ {DT_INT64},
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
/*expected_output_shapes*/
|
/*expected_output_shapes*/ {PartialTensorShape({1})},
|
||||||
{PartialTensorShape({1})},
|
|
||||||
/*expected_cardinality*/ -1,
|
/*expected_cardinality*/ -1,
|
||||||
/*breakpoints*/ {0, 2, 5, 11}};
|
/*breakpoints*/ {0, 2, 5, 11}};
|
||||||
}
|
}
|
||||||
@ -140,20 +137,18 @@ TestCase TakeNothingTestCase() {
|
|||||||
{DatasetOpsTestBase::CreateTensor<int64>(
|
{DatasetOpsTestBase::CreateTensor<int64>(
|
||||||
TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
|
TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
|
||||||
/*count*/ 0,
|
/*count*/ 0,
|
||||||
/*expected_outputs*/
|
/*expected_outputs*/ {},
|
||||||
{},
|
|
||||||
/*expected_output_dtypes*/ {DT_INT64},
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
/*expected_output_shapes*/
|
/*expected_output_shapes*/ {PartialTensorShape({1})},
|
||||||
{PartialTensorShape({1})},
|
|
||||||
/*expected_cardinality*/ 0,
|
/*expected_cardinality*/ 0,
|
||||||
/*breakpoints*/ {0, 2, 5, 11}};
|
/*breakpoints*/ {0, 2, 5, 11}};
|
||||||
}
|
}
|
||||||
|
|
||||||
class ParametrizedTakeDatasetOpTest
|
class ParameterizedTakeDatasetOpTest
|
||||||
: public TakeDatasetOpTest,
|
: public TakeDatasetOpTest,
|
||||||
public ::testing::WithParamInterface<TestCase> {};
|
public ::testing::WithParamInterface<TestCase> {};
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, GetNext) {
|
TEST_P(ParameterizedTakeDatasetOpTest, GetNext) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -204,7 +199,37 @@ TEST_P(ParametrizedTakeDatasetOpTest, GetNext) {
|
|||||||
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
|
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TakeDatasetOpTest, DatasetName) {
|
TEST_F(TakeDatasetOpTest, DatasetNodeName) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
const TestCase &test_case = TakeLessTestCase();
|
||||||
|
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||||
|
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
|
||||||
|
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
|
||||||
|
&tensor_slice_dataset_tensor));
|
||||||
|
Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs_for_take_dataset;
|
||||||
|
inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor);
|
||||||
|
inputs_for_take_dataset.emplace_back(&count);
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> take_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes,
|
||||||
|
test_case.expected_output_shapes,
|
||||||
|
&take_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> take_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(),
|
||||||
|
&inputs_for_take_dataset,
|
||||||
|
&take_dataset_context));
|
||||||
|
DatasetBase *take_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(),
|
||||||
|
take_dataset_context.get(), &take_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(take_dataset);
|
||||||
|
|
||||||
|
EXPECT_EQ(take_dataset->node_name(), kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TakeDatasetOpTest, DatasetTypeString) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -234,7 +259,7 @@ TEST_F(TakeDatasetOpTest, DatasetName) {
|
|||||||
EXPECT_EQ(take_dataset->type_string(), kOpName);
|
EXPECT_EQ(take_dataset->type_string(), kOpName);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputDtypes) {
|
TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputDtypes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -265,7 +290,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputDtypes) {
|
|||||||
test_case.expected_output_dtypes));
|
test_case.expected_output_dtypes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputShapes) {
|
TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputShapes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -296,7 +321,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputShapes) {
|
|||||||
test_case.expected_output_shapes));
|
test_case.expected_output_shapes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, Cardinality) {
|
TEST_P(ParameterizedTakeDatasetOpTest, Cardinality) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -361,7 +386,7 @@ TEST_F(TakeDatasetOpTest, DatasetSave) {
|
|||||||
TF_ASSERT_OK(writer.Flush());
|
TF_ASSERT_OK(writer.Flush());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputDtypes) {
|
TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputDtypes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -399,7 +424,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputDtypes) {
|
|||||||
test_case.expected_output_dtypes));
|
test_case.expected_output_dtypes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputShapes) {
|
TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputShapes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -437,7 +462,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputShapes) {
|
|||||||
test_case.expected_output_shapes));
|
test_case.expected_output_shapes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputPrefix) {
|
TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputPrefix) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -478,7 +503,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputPrefix) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, Roundtrip) {
|
TEST_P(ParameterizedTakeDatasetOpTest, Roundtrip) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -550,7 +575,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, Roundtrip) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(TakeDatasetOpTest, ParametrizedTakeDatasetOpTest,
|
INSTANTIATE_TEST_SUITE_P(TakeDatasetOpTest, ParameterizedTakeDatasetOpTest,
|
||||||
::testing::ValuesIn(std::vector<TestCase>(
|
::testing::ValuesIn(std::vector<TestCase>(
|
||||||
{TakeLessTestCase(), TakeMoreTestCase(),
|
{TakeLessTestCase(), TakeMoreTestCase(),
|
||||||
TakeAllTestCase(), TakeNothingTestCase()})));
|
TakeAllTestCase(), TakeNothingTestCase()})));
|
||||||
|
Loading…
Reference in New Issue
Block a user