Refactor TensorSliceDatasetOpTest
This commit is contained in:
parent
578739fb89
commit
a280e8644b
@ -46,10 +46,10 @@ class TensorSliceDatasetOpTest : public DatasetOpsTestBase {
|
||||
components.emplace_back(strings::StrCat("component_", i));
|
||||
}
|
||||
|
||||
node_def_ = test::function::NDef(
|
||||
NodeDef node_def = test::function::NDef(
|
||||
kNodeName, kOpName, components,
|
||||
{{"Toutput_types", dtypes}, {"output_shapes", shapes}});
|
||||
TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, tensor_dataset_kernel));
|
||||
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, tensor_dataset_kernel));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -63,9 +63,6 @@ class TensorSliceDatasetOpTest : public DatasetOpsTestBase {
|
||||
CreateOpKernelContext(tensor_dataset_kernel, inputs, context));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
NodeDef node_def_;
|
||||
};
|
||||
|
||||
struct TestCase {
|
||||
@ -74,142 +71,189 @@ struct TestCase {
|
||||
std::vector<int> breakpoints;
|
||||
};
|
||||
|
||||
std::vector<TestCase> TestCases() {
|
||||
TestCase PlainTensorTestCase() {
|
||||
return {/*components*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 2}),
|
||||
{1, 2, 3, 4}),
|
||||
DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 1}),
|
||||
{37.0, 38.0}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({2, 1}),
|
||||
{"a", "b"})},
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
|
||||
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {37.0}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"a"}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {3, 4}),
|
||||
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {38.0}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"b"})},
|
||||
/*breakpoints*/ {0, 1, 3}};
|
||||
}
|
||||
|
||||
TestCase NestedTensorTestCase() {
|
||||
return {
|
||||
// A single tuple of tensors.
|
||||
{{{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 2}),
|
||||
{1, 2, 3, 4}),
|
||||
DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 1}),
|
||||
{37.0, 38.0}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({2, 1}),
|
||||
{"a", "b"})}}, // components
|
||||
{{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
|
||||
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {37.0}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"a"}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {3, 4}),
|
||||
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {38.0}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}),
|
||||
{"b"})}}, // expected_outputs
|
||||
{{0, 1, 3}}}, // breakpoints
|
||||
// Nested tensors
|
||||
{{{DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({2, 1}),
|
||||
{DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}),
|
||||
{1.0, 2.0, 3.0, 4.0}),
|
||||
DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}),
|
||||
{5.0, 6.0, 7.0, 8.0})}),
|
||||
DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({2, 1}), {DatasetOpsTestBase::CreateTensor<string>(
|
||||
TensorShape({1, 2}), {"a", "b"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(
|
||||
TensorShape({1, 2}), {"c", "d"})}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(
|
||||
TensorShape({2, 3}), {1, 2, 3, 4, 5, 6})}}, // components
|
||||
{{DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
|
||||
TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
|
||||
DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
|
||||
TensorShape({1, 2}), {"a", "b"})}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {1, 2, 3}),
|
||||
DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
|
||||
TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}),
|
||||
DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
|
||||
TensorShape({1, 2}), {"c", "d"})}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(
|
||||
TensorShape({3}), {4, 5, 6})}}, // expected_outputs
|
||||
{{0, 1, 2}}} // breakpoints
|
||||
};
|
||||
/*components*/
|
||||
{DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({2, 1}),
|
||||
{DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}),
|
||||
{1.0, 2.0, 3.0, 4.0}),
|
||||
DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}),
|
||||
{5.0, 6.0, 7.0, 8.0})}),
|
||||
DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({2, 1}), {DatasetOpsTestBase::CreateTensor<string>(
|
||||
TensorShape({1, 2}), {"a", "b"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(
|
||||
TensorShape({1, 2}), {"c", "d"})}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 3}),
|
||||
{1, 2, 3, 4, 5, 6})},
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
|
||||
TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
|
||||
DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
|
||||
TensorShape({1, 2}), {"a", "b"})}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {1, 2, 3}),
|
||||
DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
|
||||
TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}),
|
||||
DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
|
||||
TensorShape({1, 2}), {"c", "d"})}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {4, 5, 6})},
|
||||
/*breakpoints*/ {0, 1, 2}};
|
||||
}
|
||||
|
||||
TEST_F(TensorSliceDatasetOpTest, GetNext) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
for (auto &test_case : TestCases()) {
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
class ParameterizedTensorSliceDatasetOpTest
|
||||
: public TensorSliceDatasetOpTest,
|
||||
public ::testing::WithParamInterface<TestCase> {};
|
||||
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
DataTypeVector dtypes;
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.push_back(&component);
|
||||
dtypes.push_back(component.dtype());
|
||||
}
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
|
||||
tensor_slice_dataset_kernel.get(), &inputs,
|
||||
&tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
|
||||
&iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
|
||||
"Iterator", &iterator));
|
||||
bool end_of_sequence = false;
|
||||
std::vector<Tensor> out_tensors;
|
||||
int cur_slice = 0;
|
||||
|
||||
while (!end_of_sequence) {
|
||||
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
||||
&end_of_sequence));
|
||||
for (int i = 0; i < out_tensors.size(); ++i) {
|
||||
EXPECT_LT(i + num_tensors_per_slice * cur_slice,
|
||||
expected_outputs.size());
|
||||
if (out_tensors[i].dtype() == DT_VARIANT) {
|
||||
// Currently `ExpectEqual()` does not support the variant tensor
|
||||
// yet, so we manually cast the variant to numeric/string tensor.
|
||||
const Tensor *output =
|
||||
out_tensors[i].scalar<Variant>()().get<Tensor>();
|
||||
const Tensor *expected_output =
|
||||
expected_outputs[i + num_tensors_per_slice * cur_slice]
|
||||
.scalar<Variant>()()
|
||||
.get<Tensor>();
|
||||
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
|
||||
} else {
|
||||
TF_EXPECT_OK(ExpectEqual(
|
||||
out_tensors[i],
|
||||
expected_outputs[i + num_tensors_per_slice * cur_slice]));
|
||||
}
|
||||
}
|
||||
out_tensors.clear();
|
||||
cur_slice++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorSliceDatasetOpTest, DatasetName) {
|
||||
TEST_P(ParameterizedTensorSliceDatasetOpTest, GetNext) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor t1 = CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4});
|
||||
Tensor t2 = CreateTensor<int64>(TensorShape({2, 2}), {5, 6, 7, 8});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs = {&t1, &t2};
|
||||
DataTypeVector dtypes({DT_INT64, DT_INT64});
|
||||
std::vector<PartialTensorShape> shapes = {PartialTensorShape({2}),
|
||||
PartialTensorShape({2})};
|
||||
const TestCase &test_case = GetParam();
|
||||
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
DataTypeVector dtypes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
shapes.reserve(num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(
|
||||
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||
&inputs, &tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
|
||||
&iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
|
||||
"Iterator", &iterator));
|
||||
bool end_of_sequence = false;
|
||||
std::vector<Tensor> out_tensors;
|
||||
int cur_slice = 0;
|
||||
|
||||
while (!end_of_sequence) {
|
||||
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
||||
&end_of_sequence));
|
||||
for (int i = 0; i < out_tensors.size(); ++i) {
|
||||
EXPECT_LT(i + num_tensors_per_slice * cur_slice, expected_outputs.size());
|
||||
if (out_tensors[i].dtype() == DT_VARIANT) {
|
||||
// Currently `ExpectEqual()` does not support the variant tensor
|
||||
// yet, so we manually cast the variant to numeric/string tensor.
|
||||
const Tensor *output = out_tensors[i].scalar<Variant>()().get<Tensor>();
|
||||
const Tensor *expected_output =
|
||||
expected_outputs[i + num_tensors_per_slice * cur_slice]
|
||||
.scalar<Variant>()()
|
||||
.get<Tensor>();
|
||||
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
|
||||
} else {
|
||||
TF_EXPECT_OK(ExpectEqual(
|
||||
out_tensors[i],
|
||||
expected_outputs[i + num_tensors_per_slice * cur_slice]));
|
||||
}
|
||||
}
|
||||
out_tensors.clear();
|
||||
cur_slice++;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorSliceDatasetOpTest, DatasetNodeName) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = PlainTensorTestCase();
|
||||
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
DataTypeVector dtypes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
shapes.reserve(num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(
|
||||
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||
&inputs, &tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||
|
||||
EXPECT_EQ(tensor_slice_dataset->node_name(), kNodeName);
|
||||
}
|
||||
|
||||
TEST_F(TensorSliceDatasetOpTest, DatasetTypeString) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = PlainTensorTestCase();
|
||||
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
DataTypeVector dtypes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
shapes.reserve(num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
@ -226,129 +270,124 @@ TEST_F(TensorSliceDatasetOpTest, DatasetName) {
|
||||
EXPECT_EQ(tensor_slice_dataset->type_string(), kOpName);
|
||||
}
|
||||
|
||||
TEST_F(TensorSliceDatasetOpTest, DatasetOutputDtypes) {
|
||||
TEST_P(ParameterizedTensorSliceDatasetOpTest, DatasetOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
for (auto &test_case : TestCases()) {
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
const TestCase &test_case = GetParam();
|
||||
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
DataTypeVector dtypes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
shapes.reserve(num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(
|
||||
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||
&inputs, &tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||
|
||||
DataTypeVector dtypes;
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
|
||||
tensor_slice_dataset_kernel.get(), &inputs,
|
||||
&tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||
|
||||
const DataTypeVector produced_output_dtypes =
|
||||
tensor_slice_dataset->output_dtypes();
|
||||
EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
|
||||
}
|
||||
const DataTypeVector produced_output_dtypes =
|
||||
tensor_slice_dataset->output_dtypes();
|
||||
EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorSliceDatasetOpTest, DatasetOutputShapes) {
|
||||
TEST_P(ParameterizedTensorSliceDatasetOpTest, DatasetOutputShapes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
for (auto &test_case : TestCases()) {
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
const TestCase &test_case = GetParam();
|
||||
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
DataTypeVector dtypes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
shapes.reserve(num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(
|
||||
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||
&inputs, &tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||
|
||||
DataTypeVector dtypes;
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
|
||||
tensor_slice_dataset_kernel.get(), &inputs,
|
||||
&tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||
|
||||
const std::vector<PartialTensorShape> produced_output_shapes =
|
||||
tensor_slice_dataset->output_shapes();
|
||||
std::vector<PartialTensorShape> expected_output_shapes;
|
||||
EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
EXPECT_TRUE(
|
||||
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
|
||||
}
|
||||
const std::vector<PartialTensorShape> produced_output_shapes =
|
||||
tensor_slice_dataset->output_shapes();
|
||||
std::vector<PartialTensorShape> expected_output_shapes;
|
||||
EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
EXPECT_TRUE(
|
||||
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorSliceDatasetOpTest, Cardinality) {
|
||||
TEST_P(ParameterizedTensorSliceDatasetOpTest, Cardinality) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
for (auto &test_case : TestCases()) {
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
DataTypeVector dtypes;
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
|
||||
tensor_slice_dataset_kernel.get(), &inputs,
|
||||
&tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||
|
||||
EXPECT_EQ(tensor_slice_dataset->Cardinality(),
|
||||
inputs[0].tensor->dim_size(0));
|
||||
const TestCase &test_case = GetParam();
|
||||
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
DataTypeVector dtypes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
shapes.reserve(num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(
|
||||
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||
&inputs, &tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||
|
||||
EXPECT_EQ(tensor_slice_dataset->Cardinality(), inputs[0].tensor->dim_size(0));
|
||||
}
|
||||
|
||||
TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
|
||||
@ -356,12 +395,21 @@ TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor t1 = CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4});
|
||||
Tensor t2 = CreateTensor<int64>(TensorShape({2, 2}), {5, 6, 7, 8});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs = {&t1, &t2};
|
||||
DataTypeVector dtypes({DT_INT64, DT_INT64});
|
||||
std::vector<PartialTensorShape> shapes = {PartialTensorShape({2}),
|
||||
PartialTensorShape({2})};
|
||||
const TestCase &test_case = PlainTensorTestCase();
|
||||
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
DataTypeVector dtypes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
shapes.reserve(num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
@ -384,102 +432,98 @@ TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_F(TensorSliceDatasetOpTest, IteratorOutputDtypes) {
|
||||
TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
for (auto &test_case : TestCases()) {
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
const TestCase &test_case = GetParam();
|
||||
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
DataTypeVector dtypes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
shapes.reserve(num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(
|
||||
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||
&inputs, &tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||
|
||||
DataTypeVector dtypes;
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
|
||||
&iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
|
||||
"Iterator", &iterator));
|
||||
const DataTypeVector produced_output_dtypes = iterator->output_dtypes();
|
||||
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
|
||||
tensor_slice_dataset_kernel.get(), &inputs,
|
||||
&tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
|
||||
&iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
|
||||
"Iterator", &iterator));
|
||||
const DataTypeVector produced_output_dtypes = iterator->output_dtypes();
|
||||
|
||||
EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
|
||||
}
|
||||
EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorSliceDatasetOpTest, IteratorOutputShapes) {
|
||||
TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputShapes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
for (auto &test_case : TestCases()) {
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
const TestCase &test_case = GetParam();
|
||||
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
DataTypeVector dtypes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
shapes.reserve(num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(
|
||||
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||
&inputs, &tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||
|
||||
DataTypeVector dtypes;
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
|
||||
tensor_slice_dataset_kernel.get(), &inputs,
|
||||
&tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
|
||||
&iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
|
||||
"Iterator", &iterator));
|
||||
const std::vector<PartialTensorShape> produced_output_shapes =
|
||||
iterator->output_shapes();
|
||||
EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
EXPECT_TRUE(
|
||||
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
|
||||
}
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
|
||||
&iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
|
||||
"Iterator", &iterator));
|
||||
const std::vector<PartialTensorShape> produced_output_shapes =
|
||||
iterator->output_shapes();
|
||||
EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
EXPECT_TRUE(
|
||||
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -488,12 +532,21 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) {
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor t1 = CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4});
|
||||
Tensor t2 = CreateTensor<int64>(TensorShape({2, 2}), {5, 6, 7, 8});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs = {&t1, &t2};
|
||||
DataTypeVector dtypes({DT_INT64, DT_INT64});
|
||||
std::vector<PartialTensorShape> shapes = {PartialTensorShape({2}),
|
||||
PartialTensorShape({2})};
|
||||
const TestCase &test_case = PlainTensorTestCase();
|
||||
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
DataTypeVector dtypes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
shapes.reserve(num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
@ -516,95 +569,96 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) {
|
||||
EXPECT_EQ(iterator->prefix(), "Iterator::TensorSlice");
|
||||
}
|
||||
|
||||
TEST_F(TensorSliceDatasetOpTest, Roundtrip) {
|
||||
TEST_P(ParameterizedTensorSliceDatasetOpTest, Roundtrip) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
for (auto &test_case : TestCases()) {
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
|
||||
std::vector<int> breakpoints = test_case.breakpoints;
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
const TestCase &test_case = GetParam();
|
||||
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
DataTypeVector dtypes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
shapes.reserve(num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(
|
||||
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||
&inputs, &tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||
|
||||
DataTypeVector dtypes;
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
|
||||
&iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
|
||||
"Iterator", &iterator));
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
|
||||
int cur_iteration = 0;
|
||||
bool end_of_sequence = false;
|
||||
int64 num_slices = inputs[0].tensor->dim_size(0);
|
||||
std::vector<Tensor> out_tensors;
|
||||
const std::vector<int> &breakpoints = test_case.breakpoints;
|
||||
for (int breakpoint : breakpoints) {
|
||||
while (cur_iteration < breakpoint) {
|
||||
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
||||
&end_of_sequence));
|
||||
cur_iteration++;
|
||||
}
|
||||
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
|
||||
tensor_slice_dataset_kernel.get(), &inputs,
|
||||
&tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(),
|
||||
&iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(),
|
||||
"Iterator", &iterator));
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
|
||||
int cur_iteration = 0;
|
||||
bool end_of_sequence = false;
|
||||
int64 num_slices = inputs[0].tensor->dim_size(0);
|
||||
std::vector<Tensor> out_tensors;
|
||||
|
||||
for (int breakpoint : breakpoints) {
|
||||
while (cur_iteration < breakpoint) {
|
||||
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
||||
&end_of_sequence));
|
||||
cur_iteration++;
|
||||
}
|
||||
|
||||
if (breakpoint == 0) {
|
||||
EXPECT_FALSE(end_of_sequence);
|
||||
} else if (breakpoint <= num_slices) {
|
||||
for (int i = 0; i < out_tensors.size(); ++i) {
|
||||
if (out_tensors[i].dtype() == DT_VARIANT) {
|
||||
const Tensor *output =
|
||||
out_tensors[i].scalar<Variant>()().get<Tensor>();
|
||||
const Tensor *expected_output =
|
||||
expected_outputs[i +
|
||||
num_tensors_per_slice * (cur_iteration - 1)]
|
||||
.scalar<Variant>()()
|
||||
.get<Tensor>();
|
||||
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
|
||||
} else {
|
||||
TF_EXPECT_OK(ExpectEqual(
|
||||
out_tensors[i], expected_outputs[i + num_tensors_per_slice *
|
||||
(cur_iteration - 1)]));
|
||||
}
|
||||
if (breakpoint == 0) {
|
||||
EXPECT_FALSE(end_of_sequence);
|
||||
} else if (breakpoint <= num_slices) {
|
||||
for (int i = 0; i < out_tensors.size(); ++i) {
|
||||
if (out_tensors[i].dtype() == DT_VARIANT) {
|
||||
const Tensor *output =
|
||||
out_tensors[i].scalar<Variant>()().get<Tensor>();
|
||||
const Tensor *expected_output =
|
||||
expected_outputs[i + num_tensors_per_slice * (cur_iteration - 1)]
|
||||
.scalar<Variant>()()
|
||||
.get<Tensor>();
|
||||
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
|
||||
} else {
|
||||
TF_EXPECT_OK(ExpectEqual(
|
||||
out_tensors[i], expected_outputs[i + num_tensors_per_slice *
|
||||
(cur_iteration - 1)]));
|
||||
}
|
||||
} else {
|
||||
EXPECT_TRUE(end_of_sequence);
|
||||
}
|
||||
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
VariantTensorDataReader reader(&data);
|
||||
TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader));
|
||||
} else {
|
||||
EXPECT_TRUE(end_of_sequence);
|
||||
}
|
||||
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
VariantTensorDataReader reader(&data);
|
||||
TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader));
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TensorSliceDatasetOpTest,
|
||||
ParameterizedTensorSliceDatasetOpTest,
|
||||
::testing::ValuesIn(std::vector<TestCase>(
|
||||
{PlainTensorTestCase(), NestedTensorTestCase()})));
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user