Merge pull request #31592 from feihugis:Refactoring_Dataset_Tests

PiperOrigin-RevId: 263659657
This commit is contained in:
TensorFlower Gardener 2019-08-15 15:59:43 -07:00
commit 4a186b5ed4
4 changed files with 268 additions and 212 deletions

View File

@ -19,7 +19,6 @@ namespace {
constexpr char kNodeName[] = "batch_dataset_v2";
constexpr int kOpVersion = 2;
constexpr char kIteratorPrefix[] = "Iterator";
class BatchDatasetParams : public DatasetParams {
public:
@ -36,8 +35,7 @@ class BatchDatasetParams : public DatasetParams {
parallel_copy(parallel_copy) {}
Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override {
if (input_dataset.NumElements() == 0 ||
input_dataset.dtype() != DT_VARIANT) {
if (!IsDatasetTensor(input_dataset)) {
return errors::Internal(
"The input dataset is not populated as the dataset tensor yet.");
}
@ -67,7 +65,7 @@ class BatchDatasetOpTest : public DatasetOpsTestBaseV2<BatchDatasetParams> {
&batch_dataset_params->input_dataset));
// Create the dataset kernel.
TF_RETURN_IF_ERROR(
CreateBatchDatasetOpKernel(*batch_dataset_params, &dataset_kernel_));
MakeDatasetOpKernel(*batch_dataset_params, &dataset_kernel_));
// Create the inputs for the dataset op.
gtl::InlinedVector<TensorValue, 4> inputs;
TF_RETURN_IF_ERROR(batch_dataset_params->MakeInputs(&inputs));
@ -82,16 +80,17 @@ class BatchDatasetOpTest : public DatasetOpsTestBaseV2<BatchDatasetParams> {
TF_RETURN_IF_ERROR(
CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
// Create the iterator.
TF_RETURN_IF_ERROR(dataset_->MakeIterator(iterator_ctx_.get(),
kIteratorPrefix, &iterator_));
TF_RETURN_IF_ERROR(dataset_->MakeIterator(
iterator_ctx_.get(), batch_dataset_params->iterator_prefix,
&iterator_));
return Status::OK();
}
protected:
// Creates a new `BatchDataset` op kernel.
Status CreateBatchDatasetOpKernel(
Status MakeDatasetOpKernel(
const BatchDatasetParams& dataset_params,
std::unique_ptr<OpKernel>* batch_dataset_op_kernel) {
std::unique_ptr<OpKernel>* batch_dataset_op_kernel) override {
name_utils::OpNameParams params;
params.op_version = kOpVersion;
NodeDef node_def = test::function::NDef(
@ -202,10 +201,6 @@ BatchDatasetParams InvalidBatchSizeBatchDatasetParams() {
/*node_name=*/kNodeName};
}
class ParameterizedGetNextTest : public BatchDatasetOpTest,
public ::testing::WithParamInterface<
GetNextTestCase<BatchDatasetParams>> {};
std::vector<GetNextTestCase<BatchDatasetParams>> GetNextTestCases() {
return {{/*dataset_params=*/BatchDatasetParams1(),
/*expected_outputs=*/
@ -236,17 +231,8 @@ std::vector<GetNextTestCase<BatchDatasetParams>> GetNextTestCases() {
/*expected_outputs=*/{}}};
}
TEST_P(ParameterizedGetNextTest, GetNext) {
auto test_case = GetParam();
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
TF_ASSERT_OK(
CheckIteratorGetNext(test_case.expected_outputs, /*compare_order=*/true));
}
INSTANTIATE_TEST_SUITE_P(
BatchDatasetOpTest, ParameterizedGetNextTest,
::testing::ValuesIn(
std::vector<GetNextTestCase<BatchDatasetParams>>(GetNextTestCases())));
ITERATOR_GET_NEXT_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
GetNextTestCases())
TEST_F(BatchDatasetOpTest, DatasetNodeName) {
auto batch_dataset_params = BatchDatasetParams1();
@ -269,11 +255,6 @@ TEST_F(BatchDatasetOpTest, DatasetOutputDtypes) {
TF_ASSERT_OK(CheckDatasetOutputDtypes({DT_INT64}));
}
class ParameterizedDatasetOutputShapesTest
: public BatchDatasetOpTest,
public ::testing::WithParamInterface<
DatasetOutputShapesTestCase<BatchDatasetParams>> {};
std::vector<DatasetOutputShapesTestCase<BatchDatasetParams>>
DatasetOutputShapesTestCases() {
return {{/*dataset_params=*/BatchDatasetParams1(),
@ -292,22 +273,8 @@ DatasetOutputShapesTestCases() {
/*expected_output_shapes=*/{PartialTensorShape({4})}}};
}
TEST_P(ParameterizedDatasetOutputShapesTest, DatasetOutputShapes) {
auto test_case = GetParam();
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
TF_ASSERT_OK(CheckDatasetOutputShapes(test_case.expected_output_shapes));
}
INSTANTIATE_TEST_SUITE_P(
BatchDatasetOpTest, ParameterizedDatasetOutputShapesTest,
::testing::ValuesIn(
std::vector<DatasetOutputShapesTestCase<BatchDatasetParams>>(
DatasetOutputShapesTestCases())));
class ParameterizedCardinalityTest
: public BatchDatasetOpTest,
public ::testing::WithParamInterface<
CardinalityTestCase<BatchDatasetParams>> {};
DATASET_OUTPUT_SHAPES_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
DatasetOutputShapesTestCases())
std::vector<CardinalityTestCase<BatchDatasetParams>> CardinalityTestCases() {
return {
@ -320,16 +287,8 @@ std::vector<CardinalityTestCase<BatchDatasetParams>> CardinalityTestCases() {
{/*dataset_params=*/BatchDatasetParams7(), /*expected_cardinality=*/0}};
}
TEST_P(ParameterizedCardinalityTest, Cardinality) {
auto test_case = GetParam();
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality));
}
INSTANTIATE_TEST_SUITE_P(
BatchDatasetOpTest, ParameterizedCardinalityTest,
::testing::ValuesIn(std::vector<CardinalityTestCase<BatchDatasetParams>>(
CardinalityTestCases())));
DATASET_CARDINALITY_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
CardinalityTestCases())
TEST_F(BatchDatasetOpTest, IteratorOutputDtypes) {
auto batch_dataset_params = BatchDatasetParams1();
@ -337,11 +296,6 @@ TEST_F(BatchDatasetOpTest, IteratorOutputDtypes) {
TF_ASSERT_OK(CheckIteratorOutputDtypes({DT_INT64}));
}
class ParameterizedIteratorOutputShapesTest
: public BatchDatasetOpTest,
public ::testing::WithParamInterface<
IteratorOutputShapesTestCase<BatchDatasetParams>> {};
std::vector<IteratorOutputShapesTestCase<BatchDatasetParams>>
IteratorOutputShapesTestCases() {
return {{/*dataset_params=*/BatchDatasetParams1(),
@ -360,17 +314,8 @@ IteratorOutputShapesTestCases() {
/*expected_output_shapes=*/{PartialTensorShape({4})}}};
}
TEST_P(ParameterizedIteratorOutputShapesTest, IteratorOutputShapes) {
auto test_case = GetParam();
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
TF_ASSERT_OK(CheckIteratorOutputShapes(test_case.expected_output_shapes));
}
INSTANTIATE_TEST_SUITE_P(
BatchDatasetOpTest, ParameterizedIteratorOutputShapesTest,
::testing::ValuesIn(
std::vector<IteratorOutputShapesTestCase<BatchDatasetParams>>(
IteratorOutputShapesTestCases())));
ITERATOR_OUTPUT_SHAPES_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
IteratorOutputShapesTestCases())
TEST_F(BatchDatasetOpTest, IteratorOutputPrefix) {
auto batch_dataset_params = BatchDatasetParams1();
@ -378,14 +323,10 @@ TEST_F(BatchDatasetOpTest, IteratorOutputPrefix) {
name_utils::IteratorPrefixParams params;
params.op_version = kOpVersion;
TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix(
BatchDatasetOp::kDatasetType, kIteratorPrefix, params)));
BatchDatasetOp::kDatasetType, batch_dataset_params.iterator_prefix,
params)));
}
class ParameterizedIteratorSaveAndRestoreTest
: public BatchDatasetOpTest,
public ::testing::WithParamInterface<
IteratorSaveAndRestoreTestCase<BatchDatasetParams>> {};
std::vector<IteratorSaveAndRestoreTestCase<BatchDatasetParams>>
IteratorSaveAndRestoreTestCases() {
return {{/*dataset_params=*/BatchDatasetParams1(),
@ -424,18 +365,8 @@ IteratorSaveAndRestoreTestCases() {
/*expected_outputs=*/{}}};
}
TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) {
auto test_case = GetParam();
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
TF_ASSERT_OK(CheckIteratorSaveAndRestore(
kIteratorPrefix, test_case.expected_outputs, test_case.breakpoints));
}
INSTANTIATE_TEST_SUITE_P(
BatchDatasetOpTest, ParameterizedIteratorSaveAndRestoreTest,
::testing::ValuesIn(
std::vector<IteratorSaveAndRestoreTestCase<BatchDatasetParams>>(
IteratorSaveAndRestoreTestCases())));
ITERATOR_SAVE_AND_RESTORE_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
IteratorSaveAndRestoreTestCases())
TEST_F(BatchDatasetOpTest, InvalidBatchSize) {
auto batch_dataset_params = InvalidBatchSizeBatchDatasetParams();

View File

@ -47,6 +47,7 @@ namespace data {
constexpr int kDefaultCPUNum = 2;
constexpr int kDefaultThreadNum = 2;
constexpr char kDefaultIteratorPrefix[] = "Iterator";
enum class CompressionType { ZLIB = 0, GZIP = 1, RAW = 2, UNCOMPRESSED = 3 };
@ -109,13 +110,19 @@ class DatasetParams {
output_shapes(std::move(output_shapes)),
node_name(std::move(node_name)) {}
virtual ~DatasetParams() {}
virtual Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) = 0;
virtual ~DatasetParams() {}
bool IsDatasetTensor(const Tensor& tensor) {
return tensor.dtype() == DT_VARIANT &&
TensorShapeUtils::IsScalar(tensor.shape());
}
DataTypeVector output_dtypes;
std::vector<PartialTensorShape> output_shapes;
string node_name;
string iterator_prefix = kDefaultIteratorPrefix;
};
class RangeDatasetParams : public DatasetParams {
@ -130,6 +137,12 @@ class RangeDatasetParams : public DatasetParams {
stop(CreateTensor<int64>(TensorShape({}), {stop})),
step(CreateTensor<int64>(TensorShape({}), {step})) {}
RangeDatasetParams(int64 start, int64 stop, int64 step)
: DatasetParams({DT_INT64}, {PartialTensorShape({})}, ""),
start(CreateTensor<int64>(TensorShape({}), {start})),
stop(CreateTensor<int64>(TensorShape({}), {stop})),
step(CreateTensor<int64>(TensorShape({}), {step})) {}
Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override {
*inputs = {TensorValue(&start), TensorValue(&stop), TensorValue(&step)};
return Status::OK();
@ -471,8 +484,205 @@ class DatasetOpsTestBaseV2 : public DatasetOpsTestBase {
public:
// Initializes the required members for running the unit tests.
virtual Status Initialize(T* dataset_params) = 0;
virtual Status MakeDatasetOpKernel(
const T& dataset_params, std::unique_ptr<OpKernel>* dataset_kernel) = 0;
};
#define ITERATOR_GET_NEXT_TEST_P(dataset_op_test_class, dataset_params_class, \
test_case_generator) \
class ParameterizedGetNextTest \
: public dataset_op_test_class, \
public ::testing::WithParamInterface< \
GetNextTestCase<dataset_params_class>> {}; \
\
TEST_P(ParameterizedGetNextTest, GetNext) { \
auto test_case = GetParam(); \
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
TF_ASSERT_OK(CheckIteratorGetNext(test_case.expected_outputs, \
/*compare_order=*/true)); \
} \
\
INSTANTIATE_TEST_SUITE_P( \
dataset_op_test_class, ParameterizedGetNextTest, \
::testing::ValuesIn(std::vector<GetNextTestCase<dataset_params_class>>( \
test_case_generator)));
#define DATASET_NODE_NAME_TEST_P(dataset_op_test_class, dataset_params_class, \
test_case_generator) \
class ParameterizedDatasetNodeNameTest \
: public dataset_op_test_class, \
public ::testing::WithParamInterface< \
DatasetNodeNameTestCase<dataset_params_class>> {}; \
\
TEST_P(ParameterizedDatasetNodeNameTest, DatasetNodeName) { \
auto test_case = GetParam(); \
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
TF_ASSERT_OK(CheckDatasetNodeName(test_case.expected_node_name)); \
} \
\
INSTANTIATE_TEST_SUITE_P( \
dataset_op_test_class, ParameterizedDatasetNodeNameTest, \
::testing::ValuesIn( \
std::vector<DatasetNodeNameTestCase<dataset_params_class>>( \
test_case_generator)));
#define DATASET_TYPE_STRING_TEST_P(dataset_op_test_class, \
dataset_params_class, test_case_generator) \
class ParameterizedDatasetTypeStringTest \
: public dataset_op_test_class, \
public ::testing::WithParamInterface< \
DatasetTypeStringTestCase<dataset_params_class>> {}; \
\
TEST_P(ParameterizedDatasetTypeStringTest, DatasetTypeString) { \
auto test_case = GetParam(); \
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
TF_ASSERT_OK( \
CheckDatasetTypeString(test_case.expected_dataset_type_string)); \
} \
\
INSTANTIATE_TEST_SUITE_P( \
dataset_op_test_class, ParameterizedDatasetTypeStringTest, \
::testing::ValuesIn( \
std::vector<DatasetTypeStringTestCase<dataset_params_class>>( \
test_case_generator)));
#define DATASET_OUTPUT_DTYPES_TEST_P( \
dataset_op_test_class, dataset_params_class, test_case_generator) \
\
class ParameterizedDatasetOutputDtypesTest \
: public dataset_op_test_class, \
public ::testing::WithParamInterface< \
DatasetOutputDtypesTestCase<dataset_params_class>> {}; \
\
TEST_P(ParameterizedDatasetOutputDtypesTest, DatasetOutputDtypes) { \
auto test_case = GetParam(); \
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
} \
\
INSTANTIATE_TEST_SUITE_P( \
dataset_op_test_class, ParameterizedDatasetOutputDtypesTest, \
::testing::ValuesIn( \
std::vector<DatasetOutputDtypesTestCase<dataset_params_class>>( \
test_case_generator)));
#define DATASET_OUTPUT_SHAPES_TEST_P( \
dataset_op_test_class, dataset_params_class, test_case_generator) \
\
class ParameterizedDatasetOutputShapesTest \
: public dataset_op_test_class, \
public ::testing::WithParamInterface< \
DatasetOutputShapesTestCase<dataset_params_class>> {}; \
\
TEST_P(ParameterizedDatasetOutputShapesTest, DatasetOutputShapes) { \
auto test_case = GetParam(); \
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
TF_ASSERT_OK(CheckDatasetOutputShapes(test_case.expected_output_shapes)); \
} \
\
INSTANTIATE_TEST_SUITE_P( \
dataset_op_test_class, ParameterizedDatasetOutputShapesTest, \
::testing::ValuesIn( \
std::vector<DatasetOutputShapesTestCase<dataset_params_class>>( \
test_case_generator)));
#define DATASET_CARDINALITY_TEST_P(dataset_op_test_class, \
dataset_params_class, test_case_generator) \
\
class ParameterizedCardinalityTest \
: public dataset_op_test_class, \
public ::testing::WithParamInterface< \
CardinalityTestCase<dataset_params_class>> {}; \
\
TEST_P(ParameterizedCardinalityTest, Cardinality) { \
auto test_case = GetParam(); \
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality)); \
} \
\
INSTANTIATE_TEST_SUITE_P( \
dataset_op_test_class, ParameterizedCardinalityTest, \
::testing::ValuesIn( \
std::vector<CardinalityTestCase<dataset_params_class>>( \
test_case_generator)));
#define ITERATOR_OUTPUT_DTYPES_TEST_P( \
dataset_op_test_class, dataset_params_class, test_case_generator) \
class ParameterizedIteratorOutputDtypesTest \
: public dataset_op_test_class, \
public ::testing::WithParamInterface< \
IteratorOutputDtypesTestCase<dataset_params_class>> {}; \
\
TEST_P(ParameterizedIteratorOutputDtypesTest, IteratorOutputDtypes) { \
auto test_case = GetParam(); \
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
} \
\
INSTANTIATE_TEST_SUITE_P( \
dataset_op_test_class, ParameterizedIteratorOutputDtypesTest, \
::testing::ValuesIn( \
std::vector<IteratorOutputDtypesTestCase<dataset_params_class>>( \
test_case_generator)));
#define ITERATOR_OUTPUT_SHAPES_TEST_P( \
dataset_op_test_class, dataset_params_class, test_case_generator) \
class ParameterizedIteratorOutputShapesTest \
: public dataset_op_test_class, \
public ::testing::WithParamInterface< \
IteratorOutputShapesTestCase<dataset_params_class>> {}; \
\
TEST_P(ParameterizedIteratorOutputShapesTest, IteratorOutputShapes) { \
auto test_case = GetParam(); \
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
TF_ASSERT_OK(CheckIteratorOutputShapes(test_case.expected_output_shapes)); \
} \
\
INSTANTIATE_TEST_SUITE_P( \
dataset_op_test_class, ParameterizedIteratorOutputShapesTest, \
::testing::ValuesIn( \
std::vector<IteratorOutputShapesTestCase<dataset_params_class>>( \
test_case_generator)));
#define ITERATOR_PREFIX_TEST_P(dataset_op_test_class, dataset_params_class, \
test_case_generator) \
class ParameterizedIteratorPrefixTest \
: public dataset_op_test_class, \
public ::testing::WithParamInterface< \
IteratorPrefixTestCase<dataset_params_class>> {}; \
\
TEST_P(ParameterizedIteratorPrefixTest, IteratorPrefix) { \
auto test_case = GetParam(); \
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
TF_ASSERT_OK(CheckIteratorPrefix(test_case.expected_iterator_prefix)); \
} \
\
INSTANTIATE_TEST_SUITE_P( \
dataset_op_test_class, ParameterizedIteratorPrefixTest, \
::testing::ValuesIn( \
std::vector<IteratorPrefixTestCase<dataset_params_class>>( \
test_case_generator)));
#define ITERATOR_SAVE_AND_RESTORE_TEST_P( \
dataset_op_test_class, dataset_params_class, test_case_generator) \
class ParameterizedIteratorSaveAndRestoreTest \
: public dataset_op_test_class, \
public ::testing::WithParamInterface< \
IteratorSaveAndRestoreTestCase<dataset_params_class>> {}; \
TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) { \
auto test_case = GetParam(); \
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
TF_ASSERT_OK(CheckIteratorSaveAndRestore( \
test_case.dataset_params.iterator_prefix, test_case.expected_outputs, \
test_case.breakpoints)); \
} \
INSTANTIATE_TEST_SUITE_P( \
dataset_op_test_class, ParameterizedIteratorSaveAndRestoreTest, \
::testing::ValuesIn( \
std::vector<IteratorSaveAndRestoreTestCase<dataset_params_class>>( \
test_case_generator)));
} // namespace data
} // namespace tensorflow

View File

@ -21,11 +21,10 @@ namespace data {
namespace {
constexpr char kNodeName[] = "map_dataset";
constexpr char kIteratorPrefix[] = "Iterator";
class MapDatasetParams : public DatasetParams {
public:
MapDatasetParams(int64 start, int64 stop, int64 step,
MapDatasetParams(RangeDatasetParams range_dataset_params,
std::vector<Tensor> other_arguments,
FunctionDefHelper::AttrValueWrapper func,
std::vector<FunctionDef> func_lib,
@ -35,8 +34,7 @@ class MapDatasetParams : public DatasetParams {
string node_name)
: DatasetParams(std::move(output_dtypes), std::move(output_shapes),
std::move(node_name)),
range_dataset_params(start, stop, step, {DT_INT64},
{PartialTensorShape({})}, ""),
range_dataset_params(std::move(range_dataset_params)),
other_arguments(std::move(other_arguments)),
func(std::move(func)),
func_lib(std::move(func_lib)),
@ -45,8 +43,7 @@ class MapDatasetParams : public DatasetParams {
preserve_cardinality(preserve_cardinality) {}
Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override {
if (input_dataset.NumElements() == 0 ||
input_dataset.dtype() != DT_VARIANT) {
if (!IsDatasetTensor(input_dataset)) {
return tensorflow::errors::Internal(
"The input dataset is not populated as the dataset tensor yet.");
}
@ -75,7 +72,7 @@ class MapDatasetOpTest : public DatasetOpsTestBaseV2<MapDatasetParams> {
InitFunctionLibraryRuntime(map_dataset_params->func_lib, cpu_num_));
TF_RETURN_IF_ERROR(
CreateMapDatasetOpKernel(*map_dataset_params, &dataset_kernel_));
MakeDatasetOpKernel(*map_dataset_params, &dataset_kernel_));
TF_RETURN_IF_ERROR(
MakeRangeDataset(map_dataset_params->range_dataset_params,
&map_dataset_params->input_dataset));
@ -87,15 +84,15 @@ class MapDatasetOpTest : public DatasetOpsTestBaseV2<MapDatasetParams> {
CreateDataset(dataset_kernel_.get(), dataset_ctx_.get(), &dataset_));
TF_RETURN_IF_ERROR(
CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
TF_RETURN_IF_ERROR(dataset_->MakeIterator(iterator_ctx_.get(),
kIteratorPrefix, &iterator_));
TF_RETURN_IF_ERROR(dataset_->MakeIterator(
iterator_ctx_.get(), map_dataset_params->iterator_prefix, &iterator_));
return Status::OK();
}
protected:
// Creates a new MapDataset op kernel.
Status CreateMapDatasetOpKernel(const MapDatasetParams& map_dataset_params,
std::unique_ptr<OpKernel>* map_kernel) {
Status MakeDatasetOpKernel(const MapDatasetParams& map_dataset_params,
std::unique_ptr<OpKernel>* map_kernel) override {
NodeDef map_dataset_node_def = test::function::NDef(
map_dataset_params.node_name,
name_utils::OpName(MapDatasetOp::kDatasetType),
@ -114,9 +111,7 @@ class MapDatasetOpTest : public DatasetOpsTestBaseV2<MapDatasetParams> {
};
MapDatasetParams MapDatasetParams1() {
return {/*start=*/0,
/*stop=*/10,
/*step=*/3,
return {{/*start=*/0, /*stop=*/10, /*step=*/3},
/*other_arguments=*/{},
/*func=*/
FunctionDefHelper::FunctionRef("XTimesTwo", {{"T", DT_INT64}}),
@ -130,9 +125,7 @@ MapDatasetParams MapDatasetParams1() {
}
MapDatasetParams MapDatasetParams2() {
return {/*start=*/10,
/*stop=*/0,
/*step=*/-3,
return {{/*start=*/10, /*stop=*/0, /*step=*/-3},
/*other_arguments=*/{},
/*func=*/
FunctionDefHelper::FunctionRef("XAddX", {{"T", DT_INT64}}),
@ -149,9 +142,7 @@ MapDatasetParams MapDatasetParams2() {
// both of them are added to the function library.
MapDatasetParams MapDatasetParams3() {
return {
/*start=*/0,
/*stop=*/10,
/*step=*/3,
{/*start=*/0, /*stop=*/10, /*step=*/3},
/*other_arguments=*/{},
/*func=*/
FunctionDefHelper::FunctionRef("XTimesFour", {{"T", DT_INT64}}),
@ -164,11 +155,6 @@ MapDatasetParams MapDatasetParams3() {
/*node_name=*/kNodeName};
}
class ParameterizedGetNextTest
: public MapDatasetOpTest,
public ::testing::WithParamInterface<GetNextTestCase<MapDatasetParams>> {
};
std::vector<GetNextTestCase<MapDatasetParams>> GetNextTestCases() {
return {{/*dataset_params=*/MapDatasetParams1(),
/*expected_outputs=*/
@ -181,17 +167,16 @@ std::vector<GetNextTestCase<MapDatasetParams>> GetNextTestCases() {
CreateTensors<int64>(TensorShape({}), {{0}, {12}, {24}, {36}})}};
}
TEST_P(ParameterizedGetNextTest, GetNext) {
auto test_case = GetParam();
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
TF_ASSERT_OK(
CheckIteratorGetNext(test_case.expected_outputs, /*compare_order=*/true));
ITERATOR_GET_NEXT_TEST_P(MapDatasetOpTest, MapDatasetParams, GetNextTestCases())
std::vector<DatasetNodeNameTestCase<MapDatasetParams>>
DatasetNodeNameTestCases() {
return {{/*dataset_params=*/MapDatasetParams1(),
/*expected_node_name=*/kNodeName}};
}
INSTANTIATE_TEST_SUITE_P(
MapDatasetOpTest, ParameterizedGetNextTest,
::testing::ValuesIn(
std::vector<GetNextTestCase<MapDatasetParams>>(GetNextTestCases())));
DATASET_NODE_NAME_TEST_P(MapDatasetOpTest, MapDatasetParams,
DatasetNodeNameTestCases())
TEST_F(MapDatasetOpTest, DatasetNodeName) {
auto dataset_params = MapDatasetParams1();
@ -218,27 +203,14 @@ TEST_F(MapDatasetOpTest, DatasetOutputShapes) {
TF_ASSERT_OK(CheckDatasetOutputShapes({PartialTensorShape({})}));
}
class ParameterizedCardinalityTest
: public MapDatasetOpTest,
public ::testing::WithParamInterface<
CardinalityTestCase<MapDatasetParams>> {};
std::vector<CardinalityTestCase<MapDatasetParams>> CardinalityTestCases() {
return {{/*dataset_params=*/MapDatasetParams1(), /*expected_cardinality=*/4},
{/*dataset_params=*/MapDatasetParams2(), /*expected_cardinality=*/4},
{/*dataset_params=*/MapDatasetParams3(), /*expected_cardinality=*/4}};
}
TEST_P(ParameterizedCardinalityTest, Cardinality) {
auto test_case = GetParam();
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality));
}
INSTANTIATE_TEST_SUITE_P(
MapDatasetOpTest, ParameterizedCardinalityTest,
::testing::ValuesIn(std::vector<CardinalityTestCase<MapDatasetParams>>(
CardinalityTestCases())));
DATASET_CARDINALITY_TEST_P(MapDatasetOpTest, MapDatasetParams,
CardinalityTestCases())
TEST_F(MapDatasetOpTest, IteratorOutputDtypes) {
auto dataset_params = MapDatasetParams1();
@ -255,15 +227,10 @@ TEST_F(MapDatasetOpTest, IteratorOutputShapes) {
TEST_F(MapDatasetOpTest, IteratorPrefix) {
auto dataset_params = MapDatasetParams1();
TF_ASSERT_OK(Initialize(&dataset_params));
TF_ASSERT_OK(CheckIteratorPrefix(
name_utils::IteratorPrefix(MapDatasetOp::kDatasetType, kIteratorPrefix)));
TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix(
MapDatasetOp::kDatasetType, dataset_params.iterator_prefix)));
}
class ParameterizedIteratorSaveAndRestoreTest
: public MapDatasetOpTest,
public ::testing::WithParamInterface<
IteratorSaveAndRestoreTestCase<MapDatasetParams>> {};
std::vector<IteratorSaveAndRestoreTestCase<MapDatasetParams>>
IteratorSaveAndRestoreTestCases() {
return {{/*dataset_params=*/MapDatasetParams1(),
@ -280,18 +247,8 @@ IteratorSaveAndRestoreTestCases() {
CreateTensors<int64>(TensorShape({}), {{0}, {12}, {24}, {36}})}};
}
TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) {
auto test_case = GetParam();
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
TF_ASSERT_OK(CheckIteratorSaveAndRestore(
kIteratorPrefix, test_case.expected_outputs, test_case.breakpoints));
}
INSTANTIATE_TEST_SUITE_P(
MapDatasetOpTest, ParameterizedIteratorSaveAndRestoreTest,
::testing::ValuesIn(
std::vector<IteratorSaveAndRestoreTestCase<MapDatasetParams>>(
IteratorSaveAndRestoreTestCases())));
ITERATOR_SAVE_AND_RESTORE_TEST_P(MapDatasetOpTest, MapDatasetParams,
IteratorSaveAndRestoreTestCases())
} // namespace
} // namespace data

View File

@ -21,7 +21,6 @@ namespace data {
namespace {
constexpr char kNodeName[] = "range_dataset";
constexpr char kIteratorPrefix[] = "Iterator";
class RangeDatasetOpTest : public DatasetOpsTestBaseV2<RangeDatasetParams> {
public:
@ -30,7 +29,7 @@ class RangeDatasetOpTest : public DatasetOpsTestBaseV2<RangeDatasetParams> {
TF_RETURN_IF_ERROR(InitFunctionLibraryRuntime({}, cpu_num_));
TF_RETURN_IF_ERROR(
CreateRangeDatasetOpKernel(*range_dataset_params, &dataset_kernel_));
MakeDatasetOpKernel(*range_dataset_params, &dataset_kernel_));
gtl::InlinedVector<TensorValue, 4> inputs;
TF_RETURN_IF_ERROR(range_dataset_params->MakeInputs(&inputs));
TF_RETURN_IF_ERROR(
@ -39,16 +38,16 @@ class RangeDatasetOpTest : public DatasetOpsTestBaseV2<RangeDatasetParams> {
CreateDataset(dataset_kernel_.get(), dataset_ctx_.get(), &dataset_));
TF_RETURN_IF_ERROR(
CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
TF_RETURN_IF_ERROR(dataset_->MakeIterator(iterator_ctx_.get(),
kIteratorPrefix, &iterator_));
TF_RETURN_IF_ERROR(dataset_->MakeIterator(
iterator_ctx_.get(), range_dataset_params->iterator_prefix,
&iterator_));
return Status::OK();
}
protected:
// Creates a new `BatchDataset` op kernel.
Status CreateRangeDatasetOpKernel(
Status MakeDatasetOpKernel(
const RangeDatasetParams& dataset_params,
std::unique_ptr<OpKernel>* range_dataset_op_kernel) {
std::unique_ptr<OpKernel>* range_dataset_op_kernel) override {
NodeDef node_def = test::function::NDef(
dataset_params.node_name,
name_utils::OpName(RangeDatasetOp::kDatasetType),
@ -87,10 +86,6 @@ RangeDatasetParams ZeroStepRangeDatasetParams() {
/*node_name=*/kNodeName};
}
class ParameterizedGetNextTest : public RangeDatasetOpTest,
public ::testing::WithParamInterface<
GetNextTestCase<RangeDatasetParams>> {};
std::vector<GetNextTestCase<RangeDatasetParams>> GetNextTestCases() {
return {{/*dataset_params=*/PositiveStepRangeDatasetParams(),
/*expected_outputs=*/
@ -100,17 +95,8 @@ std::vector<GetNextTestCase<RangeDatasetParams>> GetNextTestCases() {
CreateTensors<int64>(TensorShape({}), {{10}, {7}, {4}, {1}})}};
}
TEST_P(ParameterizedGetNextTest, GetNext) {
auto test_case = GetParam();
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
TF_ASSERT_OK(
CheckIteratorGetNext(test_case.expected_outputs, /*compare_order=*/true));
}
INSTANTIATE_TEST_SUITE_P(
RangeDatasetOpTest, ParameterizedGetNextTest,
::testing::ValuesIn(
std::vector<GetNextTestCase<RangeDatasetParams>>(GetNextTestCases())));
ITERATOR_GET_NEXT_TEST_P(RangeDatasetOpTest, RangeDatasetParams,
GetNextTestCases())
TEST_F(RangeDatasetOpTest, DatasetNodeName) {
auto range_dataset_params = PositiveStepRangeDatasetParams();
@ -137,11 +123,6 @@ TEST_F(RangeDatasetOpTest, DatasetOutputShapes) {
TF_ASSERT_OK(CheckDatasetOutputShapes({PartialTensorShape({})}));
}
class ParameterizedCardinalityTest
: public RangeDatasetOpTest,
public ::testing::WithParamInterface<
CardinalityTestCase<RangeDatasetParams>> {};
std::vector<CardinalityTestCase<RangeDatasetParams>> CardinalityTestCases() {
return {{/*dataset_params=*/PositiveStepRangeDatasetParams(),
/*expected_cardinality=*/4},
@ -149,16 +130,8 @@ std::vector<CardinalityTestCase<RangeDatasetParams>> CardinalityTestCases() {
/*expected_cardinality=*/4}};
}
TEST_P(ParameterizedCardinalityTest, Cardinality) {
auto test_case = GetParam();
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality));
}
INSTANTIATE_TEST_SUITE_P(
RangeDatasetOpTest, ParameterizedCardinalityTest,
::testing::ValuesIn(std::vector<CardinalityTestCase<RangeDatasetParams>>(
CardinalityTestCases())));
DATASET_CARDINALITY_TEST_P(RangeDatasetOpTest, RangeDatasetParams,
CardinalityTestCases())
TEST_F(RangeDatasetOpTest, IteratorOutputDtypes) {
auto range_dataset_params = PositiveStepRangeDatasetParams();
@ -176,14 +149,9 @@ TEST_F(RangeDatasetOpTest, IteratorPrefix) {
auto range_dataset_params = PositiveStepRangeDatasetParams();
TF_ASSERT_OK(Initialize(&range_dataset_params));
TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix(
RangeDatasetOp::kDatasetType, kIteratorPrefix)));
RangeDatasetOp::kDatasetType, range_dataset_params.iterator_prefix)));
}
class ParameterizedIteratorSaveAndRestoreTest
: public RangeDatasetOpTest,
public ::testing::WithParamInterface<
IteratorSaveAndRestoreTestCase<RangeDatasetParams>> {};
std::vector<IteratorSaveAndRestoreTestCase<RangeDatasetParams>>
IteratorSaveAndRestoreTestCases() {
return {{/*dataset_params=*/PositiveStepRangeDatasetParams(),
@ -196,18 +164,8 @@ IteratorSaveAndRestoreTestCases() {
CreateTensors<int64>(TensorShape({}), {{10}, {7}, {4}, {1}})}};
}
TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) {
auto test_case = GetParam();
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
TF_ASSERT_OK(CheckIteratorSaveAndRestore(
kIteratorPrefix, test_case.expected_outputs, test_case.breakpoints));
}
INSTANTIATE_TEST_SUITE_P(
RangeDatasetOpTest, ParameterizedIteratorSaveAndRestoreTest,
::testing::ValuesIn(
std::vector<IteratorSaveAndRestoreTestCase<RangeDatasetParams>>(
IteratorSaveAndRestoreTestCases())));
ITERATOR_SAVE_AND_RESTORE_TEST_P(RangeDatasetOpTest, RangeDatasetParams,
IteratorSaveAndRestoreTestCases())
TEST_F(RangeDatasetOpTest, ZeroStep) {
auto range_dataset_params = ZeroStepRangeDatasetParams();