Refactor TensorSliceDatasetOpTest

This commit is contained in:
Fei Hu 2019-03-18 21:31:24 -07:00
parent 578739fb89
commit a280e8644b

View File

@ -46,10 +46,10 @@ class TensorSliceDatasetOpTest : public DatasetOpsTestBase {
components.emplace_back(strings::StrCat("component_", i));
}
node_def_ = test::function::NDef(
NodeDef node_def = test::function::NDef(
kNodeName, kOpName, components,
{{"Toutput_types", dtypes}, {"output_shapes", shapes}});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, tensor_dataset_kernel));
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, tensor_dataset_kernel));
return Status::OK();
}
@ -63,9 +63,6 @@ class TensorSliceDatasetOpTest : public DatasetOpsTestBase {
CreateOpKernelContext(tensor_dataset_kernel, inputs, context));
return Status::OK();
}
private:
NodeDef node_def_;
};
struct TestCase {
@ -74,142 +71,189 @@ struct TestCase {
std::vector<int> breakpoints;
};
std::vector<TestCase> TestCases() {
TestCase PlainTensorTestCase() {
return {/*components*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 2}),
{1, 2, 3, 4}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 1}),
{37.0, 38.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({2, 1}),
{"a", "b"})},
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {37.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"a"}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {3, 4}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {38.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"b"})},
/*breakpoints*/ {0, 1, 3}};
}
TestCase NestedTensorTestCase() {
return {
// A single tuple of tensors.
{{{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 2}),
{1, 2, 3, 4}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 1}),
{37.0, 38.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({2, 1}),
{"a", "b"})}}, // components
{{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {37.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"a"}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {3, 4}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {38.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}),
{"b"})}}, // expected_outputs
{{0, 1, 3}}}, // breakpoints
// Nested tensors
{{{DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({2, 1}),
{DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}),
{1.0, 2.0, 3.0, 4.0}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}),
{5.0, 6.0, 7.0, 8.0})}),
DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({2, 1}), {DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1, 2}), {"a", "b"}),
DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1, 2}), {"c", "d"})}),
DatasetOpsTestBase::CreateTensor<int64>(
TensorShape({2, 3}), {1, 2, 3, 4, 5, 6})}}, // components
{{DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1, 2}), {"a", "b"})}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {1, 2, 3}),
DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}),
DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1, 2}), {"c", "d"})}),
DatasetOpsTestBase::CreateTensor<int64>(
TensorShape({3}), {4, 5, 6})}}, // expected_outputs
{{0, 1, 2}}} // breakpoints
};
/*components*/
{DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({2, 1}),
{DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}),
{1.0, 2.0, 3.0, 4.0}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}),
{5.0, 6.0, 7.0, 8.0})}),
DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({2, 1}), {DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1, 2}), {"a", "b"}),
DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1, 2}), {"c", "d"})}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 3}),
{1, 2, 3, 4, 5, 6})},
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1, 2}), {"a", "b"})}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {1, 2, 3}),
DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}),
DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1, 2}), {"c", "d"})}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {4, 5, 6})},
/*breakpoints*/ {0, 1, 2}};
}
TEST_F(TensorSliceDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) {
std::vector<Tensor> components = test_case.components;
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
class ParameterizedTensorSliceDatasetOpTest
: public TensorSliceDatasetOpTest,
public ::testing::WithParamInterface<TestCase> {};
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.push_back(&component);
dtypes.push_back(component.dtype());
}
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
tensor_slice_dataset_kernel.get(), &inputs,
&tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
"Iterator", &iterator));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
int cur_slice = 0;
while (!end_of_sequence) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
&end_of_sequence));
for (int i = 0; i < out_tensors.size(); ++i) {
EXPECT_LT(i + num_tensors_per_slice * cur_slice,
expected_outputs.size());
if (out_tensors[i].dtype() == DT_VARIANT) {
// Currently `ExpectEqual()` does not support the variant tensor
// yet, so we manually cast the variant to numeric/string tensor.
const Tensor *output =
out_tensors[i].scalar<Variant>()().get<Tensor>();
const Tensor *expected_output =
expected_outputs[i + num_tensors_per_slice * cur_slice]
.scalar<Variant>()()
.get<Tensor>();
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
} else {
TF_EXPECT_OK(ExpectEqual(
out_tensors[i],
expected_outputs[i + num_tensors_per_slice * cur_slice]));
}
}
out_tensors.clear();
cur_slice++;
}
}
}
TEST_F(TensorSliceDatasetOpTest, DatasetName) {
TEST_P(ParameterizedTensorSliceDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor t1 = CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4});
Tensor t2 = CreateTensor<int64>(TensorShape({2, 2}), {5, 6, 7, 8});
gtl::InlinedVector<TensorValue, 4> inputs = {&t1, &t2};
DataTypeVector dtypes({DT_INT64, DT_INT64});
std::vector<PartialTensorShape> shapes = {PartialTensorShape({2}),
PartialTensorShape({2})};
const TestCase &test_case = GetParam();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
"Iterator", &iterator));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
int cur_slice = 0;
while (!end_of_sequence) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
&end_of_sequence));
for (int i = 0; i < out_tensors.size(); ++i) {
EXPECT_LT(i + num_tensors_per_slice * cur_slice, expected_outputs.size());
if (out_tensors[i].dtype() == DT_VARIANT) {
// Currently `ExpectEqual()` does not support the variant tensor
// yet, so we manually cast the variant to numeric/string tensor.
const Tensor *output = out_tensors[i].scalar<Variant>()().get<Tensor>();
const Tensor *expected_output =
expected_outputs[i + num_tensors_per_slice * cur_slice]
.scalar<Variant>()()
.get<Tensor>();
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
} else {
TF_EXPECT_OK(ExpectEqual(
out_tensors[i],
expected_outputs[i + num_tensors_per_slice * cur_slice]));
}
}
out_tensors.clear();
cur_slice++;
}
}
TEST_F(TensorSliceDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = PlainTensorTestCase();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
EXPECT_EQ(tensor_slice_dataset->node_name(), kNodeName);
}
TEST_F(TensorSliceDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = PlainTensorTestCase();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
@ -226,129 +270,124 @@ TEST_F(TensorSliceDatasetOpTest, DatasetName) {
EXPECT_EQ(tensor_slice_dataset->type_string(), kOpName);
}
TEST_F(TensorSliceDatasetOpTest, DatasetOutputDtypes) {
TEST_P(ParameterizedTensorSliceDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) {
std::vector<Tensor> components = test_case.components;
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
tensor_slice_dataset_kernel.get(), &inputs,
&tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
const DataTypeVector produced_output_dtypes =
tensor_slice_dataset->output_dtypes();
EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
}
const DataTypeVector produced_output_dtypes =
tensor_slice_dataset->output_dtypes();
EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
}
}
TEST_F(TensorSliceDatasetOpTest, DatasetOutputShapes) {
TEST_P(ParameterizedTensorSliceDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) {
std::vector<Tensor> components = test_case.components;
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
tensor_slice_dataset_kernel.get(), &inputs,
&tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
const std::vector<PartialTensorShape> produced_output_shapes =
tensor_slice_dataset->output_shapes();
std::vector<PartialTensorShape> expected_output_shapes;
EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
EXPECT_TRUE(
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
}
const std::vector<PartialTensorShape> produced_output_shapes =
tensor_slice_dataset->output_shapes();
std::vector<PartialTensorShape> expected_output_shapes;
EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
EXPECT_TRUE(
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
}
}
TEST_F(TensorSliceDatasetOpTest, Cardinality) {
TEST_P(ParameterizedTensorSliceDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) {
std::vector<Tensor> components = test_case.components;
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
tensor_slice_dataset_kernel.get(), &inputs,
&tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
EXPECT_EQ(tensor_slice_dataset->Cardinality(),
inputs[0].tensor->dim_size(0));
const TestCase &test_case = GetParam();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
EXPECT_EQ(tensor_slice_dataset->Cardinality(), inputs[0].tensor->dim_size(0));
}
TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
@ -356,12 +395,21 @@ TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor t1 = CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4});
Tensor t2 = CreateTensor<int64>(TensorShape({2, 2}), {5, 6, 7, 8});
gtl::InlinedVector<TensorValue, 4> inputs = {&t1, &t2};
DataTypeVector dtypes({DT_INT64, DT_INT64});
std::vector<PartialTensorShape> shapes = {PartialTensorShape({2}),
PartialTensorShape({2})};
const TestCase &test_case = PlainTensorTestCase();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
@ -384,102 +432,98 @@ TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
TF_ASSERT_OK(writer.Flush());
}
TEST_F(TensorSliceDatasetOpTest, IteratorOutputDtypes) {
TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) {
std::vector<Tensor> components = test_case.components;
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
"Iterator", &iterator));
const DataTypeVector produced_output_dtypes = iterator->output_dtypes();
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
tensor_slice_dataset_kernel.get(), &inputs,
&tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
"Iterator", &iterator));
const DataTypeVector produced_output_dtypes = iterator->output_dtypes();
EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
}
EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
}
}
TEST_F(TensorSliceDatasetOpTest, IteratorOutputShapes) {
TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) {
std::vector<Tensor> components = test_case.components;
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
tensor_slice_dataset_kernel.get(), &inputs,
&tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
"Iterator", &iterator));
const std::vector<PartialTensorShape> produced_output_shapes =
iterator->output_shapes();
EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
EXPECT_TRUE(
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
}
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
"Iterator", &iterator));
const std::vector<PartialTensorShape> produced_output_shapes =
iterator->output_shapes();
EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
EXPECT_TRUE(
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
}
}
@ -488,12 +532,21 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) {
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor t1 = CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4});
Tensor t2 = CreateTensor<int64>(TensorShape({2, 2}), {5, 6, 7, 8});
gtl::InlinedVector<TensorValue, 4> inputs = {&t1, &t2};
DataTypeVector dtypes({DT_INT64, DT_INT64});
std::vector<PartialTensorShape> shapes = {PartialTensorShape({2}),
PartialTensorShape({2})};
const TestCase &test_case = PlainTensorTestCase();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
@ -516,95 +569,96 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) {
EXPECT_EQ(iterator->prefix(), "Iterator::TensorSlice");
}
TEST_F(TensorSliceDatasetOpTest, Roundtrip) {
TEST_P(ParameterizedTensorSliceDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) {
std::vector<Tensor> components = test_case.components;
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
std::vector<int> breakpoints = test_case.breakpoints;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
"Iterator", &iterator));
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
int cur_iteration = 0;
bool end_of_sequence = false;
int64 num_slices = inputs[0].tensor->dim_size(0);
std::vector<Tensor> out_tensors;
const std::vector<int> &breakpoints = test_case.breakpoints;
for (int breakpoint : breakpoints) {
while (cur_iteration < breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
&end_of_sequence));
cur_iteration++;
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
tensor_slice_dataset_kernel.get(), &inputs,
&tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
"Iterator", &iterator));
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
int cur_iteration = 0;
bool end_of_sequence = false;
int64 num_slices = inputs[0].tensor->dim_size(0);
std::vector<Tensor> out_tensors;
for (int breakpoint : breakpoints) {
while (cur_iteration < breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
&end_of_sequence));
cur_iteration++;
}
if (breakpoint == 0) {
EXPECT_FALSE(end_of_sequence);
} else if (breakpoint <= num_slices) {
for (int i = 0; i < out_tensors.size(); ++i) {
if (out_tensors[i].dtype() == DT_VARIANT) {
const Tensor *output =
out_tensors[i].scalar<Variant>()().get<Tensor>();
const Tensor *expected_output =
expected_outputs[i +
num_tensors_per_slice * (cur_iteration - 1)]
.scalar<Variant>()()
.get<Tensor>();
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
} else {
TF_EXPECT_OK(ExpectEqual(
out_tensors[i], expected_outputs[i + num_tensors_per_slice *
(cur_iteration - 1)]));
}
if (breakpoint == 0) {
EXPECT_FALSE(end_of_sequence);
} else if (breakpoint <= num_slices) {
for (int i = 0; i < out_tensors.size(); ++i) {
if (out_tensors[i].dtype() == DT_VARIANT) {
const Tensor *output =
out_tensors[i].scalar<Variant>()().get<Tensor>();
const Tensor *expected_output =
expected_outputs[i + num_tensors_per_slice * (cur_iteration - 1)]
.scalar<Variant>()()
.get<Tensor>();
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
} else {
TF_EXPECT_OK(ExpectEqual(
out_tensors[i], expected_outputs[i + num_tensors_per_slice *
(cur_iteration - 1)]));
}
} else {
EXPECT_TRUE(end_of_sequence);
}
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader));
} else {
EXPECT_TRUE(end_of_sequence);
}
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader));
}
}
INSTANTIATE_TEST_SUITE_P(TensorSliceDatasetOpTest,
ParameterizedTensorSliceDatasetOpTest,
::testing::ValuesIn(std::vector<TestCase>(
{PlainTensorTestCase(), NestedTensorTestCase()})));
} // namespace
} // namespace data
} // namespace tensorflow