Refactor TensorSliceDatasetOpTest

This commit is contained in:
Fei Hu 2019-03-18 21:31:24 -07:00
parent 578739fb89
commit a280e8644b

View File

@ -46,10 +46,10 @@ class TensorSliceDatasetOpTest : public DatasetOpsTestBase {
components.emplace_back(strings::StrCat("component_", i)); components.emplace_back(strings::StrCat("component_", i));
} }
node_def_ = test::function::NDef( NodeDef node_def = test::function::NDef(
kNodeName, kOpName, components, kNodeName, kOpName, components,
{{"Toutput_types", dtypes}, {"output_shapes", shapes}}); {{"Toutput_types", dtypes}, {"output_shapes", shapes}});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, tensor_dataset_kernel)); TF_RETURN_IF_ERROR(CreateOpKernel(node_def, tensor_dataset_kernel));
return Status::OK(); return Status::OK();
} }
@ -63,9 +63,6 @@ class TensorSliceDatasetOpTest : public DatasetOpsTestBase {
CreateOpKernelContext(tensor_dataset_kernel, inputs, context)); CreateOpKernelContext(tensor_dataset_kernel, inputs, context));
return Status::OK(); return Status::OK();
} }
private:
NodeDef node_def_;
}; };
struct TestCase { struct TestCase {
@ -74,142 +71,189 @@ struct TestCase {
std::vector<int> breakpoints; std::vector<int> breakpoints;
}; };
std::vector<TestCase> TestCases() { TestCase PlainTensorTestCase() {
return {/*components*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 2}),
{1, 2, 3, 4}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 1}),
{37.0, 38.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({2, 1}),
{"a", "b"})},
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {37.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"a"}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {3, 4}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {38.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"b"})},
/*breakpoints*/ {0, 1, 3}};
}
TestCase NestedTensorTestCase() {
return { return {
// A single tuple of tensors. /*components*/
{{{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}), {DatasetOpsTestBase::CreateTensor<Variant>(
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 2}), TensorShape({2, 1}),
{1, 2, 3, 4}), {DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 1}), {1.0, 2.0, 3.0, 4.0}),
{37.0, 38.0}), DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({2, 1}), {5.0, 6.0, 7.0, 8.0})}),
{"a", "b"})}}, // components DatasetOpsTestBase::CreateTensor<Variant>(
{{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}), TensorShape({2, 1}), {DatasetOpsTestBase::CreateTensor<string>(
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}), TensorShape({1, 2}), {"a", "b"}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {37.0}), DatasetOpsTestBase::CreateTensor<string>(
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"a"}), TensorShape({1, 2}), {"c", "d"})}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}), DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 3}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {3, 4}), {1, 2, 3, 4, 5, 6})},
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {38.0}), /*expected_outputs*/
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {DatasetOpsTestBase::CreateTensor<Variant>(
{"b"})}}, // expected_outputs TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
{{0, 1, 3}}}, // breakpoints TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
// Nested tensors DatasetOpsTestBase::CreateTensor<Variant>(
{{{DatasetOpsTestBase::CreateTensor<Variant>( TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
TensorShape({2, 1}), TensorShape({1, 2}), {"a", "b"})}),
{DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}), DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {1, 2, 3}),
{1.0, 2.0, 3.0, 4.0}), DatasetOpsTestBase::CreateTensor<Variant>(
DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}), TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
{5.0, 6.0, 7.0, 8.0})}), TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}),
DatasetOpsTestBase::CreateTensor<Variant>( DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({2, 1}), {DatasetOpsTestBase::CreateTensor<string>( TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1, 2}), {"a", "b"}), TensorShape({1, 2}), {"c", "d"})}),
DatasetOpsTestBase::CreateTensor<string>( DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {4, 5, 6})},
TensorShape({1, 2}), {"c", "d"})}), /*breakpoints*/ {0, 1, 2}};
DatasetOpsTestBase::CreateTensor<int64>(
TensorShape({2, 3}), {1, 2, 3, 4, 5, 6})}}, // components
{{DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1, 2}), {"a", "b"})}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {1, 2, 3}),
DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}),
DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1, 2}), {"c", "d"})}),
DatasetOpsTestBase::CreateTensor<int64>(
TensorShape({3}), {4, 5, 6})}}, // expected_outputs
{{0, 1, 2}}} // breakpoints
};
} }
TEST_F(TensorSliceDatasetOpTest, GetNext) { class ParameterizedTensorSliceDatasetOpTest
int thread_num = 2, cpu_num = 2; : public TensorSliceDatasetOpTest,
for (auto &test_case : TestCases()) { public ::testing::WithParamInterface<TestCase> {};
std::vector<Tensor> components = test_case.components;
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num)); TEST_P(ParameterizedTensorSliceDatasetOpTest, GetNext) {
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.push_back(&component);
dtypes.push_back(component.dtype());
}
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
tensor_slice_dataset_kernel.get(), &inputs,
&tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
"Iterator", &iterator));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
int cur_slice = 0;
while (!end_of_sequence) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
&end_of_sequence));
for (int i = 0; i < out_tensors.size(); ++i) {
EXPECT_LT(i + num_tensors_per_slice * cur_slice,
expected_outputs.size());
if (out_tensors[i].dtype() == DT_VARIANT) {
// Currently `ExpectEqual()` does not support the variant tensor
// yet, so we manually cast the variant to numeric/string tensor.
const Tensor *output =
out_tensors[i].scalar<Variant>()().get<Tensor>();
const Tensor *expected_output =
expected_outputs[i + num_tensors_per_slice * cur_slice]
.scalar<Variant>()()
.get<Tensor>();
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
} else {
TF_EXPECT_OK(ExpectEqual(
out_tensors[i],
expected_outputs[i + num_tensors_per_slice * cur_slice]));
}
}
out_tensors.clear();
cur_slice++;
}
}
}
TEST_F(TensorSliceDatasetOpTest, DatasetName) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor t1 = CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4}); const TestCase &test_case = GetParam();
Tensor t2 = CreateTensor<int64>(TensorShape({2, 2}), {5, 6, 7, 8}); const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
gtl::InlinedVector<TensorValue, 4> inputs = {&t1, &t2}; std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes({DT_INT64, DT_INT64}); DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes = {PartialTensorShape({2}), gtl::InlinedVector<TensorValue, 4> inputs;
PartialTensorShape({2})}; for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
"Iterator", &iterator));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
int cur_slice = 0;
while (!end_of_sequence) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
&end_of_sequence));
for (int i = 0; i < out_tensors.size(); ++i) {
EXPECT_LT(i + num_tensors_per_slice * cur_slice, expected_outputs.size());
if (out_tensors[i].dtype() == DT_VARIANT) {
// Currently `ExpectEqual()` does not support the variant tensor
// yet, so we manually cast the variant to numeric/string tensor.
const Tensor *output = out_tensors[i].scalar<Variant>()().get<Tensor>();
const Tensor *expected_output =
expected_outputs[i + num_tensors_per_slice * cur_slice]
.scalar<Variant>()()
.get<Tensor>();
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
} else {
TF_EXPECT_OK(ExpectEqual(
out_tensors[i],
expected_outputs[i + num_tensors_per_slice * cur_slice]));
}
}
out_tensors.clear();
cur_slice++;
}
}
TEST_F(TensorSliceDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = PlainTensorTestCase();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
EXPECT_EQ(tensor_slice_dataset->node_name(), kNodeName);
}
TEST_F(TensorSliceDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = PlainTensorTestCase();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel; std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel)); &tensor_slice_dataset_kernel));
@ -226,129 +270,124 @@ TEST_F(TensorSliceDatasetOpTest, DatasetName) {
EXPECT_EQ(tensor_slice_dataset->type_string(), kOpName); EXPECT_EQ(tensor_slice_dataset->type_string(), kOpName);
} }
TEST_F(TensorSliceDatasetOpTest, DatasetOutputDtypes) { TEST_P(ParameterizedTensorSliceDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) { TF_ASSERT_OK(InitThreadPool(thread_num));
std::vector<Tensor> components = test_case.components; TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num)); const TestCase &test_case = GetParam();
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
DataTypeVector dtypes; const DataTypeVector produced_output_dtypes =
std::vector<PartialTensorShape> shapes; tensor_slice_dataset->output_dtypes();
gtl::InlinedVector<TensorValue, 4> inputs; EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice);
for (auto &component : components) { for (int i = 0; i < num_tensors_per_slice; ++i) {
inputs.emplace_back(&component); EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
dtypes.emplace_back(component.dtype());
}
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
tensor_slice_dataset_kernel.get(), &inputs,
&tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
const DataTypeVector produced_output_dtypes =
tensor_slice_dataset->output_dtypes();
EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
}
} }
} }
TEST_F(TensorSliceDatasetOpTest, DatasetOutputShapes) { TEST_P(ParameterizedTensorSliceDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) { TF_ASSERT_OK(InitThreadPool(thread_num));
std::vector<Tensor> components = test_case.components; TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num)); const TestCase &test_case = GetParam();
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
DataTypeVector dtypes; const std::vector<PartialTensorShape> produced_output_shapes =
std::vector<PartialTensorShape> shapes; tensor_slice_dataset->output_shapes();
gtl::InlinedVector<TensorValue, 4> inputs; std::vector<PartialTensorShape> expected_output_shapes;
for (auto &component : components) { EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice);
inputs.emplace_back(&component); for (int i = 0; i < num_tensors_per_slice; ++i) {
dtypes.emplace_back(component.dtype()); EXPECT_TRUE(
} produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
tensor_slice_dataset_kernel.get(), &inputs,
&tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
const std::vector<PartialTensorShape> produced_output_shapes =
tensor_slice_dataset->output_shapes();
std::vector<PartialTensorShape> expected_output_shapes;
EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
EXPECT_TRUE(
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
}
} }
} }
TEST_F(TensorSliceDatasetOpTest, Cardinality) { TEST_P(ParameterizedTensorSliceDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) { TF_ASSERT_OK(InitThreadPool(thread_num));
std::vector<Tensor> components = test_case.components; TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num)); const TestCase &test_case = GetParam();
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes; DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes; gtl::InlinedVector<TensorValue, 4> inputs;
gtl::InlinedVector<TensorValue, 4> inputs; for (auto &component : components) {
for (auto &component : components) { inputs.emplace_back(&component);
inputs.emplace_back(&component); dtypes.emplace_back(component.dtype());
dtypes.emplace_back(component.dtype());
}
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
tensor_slice_dataset_kernel.get(), &inputs,
&tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
EXPECT_EQ(tensor_slice_dataset->Cardinality(),
inputs[0].tensor->dim_size(0));
} }
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
EXPECT_EQ(tensor_slice_dataset->Cardinality(), inputs[0].tensor->dim_size(0));
} }
TEST_F(TensorSliceDatasetOpTest, DatasetSave) { TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
@ -356,12 +395,21 @@ TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor t1 = CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4}); const TestCase &test_case = PlainTensorTestCase();
Tensor t2 = CreateTensor<int64>(TensorShape({2, 2}), {5, 6, 7, 8}); const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
gtl::InlinedVector<TensorValue, 4> inputs = {&t1, &t2}; std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes({DT_INT64, DT_INT64}); DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes = {PartialTensorShape({2}), gtl::InlinedVector<TensorValue, 4> inputs;
PartialTensorShape({2})}; for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel; std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel)); &tensor_slice_dataset_kernel));
@ -384,102 +432,98 @@ TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
TF_ASSERT_OK(writer.Flush()); TF_ASSERT_OK(writer.Flush());
} }
TEST_F(TensorSliceDatasetOpTest, IteratorOutputDtypes) { TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) { TF_ASSERT_OK(InitThreadPool(thread_num));
std::vector<Tensor> components = test_case.components; TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num)); const TestCase &test_case = GetParam();
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
DataTypeVector dtypes; std::unique_ptr<IteratorContext> iterator_context;
std::vector<PartialTensorShape> shapes; TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
gtl::InlinedVector<TensorValue, 4> inputs; &iterator_context));
for (auto &component : components) { std::unique_ptr<IteratorBase> iterator;
inputs.emplace_back(&component); TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
dtypes.emplace_back(component.dtype()); "Iterator", &iterator));
} const DataTypeVector produced_output_dtypes = iterator->output_dtypes();
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel; EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice);
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, for (int i = 0; i < num_tensors_per_slice; ++i) {
&tensor_slice_dataset_kernel)); EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
tensor_slice_dataset_kernel.get(), &inputs,
&tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
"Iterator", &iterator));
const DataTypeVector produced_output_dtypes = iterator->output_dtypes();
EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
}
} }
} }
TEST_F(TensorSliceDatasetOpTest, IteratorOutputShapes) { TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) { TF_ASSERT_OK(InitThreadPool(thread_num));
std::vector<Tensor> components = test_case.components; TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num)); const TestCase &test_case = GetParam();
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
DataTypeVector dtypes; std::unique_ptr<IteratorContext> iterator_context;
std::vector<PartialTensorShape> shapes; TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
gtl::InlinedVector<TensorValue, 4> inputs; &iterator_context));
for (auto &component : components) { std::unique_ptr<IteratorBase> iterator;
inputs.emplace_back(&component); TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
dtypes.emplace_back(component.dtype()); "Iterator", &iterator));
} const std::vector<PartialTensorShape> produced_output_shapes =
for (int i = 0; i < num_tensors_per_slice; ++i) { iterator->output_shapes();
shapes.emplace_back(expected_outputs[i].shape()); EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice);
} for (int i = 0; i < num_tensors_per_slice; ++i) {
EXPECT_TRUE(
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel; produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
tensor_slice_dataset_kernel.get(), &inputs,
&tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
"Iterator", &iterator));
const std::vector<PartialTensorShape> produced_output_shapes =
iterator->output_shapes();
EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
EXPECT_TRUE(
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
}
} }
} }
@ -488,12 +532,21 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) {
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor t1 = CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4}); const TestCase &test_case = PlainTensorTestCase();
Tensor t2 = CreateTensor<int64>(TensorShape({2, 2}), {5, 6, 7, 8}); const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
gtl::InlinedVector<TensorValue, 4> inputs = {&t1, &t2}; std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes({DT_INT64, DT_INT64}); DataTypeVector dtypes;
std::vector<PartialTensorShape> shapes = {PartialTensorShape({2}), gtl::InlinedVector<TensorValue, 4> inputs;
PartialTensorShape({2})}; for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel; std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel)); &tensor_slice_dataset_kernel));
@ -516,95 +569,96 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) {
EXPECT_EQ(iterator->prefix(), "Iterator::TensorSlice"); EXPECT_EQ(iterator->prefix(), "Iterator::TensorSlice");
} }
TEST_F(TensorSliceDatasetOpTest, Roundtrip) { TEST_P(ParameterizedTensorSliceDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
for (auto &test_case : TestCases()) { TF_ASSERT_OK(InitThreadPool(thread_num));
std::vector<Tensor> components = test_case.components; TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
std::vector<int> breakpoints = test_case.breakpoints;
size_t num_tensors_per_slice = components.size();
TF_ASSERT_OK(InitThreadPool(thread_num)); const TestCase &test_case = GetParam();
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scored_unref(tensor_slice_dataset);
DataTypeVector dtypes; std::unique_ptr<IteratorContext> iterator_context;
std::vector<PartialTensorShape> shapes; TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
gtl::InlinedVector<TensorValue, 4> inputs; &iterator_context));
for (auto &component : components) { std::unique_ptr<IteratorBase> iterator;
inputs.emplace_back(&component); TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
dtypes.emplace_back(component.dtype()); "Iterator", &iterator));
} std::unique_ptr<SerializationContext> serialization_context;
for (int i = 0; i < num_tensors_per_slice; ++i) { TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
shapes.emplace_back(expected_outputs[i].shape());
int cur_iteration = 0;
bool end_of_sequence = false;
int64 num_slices = inputs[0].tensor->dim_size(0);
std::vector<Tensor> out_tensors;
const std::vector<int> &breakpoints = test_case.breakpoints;
for (int breakpoint : breakpoints) {
while (cur_iteration < breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
&end_of_sequence));
cur_iteration++;
} }
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel; if (breakpoint == 0) {
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, EXPECT_FALSE(end_of_sequence);
&tensor_slice_dataset_kernel)); } else if (breakpoint <= num_slices) {
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context; for (int i = 0; i < out_tensors.size(); ++i) {
TF_ASSERT_OK(CreateTensorSliceDatasetContext( if (out_tensors[i].dtype() == DT_VARIANT) {
tensor_slice_dataset_kernel.get(), &inputs, const Tensor *output =
&tensor_slice_dataset_context)); out_tensors[i].scalar<Variant>()().get<Tensor>();
DatasetBase *tensor_slice_dataset; const Tensor *expected_output =
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), expected_outputs[i + num_tensors_per_slice * (cur_iteration - 1)]
tensor_slice_dataset_context.get(), .scalar<Variant>()()
&tensor_slice_dataset)); .get<Tensor>();
core::ScopedUnref scored_unref(tensor_slice_dataset); TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
} else {
std::unique_ptr<IteratorContext> iterator_context; TF_EXPECT_OK(ExpectEqual(
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(), out_tensors[i], expected_outputs[i + num_tensors_per_slice *
&iterator_context)); (cur_iteration - 1)]));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
"Iterator", &iterator));
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
int cur_iteration = 0;
bool end_of_sequence = false;
int64 num_slices = inputs[0].tensor->dim_size(0);
std::vector<Tensor> out_tensors;
for (int breakpoint : breakpoints) {
while (cur_iteration < breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
&end_of_sequence));
cur_iteration++;
}
if (breakpoint == 0) {
EXPECT_FALSE(end_of_sequence);
} else if (breakpoint <= num_slices) {
for (int i = 0; i < out_tensors.size(); ++i) {
if (out_tensors[i].dtype() == DT_VARIANT) {
const Tensor *output =
out_tensors[i].scalar<Variant>()().get<Tensor>();
const Tensor *expected_output =
expected_outputs[i +
num_tensors_per_slice * (cur_iteration - 1)]
.scalar<Variant>()()
.get<Tensor>();
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
} else {
TF_EXPECT_OK(ExpectEqual(
out_tensors[i], expected_outputs[i + num_tensors_per_slice *
(cur_iteration - 1)]));
}
} }
} else {
EXPECT_TRUE(end_of_sequence);
} }
} else {
VariantTensorData data; EXPECT_TRUE(end_of_sequence);
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader));
} }
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader));
} }
} }
INSTANTIATE_TEST_SUITE_P(TensorSliceDatasetOpTest,
ParameterizedTensorSliceDatasetOpTest,
::testing::ValuesIn(std::vector<TestCase>(
{PlainTensorTestCase(), NestedTensorTestCase()})));
} // namespace } // namespace
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow