Refactor TensorSliceDatasetOpTest
This commit is contained in:
parent
578739fb89
commit
a280e8644b
@ -46,10 +46,10 @@ class TensorSliceDatasetOpTest : public DatasetOpsTestBase {
|
|||||||
components.emplace_back(strings::StrCat("component_", i));
|
components.emplace_back(strings::StrCat("component_", i));
|
||||||
}
|
}
|
||||||
|
|
||||||
node_def_ = test::function::NDef(
|
NodeDef node_def = test::function::NDef(
|
||||||
kNodeName, kOpName, components,
|
kNodeName, kOpName, components,
|
||||||
{{"Toutput_types", dtypes}, {"output_shapes", shapes}});
|
{{"Toutput_types", dtypes}, {"output_shapes", shapes}});
|
||||||
TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, tensor_dataset_kernel));
|
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, tensor_dataset_kernel));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,9 +63,6 @@ class TensorSliceDatasetOpTest : public DatasetOpsTestBase {
|
|||||||
CreateOpKernelContext(tensor_dataset_kernel, inputs, context));
|
CreateOpKernelContext(tensor_dataset_kernel, inputs, context));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
NodeDef node_def_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TestCase {
|
struct TestCase {
|
||||||
@ -74,28 +71,31 @@ struct TestCase {
|
|||||||
std::vector<int> breakpoints;
|
std::vector<int> breakpoints;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<TestCase> TestCases() {
|
TestCase PlainTensorTestCase() {
|
||||||
return {
|
return {/*components*/
|
||||||
// A single tuple of tensors.
|
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
|
||||||
{{{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
|
|
||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 2}),
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 2}),
|
||||||
{1, 2, 3, 4}),
|
{1, 2, 3, 4}),
|
||||||
DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 1}),
|
DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 1}),
|
||||||
{37.0, 38.0}),
|
{37.0, 38.0}),
|
||||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({2, 1}),
|
DatasetOpsTestBase::CreateTensor<string>(TensorShape({2, 1}),
|
||||||
{"a", "b"})}}, // components
|
{"a", "b"})},
|
||||||
{{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
|
/*expected_outputs*/
|
||||||
|
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
|
||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
|
||||||
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {37.0}),
|
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {37.0}),
|
||||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"a"}),
|
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"a"}),
|
||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
|
||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {3, 4}),
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {3, 4}),
|
||||||
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {38.0}),
|
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {38.0}),
|
||||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}),
|
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"b"})},
|
||||||
{"b"})}}, // expected_outputs
|
/*breakpoints*/ {0, 1, 3}};
|
||||||
{{0, 1, 3}}}, // breakpoints
|
}
|
||||||
// Nested tensors
|
|
||||||
{{{DatasetOpsTestBase::CreateTensor<Variant>(
|
TestCase NestedTensorTestCase() {
|
||||||
|
return {
|
||||||
|
/*components*/
|
||||||
|
{DatasetOpsTestBase::CreateTensor<Variant>(
|
||||||
TensorShape({2, 1}),
|
TensorShape({2, 1}),
|
||||||
{DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}),
|
{DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 2}),
|
||||||
{1.0, 2.0, 3.0, 4.0}),
|
{1.0, 2.0, 3.0, 4.0}),
|
||||||
@ -106,9 +106,10 @@ std::vector<TestCase> TestCases() {
|
|||||||
TensorShape({1, 2}), {"a", "b"}),
|
TensorShape({1, 2}), {"a", "b"}),
|
||||||
DatasetOpsTestBase::CreateTensor<string>(
|
DatasetOpsTestBase::CreateTensor<string>(
|
||||||
TensorShape({1, 2}), {"c", "d"})}),
|
TensorShape({1, 2}), {"c", "d"})}),
|
||||||
DatasetOpsTestBase::CreateTensor<int64>(
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 3}),
|
||||||
TensorShape({2, 3}), {1, 2, 3, 4, 5, 6})}}, // components
|
{1, 2, 3, 4, 5, 6})},
|
||||||
{{DatasetOpsTestBase::CreateTensor<Variant>(
|
/*expected_outputs*/
|
||||||
|
{DatasetOpsTestBase::CreateTensor<Variant>(
|
||||||
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
|
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<double>(
|
||||||
TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
|
TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
|
||||||
DatasetOpsTestBase::CreateTensor<Variant>(
|
DatasetOpsTestBase::CreateTensor<Variant>(
|
||||||
@ -121,40 +122,41 @@ std::vector<TestCase> TestCases() {
|
|||||||
DatasetOpsTestBase::CreateTensor<Variant>(
|
DatasetOpsTestBase::CreateTensor<Variant>(
|
||||||
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
|
TensorShape({1}), {DatasetOpsTestBase::CreateTensor<string>(
|
||||||
TensorShape({1, 2}), {"c", "d"})}),
|
TensorShape({1, 2}), {"c", "d"})}),
|
||||||
DatasetOpsTestBase::CreateTensor<int64>(
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3}), {4, 5, 6})},
|
||||||
TensorShape({3}), {4, 5, 6})}}, // expected_outputs
|
/*breakpoints*/ {0, 1, 2}};
|
||||||
{{0, 1, 2}}} // breakpoints
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TensorSliceDatasetOpTest, GetNext) {
|
class ParameterizedTensorSliceDatasetOpTest
|
||||||
int thread_num = 2, cpu_num = 2;
|
: public TensorSliceDatasetOpTest,
|
||||||
for (auto &test_case : TestCases()) {
|
public ::testing::WithParamInterface<TestCase> {};
|
||||||
std::vector<Tensor> components = test_case.components;
|
|
||||||
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
|
|
||||||
size_t num_tensors_per_slice = components.size();
|
|
||||||
|
|
||||||
|
TEST_P(ParameterizedTensorSliceDatasetOpTest, GetNext) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
const TestCase &test_case = GetParam();
|
||||||
|
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||||
|
std::vector<Tensor> components = test_case.components;
|
||||||
DataTypeVector dtypes;
|
DataTypeVector dtypes;
|
||||||
std::vector<PartialTensorShape> shapes;
|
|
||||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
for (auto &component : components) {
|
for (auto &component : components) {
|
||||||
inputs.push_back(&component);
|
inputs.emplace_back(&component);
|
||||||
dtypes.push_back(component.dtype());
|
dtypes.emplace_back(component.dtype());
|
||||||
}
|
}
|
||||||
|
size_t num_tensors_per_slice = components.size();
|
||||||
|
std::vector<PartialTensorShape> shapes;
|
||||||
|
shapes.reserve(num_tensors_per_slice);
|
||||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||||
shapes.emplace_back(expected_outputs[i].shape());
|
shapes.emplace_back(expected_outputs[i].shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||||
&tensor_slice_dataset_kernel));
|
&tensor_slice_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||||
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
|
TF_ASSERT_OK(
|
||||||
tensor_slice_dataset_kernel.get(), &inputs,
|
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||||
&tensor_slice_dataset_context));
|
&inputs, &tensor_slice_dataset_context));
|
||||||
DatasetBase *tensor_slice_dataset;
|
DatasetBase *tensor_slice_dataset;
|
||||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||||
tensor_slice_dataset_context.get(),
|
tensor_slice_dataset_context.get(),
|
||||||
@ -175,13 +177,11 @@ TEST_F(TensorSliceDatasetOpTest, GetNext) {
|
|||||||
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
||||||
&end_of_sequence));
|
&end_of_sequence));
|
||||||
for (int i = 0; i < out_tensors.size(); ++i) {
|
for (int i = 0; i < out_tensors.size(); ++i) {
|
||||||
EXPECT_LT(i + num_tensors_per_slice * cur_slice,
|
EXPECT_LT(i + num_tensors_per_slice * cur_slice, expected_outputs.size());
|
||||||
expected_outputs.size());
|
|
||||||
if (out_tensors[i].dtype() == DT_VARIANT) {
|
if (out_tensors[i].dtype() == DT_VARIANT) {
|
||||||
// Currently `ExpectEqual()` does not support the variant tensor
|
// Currently `ExpectEqual()` does not support the variant tensor
|
||||||
// yet, so we manually cast the variant to numeric/string tensor.
|
// yet, so we manually cast the variant to numeric/string tensor.
|
||||||
const Tensor *output =
|
const Tensor *output = out_tensors[i].scalar<Variant>()().get<Tensor>();
|
||||||
out_tensors[i].scalar<Variant>()().get<Tensor>();
|
|
||||||
const Tensor *expected_output =
|
const Tensor *expected_output =
|
||||||
expected_outputs[i + num_tensors_per_slice * cur_slice]
|
expected_outputs[i + num_tensors_per_slice * cur_slice]
|
||||||
.scalar<Variant>()()
|
.scalar<Variant>()()
|
||||||
@ -197,19 +197,63 @@ TEST_F(TensorSliceDatasetOpTest, GetNext) {
|
|||||||
cur_slice++;
|
cur_slice++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TensorSliceDatasetOpTest, DatasetName) {
|
TEST_F(TensorSliceDatasetOpTest, DatasetNodeName) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
Tensor t1 = CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4});
|
const TestCase &test_case = PlainTensorTestCase();
|
||||||
Tensor t2 = CreateTensor<int64>(TensorShape({2, 2}), {5, 6, 7, 8});
|
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||||
gtl::InlinedVector<TensorValue, 4> inputs = {&t1, &t2};
|
std::vector<Tensor> components = test_case.components;
|
||||||
DataTypeVector dtypes({DT_INT64, DT_INT64});
|
DataTypeVector dtypes;
|
||||||
std::vector<PartialTensorShape> shapes = {PartialTensorShape({2}),
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
PartialTensorShape({2})};
|
for (auto &component : components) {
|
||||||
|
inputs.emplace_back(&component);
|
||||||
|
dtypes.emplace_back(component.dtype());
|
||||||
|
}
|
||||||
|
size_t num_tensors_per_slice = components.size();
|
||||||
|
std::vector<PartialTensorShape> shapes;
|
||||||
|
shapes.reserve(num_tensors_per_slice);
|
||||||
|
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||||
|
shapes.emplace_back(expected_outputs[i].shape());
|
||||||
|
}
|
||||||
|
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||||
|
&tensor_slice_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||||
|
&inputs, &tensor_slice_dataset_context));
|
||||||
|
DatasetBase *tensor_slice_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||||
|
tensor_slice_dataset_context.get(),
|
||||||
|
&tensor_slice_dataset));
|
||||||
|
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor_slice_dataset->node_name(), kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TensorSliceDatasetOpTest, DatasetTypeString) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
const TestCase &test_case = PlainTensorTestCase();
|
||||||
|
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||||
|
std::vector<Tensor> components = test_case.components;
|
||||||
|
DataTypeVector dtypes;
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
for (auto &component : components) {
|
||||||
|
inputs.emplace_back(&component);
|
||||||
|
dtypes.emplace_back(component.dtype());
|
||||||
|
}
|
||||||
|
size_t num_tensors_per_slice = components.size();
|
||||||
|
std::vector<PartialTensorShape> shapes;
|
||||||
|
shapes.reserve(num_tensors_per_slice);
|
||||||
|
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||||
|
shapes.emplace_back(expected_outputs[i].shape());
|
||||||
|
}
|
||||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||||
&tensor_slice_dataset_kernel));
|
&tensor_slice_dataset_kernel));
|
||||||
@ -226,34 +270,33 @@ TEST_F(TensorSliceDatasetOpTest, DatasetName) {
|
|||||||
EXPECT_EQ(tensor_slice_dataset->type_string(), kOpName);
|
EXPECT_EQ(tensor_slice_dataset->type_string(), kOpName);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TensorSliceDatasetOpTest, DatasetOutputDtypes) {
|
TEST_P(ParameterizedTensorSliceDatasetOpTest, DatasetOutputDtypes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
for (auto &test_case : TestCases()) {
|
|
||||||
std::vector<Tensor> components = test_case.components;
|
|
||||||
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
|
|
||||||
size_t num_tensors_per_slice = components.size();
|
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
const TestCase &test_case = GetParam();
|
||||||
|
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||||
|
std::vector<Tensor> components = test_case.components;
|
||||||
DataTypeVector dtypes;
|
DataTypeVector dtypes;
|
||||||
std::vector<PartialTensorShape> shapes;
|
|
||||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
for (auto &component : components) {
|
for (auto &component : components) {
|
||||||
inputs.emplace_back(&component);
|
inputs.emplace_back(&component);
|
||||||
dtypes.emplace_back(component.dtype());
|
dtypes.emplace_back(component.dtype());
|
||||||
}
|
}
|
||||||
|
size_t num_tensors_per_slice = components.size();
|
||||||
|
std::vector<PartialTensorShape> shapes;
|
||||||
|
shapes.reserve(num_tensors_per_slice);
|
||||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||||
shapes.emplace_back(expected_outputs[i].shape());
|
shapes.emplace_back(expected_outputs[i].shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||||
&tensor_slice_dataset_kernel));
|
&tensor_slice_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||||
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
|
TF_ASSERT_OK(
|
||||||
tensor_slice_dataset_kernel.get(), &inputs,
|
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||||
&tensor_slice_dataset_context));
|
&inputs, &tensor_slice_dataset_context));
|
||||||
DatasetBase *tensor_slice_dataset;
|
DatasetBase *tensor_slice_dataset;
|
||||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||||
tensor_slice_dataset_context.get(),
|
tensor_slice_dataset_context.get(),
|
||||||
@ -267,25 +310,24 @@ TEST_F(TensorSliceDatasetOpTest, DatasetOutputDtypes) {
|
|||||||
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
|
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TensorSliceDatasetOpTest, DatasetOutputShapes) {
|
TEST_P(ParameterizedTensorSliceDatasetOpTest, DatasetOutputShapes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
for (auto &test_case : TestCases()) {
|
|
||||||
std::vector<Tensor> components = test_case.components;
|
|
||||||
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
|
|
||||||
size_t num_tensors_per_slice = components.size();
|
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
const TestCase &test_case = GetParam();
|
||||||
|
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||||
|
std::vector<Tensor> components = test_case.components;
|
||||||
DataTypeVector dtypes;
|
DataTypeVector dtypes;
|
||||||
std::vector<PartialTensorShape> shapes;
|
|
||||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
for (auto &component : components) {
|
for (auto &component : components) {
|
||||||
inputs.emplace_back(&component);
|
inputs.emplace_back(&component);
|
||||||
dtypes.emplace_back(component.dtype());
|
dtypes.emplace_back(component.dtype());
|
||||||
}
|
}
|
||||||
|
size_t num_tensors_per_slice = components.size();
|
||||||
|
std::vector<PartialTensorShape> shapes;
|
||||||
|
shapes.reserve(num_tensors_per_slice);
|
||||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||||
shapes.emplace_back(expected_outputs[i].shape());
|
shapes.emplace_back(expected_outputs[i].shape());
|
||||||
}
|
}
|
||||||
@ -293,9 +335,9 @@ TEST_F(TensorSliceDatasetOpTest, DatasetOutputShapes) {
|
|||||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||||
&tensor_slice_dataset_kernel));
|
&tensor_slice_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||||
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
|
TF_ASSERT_OK(
|
||||||
tensor_slice_dataset_kernel.get(), &inputs,
|
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||||
&tensor_slice_dataset_context));
|
&inputs, &tensor_slice_dataset_context));
|
||||||
DatasetBase *tensor_slice_dataset;
|
DatasetBase *tensor_slice_dataset;
|
||||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||||
tensor_slice_dataset_context.get(),
|
tensor_slice_dataset_context.get(),
|
||||||
@ -311,25 +353,24 @@ TEST_F(TensorSliceDatasetOpTest, DatasetOutputShapes) {
|
|||||||
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
|
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TensorSliceDatasetOpTest, Cardinality) {
|
TEST_P(ParameterizedTensorSliceDatasetOpTest, Cardinality) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
for (auto &test_case : TestCases()) {
|
|
||||||
std::vector<Tensor> components = test_case.components;
|
|
||||||
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
|
|
||||||
size_t num_tensors_per_slice = components.size();
|
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
const TestCase &test_case = GetParam();
|
||||||
|
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||||
|
std::vector<Tensor> components = test_case.components;
|
||||||
DataTypeVector dtypes;
|
DataTypeVector dtypes;
|
||||||
std::vector<PartialTensorShape> shapes;
|
|
||||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
for (auto &component : components) {
|
for (auto &component : components) {
|
||||||
inputs.emplace_back(&component);
|
inputs.emplace_back(&component);
|
||||||
dtypes.emplace_back(component.dtype());
|
dtypes.emplace_back(component.dtype());
|
||||||
}
|
}
|
||||||
|
size_t num_tensors_per_slice = components.size();
|
||||||
|
std::vector<PartialTensorShape> shapes;
|
||||||
|
shapes.reserve(num_tensors_per_slice);
|
||||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||||
shapes.emplace_back(expected_outputs[i].shape());
|
shapes.emplace_back(expected_outputs[i].shape());
|
||||||
}
|
}
|
||||||
@ -337,18 +378,16 @@ TEST_F(TensorSliceDatasetOpTest, Cardinality) {
|
|||||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||||
&tensor_slice_dataset_kernel));
|
&tensor_slice_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||||
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
|
TF_ASSERT_OK(
|
||||||
tensor_slice_dataset_kernel.get(), &inputs,
|
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||||
&tensor_slice_dataset_context));
|
&inputs, &tensor_slice_dataset_context));
|
||||||
DatasetBase *tensor_slice_dataset;
|
DatasetBase *tensor_slice_dataset;
|
||||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||||
tensor_slice_dataset_context.get(),
|
tensor_slice_dataset_context.get(),
|
||||||
&tensor_slice_dataset));
|
&tensor_slice_dataset));
|
||||||
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
core::ScopedUnref scored_unref(tensor_slice_dataset);
|
||||||
|
|
||||||
EXPECT_EQ(tensor_slice_dataset->Cardinality(),
|
EXPECT_EQ(tensor_slice_dataset->Cardinality(), inputs[0].tensor->dim_size(0));
|
||||||
inputs[0].tensor->dim_size(0));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
|
TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
|
||||||
@ -356,12 +395,21 @@ TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
|
|||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
Tensor t1 = CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4});
|
const TestCase &test_case = PlainTensorTestCase();
|
||||||
Tensor t2 = CreateTensor<int64>(TensorShape({2, 2}), {5, 6, 7, 8});
|
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||||
gtl::InlinedVector<TensorValue, 4> inputs = {&t1, &t2};
|
std::vector<Tensor> components = test_case.components;
|
||||||
DataTypeVector dtypes({DT_INT64, DT_INT64});
|
DataTypeVector dtypes;
|
||||||
std::vector<PartialTensorShape> shapes = {PartialTensorShape({2}),
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
PartialTensorShape({2})};
|
for (auto &component : components) {
|
||||||
|
inputs.emplace_back(&component);
|
||||||
|
dtypes.emplace_back(component.dtype());
|
||||||
|
}
|
||||||
|
size_t num_tensors_per_slice = components.size();
|
||||||
|
std::vector<PartialTensorShape> shapes;
|
||||||
|
shapes.reserve(num_tensors_per_slice);
|
||||||
|
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||||
|
shapes.emplace_back(expected_outputs[i].shape());
|
||||||
|
}
|
||||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||||
&tensor_slice_dataset_kernel));
|
&tensor_slice_dataset_kernel));
|
||||||
@ -384,34 +432,33 @@ TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
|
|||||||
TF_ASSERT_OK(writer.Flush());
|
TF_ASSERT_OK(writer.Flush());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TensorSliceDatasetOpTest, IteratorOutputDtypes) {
|
TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputDtypes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
for (auto &test_case : TestCases()) {
|
|
||||||
std::vector<Tensor> components = test_case.components;
|
|
||||||
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
|
|
||||||
size_t num_tensors_per_slice = components.size();
|
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
const TestCase &test_case = GetParam();
|
||||||
|
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||||
|
std::vector<Tensor> components = test_case.components;
|
||||||
DataTypeVector dtypes;
|
DataTypeVector dtypes;
|
||||||
std::vector<PartialTensorShape> shapes;
|
|
||||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
for (auto &component : components) {
|
for (auto &component : components) {
|
||||||
inputs.emplace_back(&component);
|
inputs.emplace_back(&component);
|
||||||
dtypes.emplace_back(component.dtype());
|
dtypes.emplace_back(component.dtype());
|
||||||
}
|
}
|
||||||
|
size_t num_tensors_per_slice = components.size();
|
||||||
|
std::vector<PartialTensorShape> shapes;
|
||||||
|
shapes.reserve(num_tensors_per_slice);
|
||||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||||
shapes.emplace_back(expected_outputs[i].shape());
|
shapes.emplace_back(expected_outputs[i].shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||||
&tensor_slice_dataset_kernel));
|
&tensor_slice_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||||
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
|
TF_ASSERT_OK(
|
||||||
tensor_slice_dataset_kernel.get(), &inputs,
|
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||||
&tensor_slice_dataset_context));
|
&inputs, &tensor_slice_dataset_context));
|
||||||
DatasetBase *tensor_slice_dataset;
|
DatasetBase *tensor_slice_dataset;
|
||||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||||
tensor_slice_dataset_context.get(),
|
tensor_slice_dataset_context.get(),
|
||||||
@ -431,36 +478,34 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputDtypes) {
|
|||||||
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
|
EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TensorSliceDatasetOpTest, IteratorOutputShapes) {
|
TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputShapes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
for (auto &test_case : TestCases()) {
|
|
||||||
std::vector<Tensor> components = test_case.components;
|
|
||||||
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
|
|
||||||
size_t num_tensors_per_slice = components.size();
|
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
const TestCase &test_case = GetParam();
|
||||||
|
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||||
|
std::vector<Tensor> components = test_case.components;
|
||||||
DataTypeVector dtypes;
|
DataTypeVector dtypes;
|
||||||
std::vector<PartialTensorShape> shapes;
|
|
||||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
for (auto &component : components) {
|
for (auto &component : components) {
|
||||||
inputs.emplace_back(&component);
|
inputs.emplace_back(&component);
|
||||||
dtypes.emplace_back(component.dtype());
|
dtypes.emplace_back(component.dtype());
|
||||||
}
|
}
|
||||||
|
size_t num_tensors_per_slice = components.size();
|
||||||
|
std::vector<PartialTensorShape> shapes;
|
||||||
|
shapes.reserve(num_tensors_per_slice);
|
||||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||||
shapes.emplace_back(expected_outputs[i].shape());
|
shapes.emplace_back(expected_outputs[i].shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||||
&tensor_slice_dataset_kernel));
|
&tensor_slice_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||||
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
|
TF_ASSERT_OK(
|
||||||
tensor_slice_dataset_kernel.get(), &inputs,
|
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||||
&tensor_slice_dataset_context));
|
&inputs, &tensor_slice_dataset_context));
|
||||||
DatasetBase *tensor_slice_dataset;
|
DatasetBase *tensor_slice_dataset;
|
||||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||||
tensor_slice_dataset_context.get(),
|
tensor_slice_dataset_context.get(),
|
||||||
@ -481,19 +526,27 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputShapes) {
|
|||||||
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
|
produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) {
|
TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
Tensor t1 = CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4});
|
const TestCase &test_case = PlainTensorTestCase();
|
||||||
Tensor t2 = CreateTensor<int64>(TensorShape({2, 2}), {5, 6, 7, 8});
|
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||||
gtl::InlinedVector<TensorValue, 4> inputs = {&t1, &t2};
|
std::vector<Tensor> components = test_case.components;
|
||||||
DataTypeVector dtypes({DT_INT64, DT_INT64});
|
DataTypeVector dtypes;
|
||||||
std::vector<PartialTensorShape> shapes = {PartialTensorShape({2}),
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
PartialTensorShape({2})};
|
for (auto &component : components) {
|
||||||
|
inputs.emplace_back(&component);
|
||||||
|
dtypes.emplace_back(component.dtype());
|
||||||
|
}
|
||||||
|
size_t num_tensors_per_slice = components.size();
|
||||||
|
std::vector<PartialTensorShape> shapes;
|
||||||
|
shapes.reserve(num_tensors_per_slice);
|
||||||
|
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||||
|
shapes.emplace_back(expected_outputs[i].shape());
|
||||||
|
}
|
||||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||||
&tensor_slice_dataset_kernel));
|
&tensor_slice_dataset_kernel));
|
||||||
@ -516,35 +569,33 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) {
|
|||||||
EXPECT_EQ(iterator->prefix(), "Iterator::TensorSlice");
|
EXPECT_EQ(iterator->prefix(), "Iterator::TensorSlice");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TensorSliceDatasetOpTest, Roundtrip) {
|
TEST_P(ParameterizedTensorSliceDatasetOpTest, Roundtrip) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
for (auto &test_case : TestCases()) {
|
|
||||||
std::vector<Tensor> components = test_case.components;
|
|
||||||
std::vector<Tensor> expected_outputs = test_case.expected_outputs;
|
|
||||||
std::vector<int> breakpoints = test_case.breakpoints;
|
|
||||||
size_t num_tensors_per_slice = components.size();
|
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
const TestCase &test_case = GetParam();
|
||||||
|
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||||
|
std::vector<Tensor> components = test_case.components;
|
||||||
DataTypeVector dtypes;
|
DataTypeVector dtypes;
|
||||||
std::vector<PartialTensorShape> shapes;
|
|
||||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
for (auto &component : components) {
|
for (auto &component : components) {
|
||||||
inputs.emplace_back(&component);
|
inputs.emplace_back(&component);
|
||||||
dtypes.emplace_back(component.dtype());
|
dtypes.emplace_back(component.dtype());
|
||||||
}
|
}
|
||||||
|
size_t num_tensors_per_slice = components.size();
|
||||||
|
std::vector<PartialTensorShape> shapes;
|
||||||
|
shapes.reserve(num_tensors_per_slice);
|
||||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||||
shapes.emplace_back(expected_outputs[i].shape());
|
shapes.emplace_back(expected_outputs[i].shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||||
&tensor_slice_dataset_kernel));
|
&tensor_slice_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||||
TF_ASSERT_OK(CreateTensorSliceDatasetContext(
|
TF_ASSERT_OK(
|
||||||
tensor_slice_dataset_kernel.get(), &inputs,
|
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||||
&tensor_slice_dataset_context));
|
&inputs, &tensor_slice_dataset_context));
|
||||||
DatasetBase *tensor_slice_dataset;
|
DatasetBase *tensor_slice_dataset;
|
||||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||||
tensor_slice_dataset_context.get(),
|
tensor_slice_dataset_context.get(),
|
||||||
@ -564,7 +615,7 @@ TEST_F(TensorSliceDatasetOpTest, Roundtrip) {
|
|||||||
bool end_of_sequence = false;
|
bool end_of_sequence = false;
|
||||||
int64 num_slices = inputs[0].tensor->dim_size(0);
|
int64 num_slices = inputs[0].tensor->dim_size(0);
|
||||||
std::vector<Tensor> out_tensors;
|
std::vector<Tensor> out_tensors;
|
||||||
|
const std::vector<int> &breakpoints = test_case.breakpoints;
|
||||||
for (int breakpoint : breakpoints) {
|
for (int breakpoint : breakpoints) {
|
||||||
while (cur_iteration < breakpoint) {
|
while (cur_iteration < breakpoint) {
|
||||||
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
||||||
@ -580,8 +631,7 @@ TEST_F(TensorSliceDatasetOpTest, Roundtrip) {
|
|||||||
const Tensor *output =
|
const Tensor *output =
|
||||||
out_tensors[i].scalar<Variant>()().get<Tensor>();
|
out_tensors[i].scalar<Variant>()().get<Tensor>();
|
||||||
const Tensor *expected_output =
|
const Tensor *expected_output =
|
||||||
expected_outputs[i +
|
expected_outputs[i + num_tensors_per_slice * (cur_iteration - 1)]
|
||||||
num_tensors_per_slice * (cur_iteration - 1)]
|
|
||||||
.scalar<Variant>()()
|
.scalar<Variant>()()
|
||||||
.get<Tensor>();
|
.get<Tensor>();
|
||||||
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
|
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
|
||||||
@ -603,7 +653,11 @@ TEST_F(TensorSliceDatasetOpTest, Roundtrip) {
|
|||||||
TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader));
|
TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(TensorSliceDatasetOpTest,
|
||||||
|
ParameterizedTensorSliceDatasetOpTest,
|
||||||
|
::testing::ValuesIn(std::vector<TestCase>(
|
||||||
|
{PlainTensorTestCase(), NestedTensorTestCase()})));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
Loading…
Reference in New Issue
Block a user