Merge pull request #30886 from feihugis:Refactor_AssertNextDatasetOp

PiperOrigin-RevId: 259145082
This commit is contained in:
TensorFlower Gardener 2019-07-20 14:11:58 -07:00
commit 488b385a7d
8 changed files with 900 additions and 152 deletions

View File

@ -30,6 +30,7 @@ cc_library(
":iterator_ops", ":iterator_ops",
":name_utils", ":name_utils",
":range_dataset_op", ":range_dataset_op",
":take_dataset_op",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",

View File

@ -274,6 +274,46 @@ Status DatasetOpsTestBase::CreateTensorSliceDataset(
return Status::OK(); return Status::OK();
} }
// Create a `RangeDataset` dataset as a variant tensor.
Status DatasetOpsTestBase::MakeRangeDataset(
const Tensor& start, const Tensor& stop, const Tensor& step,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
Tensor* range_dataset) {
GraphConstructorOptions graph_opts;
graph_opts.allow_internal_ops = true;
graph_opts.expect_device_spec = false;
TF_RETURN_IF_ERROR(
RunFunction(test::function::MakeRangeDataset(),
/*attrs*/
{{RangeDatasetOp::kOutputTypes, output_types},
{RangeDatasetOp::kOutputShapes, output_shapes}},
/*inputs*/ {start, stop, step}, graph_opts,
/*rets*/ {range_dataset}));
return Status::OK();
}
// Create a `TakeDataset` dataset as a variant tensor.
Status DatasetOpsTestBase::MakeTakeDataset(
const Tensor& input_dataset, int64 count,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
Tensor* take_dataset) {
GraphConstructorOptions graph_opts;
graph_opts.allow_internal_ops = true;
graph_opts.expect_device_spec = false;
Tensor count_tensor = CreateTensor<int64>(TensorShape({}), {count});
TF_RETURN_IF_ERROR(
RunFunction(test::function::MakeTakeDataset(),
/*attrs*/
{{TakeDatasetOp::kOutputTypes, output_types},
{TakeDatasetOp::kOutputShapes, output_shapes}},
/*inputs*/ {input_dataset, count_tensor}, graph_opts,
/*rets*/ {take_dataset}));
return Status::OK();
}
Status DatasetOpsTestBase::CreateOpKernel( Status DatasetOpsTestBase::CreateOpKernel(
const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) { const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) {
OpKernel* kernel; OpKernel* kernel;

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/iterator_ops.h" #include "tensorflow/core/kernels/data/iterator_ops.h"
#include "tensorflow/core/kernels/data/name_utils.h" #include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/kernels/data/range_dataset_op.h" #include "tensorflow/core/kernels/data/range_dataset_op.h"
#include "tensorflow/core/kernels/data/take_dataset_op.h"
#include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/lib/io/zlib_outputbuffer.h" #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
@ -177,6 +178,19 @@ class DatasetOpsTestBase : public ::testing::Test {
std::vector<Tensor>* const components, std::vector<Tensor>* const components,
DatasetBase** tensor_slice_dataset); DatasetBase** tensor_slice_dataset);
// Creates a `RangeDataset` dataset as a variant tensor.
Status MakeRangeDataset(const Tensor& start, const Tensor& stop,
const Tensor& step,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
Tensor* range_dataset);
// Creates a `TakeDataset` dataset as a variant tensor.
Status MakeTakeDataset(const Tensor& input_dataset, int64 count,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
Tensor* take_dataset);
// Fetches the dataset from the operation context. // Fetches the dataset from the operation context.
Status GetDatasetFromContext(OpKernelContext* context, int output_index, Status GetDatasetFromContext(OpKernelContext* context, int output_index,
DatasetBase** const dataset); DatasetBase** const dataset);

View File

@ -3,6 +3,7 @@
load( load(
"//tensorflow:tensorflow.bzl", "//tensorflow:tensorflow.bzl",
"tf_cc_test",
"tf_kernel_library", "tf_kernel_library",
) )
@ -16,9 +17,27 @@ exports_files(["LICENSE"])
tf_kernel_library( tf_kernel_library(
name = "assert_next_dataset_op", name = "assert_next_dataset_op",
srcs = ["assert_next_dataset_op.cc"], srcs = ["assert_next_dataset_op.cc"],
hdrs = ["assert_next_dataset_op.h"],
deps = [ deps = [
"//tensorflow/core:experimental_dataset_ops_op_lib", "//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core/kernels/data:name_utils",
"//third_party/eigen3",
],
)
tf_cc_test(
name = "assert_next_dataset_op_test",
size = "small",
srcs = ["assert_next_dataset_op_test.cc"],
deps = [
":assert_next_dataset_op",
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels/data:dataset_test_base",
"//third_party/eigen3", "//third_party/eigen3",
], ],
) )

View File

@ -12,38 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h"
#include <map> #include <map>
#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/name_utils.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level /* static */ constexpr const char* const AssertNextDatasetOp::kInputDataset;
// description of the following op. /* static */ constexpr const char* const AssertNextDatasetOp::kDatasetType;
class AssertNextDatasetOp : public UnaryDatasetOpKernel { /* static */ constexpr const char* const AssertNextDatasetOp::kTransformations;
public: /* static */ constexpr const char* const AssertNextDatasetOp::kOutputTypes;
explicit AssertNextDatasetOp(OpKernelConstruction* ctx) /* static */ constexpr const char* const AssertNextDatasetOp::kOutputShapes;
: UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
protected: class AssertNextDatasetOp::Dataset : public DatasetBase {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
std::vector<string> transformations;
OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "transformations",
&transformations));
*output =
new Dataset(ctx, input, transformations, output_types_, output_shapes_);
}
private:
class Dataset : public DatasetBase {
public: public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, Dataset(OpKernelContext* ctx, const DatasetBase* input,
const std::vector<string>& transformations, const std::vector<string>& transformations,
@ -61,19 +48,17 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal( std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override { const string& prefix) const override {
return absl::make_unique<Iterator>( return absl::make_unique<Iterator>(Iterator::Params{
Iterator::Params{this, strings::StrCat(prefix, "::AssertNext")}); this, name_utils::IteratorPrefix(kDatasetType, prefix)});
} }
const DataTypeVector& output_dtypes() const override { const DataTypeVector& output_dtypes() const override { return output_types_; }
return output_types_;
}
const std::vector<PartialTensorShape>& output_shapes() const override { const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_; return output_shapes_;
} }
string DebugString() const override { string DebugString() const override {
return "AssertNextDatasetOp::Dataset"; return name_utils::DatasetDebugString(kDatasetType);
} }
int64 Cardinality() const override { return input_->Cardinality(); } int64 Cardinality() const override { return input_->Cardinality(); }
@ -86,8 +71,8 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* transformations_node = nullptr; Node* transformations_node = nullptr;
TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node)); TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node));
TF_RETURN_IF_ERROR(b->AddDataset( TF_RETURN_IF_ERROR(
this, {input_graph_node, transformations_node}, output)); b->AddDataset(this, {input_graph_node, transformations_node}, output));
return Status::OK(); return Status::OK();
} }
@ -149,12 +134,24 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel {
const std::vector<string> transformations_; const std::vector<string> transformations_;
const DataTypeVector output_types_; const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_; const std::vector<PartialTensorShape> output_shapes_;
};
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
}; };
AssertNextDatasetOp::AssertNextDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
}
void AssertNextDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) {
std::vector<string> transformations;
OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, kTransformations,
&transformations));
*output =
new Dataset(ctx, input, transformations, output_types_, output_shapes_);
}
namespace {
REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU),
AssertNextDatasetOp); AssertNextDatasetOp);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(

View File

@ -0,0 +1,49 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_NEXT_DATASET_OP_H_
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_NEXT_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
namespace data {
// See documentation in ../../ops/experimental_dataset_ops.cc for a high-level
// description of the following op.
class AssertNextDatasetOp : public UnaryDatasetOpKernel {
public:
static constexpr const char* const kDatasetType = "AssertNext";
static constexpr const char* const kInputDataset = "input_dataset";
static constexpr const char* const kTransformations = "transformations";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
explicit AssertNextDatasetOp(OpKernelConstruction* ctx);
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override;
private:
class Dataset;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_NEXT_DATASET_OP_H_

View File

@ -0,0 +1,667 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kNodeName[] = "assert_next_dataset";
struct RangeDatasetParams {
int start;
int stop;
int step;
};
struct TakeDatasetParams {
int count;
};
class AssertNextDatasetOpTest : public DatasetOpsTestBase {
protected:
// Creates a new `AssertNextDataset` op kernel.
Status CreateAssertNextDatasetOpKernel(
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
std::unique_ptr<OpKernel>* assert_next_dataset_op_kernel) {
NodeDef node_def = test::function::NDef(
kNodeName, name_utils::OpName(AssertNextDatasetOp::kDatasetType),
{AssertNextDatasetOp::kInputDataset,
AssertNextDatasetOp::kTransformations},
{{AssertNextDatasetOp::kOutputTypes, output_types},
{AssertNextDatasetOp::kOutputShapes, output_shapes}});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, assert_next_dataset_op_kernel));
return Status::OK();
}
// Creates a new `AssertNextDataset` op kernel context.
Status CreateAssertNextDatasetContext(
OpKernel* const op_kernel,
gtl::InlinedVector<TensorValue, 4>* const inputs,
std::unique_ptr<OpKernelContext>* context) {
TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
return Status::OK();
}
// Creates a new `RangeAndTakeDataset` tensor.
Status MakeRangeAndTakeDatasetTensor(
const RangeDatasetParams& range_dataset_params,
const TakeDatasetParams& take_dataset_params,
Tensor* range_and_take_dataset_tensor) {
Tensor range_dataset_tensor;
Tensor start =
CreateTensor<int64>(TensorShape({}), {range_dataset_params.start});
Tensor stop =
CreateTensor<int64>(TensorShape({}), {range_dataset_params.stop});
Tensor step =
CreateTensor<int64>(TensorShape({}), {range_dataset_params.step});
TF_RETURN_IF_ERROR(MakeRangeDataset(start, stop, step, {DT_INT64},
{PartialTensorShape({})},
&range_dataset_tensor));
TF_RETURN_IF_ERROR(MakeTakeDataset(
range_dataset_tensor, take_dataset_params.count, {DT_INT64},
{PartialTensorShape({})}, range_and_take_dataset_tensor));
return Status::OK();
}
};
struct TestCase {
RangeDatasetParams range_dataset_params;
TakeDatasetParams take_dataset_params;
Tensor transformations;
std::vector<Tensor> expected_outputs;
DataTypeVector expected_output_dtypes;
std::vector<PartialTensorShape> expected_output_shapes;
int64 expected_cardinality;
std::vector<int> breakpoints;
};
// Test case 1 : assert one transformation.
TestCase TestCase1() {
return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
/*take_dataset_params*/ {/*count*/ 3},
/*transformations*/
DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1}), {TakeDatasetOp::kDatasetType}),
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 3,
/*breakpoints*/ {0, 2, 5}};
}
// Test case 2 : assert two transformations.
TestCase TestCase2() {
return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
/*take_dataset_params*/ {/*count*/ 3},
/*transformations*/
DatasetOpsTestBase::CreateTensor<string>(
TensorShape({2}),
{TakeDatasetOp::kDatasetType, RangeDatasetOp::kDatasetType}),
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 3,
/*breakpoints*/ {0, 2, 5}};
}
TestCase AssertNextInvalid() {
return {
/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
/*take_dataset_params*/ {/*count*/ 3},
/*transformations*/
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"Whoops"}),
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 3,
/*breakpoints*/ {0, 2, 5}};
}
TestCase AssertNextShort() {
return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
/*take_dataset_params*/ {/*count*/ 3},
/*transformations*/
DatasetOpsTestBase::CreateTensor<string>(
TensorShape({3}), {TakeDatasetOp::kDatasetType,
RangeDatasetOp::kDatasetType, "Whoops"}),
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
/*expected_output_dtypes*/ {DT_INT64},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ 3,
/*breakpoints*/ {0, 2, 5}};
}
class ParameterizedAssertNextDatasetOpTest
: public AssertNextDatasetOpTest,
public ::testing::WithParamInterface<TestCase> {};
TEST_P(ParameterizedAssertNextDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor range_and_take_dataset_tensor;
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
test_case.take_dataset_params,
&range_and_take_dataset_tensor));
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&assert_next_dataset_kernel));
Tensor transformations = test_case.transformations;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_and_take_dataset_tensor),
TensorValue(&transformations)});
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
TF_ASSERT_OK(CreateAssertNextDatasetContext(
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
DatasetBase* assert_next_dataset;
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
assert_next_dataset_context.get(),
&assert_next_dataset));
core::ScopedUnref scoped_unref(assert_next_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
string iterator_prefix = name_utils::IteratorPrefix(
TakeDatasetOp::kDatasetType,
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(),
iterator_prefix, &iterator));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
while (!end_of_sequence) {
std::vector<Tensor> next;
TF_EXPECT_OK(
iterator->GetNext(iterator_context.get(), &next, &end_of_sequence));
out_tensors.insert(out_tensors.end(), next.begin(), next.end());
}
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
/*compare_order*/ true));
}
TEST_F(AssertNextDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = TestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor range_and_take_dataset_tensor;
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
test_case.take_dataset_params,
&range_and_take_dataset_tensor));
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&assert_next_dataset_kernel));
Tensor transformations = test_case.transformations;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_and_take_dataset_tensor),
TensorValue(&transformations)});
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
TF_ASSERT_OK(CreateAssertNextDatasetContext(
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
DatasetBase* assert_next_dataset;
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
assert_next_dataset_context.get(),
&assert_next_dataset));
core::ScopedUnref scoped_unref(assert_next_dataset);
EXPECT_EQ(assert_next_dataset->node_name(), kNodeName);
}
TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor range_and_take_dataset_tensor;
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
test_case.take_dataset_params,
&range_and_take_dataset_tensor));
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&assert_next_dataset_kernel));
Tensor transformations = test_case.transformations;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_and_take_dataset_tensor),
TensorValue(&transformations)});
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
TF_ASSERT_OK(CreateAssertNextDatasetContext(
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
DatasetBase* assert_next_dataset;
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
assert_next_dataset_context.get(),
&assert_next_dataset));
core::ScopedUnref scoped_unref(assert_next_dataset);
EXPECT_EQ(assert_next_dataset->type_string(),
name_utils::OpName(AssertNextDatasetOp::kDatasetType));
}
TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor range_and_take_dataset_tensor;
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
test_case.take_dataset_params,
&range_and_take_dataset_tensor));
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&assert_next_dataset_kernel));
Tensor transformations = test_case.transformations;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_and_take_dataset_tensor),
TensorValue(&transformations)});
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
TF_ASSERT_OK(CreateAssertNextDatasetContext(
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
DatasetBase* assert_next_dataset;
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
assert_next_dataset_context.get(),
&assert_next_dataset));
core::ScopedUnref scoped_unref(assert_next_dataset);
TF_EXPECT_OK(VerifyTypesMatch(assert_next_dataset->output_dtypes(),
test_case.expected_output_dtypes));
}
TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor range_and_take_dataset_tensor;
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
test_case.take_dataset_params,
&range_and_take_dataset_tensor));
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&assert_next_dataset_kernel));
Tensor transformations = test_case.transformations;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_and_take_dataset_tensor),
TensorValue(&transformations)});
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
TF_ASSERT_OK(CreateAssertNextDatasetContext(
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
DatasetBase* assert_next_dataset;
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
assert_next_dataset_context.get(),
&assert_next_dataset));
core::ScopedUnref scoped_unref(assert_next_dataset);
TF_EXPECT_OK(VerifyShapesCompatible(assert_next_dataset->output_shapes(),
test_case.expected_output_shapes));
}
TEST_P(ParameterizedAssertNextDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor range_and_take_dataset_tensor;
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
test_case.take_dataset_params,
&range_and_take_dataset_tensor));
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&assert_next_dataset_kernel));
Tensor transformations = test_case.transformations;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_and_take_dataset_tensor),
TensorValue(&transformations)});
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
TF_ASSERT_OK(CreateAssertNextDatasetContext(
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
DatasetBase* assert_next_dataset;
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
assert_next_dataset_context.get(),
&assert_next_dataset));
core::ScopedUnref scoped_unref(assert_next_dataset);
EXPECT_EQ(assert_next_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor range_and_take_dataset_tensor;
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
test_case.take_dataset_params,
&range_and_take_dataset_tensor));
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&assert_next_dataset_kernel));
Tensor transformations = test_case.transformations;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_and_take_dataset_tensor),
TensorValue(&transformations)});
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
TF_ASSERT_OK(CreateAssertNextDatasetContext(
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
DatasetBase* assert_next_dataset;
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
assert_next_dataset_context.get(),
&assert_next_dataset));
core::ScopedUnref scoped_unref(assert_next_dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(assert_next_dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedAssertNextDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor range_and_take_dataset_tensor;
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
test_case.take_dataset_params,
&range_and_take_dataset_tensor));
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&assert_next_dataset_kernel));
Tensor transformations = test_case.transformations;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_and_take_dataset_tensor),
TensorValue(&transformations)});
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
TF_ASSERT_OK(CreateAssertNextDatasetContext(
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
DatasetBase* assert_next_dataset;
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
assert_next_dataset_context.get(),
&assert_next_dataset));
core::ScopedUnref scoped_unref(assert_next_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
string iterator_prefix = name_utils::IteratorPrefix(
TakeDatasetOp::kDatasetType,
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(),
iterator_prefix, &iterator));
TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
test_case.expected_output_dtypes));
}
TEST_P(ParameterizedAssertNextDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor range_and_take_dataset_tensor;
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
test_case.take_dataset_params,
&range_and_take_dataset_tensor));
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&assert_next_dataset_kernel));
Tensor transformations = test_case.transformations;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_and_take_dataset_tensor),
TensorValue(&transformations)});
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
TF_ASSERT_OK(CreateAssertNextDatasetContext(
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
DatasetBase* assert_next_dataset;
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
assert_next_dataset_context.get(),
&assert_next_dataset));
core::ScopedUnref scoped_unref(assert_next_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
string iterator_prefix = name_utils::IteratorPrefix(
TakeDatasetOp::kDatasetType,
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(),
iterator_prefix, &iterator));
TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
test_case.expected_output_shapes));
}
TEST_P(ParameterizedAssertNextDatasetOpTest, IteratorOutputPrefix) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor range_and_take_dataset_tensor;
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
test_case.take_dataset_params,
&range_and_take_dataset_tensor));
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&assert_next_dataset_kernel));
Tensor transformations = test_case.transformations;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_and_take_dataset_tensor),
TensorValue(&transformations)});
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
TF_ASSERT_OK(CreateAssertNextDatasetContext(
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
DatasetBase* assert_next_dataset;
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
assert_next_dataset_context.get(),
&assert_next_dataset));
core::ScopedUnref scoped_unref(assert_next_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
string iterator_prefix = name_utils::IteratorPrefix(
TakeDatasetOp::kDatasetType,
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(),
iterator_prefix, &iterator));
EXPECT_EQ(iterator->prefix(),
name_utils::IteratorPrefix(AssertNextDatasetOp::kDatasetType,
iterator_prefix));
}
TEST_P(ParameterizedAssertNextDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor range_and_take_dataset_tensor;
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
test_case.take_dataset_params,
&range_and_take_dataset_tensor));
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&assert_next_dataset_kernel));
Tensor transformations = test_case.transformations;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_and_take_dataset_tensor),
TensorValue(&transformations)});
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
TF_ASSERT_OK(CreateAssertNextDatasetContext(
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
DatasetBase* assert_next_dataset;
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
assert_next_dataset_context.get(),
&assert_next_dataset));
core::ScopedUnref scoped_unref(assert_next_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
string iterator_prefix = name_utils::IteratorPrefix(
TakeDatasetOp::kDatasetType,
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(),
iterator_prefix, &iterator));
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
int cur_iteration = 0;
const std::vector<int>& breakpoints = test_case.breakpoints;
for (int breakpoint : breakpoints) {
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(RestoreIterator(iterator_context.get(), &reader,
iterator_prefix, *assert_next_dataset,
&iterator));
while (cur_iteration <= breakpoint) {
std::vector<Tensor> next;
TF_EXPECT_OK(
iterator->GetNext(iterator_context.get(), &next, &end_of_sequence));
out_tensors.insert(out_tensors.end(), next.begin(), next.end());
++cur_iteration;
}
}
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
/*compare_order*/ true));
}
INSTANTIATE_TEST_SUITE_P(
AssertNextDatasetOpTest, ParameterizedAssertNextDatasetOpTest,
::testing::ValuesIn(std::vector<TestCase>({TestCase1(), TestCase2()})));
TEST_F(AssertNextDatasetOpTest, InvalidArguments) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::vector<TestCase> test_cases = {AssertNextInvalid(), AssertNextShort()};
for (TestCase test_case : test_cases) {
Tensor range_and_take_dataset_tensor;
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
test_case.take_dataset_params,
&range_and_take_dataset_tensor));
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(
test_case.expected_output_dtypes, test_case.expected_output_shapes,
&assert_next_dataset_kernel));
Tensor transformations = test_case.transformations;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_and_take_dataset_tensor),
TensorValue(&transformations)});
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
TF_ASSERT_OK(
CreateAssertNextDatasetContext(assert_next_dataset_kernel.get(),
&inputs, &assert_next_dataset_context));
DatasetBase* assert_next_dataset;
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
assert_next_dataset_context.get(),
&assert_next_dataset));
core::ScopedUnref scoped_unref(assert_next_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
&iterator_context));
std::unique_ptr<IteratorBase> iterator;
string iterator_prefix = name_utils::IteratorPrefix(
TakeDatasetOp::kDatasetType,
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
EXPECT_EQ(
assert_next_dataset
->MakeIterator(iterator_context.get(), iterator_prefix, &iterator)
.code(),
tensorflow::error::INVALID_ARGUMENT);
}
}
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -50,45 +50,6 @@ class OptimizeDatasetOpTest : public DatasetOpsTestBase {
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
return Status::OK(); return Status::OK();
} }
// Create a `RangeDataset` dataset as a variant tensor.
Status MakeRangeDataset(const Tensor& start, const Tensor& stop,
const Tensor& step,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
Tensor* range_dataset) {
GraphConstructorOptions graph_opts;
graph_opts.allow_internal_ops = true;
graph_opts.expect_device_spec = false;
TF_RETURN_IF_ERROR(
RunFunction(test::function::MakeRangeDataset(),
/*attrs*/
{{RangeDatasetOp::kOutputTypes, output_types},
{RangeDatasetOp::kOutputShapes, output_shapes}},
/*inputs*/ {start, stop, step}, graph_opts,
/*rets*/ {range_dataset}));
return Status::OK();
}
// Create a `TakeDataset` dataset as a variant tensor.
Status MakeTakeDataset(const Tensor& input_dataset, int64 count,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
Tensor* take_dataset) {
GraphConstructorOptions graph_opts;
graph_opts.allow_internal_ops = true;
graph_opts.expect_device_spec = false;
Tensor count_tensor = CreateTensor<int64>(TensorShape({}), {count});
TF_RETURN_IF_ERROR(
RunFunction(test::function::MakeTakeDataset(),
/*attrs*/
{{TakeDatasetOp::kOutputTypes, output_types},
{TakeDatasetOp::kOutputShapes, output_shapes}},
/*inputs*/ {input_dataset, count_tensor}, graph_opts,
/*rets*/ {take_dataset}));
return Status::OK();
}
}; };
TEST_F(OptimizeDatasetOpTest, NoopElimination) { TEST_F(OptimizeDatasetOpTest, NoopElimination) {