Refactor TensorSliceDatasetOpTest

This commit is contained in:
Fei Hu 2019-03-18 21:31:24 -07:00
parent 578739fb89
commit a280e8644b

View File

@ -46,10 +46,10 @@ class TensorSliceDatasetOpTest : public DatasetOpsTestBase {
components.emplace_back(strings::StrCat("component_", i)); components.emplace_back(strings::StrCat("component_", i));
} }
node_def_ = test::function::NDef( NodeDef node_def = test::function::NDef(
kNodeName, kOpName, components, kNodeName, kOpName, components,
{{"Toutput_types", dtypes}, {"output_shapes", shapes}}); {{"Toutput_types", dtypes}, {"output_shapes", shapes}});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, tensor_dataset_kernel)); TF_RETURN_IF_ERROR(CreateOpKernel(node_def, tensor_dataset_kernel));
return Status::OK(); return Status::OK();
} }
@ -63,9 +63,6 @@ class TensorSliceDatasetOpTest : public DatasetOpsTestBase {
CreateOpKernelContext(tensor_dataset_kernel, inputs, context)); CreateOpKernelContext(tensor_dataset_kernel, inputs, context));
return Status::OK(); return Status::OK();
} }
private:
NodeDef node_def_;
}; };
struct TestCase { struct TestCase {
@ -74,28 +71,31 @@ struct TestCase {
std::vector<int> breakpoints; std::vector<int> breakpoints;
}; };
std::vector<TestCase> TestCases() { TestCase PlainTensorTestCase() {
return { return {/*components*/
// A single tuple of tensors. {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
{{{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 2}), DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 2}),
{1, 2, 3, 4}), {1, 2, 3, 4}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 1}), DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 1}),
{37.0, 38.0}), {37.0, 38.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({2, 1}), DatasetOpsTestBase::CreateTensor<string>(TensorShape({2, 1}),
{"a", "b"})}}, // components {"a", "b"})},
{{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}), /*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}), DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {37.0}), DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {37.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"a"}), DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"a"}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}), DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {3, 4}), DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {3, 4}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {38.0}), DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {38.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"b"})},
{"b"})}}, // expected_outputs /*breakpoints*/ {0, 1, 3}};
{{0, 1, 3}}}, // breakpoints }
// Nested tensors
{{{DatasetOpsTestBase::CreateTensor<Variant>( TestCase NestedTensorTestCase() {
return {
/*components*/
{DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({2, 1}), TensorShape({2, 1}),
{DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}), {DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}),
{1.0, 2.0, 3.0, 4.0}), {1.0, 2.0, 3.0, 4.0}),
@ -106,9 +106,10 @@ std::vector<TestCase> TestCases() {
TensorShape({1, 2}), {"a", "b"}), TensorShape({1, 2}), {"a", "b"}),
DatasetOpsTestBase::CreateTensor<string>( DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1, 2}), {"c", "d"})}), TensorShape({1, 2}), {"c", "d"})}),
DatasetOpsTestBase::CreateTensor<int64>( DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 3}),
TensorShape({2, 3}), {1, 2, 3, 4, 5, 6})}}, // components {1, 2, 3, 4, 5, 6})},
{{DatasetOpsTestBase::CreateTensor<Variant>( /*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>( TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}), TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
DatasetOpsTestBase::CreateTensor<Variant>( DatasetOpsTestBase::CreateTensor<Variant>(
@ -121,40 +122,41 @@ std::vector<TestCase> TestCases() {
DatasetOpsTestBase::CreateTensor<Variant>( DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>( TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1, 2}), {"c", "d"})}), TensorShape({1, 2}), {"c", "d"})}),
DatasetOpsTestBase::CreateTensor<int64>( DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {4, 5, 6})},
TensorShape({3}), {4, 5, 6})}}, // expected_outputs /*breakpoints*/ {0, 1, 2}};
{{0, 1, 2}}} // breakpoints
};
} }
TEST_F(TensorSliceDatasetOpTest, GetNext) { class ParameterizedTensorSliceDatasetOpTest
int thread_num = 2, cpu_num = 2; : public TensorSliceDatasetOpTest,
for (auto &test_case : TestCases()) { public ::testing::WithParamInterface<TestCase> {};
std::vector<Tensor> components = test_case.components;
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TEST_P(ParameterizedTensorSliceDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes; DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes;
gtl::InlinedVector<TensorValue, 4> inputs; gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) { for (auto &component : components) {
inputs.push_back(&component); inputs.emplace_back(&component);
dtypes.push_back(component.dtype()); dtypes.emplace_back(component.dtype());
} }
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) { for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape()); shapes.emplace_back(expected_outputs[i].shape());
} }
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel; std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel)); &tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context; std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext( TF_ASSERT_OK(
tensor_slice_dataset_kernel.get(), &inputs, CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&tensor_slice_dataset_context)); &inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset; DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(), tensor_slice_dataset_context.get(),
@ -175,13 +177,11 @@ TEST_F(TensorSliceDatasetOpTest, GetNext) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
&end_of_sequence)); &end_of_sequence));
for (int i = 0; i < out_tensors.size(); ++i) { for (int i = 0; i < out_tensors.size(); ++i) {
EXPECT_LT(i + num_tensors_per_slice * cur_slice, EXPECT_LT(i + num_tensors_per_slice * cur_slice, expected_outputs.size());
expected_outputs.size());
if (out_tensors[i].dtype() == DT_VARIANT) { if (out_tensors[i].dtype() == DT_VARIANT) {
// Currently `ExpectEqual()` does not support the variant tensor // Currently `ExpectEqual()` does not support the variant tensor
// yet, so we manually cast the variant to numeric/string tensor. // yet, so we manually cast the variant to numeric/string tensor.
const Tensor *output = const Tensor *output = out_tensors[i].scalar<Variant>()().get<Tensor>();
out_tensors[i].scalar<Variant>()().get<Tensor>();
const Tensor *expected_output = const Tensor *expected_output =
expected_outputs[i + num_tensors_per_slice * cur_slice] expected_outputs[i + num_tensors_per_slice * cur_slice]
.scalar<Variant>()() .scalar<Variant>()()
@ -196,20 +196,64 @@ TEST_F(TensorSliceDatasetOpTest, GetNext) {
out_tensors.clear(); out_tensors.clear();
cur_slice++; cur_slice++;
} }
}
} }
TEST_F(TensorSliceDatasetOpTest, DatasetName) { TEST_F(TensorSliceDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor t1 = CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4}); const TestCase &test_case = PlainTensorTestCase();
Tensor t2 = CreateTensor<int64>(TensorShape({2, 2}), {5, 6, 7, 8}); const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
gtl::InlinedVector<TensorValue, 4> inputs = {&t1, &t2}; std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes({DT_INT64, DT_INT64}); DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes = {PartialTensorShape({2}), gtl::InlinedVector<TensorValue, 4> inputs;
PartialTensorShape({2})}; for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
EXPECT_EQ(tensor_slice_dataset->node_name(), kNodeName);
}
TEST_F(TensorSliceDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = PlainTensorTestCase();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel; std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel)); &tensor_slice_dataset_kernel));
@ -226,34 +270,33 @@ TEST_F(TensorSliceDatasetOpTest, DatasetName) {
EXPECT_EQ(tensor_slice_dataset->type_string(), kOpName); EXPECT_EQ(tensor_slice_dataset->type_string(), kOpName);
} }
TEST_F(TensorSliceDatasetOpTest, DatasetOutputDtypes) { TEST_P(ParameterizedTensorSliceDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) {
std::vector<Tensor> components = test_case.components;
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes; DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes;
gtl::InlinedVector<TensorValue, 4> inputs; gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) { for (auto &component : components) {
inputs.emplace_back(&component); inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype()); dtypes.emplace_back(component.dtype());
} }
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) { for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape()); shapes.emplace_back(expected_outputs[i].shape());
} }
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel; std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel)); &tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context; std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext( TF_ASSERT_OK(
tensor_slice_dataset_kernel.get(), &inputs, CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&tensor_slice_dataset_context)); &inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset; DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(), tensor_slice_dataset_context.get(),
@ -266,26 +309,25 @@ TEST_F(TensorSliceDatasetOpTest, DatasetOutputDtypes) {
for (int i = 0; i < num_tensors_per_slice; ++i) { for (int i = 0; i < num_tensors_per_slice; ++i) {
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype()); EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
} }
}
} }
TEST_F(TensorSliceDatasetOpTest, DatasetOutputShapes) { TEST_P(ParameterizedTensorSliceDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) {
std::vector<Tensor> components = test_case.components;
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes; DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes;
gtl::InlinedVector<TensorValue, 4> inputs; gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) { for (auto &component : components) {
inputs.emplace_back(&component); inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype()); dtypes.emplace_back(component.dtype());
} }
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) { for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape()); shapes.emplace_back(expected_outputs[i].shape());
} }
@ -293,9 +335,9 @@ TEST_F(TensorSliceDatasetOpTest, DatasetOutputShapes) {
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel)); &tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context; std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext( TF_ASSERT_OK(
tensor_slice_dataset_kernel.get(), &inputs, CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&tensor_slice_dataset_context)); &inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset; DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(), tensor_slice_dataset_context.get(),
@ -310,26 +352,25 @@ TEST_F(TensorSliceDatasetOpTest, DatasetOutputShapes) {
EXPECT_TRUE( EXPECT_TRUE(
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape())); produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
} }
}
} }
TEST_F(TensorSliceDatasetOpTest, Cardinality) { TEST_P(ParameterizedTensorSliceDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) {
std::vector<Tensor> components = test_case.components;
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes; DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes;
gtl::InlinedVector<TensorValue, 4> inputs; gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) { for (auto &component : components) {
inputs.emplace_back(&component); inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype()); dtypes.emplace_back(component.dtype());
} }
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) { for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape()); shapes.emplace_back(expected_outputs[i].shape());
} }
@ -337,18 +378,16 @@ TEST_F(TensorSliceDatasetOpTest, Cardinality) {
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel)); &tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context; std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext( TF_ASSERT_OK(
tensor_slice_dataset_kernel.get(), &inputs, CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&tensor_slice_dataset_context)); &inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset; DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(), tensor_slice_dataset_context.get(),
&tensor_slice_dataset)); &tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset); core::ScopedUnref scored_unref(tensor_slice_dataset);
EXPECT_EQ(tensor_slice_dataset->Cardinality(), EXPECT_EQ(tensor_slice_dataset->Cardinality(), inputs[0].tensor->dim_size(0));
inputs[0].tensor->dim_size(0));
}
} }
TEST_F(TensorSliceDatasetOpTest, DatasetSave) { TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
@ -356,12 +395,21 @@ TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor t1 = CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4}); const TestCase &test_case = PlainTensorTestCase();
Tensor t2 = CreateTensor<int64>(TensorShape({2, 2}), {5, 6, 7, 8}); const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
gtl::InlinedVector<TensorValue, 4> inputs = {&t1, &t2}; std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes({DT_INT64, DT_INT64}); DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes = {PartialTensorShape({2}), gtl::InlinedVector<TensorValue, 4> inputs;
PartialTensorShape({2})}; for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel; std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel)); &tensor_slice_dataset_kernel));
@ -384,34 +432,33 @@ TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
TF_ASSERT_OK(writer.Flush()); TF_ASSERT_OK(writer.Flush());
} }
TEST_F(TensorSliceDatasetOpTest, IteratorOutputDtypes) { TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) {
std::vector<Tensor> components = test_case.components;
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes; DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes;
gtl::InlinedVector<TensorValue, 4> inputs; gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) { for (auto &component : components) {
inputs.emplace_back(&component); inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype()); dtypes.emplace_back(component.dtype());
} }
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) { for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape()); shapes.emplace_back(expected_outputs[i].shape());
} }
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel; std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel)); &tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context; std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext( TF_ASSERT_OK(
tensor_slice_dataset_kernel.get(), &inputs, CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&tensor_slice_dataset_context)); &inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset; DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(), tensor_slice_dataset_context.get(),
@ -430,37 +477,35 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputDtypes) {
for (int i = 0; i < num_tensors_per_slice; ++i) { for (int i = 0; i < num_tensors_per_slice; ++i) {
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype()); EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
} }
}
} }
TEST_F(TensorSliceDatasetOpTest, IteratorOutputShapes) { TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) {
std::vector<Tensor> components = test_case.components;
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes; DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes;
gtl::InlinedVector<TensorValue, 4> inputs; gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) { for (auto &component : components) {
inputs.emplace_back(&component); inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype()); dtypes.emplace_back(component.dtype());
} }
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) { for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape()); shapes.emplace_back(expected_outputs[i].shape());
} }
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel; std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel)); &tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context; std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext( TF_ASSERT_OK(
tensor_slice_dataset_kernel.get(), &inputs, CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&tensor_slice_dataset_context)); &inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset; DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(), tensor_slice_dataset_context.get(),
@ -480,7 +525,6 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputShapes) {
EXPECT_TRUE( EXPECT_TRUE(
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape())); produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
} }
}
} }
TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) { TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) {
@ -488,12 +532,21 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) {
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor t1 = CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4}); const TestCase &test_case = PlainTensorTestCase();
Tensor t2 = CreateTensor<int64>(TensorShape({2, 2}), {5, 6, 7, 8}); const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
gtl::InlinedVector<TensorValue, 4> inputs = {&t1, &t2}; std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes({DT_INT64, DT_INT64}); DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes = {PartialTensorShape({2}), gtl::InlinedVector<TensorValue, 4> inputs;
PartialTensorShape({2})}; for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel; std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel)); &tensor_slice_dataset_kernel));
@ -516,35 +569,33 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) {
EXPECT_EQ(iterator->prefix(), "Iterator::TensorSlice"); EXPECT_EQ(iterator->prefix(), "Iterator::TensorSlice");
} }
TEST_F(TensorSliceDatasetOpTest, Roundtrip) { TEST_P(ParameterizedTensorSliceDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) {
std::vector<Tensor> components = test_case.components;
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
std::vector<int> breakpoints = test_case.breakpoints;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes; DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes;
gtl::InlinedVector<TensorValue, 4> inputs; gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) { for (auto &component : components) {
inputs.emplace_back(&component); inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype()); dtypes.emplace_back(component.dtype());
} }
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) { for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape()); shapes.emplace_back(expected_outputs[i].shape());
} }
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel; std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel)); &tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context; std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext( TF_ASSERT_OK(
tensor_slice_dataset_kernel.get(), &inputs, CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&tensor_slice_dataset_context)); &inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset; DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(), tensor_slice_dataset_context.get(),
@ -564,7 +615,7 @@ TEST_F(TensorSliceDatasetOpTest, Roundtrip) {
bool end_of_sequence = false; bool end_of_sequence = false;
int64 num_slices = inputs[0].tensor->dim_size(0); int64 num_slices = inputs[0].tensor->dim_size(0);
std::vector<Tensor> out_tensors; std::vector<Tensor> out_tensors;
const std::vector<int> &breakpoints = test_case.breakpoints;
for (int breakpoint : breakpoints) { for (int breakpoint : breakpoints) {
while (cur_iteration < breakpoint) { while (cur_iteration < breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
@ -580,8 +631,7 @@ TEST_F(TensorSliceDatasetOpTest, Roundtrip) {
const Tensor *output = const Tensor *output =
out_tensors[i].scalar<Variant>()().get<Tensor>(); out_tensors[i].scalar<Variant>()().get<Tensor>();
const Tensor *expected_output = const Tensor *expected_output =
expected_outputs[i + expected_outputs[i + num_tensors_per_slice * (cur_iteration - 1)]
num_tensors_per_slice * (cur_iteration - 1)]
.scalar<Variant>()() .scalar<Variant>()()
.get<Tensor>(); .get<Tensor>();
TF_EXPECT_OK(ExpectEqual(*output, *expected_output)); TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
@ -602,9 +652,13 @@ TEST_F(TensorSliceDatasetOpTest, Roundtrip) {
VariantTensorDataReader reader(&data); VariantTensorDataReader reader(&data);
TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader)); TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader));
} }
}
} }
INSTANTIATE_TEST_SUITE_P(TensorSliceDatasetOpTest,
ParameterizedTensorSliceDatasetOpTest,
::testing::ValuesIn(std::vector<TestCase>(
{PlainTensorTestCase(), NestedTensorTestCase()})));
} // namespace } // namespace
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow