Refactor ZipDatasetOpTest
This commit is contained in:
parent
a280e8644b
commit
a588b1c97e
@ -58,10 +58,10 @@ class ZipDatasetOpTest : public DatasetOpsTestBase {
|
|||||||
// Create the placeholder names for the input components of `ZipDataset`.
|
// Create the placeholder names for the input components of `ZipDataset`.
|
||||||
input_datasets.emplace_back(strings::StrCat("input_dataset_", i));
|
input_datasets.emplace_back(strings::StrCat("input_dataset_", i));
|
||||||
}
|
}
|
||||||
node_def_ = test::function::NDef(
|
NodeDef node_def = test::function::NDef(
|
||||||
kNodeName, kOpName, input_datasets,
|
kNodeName, kOpName, input_datasets,
|
||||||
{{"output_types", dtypes}, {"output_shapes", output_shapes}, {"N", n}});
|
{{"output_types", dtypes}, {"output_shapes", output_shapes}, {"N", n}});
|
||||||
TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel));
|
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,9 +74,6 @@ class ZipDatasetOpTest : public DatasetOpsTestBase {
|
|||||||
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
|
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
NodeDef node_def_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TestParam {
|
struct TestParam {
|
||||||
@ -85,8 +82,8 @@ struct TestParam {
|
|||||||
std::vector<int> breakpoints;
|
std::vector<int> breakpoints;
|
||||||
};
|
};
|
||||||
|
|
||||||
TestParam TestCase1() {
|
|
||||||
// Test case 1: the input datasets with same number of outputs.
|
// Test case 1: the input datasets with same number of outputs.
|
||||||
|
TestParam TestCase1() {
|
||||||
return {/*input_range_dataset_params*/
|
return {/*input_range_dataset_params*/
|
||||||
{RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 13, 1}},
|
{RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 13, 1}},
|
||||||
/*expected_outputs*/
|
/*expected_outputs*/
|
||||||
@ -99,8 +96,8 @@ TestParam TestCase1() {
|
|||||||
/*breakpoints*/ {0, 1, 4}};
|
/*breakpoints*/ {0, 1, 4}};
|
||||||
}
|
}
|
||||||
|
|
||||||
TestParam TestCase2() {
|
|
||||||
// Test case 2: the input datasets with different number of outputs.
|
// Test case 2: the input datasets with different number of outputs.
|
||||||
|
TestParam TestCase2() {
|
||||||
return {/*input_range_dataset_params*/
|
return {/*input_range_dataset_params*/
|
||||||
{RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 15, 1}},
|
{RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 15, 1}},
|
||||||
/*expected_outputs*/
|
/*expected_outputs*/
|
||||||
@ -113,67 +110,48 @@ TestParam TestCase2() {
|
|||||||
/*breakpoints*/ {0, 1, 4}};
|
/*breakpoints*/ {0, 1, 4}};
|
||||||
}
|
}
|
||||||
|
|
||||||
class ZipDatasetOpTestHelper : public ZipDatasetOpTest {
|
class ParameterizedZipDatasetOpTest
|
||||||
public:
|
: public ZipDatasetOpTest,
|
||||||
~ZipDatasetOpTestHelper() override {
|
public ::testing::WithParamInterface<TestParam> {};
|
||||||
if (dataset_) dataset_->Unref();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
TEST_P(ParameterizedZipDatasetOpTest, GetNext) {
|
||||||
Status CreateDatasetFromTestCase(const TestParam &test_case) {
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
const TestParam &test_case = GetParam();
|
||||||
std::vector<Tensor> range_dataset_tensors;
|
std::vector<Tensor> range_dataset_tensors;
|
||||||
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
|
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
|
||||||
TF_RETURN_IF_ERROR(CreateRangeDatasetTensors(
|
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
|
||||||
test_case.input_range_dataset_params, &range_dataset_tensors));
|
&range_dataset_tensors));
|
||||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
inputs.reserve(range_dataset_tensors.size());
|
inputs.reserve(range_dataset_tensors.size());
|
||||||
for (auto &tensor : range_dataset_tensors) {
|
for (auto &tensor : range_dataset_tensors) {
|
||||||
inputs.emplace_back(&tensor);
|
inputs.emplace_back(&tensor);
|
||||||
}
|
}
|
||||||
|
std::unique_ptr<OpKernel> dataset_kernel;
|
||||||
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
||||||
TF_RETURN_IF_ERROR(CreateZipDatasetKernel({DT_INT64},
|
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
|
||||||
{{num_tensors_per_slice}},
|
inputs.size(), &dataset_kernel));
|
||||||
inputs.size(), &dataset_kernel_));
|
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
|
||||||
TF_RETURN_IF_ERROR(CreateZipDatasetContext(dataset_kernel_.get(), &inputs,
|
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
|
||||||
&dataset_kernel_ctx_));
|
&dataset_kernel_ctx));
|
||||||
TF_RETURN_IF_ERROR(CreateDataset(dataset_kernel_.get(),
|
DatasetBase *zip_dataset;
|
||||||
dataset_kernel_ctx_.get(), &dataset_));
|
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
|
||||||
return Status::OK();
|
&zip_dataset));
|
||||||
}
|
core::ScopedUnref scoped_unref(zip_dataset);
|
||||||
|
std::unique_ptr<IteratorContext> iterator_ctx;
|
||||||
Status CreateIteratorFromTestCase(const TestParam &test_case) {
|
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
|
||||||
TF_RETURN_IF_ERROR(CreateDatasetFromTestCase(test_case));
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_ASSERT_OK(
|
||||||
CreateIteratorContext(dataset_kernel_ctx_.get(), &iterator_ctx_));
|
zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dataset_->MakeIterator(iterator_ctx_.get(), "Iterator", &iterator_));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> dataset_kernel_;
|
|
||||||
std::unique_ptr<OpKernelContext> dataset_kernel_ctx_;
|
|
||||||
DatasetBase *dataset_ = nullptr; // owned by this class.
|
|
||||||
std::unique_ptr<IteratorContext> iterator_ctx_;
|
|
||||||
std::unique_ptr<IteratorBase> iterator_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class ParameterizedDatasetTest
|
|
||||||
: public ZipDatasetOpTestHelper,
|
|
||||||
public ::testing::WithParamInterface<TestParam> {};
|
|
||||||
|
|
||||||
TEST_P(ParameterizedDatasetTest, GetNext) {
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
|
||||||
const TestParam &test_case = GetParam();
|
|
||||||
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
|
|
||||||
|
|
||||||
auto expected_outputs_it = test_case.expected_outputs.begin();
|
auto expected_outputs_it = test_case.expected_outputs.begin();
|
||||||
bool end_of_sequence = false;
|
bool end_of_sequence = false;
|
||||||
std::vector<Tensor> out_tensors;
|
std::vector<Tensor> out_tensors;
|
||||||
while (!end_of_sequence) {
|
while (!end_of_sequence) {
|
||||||
TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors,
|
TF_EXPECT_OK(
|
||||||
&end_of_sequence));
|
iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence));
|
||||||
if (!end_of_sequence) {
|
if (!end_of_sequence) {
|
||||||
for (const auto &tensor : out_tensors) {
|
for (const auto &tensor : out_tensors) {
|
||||||
EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
|
EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
|
||||||
@ -185,22 +163,92 @@ TEST_P(ParameterizedDatasetTest, GetNext) {
|
|||||||
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
|
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ZipDatasetOpTestHelper, DatasetName) {
|
TEST_F(ZipDatasetOpTest, DatasetNodeName) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1()));
|
|
||||||
|
|
||||||
EXPECT_EQ(dataset_->type_string(), kOpName);
|
const TestParam &test_case = TestCase1();
|
||||||
|
std::vector<Tensor> range_dataset_tensors;
|
||||||
|
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
|
||||||
|
&range_dataset_tensors));
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
inputs.reserve(range_dataset_tensors.size());
|
||||||
|
for (auto &tensor : range_dataset_tensors) {
|
||||||
|
inputs.emplace_back(&tensor);
|
||||||
|
}
|
||||||
|
std::unique_ptr<OpKernel> dataset_kernel;
|
||||||
|
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
|
||||||
|
inputs.size(), &dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
|
||||||
|
&dataset_kernel_ctx));
|
||||||
|
DatasetBase *zip_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
|
||||||
|
&zip_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(zip_dataset);
|
||||||
|
|
||||||
|
EXPECT_EQ(zip_dataset->node_name(), kNodeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) {
|
TEST_F(ZipDatasetOpTest, DatasetTypeString) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
const TestParam &test_case = GetParam();
|
|
||||||
|
const TestParam &test_case = TestCase1();
|
||||||
|
std::vector<Tensor> range_dataset_tensors;
|
||||||
|
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
|
||||||
|
&range_dataset_tensors));
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
inputs.reserve(range_dataset_tensors.size());
|
||||||
|
for (auto &tensor : range_dataset_tensors) {
|
||||||
|
inputs.emplace_back(&tensor);
|
||||||
|
}
|
||||||
|
std::unique_ptr<OpKernel> dataset_kernel;
|
||||||
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
||||||
TF_ASSERT_OK(CreateDatasetFromTestCase(test_case));
|
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
|
||||||
|
inputs.size(), &dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
|
||||||
|
&dataset_kernel_ctx));
|
||||||
|
DatasetBase *zip_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
|
||||||
|
&zip_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(zip_dataset);
|
||||||
|
|
||||||
|
EXPECT_EQ(zip_dataset->type_string(), kOpName);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(ParameterizedZipDatasetOpTest, DatasetOutputDtypes) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
const TestParam &test_case = GetParam();
|
||||||
|
std::vector<Tensor> range_dataset_tensors;
|
||||||
|
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
|
||||||
|
&range_dataset_tensors));
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
inputs.reserve(range_dataset_tensors.size());
|
||||||
|
for (auto &tensor : range_dataset_tensors) {
|
||||||
|
inputs.emplace_back(&tensor);
|
||||||
|
}
|
||||||
|
std::unique_ptr<OpKernel> dataset_kernel;
|
||||||
|
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
|
||||||
|
inputs.size(), &dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
|
||||||
|
&dataset_kernel_ctx));
|
||||||
|
DatasetBase *zip_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
|
||||||
|
&zip_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(zip_dataset);
|
||||||
|
|
||||||
DataTypeVector expected_output_dtypes;
|
DataTypeVector expected_output_dtypes;
|
||||||
expected_output_dtypes.reserve(num_tensors_per_slice);
|
expected_output_dtypes.reserve(num_tensors_per_slice);
|
||||||
@ -209,16 +257,35 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TF_EXPECT_OK(
|
TF_EXPECT_OK(
|
||||||
VerifyTypesMatch(dataset_->output_dtypes(), expected_output_dtypes));
|
VerifyTypesMatch(zip_dataset->output_dtypes(), expected_output_dtypes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) {
|
TEST_P(ParameterizedZipDatasetOpTest, DatasetOutputShapes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
const TestParam &test_case = GetParam();
|
const TestParam &test_case = GetParam();
|
||||||
|
std::vector<Tensor> range_dataset_tensors;
|
||||||
|
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
|
||||||
|
&range_dataset_tensors));
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
inputs.reserve(range_dataset_tensors.size());
|
||||||
|
for (auto &tensor : range_dataset_tensors) {
|
||||||
|
inputs.emplace_back(&tensor);
|
||||||
|
}
|
||||||
|
std::unique_ptr<OpKernel> dataset_kernel;
|
||||||
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
||||||
TF_ASSERT_OK(CreateDatasetFromTestCase(test_case));
|
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
|
||||||
|
inputs.size(), &dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
|
||||||
|
&dataset_kernel_ctx));
|
||||||
|
DatasetBase *zip_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
|
||||||
|
&zip_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(zip_dataset);
|
||||||
|
|
||||||
std::vector<PartialTensorShape> expected_output_shapes;
|
std::vector<PartialTensorShape> expected_output_shapes;
|
||||||
expected_output_shapes.reserve(num_tensors_per_slice);
|
expected_output_shapes.reserve(num_tensors_per_slice);
|
||||||
@ -226,43 +293,107 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) {
|
|||||||
expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape());
|
expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(),
|
TF_EXPECT_OK(VerifyShapesCompatible(zip_dataset->output_shapes(),
|
||||||
expected_output_shapes));
|
expected_output_shapes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParameterizedDatasetTest, Cardinality) {
|
TEST_P(ParameterizedZipDatasetOpTest, Cardinality) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
const TestParam &test_case = GetParam();
|
|
||||||
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
|
||||||
TF_ASSERT_OK(CreateDatasetFromTestCase(test_case));
|
|
||||||
|
|
||||||
EXPECT_EQ(dataset_->Cardinality(),
|
const TestParam &test_case = GetParam();
|
||||||
|
std::vector<Tensor> range_dataset_tensors;
|
||||||
|
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
|
||||||
|
&range_dataset_tensors));
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
inputs.reserve(range_dataset_tensors.size());
|
||||||
|
for (auto &tensor : range_dataset_tensors) {
|
||||||
|
inputs.emplace_back(&tensor);
|
||||||
|
}
|
||||||
|
std::unique_ptr<OpKernel> dataset_kernel;
|
||||||
|
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
|
||||||
|
inputs.size(), &dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
|
||||||
|
&dataset_kernel_ctx));
|
||||||
|
DatasetBase *zip_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
|
||||||
|
&zip_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(zip_dataset);
|
||||||
|
|
||||||
|
EXPECT_EQ(zip_dataset->Cardinality(),
|
||||||
test_case.expected_outputs.size() / num_tensors_per_slice);
|
test_case.expected_outputs.size() / num_tensors_per_slice);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ZipDatasetOpTestHelper, DatasetSave) {
|
TEST_F(ZipDatasetOpTest, DatasetSave) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1()));
|
|
||||||
|
const TestParam &test_case = TestCase1();
|
||||||
|
std::vector<Tensor> range_dataset_tensors;
|
||||||
|
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
|
||||||
|
&range_dataset_tensors));
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
inputs.reserve(range_dataset_tensors.size());
|
||||||
|
for (auto &tensor : range_dataset_tensors) {
|
||||||
|
inputs.emplace_back(&tensor);
|
||||||
|
}
|
||||||
|
std::unique_ptr<OpKernel> dataset_kernel;
|
||||||
|
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
|
||||||
|
inputs.size(), &dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
|
||||||
|
&dataset_kernel_ctx));
|
||||||
|
DatasetBase *zip_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
|
||||||
|
&zip_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(zip_dataset);
|
||||||
|
|
||||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||||
VariantTensorData data;
|
VariantTensorData data;
|
||||||
VariantTensorDataWriter writer(&data);
|
VariantTensorDataWriter writer(&data);
|
||||||
TF_ASSERT_OK(dataset_->Save(serialization_ctx.get(), &writer));
|
TF_ASSERT_OK(zip_dataset->Save(serialization_ctx.get(), &writer));
|
||||||
TF_ASSERT_OK(writer.Flush());
|
TF_ASSERT_OK(writer.Flush());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) {
|
TEST_P(ParameterizedZipDatasetOpTest, IteratorOutputDtypes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
const TestParam &test_case = GetParam();
|
const TestParam &test_case = GetParam();
|
||||||
|
std::vector<Tensor> range_dataset_tensors;
|
||||||
|
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
|
||||||
|
&range_dataset_tensors));
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
inputs.reserve(range_dataset_tensors.size());
|
||||||
|
for (auto &tensor : range_dataset_tensors) {
|
||||||
|
inputs.emplace_back(&tensor);
|
||||||
|
}
|
||||||
|
std::unique_ptr<OpKernel> dataset_kernel;
|
||||||
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
||||||
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
|
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
|
||||||
|
inputs.size(), &dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
|
||||||
|
&dataset_kernel_ctx));
|
||||||
|
DatasetBase *zip_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
|
||||||
|
&zip_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(zip_dataset);
|
||||||
|
std::unique_ptr<IteratorContext> iterator_ctx;
|
||||||
|
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
|
||||||
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
|
||||||
|
|
||||||
DataTypeVector expected_output_dtypes;
|
DataTypeVector expected_output_dtypes;
|
||||||
expected_output_dtypes.reserve(num_tensors_per_slice);
|
expected_output_dtypes.reserve(num_tensors_per_slice);
|
||||||
@ -271,16 +402,40 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TF_EXPECT_OK(
|
TF_EXPECT_OK(
|
||||||
VerifyTypesMatch(iterator_->output_dtypes(), expected_output_dtypes));
|
VerifyTypesMatch(iterator->output_dtypes(), expected_output_dtypes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) {
|
TEST_P(ParameterizedZipDatasetOpTest, IteratorOutputShapes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
const TestParam &test_case = GetParam();
|
const TestParam &test_case = GetParam();
|
||||||
|
std::vector<Tensor> range_dataset_tensors;
|
||||||
|
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
|
||||||
|
&range_dataset_tensors));
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
inputs.reserve(range_dataset_tensors.size());
|
||||||
|
for (auto &tensor : range_dataset_tensors) {
|
||||||
|
inputs.emplace_back(&tensor);
|
||||||
|
}
|
||||||
|
std::unique_ptr<OpKernel> dataset_kernel;
|
||||||
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
||||||
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
|
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
|
||||||
|
inputs.size(), &dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
|
||||||
|
&dataset_kernel_ctx));
|
||||||
|
DatasetBase *zip_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
|
||||||
|
&zip_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(zip_dataset);
|
||||||
|
std::unique_ptr<IteratorContext> iterator_ctx;
|
||||||
|
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
|
||||||
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
|
||||||
|
|
||||||
std::vector<PartialTensorShape> expected_output_shapes;
|
std::vector<PartialTensorShape> expected_output_shapes;
|
||||||
expected_output_shapes.reserve(num_tensors_per_slice);
|
expected_output_shapes.reserve(num_tensors_per_slice);
|
||||||
@ -288,42 +443,94 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) {
|
|||||||
expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape());
|
expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(),
|
TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
|
||||||
expected_output_shapes));
|
expected_output_shapes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ZipDatasetOpTestHelper, IteratorOutputPrefix) {
|
TEST_F(ZipDatasetOpTest, IteratorOutputPrefix) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
TF_ASSERT_OK(CreateIteratorFromTestCase(TestCase1()));
|
|
||||||
EXPECT_EQ(iterator_->prefix(), "Iterator::Zip");
|
const TestParam &test_case = TestCase1();
|
||||||
|
std::vector<Tensor> range_dataset_tensors;
|
||||||
|
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
|
||||||
|
&range_dataset_tensors));
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
inputs.reserve(range_dataset_tensors.size());
|
||||||
|
for (auto &tensor : range_dataset_tensors) {
|
||||||
|
inputs.emplace_back(&tensor);
|
||||||
|
}
|
||||||
|
std::unique_ptr<OpKernel> dataset_kernel;
|
||||||
|
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
|
||||||
|
inputs.size(), &dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
|
||||||
|
&dataset_kernel_ctx));
|
||||||
|
DatasetBase *zip_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
|
||||||
|
&zip_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(zip_dataset);
|
||||||
|
std::unique_ptr<IteratorContext> iterator_ctx;
|
||||||
|
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
|
||||||
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
|
||||||
|
|
||||||
|
EXPECT_EQ(iterator->prefix(), "Iterator::Zip");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParameterizedDatasetTest, Roundtrip) {
|
TEST_P(ParameterizedZipDatasetOpTest, Roundtrip) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
const TestParam &test_case = GetParam();
|
const TestParam &test_case = GetParam();
|
||||||
auto expected_outputs_it = test_case.expected_outputs.begin();
|
std::vector<Tensor> range_dataset_tensors;
|
||||||
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
|
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
|
||||||
|
&range_dataset_tensors));
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
inputs.reserve(range_dataset_tensors.size());
|
||||||
|
for (auto &tensor : range_dataset_tensors) {
|
||||||
|
inputs.emplace_back(&tensor);
|
||||||
|
}
|
||||||
|
std::unique_ptr<OpKernel> dataset_kernel;
|
||||||
|
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
|
||||||
|
inputs.size(), &dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
|
||||||
|
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
|
||||||
|
&dataset_kernel_ctx));
|
||||||
|
DatasetBase *zip_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
|
||||||
|
&zip_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(zip_dataset);
|
||||||
|
std::unique_ptr<IteratorContext> iterator_ctx;
|
||||||
|
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
|
||||||
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
|
||||||
|
|
||||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||||
|
|
||||||
bool end_of_sequence = false;
|
bool end_of_sequence = false;
|
||||||
std::vector<Tensor> out_tensors;
|
std::vector<Tensor> out_tensors;
|
||||||
|
auto expected_outputs_it = test_case.expected_outputs.begin();
|
||||||
int cur_iteration = 0;
|
int cur_iteration = 0;
|
||||||
for (int breakpoint : test_case.breakpoints) {
|
for (int breakpoint : test_case.breakpoints) {
|
||||||
VariantTensorData data;
|
VariantTensorData data;
|
||||||
VariantTensorDataWriter writer(&data);
|
VariantTensorDataWriter writer(&data);
|
||||||
TF_EXPECT_OK(iterator_->Save(serialization_ctx.get(), &writer));
|
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
|
||||||
TF_EXPECT_OK(writer.Flush());
|
TF_EXPECT_OK(writer.Flush());
|
||||||
VariantTensorDataReader reader(&data);
|
VariantTensorDataReader reader(&data);
|
||||||
TF_EXPECT_OK(iterator_->Restore(iterator_ctx_.get(), &reader));
|
TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
|
||||||
|
|
||||||
while (cur_iteration < breakpoint) {
|
while (cur_iteration < breakpoint) {
|
||||||
TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors,
|
TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
|
||||||
&end_of_sequence));
|
&end_of_sequence));
|
||||||
if (!end_of_sequence) {
|
if (!end_of_sequence) {
|
||||||
for (auto &tensor : out_tensors) {
|
for (auto &tensor : out_tensors) {
|
||||||
@ -335,7 +542,7 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) {
|
|||||||
cur_iteration++;
|
cur_iteration++;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (breakpoint >= dataset_->Cardinality()) {
|
if (breakpoint >= zip_dataset->Cardinality()) {
|
||||||
EXPECT_TRUE(end_of_sequence);
|
EXPECT_TRUE(end_of_sequence);
|
||||||
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
|
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
|
||||||
} else {
|
} else {
|
||||||
@ -345,7 +552,7 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
ZipDatasetOpTest, ParameterizedDatasetTest,
|
ZipDatasetOpTest, ParameterizedZipDatasetOpTest,
|
||||||
::testing::ValuesIn(std::vector<TestParam>({TestCase1(), TestCase2()})));
|
::testing::ValuesIn(std::vector<TestParam>({TestCase1(), TestCase2()})));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user