Refactor MapDatasetOpTest

This commit is contained in:
Fei Hu 2019-03-15 16:46:59 -07:00
parent 8dc85de492
commit 301dfdac4c

View File

@ -49,7 +49,7 @@ class MapDatasetOpTest : public DatasetOpsTestBase {
FunctionDefHelper::AttrValueWrapper func = FunctionDefHelper::AttrValueWrapper func =
FunctionDefHelper::FunctionRef(func_name, {{"T", DT_INT64}}); FunctionDefHelper::FunctionRef(func_name, {{"T", DT_INT64}});
map_node_def_ = test::function::NDef( NodeDef map_dataset_node_def = test::function::NDef(
kNodeName, kOpName, {input_dataset}, kNodeName, kOpName, {input_dataset},
{{"f", func}, {{"f", func},
{"Targuments", {}}, {"Targuments", {}},
@ -58,136 +58,198 @@ class MapDatasetOpTest : public DatasetOpsTestBase {
gtl::ArraySlice<DataType>{tensorflow::DataTypeToEnum<T>::value}}, gtl::ArraySlice<DataType>{tensorflow::DataTypeToEnum<T>::value}},
{"use_inter_op_parallelism", true}, {"use_inter_op_parallelism", true},
{"preserve_cardinality", false}}); {"preserve_cardinality", false}});
TF_CHECK_OK(CreateOpKernel(map_node_def_, map_kernel)); TF_RETURN_IF_ERROR(CreateOpKernel(map_dataset_node_def, map_kernel));
return Status::OK(); return Status::OK();
} }
// Creates a new MapDataset op kernel context. // Creates a new MapDataset op kernel context.
Status CreateMapDatasetContext( Status CreateMapDatasetContext(
DatasetBase* const input_dataset, OpKernel* const map_kernel, OpKernel* const map_kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
std::unique_ptr<OpKernelContext>* map_context) { std::unique_ptr<OpKernelContext>* map_context) {
map_inputs_.clear(); TF_RETURN_IF_ERROR(CheckOpKernelInput(*map_kernel, *inputs));
// Save the input dataset into a variant tensor as the input of MapDataset. TF_RETURN_IF_ERROR(CreateOpKernelContext(map_kernel, inputs, map_context));
Tensor dataset_tensor(DT_VARIANT, TensorShape({}));
TF_RETURN_IF_ERROR(
StoreDatasetInVariantTensor(input_dataset, &dataset_tensor));
Variant variant = dataset_tensor.scalar<Variant>()();
TF_RETURN_IF_ERROR(AddDatasetInputFromArray<Variant>(
&map_inputs_, map_kernel->input_types(), TensorShape({}), {variant}));
input_dataset->Ref();
TF_RETURN_IF_ERROR(
CreateOpKernelContext(map_kernel, &map_inputs_, map_context));
TF_RETURN_IF_ERROR(CheckOpKernelInput(*map_kernel, map_inputs_));
return Status::OK(); return Status::OK();
} }
private:
NodeDef map_node_def_;
gtl::InlinedVector<TensorValue, 4> map_inputs_;
}; };
struct GetNextTestParams { struct TestCase {
explicit GetNextTestParams(int64 input_start, int64 input_end,
int64 input_step, string input_func_name,
std::vector<int64> input_expected_values,
std::vector<FunctionDef> input_func_lib)
: start(input_start),
end(input_end),
step(input_step),
func_name(std::move(input_func_name)),
expected_values(std::move(input_expected_values)),
func_lib(std::move(input_func_lib)) {}
int64 start; int64 start;
int64 end; int64 end;
int64 step; int64 step;
string func_name; string func_name;
std::vector<int64> expected_values;
std::vector<FunctionDef> func_lib; std::vector<FunctionDef> func_lib;
std::vector<Tensor> expected_outputs;
DataTypeVector expected_output_dtypes;
std::vector<PartialTensorShape> expected_output_shapes;
int64 expected_cardinality;
std::vector<int> breakpoints;
}; };
struct DatasetGetNextTest : MapDatasetOpTest, TestCase TestCase1() {
::testing::WithParamInterface<GetNextTestParams> {}; return {/*start*/ 0,
/*end*/ 10,
/*step*/ 3,
/*func_name*/ "XTimesTwo",
/*func_lib*/ {test::function::XTimesTwo()},
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {18})},
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 4,
/*breakpoints*/ {0, 1, 5}};
}
TEST_P(DatasetGetNextTest, GetNext) { TestCase TestCase2() {
return {/*start*/ 10,
/*end*/ 0,
/*step*/ -3,
/*func_name*/ "XAddX",
/*func_lib*/ {test::function::XAddX()},
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {20}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {14}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {8}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 4,
/*breakpoints*/ {0, 1, 5}};
}
// In this test case, the function `XTimesFour()` will call `XTimesTwo()`, so
// both of them are added to the function library.
TestCase TestCase3() {
return {
/*start*/ 0,
/*end*/ 10,
/*step*/ 3,
/*func_name*/ "XTimesFour",
/*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()},
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {24}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {36})},
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 4,
/*breakpoints*/ {0, 1, 5}};
}
class ParameterizedMapDatasetOpTest
: public MapDatasetOpTest,
public ::testing::WithParamInterface<TestCase> {};
TEST_P(ParameterizedMapDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
GetNextTestParams test_params = GetParam(); TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_params.func_lib, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
DatasetBase* range_dataset; DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(test_params.start, test_params.end, TF_ASSERT_OK(CreateRangeDataset<int64>(
test_params.step, "range", test_case.start, test_case.end, test_case.step, "range", &range_dataset));
&range_dataset)); Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
core::ScopedUnref scored_unref_range_dataset(range_dataset); // The ownership of range_dataset is transfered to DatasetVariantWrapper,
// which will handle the release of memory.
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
map_dataset_inputs.emplace_back(&range_dataset_tensor);
std::unique_ptr<OpKernel> map_kernel; std::unique_ptr<OpKernel> map_dataset_kernel;
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>( TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
range_dataset->node_name(), test_params.func_name, &map_kernel)); range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
std::unique_ptr<OpKernelContext> map_context; std::unique_ptr<OpKernelContext> map_dataset_context;
TF_ASSERT_OK( TF_ASSERT_OK(CreateMapDatasetContext(
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
DatasetBase* map_dataset; DatasetBase* map_dataset;
TF_ASSERT_OK( TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); map_dataset_context.get(), &map_dataset));
core::ScopedUnref scoped_unref_map_dataset(map_dataset); core::ScopedUnref scoped_unref_map_dataset(map_dataset);
std::unique_ptr<IteratorContext> iterator_context; std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); TF_ASSERT_OK(
CreateIteratorContext(map_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator; std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK( TF_ASSERT_OK(
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
bool end_of_sequence = false; bool end_of_sequence = false;
auto expected_outputs_it = test_case.expected_outputs.begin();
std::vector<Tensor> out_tensors; std::vector<Tensor> out_tensors;
while (!end_of_sequence) { while (!end_of_sequence) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
&end_of_sequence)); &end_of_sequence));
if (!end_of_sequence) {
EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it));
expected_outputs_it++;
}
} }
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
EXPECT_EQ(out_tensors.size(), test_params.expected_values.size());
for (size_t i = 0; i < out_tensors.size(); ++i) {
int64 actual_value = out_tensors[i].flat<int64>()(0);
int64 expect_value = test_params.expected_values[i];
EXPECT_EQ(actual_value, expect_value);
}
} }
INSTANTIATE_TEST_CASE_P( TEST_F(MapDatasetOpTest, DatasetNodeName) {
MapDatasetOpTest, DatasetGetNextTest,
::testing::Values(
GetNextTestParams(
0, 10, 3, "XTimesTwo", std::vector<int64>{0, 6, 12, 18},
std::vector<FunctionDef>{test::function::XTimesTwo()}),
GetNextTestParams(0, 10, 3, "XAddX", std::vector<int64>{0, 6, 12, 18},
std::vector<FunctionDef>{test::function::XAddX()}),
GetNextTestParams(
10, 0, -3, "XTimesFour", std::vector<int64>{40, 28, 16, 4},
std::vector<FunctionDef>{test::function::XTimesTwo(),
test::function::XTimesFour()})));
TEST_F(MapDatasetOpTest, DatasetName) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
int64 start = 0, end = 10, step = 1; TestCase test_case = TestCase1();
FunctionDef func_def = test::function::XTimesTwo();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
DatasetBase* range_dataset; DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
// which will handle the release of memory.
TF_ASSERT_OK( TF_ASSERT_OK(
CreateRangeDataset<int64>(start, end, step, "range", &range_dataset)); StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
core::ScopedUnref scoped_unref_range_dataset(range_dataset); gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
map_dataset_inputs.emplace_back(&range_dataset_tensor);
std::unique_ptr<OpKernel> map_kernel; std::unique_ptr<OpKernel> map_dataset_kernel;
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>( TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
range_dataset->node_name(), func_def.signature().name(), &map_kernel)); range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
std::unique_ptr<OpKernelContext> map_context; std::unique_ptr<OpKernelContext> map_dataset_context;
TF_ASSERT_OK( TF_ASSERT_OK(CreateMapDatasetContext(
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
DatasetBase* map_dataset; DatasetBase* map_dataset;
TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
map_dataset_context.get(), &map_dataset));
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
EXPECT_EQ(map_dataset->node_name(), kNodeName);
}
TEST_F(MapDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = TestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
// which will handle the release of memory.
TF_ASSERT_OK( TF_ASSERT_OK(
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
map_dataset_inputs.emplace_back(&range_dataset_tensor);
std::unique_ptr<OpKernel> map_dataset_kernel;
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
std::unique_ptr<OpKernelContext> map_dataset_context;
TF_ASSERT_OK(CreateMapDatasetContext(
map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
DatasetBase* map_dataset;
TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
map_dataset_context.get(), &map_dataset));
core::ScopedUnref scoped_unref_map_dataset(map_dataset); core::ScopedUnref scoped_unref_map_dataset(map_dataset);
EXPECT_EQ(map_dataset->type_string(), kOpName); EXPECT_EQ(map_dataset->type_string(), kOpName);
@ -195,138 +257,125 @@ TEST_F(MapDatasetOpTest, DatasetName) {
TEST_F(MapDatasetOpTest, DatasetOutputDtypes) { TEST_F(MapDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
int64 start = 0, end = 10, step = 1; TestCase test_case = TestCase1();
FunctionDef func_def = test::function::XTimesTwo();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
DatasetBase* range_dataset; DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
// which will handle the release of memory.
TF_ASSERT_OK( TF_ASSERT_OK(
CreateRangeDataset<int64>(start, end, step, "range", &range_dataset)); StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
core::ScopedUnref scoped_unref_range_dataset(range_dataset); gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
map_dataset_inputs.emplace_back(&range_dataset_tensor);
std::unique_ptr<OpKernel> map_kernel; std::unique_ptr<OpKernel> map_dataset_kernel;
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>( TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
range_dataset->node_name(), func_def.signature().name(), &map_kernel)); range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
std::unique_ptr<OpKernelContext> map_context; std::unique_ptr<OpKernelContext> map_dataset_context;
TF_ASSERT_OK( TF_ASSERT_OK(CreateMapDatasetContext(
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
DatasetBase* map_dataset; DatasetBase* map_dataset;
TF_ASSERT_OK( TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); map_dataset_context.get(), &map_dataset));
core::ScopedUnref scoped_unref_map_dataset(map_dataset); core::ScopedUnref scoped_unref_map_dataset(map_dataset);
DataTypeVector expected_dtypes({DT_INT64}); TF_EXPECT_OK(VerifyTypesMatch(map_dataset->output_dtypes(),
EXPECT_EQ(map_dataset->output_dtypes(), expected_dtypes); test_case.expected_output_dtypes));
} }
TEST_F(MapDatasetOpTest, DatasetOutputShapes) { TEST_F(MapDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
int64 start = 0, end = 10, step = 1; TestCase test_case = TestCase1();
FunctionDef func_def = test::function::XTimesTwo();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
DatasetBase* range_dataset; DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
// which will handle the release of memory.
TF_ASSERT_OK( TF_ASSERT_OK(
CreateRangeDataset<int64>(start, end, step, "range", &range_dataset)); StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
core::ScopedUnref scoped_unref_range_dataset(range_dataset); gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
map_dataset_inputs.emplace_back(&range_dataset_tensor);
std::unique_ptr<OpKernel> map_kernel; std::unique_ptr<OpKernel> map_dataset_kernel;
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>( TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
range_dataset->node_name(), func_def.signature().name(), &map_kernel)); range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
std::unique_ptr<OpKernelContext> map_context; std::unique_ptr<OpKernelContext> map_dataset_context;
TF_ASSERT_OK( TF_ASSERT_OK(CreateMapDatasetContext(
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
DatasetBase* map_dataset; DatasetBase* map_dataset;
TF_ASSERT_OK( TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); map_dataset_context.get(), &map_dataset));
core::ScopedUnref scoped_unref_map_dataset(map_dataset); core::ScopedUnref scoped_unref_map_dataset(map_dataset);
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})}); TF_EXPECT_OK(VerifyShapesCompatible(map_dataset->output_shapes(),
EXPECT_EQ(map_dataset->output_shapes().size(), expected_shapes.size()); test_case.expected_output_shapes));
for (int i = 0; i < map_dataset->output_shapes().size(); ++i) {
EXPECT_TRUE(
map_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
}
} }
struct CardinalityTestParams { TEST_P(ParameterizedMapDatasetOpTest, Cardinality) {
explicit CardinalityTestParams(int64 input_start, int64 input_end,
int64 input_step,
int input_expected_cardinality)
: start(input_start),
end(input_end),
step(input_step),
expected_cardinality(input_expected_cardinality) {}
int64 start;
int64 end;
int64 step;
int expected_cardinality;
};
struct DatasetCardinalityTest
: MapDatasetOpTest,
::testing::WithParamInterface<CardinalityTestParams> {};
TEST_P(DatasetCardinalityTest, Cardinality) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
CardinalityTestParams test_params = GetParam(); TestCase test_case = GetParam();
FunctionDef func_def = test::function::XTimesTwo();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
DatasetBase* range_dataset; DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(test_params.start, test_params.end, TF_ASSERT_OK(CreateRangeDataset<int64>(
test_params.step, "range", test_case.start, test_case.end, test_case.step, "range", &range_dataset));
&range_dataset)); Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
core::ScopedUnref scoped_unref_range_dataset(range_dataset); // The ownership of range_dataset is transfered to DatasetVariantWrapper,
// which will handle the release of memory.
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
map_dataset_inputs.emplace_back(&range_dataset_tensor);
std::unique_ptr<OpKernel> map_kernel; std::unique_ptr<OpKernel> map_dataset_kernel;
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>( TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
range_dataset->node_name(), func_def.signature().name(), &map_kernel)); range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
std::unique_ptr<OpKernelContext> map_context; std::unique_ptr<OpKernelContext> map_dataset_context;
TF_ASSERT_OK( TF_ASSERT_OK(CreateMapDatasetContext(
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
DatasetBase* map_dataset; DatasetBase* map_dataset;
TF_ASSERT_OK( TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); map_dataset_context.get(), &map_dataset));
core::ScopedUnref scoped_unref_map_dataset(map_dataset); core::ScopedUnref scoped_unref_map_dataset(map_dataset);
EXPECT_EQ(map_dataset->Cardinality(), test_params.expected_cardinality); EXPECT_EQ(map_dataset->Cardinality(), test_case.expected_cardinality);
} }
INSTANTIATE_TEST_CASE_P(MapDatasetOpTest, DatasetCardinalityTest, TEST_P(ParameterizedMapDatasetOpTest, DatasetSave) {
::testing::Values(CardinalityTestParams(0, 10, 1, 10),
CardinalityTestParams(0, 10, 3, 4),
CardinalityTestParams(10, 0, -3, 4)));
TEST_F(MapDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
int64 start = 0, end = 10, step = 1; TestCase test_case = GetParam();
FunctionDef func_def = test::function::XTimesTwo();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
DatasetBase* range_dataset; DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
// which will handle the release of memory.
TF_ASSERT_OK( TF_ASSERT_OK(
CreateRangeDataset<int64>(start, end, step, "range", &range_dataset)); StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
core::ScopedUnref scoped_unref_range_dataset(range_dataset); gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
map_dataset_inputs.emplace_back(&range_dataset_tensor);
std::unique_ptr<OpKernel> map_kernel; std::unique_ptr<OpKernel> map_dataset_kernel;
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>( TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
range_dataset->node_name(), func_def.signature().name(), &map_kernel)); range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
std::unique_ptr<OpKernelContext> map_context; std::unique_ptr<OpKernelContext> map_dataset_context;
TF_ASSERT_OK( TF_ASSERT_OK(CreateMapDatasetContext(
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
DatasetBase* map_dataset; DatasetBase* map_dataset;
TF_ASSERT_OK( TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); map_dataset_context.get(), &map_dataset));
core::ScopedUnref scoped_unref_map_dataset(map_dataset); core::ScopedUnref scoped_unref_map_dataset(map_dataset);
std::unique_ptr<SerializationContext> serialization_context; std::unique_ptr<SerializationContext> serialization_context;
@ -338,101 +387,114 @@ TEST_F(MapDatasetOpTest, DatasetSave) {
} }
TEST_F(MapDatasetOpTest, IteratorOutputDtypes) { TEST_F(MapDatasetOpTest, IteratorOutputDtypes) {
int64 start = 0, end = 10, step = 1;
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
FunctionDef func_def = test::function::XTimesTwo(); TestCase test_case = TestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
DatasetBase* range_dataset; DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
// which will handle the release of memory.
TF_ASSERT_OK( TF_ASSERT_OK(
CreateRangeDataset<int64>(start, end, step, "range", &range_dataset)); StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
core::ScopedUnref scoped_unref_range_dataset(range_dataset); gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
map_dataset_inputs.emplace_back(&range_dataset_tensor);
std::unique_ptr<OpKernel> map_kernel; std::unique_ptr<OpKernel> map_dataset_kernel;
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>( TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
range_dataset->node_name(), func_def.signature().name(), &map_kernel)); range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
std::unique_ptr<OpKernelContext> map_context; std::unique_ptr<OpKernelContext> map_dataset_context;
TF_ASSERT_OK( TF_ASSERT_OK(CreateMapDatasetContext(
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
DatasetBase* map_dataset; DatasetBase* map_dataset;
TF_ASSERT_OK( TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); map_dataset_context.get(), &map_dataset));
core::ScopedUnref scoped_unref_map_dataset(map_dataset); core::ScopedUnref scoped_unref_map_dataset(map_dataset);
std::unique_ptr<IteratorContext> iterator_context; std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); TF_ASSERT_OK(
CreateIteratorContext(map_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator; std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK( TF_ASSERT_OK(
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
DataTypeVector expected_dtypes({DT_INT64});
EXPECT_EQ(iterator->output_dtypes(), expected_dtypes); TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
test_case.expected_output_dtypes));
} }
TEST_F(MapDatasetOpTest, IteratorOutputShapes) { TEST_F(MapDatasetOpTest, IteratorOutputShapes) {
int64 start = 0, end = 10, step = 1;
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
FunctionDef func_def = test::function::XTimesTwo(); TestCase test_case = TestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
DatasetBase* range_dataset; DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
// which will handle the release of memory.
TF_ASSERT_OK( TF_ASSERT_OK(
CreateRangeDataset<int64>(start, end, step, "range", &range_dataset)); StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
core::ScopedUnref scoped_unref_range_dataset(range_dataset); gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
map_dataset_inputs.emplace_back(&range_dataset_tensor);
std::unique_ptr<OpKernel> map_kernel; std::unique_ptr<OpKernel> map_dataset_kernel;
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>( TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
range_dataset->node_name(), func_def.signature().name(), &map_kernel)); range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
std::unique_ptr<OpKernelContext> map_context; std::unique_ptr<OpKernelContext> map_dataset_context;
TF_ASSERT_OK( TF_ASSERT_OK(CreateMapDatasetContext(
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
DatasetBase* map_dataset; DatasetBase* map_dataset;
TF_ASSERT_OK( TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); map_dataset_context.get(), &map_dataset));
core::ScopedUnref scoped_unref_map_dataset(map_dataset); core::ScopedUnref scoped_unref_map_dataset(map_dataset);
std::unique_ptr<IteratorContext> iterator_context; std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); TF_ASSERT_OK(
CreateIteratorContext(map_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator; std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK( TF_ASSERT_OK(
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})}); TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size()); test_case.expected_output_shapes));
for (int i = 0; i < map_dataset->output_shapes().size(); ++i) {
EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
}
} }
TEST_F(MapDatasetOpTest, IteratorOutputPrefix) { TEST_F(MapDatasetOpTest, IteratorOutputPrefix) {
int64 start = 0, end = 10, step = 1;
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
FunctionDef func_def = test::function::XTimesTwo(); TestCase test_case = TestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
DatasetBase* range_dataset; DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
// which will handle the release of memory.
TF_ASSERT_OK( TF_ASSERT_OK(
CreateRangeDataset<int64>(start, end, step, "range", &range_dataset)); StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
core::ScopedUnref scoped_unref_range_dataset(range_dataset); gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
map_dataset_inputs.emplace_back(&range_dataset_tensor);
std::unique_ptr<OpKernel> map_kernel; std::unique_ptr<OpKernel> map_dataset_kernel;
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>( TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
range_dataset->node_name(), func_def.signature().name(), &map_kernel)); range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
std::unique_ptr<OpKernelContext> map_context; std::unique_ptr<OpKernelContext> map_dataset_context;
TF_ASSERT_OK( TF_ASSERT_OK(CreateMapDatasetContext(
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
DatasetBase* map_dataset; DatasetBase* map_dataset;
TF_ASSERT_OK( TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); map_dataset_context.get(), &map_dataset));
core::ScopedUnref scoped_unref_map_dataset(map_dataset); core::ScopedUnref scoped_unref_map_dataset(map_dataset);
std::unique_ptr<IteratorContext> iterator_context; std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); TF_ASSERT_OK(
CreateIteratorContext(map_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator; std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK( TF_ASSERT_OK(
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
@ -440,95 +502,79 @@ TEST_F(MapDatasetOpTest, IteratorOutputPrefix) {
EXPECT_EQ(iterator->prefix(), "Iterator::Map"); EXPECT_EQ(iterator->prefix(), "Iterator::Map");
} }
struct RoundtripTestParams { TEST_P(ParameterizedMapDatasetOpTest, Roundtrip) {
explicit RoundtripTestParams(int64 input_start, int64 input_end,
int64 input_step, int input_breakpoint,
int64 input_expected_value,
string input_func_name,
std::vector<FunctionDef> input_func_lib)
: start(input_start),
end(input_end),
step(input_step),
breakpoint(input_breakpoint),
expected_value(input_expected_value),
func_name(std::move(input_func_name)),
func_lib(std::move(input_func_lib)) {}
int64 start;
int64 end;
int64 step;
int breakpoint;
int64 expected_value;
string func_name;
std::vector<FunctionDef> func_lib;
};
struct IteratorRoundtripTest
: MapDatasetOpTest,
::testing::WithParamInterface<RoundtripTestParams> {};
TEST_P(IteratorRoundtripTest, Roundtrip) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
RoundtripTestParams test_params = GetParam(); TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_params.func_lib, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
DatasetBase* range_dataset; DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(test_params.start, test_params.end, TF_ASSERT_OK(CreateRangeDataset<int64>(
test_params.step, "range", test_case.start, test_case.end, test_case.step, "range", &range_dataset));
&range_dataset)); Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
core::ScopedUnref scoped_unref_range_dataset(range_dataset); // The ownership of range_dataset is transfered to DatasetVariantWrapper,
// which will handle the release of memory.
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
map_dataset_inputs.emplace_back(&range_dataset_tensor);
std::unique_ptr<OpKernel> map_kernel; std::unique_ptr<OpKernel> map_dataset_kernel;
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>( TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
range_dataset->node_name(), test_params.func_name, &map_kernel)); range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
std::unique_ptr<OpKernelContext> map_context; std::unique_ptr<OpKernelContext> map_dataset_context;
TF_ASSERT_OK( TF_ASSERT_OK(CreateMapDatasetContext(
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
DatasetBase* map_dataset; DatasetBase* map_dataset;
TF_ASSERT_OK( TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); map_dataset_context.get(), &map_dataset));
core::ScopedUnref scoped_unref_map_dataset(map_dataset); core::ScopedUnref scoped_unref_map_dataset(map_dataset);
std::unique_ptr<IteratorContext> iterator_context; std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); TF_ASSERT_OK(
CreateIteratorContext(map_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator; std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK( TF_ASSERT_OK(
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
std::vector<Tensor> out_tensors; std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
bool end_of_sequence = false; bool end_of_sequence = false;
for (int i = 0; i < test_params.breakpoint; i++) { std::vector<Tensor> out_tensors;
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, int cur_iteration = 0;
&end_of_sequence)); auto expected_outputs_it = test_case.expected_outputs.begin();
} const std::vector<int>& breakpoints = test_case.breakpoints;
for (int breakpoint : breakpoints) {
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator->Restore(iterator_context.get(), &reader));
std::unique_ptr<SerializationContext> serialization_context; while (cur_iteration <= breakpoint) {
TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
VariantTensorData data; &end_of_sequence));
VariantTensorDataWriter writer(&data); if (!end_of_sequence) {
TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer)); EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
TF_ASSERT_OK(writer.Flush()); TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it));
VariantTensorDataReader reader(&data); expected_outputs_it++;
TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader)); }
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, cur_iteration++;
&end_of_sequence)); }
EXPECT_EQ(out_tensors.back().flat<int64>()(0), test_params.expected_value);
if (breakpoint >= test_case.expected_cardinality) {
EXPECT_TRUE(end_of_sequence);
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
} else {
EXPECT_FALSE(end_of_sequence);
}
}
} }
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_SUITE_P(MapDatasetOpTest, ParameterizedMapDatasetOpTest,
MapDatasetOpTest, IteratorRoundtripTest, ::testing::ValuesIn(std::vector<TestCase>(
::testing::Values(RoundtripTestParams(0, 10, 2, 0, 0, "XTimesTwo", {TestCase1(), TestCase2(), TestCase3()})));
std::vector<FunctionDef>{
test::function::XTimesTwo()}),
RoundtripTestParams(0, 10, 2, 4, 16, "XAddX",
std::vector<FunctionDef>{
test::function::XAddX()}),
RoundtripTestParams(0, 10, 2, 6, 32, "XTimesFour",
std::vector<FunctionDef>{
test::function::XTimesTwo(),
test::function::XTimesFour()})));
} // namespace } // namespace
} // namespace data } // namespace data