Merge pull request #26772 from feihugis:Refactor_Dataset_Tests
PiperOrigin-RevId: 239061831
This commit is contained in:
commit
b7a36dec6b
@ -49,7 +49,7 @@ class MapDatasetOpTest : public DatasetOpsTestBase {
|
|||||||
FunctionDefHelper::AttrValueWrapper func =
|
FunctionDefHelper::AttrValueWrapper func =
|
||||||
FunctionDefHelper::FunctionRef(func_name, {{"T", DT_INT64}});
|
FunctionDefHelper::FunctionRef(func_name, {{"T", DT_INT64}});
|
||||||
|
|
||||||
map_node_def_ = test::function::NDef(
|
NodeDef map_dataset_node_def = test::function::NDef(
|
||||||
kNodeName, kOpName, {input_dataset},
|
kNodeName, kOpName, {input_dataset},
|
||||||
{{"f", func},
|
{{"f", func},
|
||||||
{"Targuments", {}},
|
{"Targuments", {}},
|
||||||
@ -58,136 +58,198 @@ class MapDatasetOpTest : public DatasetOpsTestBase {
|
|||||||
gtl::ArraySlice<DataType>{tensorflow::DataTypeToEnum<T>::value}},
|
gtl::ArraySlice<DataType>{tensorflow::DataTypeToEnum<T>::value}},
|
||||||
{"use_inter_op_parallelism", true},
|
{"use_inter_op_parallelism", true},
|
||||||
{"preserve_cardinality", false}});
|
{"preserve_cardinality", false}});
|
||||||
TF_CHECK_OK(CreateOpKernel(map_node_def_, map_kernel));
|
TF_RETURN_IF_ERROR(CreateOpKernel(map_dataset_node_def, map_kernel));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a new MapDataset op kernel context.
|
// Creates a new MapDataset op kernel context.
|
||||||
Status CreateMapDatasetContext(
|
Status CreateMapDatasetContext(
|
||||||
DatasetBase* const input_dataset, OpKernel* const map_kernel,
|
OpKernel* const map_kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
|
||||||
std::unique_ptr<OpKernelContext>* map_context) {
|
std::unique_ptr<OpKernelContext>* map_context) {
|
||||||
map_inputs_.clear();
|
TF_RETURN_IF_ERROR(CheckOpKernelInput(*map_kernel, *inputs));
|
||||||
// Save the input dataset into a variant tensor as the input of MapDataset.
|
TF_RETURN_IF_ERROR(CreateOpKernelContext(map_kernel, inputs, map_context));
|
||||||
Tensor dataset_tensor(DT_VARIANT, TensorShape({}));
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
StoreDatasetInVariantTensor(input_dataset, &dataset_tensor));
|
|
||||||
Variant variant = dataset_tensor.scalar<Variant>()();
|
|
||||||
TF_RETURN_IF_ERROR(AddDatasetInputFromArray<Variant>(
|
|
||||||
&map_inputs_, map_kernel->input_types(), TensorShape({}), {variant}));
|
|
||||||
input_dataset->Ref();
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
CreateOpKernelContext(map_kernel, &map_inputs_, map_context));
|
|
||||||
TF_RETURN_IF_ERROR(CheckOpKernelInput(*map_kernel, map_inputs_));
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
NodeDef map_node_def_;
|
|
||||||
gtl::InlinedVector<TensorValue, 4> map_inputs_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct GetNextTestParams {
|
struct TestCase {
|
||||||
explicit GetNextTestParams(int64 input_start, int64 input_end,
|
|
||||||
int64 input_step, string input_func_name,
|
|
||||||
std::vector<int64> input_expected_values,
|
|
||||||
std::vector<FunctionDef> input_func_lib)
|
|
||||||
: start(input_start),
|
|
||||||
end(input_end),
|
|
||||||
step(input_step),
|
|
||||||
func_name(std::move(input_func_name)),
|
|
||||||
expected_values(std::move(input_expected_values)),
|
|
||||||
func_lib(std::move(input_func_lib)) {}
|
|
||||||
|
|
||||||
int64 start;
|
int64 start;
|
||||||
int64 end;
|
int64 end;
|
||||||
int64 step;
|
int64 step;
|
||||||
string func_name;
|
string func_name;
|
||||||
std::vector<int64> expected_values;
|
|
||||||
std::vector<FunctionDef> func_lib;
|
std::vector<FunctionDef> func_lib;
|
||||||
|
std::vector<Tensor> expected_outputs;
|
||||||
|
DataTypeVector expected_output_dtypes;
|
||||||
|
std::vector<PartialTensorShape> expected_output_shapes;
|
||||||
|
int64 expected_cardinality;
|
||||||
|
std::vector<int> breakpoints;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DatasetGetNextTest : MapDatasetOpTest,
|
TestCase TestCase1() {
|
||||||
::testing::WithParamInterface<GetNextTestParams> {};
|
return {/*start*/ 0,
|
||||||
|
/*end*/ 10,
|
||||||
|
/*step*/ 3,
|
||||||
|
/*func_name*/ "XTimesTwo",
|
||||||
|
/*func_lib*/ {test::function::XTimesTwo()},
|
||||||
|
/*expected_outputs*/
|
||||||
|
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {18})},
|
||||||
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
|
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||||
|
/*expected_cardinality*/ 4,
|
||||||
|
/*breakpoints*/ {0, 1, 5}};
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(DatasetGetNextTest, GetNext) {
|
TestCase TestCase2() {
|
||||||
|
return {/*start*/ 10,
|
||||||
|
/*end*/ 0,
|
||||||
|
/*step*/ -3,
|
||||||
|
/*func_name*/ "XAddX",
|
||||||
|
/*func_lib*/ {test::function::XAddX()},
|
||||||
|
/*expected_outputs*/
|
||||||
|
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {20}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {14}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {8}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
|
||||||
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
|
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||||
|
/*expected_cardinality*/ 4,
|
||||||
|
/*breakpoints*/ {0, 1, 5}};
|
||||||
|
}
|
||||||
|
|
||||||
|
// In this test case, the function `XTimesFour()` will call `XTimesTwo()`, so
|
||||||
|
// both of them are added to the function library.
|
||||||
|
TestCase TestCase3() {
|
||||||
|
return {
|
||||||
|
/*start*/ 0,
|
||||||
|
/*end*/ 10,
|
||||||
|
/*step*/ 3,
|
||||||
|
/*func_name*/ "XTimesFour",
|
||||||
|
/*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()},
|
||||||
|
/*expected_outputs*/
|
||||||
|
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {12}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {24}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {36})},
|
||||||
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
|
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||||
|
/*expected_cardinality*/ 4,
|
||||||
|
/*breakpoints*/ {0, 1, 5}};
|
||||||
|
}
|
||||||
|
|
||||||
|
class ParameterizedMapDatasetOpTest
|
||||||
|
: public MapDatasetOpTest,
|
||||||
|
public ::testing::WithParamInterface<TestCase> {};
|
||||||
|
|
||||||
|
TEST_P(ParameterizedMapDatasetOpTest, GetNext) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
GetNextTestParams test_params = GetParam();
|
TestCase test_case = GetParam();
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_params.func_lib, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||||
|
|
||||||
DatasetBase* range_dataset;
|
DatasetBase* range_dataset;
|
||||||
TF_ASSERT_OK(CreateRangeDataset<int64>(test_params.start, test_params.end,
|
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||||
test_params.step, "range",
|
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
|
||||||
&range_dataset));
|
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||||
core::ScopedUnref scored_unref_range_dataset(range_dataset);
|
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
|
||||||
|
// which will handle the release of memory.
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||||
|
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
|
||||||
|
map_dataset_inputs.emplace_back(&range_dataset_tensor);
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> map_kernel;
|
std::unique_ptr<OpKernel> map_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
||||||
range_dataset->node_name(), test_params.func_name, &map_kernel));
|
range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> map_context;
|
std::unique_ptr<OpKernelContext> map_dataset_context;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateMapDatasetContext(
|
||||||
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
|
map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
|
||||||
DatasetBase* map_dataset;
|
DatasetBase* map_dataset;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
|
||||||
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
|
map_dataset_context.get(), &map_dataset));
|
||||||
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
||||||
|
|
||||||
std::unique_ptr<IteratorContext> iterator_context;
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context));
|
TF_ASSERT_OK(
|
||||||
|
CreateIteratorContext(map_dataset_context.get(), &iterator_context));
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
|
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
|
||||||
|
|
||||||
bool end_of_sequence = false;
|
bool end_of_sequence = false;
|
||||||
|
auto expected_outputs_it = test_case.expected_outputs.begin();
|
||||||
std::vector<Tensor> out_tensors;
|
std::vector<Tensor> out_tensors;
|
||||||
while (!end_of_sequence) {
|
while (!end_of_sequence) {
|
||||||
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
||||||
&end_of_sequence));
|
&end_of_sequence));
|
||||||
|
if (!end_of_sequence) {
|
||||||
|
EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
|
||||||
|
TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it));
|
||||||
|
expected_outputs_it++;
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_EQ(out_tensors.size(), test_params.expected_values.size());
|
|
||||||
for (size_t i = 0; i < out_tensors.size(); ++i) {
|
|
||||||
int64 actual_value = out_tensors[i].flat<int64>()(0);
|
|
||||||
int64 expect_value = test_params.expected_values[i];
|
|
||||||
EXPECT_EQ(actual_value, expect_value);
|
|
||||||
}
|
}
|
||||||
|
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(
|
TEST_F(MapDatasetOpTest, DatasetNodeName) {
|
||||||
MapDatasetOpTest, DatasetGetNextTest,
|
|
||||||
::testing::Values(
|
|
||||||
GetNextTestParams(
|
|
||||||
0, 10, 3, "XTimesTwo", std::vector<int64>{0, 6, 12, 18},
|
|
||||||
std::vector<FunctionDef>{test::function::XTimesTwo()}),
|
|
||||||
GetNextTestParams(0, 10, 3, "XAddX", std::vector<int64>{0, 6, 12, 18},
|
|
||||||
std::vector<FunctionDef>{test::function::XAddX()}),
|
|
||||||
GetNextTestParams(
|
|
||||||
10, 0, -3, "XTimesFour", std::vector<int64>{40, 28, 16, 4},
|
|
||||||
std::vector<FunctionDef>{test::function::XTimesTwo(),
|
|
||||||
test::function::XTimesFour()})));
|
|
||||||
|
|
||||||
TEST_F(MapDatasetOpTest, DatasetName) {
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
int64 start = 0, end = 10, step = 1;
|
TestCase test_case = TestCase1();
|
||||||
FunctionDef func_def = test::function::XTimesTwo();
|
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||||
|
|
||||||
DatasetBase* range_dataset;
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||||
|
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
|
||||||
|
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||||
|
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
|
||||||
|
// which will handle the release of memory.
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateRangeDataset<int64>(start, end, step, "range", &range_dataset));
|
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||||
core::ScopedUnref scoped_unref_range_dataset(range_dataset);
|
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
|
||||||
|
map_dataset_inputs.emplace_back(&range_dataset_tensor);
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> map_kernel;
|
std::unique_ptr<OpKernel> map_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
||||||
range_dataset->node_name(), func_def.signature().name(), &map_kernel));
|
range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> map_context;
|
std::unique_ptr<OpKernelContext> map_dataset_context;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateMapDatasetContext(
|
||||||
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
|
map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
|
||||||
DatasetBase* map_dataset;
|
DatasetBase* map_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
|
||||||
|
map_dataset_context.get(), &map_dataset));
|
||||||
|
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
||||||
|
|
||||||
|
EXPECT_EQ(map_dataset->node_name(), kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MapDatasetOpTest, DatasetTypeString) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TestCase test_case = TestCase1();
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||||
|
|
||||||
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||||
|
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
|
||||||
|
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||||
|
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
|
||||||
|
// which will handle the release of memory.
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
|
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||||
|
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
|
||||||
|
map_dataset_inputs.emplace_back(&range_dataset_tensor);
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> map_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
||||||
|
range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> map_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateMapDatasetContext(
|
||||||
|
map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
|
||||||
|
DatasetBase* map_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
|
||||||
|
map_dataset_context.get(), &map_dataset));
|
||||||
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
||||||
|
|
||||||
EXPECT_EQ(map_dataset->type_string(), kOpName);
|
EXPECT_EQ(map_dataset->type_string(), kOpName);
|
||||||
@ -195,138 +257,125 @@ TEST_F(MapDatasetOpTest, DatasetName) {
|
|||||||
|
|
||||||
TEST_F(MapDatasetOpTest, DatasetOutputDtypes) {
|
TEST_F(MapDatasetOpTest, DatasetOutputDtypes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
int64 start = 0, end = 10, step = 1;
|
TestCase test_case = TestCase1();
|
||||||
FunctionDef func_def = test::function::XTimesTwo();
|
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||||
|
|
||||||
DatasetBase* range_dataset;
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||||
|
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
|
||||||
|
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||||
|
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
|
||||||
|
// which will handle the release of memory.
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateRangeDataset<int64>(start, end, step, "range", &range_dataset));
|
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||||
core::ScopedUnref scoped_unref_range_dataset(range_dataset);
|
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
|
||||||
|
map_dataset_inputs.emplace_back(&range_dataset_tensor);
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> map_kernel;
|
std::unique_ptr<OpKernel> map_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
||||||
range_dataset->node_name(), func_def.signature().name(), &map_kernel));
|
range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> map_context;
|
std::unique_ptr<OpKernelContext> map_dataset_context;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateMapDatasetContext(
|
||||||
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
|
map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
|
||||||
DatasetBase* map_dataset;
|
DatasetBase* map_dataset;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
|
||||||
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
|
map_dataset_context.get(), &map_dataset));
|
||||||
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
||||||
|
|
||||||
DataTypeVector expected_dtypes({DT_INT64});
|
TF_EXPECT_OK(VerifyTypesMatch(map_dataset->output_dtypes(),
|
||||||
EXPECT_EQ(map_dataset->output_dtypes(), expected_dtypes);
|
test_case.expected_output_dtypes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MapDatasetOpTest, DatasetOutputShapes) {
|
TEST_F(MapDatasetOpTest, DatasetOutputShapes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
int64 start = 0, end = 10, step = 1;
|
TestCase test_case = TestCase1();
|
||||||
FunctionDef func_def = test::function::XTimesTwo();
|
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||||
|
|
||||||
DatasetBase* range_dataset;
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||||
|
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
|
||||||
|
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||||
|
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
|
||||||
|
// which will handle the release of memory.
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateRangeDataset<int64>(start, end, step, "range", &range_dataset));
|
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||||
core::ScopedUnref scoped_unref_range_dataset(range_dataset);
|
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
|
||||||
|
map_dataset_inputs.emplace_back(&range_dataset_tensor);
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> map_kernel;
|
std::unique_ptr<OpKernel> map_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
||||||
range_dataset->node_name(), func_def.signature().name(), &map_kernel));
|
range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> map_context;
|
std::unique_ptr<OpKernelContext> map_dataset_context;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateMapDatasetContext(
|
||||||
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
|
map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
|
||||||
DatasetBase* map_dataset;
|
DatasetBase* map_dataset;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
|
||||||
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
|
map_dataset_context.get(), &map_dataset));
|
||||||
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
||||||
|
|
||||||
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
|
TF_EXPECT_OK(VerifyShapesCompatible(map_dataset->output_shapes(),
|
||||||
EXPECT_EQ(map_dataset->output_shapes().size(), expected_shapes.size());
|
test_case.expected_output_shapes));
|
||||||
for (int i = 0; i < map_dataset->output_shapes().size(); ++i) {
|
|
||||||
EXPECT_TRUE(
|
|
||||||
map_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CardinalityTestParams {
|
TEST_P(ParameterizedMapDatasetOpTest, Cardinality) {
|
||||||
explicit CardinalityTestParams(int64 input_start, int64 input_end,
|
|
||||||
int64 input_step,
|
|
||||||
int input_expected_cardinality)
|
|
||||||
: start(input_start),
|
|
||||||
end(input_end),
|
|
||||||
step(input_step),
|
|
||||||
expected_cardinality(input_expected_cardinality) {}
|
|
||||||
|
|
||||||
int64 start;
|
|
||||||
int64 end;
|
|
||||||
int64 step;
|
|
||||||
int expected_cardinality;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DatasetCardinalityTest
|
|
||||||
: MapDatasetOpTest,
|
|
||||||
::testing::WithParamInterface<CardinalityTestParams> {};
|
|
||||||
|
|
||||||
TEST_P(DatasetCardinalityTest, Cardinality) {
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
CardinalityTestParams test_params = GetParam();
|
TestCase test_case = GetParam();
|
||||||
FunctionDef func_def = test::function::XTimesTwo();
|
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||||
|
|
||||||
DatasetBase* range_dataset;
|
DatasetBase* range_dataset;
|
||||||
TF_ASSERT_OK(CreateRangeDataset<int64>(test_params.start, test_params.end,
|
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||||
test_params.step, "range",
|
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
|
||||||
&range_dataset));
|
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||||
core::ScopedUnref scoped_unref_range_dataset(range_dataset);
|
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
|
||||||
|
// which will handle the release of memory.
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||||
|
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
|
||||||
|
map_dataset_inputs.emplace_back(&range_dataset_tensor);
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> map_kernel;
|
std::unique_ptr<OpKernel> map_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
||||||
range_dataset->node_name(), func_def.signature().name(), &map_kernel));
|
range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> map_context;
|
std::unique_ptr<OpKernelContext> map_dataset_context;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateMapDatasetContext(
|
||||||
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
|
map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
|
||||||
DatasetBase* map_dataset;
|
DatasetBase* map_dataset;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
|
||||||
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
|
map_dataset_context.get(), &map_dataset));
|
||||||
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
||||||
|
|
||||||
EXPECT_EQ(map_dataset->Cardinality(), test_params.expected_cardinality);
|
EXPECT_EQ(map_dataset->Cardinality(), test_case.expected_cardinality);
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(MapDatasetOpTest, DatasetCardinalityTest,
|
TEST_P(ParameterizedMapDatasetOpTest, DatasetSave) {
|
||||||
::testing::Values(CardinalityTestParams(0, 10, 1, 10),
|
|
||||||
CardinalityTestParams(0, 10, 3, 4),
|
|
||||||
CardinalityTestParams(10, 0, -3, 4)));
|
|
||||||
|
|
||||||
TEST_F(MapDatasetOpTest, DatasetSave) {
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
int64 start = 0, end = 10, step = 1;
|
TestCase test_case = GetParam();
|
||||||
FunctionDef func_def = test::function::XTimesTwo();
|
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||||
|
|
||||||
DatasetBase* range_dataset;
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||||
|
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
|
||||||
|
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||||
|
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
|
||||||
|
// which will handle the release of memory.
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateRangeDataset<int64>(start, end, step, "range", &range_dataset));
|
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||||
core::ScopedUnref scoped_unref_range_dataset(range_dataset);
|
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
|
||||||
|
map_dataset_inputs.emplace_back(&range_dataset_tensor);
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> map_kernel;
|
std::unique_ptr<OpKernel> map_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
||||||
range_dataset->node_name(), func_def.signature().name(), &map_kernel));
|
range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> map_context;
|
std::unique_ptr<OpKernelContext> map_dataset_context;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateMapDatasetContext(
|
||||||
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
|
map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
|
||||||
DatasetBase* map_dataset;
|
DatasetBase* map_dataset;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
|
||||||
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
|
map_dataset_context.get(), &map_dataset));
|
||||||
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
||||||
|
|
||||||
std::unique_ptr<SerializationContext> serialization_context;
|
std::unique_ptr<SerializationContext> serialization_context;
|
||||||
@ -338,101 +387,114 @@ TEST_F(MapDatasetOpTest, DatasetSave) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MapDatasetOpTest, IteratorOutputDtypes) {
|
TEST_F(MapDatasetOpTest, IteratorOutputDtypes) {
|
||||||
int64 start = 0, end = 10, step = 1;
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
FunctionDef func_def = test::function::XTimesTwo();
|
TestCase test_case = TestCase1();
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||||
|
|
||||||
DatasetBase* range_dataset;
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||||
|
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
|
||||||
|
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||||
|
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
|
||||||
|
// which will handle the release of memory.
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateRangeDataset<int64>(start, end, step, "range", &range_dataset));
|
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||||
core::ScopedUnref scoped_unref_range_dataset(range_dataset);
|
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
|
||||||
|
map_dataset_inputs.emplace_back(&range_dataset_tensor);
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> map_kernel;
|
std::unique_ptr<OpKernel> map_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
||||||
range_dataset->node_name(), func_def.signature().name(), &map_kernel));
|
range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> map_context;
|
std::unique_ptr<OpKernelContext> map_dataset_context;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateMapDatasetContext(
|
||||||
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
|
map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
|
||||||
DatasetBase* map_dataset;
|
DatasetBase* map_dataset;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
|
||||||
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
|
map_dataset_context.get(), &map_dataset));
|
||||||
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
||||||
|
|
||||||
std::unique_ptr<IteratorContext> iterator_context;
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context));
|
TF_ASSERT_OK(
|
||||||
|
CreateIteratorContext(map_dataset_context.get(), &iterator_context));
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
|
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
|
||||||
DataTypeVector expected_dtypes({DT_INT64});
|
|
||||||
EXPECT_EQ(iterator->output_dtypes(), expected_dtypes);
|
TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
|
||||||
|
test_case.expected_output_dtypes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MapDatasetOpTest, IteratorOutputShapes) {
|
TEST_F(MapDatasetOpTest, IteratorOutputShapes) {
|
||||||
int64 start = 0, end = 10, step = 1;
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
FunctionDef func_def = test::function::XTimesTwo();
|
TestCase test_case = TestCase1();
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||||
|
|
||||||
DatasetBase* range_dataset;
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||||
|
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
|
||||||
|
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||||
|
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
|
||||||
|
// which will handle the release of memory.
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateRangeDataset<int64>(start, end, step, "range", &range_dataset));
|
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||||
core::ScopedUnref scoped_unref_range_dataset(range_dataset);
|
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
|
||||||
|
map_dataset_inputs.emplace_back(&range_dataset_tensor);
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> map_kernel;
|
std::unique_ptr<OpKernel> map_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
||||||
range_dataset->node_name(), func_def.signature().name(), &map_kernel));
|
range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> map_context;
|
std::unique_ptr<OpKernelContext> map_dataset_context;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateMapDatasetContext(
|
||||||
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
|
map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
|
||||||
DatasetBase* map_dataset;
|
DatasetBase* map_dataset;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
|
||||||
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
|
map_dataset_context.get(), &map_dataset));
|
||||||
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
||||||
|
|
||||||
std::unique_ptr<IteratorContext> iterator_context;
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context));
|
TF_ASSERT_OK(
|
||||||
|
CreateIteratorContext(map_dataset_context.get(), &iterator_context));
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
|
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
|
||||||
|
|
||||||
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
|
TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
|
||||||
EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size());
|
test_case.expected_output_shapes));
|
||||||
for (int i = 0; i < map_dataset->output_shapes().size(); ++i) {
|
|
||||||
EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MapDatasetOpTest, IteratorOutputPrefix) {
|
TEST_F(MapDatasetOpTest, IteratorOutputPrefix) {
|
||||||
int64 start = 0, end = 10, step = 1;
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
FunctionDef func_def = test::function::XTimesTwo();
|
TestCase test_case = TestCase1();
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||||
|
|
||||||
DatasetBase* range_dataset;
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||||
|
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
|
||||||
|
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||||
|
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
|
||||||
|
// which will handle the release of memory.
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateRangeDataset<int64>(start, end, step, "range", &range_dataset));
|
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||||
core::ScopedUnref scoped_unref_range_dataset(range_dataset);
|
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
|
||||||
|
map_dataset_inputs.emplace_back(&range_dataset_tensor);
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> map_kernel;
|
std::unique_ptr<OpKernel> map_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
||||||
range_dataset->node_name(), func_def.signature().name(), &map_kernel));
|
range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> map_context;
|
std::unique_ptr<OpKernelContext> map_dataset_context;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateMapDatasetContext(
|
||||||
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
|
map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
|
||||||
DatasetBase* map_dataset;
|
DatasetBase* map_dataset;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
|
||||||
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
|
map_dataset_context.get(), &map_dataset));
|
||||||
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
||||||
|
|
||||||
std::unique_ptr<IteratorContext> iterator_context;
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context));
|
TF_ASSERT_OK(
|
||||||
|
CreateIteratorContext(map_dataset_context.get(), &iterator_context));
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
|
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
|
||||||
@ -440,95 +502,79 @@ TEST_F(MapDatasetOpTest, IteratorOutputPrefix) {
|
|||||||
EXPECT_EQ(iterator->prefix(), "Iterator::Map");
|
EXPECT_EQ(iterator->prefix(), "Iterator::Map");
|
||||||
}
|
}
|
||||||
|
|
||||||
struct RoundtripTestParams {
|
TEST_P(ParameterizedMapDatasetOpTest, Roundtrip) {
|
||||||
explicit RoundtripTestParams(int64 input_start, int64 input_end,
|
|
||||||
int64 input_step, int input_breakpoint,
|
|
||||||
int64 input_expected_value,
|
|
||||||
string input_func_name,
|
|
||||||
std::vector<FunctionDef> input_func_lib)
|
|
||||||
: start(input_start),
|
|
||||||
end(input_end),
|
|
||||||
step(input_step),
|
|
||||||
breakpoint(input_breakpoint),
|
|
||||||
expected_value(input_expected_value),
|
|
||||||
func_name(std::move(input_func_name)),
|
|
||||||
func_lib(std::move(input_func_lib)) {}
|
|
||||||
|
|
||||||
int64 start;
|
|
||||||
int64 end;
|
|
||||||
int64 step;
|
|
||||||
int breakpoint;
|
|
||||||
int64 expected_value;
|
|
||||||
string func_name;
|
|
||||||
std::vector<FunctionDef> func_lib;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct IteratorRoundtripTest
|
|
||||||
: MapDatasetOpTest,
|
|
||||||
::testing::WithParamInterface<RoundtripTestParams> {};
|
|
||||||
|
|
||||||
TEST_P(IteratorRoundtripTest, Roundtrip) {
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
RoundtripTestParams test_params = GetParam();
|
TestCase test_case = GetParam();
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_params.func_lib, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||||
|
|
||||||
DatasetBase* range_dataset;
|
DatasetBase* range_dataset;
|
||||||
TF_ASSERT_OK(CreateRangeDataset<int64>(test_params.start, test_params.end,
|
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||||
test_params.step, "range",
|
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
|
||||||
&range_dataset));
|
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||||
core::ScopedUnref scoped_unref_range_dataset(range_dataset);
|
// The ownership of range_dataset is transfered to DatasetVariantWrapper,
|
||||||
|
// which will handle the release of memory.
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||||
|
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
|
||||||
|
map_dataset_inputs.emplace_back(&range_dataset_tensor);
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> map_kernel;
|
std::unique_ptr<OpKernel> map_dataset_kernel;
|
||||||
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
|
||||||
range_dataset->node_name(), test_params.func_name, &map_kernel));
|
range_dataset->node_name(), test_case.func_name, &map_dataset_kernel));
|
||||||
std::unique_ptr<OpKernelContext> map_context;
|
std::unique_ptr<OpKernelContext> map_dataset_context;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateMapDatasetContext(
|
||||||
CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
|
map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
|
||||||
DatasetBase* map_dataset;
|
DatasetBase* map_dataset;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
|
||||||
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
|
map_dataset_context.get(), &map_dataset));
|
||||||
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
||||||
|
|
||||||
std::unique_ptr<IteratorContext> iterator_context;
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context));
|
TF_ASSERT_OK(
|
||||||
|
CreateIteratorContext(map_dataset_context.get(), &iterator_context));
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
|
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
|
||||||
|
|
||||||
std::vector<Tensor> out_tensors;
|
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||||
|
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||||
bool end_of_sequence = false;
|
bool end_of_sequence = false;
|
||||||
for (int i = 0; i < test_params.breakpoint; i++) {
|
std::vector<Tensor> out_tensors;
|
||||||
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
int cur_iteration = 0;
|
||||||
&end_of_sequence));
|
auto expected_outputs_it = test_case.expected_outputs.begin();
|
||||||
}
|
const std::vector<int>& breakpoints = test_case.breakpoints;
|
||||||
|
for (int breakpoint : breakpoints) {
|
||||||
std::unique_ptr<SerializationContext> serialization_context;
|
|
||||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
|
||||||
VariantTensorData data;
|
VariantTensorData data;
|
||||||
VariantTensorDataWriter writer(&data);
|
VariantTensorDataWriter writer(&data);
|
||||||
TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer));
|
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
|
||||||
TF_ASSERT_OK(writer.Flush());
|
TF_EXPECT_OK(writer.Flush());
|
||||||
VariantTensorDataReader reader(&data);
|
VariantTensorDataReader reader(&data);
|
||||||
TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader));
|
TF_EXPECT_OK(iterator->Restore(iterator_context.get(), &reader));
|
||||||
|
|
||||||
|
while (cur_iteration <= breakpoint) {
|
||||||
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
||||||
&end_of_sequence));
|
&end_of_sequence));
|
||||||
EXPECT_EQ(out_tensors.back().flat<int64>()(0), test_params.expected_value);
|
if (!end_of_sequence) {
|
||||||
|
EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
|
||||||
|
TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it));
|
||||||
|
expected_outputs_it++;
|
||||||
|
}
|
||||||
|
cur_iteration++;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (breakpoint >= test_case.expected_cardinality) {
|
||||||
|
EXPECT_TRUE(end_of_sequence);
|
||||||
|
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
|
||||||
|
} else {
|
||||||
|
EXPECT_FALSE(end_of_sequence);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(
|
INSTANTIATE_TEST_SUITE_P(MapDatasetOpTest, ParameterizedMapDatasetOpTest,
|
||||||
MapDatasetOpTest, IteratorRoundtripTest,
|
::testing::ValuesIn(std::vector<TestCase>(
|
||||||
::testing::Values(RoundtripTestParams(0, 10, 2, 0, 0, "XTimesTwo",
|
{TestCase1(), TestCase2(), TestCase3()})));
|
||||||
std::vector<FunctionDef>{
|
|
||||||
test::function::XTimesTwo()}),
|
|
||||||
RoundtripTestParams(0, 10, 2, 4, 16, "XAddX",
|
|
||||||
std::vector<FunctionDef>{
|
|
||||||
test::function::XAddX()}),
|
|
||||||
RoundtripTestParams(0, 10, 2, 6, 32, "XTimesFour",
|
|
||||||
std::vector<FunctionDef>{
|
|
||||||
test::function::XTimesTwo(),
|
|
||||||
test::function::XTimesFour()})));
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
@ -13,237 +13,324 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
|
||||||
#include "tensorflow/core/framework/fake_input.h"
|
|
||||||
#include "tensorflow/core/framework/function_testlib.h"
|
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
|
||||||
#include "tensorflow/core/framework/variant.h"
|
|
||||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
|
||||||
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
|
||||||
#include "tensorflow/core/kernels/data/iterator_ops.h"
|
|
||||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
|
||||||
#include "tensorflow/core/platform/test.h"
|
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kNodeName[] = "range_dataset";
|
||||||
constexpr char kOpName[] = "RangeDataset";
|
constexpr char kOpName[] = "RangeDataset";
|
||||||
|
|
||||||
class RangeDatasetOpTest : public DatasetOpsTestBase {
|
class RangeDatasetOpTest : public DatasetOpsTestBase {
|
||||||
protected:
|
protected:
|
||||||
// Creates a new RangeDataset op kernel context.
|
// Creates a new RangeDataset op kernel context.
|
||||||
Status CreateRangeDatasetContext(
|
Status CreateRangeDatasetContext(
|
||||||
int64 start, int64 end, int64 step, OpKernel* const range_kernel,
|
OpKernel* const range_kernel,
|
||||||
|
gtl::InlinedVector<TensorValue, 4>* const inputs,
|
||||||
std::unique_ptr<OpKernelContext>* range_context) {
|
std::unique_ptr<OpKernelContext>* range_context) {
|
||||||
inputs_.clear();
|
TF_RETURN_IF_ERROR(CheckOpKernelInput(*range_kernel, *inputs));
|
||||||
TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>(
|
|
||||||
&inputs_, range_kernel->input_types(), TensorShape({}), {start}));
|
|
||||||
TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>(
|
|
||||||
&inputs_, range_kernel->input_types(), TensorShape({}), {end}));
|
|
||||||
TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>(
|
|
||||||
&inputs_, range_kernel->input_types(), TensorShape({}), {step}));
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
CreateOpKernelContext(range_kernel, &inputs_, range_context));
|
CreateOpKernelContext(range_kernel, inputs, range_context));
|
||||||
TF_RETURN_IF_ERROR(CheckOpKernelInput(*range_kernel, inputs_));
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
gtl::InlinedVector<TensorValue, 4> inputs_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct GetNextTestParams {
|
struct TestCase {
|
||||||
explicit GetNextTestParams(int64 input_start, int64 input_end,
|
|
||||||
int64 input_step)
|
|
||||||
: start(input_start), end(input_end), step(input_step) {}
|
|
||||||
|
|
||||||
int64 start;
|
int64 start;
|
||||||
int64 end;
|
int64 end;
|
||||||
int64 step;
|
int64 step;
|
||||||
|
std::vector<Tensor> expected_outputs;
|
||||||
|
DataTypeVector expected_output_dtypes;
|
||||||
|
std::vector<PartialTensorShape> expected_output_shapes;
|
||||||
|
int64 expected_cardinality;
|
||||||
|
std::vector<int> breakpoints;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DatasetGetNextTest : RangeDatasetOpTest,
|
TestCase PositiveStepTestCase() {
|
||||||
::testing::WithParamInterface<GetNextTestParams> {};
|
return {/*start*/ 0,
|
||||||
|
/*end*/ 10,
|
||||||
|
/*step*/ 3,
|
||||||
|
/*expected_outputs*/
|
||||||
|
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {9})},
|
||||||
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
|
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||||
|
/*expected_cardinality*/ 4,
|
||||||
|
/*breakpoints*/ {0, 1, 4}};
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(DatasetGetNextTest, GetNext) {
|
TestCase NegativeStepTestCase() {
|
||||||
|
return {/*start*/ 10,
|
||||||
|
/*end*/ 0,
|
||||||
|
/*step*/ -3,
|
||||||
|
/*expected_outputs*/
|
||||||
|
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {7}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1})},
|
||||||
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
|
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||||
|
/*expected_cardinality*/ 4,
|
||||||
|
/*breakpoints*/ {0, 1, 4}};
|
||||||
|
}
|
||||||
|
|
||||||
|
TestCase ZeroStepTestCase() {
|
||||||
|
return {/*start*/ 0,
|
||||||
|
/*end*/ 10,
|
||||||
|
/*step*/ 0,
|
||||||
|
/*expected_outputs*/ {},
|
||||||
|
/*expected_output_dtypes*/ {},
|
||||||
|
/*expected_output_shapes*/ {},
|
||||||
|
/*expected_cardinality*/ 0,
|
||||||
|
/*breakpoints*/ {}};
|
||||||
|
}
|
||||||
|
|
||||||
|
class ParameterizedRangeDatasetOpTest
|
||||||
|
: public RangeDatasetOpTest,
|
||||||
|
public ::testing::WithParamInterface<TestCase> {};
|
||||||
|
|
||||||
|
TEST_P(ParameterizedRangeDatasetOpTest, GetNext) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
GetNextTestParams params = GetParam();
|
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> range_kernel;
|
TestCase test_case = GetParam();
|
||||||
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
std::unique_ptr<OpKernelContext> range_context;
|
Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
|
||||||
TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step,
|
Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
|
||||||
range_kernel.get(), &range_context));
|
Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
|
||||||
DatasetBase* range_dataset;
|
inputs.emplace_back(&start);
|
||||||
|
inputs.emplace_back(&end);
|
||||||
|
inputs.emplace_back(&step);
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> range_dataset_kernel;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
|
CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> range_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
|
||||||
|
&range_dataset_context));
|
||||||
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
|
||||||
|
range_dataset_context.get(), &range_dataset));
|
||||||
core::ScopedUnref scoped_unref(range_dataset);
|
core::ScopedUnref scoped_unref(range_dataset);
|
||||||
|
|
||||||
std::unique_ptr<IteratorContext> iterator_context;
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context));
|
TF_ASSERT_OK(
|
||||||
|
CreateIteratorContext(range_dataset_context.get(), &iterator_context));
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
|
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
|
||||||
&iterator));
|
&iterator));
|
||||||
|
|
||||||
bool end_of_sequence = false;
|
bool end_of_sequence = false;
|
||||||
|
auto expected_outputs_it = test_case.expected_outputs.begin();
|
||||||
std::vector<Tensor> out_tensors;
|
std::vector<Tensor> out_tensors;
|
||||||
while (!end_of_sequence) {
|
while (!end_of_sequence) {
|
||||||
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
||||||
&end_of_sequence));
|
&end_of_sequence));
|
||||||
|
if (!end_of_sequence) {
|
||||||
|
EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
|
||||||
|
TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it));
|
||||||
|
expected_outputs_it++;
|
||||||
}
|
}
|
||||||
std::vector<int> expected_values;
|
|
||||||
for (int i = params.start; (params.end - i) * params.step > 0;
|
|
||||||
i = i + params.step) {
|
|
||||||
expected_values.reserve(1);
|
|
||||||
expected_values.emplace_back(i);
|
|
||||||
}
|
|
||||||
EXPECT_EQ(out_tensors.size(), expected_values.size());
|
|
||||||
for (size_t i = 0; i < out_tensors.size(); ++i) {
|
|
||||||
int64 actual_value = out_tensors[i].flat<int64>()(0);
|
|
||||||
int64 expect_value = expected_values[i];
|
|
||||||
EXPECT_EQ(actual_value, expect_value);
|
|
||||||
}
|
}
|
||||||
|
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(RangeDatasetOpTest, DatasetGetNextTest,
|
TEST_F(RangeDatasetOpTest, ZeroStep) {
|
||||||
::testing::Values(GetNextTestParams(0, 10, 1),
|
|
||||||
GetNextTestParams(0, 10, 3),
|
|
||||||
GetNextTestParams(10, 0, -1),
|
|
||||||
GetNextTestParams(10, 0, -3)));
|
|
||||||
|
|
||||||
TEST_F(RangeDatasetOpTest, DatasetName) {
|
|
||||||
int64 start = 0, end = 10, step = 1;
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> range_kernel;
|
TestCase test_case = ZeroStepTestCase();
|
||||||
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
std::unique_ptr<OpKernelContext> range_context;
|
Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
|
||||||
TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(),
|
Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
|
||||||
&range_context));
|
Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
|
||||||
DatasetBase* range_dataset;
|
inputs.emplace_back(&start);
|
||||||
|
inputs.emplace_back(&end);
|
||||||
|
inputs.emplace_back(&step);
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> range_dataset_kernel;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
|
CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> range_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
|
||||||
|
&range_dataset_context));
|
||||||
|
DatasetBase* range_dataset;
|
||||||
|
EXPECT_EQ(CreateDataset(range_dataset_kernel.get(),
|
||||||
|
range_dataset_context.get(), &range_dataset)
|
||||||
|
.code(),
|
||||||
|
tensorflow::error::INVALID_ARGUMENT);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RangeDatasetOpTest, DatasetNodeName) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
TestCase test_case = PositiveStepTestCase();
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
|
||||||
|
Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
|
||||||
|
Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
|
||||||
|
inputs.emplace_back(&start);
|
||||||
|
inputs.emplace_back(&end);
|
||||||
|
inputs.emplace_back(&step);
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> range_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> range_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
|
||||||
|
&range_dataset_context));
|
||||||
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
|
||||||
|
range_dataset_context.get(), &range_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(range_dataset);
|
||||||
|
|
||||||
|
EXPECT_EQ(range_dataset->node_name(), kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RangeDatasetOpTest, DatasetTypeString) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
TestCase test_case = PositiveStepTestCase();
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
|
||||||
|
Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
|
||||||
|
Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
|
||||||
|
inputs.emplace_back(&start);
|
||||||
|
inputs.emplace_back(&end);
|
||||||
|
inputs.emplace_back(&step);
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> range_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> range_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
|
||||||
|
&range_dataset_context));
|
||||||
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
|
||||||
|
range_dataset_context.get(), &range_dataset));
|
||||||
core::ScopedUnref scoped_unref(range_dataset);
|
core::ScopedUnref scoped_unref(range_dataset);
|
||||||
|
|
||||||
EXPECT_EQ(range_dataset->type_string(), kOpName);
|
EXPECT_EQ(range_dataset->type_string(), kOpName);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RangeDatasetOpTest, DatasetOutputDtypes) {
|
TEST_F(RangeDatasetOpTest, DatasetOutputDtypes) {
|
||||||
int64 start = 0, end = 10, step = 1;
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> range_kernel;
|
TestCase test_case = PositiveStepTestCase();
|
||||||
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
std::unique_ptr<OpKernelContext> range_context;
|
Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
|
||||||
TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(),
|
Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
|
||||||
&range_context));
|
Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
|
||||||
DatasetBase* range_dataset;
|
inputs.emplace_back(&start);
|
||||||
|
inputs.emplace_back(&end);
|
||||||
|
inputs.emplace_back(&step);
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> range_dataset_kernel;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
|
CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> range_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
|
||||||
|
&range_dataset_context));
|
||||||
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
|
||||||
|
range_dataset_context.get(), &range_dataset));
|
||||||
core::ScopedUnref scoped_unref(range_dataset);
|
core::ScopedUnref scoped_unref(range_dataset);
|
||||||
|
|
||||||
DataTypeVector expected_dtypes({DT_INT64});
|
TF_EXPECT_OK(VerifyTypesMatch(range_dataset->output_dtypes(),
|
||||||
EXPECT_EQ(range_dataset->output_dtypes(), expected_dtypes);
|
test_case.expected_output_dtypes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RangeDatasetOpTest, DatasetOutputShapes) {
|
TEST_F(RangeDatasetOpTest, DatasetOutputShapes) {
|
||||||
int64 start = 0, end = 10, step = 1;
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> range_kernel;
|
TestCase test_case = PositiveStepTestCase();
|
||||||
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
std::unique_ptr<OpKernelContext> range_context;
|
Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
|
||||||
TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(),
|
Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
|
||||||
&range_context));
|
Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
|
||||||
DatasetBase* range_dataset;
|
inputs.emplace_back(&start);
|
||||||
|
inputs.emplace_back(&end);
|
||||||
|
inputs.emplace_back(&step);
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> range_dataset_kernel;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
|
CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> range_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
|
||||||
|
&range_dataset_context));
|
||||||
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
|
||||||
|
range_dataset_context.get(), &range_dataset));
|
||||||
core::ScopedUnref scoped_unref(range_dataset);
|
core::ScopedUnref scoped_unref(range_dataset);
|
||||||
|
|
||||||
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
|
TF_EXPECT_OK(VerifyShapesCompatible(range_dataset->output_shapes(),
|
||||||
EXPECT_EQ(range_dataset->output_shapes().size(), expected_shapes.size());
|
test_case.expected_output_shapes));
|
||||||
for (int i = 0; i < range_dataset->output_shapes().size(); ++i) {
|
|
||||||
EXPECT_TRUE(
|
|
||||||
range_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CardinalityTestParams {
|
TEST_P(ParameterizedRangeDatasetOpTest, Cardinality) {
|
||||||
explicit CardinalityTestParams(int64 input_start, int64 input_end,
|
|
||||||
int64 input_step,
|
|
||||||
int input_expected_cardinality)
|
|
||||||
: start(input_start),
|
|
||||||
end(input_end),
|
|
||||||
step(input_step),
|
|
||||||
expected_cardinality(input_expected_cardinality) {}
|
|
||||||
|
|
||||||
int64 start;
|
|
||||||
int64 end;
|
|
||||||
int64 step;
|
|
||||||
int expected_cardinality;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DatasetCardinalityTest
|
|
||||||
: RangeDatasetOpTest,
|
|
||||||
::testing::WithParamInterface<CardinalityTestParams> {};
|
|
||||||
|
|
||||||
TEST_P(DatasetCardinalityTest, Cardinality) {
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
CardinalityTestParams params = GetParam();
|
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> range_kernel;
|
TestCase test_case = GetParam();
|
||||||
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
std::unique_ptr<OpKernelContext> range_context;
|
Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
|
||||||
TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step,
|
Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
|
||||||
range_kernel.get(), &range_context));
|
Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
|
||||||
DatasetBase* range_dataset;
|
inputs.emplace_back(&start);
|
||||||
|
inputs.emplace_back(&end);
|
||||||
|
inputs.emplace_back(&step);
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> range_dataset_kernel;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
|
CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> range_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
|
||||||
|
&range_dataset_context));
|
||||||
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
|
||||||
|
range_dataset_context.get(), &range_dataset));
|
||||||
core::ScopedUnref scoped_unref(range_dataset);
|
core::ScopedUnref scoped_unref(range_dataset);
|
||||||
|
|
||||||
EXPECT_EQ(range_dataset->Cardinality(), params.expected_cardinality);
|
EXPECT_EQ(range_dataset->Cardinality(), test_case.expected_cardinality);
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(RangeDatasetOpTest, DatasetCardinalityTest,
|
|
||||||
::testing::Values(CardinalityTestParams(0, 10, 1, 10),
|
|
||||||
CardinalityTestParams(0, 10, 3, 4),
|
|
||||||
CardinalityTestParams(10, 0, -3, 4)));
|
|
||||||
|
|
||||||
TEST_F(RangeDatasetOpTest, DatasetSave) {
|
TEST_F(RangeDatasetOpTest, DatasetSave) {
|
||||||
int64 thread_num = 2, cpu_num = 2;
|
int64 thread_num = 2, cpu_num = 2;
|
||||||
int start = 0, end = 10, step = 1;
|
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> range_kernel;
|
TestCase test_case = PositiveStepTestCase();
|
||||||
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
std::unique_ptr<OpKernelContext> range_context;
|
Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
|
||||||
TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(),
|
Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
|
||||||
&range_context));
|
Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
|
||||||
DatasetBase* range_dataset;
|
inputs.emplace_back(&start);
|
||||||
|
inputs.emplace_back(&end);
|
||||||
|
inputs.emplace_back(&step);
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> range_dataset_kernel;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
|
CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> range_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
|
||||||
|
&range_dataset_context));
|
||||||
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
|
||||||
|
range_dataset_context.get(), &range_dataset));
|
||||||
core::ScopedUnref scoped_unref(range_dataset);
|
core::ScopedUnref scoped_unref(range_dataset);
|
||||||
|
|
||||||
std::unique_ptr<SerializationContext> serialization_context;
|
std::unique_ptr<SerializationContext> serialization_context;
|
||||||
@ -256,81 +343,105 @@ TEST_F(RangeDatasetOpTest, DatasetSave) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RangeDatasetOpTest, IteratorOutputDtypes) {
|
TEST_F(RangeDatasetOpTest, IteratorOutputDtypes) {
|
||||||
int64 start = 0, end = 10, step = 1;
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> range_kernel;
|
TestCase test_case = PositiveStepTestCase();
|
||||||
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
std::unique_ptr<OpKernelContext> range_context;
|
Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
|
||||||
TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(),
|
Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
|
||||||
&range_context));
|
Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
|
||||||
DatasetBase* range_dataset;
|
inputs.emplace_back(&start);
|
||||||
|
inputs.emplace_back(&end);
|
||||||
|
inputs.emplace_back(&step);
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> range_dataset_kernel;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
|
CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> range_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
|
||||||
|
&range_dataset_context));
|
||||||
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
|
||||||
|
range_dataset_context.get(), &range_dataset));
|
||||||
core::ScopedUnref scoped_unref(range_dataset);
|
core::ScopedUnref scoped_unref(range_dataset);
|
||||||
|
|
||||||
std::unique_ptr<IteratorContext> iterator_context;
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context));
|
TF_ASSERT_OK(
|
||||||
|
CreateIteratorContext(range_dataset_context.get(), &iterator_context));
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
|
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
|
||||||
&iterator));
|
&iterator));
|
||||||
|
|
||||||
DataTypeVector expected_dtypes({DT_INT64});
|
TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
|
||||||
EXPECT_EQ(iterator->output_dtypes(), expected_dtypes);
|
test_case.expected_output_dtypes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RangeDatasetOpTest, IteratorOutputShapes) {
|
TEST_F(RangeDatasetOpTest, IteratorOutputShapes) {
|
||||||
int64 start = 0, end = 10, step = 1;
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> range_kernel;
|
TestCase test_case = PositiveStepTestCase();
|
||||||
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
std::unique_ptr<OpKernelContext> range_context;
|
Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
|
||||||
TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(),
|
Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
|
||||||
&range_context));
|
Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
|
||||||
DatasetBase* range_dataset;
|
inputs.emplace_back(&start);
|
||||||
|
inputs.emplace_back(&end);
|
||||||
|
inputs.emplace_back(&step);
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> range_dataset_kernel;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
|
CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> range_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
|
||||||
|
&range_dataset_context));
|
||||||
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
|
||||||
|
range_dataset_context.get(), &range_dataset));
|
||||||
core::ScopedUnref scoped_unref(range_dataset);
|
core::ScopedUnref scoped_unref(range_dataset);
|
||||||
|
|
||||||
std::unique_ptr<IteratorContext> iterator_context;
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context));
|
TF_ASSERT_OK(
|
||||||
|
CreateIteratorContext(range_dataset_context.get(), &iterator_context));
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
|
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
|
||||||
&iterator));
|
&iterator));
|
||||||
|
|
||||||
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
|
TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
|
||||||
EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size());
|
test_case.expected_output_shapes));
|
||||||
for (int i = 0; i < range_dataset->output_shapes().size(); ++i) {
|
|
||||||
EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) {
|
TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) {
|
||||||
int64 start = 0, end = 10, step = 1;
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> range_kernel;
|
TestCase test_case = PositiveStepTestCase();
|
||||||
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
std::unique_ptr<OpKernelContext> range_context;
|
Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
|
||||||
TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(),
|
Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
|
||||||
&range_context));
|
Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
|
||||||
DatasetBase* range_dataset;
|
inputs.emplace_back(&start);
|
||||||
|
inputs.emplace_back(&end);
|
||||||
|
inputs.emplace_back(&step);
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> range_dataset_kernel;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
|
CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> range_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
|
||||||
|
&range_dataset_context));
|
||||||
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
|
||||||
|
range_dataset_context.get(), &range_dataset));
|
||||||
core::ScopedUnref scoped_unref(range_dataset);
|
core::ScopedUnref scoped_unref(range_dataset);
|
||||||
|
|
||||||
std::unique_ptr<IteratorContext> iterator_context;
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context));
|
TF_ASSERT_OK(
|
||||||
|
CreateIteratorContext(range_dataset_context.get(), &iterator_context));
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
|
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
|
||||||
&iterator));
|
&iterator));
|
||||||
@ -338,83 +449,77 @@ TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) {
|
|||||||
EXPECT_EQ(iterator->prefix(), "Iterator::Range");
|
EXPECT_EQ(iterator->prefix(), "Iterator::Range");
|
||||||
}
|
}
|
||||||
|
|
||||||
struct RoundtripTestParams {
|
TEST_P(ParameterizedRangeDatasetOpTest, Roundtrip) {
|
||||||
explicit RoundtripTestParams(int64 input_start, int64 input_end,
|
|
||||||
int64 input_step, int input_breakpoint)
|
|
||||||
: start(input_start),
|
|
||||||
end(input_end),
|
|
||||||
step(input_step),
|
|
||||||
breakpoint(input_breakpoint) {}
|
|
||||||
|
|
||||||
int64 start;
|
|
||||||
int64 end;
|
|
||||||
int64 step;
|
|
||||||
int breakpoint;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct IteratorRoundtripTest
|
|
||||||
: RangeDatasetOpTest,
|
|
||||||
::testing::WithParamInterface<RoundtripTestParams> {};
|
|
||||||
|
|
||||||
TEST_P(IteratorRoundtripTest, Roundtrip) {
|
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
RoundtripTestParams params = GetParam();
|
|
||||||
|
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
std::unique_ptr<OpKernel> range_kernel;
|
TestCase test_case = GetParam();
|
||||||
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
std::unique_ptr<OpKernelContext> range_context;
|
Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
|
||||||
TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step,
|
Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
|
||||||
range_kernel.get(), &range_context));
|
Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
|
||||||
DatasetBase* range_dataset;
|
inputs.emplace_back(&start);
|
||||||
|
inputs.emplace_back(&end);
|
||||||
|
inputs.emplace_back(&step);
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> range_dataset_kernel;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
|
CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> range_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
|
||||||
|
&range_dataset_context));
|
||||||
|
DatasetBase* range_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
|
||||||
|
range_dataset_context.get(), &range_dataset));
|
||||||
core::ScopedUnref scoped_unref(range_dataset);
|
core::ScopedUnref scoped_unref(range_dataset);
|
||||||
|
|
||||||
std::unique_ptr<IteratorContext> iterator_context;
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context));
|
TF_ASSERT_OK(
|
||||||
|
CreateIteratorContext(range_dataset_context.get(), &iterator_context));
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
|
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
|
||||||
&iterator));
|
&iterator));
|
||||||
|
|
||||||
std::vector<Tensor> out_tensors;
|
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||||
|
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||||
bool end_of_sequence = false;
|
bool end_of_sequence = false;
|
||||||
int64 cur_val = params.start - params.step;
|
std::vector<Tensor> out_tensors;
|
||||||
for (int i = 0; i < params.breakpoint; i++) {
|
int cur_iteration = 0;
|
||||||
if (!end_of_sequence) {
|
auto expected_outputs_it = test_case.expected_outputs.begin();
|
||||||
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
const std::vector<int>& breakpoints = test_case.breakpoints;
|
||||||
&end_of_sequence));
|
for (int breakpoint : breakpoints) {
|
||||||
cur_val = ((params.end - cur_val - params.step) * params.step > 0)
|
|
||||||
? cur_val + params.step
|
|
||||||
: cur_val;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<SerializationContext> serialization_context;
|
|
||||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
|
||||||
VariantTensorData data;
|
VariantTensorData data;
|
||||||
VariantTensorDataWriter writer(&data);
|
VariantTensorDataWriter writer(&data);
|
||||||
TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer));
|
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
|
||||||
TF_ASSERT_OK(writer.Flush());
|
TF_EXPECT_OK(writer.Flush());
|
||||||
VariantTensorDataReader reader(&data);
|
VariantTensorDataReader reader(&data);
|
||||||
TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader));
|
TF_EXPECT_OK(iterator->Restore(iterator_context.get(), &reader));
|
||||||
|
|
||||||
|
while (cur_iteration <= breakpoint) {
|
||||||
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
||||||
&end_of_sequence));
|
&end_of_sequence));
|
||||||
int64 expect_next = ((params.end - cur_val - params.step) * params.step > 0)
|
if (!end_of_sequence) {
|
||||||
? cur_val + params.step
|
EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
|
||||||
: cur_val;
|
TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it));
|
||||||
EXPECT_EQ(out_tensors.back().flat<int64>()(0), expect_next);
|
expected_outputs_it++;
|
||||||
|
}
|
||||||
|
cur_iteration++;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (breakpoint >= test_case.expected_cardinality) {
|
||||||
|
EXPECT_TRUE(end_of_sequence);
|
||||||
|
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
|
||||||
|
} else {
|
||||||
|
EXPECT_FALSE(end_of_sequence);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
RangeDatasetOpTest, IteratorRoundtripTest,
|
RangeDatasetOpTest, ParameterizedRangeDatasetOpTest,
|
||||||
::testing::Values(
|
::testing::ValuesIn(std::vector<TestCase>({PositiveStepTestCase(),
|
||||||
RoundtripTestParams(0, 10, 2, 0), // unused_iterator
|
NegativeStepTestCase()})));
|
||||||
RoundtripTestParams(0, 10, 2, 4), // fully_used_iterator_increase
|
|
||||||
RoundtripTestParams(10, 0, -2, 4), // fully_used_iterator_decrease
|
|
||||||
RoundtripTestParams(0, 10, 2, 6))); // exhausted_iterator
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
@ -80,8 +80,7 @@ TestCase TakeLessTestCase() {
|
|||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}),
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {2}),
|
||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {3})},
|
||||||
/*expected_output_dtypes*/ {DT_INT64},
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
/*expected_output_shapes*/
|
/*expected_output_shapes*/ {PartialTensorShape({1})},
|
||||||
{PartialTensorShape({1})},
|
|
||||||
/*expected_cardinality*/ 4,
|
/*expected_cardinality*/ 4,
|
||||||
/*breakpoints*/ {0, 2, 5}};
|
/*breakpoints*/ {0, 2, 5}};
|
||||||
}
|
}
|
||||||
@ -104,8 +103,7 @@ TestCase TakeMoreTestCase() {
|
|||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
|
||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
|
||||||
/*expected_output_dtypes*/ {DT_INT64},
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
/*expected_output_shapes*/
|
/*expected_output_shapes*/ {PartialTensorShape({1})},
|
||||||
{PartialTensorShape({1})},
|
|
||||||
/*expected_cardinality*/ 10,
|
/*expected_cardinality*/ 10,
|
||||||
/*breakpoints*/ {0, 2, 5, 11}};
|
/*breakpoints*/ {0, 2, 5, 11}};
|
||||||
}
|
}
|
||||||
@ -128,8 +126,7 @@ TestCase TakeAllTestCase() {
|
|||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {8}),
|
||||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape{1}, {9})},
|
||||||
/*expected_output_dtypes*/ {DT_INT64},
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
/*expected_output_shapes*/
|
/*expected_output_shapes*/ {PartialTensorShape({1})},
|
||||||
{PartialTensorShape({1})},
|
|
||||||
/*expected_cardinality*/ -1,
|
/*expected_cardinality*/ -1,
|
||||||
/*breakpoints*/ {0, 2, 5, 11}};
|
/*breakpoints*/ {0, 2, 5, 11}};
|
||||||
}
|
}
|
||||||
@ -140,20 +137,18 @@ TestCase TakeNothingTestCase() {
|
|||||||
{DatasetOpsTestBase::CreateTensor<int64>(
|
{DatasetOpsTestBase::CreateTensor<int64>(
|
||||||
TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
|
TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
|
||||||
/*count*/ 0,
|
/*count*/ 0,
|
||||||
/*expected_outputs*/
|
/*expected_outputs*/ {},
|
||||||
{},
|
|
||||||
/*expected_output_dtypes*/ {DT_INT64},
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
/*expected_output_shapes*/
|
/*expected_output_shapes*/ {PartialTensorShape({1})},
|
||||||
{PartialTensorShape({1})},
|
|
||||||
/*expected_cardinality*/ 0,
|
/*expected_cardinality*/ 0,
|
||||||
/*breakpoints*/ {0, 2, 5, 11}};
|
/*breakpoints*/ {0, 2, 5, 11}};
|
||||||
}
|
}
|
||||||
|
|
||||||
class ParametrizedTakeDatasetOpTest
|
class ParameterizedTakeDatasetOpTest
|
||||||
: public TakeDatasetOpTest,
|
: public TakeDatasetOpTest,
|
||||||
public ::testing::WithParamInterface<TestCase> {};
|
public ::testing::WithParamInterface<TestCase> {};
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, GetNext) {
|
TEST_P(ParameterizedTakeDatasetOpTest, GetNext) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -204,7 +199,37 @@ TEST_P(ParametrizedTakeDatasetOpTest, GetNext) {
|
|||||||
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
|
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TakeDatasetOpTest, DatasetName) {
|
TEST_F(TakeDatasetOpTest, DatasetNodeName) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
const TestCase &test_case = TakeLessTestCase();
|
||||||
|
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||||
|
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
|
||||||
|
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
|
||||||
|
&tensor_slice_dataset_tensor));
|
||||||
|
Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs_for_take_dataset;
|
||||||
|
inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor);
|
||||||
|
inputs_for_take_dataset.emplace_back(&count);
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> take_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes,
|
||||||
|
test_case.expected_output_shapes,
|
||||||
|
&take_dataset_kernel));
|
||||||
|
std::unique_ptr<OpKernelContext> take_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(),
|
||||||
|
&inputs_for_take_dataset,
|
||||||
|
&take_dataset_context));
|
||||||
|
DatasetBase *take_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(),
|
||||||
|
take_dataset_context.get(), &take_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(take_dataset);
|
||||||
|
|
||||||
|
EXPECT_EQ(take_dataset->node_name(), kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TakeDatasetOpTest, DatasetTypeString) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -234,7 +259,7 @@ TEST_F(TakeDatasetOpTest, DatasetName) {
|
|||||||
EXPECT_EQ(take_dataset->type_string(), kOpName);
|
EXPECT_EQ(take_dataset->type_string(), kOpName);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputDtypes) {
|
TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputDtypes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -265,7 +290,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputDtypes) {
|
|||||||
test_case.expected_output_dtypes));
|
test_case.expected_output_dtypes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputShapes) {
|
TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputShapes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -296,7 +321,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, DatasetOutputShapes) {
|
|||||||
test_case.expected_output_shapes));
|
test_case.expected_output_shapes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, Cardinality) {
|
TEST_P(ParameterizedTakeDatasetOpTest, Cardinality) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -361,7 +386,7 @@ TEST_F(TakeDatasetOpTest, DatasetSave) {
|
|||||||
TF_ASSERT_OK(writer.Flush());
|
TF_ASSERT_OK(writer.Flush());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputDtypes) {
|
TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputDtypes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -399,7 +424,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputDtypes) {
|
|||||||
test_case.expected_output_dtypes));
|
test_case.expected_output_dtypes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputShapes) {
|
TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputShapes) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -437,7 +462,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputShapes) {
|
|||||||
test_case.expected_output_shapes));
|
test_case.expected_output_shapes));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputPrefix) {
|
TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputPrefix) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -478,7 +503,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, IteratorOutputPrefix) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(ParametrizedTakeDatasetOpTest, Roundtrip) {
|
TEST_P(ParameterizedTakeDatasetOpTest, Roundtrip) {
|
||||||
int thread_num = 2, cpu_num = 2;
|
int thread_num = 2, cpu_num = 2;
|
||||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
@ -550,7 +575,7 @@ TEST_P(ParametrizedTakeDatasetOpTest, Roundtrip) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(TakeDatasetOpTest, ParametrizedTakeDatasetOpTest,
|
INSTANTIATE_TEST_SUITE_P(TakeDatasetOpTest, ParameterizedTakeDatasetOpTest,
|
||||||
::testing::ValuesIn(std::vector<TestCase>(
|
::testing::ValuesIn(std::vector<TestCase>(
|
||||||
{TakeLessTestCase(), TakeMoreTestCase(),
|
{TakeLessTestCase(), TakeMoreTestCase(),
|
||||||
TakeAllTestCase(), TakeNothingTestCase()})));
|
TakeAllTestCase(), TakeNothingTestCase()})));
|
||||||
|
Loading…
Reference in New Issue
Block a user