Merge pull request #31592 from feihugis:Refactoring_Dataset_Tests
PiperOrigin-RevId: 263659657
This commit is contained in:
commit
4a186b5ed4
@ -19,7 +19,6 @@ namespace {
|
||||
|
||||
constexpr char kNodeName[] = "batch_dataset_v2";
|
||||
constexpr int kOpVersion = 2;
|
||||
constexpr char kIteratorPrefix[] = "Iterator";
|
||||
|
||||
class BatchDatasetParams : public DatasetParams {
|
||||
public:
|
||||
@ -36,8 +35,7 @@ class BatchDatasetParams : public DatasetParams {
|
||||
parallel_copy(parallel_copy) {}
|
||||
|
||||
Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override {
|
||||
if (input_dataset.NumElements() == 0 ||
|
||||
input_dataset.dtype() != DT_VARIANT) {
|
||||
if (!IsDatasetTensor(input_dataset)) {
|
||||
return errors::Internal(
|
||||
"The input dataset is not populated as the dataset tensor yet.");
|
||||
}
|
||||
@ -67,7 +65,7 @@ class BatchDatasetOpTest : public DatasetOpsTestBaseV2<BatchDatasetParams> {
|
||||
&batch_dataset_params->input_dataset));
|
||||
// Create the dataset kernel.
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateBatchDatasetOpKernel(*batch_dataset_params, &dataset_kernel_));
|
||||
MakeDatasetOpKernel(*batch_dataset_params, &dataset_kernel_));
|
||||
// Create the inputs for the dataset op.
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
TF_RETURN_IF_ERROR(batch_dataset_params->MakeInputs(&inputs));
|
||||
@ -82,16 +80,17 @@ class BatchDatasetOpTest : public DatasetOpsTestBaseV2<BatchDatasetParams> {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
|
||||
// Create the iterator.
|
||||
TF_RETURN_IF_ERROR(dataset_->MakeIterator(iterator_ctx_.get(),
|
||||
kIteratorPrefix, &iterator_));
|
||||
TF_RETURN_IF_ERROR(dataset_->MakeIterator(
|
||||
iterator_ctx_.get(), batch_dataset_params->iterator_prefix,
|
||||
&iterator_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
// Creates a new `BatchDataset` op kernel.
|
||||
Status CreateBatchDatasetOpKernel(
|
||||
Status MakeDatasetOpKernel(
|
||||
const BatchDatasetParams& dataset_params,
|
||||
std::unique_ptr<OpKernel>* batch_dataset_op_kernel) {
|
||||
std::unique_ptr<OpKernel>* batch_dataset_op_kernel) override {
|
||||
name_utils::OpNameParams params;
|
||||
params.op_version = kOpVersion;
|
||||
NodeDef node_def = test::function::NDef(
|
||||
@ -202,10 +201,6 @@ BatchDatasetParams InvalidBatchSizeBatchDatasetParams() {
|
||||
/*node_name=*/kNodeName};
|
||||
}
|
||||
|
||||
class ParameterizedGetNextTest : public BatchDatasetOpTest,
|
||||
public ::testing::WithParamInterface<
|
||||
GetNextTestCase<BatchDatasetParams>> {};
|
||||
|
||||
std::vector<GetNextTestCase<BatchDatasetParams>> GetNextTestCases() {
|
||||
return {{/*dataset_params=*/BatchDatasetParams1(),
|
||||
/*expected_outputs=*/
|
||||
@ -236,17 +231,8 @@ std::vector<GetNextTestCase<BatchDatasetParams>> GetNextTestCases() {
|
||||
/*expected_outputs=*/{}}};
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedGetNextTest, GetNext) {
|
||||
auto test_case = GetParam();
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
|
||||
TF_ASSERT_OK(
|
||||
CheckIteratorGetNext(test_case.expected_outputs, /*compare_order=*/true));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
BatchDatasetOpTest, ParameterizedGetNextTest,
|
||||
::testing::ValuesIn(
|
||||
std::vector<GetNextTestCase<BatchDatasetParams>>(GetNextTestCases())));
|
||||
ITERATOR_GET_NEXT_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
|
||||
GetNextTestCases())
|
||||
|
||||
TEST_F(BatchDatasetOpTest, DatasetNodeName) {
|
||||
auto batch_dataset_params = BatchDatasetParams1();
|
||||
@ -269,11 +255,6 @@ TEST_F(BatchDatasetOpTest, DatasetOutputDtypes) {
|
||||
TF_ASSERT_OK(CheckDatasetOutputDtypes({DT_INT64}));
|
||||
}
|
||||
|
||||
class ParameterizedDatasetOutputShapesTest
|
||||
: public BatchDatasetOpTest,
|
||||
public ::testing::WithParamInterface<
|
||||
DatasetOutputShapesTestCase<BatchDatasetParams>> {};
|
||||
|
||||
std::vector<DatasetOutputShapesTestCase<BatchDatasetParams>>
|
||||
DatasetOutputShapesTestCases() {
|
||||
return {{/*dataset_params=*/BatchDatasetParams1(),
|
||||
@ -292,22 +273,8 @@ DatasetOutputShapesTestCases() {
|
||||
/*expected_output_shapes=*/{PartialTensorShape({4})}}};
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedDatasetOutputShapesTest, DatasetOutputShapes) {
|
||||
auto test_case = GetParam();
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
|
||||
TF_ASSERT_OK(CheckDatasetOutputShapes(test_case.expected_output_shapes));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
BatchDatasetOpTest, ParameterizedDatasetOutputShapesTest,
|
||||
::testing::ValuesIn(
|
||||
std::vector<DatasetOutputShapesTestCase<BatchDatasetParams>>(
|
||||
DatasetOutputShapesTestCases())));
|
||||
|
||||
class ParameterizedCardinalityTest
|
||||
: public BatchDatasetOpTest,
|
||||
public ::testing::WithParamInterface<
|
||||
CardinalityTestCase<BatchDatasetParams>> {};
|
||||
DATASET_OUTPUT_SHAPES_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
|
||||
DatasetOutputShapesTestCases())
|
||||
|
||||
std::vector<CardinalityTestCase<BatchDatasetParams>> CardinalityTestCases() {
|
||||
return {
|
||||
@ -320,16 +287,8 @@ std::vector<CardinalityTestCase<BatchDatasetParams>> CardinalityTestCases() {
|
||||
{/*dataset_params=*/BatchDatasetParams7(), /*expected_cardinality=*/0}};
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedCardinalityTest, Cardinality) {
|
||||
auto test_case = GetParam();
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
|
||||
TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
BatchDatasetOpTest, ParameterizedCardinalityTest,
|
||||
::testing::ValuesIn(std::vector<CardinalityTestCase<BatchDatasetParams>>(
|
||||
CardinalityTestCases())));
|
||||
DATASET_CARDINALITY_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
|
||||
CardinalityTestCases())
|
||||
|
||||
TEST_F(BatchDatasetOpTest, IteratorOutputDtypes) {
|
||||
auto batch_dataset_params = BatchDatasetParams1();
|
||||
@ -337,11 +296,6 @@ TEST_F(BatchDatasetOpTest, IteratorOutputDtypes) {
|
||||
TF_ASSERT_OK(CheckIteratorOutputDtypes({DT_INT64}));
|
||||
}
|
||||
|
||||
class ParameterizedIteratorOutputShapesTest
|
||||
: public BatchDatasetOpTest,
|
||||
public ::testing::WithParamInterface<
|
||||
IteratorOutputShapesTestCase<BatchDatasetParams>> {};
|
||||
|
||||
std::vector<IteratorOutputShapesTestCase<BatchDatasetParams>>
|
||||
IteratorOutputShapesTestCases() {
|
||||
return {{/*dataset_params=*/BatchDatasetParams1(),
|
||||
@ -360,17 +314,8 @@ IteratorOutputShapesTestCases() {
|
||||
/*expected_output_shapes=*/{PartialTensorShape({4})}}};
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedIteratorOutputShapesTest, IteratorOutputShapes) {
|
||||
auto test_case = GetParam();
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
|
||||
TF_ASSERT_OK(CheckIteratorOutputShapes(test_case.expected_output_shapes));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
BatchDatasetOpTest, ParameterizedIteratorOutputShapesTest,
|
||||
::testing::ValuesIn(
|
||||
std::vector<IteratorOutputShapesTestCase<BatchDatasetParams>>(
|
||||
IteratorOutputShapesTestCases())));
|
||||
ITERATOR_OUTPUT_SHAPES_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
|
||||
IteratorOutputShapesTestCases())
|
||||
|
||||
TEST_F(BatchDatasetOpTest, IteratorOutputPrefix) {
|
||||
auto batch_dataset_params = BatchDatasetParams1();
|
||||
@ -378,14 +323,10 @@ TEST_F(BatchDatasetOpTest, IteratorOutputPrefix) {
|
||||
name_utils::IteratorPrefixParams params;
|
||||
params.op_version = kOpVersion;
|
||||
TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix(
|
||||
BatchDatasetOp::kDatasetType, kIteratorPrefix, params)));
|
||||
BatchDatasetOp::kDatasetType, batch_dataset_params.iterator_prefix,
|
||||
params)));
|
||||
}
|
||||
|
||||
class ParameterizedIteratorSaveAndRestoreTest
|
||||
: public BatchDatasetOpTest,
|
||||
public ::testing::WithParamInterface<
|
||||
IteratorSaveAndRestoreTestCase<BatchDatasetParams>> {};
|
||||
|
||||
std::vector<IteratorSaveAndRestoreTestCase<BatchDatasetParams>>
|
||||
IteratorSaveAndRestoreTestCases() {
|
||||
return {{/*dataset_params=*/BatchDatasetParams1(),
|
||||
@ -424,18 +365,8 @@ IteratorSaveAndRestoreTestCases() {
|
||||
/*expected_outputs=*/{}}};
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) {
|
||||
auto test_case = GetParam();
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
|
||||
TF_ASSERT_OK(CheckIteratorSaveAndRestore(
|
||||
kIteratorPrefix, test_case.expected_outputs, test_case.breakpoints));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
BatchDatasetOpTest, ParameterizedIteratorSaveAndRestoreTest,
|
||||
::testing::ValuesIn(
|
||||
std::vector<IteratorSaveAndRestoreTestCase<BatchDatasetParams>>(
|
||||
IteratorSaveAndRestoreTestCases())));
|
||||
ITERATOR_SAVE_AND_RESTORE_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
|
||||
IteratorSaveAndRestoreTestCases())
|
||||
|
||||
TEST_F(BatchDatasetOpTest, InvalidBatchSize) {
|
||||
auto batch_dataset_params = InvalidBatchSizeBatchDatasetParams();
|
||||
|
@ -47,6 +47,7 @@ namespace data {
|
||||
|
||||
constexpr int kDefaultCPUNum = 2;
|
||||
constexpr int kDefaultThreadNum = 2;
|
||||
constexpr char kDefaultIteratorPrefix[] = "Iterator";
|
||||
|
||||
enum class CompressionType { ZLIB = 0, GZIP = 1, RAW = 2, UNCOMPRESSED = 3 };
|
||||
|
||||
@ -109,13 +110,19 @@ class DatasetParams {
|
||||
output_shapes(std::move(output_shapes)),
|
||||
node_name(std::move(node_name)) {}
|
||||
|
||||
virtual ~DatasetParams() {}
|
||||
|
||||
virtual Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) = 0;
|
||||
|
||||
virtual ~DatasetParams() {}
|
||||
bool IsDatasetTensor(const Tensor& tensor) {
|
||||
return tensor.dtype() == DT_VARIANT &&
|
||||
TensorShapeUtils::IsScalar(tensor.shape());
|
||||
}
|
||||
|
||||
DataTypeVector output_dtypes;
|
||||
std::vector<PartialTensorShape> output_shapes;
|
||||
string node_name;
|
||||
string iterator_prefix = kDefaultIteratorPrefix;
|
||||
};
|
||||
|
||||
class RangeDatasetParams : public DatasetParams {
|
||||
@ -130,6 +137,12 @@ class RangeDatasetParams : public DatasetParams {
|
||||
stop(CreateTensor<int64>(TensorShape({}), {stop})),
|
||||
step(CreateTensor<int64>(TensorShape({}), {step})) {}
|
||||
|
||||
RangeDatasetParams(int64 start, int64 stop, int64 step)
|
||||
: DatasetParams({DT_INT64}, {PartialTensorShape({})}, ""),
|
||||
start(CreateTensor<int64>(TensorShape({}), {start})),
|
||||
stop(CreateTensor<int64>(TensorShape({}), {stop})),
|
||||
step(CreateTensor<int64>(TensorShape({}), {step})) {}
|
||||
|
||||
Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override {
|
||||
*inputs = {TensorValue(&start), TensorValue(&stop), TensorValue(&step)};
|
||||
return Status::OK();
|
||||
@ -471,8 +484,205 @@ class DatasetOpsTestBaseV2 : public DatasetOpsTestBase {
|
||||
public:
|
||||
// Initializes the required members for running the unit tests.
|
||||
virtual Status Initialize(T* dataset_params) = 0;
|
||||
|
||||
virtual Status MakeDatasetOpKernel(
|
||||
const T& dataset_params, std::unique_ptr<OpKernel>* dataset_kernel) = 0;
|
||||
};
|
||||
|
||||
#define ITERATOR_GET_NEXT_TEST_P(dataset_op_test_class, dataset_params_class, \
|
||||
test_case_generator) \
|
||||
class ParameterizedGetNextTest \
|
||||
: public dataset_op_test_class, \
|
||||
public ::testing::WithParamInterface< \
|
||||
GetNextTestCase<dataset_params_class>> {}; \
|
||||
\
|
||||
TEST_P(ParameterizedGetNextTest, GetNext) { \
|
||||
auto test_case = GetParam(); \
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
|
||||
TF_ASSERT_OK(CheckIteratorGetNext(test_case.expected_outputs, \
|
||||
/*compare_order=*/true)); \
|
||||
} \
|
||||
\
|
||||
INSTANTIATE_TEST_SUITE_P( \
|
||||
dataset_op_test_class, ParameterizedGetNextTest, \
|
||||
::testing::ValuesIn(std::vector<GetNextTestCase<dataset_params_class>>( \
|
||||
test_case_generator)));
|
||||
|
||||
#define DATASET_NODE_NAME_TEST_P(dataset_op_test_class, dataset_params_class, \
|
||||
test_case_generator) \
|
||||
class ParameterizedDatasetNodeNameTest \
|
||||
: public dataset_op_test_class, \
|
||||
public ::testing::WithParamInterface< \
|
||||
DatasetNodeNameTestCase<dataset_params_class>> {}; \
|
||||
\
|
||||
TEST_P(ParameterizedDatasetNodeNameTest, DatasetNodeName) { \
|
||||
auto test_case = GetParam(); \
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
|
||||
TF_ASSERT_OK(CheckDatasetNodeName(test_case.expected_node_name)); \
|
||||
} \
|
||||
\
|
||||
INSTANTIATE_TEST_SUITE_P( \
|
||||
dataset_op_test_class, ParameterizedDatasetNodeNameTest, \
|
||||
::testing::ValuesIn( \
|
||||
std::vector<DatasetNodeNameTestCase<dataset_params_class>>( \
|
||||
test_case_generator)));
|
||||
|
||||
#define DATASET_TYPE_STRING_TEST_P(dataset_op_test_class, \
|
||||
dataset_params_class, test_case_generator) \
|
||||
class ParameterizedDatasetTypeStringTest \
|
||||
: public dataset_op_test_class, \
|
||||
public ::testing::WithParamInterface< \
|
||||
DatasetTypeStringTestCase<dataset_params_class>> {}; \
|
||||
\
|
||||
TEST_P(ParameterizedDatasetTypeStringTest, DatasetTypeString) { \
|
||||
auto test_case = GetParam(); \
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
|
||||
TF_ASSERT_OK( \
|
||||
CheckDatasetTypeString(test_case.expected_dataset_type_string)); \
|
||||
} \
|
||||
\
|
||||
INSTANTIATE_TEST_SUITE_P( \
|
||||
dataset_op_test_class, ParameterizedDatasetTypeStringTest, \
|
||||
::testing::ValuesIn( \
|
||||
std::vector<DatasetTypeStringTestCase<dataset_params_class>>( \
|
||||
test_case_generator)));
|
||||
|
||||
#define DATASET_OUTPUT_DTYPES_TEST_P( \
|
||||
dataset_op_test_class, dataset_params_class, test_case_generator) \
|
||||
\
|
||||
class ParameterizedDatasetOutputDtypesTest \
|
||||
: public dataset_op_test_class, \
|
||||
public ::testing::WithParamInterface< \
|
||||
DatasetOutputDtypesTestCase<dataset_params_class>> {}; \
|
||||
\
|
||||
TEST_P(ParameterizedDatasetOutputDtypesTest, DatasetOutputDtypes) { \
|
||||
auto test_case = GetParam(); \
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
|
||||
TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
|
||||
} \
|
||||
\
|
||||
INSTANTIATE_TEST_SUITE_P( \
|
||||
dataset_op_test_class, ParameterizedDatasetOutputDtypesTest, \
|
||||
::testing::ValuesIn( \
|
||||
std::vector<DatasetOutputDtypesTestCase<dataset_params_class>>( \
|
||||
test_case_generator)));
|
||||
|
||||
#define DATASET_OUTPUT_SHAPES_TEST_P( \
|
||||
dataset_op_test_class, dataset_params_class, test_case_generator) \
|
||||
\
|
||||
class ParameterizedDatasetOutputShapesTest \
|
||||
: public dataset_op_test_class, \
|
||||
public ::testing::WithParamInterface< \
|
||||
DatasetOutputShapesTestCase<dataset_params_class>> {}; \
|
||||
\
|
||||
TEST_P(ParameterizedDatasetOutputShapesTest, DatasetOutputShapes) { \
|
||||
auto test_case = GetParam(); \
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
|
||||
TF_ASSERT_OK(CheckDatasetOutputShapes(test_case.expected_output_shapes)); \
|
||||
} \
|
||||
\
|
||||
INSTANTIATE_TEST_SUITE_P( \
|
||||
dataset_op_test_class, ParameterizedDatasetOutputShapesTest, \
|
||||
::testing::ValuesIn( \
|
||||
std::vector<DatasetOutputShapesTestCase<dataset_params_class>>( \
|
||||
test_case_generator)));
|
||||
|
||||
#define DATASET_CARDINALITY_TEST_P(dataset_op_test_class, \
|
||||
dataset_params_class, test_case_generator) \
|
||||
\
|
||||
class ParameterizedCardinalityTest \
|
||||
: public dataset_op_test_class, \
|
||||
public ::testing::WithParamInterface< \
|
||||
CardinalityTestCase<dataset_params_class>> {}; \
|
||||
\
|
||||
TEST_P(ParameterizedCardinalityTest, Cardinality) { \
|
||||
auto test_case = GetParam(); \
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
|
||||
TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality)); \
|
||||
} \
|
||||
\
|
||||
INSTANTIATE_TEST_SUITE_P( \
|
||||
dataset_op_test_class, ParameterizedCardinalityTest, \
|
||||
::testing::ValuesIn( \
|
||||
std::vector<CardinalityTestCase<dataset_params_class>>( \
|
||||
test_case_generator)));
|
||||
|
||||
#define ITERATOR_OUTPUT_DTYPES_TEST_P( \
|
||||
dataset_op_test_class, dataset_params_class, test_case_generator) \
|
||||
class ParameterizedIteratorOutputDtypesTest \
|
||||
: public dataset_op_test_class, \
|
||||
public ::testing::WithParamInterface< \
|
||||
IteratorOutputDtypesTestCase<dataset_params_class>> {}; \
|
||||
\
|
||||
TEST_P(ParameterizedIteratorOutputDtypesTest, IteratorOutputDtypes) { \
|
||||
auto test_case = GetParam(); \
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
|
||||
TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
|
||||
} \
|
||||
\
|
||||
INSTANTIATE_TEST_SUITE_P( \
|
||||
dataset_op_test_class, ParameterizedIteratorOutputDtypesTest, \
|
||||
::testing::ValuesIn( \
|
||||
std::vector<IteratorOutputDtypesTestCase<dataset_params_class>>( \
|
||||
test_case_generator)));
|
||||
|
||||
#define ITERATOR_OUTPUT_SHAPES_TEST_P( \
|
||||
dataset_op_test_class, dataset_params_class, test_case_generator) \
|
||||
class ParameterizedIteratorOutputShapesTest \
|
||||
: public dataset_op_test_class, \
|
||||
public ::testing::WithParamInterface< \
|
||||
IteratorOutputShapesTestCase<dataset_params_class>> {}; \
|
||||
\
|
||||
TEST_P(ParameterizedIteratorOutputShapesTest, IteratorOutputShapes) { \
|
||||
auto test_case = GetParam(); \
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
|
||||
TF_ASSERT_OK(CheckIteratorOutputShapes(test_case.expected_output_shapes)); \
|
||||
} \
|
||||
\
|
||||
INSTANTIATE_TEST_SUITE_P( \
|
||||
dataset_op_test_class, ParameterizedIteratorOutputShapesTest, \
|
||||
::testing::ValuesIn( \
|
||||
std::vector<IteratorOutputShapesTestCase<dataset_params_class>>( \
|
||||
test_case_generator)));
|
||||
|
||||
#define ITERATOR_PREFIX_TEST_P(dataset_op_test_class, dataset_params_class, \
|
||||
test_case_generator) \
|
||||
class ParameterizedIteratorPrefixTest \
|
||||
: public dataset_op_test_class, \
|
||||
public ::testing::WithParamInterface< \
|
||||
IteratorPrefixTestCase<dataset_params_class>> {}; \
|
||||
\
|
||||
TEST_P(ParameterizedIteratorPrefixTest, IteratorPrefix) { \
|
||||
auto test_case = GetParam(); \
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
|
||||
TF_ASSERT_OK(CheckIteratorPrefix(test_case.expected_iterator_prefix)); \
|
||||
} \
|
||||
\
|
||||
INSTANTIATE_TEST_SUITE_P( \
|
||||
dataset_op_test_class, ParameterizedIteratorPrefixTest, \
|
||||
::testing::ValuesIn( \
|
||||
std::vector<IteratorPrefixTestCase<dataset_params_class>>( \
|
||||
test_case_generator)));
|
||||
|
||||
#define ITERATOR_SAVE_AND_RESTORE_TEST_P( \
|
||||
dataset_op_test_class, dataset_params_class, test_case_generator) \
|
||||
class ParameterizedIteratorSaveAndRestoreTest \
|
||||
: public dataset_op_test_class, \
|
||||
public ::testing::WithParamInterface< \
|
||||
IteratorSaveAndRestoreTestCase<dataset_params_class>> {}; \
|
||||
TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) { \
|
||||
auto test_case = GetParam(); \
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params)); \
|
||||
TF_ASSERT_OK(CheckIteratorSaveAndRestore( \
|
||||
test_case.dataset_params.iterator_prefix, test_case.expected_outputs, \
|
||||
test_case.breakpoints)); \
|
||||
} \
|
||||
INSTANTIATE_TEST_SUITE_P( \
|
||||
dataset_op_test_class, ParameterizedIteratorSaveAndRestoreTest, \
|
||||
::testing::ValuesIn( \
|
||||
std::vector<IteratorSaveAndRestoreTestCase<dataset_params_class>>( \
|
||||
test_case_generator)));
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -21,11 +21,10 @@ namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr char kNodeName[] = "map_dataset";
|
||||
constexpr char kIteratorPrefix[] = "Iterator";
|
||||
|
||||
class MapDatasetParams : public DatasetParams {
|
||||
public:
|
||||
MapDatasetParams(int64 start, int64 stop, int64 step,
|
||||
MapDatasetParams(RangeDatasetParams range_dataset_params,
|
||||
std::vector<Tensor> other_arguments,
|
||||
FunctionDefHelper::AttrValueWrapper func,
|
||||
std::vector<FunctionDef> func_lib,
|
||||
@ -35,8 +34,7 @@ class MapDatasetParams : public DatasetParams {
|
||||
string node_name)
|
||||
: DatasetParams(std::move(output_dtypes), std::move(output_shapes),
|
||||
std::move(node_name)),
|
||||
range_dataset_params(start, stop, step, {DT_INT64},
|
||||
{PartialTensorShape({})}, ""),
|
||||
range_dataset_params(std::move(range_dataset_params)),
|
||||
other_arguments(std::move(other_arguments)),
|
||||
func(std::move(func)),
|
||||
func_lib(std::move(func_lib)),
|
||||
@ -45,8 +43,7 @@ class MapDatasetParams : public DatasetParams {
|
||||
preserve_cardinality(preserve_cardinality) {}
|
||||
|
||||
Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override {
|
||||
if (input_dataset.NumElements() == 0 ||
|
||||
input_dataset.dtype() != DT_VARIANT) {
|
||||
if (!IsDatasetTensor(input_dataset)) {
|
||||
return tensorflow::errors::Internal(
|
||||
"The input dataset is not populated as the dataset tensor yet.");
|
||||
}
|
||||
@ -75,7 +72,7 @@ class MapDatasetOpTest : public DatasetOpsTestBaseV2<MapDatasetParams> {
|
||||
InitFunctionLibraryRuntime(map_dataset_params->func_lib, cpu_num_));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateMapDatasetOpKernel(*map_dataset_params, &dataset_kernel_));
|
||||
MakeDatasetOpKernel(*map_dataset_params, &dataset_kernel_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
MakeRangeDataset(map_dataset_params->range_dataset_params,
|
||||
&map_dataset_params->input_dataset));
|
||||
@ -87,15 +84,15 @@ class MapDatasetOpTest : public DatasetOpsTestBaseV2<MapDatasetParams> {
|
||||
CreateDataset(dataset_kernel_.get(), dataset_ctx_.get(), &dataset_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
|
||||
TF_RETURN_IF_ERROR(dataset_->MakeIterator(iterator_ctx_.get(),
|
||||
kIteratorPrefix, &iterator_));
|
||||
TF_RETURN_IF_ERROR(dataset_->MakeIterator(
|
||||
iterator_ctx_.get(), map_dataset_params->iterator_prefix, &iterator_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
// Creates a new MapDataset op kernel.
|
||||
Status CreateMapDatasetOpKernel(const MapDatasetParams& map_dataset_params,
|
||||
std::unique_ptr<OpKernel>* map_kernel) {
|
||||
Status MakeDatasetOpKernel(const MapDatasetParams& map_dataset_params,
|
||||
std::unique_ptr<OpKernel>* map_kernel) override {
|
||||
NodeDef map_dataset_node_def = test::function::NDef(
|
||||
map_dataset_params.node_name,
|
||||
name_utils::OpName(MapDatasetOp::kDatasetType),
|
||||
@ -114,9 +111,7 @@ class MapDatasetOpTest : public DatasetOpsTestBaseV2<MapDatasetParams> {
|
||||
};
|
||||
|
||||
MapDatasetParams MapDatasetParams1() {
|
||||
return {/*start=*/0,
|
||||
/*stop=*/10,
|
||||
/*step=*/3,
|
||||
return {{/*start=*/0, /*stop=*/10, /*step=*/3},
|
||||
/*other_arguments=*/{},
|
||||
/*func=*/
|
||||
FunctionDefHelper::FunctionRef("XTimesTwo", {{"T", DT_INT64}}),
|
||||
@ -130,9 +125,7 @@ MapDatasetParams MapDatasetParams1() {
|
||||
}
|
||||
|
||||
MapDatasetParams MapDatasetParams2() {
|
||||
return {/*start=*/10,
|
||||
/*stop=*/0,
|
||||
/*step=*/-3,
|
||||
return {{/*start=*/10, /*stop=*/0, /*step=*/-3},
|
||||
/*other_arguments=*/{},
|
||||
/*func=*/
|
||||
FunctionDefHelper::FunctionRef("XAddX", {{"T", DT_INT64}}),
|
||||
@ -149,9 +142,7 @@ MapDatasetParams MapDatasetParams2() {
|
||||
// both of them are added to the function library.
|
||||
MapDatasetParams MapDatasetParams3() {
|
||||
return {
|
||||
/*start=*/0,
|
||||
/*stop=*/10,
|
||||
/*step=*/3,
|
||||
{/*start=*/0, /*stop=*/10, /*step=*/3},
|
||||
/*other_arguments=*/{},
|
||||
/*func=*/
|
||||
FunctionDefHelper::FunctionRef("XTimesFour", {{"T", DT_INT64}}),
|
||||
@ -164,11 +155,6 @@ MapDatasetParams MapDatasetParams3() {
|
||||
/*node_name=*/kNodeName};
|
||||
}
|
||||
|
||||
class ParameterizedGetNextTest
|
||||
: public MapDatasetOpTest,
|
||||
public ::testing::WithParamInterface<GetNextTestCase<MapDatasetParams>> {
|
||||
};
|
||||
|
||||
std::vector<GetNextTestCase<MapDatasetParams>> GetNextTestCases() {
|
||||
return {{/*dataset_params=*/MapDatasetParams1(),
|
||||
/*expected_outputs=*/
|
||||
@ -181,17 +167,16 @@ std::vector<GetNextTestCase<MapDatasetParams>> GetNextTestCases() {
|
||||
CreateTensors<int64>(TensorShape({}), {{0}, {12}, {24}, {36}})}};
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedGetNextTest, GetNext) {
|
||||
auto test_case = GetParam();
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
|
||||
TF_ASSERT_OK(
|
||||
CheckIteratorGetNext(test_case.expected_outputs, /*compare_order=*/true));
|
||||
ITERATOR_GET_NEXT_TEST_P(MapDatasetOpTest, MapDatasetParams, GetNextTestCases())
|
||||
|
||||
std::vector<DatasetNodeNameTestCase<MapDatasetParams>>
|
||||
DatasetNodeNameTestCases() {
|
||||
return {{/*dataset_params=*/MapDatasetParams1(),
|
||||
/*expected_node_name=*/kNodeName}};
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
MapDatasetOpTest, ParameterizedGetNextTest,
|
||||
::testing::ValuesIn(
|
||||
std::vector<GetNextTestCase<MapDatasetParams>>(GetNextTestCases())));
|
||||
DATASET_NODE_NAME_TEST_P(MapDatasetOpTest, MapDatasetParams,
|
||||
DatasetNodeNameTestCases())
|
||||
|
||||
TEST_F(MapDatasetOpTest, DatasetNodeName) {
|
||||
auto dataset_params = MapDatasetParams1();
|
||||
@ -218,27 +203,14 @@ TEST_F(MapDatasetOpTest, DatasetOutputShapes) {
|
||||
TF_ASSERT_OK(CheckDatasetOutputShapes({PartialTensorShape({})}));
|
||||
}
|
||||
|
||||
class ParameterizedCardinalityTest
|
||||
: public MapDatasetOpTest,
|
||||
public ::testing::WithParamInterface<
|
||||
CardinalityTestCase<MapDatasetParams>> {};
|
||||
|
||||
std::vector<CardinalityTestCase<MapDatasetParams>> CardinalityTestCases() {
|
||||
return {{/*dataset_params=*/MapDatasetParams1(), /*expected_cardinality=*/4},
|
||||
{/*dataset_params=*/MapDatasetParams2(), /*expected_cardinality=*/4},
|
||||
{/*dataset_params=*/MapDatasetParams3(), /*expected_cardinality=*/4}};
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedCardinalityTest, Cardinality) {
|
||||
auto test_case = GetParam();
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
|
||||
TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
MapDatasetOpTest, ParameterizedCardinalityTest,
|
||||
::testing::ValuesIn(std::vector<CardinalityTestCase<MapDatasetParams>>(
|
||||
CardinalityTestCases())));
|
||||
DATASET_CARDINALITY_TEST_P(MapDatasetOpTest, MapDatasetParams,
|
||||
CardinalityTestCases())
|
||||
|
||||
TEST_F(MapDatasetOpTest, IteratorOutputDtypes) {
|
||||
auto dataset_params = MapDatasetParams1();
|
||||
@ -255,15 +227,10 @@ TEST_F(MapDatasetOpTest, IteratorOutputShapes) {
|
||||
TEST_F(MapDatasetOpTest, IteratorPrefix) {
|
||||
auto dataset_params = MapDatasetParams1();
|
||||
TF_ASSERT_OK(Initialize(&dataset_params));
|
||||
TF_ASSERT_OK(CheckIteratorPrefix(
|
||||
name_utils::IteratorPrefix(MapDatasetOp::kDatasetType, kIteratorPrefix)));
|
||||
TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix(
|
||||
MapDatasetOp::kDatasetType, dataset_params.iterator_prefix)));
|
||||
}
|
||||
|
||||
class ParameterizedIteratorSaveAndRestoreTest
|
||||
: public MapDatasetOpTest,
|
||||
public ::testing::WithParamInterface<
|
||||
IteratorSaveAndRestoreTestCase<MapDatasetParams>> {};
|
||||
|
||||
std::vector<IteratorSaveAndRestoreTestCase<MapDatasetParams>>
|
||||
IteratorSaveAndRestoreTestCases() {
|
||||
return {{/*dataset_params=*/MapDatasetParams1(),
|
||||
@ -280,18 +247,8 @@ IteratorSaveAndRestoreTestCases() {
|
||||
CreateTensors<int64>(TensorShape({}), {{0}, {12}, {24}, {36}})}};
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) {
|
||||
auto test_case = GetParam();
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
|
||||
TF_ASSERT_OK(CheckIteratorSaveAndRestore(
|
||||
kIteratorPrefix, test_case.expected_outputs, test_case.breakpoints));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
MapDatasetOpTest, ParameterizedIteratorSaveAndRestoreTest,
|
||||
::testing::ValuesIn(
|
||||
std::vector<IteratorSaveAndRestoreTestCase<MapDatasetParams>>(
|
||||
IteratorSaveAndRestoreTestCases())));
|
||||
ITERATOR_SAVE_AND_RESTORE_TEST_P(MapDatasetOpTest, MapDatasetParams,
|
||||
IteratorSaveAndRestoreTestCases())
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
|
@ -21,7 +21,6 @@ namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr char kNodeName[] = "range_dataset";
|
||||
constexpr char kIteratorPrefix[] = "Iterator";
|
||||
|
||||
class RangeDatasetOpTest : public DatasetOpsTestBaseV2<RangeDatasetParams> {
|
||||
public:
|
||||
@ -30,7 +29,7 @@ class RangeDatasetOpTest : public DatasetOpsTestBaseV2<RangeDatasetParams> {
|
||||
TF_RETURN_IF_ERROR(InitFunctionLibraryRuntime({}, cpu_num_));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateRangeDatasetOpKernel(*range_dataset_params, &dataset_kernel_));
|
||||
MakeDatasetOpKernel(*range_dataset_params, &dataset_kernel_));
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
TF_RETURN_IF_ERROR(range_dataset_params->MakeInputs(&inputs));
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -39,16 +38,16 @@ class RangeDatasetOpTest : public DatasetOpsTestBaseV2<RangeDatasetParams> {
|
||||
CreateDataset(dataset_kernel_.get(), dataset_ctx_.get(), &dataset_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
|
||||
TF_RETURN_IF_ERROR(dataset_->MakeIterator(iterator_ctx_.get(),
|
||||
kIteratorPrefix, &iterator_));
|
||||
TF_RETURN_IF_ERROR(dataset_->MakeIterator(
|
||||
iterator_ctx_.get(), range_dataset_params->iterator_prefix,
|
||||
&iterator_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
// Creates a new `BatchDataset` op kernel.
|
||||
Status CreateRangeDatasetOpKernel(
|
||||
Status MakeDatasetOpKernel(
|
||||
const RangeDatasetParams& dataset_params,
|
||||
std::unique_ptr<OpKernel>* range_dataset_op_kernel) {
|
||||
std::unique_ptr<OpKernel>* range_dataset_op_kernel) override {
|
||||
NodeDef node_def = test::function::NDef(
|
||||
dataset_params.node_name,
|
||||
name_utils::OpName(RangeDatasetOp::kDatasetType),
|
||||
@ -87,10 +86,6 @@ RangeDatasetParams ZeroStepRangeDatasetParams() {
|
||||
/*node_name=*/kNodeName};
|
||||
}
|
||||
|
||||
class ParameterizedGetNextTest : public RangeDatasetOpTest,
|
||||
public ::testing::WithParamInterface<
|
||||
GetNextTestCase<RangeDatasetParams>> {};
|
||||
|
||||
std::vector<GetNextTestCase<RangeDatasetParams>> GetNextTestCases() {
|
||||
return {{/*dataset_params=*/PositiveStepRangeDatasetParams(),
|
||||
/*expected_outputs=*/
|
||||
@ -100,17 +95,8 @@ std::vector<GetNextTestCase<RangeDatasetParams>> GetNextTestCases() {
|
||||
CreateTensors<int64>(TensorShape({}), {{10}, {7}, {4}, {1}})}};
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedGetNextTest, GetNext) {
|
||||
auto test_case = GetParam();
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
|
||||
TF_ASSERT_OK(
|
||||
CheckIteratorGetNext(test_case.expected_outputs, /*compare_order=*/true));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
RangeDatasetOpTest, ParameterizedGetNextTest,
|
||||
::testing::ValuesIn(
|
||||
std::vector<GetNextTestCase<RangeDatasetParams>>(GetNextTestCases())));
|
||||
ITERATOR_GET_NEXT_TEST_P(RangeDatasetOpTest, RangeDatasetParams,
|
||||
GetNextTestCases())
|
||||
|
||||
TEST_F(RangeDatasetOpTest, DatasetNodeName) {
|
||||
auto range_dataset_params = PositiveStepRangeDatasetParams();
|
||||
@ -137,11 +123,6 @@ TEST_F(RangeDatasetOpTest, DatasetOutputShapes) {
|
||||
TF_ASSERT_OK(CheckDatasetOutputShapes({PartialTensorShape({})}));
|
||||
}
|
||||
|
||||
class ParameterizedCardinalityTest
|
||||
: public RangeDatasetOpTest,
|
||||
public ::testing::WithParamInterface<
|
||||
CardinalityTestCase<RangeDatasetParams>> {};
|
||||
|
||||
std::vector<CardinalityTestCase<RangeDatasetParams>> CardinalityTestCases() {
|
||||
return {{/*dataset_params=*/PositiveStepRangeDatasetParams(),
|
||||
/*expected_cardinality=*/4},
|
||||
@ -149,16 +130,8 @@ std::vector<CardinalityTestCase<RangeDatasetParams>> CardinalityTestCases() {
|
||||
/*expected_cardinality=*/4}};
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedCardinalityTest, Cardinality) {
|
||||
auto test_case = GetParam();
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
|
||||
TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
RangeDatasetOpTest, ParameterizedCardinalityTest,
|
||||
::testing::ValuesIn(std::vector<CardinalityTestCase<RangeDatasetParams>>(
|
||||
CardinalityTestCases())));
|
||||
DATASET_CARDINALITY_TEST_P(RangeDatasetOpTest, RangeDatasetParams,
|
||||
CardinalityTestCases())
|
||||
|
||||
TEST_F(RangeDatasetOpTest, IteratorOutputDtypes) {
|
||||
auto range_dataset_params = PositiveStepRangeDatasetParams();
|
||||
@ -176,14 +149,9 @@ TEST_F(RangeDatasetOpTest, IteratorPrefix) {
|
||||
auto range_dataset_params = PositiveStepRangeDatasetParams();
|
||||
TF_ASSERT_OK(Initialize(&range_dataset_params));
|
||||
TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix(
|
||||
RangeDatasetOp::kDatasetType, kIteratorPrefix)));
|
||||
RangeDatasetOp::kDatasetType, range_dataset_params.iterator_prefix)));
|
||||
}
|
||||
|
||||
class ParameterizedIteratorSaveAndRestoreTest
|
||||
: public RangeDatasetOpTest,
|
||||
public ::testing::WithParamInterface<
|
||||
IteratorSaveAndRestoreTestCase<RangeDatasetParams>> {};
|
||||
|
||||
std::vector<IteratorSaveAndRestoreTestCase<RangeDatasetParams>>
|
||||
IteratorSaveAndRestoreTestCases() {
|
||||
return {{/*dataset_params=*/PositiveStepRangeDatasetParams(),
|
||||
@ -196,18 +164,8 @@ IteratorSaveAndRestoreTestCases() {
|
||||
CreateTensors<int64>(TensorShape({}), {{10}, {7}, {4}, {1}})}};
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) {
|
||||
auto test_case = GetParam();
|
||||
TF_ASSERT_OK(Initialize(&test_case.dataset_params));
|
||||
TF_ASSERT_OK(CheckIteratorSaveAndRestore(
|
||||
kIteratorPrefix, test_case.expected_outputs, test_case.breakpoints));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
RangeDatasetOpTest, ParameterizedIteratorSaveAndRestoreTest,
|
||||
::testing::ValuesIn(
|
||||
std::vector<IteratorSaveAndRestoreTestCase<RangeDatasetParams>>(
|
||||
IteratorSaveAndRestoreTestCases())));
|
||||
ITERATOR_SAVE_AND_RESTORE_TEST_P(RangeDatasetOpTest, RangeDatasetParams,
|
||||
IteratorSaveAndRestoreTestCases())
|
||||
|
||||
TEST_F(RangeDatasetOpTest, ZeroStep) {
|
||||
auto range_dataset_params = ZeroStepRangeDatasetParams();
|
||||
|
Loading…
x
Reference in New Issue
Block a user