Refactor RangeDatasetOpTest

This commit is contained in:
Fei Hu 2019-03-15 16:46:03 -07:00
parent c57401b145
commit 8dc85de492

View File

@ -13,237 +13,324 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h" #include "tensorflow/core/kernels/data/dataset_test_base.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/iterator_ops.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
namespace { namespace {
constexpr char kNodeName[] = "range_dataset";
constexpr char kOpName[] = "RangeDataset"; constexpr char kOpName[] = "RangeDataset";
class RangeDatasetOpTest : public DatasetOpsTestBase { class RangeDatasetOpTest : public DatasetOpsTestBase {
protected: protected:
// Creates a new RangeDataset op kernel context. // Creates a new RangeDataset op kernel context.
Status CreateRangeDatasetContext( Status CreateRangeDatasetContext(
int64 start, int64 end, int64 step, OpKernel* const range_kernel, OpKernel* const range_kernel,
gtl::InlinedVector<TensorValue, 4>* const inputs,
std::unique_ptr<OpKernelContext>* range_context) { std::unique_ptr<OpKernelContext>* range_context) {
inputs_.clear(); TF_RETURN_IF_ERROR(CheckOpKernelInput(*range_kernel, *inputs));
TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>(
&inputs_, range_kernel->input_types(), TensorShape({}), {start}));
TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>(
&inputs_, range_kernel->input_types(), TensorShape({}), {end}));
TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>(
&inputs_, range_kernel->input_types(), TensorShape({}), {step}));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
CreateOpKernelContext(range_kernel, &inputs_, range_context)); CreateOpKernelContext(range_kernel, inputs, range_context));
TF_RETURN_IF_ERROR(CheckOpKernelInput(*range_kernel, inputs_));
return Status::OK(); return Status::OK();
} }
private:
gtl::InlinedVector<TensorValue, 4> inputs_;
}; };
struct GetNextTestParams { struct TestCase {
explicit GetNextTestParams(int64 input_start, int64 input_end,
int64 input_step)
: start(input_start), end(input_end), step(input_step) {}
int64 start; int64 start;
int64 end; int64 end;
int64 step; int64 step;
std::vector<Tensor> expected_outputs;
DataTypeVector expected_output_dtypes;
std::vector<PartialTensorShape> expected_output_shapes;
int64 expected_cardinality;
std::vector<int> breakpoints;
}; };
struct DatasetGetNextTest : RangeDatasetOpTest, TestCase PositiveStepTestCase() {
::testing::WithParamInterface<GetNextTestParams> {}; return {/*start*/ 0,
/*end*/ 10,
/*step*/ 3,
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {3}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {6}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {9})},
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 4,
/*breakpoints*/ {0, 1, 4}};
}
TEST_P(DatasetGetNextTest, GetNext) { TestCase NegativeStepTestCase() {
return {/*start*/ 10,
/*end*/ 0,
/*step*/ -3,
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {10}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {7}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {4}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1})},
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 4,
/*breakpoints*/ {0, 1, 4}};
}
TestCase ZeroStepTestCase() {
return {/*start*/ 0,
/*end*/ 10,
/*step*/ 0,
/*expected_outputs*/ {},
/*expected_output_dtypes*/ {},
/*expected_output_shapes*/ {},
/*expected_cardinality*/ 0,
/*breakpoints*/ {}};
}
class ParameterizedRangeDatasetOpTest
: public RangeDatasetOpTest,
public ::testing::WithParamInterface<TestCase> {};
TEST_P(ParameterizedRangeDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
GetNextTestParams params = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::unique_ptr<OpKernel> range_kernel; TestCase test_case = GetParam();
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel)); gtl::InlinedVector<TensorValue, 4> inputs;
std::unique_ptr<OpKernelContext> range_context; Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step, Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
range_kernel.get(), &range_context)); Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
DatasetBase* range_dataset; inputs.emplace_back(&start);
inputs.emplace_back(&end);
inputs.emplace_back(&step);
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset); core::ScopedUnref scoped_unref(range_dataset);
std::unique_ptr<IteratorContext> iterator_context; std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); TF_ASSERT_OK(
CreateIteratorContext(range_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator; std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
&iterator)); &iterator));
bool end_of_sequence = false; bool end_of_sequence = false;
auto expected_outputs_it = test_case.expected_outputs.begin();
std::vector<Tensor> out_tensors; std::vector<Tensor> out_tensors;
while (!end_of_sequence) { while (!end_of_sequence) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
&end_of_sequence)); &end_of_sequence));
if (!end_of_sequence) {
EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it));
expected_outputs_it++;
} }
std::vector<int> expected_values;
for (int i = params.start; (params.end - i) * params.step > 0;
i = i + params.step) {
expected_values.reserve(1);
expected_values.emplace_back(i);
}
EXPECT_EQ(out_tensors.size(), expected_values.size());
for (size_t i = 0; i < out_tensors.size(); ++i) {
int64 actual_value = out_tensors[i].flat<int64>()(0);
int64 expect_value = expected_values[i];
EXPECT_EQ(actual_value, expect_value);
} }
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
} }
INSTANTIATE_TEST_CASE_P(RangeDatasetOpTest, DatasetGetNextTest, TEST_F(RangeDatasetOpTest, ZeroStep) {
::testing::Values(GetNextTestParams(0, 10, 1),
GetNextTestParams(0, 10, 3),
GetNextTestParams(10, 0, -1),
GetNextTestParams(10, 0, -3)));
TEST_F(RangeDatasetOpTest, DatasetName) {
int64 start = 0, end = 10, step = 1;
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::unique_ptr<OpKernel> range_kernel; TestCase test_case = ZeroStepTestCase();
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel)); gtl::InlinedVector<TensorValue, 4> inputs;
std::unique_ptr<OpKernelContext> range_context; Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
&range_context)); Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
DatasetBase* range_dataset; inputs.emplace_back(&start);
inputs.emplace_back(&end);
inputs.emplace_back(&step);
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
EXPECT_EQ(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset)
.code(),
tensorflow::error::INVALID_ARGUMENT);
}
TEST_F(RangeDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TestCase test_case = PositiveStepTestCase();
gtl::InlinedVector<TensorValue, 4> inputs;
Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
inputs.emplace_back(&start);
inputs.emplace_back(&end);
inputs.emplace_back(&step);
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK(
CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset);
EXPECT_EQ(range_dataset->node_name(), kNodeName);
}
TEST_F(RangeDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TestCase test_case = PositiveStepTestCase();
gtl::InlinedVector<TensorValue, 4> inputs;
Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
inputs.emplace_back(&start);
inputs.emplace_back(&end);
inputs.emplace_back(&step);
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK(
CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset); core::ScopedUnref scoped_unref(range_dataset);
EXPECT_EQ(range_dataset->type_string(), kOpName); EXPECT_EQ(range_dataset->type_string(), kOpName);
} }
TEST_F(RangeDatasetOpTest, DatasetOutputDtypes) { TEST_F(RangeDatasetOpTest, DatasetOutputDtypes) {
int64 start = 0, end = 10, step = 1;
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::unique_ptr<OpKernel> range_kernel; TestCase test_case = PositiveStepTestCase();
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel)); gtl::InlinedVector<TensorValue, 4> inputs;
std::unique_ptr<OpKernelContext> range_context; Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
&range_context)); Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
DatasetBase* range_dataset; inputs.emplace_back(&start);
inputs.emplace_back(&end);
inputs.emplace_back(&step);
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset); core::ScopedUnref scoped_unref(range_dataset);
DataTypeVector expected_dtypes({DT_INT64}); TF_EXPECT_OK(VerifyTypesMatch(range_dataset->output_dtypes(),
EXPECT_EQ(range_dataset->output_dtypes(), expected_dtypes); test_case.expected_output_dtypes));
} }
TEST_F(RangeDatasetOpTest, DatasetOutputShapes) { TEST_F(RangeDatasetOpTest, DatasetOutputShapes) {
int64 start = 0, end = 10, step = 1;
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::unique_ptr<OpKernel> range_kernel; TestCase test_case = PositiveStepTestCase();
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel)); gtl::InlinedVector<TensorValue, 4> inputs;
std::unique_ptr<OpKernelContext> range_context; Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
&range_context)); Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
DatasetBase* range_dataset; inputs.emplace_back(&start);
inputs.emplace_back(&end);
inputs.emplace_back(&step);
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset); core::ScopedUnref scoped_unref(range_dataset);
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})}); TF_EXPECT_OK(VerifyShapesCompatible(range_dataset->output_shapes(),
EXPECT_EQ(range_dataset->output_shapes().size(), expected_shapes.size()); test_case.expected_output_shapes));
for (int i = 0; i < range_dataset->output_shapes().size(); ++i) {
EXPECT_TRUE(
range_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
}
} }
struct CardinalityTestParams { TEST_P(ParameterizedRangeDatasetOpTest, Cardinality) {
explicit CardinalityTestParams(int64 input_start, int64 input_end,
int64 input_step,
int input_expected_cardinality)
: start(input_start),
end(input_end),
step(input_step),
expected_cardinality(input_expected_cardinality) {}
int64 start;
int64 end;
int64 step;
int expected_cardinality;
};
struct DatasetCardinalityTest
: RangeDatasetOpTest,
::testing::WithParamInterface<CardinalityTestParams> {};
TEST_P(DatasetCardinalityTest, Cardinality) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
CardinalityTestParams params = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::unique_ptr<OpKernel> range_kernel; TestCase test_case = GetParam();
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel)); gtl::InlinedVector<TensorValue, 4> inputs;
std::unique_ptr<OpKernelContext> range_context; Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step, Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
range_kernel.get(), &range_context)); Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
DatasetBase* range_dataset; inputs.emplace_back(&start);
inputs.emplace_back(&end);
inputs.emplace_back(&step);
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset); core::ScopedUnref scoped_unref(range_dataset);
EXPECT_EQ(range_dataset->Cardinality(), params.expected_cardinality); EXPECT_EQ(range_dataset->Cardinality(), test_case.expected_cardinality);
} }
INSTANTIATE_TEST_CASE_P(RangeDatasetOpTest, DatasetCardinalityTest,
::testing::Values(CardinalityTestParams(0, 10, 1, 10),
CardinalityTestParams(0, 10, 3, 4),
CardinalityTestParams(10, 0, -3, 4)));
TEST_F(RangeDatasetOpTest, DatasetSave) { TEST_F(RangeDatasetOpTest, DatasetSave) {
int64 thread_num = 2, cpu_num = 2; int64 thread_num = 2, cpu_num = 2;
int start = 0, end = 10, step = 1;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::unique_ptr<OpKernel> range_kernel; TestCase test_case = PositiveStepTestCase();
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel)); gtl::InlinedVector<TensorValue, 4> inputs;
std::unique_ptr<OpKernelContext> range_context; Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
&range_context)); Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
DatasetBase* range_dataset; inputs.emplace_back(&start);
inputs.emplace_back(&end);
inputs.emplace_back(&step);
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset); core::ScopedUnref scoped_unref(range_dataset);
std::unique_ptr<SerializationContext> serialization_context; std::unique_ptr<SerializationContext> serialization_context;
@ -256,81 +343,105 @@ TEST_F(RangeDatasetOpTest, DatasetSave) {
} }
TEST_F(RangeDatasetOpTest, IteratorOutputDtypes) { TEST_F(RangeDatasetOpTest, IteratorOutputDtypes) {
int64 start = 0, end = 10, step = 1;
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::unique_ptr<OpKernel> range_kernel; TestCase test_case = PositiveStepTestCase();
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel)); gtl::InlinedVector<TensorValue, 4> inputs;
std::unique_ptr<OpKernelContext> range_context; Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
&range_context)); Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
DatasetBase* range_dataset; inputs.emplace_back(&start);
inputs.emplace_back(&end);
inputs.emplace_back(&step);
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset); core::ScopedUnref scoped_unref(range_dataset);
std::unique_ptr<IteratorContext> iterator_context; std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); TF_ASSERT_OK(
CreateIteratorContext(range_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator; std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
&iterator)); &iterator));
DataTypeVector expected_dtypes({DT_INT64}); TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
EXPECT_EQ(iterator->output_dtypes(), expected_dtypes); test_case.expected_output_dtypes));
} }
TEST_F(RangeDatasetOpTest, IteratorOutputShapes) { TEST_F(RangeDatasetOpTest, IteratorOutputShapes) {
int64 start = 0, end = 10, step = 1;
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::unique_ptr<OpKernel> range_kernel; TestCase test_case = PositiveStepTestCase();
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel)); gtl::InlinedVector<TensorValue, 4> inputs;
std::unique_ptr<OpKernelContext> range_context; Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
&range_context)); Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
DatasetBase* range_dataset; inputs.emplace_back(&start);
inputs.emplace_back(&end);
inputs.emplace_back(&step);
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset); core::ScopedUnref scoped_unref(range_dataset);
std::unique_ptr<IteratorContext> iterator_context; std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); TF_ASSERT_OK(
CreateIteratorContext(range_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator; std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
&iterator)); &iterator));
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})}); TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size()); test_case.expected_output_shapes));
for (int i = 0; i < range_dataset->output_shapes().size(); ++i) {
EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
}
} }
TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) { TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) {
int64 start = 0, end = 10, step = 1;
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::unique_ptr<OpKernel> range_kernel; TestCase test_case = PositiveStepTestCase();
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel)); gtl::InlinedVector<TensorValue, 4> inputs;
std::unique_ptr<OpKernelContext> range_context; Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
&range_context)); Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
DatasetBase* range_dataset; inputs.emplace_back(&start);
inputs.emplace_back(&end);
inputs.emplace_back(&step);
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset); core::ScopedUnref scoped_unref(range_dataset);
std::unique_ptr<IteratorContext> iterator_context; std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); TF_ASSERT_OK(
CreateIteratorContext(range_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator; std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
&iterator)); &iterator));
@ -338,83 +449,77 @@ TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) {
EXPECT_EQ(iterator->prefix(), "Iterator::Range"); EXPECT_EQ(iterator->prefix(), "Iterator::Range");
} }
struct RoundtripTestParams { TEST_P(ParameterizedRangeDatasetOpTest, Roundtrip) {
explicit RoundtripTestParams(int64 input_start, int64 input_end,
int64 input_step, int input_breakpoint)
: start(input_start),
end(input_end),
step(input_step),
breakpoint(input_breakpoint) {}
int64 start;
int64 end;
int64 step;
int breakpoint;
};
struct IteratorRoundtripTest
: RangeDatasetOpTest,
::testing::WithParamInterface<RoundtripTestParams> {};
TEST_P(IteratorRoundtripTest, Roundtrip) {
int thread_num = 2, cpu_num = 2; int thread_num = 2, cpu_num = 2;
RoundtripTestParams params = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::unique_ptr<OpKernel> range_kernel; TestCase test_case = GetParam();
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel)); gtl::InlinedVector<TensorValue, 4> inputs;
std::unique_ptr<OpKernelContext> range_context; Tensor start = CreateTensor<int64>(TensorShape({}), {test_case.start});
TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step, Tensor end = CreateTensor<int64>(TensorShape({}), {test_case.end});
range_kernel.get(), &range_context)); Tensor step = CreateTensor<int64>(TensorShape({}), {test_case.step});
DatasetBase* range_dataset; inputs.emplace_back(&start);
inputs.emplace_back(&end);
inputs.emplace_back(&step);
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK( TF_ASSERT_OK(
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); CreateRangeDatasetOpKernel<int64>(kNodeName, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset); core::ScopedUnref scoped_unref(range_dataset);
std::unique_ptr<IteratorContext> iterator_context; std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); TF_ASSERT_OK(
CreateIteratorContext(range_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator; std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
&iterator)); &iterator));
std::vector<Tensor> out_tensors; std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
bool end_of_sequence = false; bool end_of_sequence = false;
int64 cur_val = params.start - params.step; std::vector<Tensor> out_tensors;
for (int i = 0; i < params.breakpoint; i++) { int cur_iteration = 0;
if (!end_of_sequence) { auto expected_outputs_it = test_case.expected_outputs.begin();
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, const std::vector<int>& breakpoints = test_case.breakpoints;
&end_of_sequence)); for (int breakpoint : breakpoints) {
cur_val = ((params.end - cur_val - params.step) * params.step > 0)
? cur_val + params.step
: cur_val;
}
}
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data; VariantTensorData data;
VariantTensorDataWriter writer(&data); VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer)); TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush()); TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data); VariantTensorDataReader reader(&data);
TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader)); TF_EXPECT_OK(iterator->Restore(iterator_context.get(), &reader));
while (cur_iteration <= breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
&end_of_sequence)); &end_of_sequence));
int64 expect_next = ((params.end - cur_val - params.step) * params.step > 0) if (!end_of_sequence) {
? cur_val + params.step EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end());
: cur_val; TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it));
EXPECT_EQ(out_tensors.back().flat<int64>()(0), expect_next); expected_outputs_it++;
}
cur_iteration++;
}
if (breakpoint >= test_case.expected_cardinality) {
EXPECT_TRUE(end_of_sequence);
EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
} else {
EXPECT_FALSE(end_of_sequence);
}
}
} }
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_SUITE_P(
RangeDatasetOpTest, IteratorRoundtripTest, RangeDatasetOpTest, ParameterizedRangeDatasetOpTest,
::testing::Values( ::testing::ValuesIn(std::vector<TestCase>({PositiveStepTestCase(),
RoundtripTestParams(0, 10, 2, 0), // unused_iterator NegativeStepTestCase()})));
RoundtripTestParams(0, 10, 2, 4), // fully_used_iterator_increase
RoundtripTestParams(10, 0, -2, 4), // fully_used_iterator_decrease
RoundtripTestParams(0, 10, 2, 6))); // exhausted_iterator
} // namespace } // namespace
} // namespace data } // namespace data