Add the tests of node_name() and type_string() for TakeDatasetOpTest
This commit is contained in:
parent
301dfdac4c
commit
3b23e0ee4f
@ -80,8 +80,7 @@ TestCase TakeLessTestCase() {
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
|
||||
/*expected_output_dtypes*/ {DT_INT64},
|
||||
/*expected_output_shapes*/
|
||||
{PartialTensorShape({1})},
|
||||
/*expected_output_shapes*/ {PartialTensorShape({1})},
|
||||
/*expected_cardinality*/ 4,
|
||||
/*breakpoints*/ {0, 2, 5}};
|
||||
}
|
||||
@ -104,8 +103,7 @@ TestCase TakeMoreTestCase() {
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
|
||||
/*expected_output_dtypes*/ {DT_INT64},
|
||||
/*expected_output_shapes*/
|
||||
{PartialTensorShape({1})},
|
||||
/*expected_output_shapes*/ {PartialTensorShape({1})},
|
||||
/*expected_cardinality*/ 10,
|
||||
/*breakpoints*/ {0, 2, 5, 11}};
|
||||
}
|
||||
@ -128,8 +126,7 @@ TestCase TakeAllTestCase() {
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
|
||||
/*expected_output_dtypes*/ {DT_INT64},
|
||||
/*expected_output_shapes*/
|
||||
{PartialTensorShape({1})},
|
||||
/*expected_output_shapes*/ {PartialTensorShape({1})},
|
||||
/*expected_cardinality*/ -1,
|
||||
/*breakpoints*/ {0, 2, 5, 11}};
|
||||
}
|
||||
@ -140,20 +137,18 @@ TestCase TakeNothingTestCase() {
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(
|
||||
TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
|
||||
/*count*/ 0,
|
||||
/*expected_outputs*/
|
||||
{},
|
||||
/*expected_outputs*/ {},
|
||||
/*expected_output_dtypes*/ {DT_INT64},
|
||||
/*expected_output_shapes*/
|
||||
{PartialTensorShape({1})},
|
||||
/*expected_output_shapes*/ {PartialTensorShape({1})},
|
||||
/*expected_cardinality*/ 0,
|
||||
/*breakpoints*/ {0, 2, 5, 11}};
|
||||
}
|
||||
|
||||
class ParametrizedTakeDatasetOpTest
|
||||
class ParameterizedTakeDatasetOpTest
|
||||
: public TakeDatasetOpTest,
|
||||
public ::testing::WithParamInterface<TestCase> {};
|
||||
|
||||
TEST_P(ParametrizedTakeDatasetOpTest, GetNext) {
|
||||
TEST_P(ParameterizedTakeDatasetOpTest, GetNext) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
@ -204,7 +199,37 @@ TEST_P(ParametrizedTakeDatasetOpTest, GetNext) {
|
||||
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
|
||||
}
|
||||
|
||||
TEST_F(TakeDatasetOpTest, DatasetName) {
|
||||
TEST_F(TakeDatasetOpTest, DatasetNodeName) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
const TestCase &test_case = TakeLessTestCase();
|
||||
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
|
||||
&tensor_slice_dataset_tensor));
|
||||
Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs_for_take_dataset;
|
||||
inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor);
|
||||
inputs_for_take_dataset.emplace_back(&count);
|
||||
|
||||
std::unique_ptr<OpKernel> take_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&take_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> take_dataset_context;
|
||||
TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(),
|
||||
&inputs_for_take_dataset,
|
||||
&take_dataset_context));
|
||||
DatasetBase *take_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(),
|
||||
take_dataset_context.get(), &take_dataset));
|
||||
core::ScopedUnref scoped_unref(take_dataset);
|
||||
|
||||
EXPECT_EQ(take_dataset->node_name(), kNodeName);
|
||||
}
|
||||
|
||||
TEST_F(TakeDatasetOpTest, DatasetTypeString) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
@ -234,7 +259,7 @@ TEST_F(TakeDatasetOpTest, DatasetName) {
|
||||
EXPECT_EQ(take_dataset->type_string(), kOpName);
|
||||
}
|
||||
|
||||
TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputDtypes) {
|
||||
TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
@ -265,7 +290,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputDtypes) {
|
||||
test_case.expected_output_dtypes));
|
||||
}
|
||||
|
||||
TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputShapes) {
|
||||
TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputShapes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
@ -296,7 +321,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputShapes) {
|
||||
test_case.expected_output_shapes));
|
||||
}
|
||||
|
||||
TEST_P(ParametrizedTakeDatasetOpTest, Cardinality) {
|
||||
TEST_P(ParameterizedTakeDatasetOpTest, Cardinality) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
@ -361,7 +386,7 @@ TEST_F(TakeDatasetOpTest, DatasetSave) {
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputDtypes) {
|
||||
TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
@ -399,7 +424,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputDtypes) {
|
||||
test_case.expected_output_dtypes));
|
||||
}
|
||||
|
||||
TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputShapes) {
|
||||
TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputShapes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
@ -437,7 +462,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputShapes) {
|
||||
test_case.expected_output_shapes));
|
||||
}
|
||||
|
||||
TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputPrefix) {
|
||||
TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputPrefix) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
@ -478,7 +503,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputPrefix) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(ParametrizedTakeDatasetOpTest, Roundtrip) {
|
||||
TEST_P(ParameterizedTakeDatasetOpTest, Roundtrip) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
@ -550,7 +575,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, Roundtrip) {
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TakeDatasetOpTest, ParametrizedTakeDatasetOpTest,
|
||||
INSTANTIATE_TEST_SUITE_P(TakeDatasetOpTest, ParameterizedTakeDatasetOpTest,
|
||||
::testing::ValuesIn(std::vector<TestCase>(
|
||||
{TakeLessTestCase(), TakeMoreTestCase(),
|
||||
TakeAllTestCase(), TakeNothingTestCase()})));
|
||||
|
Loading…
Reference in New Issue
Block a user