Merge pull request #26860 from feihugis:Refactor_Dataset_Tests

PiperOrigin-RevId: 239862595
This commit is contained in:
TensorFlower Gardener 2019-03-22 15:19:37 -07:00
commit d395bd4682
6 changed files with 1524 additions and 993 deletions

View File

@ -48,10 +48,10 @@ class ConcatenateDatasetOpTest : public DatasetOpsTestBase {
const DataTypeVector &output_types, const DataTypeVector &output_types,
const std::vector<PartialTensorShape> &output_shapes, const std::vector<PartialTensorShape> &output_shapes,
std::unique_ptr<OpKernel> *op_kernel) { std::unique_ptr<OpKernel> *op_kernel) {
node_def_ = test::function::NDef( NodeDef node_def = test::function::NDef(
kNodeName, kOpName, {"input_dataset", "another_dataset"}, kNodeName, kOpName, {"input_dataset", "another_dataset"},
{{"output_types", output_types}, {"output_shapes", output_shapes}}); {{"output_types", output_types}, {"output_shapes", output_shapes}});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel)); TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel));
return Status::OK(); return Status::OK();
} }
@ -64,12 +64,9 @@ class ConcatenateDatasetOpTest : public DatasetOpsTestBase {
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
return Status::OK(); return Status::OK();
} }
private:
NodeDef node_def_;
}; };
struct TestParam { struct TestCase {
std::vector<std::vector<Tensor>> input_tensors; std::vector<std::vector<Tensor>> input_tensors;
std::vector<Tensor> expected_outputs; std::vector<Tensor> expected_outputs;
DataTypeVector expected_output_dtypes; DataTypeVector expected_output_dtypes;
@ -77,8 +74,9 @@ struct TestParam {
int64 expected_cardinality; int64 expected_cardinality;
std::vector<int> breakpoints; std::vector<int> breakpoints;
}; };
TestParam TestCase1() {
// Test case 1: same shape. // Test case 1: same shape.
TestCase SameShapeTestCase() {
return {/*input_tensors*/ return {/*input_tensors*/
{{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2}, {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2},
{1, 2, 3, 4}), {1, 2, 3, 4}),
@ -104,8 +102,8 @@ TestParam TestCase1() {
/*breakpoints*/ {0, 2, 5}}; /*breakpoints*/ {0, 2, 5}};
} }
TestParam TestCase2() { // Test case 2: different shape.
// Test case 2: different shape. TestCase DifferentShapeTestCase() {
return { return {
/*input_tensors*/ /*input_tensors*/
{{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 3}, {{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 3},
@ -131,64 +129,59 @@ TestParam TestCase2() {
/*breakpoints*/ {0, 2, 5}}; /*breakpoints*/ {0, 2, 5}};
} }
class ConcatenateDatasetOpTestHelper : public ConcatenateDatasetOpTest { // Test case 3: different dtypes
public: TestCase DifferentDtypeTestCase() {
~ConcatenateDatasetOpTestHelper() override { return {/*input_tensors*/ {{DatasetOpsTestBase::CreateTensor<int64>(
if (dataset_) dataset_->Unref(); TensorShape({2, 2}), {1, 2, 3, 4})},
} {DatasetOpsTestBase::CreateTensor<double>(
TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}},
/*expected_outputs*/ {},
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({2})},
/*expected_cardinality*/ 0,
/*breakpoints*/ {}};
}
protected: class ParameterizedConcatenateDatasetOpTest
Status CreateDatasetFromTestCase(const TestParam &test_case) { : public ConcatenateDatasetOpTest,
std::vector<Tensor> tensor_slice_dataset_tensors; public ::testing::WithParamInterface<TestCase> {};
TF_RETURN_IF_ERROR(CreateTensorSliceDatasetTensors(
test_case.input_tensors, &tensor_slice_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &tensor : tensor_slice_dataset_tensors) {
inputs.emplace_back(&tensor);
}
TF_RETURN_IF_ERROR(CreateConcatenateDatasetKernel(
test_case.expected_output_dtypes, test_case.expected_output_shapes,
&dataset_kernel_));
TF_RETURN_IF_ERROR(CreateConcatenateDatasetContext(
dataset_kernel_.get(), &inputs, &dataset_kernel_ctx_));
TF_RETURN_IF_ERROR(CreateDataset(dataset_kernel_.get(),
dataset_kernel_ctx_.get(), &dataset_));
return Status::OK();
}
Status CreateIteratorFromTestCase(const TestParam &test_case) { TEST_P(ParameterizedConcatenateDatasetOpTest, GetNext) {
TF_RETURN_IF_ERROR(CreateDatasetFromTestCase(test_case));
TF_RETURN_IF_ERROR(
CreateIteratorContext(dataset_kernel_ctx_.get(), &iterator_ctx_));
TF_RETURN_IF_ERROR(
dataset_->MakeIterator(iterator_ctx_.get(), "Iterator", &iterator_));
return Status::OK();
}
std::unique_ptr<OpKernel> dataset_kernel_;
std::unique_ptr<OpKernelContext> dataset_kernel_ctx_;
DatasetBase *dataset_ = nullptr; // owned by this class.
std::unique_ptr<IteratorContext> iterator_ctx_;
std::unique_ptr<IteratorBase> iterator_;
};
class ParameterizedDatasetTest
: public ConcatenateDatasetOpTestHelper,
public ::testing::WithParamInterface<TestParam> {};
TEST_P(ParameterizedDatasetTest, GetNext) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); const TestCase &test_case = GetParam();
std::vector<Tensor> tensor_slice_dataset_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors,
&tensor_slice_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &tensor : tensor_slice_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *concatenate_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&concatenate_dataset));
core::ScopedUnref scoped_unref(concatenate_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(concatenate_dataset->MakeIterator(iterator_ctx.get(), "Iterator",
&iterator));
auto expected_outputs_it = test_case.expected_outputs.begin(); auto expected_outputs_it = test_case.expected_outputs.begin();
bool end_of_sequence = false; bool end_of_sequence = false;
std::vector<Tensor> out_tensors; std::vector<Tensor> out_tensors;
while (!end_of_sequence) { while (!end_of_sequence) {
TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors, TF_EXPECT_OK(
&end_of_sequence)); iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence));
if (!end_of_sequence) { if (!end_of_sequence) {
for (const auto &tensor : out_tensors) { for (const auto &tensor : out_tensors) {
EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
@ -200,113 +193,334 @@ TEST_P(ParameterizedDatasetTest, GetNext) {
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
} }
TEST_F(ConcatenateDatasetOpTestHelper, DifferentDtypes) { TEST_F(ConcatenateDatasetOpTest, DifferentDtypes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TestParam test_case_with_different_dtypes = { const TestCase &test_case = DifferentDtypeTestCase();
/*input_tensors*/ { std::vector<Tensor> tensor_slice_dataset_tensors;
{CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4})}, TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors,
{CreateTensor<double>(TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}}, &tensor_slice_dataset_tensors));
/*expected_outputs*/ {}, gtl::InlinedVector<TensorValue, 4> inputs;
/*expected_output_dtypes*/ {DT_INT64}, for (auto &tensor : tensor_slice_dataset_tensors) {
/*expected_output_shapes*/ {PartialTensorShape({2})}, inputs.emplace_back(&tensor);
/*expected_cardinality*/ 0, }
/*breakpoints*/ {}}; std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes,
EXPECT_EQ(CreateDatasetFromTestCase(test_case_with_different_dtypes).code(), test_case.expected_output_shapes,
&dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *concatenate_dataset;
EXPECT_EQ(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&concatenate_dataset)
.code(),
tensorflow::error::INVALID_ARGUMENT); tensorflow::error::INVALID_ARGUMENT);
} }
TEST_F(ConcatenateDatasetOpTestHelper, DatasetName) { TEST_F(ConcatenateDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1()));
EXPECT_EQ(dataset_->type_string(), kOpName); const TestCase &test_case = SameShapeTestCase();
std::vector<Tensor> tensor_slice_dataset_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors,
&tensor_slice_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &tensor : tensor_slice_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *concatenate_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&concatenate_dataset));
core::ScopedUnref scoped_unref(concatenate_dataset);
EXPECT_EQ(concatenate_dataset->node_name(), kNodeName);
} }
TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) { TEST_F(ConcatenateDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
TF_ASSERT_OK(CreateDatasetFromTestCase(test_case)); const TestCase &test_case = SameShapeTestCase();
TF_EXPECT_OK(VerifyTypesMatch(dataset_->output_dtypes(), std::vector<Tensor> tensor_slice_dataset_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors,
&tensor_slice_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &tensor : tensor_slice_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *concatenate_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&concatenate_dataset));
core::ScopedUnref scoped_unref(concatenate_dataset);
EXPECT_EQ(concatenate_dataset->type_string(), kOpName);
}
TEST_P(ParameterizedConcatenateDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
std::vector<Tensor> tensor_slice_dataset_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors,
&tensor_slice_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &tensor : tensor_slice_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *concatenate_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&concatenate_dataset));
core::ScopedUnref scoped_unref(concatenate_dataset);
TF_EXPECT_OK(VerifyTypesMatch(concatenate_dataset->output_dtypes(),
test_case.expected_output_dtypes)); test_case.expected_output_dtypes));
} }
TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) { TEST_P(ParameterizedConcatenateDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
TF_ASSERT_OK(CreateDatasetFromTestCase(test_case)); const TestCase &test_case = GetParam();
TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(), std::vector<Tensor> tensor_slice_dataset_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors,
&tensor_slice_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &tensor : tensor_slice_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *concatenate_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&concatenate_dataset));
core::ScopedUnref scoped_unref(concatenate_dataset);
TF_EXPECT_OK(VerifyShapesCompatible(concatenate_dataset->output_shapes(),
test_case.expected_output_shapes)); test_case.expected_output_shapes));
} }
TEST_P(ParameterizedDatasetTest, Cardinality) { TEST_P(ParameterizedConcatenateDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
TF_ASSERT_OK(CreateDatasetFromTestCase(test_case));
EXPECT_EQ(dataset_->Cardinality(), GetParam().expected_cardinality); const TestCase &test_case = GetParam();
std::vector<Tensor> tensor_slice_dataset_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors,
&tensor_slice_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &tensor : tensor_slice_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *concatenate_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&concatenate_dataset));
core::ScopedUnref scoped_unref(concatenate_dataset);
EXPECT_EQ(concatenate_dataset->Cardinality(), test_case.expected_cardinality);
} }
TEST_F(ConcatenateDatasetOpTestHelper, DatasetSave) { TEST_F(ConcatenateDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1()));
const TestCase &test_case = SameShapeTestCase();
std::vector<Tensor> tensor_slice_dataset_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors,
&tensor_slice_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &tensor : tensor_slice_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *concatenate_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&concatenate_dataset));
core::ScopedUnref scoped_unref(concatenate_dataset);
std::unique_ptr<SerializationContext> serialization_ctx; std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data; VariantTensorData data;
VariantTensorDataWriter writer(&data); VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(dataset_->Save(serialization_ctx.get(), &writer)); TF_ASSERT_OK(concatenate_dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush()); TF_ASSERT_OK(writer.Flush());
} }
TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) { TEST_P(ParameterizedConcatenateDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); const TestCase &test_case = GetParam();
TF_EXPECT_OK(VerifyTypesMatch(iterator_->output_dtypes(), std::vector<Tensor> tensor_slice_dataset_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors,
&tensor_slice_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &tensor : tensor_slice_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *concatenate_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&concatenate_dataset));
core::ScopedUnref scoped_unref(concatenate_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(concatenate_dataset->MakeIterator(iterator_ctx.get(), "Iterator",
&iterator));
TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
test_case.expected_output_dtypes)); test_case.expected_output_dtypes));
} }
TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) { TEST_P(ParameterizedConcatenateDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); const TestCase &test_case = GetParam();
TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(), std::vector<Tensor> tensor_slice_dataset_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors,
&tensor_slice_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &tensor : tensor_slice_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *concatenate_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&concatenate_dataset));
core::ScopedUnref scoped_unref(concatenate_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(concatenate_dataset->MakeIterator(iterator_ctx.get(), "Iterator",
&iterator));
TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
test_case.expected_output_shapes)); test_case.expected_output_shapes));
} }
TEST_F(ConcatenateDatasetOpTestHelper, IteratorOutputPrefix) { TEST_F(ConcatenateDatasetOpTest, IteratorOutputPrefix) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateIteratorFromTestCase(TestCase1()));
EXPECT_EQ(iterator_->prefix(), "Iterator::Concatenate"); const TestCase &test_case = SameShapeTestCase();
std::vector<Tensor> tensor_slice_dataset_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors,
&tensor_slice_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &tensor : tensor_slice_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *concatenate_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&concatenate_dataset));
core::ScopedUnref scoped_unref(concatenate_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(concatenate_dataset->MakeIterator(iterator_ctx.get(), "Iterator",
&iterator));
EXPECT_EQ(iterator->prefix(), "Iterator::Concatenate");
} }
TEST_P(ParameterizedDatasetTest, Roundtrip) { TEST_P(ParameterizedConcatenateDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam(); const TestCase &test_case = GetParam();
auto expected_outputs_it = test_case.expected_outputs.begin(); std::vector<Tensor> tensor_slice_dataset_tensors;
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors,
&tensor_slice_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &tensor : tensor_slice_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *concatenate_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&concatenate_dataset));
core::ScopedUnref scoped_unref(concatenate_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(concatenate_dataset->MakeIterator(iterator_ctx.get(), "Iterator",
&iterator));
std::unique_ptr<SerializationContext> serialization_ctx; std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
@ -314,18 +528,19 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) {
bool end_of_sequence = false; bool end_of_sequence = false;
std::vector<Tensor> out_tensors; std::vector<Tensor> out_tensors;
int cur_iteration = 0; int cur_iteration = 0;
auto expected_outputs_it = test_case.expected_outputs.begin();
std::vector<int> breakpoints = GetParam().breakpoints; std::vector<int> breakpoints = GetParam().breakpoints;
for (int breakpoint : breakpoints) { for (int breakpoint : breakpoints) {
VariantTensorData data; VariantTensorData data;
VariantTensorDataWriter writer(&data); VariantTensorDataWriter writer(&data);
TF_EXPECT_OK(iterator_->Save(serialization_ctx.get(), &writer)); TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush()); TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data); VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator_->Restore(iterator_ctx_.get(), &reader)); TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
while (cur_iteration < breakpoint) { while (cur_iteration < breakpoint) {
TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors, TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
&end_of_sequence)); &end_of_sequence));
if (!end_of_sequence) { if (!end_of_sequence) {
for (auto &tensor : out_tensors) { for (auto &tensor : out_tensors) {
EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
@ -336,7 +551,7 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) {
cur_iteration++; cur_iteration++;
} }
if (breakpoint >= dataset_->Cardinality()) { if (breakpoint >= concatenate_dataset->Cardinality()) {
EXPECT_TRUE(end_of_sequence); EXPECT_TRUE(end_of_sequence);
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
} else { } else {
@ -345,9 +560,10 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) {
} }
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(ConcatenateDatasetOpTest,
ConcatenateDatasetOpTest, ParameterizedDatasetTest, ParameterizedConcatenateDatasetOpTest,
::testing::ValuesIn(std::vector<TestParam>({TestCase1(), TestCase2()}))); ::testing::ValuesIn(std::vector<TestCase>(
{SameShapeTestCase(), DifferentShapeTestCase()})));
} // namespace } // namespace
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow

View File

@ -41,10 +41,10 @@ class RepeatDatasetOpTest : public DatasetOpsTestBase {
const DataTypeVector &output_types, const DataTypeVector &output_types,
const std::vector<PartialTensorShape> &output_shapes, const std::vector<PartialTensorShape> &output_shapes,
std::unique_ptr<OpKernel> *op_kernel) { std::unique_ptr<OpKernel> *op_kernel) {
node_def_ = test::function::NDef( NodeDef node_def = test::function::NDef(
kNodeName, kOpName, {"input_dataset", "count"}, kNodeName, kOpName, {"input_dataset", "count"},
{{"output_types", output_types}, {"output_shapes", output_shapes}}); {{"output_types", output_types}, {"output_shapes", output_shapes}});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel)); TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel));
return Status::OK(); return Status::OK();
} }
@ -56,9 +56,6 @@ class RepeatDatasetOpTest : public DatasetOpsTestBase {
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
return Status::OK(); return Status::OK();
} }
private:
NodeDef node_def_;
}; };
struct TestCase { struct TestCase {
@ -123,11 +120,11 @@ TestCase ForeverRepeatTestCase() {
/*breakpoints*/ {0, 1, 3}}; /*breakpoints*/ {0, 1, 3}};
} }
class ParameterizedDatasetTest class ParameterizedDatasetOpTest
: public RepeatDatasetOpTest, : public RepeatDatasetOpTest,
public ::testing::WithParamInterface<TestCase> {}; public ::testing::WithParamInterface<TestCase> {};
TEST_P(ParameterizedDatasetTest, GetNext) { TEST_P(ParameterizedDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -199,7 +196,38 @@ TEST_P(ParameterizedDatasetTest, GetNext) {
} }
} }
TEST_F(RepeatDatasetOpTest, DatasetName) { TEST_F(RepeatDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = FiniteRepeatTestCase();
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
&tensor_slice_dataset_tensor));
Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
gtl::InlinedVector<TensorValue, 4> inputs_for_repeat_dataset;
inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor);
inputs_for_repeat_dataset.emplace_back(&count);
std::unique_ptr<OpKernel> repeat_dataset_kernel;
TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&repeat_dataset_kernel));
std::unique_ptr<OpKernelContext> repeat_dataset_context;
TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(),
&inputs_for_repeat_dataset,
&repeat_dataset_context));
DatasetBase *repeat_dataset;
TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(),
repeat_dataset_context.get(), &repeat_dataset));
core::ScopedUnref scoped_unref(repeat_dataset);
EXPECT_EQ(repeat_dataset->node_name(), kNodeName);
}
TEST_F(RepeatDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -230,7 +258,7 @@ TEST_F(RepeatDatasetOpTest, DatasetName) {
EXPECT_EQ(repeat_dataset->type_string(), kOpName); EXPECT_EQ(repeat_dataset->type_string(), kOpName);
} }
TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) { TEST_P(ParameterizedDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -260,7 +288,7 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) {
test_case.expected_output_dtypes)); test_case.expected_output_dtypes));
} }
TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) { TEST_P(ParameterizedDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -290,7 +318,7 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) {
test_case.expected_output_shapes)); test_case.expected_output_shapes));
} }
TEST_P(ParameterizedDatasetTest, Cardinality) { TEST_P(ParameterizedDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -355,7 +383,7 @@ TEST_F(RepeatDatasetOpTest, DatasetSave) {
TF_ASSERT_OK(writer.Flush()); TF_ASSERT_OK(writer.Flush());
} }
TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) { TEST_P(ParameterizedDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -392,7 +420,7 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) {
test_case.expected_output_dtypes)); test_case.expected_output_dtypes));
} }
TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) { TEST_P(ParameterizedDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -429,7 +457,7 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) {
test_case.expected_output_shapes)); test_case.expected_output_shapes));
} }
TEST_P(ParameterizedDatasetTest, IteratorOutputPrefix) { TEST_P(ParameterizedDatasetOpTest, IteratorOutputPrefix) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -471,7 +499,7 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputPrefix) {
} }
} }
TEST_P(ParameterizedDatasetTest, Roundtrip) { TEST_P(ParameterizedDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
@ -552,7 +580,7 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) {
} }
} }
INSTANTIATE_TEST_SUITE_P(RepeatDatasetOpTest, ParameterizedDatasetTest, INSTANTIATE_TEST_SUITE_P(RepeatDatasetOpTest, ParameterizedDatasetOpTest,
::testing::ValuesIn(std::vector<TestCase>( ::testing::ValuesIn(std::vector<TestCase>(
{FiniteRepeatTestCase(), EmptyRepeatTestCase(), {FiniteRepeatTestCase(), EmptyRepeatTestCase(),
ForeverRepeatTestCase()}))); ForeverRepeatTestCase()})));

View File

@ -13,19 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h" #include "tensorflow/core/kernels/data/dataset_test_base.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/iterator_ops.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
@ -39,10 +27,10 @@ class SparseTensorSliceDatasetOpTest : public DatasetOpsTestBase {
// Creates a new SparseTensorSliceDataset op kernel. // Creates a new SparseTensorSliceDataset op kernel.
Status CreateSparseTensorSliceDatasetKernel( Status CreateSparseTensorSliceDatasetKernel(
DataType tvalues, std::unique_ptr<OpKernel> *op_kernel) { DataType tvalues, std::unique_ptr<OpKernel> *op_kernel) {
node_def_ = test::function::NDef(kNodeName, kOpName, NodeDef node_def = test::function::NDef(
{"indices", "values", "dense_shape"}, kNodeName, kOpName, {"indices", "values", "dense_shape"},
{{"Tvalues", tvalues}}); {{"Tvalues", tvalues}});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel)); TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel));
return Status::OK(); return Status::OK();
} }
@ -54,9 +42,6 @@ class SparseTensorSliceDatasetOpTest : public DatasetOpsTestBase {
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
return Status::OK(); return Status::OK();
} }
private:
NodeDef node_def_;
}; };
struct SparseTensorParam { struct SparseTensorParam {
@ -71,123 +56,180 @@ struct TestCase {
std::vector<int> breakpoints; std::vector<int> breakpoints;
}; };
std::vector<TestCase> TestCases() { TestCase TwoDimsTestCase() {
return { return {
{{{DatasetOpsTestBase::CreateTensor<int64>({2, 2}, {0, 0, 1, 1})}, /*input_sparse_tensor*/
{DatasetOpsTestBase::CreateTensor<int32>({2}, {888, 999})}, {/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({2, 2},
{DatasetOpsTestBase::CreateTensor<int64>({2}, {2, 2})}}, {0, 0, 1, 1}),
{{{DatasetOpsTestBase::CreateTensor<int64>({1, 1}, {0})}, /*values*/ DatasetOpsTestBase::CreateTensor<int32>({2}, {888, 999}),
{DatasetOpsTestBase::CreateTensor<int32>({1}, {888})}, /*dense_shape*/ DatasetOpsTestBase::CreateTensor<int64>({2}, {2, 2})},
{DatasetOpsTestBase::CreateTensor<int64>({1}, {2})}}, /*expected_outputs*/
{{DatasetOpsTestBase::CreateTensor<int64>({1, 1}, {1})}, {{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 1}, {0}),
{DatasetOpsTestBase::CreateTensor<int32>({1}, {999})}, /*values*/ DatasetOpsTestBase::CreateTensor<int32>({1}, {888}),
{DatasetOpsTestBase::CreateTensor<int64>({1}, {2})}}}, /*dense_shape*/ DatasetOpsTestBase::CreateTensor<int64>({1}, {2})},
{0, 1, 2}}, // 2-D sparse tensor {/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 1}, {1}),
{{{DatasetOpsTestBase::CreateTensor<int64>({2, 3}, {0, 0, 0, 1, 1, 1})}, /*values*/ DatasetOpsTestBase::CreateTensor<int32>({1}, {999}),
{DatasetOpsTestBase::CreateTensor<double>({2}, {888.0, 999.0})}, /*dense_shape*/ DatasetOpsTestBase::CreateTensor<int64>({1}, {2})}},
{DatasetOpsTestBase::CreateTensor<int64>({3}, {2, 2, 2})}}, /*breakpoints*/ {0, 1, 2}};
{{{DatasetOpsTestBase::CreateTensor<int64>({1, 2}, {0, 0})},
{DatasetOpsTestBase::CreateTensor<double>({1}, {888.0})},
{DatasetOpsTestBase::CreateTensor<int64>({2}, {2, 2})}},
{{DatasetOpsTestBase::CreateTensor<int64>({1, 2}, {1, 1})},
{DatasetOpsTestBase::CreateTensor<double>({1}, {999.0})},
{DatasetOpsTestBase::CreateTensor<int64>({2}, {2, 2})}}},
{0, 1, 2}}, // 3-D sparse tensor
{{{DatasetOpsTestBase::CreateTensor<int64>({2, 4},
{0, 0, 0, 0, 1, 1, 1, 1})},
{DatasetOpsTestBase::CreateTensor<string>({2}, {"a", "b"})},
{DatasetOpsTestBase::CreateTensor<int64>({4}, {3, 2, 2, 2})}},
{{{DatasetOpsTestBase::CreateTensor<int64>({1, 3}, {0, 0, 0})},
{DatasetOpsTestBase::CreateTensor<string>({1}, {"a"})},
{DatasetOpsTestBase::CreateTensor<int64>({3}, {2, 2, 2})}},
{{DatasetOpsTestBase::CreateTensor<int64>({1, 3}, {1, 1, 1})},
{DatasetOpsTestBase::CreateTensor<string>({1}, {"b"})},
{DatasetOpsTestBase::CreateTensor<int64>({3}, {2, 2, 2})}},
{{DatasetOpsTestBase::CreateTensor<int64>({0, 3}, {})},
{DatasetOpsTestBase::CreateTensor<string>({0}, {})},
{DatasetOpsTestBase::CreateTensor<int64>({3}, {2, 2, 2})}}},
{0, 1, 3}}, // 4-D sparse tensor
{{{DatasetOpsTestBase::CreateTensor<int64>(
{2, 5}, {0, 0, 0, 0, 0, 1, 1, 1, 1, 1})},
{DatasetOpsTestBase::CreateTensor<int32>({2}, {888, 999})},
{DatasetOpsTestBase::CreateTensor<int64>({5}, {3, 2, 2, 2, 2})}},
{{{DatasetOpsTestBase::CreateTensor<int64>({1, 4}, {0, 0, 0, 0})},
{DatasetOpsTestBase::CreateTensor<int32>({1}, {888})},
{DatasetOpsTestBase::CreateTensor<int64>({4}, {2, 2, 2, 2})}},
{{DatasetOpsTestBase::CreateTensor<int64>({1, 4}, {1, 1, 1, 1})},
{DatasetOpsTestBase::CreateTensor<int32>({1}, {999})},
{DatasetOpsTestBase::CreateTensor<int64>({4}, {2, 2, 2, 2})}},
{{DatasetOpsTestBase::CreateTensor<int64>({0, 4}, {})},
{DatasetOpsTestBase::CreateTensor<int32>({0}, {})},
{DatasetOpsTestBase::CreateTensor<int64>({4}, {2, 2, 2, 2})}}},
{0, 1, 3}} // 5-D sparse tensor
};
} }
TEST_F(SparseTensorSliceDatasetOpTest, GetNext) { TestCase ThreeDimsTestCase() {
return {
/*input_sparse_tensor*/
{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({2, 3},
{0, 0, 0, 1, 1, 1}),
/*values*/ DatasetOpsTestBase::CreateTensor<double>({2}, {888.0, 999.0}),
/*dense_shape*/ DatasetOpsTestBase::CreateTensor<int64>({3}, {2, 2, 2})},
/*expected_outputs*/
{{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 2}, {0, 0}),
/*values*/ DatasetOpsTestBase::CreateTensor<double>({1}, {888.0}),
/*dense_shape*/ DatasetOpsTestBase::CreateTensor<int64>({2}, {2, 2})},
{{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 2}, {1, 1})},
{/*values*/ DatasetOpsTestBase::CreateTensor<double>({1}, {999.0})},
{/*dense_shape*/ DatasetOpsTestBase::CreateTensor<int64>({2},
{2, 2})}}},
/*breakpoints*/ {0, 1, 2}};
}
TestCase FourDimsTestCase() {
return {
/*input_sparse_tensor*/
{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>(
{2, 4}, {0, 0, 0, 0, 1, 1, 1, 1}),
/*values*/ DatasetOpsTestBase::CreateTensor<string>({2}, {"a", "b"}),
/*dense_shape*/
DatasetOpsTestBase::CreateTensor<int64>({4}, {3, 2, 2, 2})},
/*expected_outputs*/
{{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 3}, {0, 0, 0}),
/*values*/ DatasetOpsTestBase::CreateTensor<string>({1}, {"a"}),
/*dense_shape*/
DatasetOpsTestBase::CreateTensor<int64>({3}, {2, 2, 2})},
{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 3}, {1, 1, 1}),
/*values*/ DatasetOpsTestBase::CreateTensor<string>({1}, {"b"}),
/*dense_shape*/
DatasetOpsTestBase::CreateTensor<int64>({3}, {2, 2, 2})},
{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({0, 3}, {}),
/*values*/ DatasetOpsTestBase::CreateTensor<string>({0}, {}),
/*dense_shape*/
DatasetOpsTestBase::CreateTensor<int64>({3}, {2, 2, 2})}},
/*breakpoints*/ {0, 1, 3}};
}
TestCase FiveDimsTestCase() {
return {/*input_sparse_tensor*/
{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>(
{2, 5}, {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}),
/*values*/ DatasetOpsTestBase::CreateTensor<int32>({2}, {888, 999}),
/*dense_shape*/
DatasetOpsTestBase::CreateTensor<int64>({5}, {3, 2, 2, 2, 2})},
/*expected_outputs*/
{{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 4},
{0, 0, 0, 0}),
/*values*/ DatasetOpsTestBase::CreateTensor<int32>({1}, {888}),
/*dense_shape*/
DatasetOpsTestBase::CreateTensor<int64>({4}, {2, 2, 2, 2})},
{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({1, 4},
{1, 1, 1, 1}),
/*values*/ DatasetOpsTestBase::CreateTensor<int32>({1}, {999}),
/*dense_shape*/
DatasetOpsTestBase::CreateTensor<int64>({4}, {2, 2, 2, 2})},
{/*indices*/ DatasetOpsTestBase::CreateTensor<int64>({0, 4}, {}),
/*values*/ DatasetOpsTestBase::CreateTensor<int32>({0}, {}),
/*dense_shape*/
DatasetOpsTestBase::CreateTensor<int64>({4}, {2, 2, 2, 2})}},
/*breakpoints*/ {0, 1, 3}};
}
class ParameterizedSparseTensorSliceDatasetOpTest
: public SparseTensorSliceDatasetOpTest,
public ::testing::WithParamInterface<TestCase> {};
TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
for (auto &test_case : TestCases()) { const TestCase &test_case = GetParam();
SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor;
std::vector<SparseTensorParam> expected_outputs = std::vector<SparseTensorParam> expected_outputs = test_case.expected_outputs;
test_case.expected_outputs; DataType tvalues = input_sparse_tensor.values.dtype();
DataType tvalues = input_sparse_tensor.values.dtype(); gtl::InlinedVector<TensorValue, 4> inputs = {
gtl::InlinedVector<TensorValue, 4> inputs = { &input_sparse_tensor.indices, &input_sparse_tensor.values,
&input_sparse_tensor.indices, &input_sparse_tensor.values, &input_sparse_tensor.dense_shape};
&input_sparse_tensor.dense_shape};
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(
CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext(
dataset_kernel.get(), &inputs, &dataset_kernel_ctx));
DatasetBase *dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&dataset));
core::ScopedUnref scoped_unref(dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(
CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
int cur_slice = 0;
while (!end_of_sequence) {
TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
&end_of_sequence));
if (!end_of_sequence) {
TF_EXPECT_OK(
ExpectEqual(out_tensors[0], expected_outputs[cur_slice].indices));
TF_EXPECT_OK(
ExpectEqual(out_tensors[1], expected_outputs[cur_slice].values));
TF_EXPECT_OK(ExpectEqual(out_tensors[2],
expected_outputs[cur_slice].dense_shape));
cur_slice++;
}
}
}
}
TEST_F(SparseTensorSliceDatasetOpTest, DatasetName) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
int N = 2;
const int NDIM = 2;
Tensor indices = CreateTensor<int64>(TensorShape({N, NDIM}), {0, 0, 1, 1});
Tensor values = CreateTensor<int32>(TensorShape({N}), {888, 999});
Tensor dense_shape = CreateTensor<int64>(TensorShape({NDIM}), {5, 5});
gtl::InlinedVector<TensorValue, 4> inputs = {&indices, &values, &dense_shape};
std::unique_ptr<OpKernel> dataset_kernel; std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(DT_INT32, &dataset_kernel)); TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext(
dataset_kernel.get(), &inputs, &dataset_kernel_ctx));
DatasetBase *dataset;
TF_ASSERT_OK(
CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset));
core::ScopedUnref scoped_unref(dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
auto expected_outputs_it = expected_outputs.begin();
while (!end_of_sequence) {
TF_EXPECT_OK(
iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence));
if (!end_of_sequence) {
TF_EXPECT_OK(ExpectEqual(out_tensors[0], expected_outputs_it->indices));
TF_EXPECT_OK(ExpectEqual(out_tensors[1], expected_outputs_it->values));
TF_EXPECT_OK(
ExpectEqual(out_tensors[2], expected_outputs_it->dense_shape));
expected_outputs_it++;
}
}
EXPECT_EQ(expected_outputs_it, expected_outputs.end());
}
TEST_F(SparseTensorSliceDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = TwoDimsTestCase();
SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor;
std::vector<SparseTensorParam> expected_outputs = test_case.expected_outputs;
DataType tvalues = input_sparse_tensor.values.dtype();
gtl::InlinedVector<TensorValue, 4> inputs = {
&input_sparse_tensor.indices, &input_sparse_tensor.values,
&input_sparse_tensor.dense_shape};
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext(
dataset_kernel.get(), &inputs, &dataset_kernel_ctx));
DatasetBase *dataset;
TF_ASSERT_OK(
CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset));
core::ScopedUnref scoped_unref(dataset);
EXPECT_EQ(dataset->node_name(), kNodeName);
}
TEST_F(SparseTensorSliceDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = TwoDimsTestCase();
SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor;
std::vector<SparseTensorParam> expected_outputs = test_case.expected_outputs;
DataType tvalues = input_sparse_tensor.values.dtype();
gtl::InlinedVector<TensorValue, 4> inputs = {
&input_sparse_tensor.indices, &input_sparse_tensor.values,
&input_sparse_tensor.dense_shape};
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx; std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext(
dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); dataset_kernel.get(), &inputs, &dataset_kernel_ctx));
@ -199,99 +241,90 @@ TEST_F(SparseTensorSliceDatasetOpTest, DatasetName) {
EXPECT_EQ(dataset->type_string(), kOpName); EXPECT_EQ(dataset->type_string(), kOpName);
} }
TEST_F(SparseTensorSliceDatasetOpTest, DatasetOutputDtypes) { TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
for (auto &test_case : TestCases()) { const TestCase &test_case = GetParam();
SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor;
std::vector<SparseTensorParam> expected_outputs = std::vector<SparseTensorParam> expected_outputs = test_case.expected_outputs;
test_case.expected_outputs; DataType tvalues = input_sparse_tensor.values.dtype();
DataType tvalues = input_sparse_tensor.values.dtype(); gtl::InlinedVector<TensorValue, 4> inputs = {
gtl::InlinedVector<TensorValue, 4> inputs = { &input_sparse_tensor.indices, &input_sparse_tensor.values,
&input_sparse_tensor.indices, &input_sparse_tensor.values, &input_sparse_tensor.dense_shape};
&input_sparse_tensor.dense_shape};
std::unique_ptr<OpKernel> dataset_kernel; std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel));
CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
std::unique_ptr<OpKernelContext> dataset_kernel_ctx; TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext(
TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( dataset_kernel.get(), &inputs, &dataset_kernel_ctx));
dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); DatasetBase *dataset;
DatasetBase *dataset; TF_ASSERT_OK(
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset));
&dataset)); core::ScopedUnref scoped_unref(dataset);
core::ScopedUnref scoped_unref(dataset);
DataTypeVector expected_output_dtypes = { DataTypeVector expected_output_dtypes = {
expected_outputs[0].indices.dtype(), expected_outputs[0].values.dtype(), expected_outputs[0].indices.dtype(), expected_outputs[0].values.dtype(),
expected_outputs[0].dense_shape.dtype()}; expected_outputs[0].dense_shape.dtype()};
TF_EXPECT_OK( TF_EXPECT_OK(
VerifyTypesMatch(dataset->output_dtypes(), expected_output_dtypes)); VerifyTypesMatch(dataset->output_dtypes(), expected_output_dtypes));
}
} }
TEST_F(SparseTensorSliceDatasetOpTest, DatasetOutputShapes) { TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
for (auto &test_case : TestCases()) { const TestCase &test_case = GetParam();
SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor;
std::vector<SparseTensorParam> expected_outputs = std::vector<SparseTensorParam> expected_outputs = test_case.expected_outputs;
test_case.expected_outputs; DataType tvalues = input_sparse_tensor.values.dtype();
DataType tvalues = input_sparse_tensor.values.dtype(); gtl::InlinedVector<TensorValue, 4> inputs = {
gtl::InlinedVector<TensorValue, 4> inputs = { &input_sparse_tensor.indices, &input_sparse_tensor.values,
&input_sparse_tensor.indices, &input_sparse_tensor.values, &input_sparse_tensor.dense_shape};
&input_sparse_tensor.dense_shape};
std::unique_ptr<OpKernel> dataset_kernel; std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel));
CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
std::unique_ptr<OpKernelContext> dataset_kernel_ctx; TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext(
TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( dataset_kernel.get(), &inputs, &dataset_kernel_ctx));
dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); DatasetBase *dataset;
DatasetBase *dataset; TF_ASSERT_OK(
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset));
&dataset)); core::ScopedUnref scoped_unref(dataset);
core::ScopedUnref scoped_unref(dataset);
std::vector<PartialTensorShape> expected_output_shapes = { std::vector<PartialTensorShape> expected_output_shapes = {
expected_outputs[0].indices.shape(), expected_outputs[0].values.shape(), expected_outputs[0].indices.shape(), expected_outputs[0].values.shape(),
expected_outputs[0].dense_shape.shape()}; expected_outputs[0].dense_shape.shape()};
TF_EXPECT_OK(VerifyShapesCompatible(dataset->output_shapes(), TF_EXPECT_OK(
expected_output_shapes)); VerifyShapesCompatible(dataset->output_shapes(), expected_output_shapes));
}
} }
TEST_F(SparseTensorSliceDatasetOpTest, Cardinality) { TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
for (auto &test_case : TestCases()) { const TestCase &test_case = TwoDimsTestCase();
SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor;
std::vector<SparseTensorParam> expected_outputs = std::vector<SparseTensorParam> expected_outputs = test_case.expected_outputs;
test_case.expected_outputs; DataType tvalues = input_sparse_tensor.values.dtype();
DataType tvalues = input_sparse_tensor.values.dtype(); gtl::InlinedVector<TensorValue, 4> inputs = {
gtl::InlinedVector<TensorValue, 4> inputs = { &input_sparse_tensor.indices, &input_sparse_tensor.values,
&input_sparse_tensor.indices, &input_sparse_tensor.values, &input_sparse_tensor.dense_shape};
&input_sparse_tensor.dense_shape};
std::unique_ptr<OpKernel> dataset_kernel; std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel));
CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
std::unique_ptr<OpKernelContext> dataset_kernel_ctx; TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext(
TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( dataset_kernel.get(), &inputs, &dataset_kernel_ctx));
dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); DatasetBase *dataset;
DatasetBase *dataset; TF_ASSERT_OK(
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset));
&dataset)); core::ScopedUnref scoped_unref(dataset);
core::ScopedUnref scoped_unref(dataset);
EXPECT_EQ(dataset->Cardinality(), expected_outputs.size()); EXPECT_EQ(dataset->Cardinality(), expected_outputs.size());
}
} }
TEST_F(SparseTensorSliceDatasetOpTest, DatasetSave) { TEST_F(SparseTensorSliceDatasetOpTest, DatasetSave) {
@ -299,15 +332,16 @@ TEST_F(SparseTensorSliceDatasetOpTest, DatasetSave) {
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
int N = 2; const TestCase &test_case = TwoDimsTestCase();
const int NDIM = 2; SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor;
Tensor indices = CreateTensor<int64>(TensorShape({N, NDIM}), {0, 0, 1, 1}); std::vector<SparseTensorParam> expected_outputs = test_case.expected_outputs;
Tensor values = CreateTensor<int32>(TensorShape({N}), {888, 999}); DataType tvalues = input_sparse_tensor.values.dtype();
Tensor dense_shape = CreateTensor<int64>(TensorShape({NDIM}), {5, 5}); gtl::InlinedVector<TensorValue, 4> inputs = {
gtl::InlinedVector<TensorValue, 4> inputs = {&indices, &values, &dense_shape}; &input_sparse_tensor.indices, &input_sparse_tensor.values,
&input_sparse_tensor.dense_shape};
std::unique_ptr<OpKernel> dataset_kernel; std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(DT_INT32, &dataset_kernel)); TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx; std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext(
dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); dataset_kernel.get(), &inputs, &dataset_kernel_ctx));
@ -324,82 +358,74 @@ TEST_F(SparseTensorSliceDatasetOpTest, DatasetSave) {
TF_ASSERT_OK(writer.Flush()); TF_ASSERT_OK(writer.Flush());
} }
TEST_F(SparseTensorSliceDatasetOpTest, IteratorOutputDtypes) { TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
for (auto &test_case : TestCases()) { const TestCase &test_case = GetParam();
SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor;
std::vector<SparseTensorParam> expected_outputs = std::vector<SparseTensorParam> expected_outputs = test_case.expected_outputs;
test_case.expected_outputs; DataType tvalues = input_sparse_tensor.values.dtype();
DataType tvalues = input_sparse_tensor.values.dtype(); gtl::InlinedVector<TensorValue, 4> inputs = {
gtl::InlinedVector<TensorValue, 4> inputs = { &input_sparse_tensor.indices, &input_sparse_tensor.values,
&input_sparse_tensor.indices, &input_sparse_tensor.values, &input_sparse_tensor.dense_shape};
&input_sparse_tensor.dense_shape};
std::unique_ptr<OpKernel> dataset_kernel; std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel));
CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
std::unique_ptr<OpKernelContext> dataset_kernel_ctx; TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext(
TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( dataset_kernel.get(), &inputs, &dataset_kernel_ctx));
dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); DatasetBase *dataset;
DatasetBase *dataset; TF_ASSERT_OK(
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset));
&dataset)); core::ScopedUnref scoped_unref(dataset);
core::ScopedUnref scoped_unref(dataset);
std::unique_ptr<IteratorContext> iterator_ctx; std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK( TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); std::unique_ptr<IteratorBase> iterator;
std::unique_ptr<IteratorBase> iterator; TF_ASSERT_OK(
TF_ASSERT_OK( dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); DataTypeVector expected_output_dtypes = {
DataTypeVector expected_output_dtypes = { expected_outputs[0].indices.dtype(), expected_outputs[0].values.dtype(),
expected_outputs[0].indices.dtype(), expected_outputs[0].values.dtype(), expected_outputs[0].dense_shape.dtype()};
expected_outputs[0].dense_shape.dtype()}; TF_EXPECT_OK(
TF_EXPECT_OK( VerifyTypesMatch(iterator->output_dtypes(), expected_output_dtypes));
VerifyTypesMatch(iterator->output_dtypes(), expected_output_dtypes));
}
} }
TEST_F(SparseTensorSliceDatasetOpTest, IteratorOutputShapes) { TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
for (auto &test_case : TestCases()) { const TestCase &test_case = GetParam();
SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor;
std::vector<SparseTensorParam> expected_outputs = std::vector<SparseTensorParam> expected_outputs = test_case.expected_outputs;
test_case.expected_outputs; DataType tvalues = input_sparse_tensor.values.dtype();
DataType tvalues = input_sparse_tensor.values.dtype(); gtl::InlinedVector<TensorValue, 4> inputs = {
gtl::InlinedVector<TensorValue, 4> inputs = { &input_sparse_tensor.indices, &input_sparse_tensor.values,
&input_sparse_tensor.indices, &input_sparse_tensor.values, &input_sparse_tensor.dense_shape};
&input_sparse_tensor.dense_shape};
std::unique_ptr<OpKernel> dataset_kernel; std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel));
CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
std::unique_ptr<OpKernelContext> dataset_kernel_ctx; TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext(
TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( dataset_kernel.get(), &inputs, &dataset_kernel_ctx));
dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); DatasetBase *dataset;
DatasetBase *dataset; TF_ASSERT_OK(
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset));
&dataset)); core::ScopedUnref scoped_unref(dataset);
core::ScopedUnref scoped_unref(dataset);
std::unique_ptr<IteratorContext> iterator_ctx; std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK( TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); std::unique_ptr<IteratorBase> iterator;
std::unique_ptr<IteratorBase> iterator; TF_ASSERT_OK(
TF_ASSERT_OK( dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); std::vector<PartialTensorShape> expected_output_shapes = {
std::vector<PartialTensorShape> expected_output_shapes = { expected_outputs[0].indices.shape(), expected_outputs[0].values.shape(),
expected_outputs[0].indices.shape(), expected_outputs[0].values.shape(), expected_outputs[0].dense_shape.shape()};
expected_outputs[0].dense_shape.shape()}; TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), expected_output_shapes));
expected_output_shapes));
}
} }
TEST_F(SparseTensorSliceDatasetOpTest, IteratorOutputPrefix) { TEST_F(SparseTensorSliceDatasetOpTest, IteratorOutputPrefix) {
@ -407,15 +433,16 @@ TEST_F(SparseTensorSliceDatasetOpTest, IteratorOutputPrefix) {
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
int N = 2; const TestCase &test_case = TwoDimsTestCase();
const int NDIM = 2; SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor;
Tensor indices = CreateTensor<int64>(TensorShape({N, NDIM}), {0, 0, 1, 1}); std::vector<SparseTensorParam> expected_outputs = test_case.expected_outputs;
Tensor values = CreateTensor<int32>(TensorShape({N}), {888, 999}); DataType tvalues = input_sparse_tensor.values.dtype();
Tensor dense_shape = CreateTensor<int64>(TensorShape({NDIM}), {5, 5}); gtl::InlinedVector<TensorValue, 4> inputs = {
gtl::InlinedVector<TensorValue, 4> inputs = {&indices, &values, &dense_shape}; &input_sparse_tensor.indices, &input_sparse_tensor.values,
&input_sparse_tensor.dense_shape};
std::unique_ptr<OpKernel> dataset_kernel; std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(DT_INT32, &dataset_kernel)); TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx; std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext(
dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); dataset_kernel.get(), &inputs, &dataset_kernel_ctx));
@ -432,79 +459,81 @@ TEST_F(SparseTensorSliceDatasetOpTest, IteratorOutputPrefix) {
EXPECT_EQ(iterator->prefix(), strings::StrCat("Iterator::SparseTensorSlice")); EXPECT_EQ(iterator->prefix(), strings::StrCat("Iterator::SparseTensorSlice"));
} }
TEST_F(SparseTensorSliceDatasetOpTest, Roundtrip) { TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
for (auto &test_case : TestCases()) { const TestCase &test_case = GetParam();
SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor;
std::vector<SparseTensorParam> expected_outputs = std::vector<SparseTensorParam> expected_outputs = test_case.expected_outputs;
test_case.expected_outputs; std::vector<int> breakpoints = test_case.breakpoints;
std::vector<int> breakpoints = test_case.breakpoints; DataType tvalues = input_sparse_tensor.values.dtype();
DataType tvalues = input_sparse_tensor.values.dtype(); gtl::InlinedVector<TensorValue, 4> inputs = {
gtl::InlinedVector<TensorValue, 4> inputs = { &input_sparse_tensor.indices, &input_sparse_tensor.values,
&input_sparse_tensor.indices, &input_sparse_tensor.values, &input_sparse_tensor.dense_shape};
&input_sparse_tensor.dense_shape};
std::unique_ptr<OpKernel> dataset_kernel; std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel));
CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
std::unique_ptr<OpKernelContext> dataset_kernel_ctx; TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext(
TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( dataset_kernel.get(), &inputs, &dataset_kernel_ctx));
dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); DatasetBase *dataset;
DatasetBase *dataset; TF_ASSERT_OK(
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset));
&dataset)); core::ScopedUnref scoped_unref(dataset);
core::ScopedUnref scoped_unref(dataset);
std::unique_ptr<IteratorContext> iterator_ctx; std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK( TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); std::unique_ptr<IteratorBase> iterator;
std::unique_ptr<IteratorBase> iterator; TF_ASSERT_OK(
TF_ASSERT_OK( dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
std::unique_ptr<SerializationContext> serialization_ctx; std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
int cur_iteration = 0; int cur_iteration = 0;
bool end_of_sequence = false; bool end_of_sequence = false;
int64 num_slices = input_sparse_tensor.dense_shape.dim_size(0); int64 num_slices = input_sparse_tensor.dense_shape.dim_size(0);
std::vector<Tensor> out_tensors; std::vector<Tensor> out_tensors;
for (int breakpoint : breakpoints) { for (int breakpoint : breakpoints) {
while (cur_iteration < breakpoint) { while (cur_iteration < breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
&end_of_sequence)); &end_of_sequence));
cur_iteration++; cur_iteration++;
}
if (breakpoint == 0) {
EXPECT_FALSE(end_of_sequence);
} else if (breakpoint <= num_slices) {
for (int i = 0; i < out_tensors.size(); ++i) {
TF_EXPECT_OK(ExpectEqual(
out_tensors[0], expected_outputs[cur_iteration - 1].indices));
TF_EXPECT_OK(ExpectEqual(out_tensors[1],
expected_outputs[cur_iteration - 1].values));
TF_EXPECT_OK(ExpectEqual(
out_tensors[2], expected_outputs[cur_iteration - 1].dense_shape));
}
} else {
EXPECT_TRUE(end_of_sequence);
}
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_ASSERT_OK(iterator->Restore(iterator_ctx.get(), &reader));
} }
if (breakpoint == 0) {
EXPECT_FALSE(end_of_sequence);
} else if (breakpoint <= num_slices) {
for (int i = 0; i < out_tensors.size(); ++i) {
TF_EXPECT_OK(ExpectEqual(out_tensors[0],
expected_outputs[cur_iteration - 1].indices));
TF_EXPECT_OK(ExpectEqual(out_tensors[1],
expected_outputs[cur_iteration - 1].values));
TF_EXPECT_OK(ExpectEqual(
out_tensors[2], expected_outputs[cur_iteration - 1].dense_shape));
}
} else {
EXPECT_TRUE(end_of_sequence);
}
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_ASSERT_OK(iterator->Restore(iterator_ctx.get(), &reader));
} }
} }
INSTANTIATE_TEST_SUITE_P(SparseTensorSliceDatasetOpTest,
ParameterizedSparseTensorSliceDatasetOpTest,
::testing::ValuesIn(std::vector<TestCase>(
{TwoDimsTestCase(), ThreeDimsTestCase(),
FourDimsTestCase(), FiveDimsTestCase()})));
} // namespace } // namespace
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow

View File

@ -38,10 +38,10 @@ class TakeDatasetOpTest : public DatasetOpsTestBase {
const DataTypeVector &output_types, const DataTypeVector &output_types,
const std::vector<PartialTensorShape> &output_shapes, const std::vector<PartialTensorShape> &output_shapes,
std::unique_ptr<OpKernel> *op_kernel) { std::unique_ptr<OpKernel> *op_kernel) {
node_def_ = test::function::NDef( NodeDef node_def = test::function::NDef(
kNodeName, kOpName, {"input_dataset", "count"}, kNodeName, kOpName, {"input_dataset", "count"},
{{"output_types", output_types}, {"output_shapes", output_shapes}}); {{"output_types", output_types}, {"output_shapes", output_shapes}});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel)); TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel));
return Status::OK(); return Status::OK();
} }
@ -53,9 +53,6 @@ class TakeDatasetOpTest : public DatasetOpsTestBase {
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
return Status::OK(); return Status::OK();
} }
private:
NodeDef node_def_;
}; };
struct TestCase { struct TestCase {

File diff suppressed because it is too large Load Diff

View File

@ -58,10 +58,10 @@ class ZipDatasetOpTest : public DatasetOpsTestBase {
// Create the placeholder names for the input components of `ZipDataset`. // Create the placeholder names for the input components of `ZipDataset`.
input_datasets.emplace_back(strings::StrCat("input_dataset_", i)); input_datasets.emplace_back(strings::StrCat("input_dataset_", i));
} }
node_def_ = test::function::NDef( NodeDef node_def = test::function::NDef(
kNodeName, kOpName, input_datasets, kNodeName, kOpName, input_datasets,
{{"output_types", dtypes}, {"output_shapes", output_shapes}, {"N", n}}); {{"output_types", dtypes}, {"output_shapes", output_shapes}, {"N", n}});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel)); TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel));
return Status::OK(); return Status::OK();
} }
@ -74,9 +74,6 @@ class ZipDatasetOpTest : public DatasetOpsTestBase {
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
return Status::OK(); return Status::OK();
} }
private:
NodeDef node_def_;
}; };
struct TestParam { struct TestParam {
@ -85,8 +82,8 @@ struct TestParam {
std::vector<int> breakpoints; std::vector<int> breakpoints;
}; };
// Test case 1: the input datasets with same number of outputs.
TestParam TestCase1() { TestParam TestCase1() {
// Test case 1: the input datasets with same number of outputs.
return {/*input_range_dataset_params*/ return {/*input_range_dataset_params*/
{RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 13, 1}}, {RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 13, 1}},
/*expected_outputs*/ /*expected_outputs*/
@ -99,8 +96,8 @@ TestParam TestCase1() {
/*breakpoints*/ {0, 1, 4}}; /*breakpoints*/ {0, 1, 4}};
} }
// Test case 2: the input datasets with different number of outputs.
TestParam TestCase2() { TestParam TestCase2() {
// Test case 2: the input datasets with different number of outputs.
return {/*input_range_dataset_params*/ return {/*input_range_dataset_params*/
{RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 15, 1}}, {RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 15, 1}},
/*expected_outputs*/ /*expected_outputs*/
@ -113,67 +110,48 @@ TestParam TestCase2() {
/*breakpoints*/ {0, 1, 4}}; /*breakpoints*/ {0, 1, 4}};
} }
class ZipDatasetOpTestHelper : public ZipDatasetOpTest { class ParameterizedZipDatasetOpTest
public: : public ZipDatasetOpTest,
~ZipDatasetOpTestHelper() override {
if (dataset_) dataset_->Unref();
}
protected:
Status CreateDatasetFromTestCase(const TestParam &test_case) {
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_RETURN_IF_ERROR(CreateRangeDatasetTensors(
test_case.input_range_dataset_params, &range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_RETURN_IF_ERROR(CreateZipDatasetKernel({DT_INT64},
{{num_tensors_per_slice}},
inputs.size(), &dataset_kernel_));
TF_RETURN_IF_ERROR(CreateZipDatasetContext(dataset_kernel_.get(), &inputs,
&dataset_kernel_ctx_));
TF_RETURN_IF_ERROR(CreateDataset(dataset_kernel_.get(),
dataset_kernel_ctx_.get(), &dataset_));
return Status::OK();
}
Status CreateIteratorFromTestCase(const TestParam &test_case) {
TF_RETURN_IF_ERROR(CreateDatasetFromTestCase(test_case));
TF_RETURN_IF_ERROR(
CreateIteratorContext(dataset_kernel_ctx_.get(), &iterator_ctx_));
TF_RETURN_IF_ERROR(
dataset_->MakeIterator(iterator_ctx_.get(), "Iterator", &iterator_));
return Status::OK();
}
std::unique_ptr<OpKernel> dataset_kernel_;
std::unique_ptr<OpKernelContext> dataset_kernel_ctx_;
DatasetBase *dataset_ = nullptr; // owned by this class.
std::unique_ptr<IteratorContext> iterator_ctx_;
std::unique_ptr<IteratorBase> iterator_;
};
class ParameterizedDatasetTest
: public ZipDatasetOpTestHelper,
public ::testing::WithParamInterface<TestParam> {}; public ::testing::WithParamInterface<TestParam> {};
TEST_P(ParameterizedDatasetTest, GetNext) { TEST_P(ParameterizedZipDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam(); const TestParam &test_case = GetParam();
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
auto expected_outputs_it = test_case.expected_outputs.begin(); auto expected_outputs_it = test_case.expected_outputs.begin();
bool end_of_sequence = false; bool end_of_sequence = false;
std::vector<Tensor> out_tensors; std::vector<Tensor> out_tensors;
while (!end_of_sequence) { while (!end_of_sequence) {
TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors, TF_EXPECT_OK(
&end_of_sequence)); iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence));
if (!end_of_sequence) { if (!end_of_sequence) {
for (const auto &tensor : out_tensors) { for (const auto &tensor : out_tensors) {
EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
@ -185,22 +163,92 @@ TEST_P(ParameterizedDatasetTest, GetNext) {
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
} }
TEST_F(ZipDatasetOpTestHelper, DatasetName) { TEST_F(ZipDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1()));
EXPECT_EQ(dataset_->type_string(), kOpName); const TestParam &test_case = TestCase1();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
EXPECT_EQ(zip_dataset->node_name(), kNodeName);
} }
TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) { TEST_F(ZipDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
const TestParam &test_case = TestCase1();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size(); int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateDatasetFromTestCase(test_case)); TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
EXPECT_EQ(zip_dataset->type_string(), kOpName);
}
TEST_P(ParameterizedZipDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
DataTypeVector expected_output_dtypes; DataTypeVector expected_output_dtypes;
expected_output_dtypes.reserve(num_tensors_per_slice); expected_output_dtypes.reserve(num_tensors_per_slice);
@ -209,16 +257,35 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) {
} }
TF_EXPECT_OK( TF_EXPECT_OK(
VerifyTypesMatch(dataset_->output_dtypes(), expected_output_dtypes)); VerifyTypesMatch(zip_dataset->output_dtypes(), expected_output_dtypes));
} }
TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) { TEST_P(ParameterizedZipDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam(); const TestParam &test_case = GetParam();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size(); int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateDatasetFromTestCase(test_case)); TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
std::vector<PartialTensorShape> expected_output_shapes; std::vector<PartialTensorShape> expected_output_shapes;
expected_output_shapes.reserve(num_tensors_per_slice); expected_output_shapes.reserve(num_tensors_per_slice);
@ -226,43 +293,107 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) {
expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape()); expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape());
} }
TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(), TF_EXPECT_OK(VerifyShapesCompatible(zip_dataset->output_shapes(),
expected_output_shapes)); expected_output_shapes));
} }
TEST_P(ParameterizedDatasetTest, Cardinality) { TEST_P(ParameterizedZipDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateDatasetFromTestCase(test_case));
EXPECT_EQ(dataset_->Cardinality(), const TestParam &test_case = GetParam();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
EXPECT_EQ(zip_dataset->Cardinality(),
test_case.expected_outputs.size() / num_tensors_per_slice); test_case.expected_outputs.size() / num_tensors_per_slice);
} }
TEST_F(ZipDatasetOpTestHelper, DatasetSave) { TEST_F(ZipDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1()));
const TestParam &test_case = TestCase1();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
std::unique_ptr<SerializationContext> serialization_ctx; std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data; VariantTensorData data;
VariantTensorDataWriter writer(&data); VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(dataset_->Save(serialization_ctx.get(), &writer)); TF_ASSERT_OK(zip_dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush()); TF_ASSERT_OK(writer.Flush());
} }
TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) { TEST_P(ParameterizedZipDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam(); const TestParam &test_case = GetParam();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size(); int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
DataTypeVector expected_output_dtypes; DataTypeVector expected_output_dtypes;
expected_output_dtypes.reserve(num_tensors_per_slice); expected_output_dtypes.reserve(num_tensors_per_slice);
@ -271,16 +402,40 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) {
} }
TF_EXPECT_OK( TF_EXPECT_OK(
VerifyTypesMatch(iterator_->output_dtypes(), expected_output_dtypes)); VerifyTypesMatch(iterator->output_dtypes(), expected_output_dtypes));
} }
TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) { TEST_P(ParameterizedZipDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam(); const TestParam &test_case = GetParam();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size(); int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
std::vector<PartialTensorShape> expected_output_shapes; std::vector<PartialTensorShape> expected_output_shapes;
expected_output_shapes.reserve(num_tensors_per_slice); expected_output_shapes.reserve(num_tensors_per_slice);
@ -288,43 +443,95 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) {
expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape()); expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape());
} }
TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(), TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
expected_output_shapes)); expected_output_shapes));
} }
TEST_F(ZipDatasetOpTestHelper, IteratorOutputPrefix) { TEST_F(ZipDatasetOpTest, IteratorOutputPrefix) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateIteratorFromTestCase(TestCase1()));
EXPECT_EQ(iterator_->prefix(), "Iterator::Zip"); const TestParam &test_case = TestCase1();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
EXPECT_EQ(iterator->prefix(), "Iterator::Zip");
} }
TEST_P(ParameterizedDatasetTest, Roundtrip) { TEST_P(ParameterizedZipDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam(); const TestParam &test_case = GetParam();
auto expected_outputs_it = test_case.expected_outputs.begin(); std::vector<Tensor> range_dataset_tensors;
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
std::unique_ptr<SerializationContext> serialization_ctx; std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
bool end_of_sequence = false; bool end_of_sequence = false;
std::vector<Tensor> out_tensors; std::vector<Tensor> out_tensors;
auto expected_outputs_it = test_case.expected_outputs.begin();
int cur_iteration = 0; int cur_iteration = 0;
for (int breakpoint : test_case.breakpoints) { for (int breakpoint : test_case.breakpoints) {
VariantTensorData data; VariantTensorData data;
VariantTensorDataWriter writer(&data); VariantTensorDataWriter writer(&data);
TF_EXPECT_OK(iterator_->Save(serialization_ctx.get(), &writer)); TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush()); TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data); VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator_->Restore(iterator_ctx_.get(), &reader)); TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
while (cur_iteration < breakpoint) { while (cur_iteration < breakpoint) {
TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors, TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
&end_of_sequence)); &end_of_sequence));
if (!end_of_sequence) { if (!end_of_sequence) {
for (auto &tensor : out_tensors) { for (auto &tensor : out_tensors) {
EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
@ -335,7 +542,7 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) {
cur_iteration++; cur_iteration++;
} }
if (breakpoint >= dataset_->Cardinality()) { if (breakpoint >= zip_dataset->Cardinality()) {
EXPECT_TRUE(end_of_sequence); EXPECT_TRUE(end_of_sequence);
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
} else { } else {
@ -345,7 +552,7 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) {
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
ZipDatasetOpTest, ParameterizedDatasetTest, ZipDatasetOpTest, ParameterizedZipDatasetOpTest,
::testing::ValuesIn(std::vector<TestParam>({TestCase1(), TestCase2()}))); ::testing::ValuesIn(std::vector<TestParam>({TestCase1(), TestCase2()})));
} // namespace } // namespace