diff --git a/tensorflow/core/kernels/data/batch_dataset_op_test.cc b/tensorflow/core/kernels/data/batch_dataset_op_test.cc index c954326925c..cce73a41ca4 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op_test.cc @@ -19,7 +19,6 @@ namespace { constexpr char kNodeName[] = "batch_dataset_v2"; constexpr int kOpVersion = 2; -constexpr char kIteratorPrefix[] = "Iterator"; class BatchDatasetParams : public DatasetParams { public: @@ -36,8 +35,7 @@ class BatchDatasetParams : public DatasetParams { parallel_copy(parallel_copy) {} Status MakeInputs(gtl::InlinedVector* inputs) override { - if (input_dataset.NumElements() == 0 || - input_dataset.dtype() != DT_VARIANT) { + if (!IsDatasetTensor(input_dataset)) { return errors::Internal( "The input dataset is not populated as the dataset tensor yet."); } @@ -67,7 +65,7 @@ class BatchDatasetOpTest : public DatasetOpsTestBaseV2 { &batch_dataset_params->input_dataset)); // Create the dataset kernel. TF_RETURN_IF_ERROR( - CreateBatchDatasetOpKernel(*batch_dataset_params, &dataset_kernel_)); + MakeDatasetOpKernel(*batch_dataset_params, &dataset_kernel_)); // Create the inputs for the dataset op. gtl::InlinedVector inputs; TF_RETURN_IF_ERROR(batch_dataset_params->MakeInputs(&inputs)); @@ -82,16 +80,17 @@ class BatchDatasetOpTest : public DatasetOpsTestBaseV2 { TF_RETURN_IF_ERROR( CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_)); // Create the iterator. - TF_RETURN_IF_ERROR(dataset_->MakeIterator(iterator_ctx_.get(), - kIteratorPrefix, &iterator_)); + TF_RETURN_IF_ERROR(dataset_->MakeIterator( + iterator_ctx_.get(), batch_dataset_params->iterator_prefix, + &iterator_)); return Status::OK(); } protected: // Creates a new `BatchDataset` op kernel. - Status CreateBatchDatasetOpKernel( + Status MakeDatasetOpKernel( const BatchDatasetParams& dataset_params, - std::unique_ptr* batch_dataset_op_kernel) { + std::unique_ptr* batch_dataset_op_kernel) override { name_utils::OpNameParams params; params.op_version = kOpVersion; NodeDef node_def = test::function::NDef( @@ -202,10 +201,6 @@ BatchDatasetParams InvalidBatchSizeBatchDatasetParams() { /*node_name=*/kNodeName}; } -class ParameterizedGetNextTest : public BatchDatasetOpTest, - public ::testing::WithParamInterface< - GetNextTestCase> {}; - std::vector> GetNextTestCases() { return {{/*dataset_params=*/BatchDatasetParams1(), /*expected_outputs=*/ @@ -236,17 +231,8 @@ std::vector> GetNextTestCases() { /*expected_outputs=*/{}}}; } -TEST_P(ParameterizedGetNextTest, GetNext) { - auto test_case = GetParam(); - TF_ASSERT_OK(Initialize(&test_case.dataset_params)); - TF_ASSERT_OK( - CheckIteratorGetNext(test_case.expected_outputs, /*compare_order=*/true)); -} - -INSTANTIATE_TEST_SUITE_P( - BatchDatasetOpTest, ParameterizedGetNextTest, - ::testing::ValuesIn( - std::vector>(GetNextTestCases()))); +ITERATOR_GET_NEXT_TEST_P(BatchDatasetOpTest, BatchDatasetParams, + GetNextTestCases()) TEST_F(BatchDatasetOpTest, DatasetNodeName) { auto batch_dataset_params = BatchDatasetParams1(); @@ -269,11 +255,6 @@ TEST_F(BatchDatasetOpTest, DatasetOutputDtypes) { TF_ASSERT_OK(CheckDatasetOutputDtypes({DT_INT64})); } -class ParameterizedDatasetOutputShapesTest - : public BatchDatasetOpTest, - public ::testing::WithParamInterface< - DatasetOutputShapesTestCase> {}; - std::vector> DatasetOutputShapesTestCases() { return {{/*dataset_params=*/BatchDatasetParams1(), @@ -292,22 +273,8 @@ DatasetOutputShapesTestCases() { /*expected_output_shapes=*/{PartialTensorShape({4})}}}; } -TEST_P(ParameterizedDatasetOutputShapesTest, DatasetOutputShapes) { - auto test_case = GetParam(); - TF_ASSERT_OK(Initialize(&test_case.dataset_params)); - TF_ASSERT_OK(CheckDatasetOutputShapes(test_case.expected_output_shapes)); -} - -INSTANTIATE_TEST_SUITE_P( - BatchDatasetOpTest, ParameterizedDatasetOutputShapesTest, - ::testing::ValuesIn( - std::vector>( - DatasetOutputShapesTestCases()))); - -class ParameterizedCardinalityTest - : public BatchDatasetOpTest, - public ::testing::WithParamInterface< - CardinalityTestCase> {}; +DATASET_OUTPUT_SHAPES_TEST_P(BatchDatasetOpTest, BatchDatasetParams, + DatasetOutputShapesTestCases()) std::vector> CardinalityTestCases() { return { @@ -320,16 +287,8 @@ std::vector> CardinalityTestCases() { {/*dataset_params=*/BatchDatasetParams7(), /*expected_cardinality=*/0}}; } -TEST_P(ParameterizedCardinalityTest, Cardinality) { - auto test_case = GetParam(); - TF_ASSERT_OK(Initialize(&test_case.dataset_params)); - TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality)); -} - -INSTANTIATE_TEST_SUITE_P( - BatchDatasetOpTest, ParameterizedCardinalityTest, - ::testing::ValuesIn(std::vector>( - CardinalityTestCases()))); +DATASET_CARDINALITY_TEST_P(BatchDatasetOpTest, BatchDatasetParams, + CardinalityTestCases()) TEST_F(BatchDatasetOpTest, IteratorOutputDtypes) { auto batch_dataset_params = BatchDatasetParams1(); @@ -337,11 +296,6 @@ TEST_F(BatchDatasetOpTest, IteratorOutputDtypes) { TF_ASSERT_OK(CheckIteratorOutputDtypes({DT_INT64})); } -class ParameterizedIteratorOutputShapesTest - : public BatchDatasetOpTest, - public ::testing::WithParamInterface< - IteratorOutputShapesTestCase> {}; - std::vector> IteratorOutputShapesTestCases() { return {{/*dataset_params=*/BatchDatasetParams1(), @@ -360,17 +314,8 @@ IteratorOutputShapesTestCases() { /*expected_output_shapes=*/{PartialTensorShape({4})}}}; } -TEST_P(ParameterizedIteratorOutputShapesTest, IteratorOutputShapes) { - auto test_case = GetParam(); - TF_ASSERT_OK(Initialize(&test_case.dataset_params)); - TF_ASSERT_OK(CheckIteratorOutputShapes(test_case.expected_output_shapes)); -} - -INSTANTIATE_TEST_SUITE_P( - BatchDatasetOpTest, ParameterizedIteratorOutputShapesTest, - ::testing::ValuesIn( - std::vector>( - IteratorOutputShapesTestCases()))); +ITERATOR_OUTPUT_SHAPES_TEST_P(BatchDatasetOpTest, BatchDatasetParams, + IteratorOutputShapesTestCases()) TEST_F(BatchDatasetOpTest, IteratorOutputPrefix) { auto batch_dataset_params = BatchDatasetParams1(); @@ -378,14 +323,10 @@ TEST_F(BatchDatasetOpTest, IteratorOutputPrefix) { name_utils::IteratorPrefixParams params; params.op_version = kOpVersion; TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix( - BatchDatasetOp::kDatasetType, kIteratorPrefix, params))); + BatchDatasetOp::kDatasetType, batch_dataset_params.iterator_prefix, + params))); } -class ParameterizedIteratorSaveAndRestoreTest - : public BatchDatasetOpTest, - public ::testing::WithParamInterface< - IteratorSaveAndRestoreTestCase> {}; - std::vector> IteratorSaveAndRestoreTestCases() { return {{/*dataset_params=*/BatchDatasetParams1(), @@ -424,18 +365,8 @@ IteratorSaveAndRestoreTestCases() { /*expected_outputs=*/{}}}; } -TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) { - auto test_case = GetParam(); - TF_ASSERT_OK(Initialize(&test_case.dataset_params)); - TF_ASSERT_OK(CheckIteratorSaveAndRestore( - kIteratorPrefix, test_case.expected_outputs, test_case.breakpoints)); -} - -INSTANTIATE_TEST_SUITE_P( - BatchDatasetOpTest, ParameterizedIteratorSaveAndRestoreTest, - ::testing::ValuesIn( - std::vector>( - IteratorSaveAndRestoreTestCases()))); +ITERATOR_SAVE_AND_RESTORE_TEST_P(BatchDatasetOpTest, BatchDatasetParams, + IteratorSaveAndRestoreTestCases()) TEST_F(BatchDatasetOpTest, InvalidBatchSize) { auto batch_dataset_params = InvalidBatchSizeBatchDatasetParams(); diff --git a/tensorflow/core/kernels/data/dataset_test_base.h b/tensorflow/core/kernels/data/dataset_test_base.h index e813f046f1e..55ad8677701 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.h +++ b/tensorflow/core/kernels/data/dataset_test_base.h @@ -47,6 +47,7 @@ namespace data { constexpr int kDefaultCPUNum = 2; constexpr int kDefaultThreadNum = 2; +constexpr char kDefaultIteratorPrefix[] = "Iterator"; enum class CompressionType { ZLIB = 0, GZIP = 1, RAW = 2, UNCOMPRESSED = 3 }; @@ -109,13 +110,19 @@ class DatasetParams { output_shapes(std::move(output_shapes)), node_name(std::move(node_name)) {} + virtual ~DatasetParams() {} + virtual Status MakeInputs(gtl::InlinedVector* inputs) = 0; - virtual ~DatasetParams() {} + bool IsDatasetTensor(const Tensor& tensor) { + return tensor.dtype() == DT_VARIANT && + TensorShapeUtils::IsScalar(tensor.shape()); + } DataTypeVector output_dtypes; std::vector output_shapes; string node_name; + string iterator_prefix = kDefaultIteratorPrefix; }; class RangeDatasetParams : public DatasetParams { @@ -130,6 +137,12 @@ class RangeDatasetParams : public DatasetParams { stop(CreateTensor(TensorShape({}), {stop})), step(CreateTensor(TensorShape({}), {step})) {} + RangeDatasetParams(int64 start, int64 stop, int64 step) + : DatasetParams({DT_INT64}, {PartialTensorShape({})}, ""), + start(CreateTensor(TensorShape({}), {start})), + stop(CreateTensor(TensorShape({}), {stop})), + step(CreateTensor(TensorShape({}), {step})) {} + Status MakeInputs(gtl::InlinedVector* inputs) override { *inputs = {TensorValue(&start), TensorValue(&stop), TensorValue(&step)}; return Status::OK(); @@ -471,8 +484,205 @@ class DatasetOpsTestBaseV2 : public DatasetOpsTestBase { public: // Initializes the required members for running the unit tests. virtual Status Initialize(T* dataset_params) = 0; + + virtual Status MakeDatasetOpKernel( + const T& dataset_params, std::unique_ptr* dataset_kernel) = 0; }; +#define ITERATOR_GET_NEXT_TEST_P(dataset_op_test_class, dataset_params_class, \ + test_case_generator) \ + class ParameterizedGetNextTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + GetNextTestCase> {}; \ + \ + TEST_P(ParameterizedGetNextTest, GetNext) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \ + TF_ASSERT_OK(CheckIteratorGetNext(test_case.expected_outputs, \ + /*compare_order=*/true)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedGetNextTest, \ + ::testing::ValuesIn(std::vector>( \ + test_case_generator))); + +#define DATASET_NODE_NAME_TEST_P(dataset_op_test_class, dataset_params_class, \ + test_case_generator) \ + class ParameterizedDatasetNodeNameTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + DatasetNodeNameTestCase> {}; \ + \ + TEST_P(ParameterizedDatasetNodeNameTest, DatasetNodeName) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \ + TF_ASSERT_OK(CheckDatasetNodeName(test_case.expected_node_name)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedDatasetNodeNameTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_case_generator))); + +#define DATASET_TYPE_STRING_TEST_P(dataset_op_test_class, \ + dataset_params_class, test_case_generator) \ + class ParameterizedDatasetTypeStringTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + DatasetTypeStringTestCase> {}; \ + \ + TEST_P(ParameterizedDatasetTypeStringTest, DatasetTypeString) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \ + TF_ASSERT_OK( \ + CheckDatasetTypeString(test_case.expected_dataset_type_string)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedDatasetTypeStringTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_case_generator))); + +#define DATASET_OUTPUT_DTYPES_TEST_P( \ + dataset_op_test_class, dataset_params_class, test_case_generator) \ + \ + class ParameterizedDatasetOutputDtypesTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + DatasetOutputDtypesTestCase> {}; \ + \ + TEST_P(ParameterizedDatasetOutputDtypesTest, DatasetOutputDtypes) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \ + TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedDatasetOutputDtypesTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_case_generator))); + +#define DATASET_OUTPUT_SHAPES_TEST_P( \ + dataset_op_test_class, dataset_params_class, test_case_generator) \ + \ + class ParameterizedDatasetOutputShapesTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + DatasetOutputShapesTestCase> {}; \ + \ + TEST_P(ParameterizedDatasetOutputShapesTest, DatasetOutputShapes) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \ + TF_ASSERT_OK(CheckDatasetOutputShapes(test_case.expected_output_shapes)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedDatasetOutputShapesTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_case_generator))); + +#define DATASET_CARDINALITY_TEST_P(dataset_op_test_class, \ + dataset_params_class, test_case_generator) \ + \ + class ParameterizedCardinalityTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + CardinalityTestCase> {}; \ + \ + TEST_P(ParameterizedCardinalityTest, Cardinality) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \ + TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedCardinalityTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_case_generator))); + +#define ITERATOR_OUTPUT_DTYPES_TEST_P( \ + dataset_op_test_class, dataset_params_class, test_case_generator) \ + class ParameterizedIteratorOutputDtypesTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + IteratorOutputDtypesTestCase> {}; \ + \ + TEST_P(ParameterizedIteratorOutputDtypesTest, IteratorOutputDtypes) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \ + TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedIteratorOutputDtypesTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_case_generator))); + +#define ITERATOR_OUTPUT_SHAPES_TEST_P( \ + dataset_op_test_class, dataset_params_class, test_case_generator) \ + class ParameterizedIteratorOutputShapesTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + IteratorOutputShapesTestCase> {}; \ + \ + TEST_P(ParameterizedIteratorOutputShapesTest, IteratorOutputShapes) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \ + TF_ASSERT_OK(CheckIteratorOutputShapes(test_case.expected_output_shapes)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedIteratorOutputShapesTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_case_generator))); + +#define ITERATOR_PREFIX_TEST_P(dataset_op_test_class, dataset_params_class, \ + test_case_generator) \ + class ParameterizedIteratorPrefixTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + IteratorPrefixTestCase> {}; \ + \ + TEST_P(ParameterizedIteratorPrefixTest, IteratorPrefix) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \ + TF_ASSERT_OK(CheckIteratorPrefix(test_case.expected_iterator_prefix)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedIteratorPrefixTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_case_generator))); + +#define ITERATOR_SAVE_AND_RESTORE_TEST_P( \ + dataset_op_test_class, dataset_params_class, test_case_generator) \ + class ParameterizedIteratorSaveAndRestoreTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + IteratorSaveAndRestoreTestCase> {}; \ + TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \ + TF_ASSERT_OK(CheckIteratorSaveAndRestore( \ + test_case.dataset_params.iterator_prefix, test_case.expected_outputs, \ + test_case.breakpoints)); \ + } \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedIteratorSaveAndRestoreTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_case_generator))); + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/map_dataset_op_test.cc b/tensorflow/core/kernels/data/map_dataset_op_test.cc index 378f3b3e33f..3c0a635ab00 100644 --- a/tensorflow/core/kernels/data/map_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/map_dataset_op_test.cc @@ -21,11 +21,10 @@ namespace data { namespace { constexpr char kNodeName[] = "map_dataset"; -constexpr char kIteratorPrefix[] = "Iterator"; class MapDatasetParams : public DatasetParams { public: - MapDatasetParams(int64 start, int64 stop, int64 step, + MapDatasetParams(RangeDatasetParams range_dataset_params, std::vector other_arguments, FunctionDefHelper::AttrValueWrapper func, std::vector func_lib, @@ -35,8 +34,7 @@ class MapDatasetParams : public DatasetParams { string node_name) : DatasetParams(std::move(output_dtypes), std::move(output_shapes), std::move(node_name)), - range_dataset_params(start, stop, step, {DT_INT64}, - {PartialTensorShape({})}, ""), + range_dataset_params(std::move(range_dataset_params)), other_arguments(std::move(other_arguments)), func(std::move(func)), func_lib(std::move(func_lib)), @@ -45,8 +43,7 @@ class MapDatasetParams : public DatasetParams { preserve_cardinality(preserve_cardinality) {} Status MakeInputs(gtl::InlinedVector* inputs) override { - if (input_dataset.NumElements() == 0 || - input_dataset.dtype() != DT_VARIANT) { + if (!IsDatasetTensor(input_dataset)) { return tensorflow::errors::Internal( "The input dataset is not populated as the dataset tensor yet."); } @@ -75,7 +72,7 @@ class MapDatasetOpTest : public DatasetOpsTestBaseV2 { InitFunctionLibraryRuntime(map_dataset_params->func_lib, cpu_num_)); TF_RETURN_IF_ERROR( - CreateMapDatasetOpKernel(*map_dataset_params, &dataset_kernel_)); + MakeDatasetOpKernel(*map_dataset_params, &dataset_kernel_)); TF_RETURN_IF_ERROR( MakeRangeDataset(map_dataset_params->range_dataset_params, &map_dataset_params->input_dataset)); @@ -87,15 +84,15 @@ class MapDatasetOpTest : public DatasetOpsTestBaseV2 { CreateDataset(dataset_kernel_.get(), dataset_ctx_.get(), &dataset_)); TF_RETURN_IF_ERROR( CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_)); - TF_RETURN_IF_ERROR(dataset_->MakeIterator(iterator_ctx_.get(), - kIteratorPrefix, &iterator_)); + TF_RETURN_IF_ERROR(dataset_->MakeIterator( + iterator_ctx_.get(), map_dataset_params->iterator_prefix, &iterator_)); return Status::OK(); } protected: // Creates a new MapDataset op kernel. - Status CreateMapDatasetOpKernel(const MapDatasetParams& map_dataset_params, - std::unique_ptr* map_kernel) { + Status MakeDatasetOpKernel(const MapDatasetParams& map_dataset_params, + std::unique_ptr* map_kernel) override { NodeDef map_dataset_node_def = test::function::NDef( map_dataset_params.node_name, name_utils::OpName(MapDatasetOp::kDatasetType), @@ -114,9 +111,7 @@ class MapDatasetOpTest : public DatasetOpsTestBaseV2 { }; MapDatasetParams MapDatasetParams1() { - return {/*start=*/0, - /*stop=*/10, - /*step=*/3, + return {{/*start=*/0, /*stop=*/10, /*step=*/3}, /*other_arguments=*/{}, /*func=*/ FunctionDefHelper::FunctionRef("XTimesTwo", {{"T", DT_INT64}}), @@ -130,9 +125,7 @@ MapDatasetParams MapDatasetParams1() { } MapDatasetParams MapDatasetParams2() { - return {/*start=*/10, - /*stop=*/0, - /*step=*/-3, + return {{/*start=*/10, /*stop=*/0, /*step=*/-3}, /*other_arguments=*/{}, /*func=*/ FunctionDefHelper::FunctionRef("XAddX", {{"T", DT_INT64}}), @@ -149,9 +142,7 @@ MapDatasetParams MapDatasetParams2() { // both of them are added to the function library. MapDatasetParams MapDatasetParams3() { return { - /*start=*/0, - /*stop=*/10, - /*step=*/3, + {/*start=*/0, /*stop=*/10, /*step=*/3}, /*other_arguments=*/{}, /*func=*/ FunctionDefHelper::FunctionRef("XTimesFour", {{"T", DT_INT64}}), @@ -164,11 +155,6 @@ MapDatasetParams MapDatasetParams3() { /*node_name=*/kNodeName}; } -class ParameterizedGetNextTest - : public MapDatasetOpTest, - public ::testing::WithParamInterface> { -}; - std::vector> GetNextTestCases() { return {{/*dataset_params=*/MapDatasetParams1(), /*expected_outputs=*/ @@ -181,17 +167,16 @@ std::vector> GetNextTestCases() { CreateTensors(TensorShape({}), {{0}, {12}, {24}, {36}})}}; } -TEST_P(ParameterizedGetNextTest, GetNext) { - auto test_case = GetParam(); - TF_ASSERT_OK(Initialize(&test_case.dataset_params)); - TF_ASSERT_OK( - CheckIteratorGetNext(test_case.expected_outputs, /*compare_order=*/true)); +ITERATOR_GET_NEXT_TEST_P(MapDatasetOpTest, MapDatasetParams, GetNextTestCases()) + +std::vector> +DatasetNodeNameTestCases() { + return {{/*dataset_params=*/MapDatasetParams1(), + /*expected_node_name=*/kNodeName}}; } -INSTANTIATE_TEST_SUITE_P( - MapDatasetOpTest, ParameterizedGetNextTest, - ::testing::ValuesIn( - std::vector>(GetNextTestCases()))); +DATASET_NODE_NAME_TEST_P(MapDatasetOpTest, MapDatasetParams, + DatasetNodeNameTestCases()) TEST_F(MapDatasetOpTest, DatasetNodeName) { auto dataset_params = MapDatasetParams1(); @@ -218,27 +203,14 @@ TEST_F(MapDatasetOpTest, DatasetOutputShapes) { TF_ASSERT_OK(CheckDatasetOutputShapes({PartialTensorShape({})})); } -class ParameterizedCardinalityTest - : public MapDatasetOpTest, - public ::testing::WithParamInterface< - CardinalityTestCase> {}; - std::vector> CardinalityTestCases() { return {{/*dataset_params=*/MapDatasetParams1(), /*expected_cardinality=*/4}, {/*dataset_params=*/MapDatasetParams2(), /*expected_cardinality=*/4}, {/*dataset_params=*/MapDatasetParams3(), /*expected_cardinality=*/4}}; } -TEST_P(ParameterizedCardinalityTest, Cardinality) { - auto test_case = GetParam(); - TF_ASSERT_OK(Initialize(&test_case.dataset_params)); - TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality)); -} - -INSTANTIATE_TEST_SUITE_P( - MapDatasetOpTest, ParameterizedCardinalityTest, - ::testing::ValuesIn(std::vector>( - CardinalityTestCases()))); +DATASET_CARDINALITY_TEST_P(MapDatasetOpTest, MapDatasetParams, + CardinalityTestCases()) TEST_F(MapDatasetOpTest, IteratorOutputDtypes) { auto dataset_params = MapDatasetParams1(); @@ -255,15 +227,10 @@ TEST_F(MapDatasetOpTest, IteratorOutputShapes) { TEST_F(MapDatasetOpTest, IteratorPrefix) { auto dataset_params = MapDatasetParams1(); TF_ASSERT_OK(Initialize(&dataset_params)); - TF_ASSERT_OK(CheckIteratorPrefix( - name_utils::IteratorPrefix(MapDatasetOp::kDatasetType, kIteratorPrefix))); + TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix( + MapDatasetOp::kDatasetType, dataset_params.iterator_prefix))); } -class ParameterizedIteratorSaveAndRestoreTest - : public MapDatasetOpTest, - public ::testing::WithParamInterface< - IteratorSaveAndRestoreTestCase> {}; - std::vector> IteratorSaveAndRestoreTestCases() { return {{/*dataset_params=*/MapDatasetParams1(), @@ -280,18 +247,8 @@ IteratorSaveAndRestoreTestCases() { CreateTensors(TensorShape({}), {{0}, {12}, {24}, {36}})}}; } -TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) { - auto test_case = GetParam(); - TF_ASSERT_OK(Initialize(&test_case.dataset_params)); - TF_ASSERT_OK(CheckIteratorSaveAndRestore( - kIteratorPrefix, test_case.expected_outputs, test_case.breakpoints)); -} - -INSTANTIATE_TEST_SUITE_P( - MapDatasetOpTest, ParameterizedIteratorSaveAndRestoreTest, - ::testing::ValuesIn( - std::vector>( - IteratorSaveAndRestoreTestCases()))); +ITERATOR_SAVE_AND_RESTORE_TEST_P(MapDatasetOpTest, MapDatasetParams, + IteratorSaveAndRestoreTestCases()) } // namespace } // namespace data diff --git a/tensorflow/core/kernels/data/range_dataset_op_test.cc b/tensorflow/core/kernels/data/range_dataset_op_test.cc index 2688c30fab1..62f621fd838 100644 --- a/tensorflow/core/kernels/data/range_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/range_dataset_op_test.cc @@ -21,7 +21,6 @@ namespace data { namespace { constexpr char kNodeName[] = "range_dataset"; -constexpr char kIteratorPrefix[] = "Iterator"; class RangeDatasetOpTest : public DatasetOpsTestBaseV2 { public: @@ -30,7 +29,7 @@ class RangeDatasetOpTest : public DatasetOpsTestBaseV2 { TF_RETURN_IF_ERROR(InitFunctionLibraryRuntime({}, cpu_num_)); TF_RETURN_IF_ERROR( - CreateRangeDatasetOpKernel(*range_dataset_params, &dataset_kernel_)); + MakeDatasetOpKernel(*range_dataset_params, &dataset_kernel_)); gtl::InlinedVector inputs; TF_RETURN_IF_ERROR(range_dataset_params->MakeInputs(&inputs)); TF_RETURN_IF_ERROR( @@ -39,16 +38,16 @@ class RangeDatasetOpTest : public DatasetOpsTestBaseV2 { CreateDataset(dataset_kernel_.get(), dataset_ctx_.get(), &dataset_)); TF_RETURN_IF_ERROR( CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_)); - TF_RETURN_IF_ERROR(dataset_->MakeIterator(iterator_ctx_.get(), - kIteratorPrefix, &iterator_)); + TF_RETURN_IF_ERROR(dataset_->MakeIterator( + iterator_ctx_.get(), range_dataset_params->iterator_prefix, + &iterator_)); return Status::OK(); } protected: - // Creates a new `BatchDataset` op kernel. - Status CreateRangeDatasetOpKernel( + Status MakeDatasetOpKernel( const RangeDatasetParams& dataset_params, - std::unique_ptr* range_dataset_op_kernel) { + std::unique_ptr* range_dataset_op_kernel) override { NodeDef node_def = test::function::NDef( dataset_params.node_name, name_utils::OpName(RangeDatasetOp::kDatasetType), @@ -87,10 +86,6 @@ RangeDatasetParams ZeroStepRangeDatasetParams() { /*node_name=*/kNodeName}; } -class ParameterizedGetNextTest : public RangeDatasetOpTest, - public ::testing::WithParamInterface< - GetNextTestCase> {}; - std::vector> GetNextTestCases() { return {{/*dataset_params=*/PositiveStepRangeDatasetParams(), /*expected_outputs=*/ @@ -100,17 +95,8 @@ std::vector> GetNextTestCases() { CreateTensors(TensorShape({}), {{10}, {7}, {4}, {1}})}}; } -TEST_P(ParameterizedGetNextTest, GetNext) { - auto test_case = GetParam(); - TF_ASSERT_OK(Initialize(&test_case.dataset_params)); - TF_ASSERT_OK( - CheckIteratorGetNext(test_case.expected_outputs, /*compare_order=*/true)); -} - -INSTANTIATE_TEST_SUITE_P( - RangeDatasetOpTest, ParameterizedGetNextTest, - ::testing::ValuesIn( - std::vector>(GetNextTestCases()))); +ITERATOR_GET_NEXT_TEST_P(RangeDatasetOpTest, RangeDatasetParams, + GetNextTestCases()) TEST_F(RangeDatasetOpTest, DatasetNodeName) { auto range_dataset_params = PositiveStepRangeDatasetParams(); @@ -137,11 +123,6 @@ TEST_F(RangeDatasetOpTest, DatasetOutputShapes) { TF_ASSERT_OK(CheckDatasetOutputShapes({PartialTensorShape({})})); } -class ParameterizedCardinalityTest - : public RangeDatasetOpTest, - public ::testing::WithParamInterface< - CardinalityTestCase> {}; - std::vector> CardinalityTestCases() { return {{/*dataset_params=*/PositiveStepRangeDatasetParams(), /*expected_cardinality=*/4}, @@ -149,16 +130,8 @@ std::vector> CardinalityTestCases() { /*expected_cardinality=*/4}}; } -TEST_P(ParameterizedCardinalityTest, Cardinality) { - auto test_case = GetParam(); - TF_ASSERT_OK(Initialize(&test_case.dataset_params)); - TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality)); -} - -INSTANTIATE_TEST_SUITE_P( - RangeDatasetOpTest, ParameterizedCardinalityTest, - ::testing::ValuesIn(std::vector>( - CardinalityTestCases()))); +DATASET_CARDINALITY_TEST_P(RangeDatasetOpTest, RangeDatasetParams, + CardinalityTestCases()) TEST_F(RangeDatasetOpTest, IteratorOutputDtypes) { auto range_dataset_params = PositiveStepRangeDatasetParams(); @@ -176,14 +149,9 @@ TEST_F(RangeDatasetOpTest, IteratorPrefix) { auto range_dataset_params = PositiveStepRangeDatasetParams(); TF_ASSERT_OK(Initialize(&range_dataset_params)); TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix( - RangeDatasetOp::kDatasetType, kIteratorPrefix))); + RangeDatasetOp::kDatasetType, range_dataset_params.iterator_prefix))); } -class ParameterizedIteratorSaveAndRestoreTest - : public RangeDatasetOpTest, - public ::testing::WithParamInterface< - IteratorSaveAndRestoreTestCase> {}; - std::vector> IteratorSaveAndRestoreTestCases() { return {{/*dataset_params=*/PositiveStepRangeDatasetParams(), @@ -196,18 +164,8 @@ IteratorSaveAndRestoreTestCases() { CreateTensors(TensorShape({}), {{10}, {7}, {4}, {1}})}}; } -TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) { - auto test_case = GetParam(); - TF_ASSERT_OK(Initialize(&test_case.dataset_params)); - TF_ASSERT_OK(CheckIteratorSaveAndRestore( - kIteratorPrefix, test_case.expected_outputs, test_case.breakpoints)); -} - -INSTANTIATE_TEST_SUITE_P( - RangeDatasetOpTest, ParameterizedIteratorSaveAndRestoreTest, - ::testing::ValuesIn( - std::vector>( - IteratorSaveAndRestoreTestCases()))); +ITERATOR_SAVE_AND_RESTORE_TEST_P(RangeDatasetOpTest, RangeDatasetParams, + IteratorSaveAndRestoreTestCases()) TEST_F(RangeDatasetOpTest, ZeroStep) { auto range_dataset_params = ZeroStepRangeDatasetParams();