Refactor ZipDatasetOpTest

This commit is contained in:
Fei Hu 2019-03-18 21:31:54 -07:00
parent a280e8644b
commit a588b1c97e

View File

@ -58,10 +58,10 @@ class ZipDatasetOpTest : public DatasetOpsTestBase {
// Create the placeholder names for the input components of `ZipDataset`. // Create the placeholder names for the input components of `ZipDataset`.
input_datasets.emplace_back(strings::StrCat("input_dataset_", i)); input_datasets.emplace_back(strings::StrCat("input_dataset_", i));
} }
node_def_ = test::function::NDef( NodeDef node_def = test::function::NDef(
kNodeName, kOpName, input_datasets, kNodeName, kOpName, input_datasets,
{{"output_types", dtypes}, {"output_shapes", output_shapes}, {"N", n}}); {{"output_types", dtypes}, {"output_shapes", output_shapes}, {"N", n}});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel)); TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel));
return Status::OK(); return Status::OK();
} }
@ -74,9 +74,6 @@ class ZipDatasetOpTest : public DatasetOpsTestBase {
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
return Status::OK(); return Status::OK();
} }
private:
NodeDef node_def_;
}; };
struct TestParam { struct TestParam {
@ -85,8 +82,8 @@ struct TestParam {
std::vector<int> breakpoints; std::vector<int> breakpoints;
}; };
TestParam TestCase1() {
// Test case 1: the input datasets with same number of outputs. // Test case 1: the input datasets with same number of outputs.
TestParam TestCase1() {
return {/*input_range_dataset_params*/ return {/*input_range_dataset_params*/
{RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 13, 1}}, {RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 13, 1}},
/*expected_outputs*/ /*expected_outputs*/
@ -99,8 +96,8 @@ TestParam TestCase1() {
/*breakpoints*/ {0, 1, 4}}; /*breakpoints*/ {0, 1, 4}};
} }
TestParam TestCase2() {
// Test case 2: the input datasets with different number of outputs. // Test case 2: the input datasets with different number of outputs.
TestParam TestCase2() {
return {/*input_range_dataset_params*/ return {/*input_range_dataset_params*/
{RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 15, 1}}, {RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 15, 1}},
/*expected_outputs*/ /*expected_outputs*/
@ -113,67 +110,48 @@ TestParam TestCase2() {
/*breakpoints*/ {0, 1, 4}}; /*breakpoints*/ {0, 1, 4}};
} }
class ZipDatasetOpTestHelper : public ZipDatasetOpTest { class ParameterizedZipDatasetOpTest
public: : public ZipDatasetOpTest,
~ZipDatasetOpTestHelper() override { public ::testing::WithParamInterface<TestParam> {};
if (dataset_) dataset_->Unref();
}
protected: TEST_P(ParameterizedZipDatasetOpTest, GetNext) {
Status CreateDatasetFromTestCase(const TestParam &test_case) { int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
std::vector<Tensor> range_dataset_tensors; std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_RETURN_IF_ERROR(CreateRangeDatasetTensors( TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
test_case.input_range_dataset_params, &range_dataset_tensors)); &range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs; gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size()); inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) { for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor); inputs.emplace_back(&tensor);
} }
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size(); int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_RETURN_IF_ERROR(CreateZipDatasetKernel({DT_INT64}, TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
{{num_tensors_per_slice}}, inputs.size(), &dataset_kernel));
inputs.size(), &dataset_kernel_)); std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_RETURN_IF_ERROR(CreateZipDatasetContext(dataset_kernel_.get(), &inputs, TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx_)); &dataset_kernel_ctx));
TF_RETURN_IF_ERROR(CreateDataset(dataset_kernel_.get(), DatasetBase *zip_dataset;
dataset_kernel_ctx_.get(), &dataset_)); TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
return Status::OK(); &zip_dataset));
} core::ScopedUnref scoped_unref(zip_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
Status CreateIteratorFromTestCase(const TestParam &test_case) { TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
TF_RETURN_IF_ERROR(CreateDatasetFromTestCase(test_case)); std::unique_ptr<IteratorBase> iterator;
TF_RETURN_IF_ERROR( TF_ASSERT_OK(
CreateIteratorContext(dataset_kernel_ctx_.get(), &iterator_ctx_)); zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
TF_RETURN_IF_ERROR(
dataset_->MakeIterator(iterator_ctx_.get(), "Iterator", &iterator_));
return Status::OK();
}
std::unique_ptr<OpKernel> dataset_kernel_;
std::unique_ptr<OpKernelContext> dataset_kernel_ctx_;
DatasetBase *dataset_ = nullptr; // owned by this class.
std::unique_ptr<IteratorContext> iterator_ctx_;
std::unique_ptr<IteratorBase> iterator_;
};
class ParameterizedDatasetTest
: public ZipDatasetOpTestHelper,
public ::testing::WithParamInterface<TestParam> {};
TEST_P(ParameterizedDatasetTest, GetNext) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
auto expected_outputs_it = test_case.expected_outputs.begin(); auto expected_outputs_it = test_case.expected_outputs.begin();
bool end_of_sequence = false; bool end_of_sequence = false;
std::vector<Tensor> out_tensors; std::vector<Tensor> out_tensors;
while (!end_of_sequence) { while (!end_of_sequence) {
TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors, TF_EXPECT_OK(
&end_of_sequence)); iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence));
if (!end_of_sequence) { if (!end_of_sequence) {
for (const auto &tensor : out_tensors) { for (const auto &tensor : out_tensors) {
EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
@ -185,22 +163,92 @@ TEST_P(ParameterizedDatasetTest, GetNext) {
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
} }
TEST_F(ZipDatasetOpTestHelper, DatasetName) { TEST_F(ZipDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1()));
EXPECT_EQ(dataset_->type_string(), kOpName); const TestParam &test_case = TestCase1();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
EXPECT_EQ(zip_dataset->node_name(), kNodeName);
} }
TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) { TEST_F(ZipDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
const TestParam &test_case = TestCase1();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size(); int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateDatasetFromTestCase(test_case)); TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
EXPECT_EQ(zip_dataset->type_string(), kOpName);
}
TEST_P(ParameterizedZipDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
DataTypeVector expected_output_dtypes; DataTypeVector expected_output_dtypes;
expected_output_dtypes.reserve(num_tensors_per_slice); expected_output_dtypes.reserve(num_tensors_per_slice);
@ -209,16 +257,35 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) {
} }
TF_EXPECT_OK( TF_EXPECT_OK(
VerifyTypesMatch(dataset_->output_dtypes(), expected_output_dtypes)); VerifyTypesMatch(zip_dataset->output_dtypes(), expected_output_dtypes));
} }
TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) { TEST_P(ParameterizedZipDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam(); const TestParam &test_case = GetParam();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size(); int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateDatasetFromTestCase(test_case)); TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
std::vector<PartialTensorShape> expected_output_shapes; std::vector<PartialTensorShape> expected_output_shapes;
expected_output_shapes.reserve(num_tensors_per_slice); expected_output_shapes.reserve(num_tensors_per_slice);
@ -226,43 +293,107 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) {
expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape()); expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape());
} }
TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(), TF_EXPECT_OK(VerifyShapesCompatible(zip_dataset->output_shapes(),
expected_output_shapes)); expected_output_shapes));
} }
TEST_P(ParameterizedDatasetTest, Cardinality) { TEST_P(ParameterizedZipDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateDatasetFromTestCase(test_case));
EXPECT_EQ(dataset_->Cardinality(), const TestParam &test_case = GetParam();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
EXPECT_EQ(zip_dataset->Cardinality(),
test_case.expected_outputs.size() / num_tensors_per_slice); test_case.expected_outputs.size() / num_tensors_per_slice);
} }
TEST_F(ZipDatasetOpTestHelper, DatasetSave) { TEST_F(ZipDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1()));
const TestParam &test_case = TestCase1();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
std::unique_ptr<SerializationContext> serialization_ctx; std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data; VariantTensorData data;
VariantTensorDataWriter writer(&data); VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(dataset_->Save(serialization_ctx.get(), &writer)); TF_ASSERT_OK(zip_dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush()); TF_ASSERT_OK(writer.Flush());
} }
TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) { TEST_P(ParameterizedZipDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam(); const TestParam &test_case = GetParam();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size(); int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
DataTypeVector expected_output_dtypes; DataTypeVector expected_output_dtypes;
expected_output_dtypes.reserve(num_tensors_per_slice); expected_output_dtypes.reserve(num_tensors_per_slice);
@ -271,16 +402,40 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) {
} }
TF_EXPECT_OK( TF_EXPECT_OK(
VerifyTypesMatch(iterator_->output_dtypes(), expected_output_dtypes)); VerifyTypesMatch(iterator->output_dtypes(), expected_output_dtypes));
} }
TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) { TEST_P(ParameterizedZipDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam(); const TestParam &test_case = GetParam();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size(); int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
std::vector<PartialTensorShape> expected_output_shapes; std::vector<PartialTensorShape> expected_output_shapes;
expected_output_shapes.reserve(num_tensors_per_slice); expected_output_shapes.reserve(num_tensors_per_slice);
@ -288,42 +443,94 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) {
expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape()); expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape());
} }
TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(), TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
expected_output_shapes)); expected_output_shapes));
} }
TEST_F(ZipDatasetOpTestHelper, IteratorOutputPrefix) { TEST_F(ZipDatasetOpTest, IteratorOutputPrefix) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateIteratorFromTestCase(TestCase1()));
EXPECT_EQ(iterator_->prefix(), "Iterator::Zip"); const TestParam &test_case = TestCase1();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
EXPECT_EQ(iterator->prefix(), "Iterator::Zip");
} }
TEST_P(ParameterizedDatasetTest, Roundtrip) { TEST_P(ParameterizedZipDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam(); const TestParam &test_case = GetParam();
auto expected_outputs_it = test_case.expected_outputs.begin(); std::vector<Tensor> range_dataset_tensors;
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
std::unique_ptr<SerializationContext> serialization_ctx; std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
bool end_of_sequence = false; bool end_of_sequence = false;
std::vector<Tensor> out_tensors; std::vector<Tensor> out_tensors;
auto expected_outputs_it = test_case.expected_outputs.begin();
int cur_iteration = 0; int cur_iteration = 0;
for (int breakpoint : test_case.breakpoints) { for (int breakpoint : test_case.breakpoints) {
VariantTensorData data; VariantTensorData data;
VariantTensorDataWriter writer(&data); VariantTensorDataWriter writer(&data);
TF_EXPECT_OK(iterator_->Save(serialization_ctx.get(), &writer)); TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush()); TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data); VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator_->Restore(iterator_ctx_.get(), &reader)); TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
while (cur_iteration < breakpoint) { while (cur_iteration < breakpoint) {
TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors, TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
&end_of_sequence)); &end_of_sequence));
if (!end_of_sequence) { if (!end_of_sequence) {
for (auto &tensor : out_tensors) { for (auto &tensor : out_tensors) {
@ -335,7 +542,7 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) {
cur_iteration++; cur_iteration++;
} }
if (breakpoint >= dataset_->Cardinality()) { if (breakpoint >= zip_dataset->Cardinality()) {
EXPECT_TRUE(end_of_sequence); EXPECT_TRUE(end_of_sequence);
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
} else { } else {
@ -345,7 +552,7 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) {
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
ZipDatasetOpTest, ParameterizedDatasetTest, ZipDatasetOpTest, ParameterizedZipDatasetOpTest,
::testing::ValuesIn(std::vector<TestParam>({TestCase1(), TestCase2()}))); ::testing::ValuesIn(std::vector<TestParam>({TestCase1(), TestCase2()})));
} // namespace } // namespace