diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 83252bfcbd8..a5f41b6dcae 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -30,6 +30,7 @@ cc_library( ":iterator_ops", ":name_utils", ":range_dataset_op", + ":take_dataset_op", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc index 8c9d775444f..2a5f03edf16 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.cc +++ b/tensorflow/core/kernels/data/dataset_test_base.cc @@ -274,6 +274,46 @@ Status DatasetOpsTestBase::CreateTensorSliceDataset( return Status::OK(); } +// Create a `RangeDataset` dataset as a variant tensor. +Status DatasetOpsTestBase::MakeRangeDataset( + const Tensor& start, const Tensor& stop, const Tensor& step, + const DataTypeVector& output_types, + const std::vector& output_shapes, + Tensor* range_dataset) { + GraphConstructorOptions graph_opts; + graph_opts.allow_internal_ops = true; + graph_opts.expect_device_spec = false; + TF_RETURN_IF_ERROR( + RunFunction(test::function::MakeRangeDataset(), + /*attrs*/ + {{RangeDatasetOp::kOutputTypes, output_types}, + {RangeDatasetOp::kOutputShapes, output_shapes}}, + /*inputs*/ {start, stop, step}, graph_opts, + /*rets*/ {range_dataset})); + return Status::OK(); +} + +// Create a `TakeDataset` dataset as a variant tensor. +Status DatasetOpsTestBase::MakeTakeDataset( + const Tensor& input_dataset, int64 count, + const DataTypeVector& output_types, + const std::vector& output_shapes, + Tensor* take_dataset) { + GraphConstructorOptions graph_opts; + graph_opts.allow_internal_ops = true; + graph_opts.expect_device_spec = false; + + Tensor count_tensor = CreateTensor(TensorShape({}), {count}); + TF_RETURN_IF_ERROR( + RunFunction(test::function::MakeTakeDataset(), + /*attrs*/ + {{TakeDatasetOp::kOutputTypes, output_types}, + {TakeDatasetOp::kOutputShapes, output_shapes}}, + /*inputs*/ {input_dataset, count_tensor}, graph_opts, + /*rets*/ {take_dataset})); + return Status::OK(); +} + Status DatasetOpsTestBase::CreateOpKernel( const NodeDef& node_def, std::unique_ptr* op_kernel) { OpKernel* kernel; diff --git a/tensorflow/core/kernels/data/dataset_test_base.h b/tensorflow/core/kernels/data/dataset_test_base.h index 75a221e2782..427cccac9f9 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.h +++ b/tensorflow/core/kernels/data/dataset_test_base.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/iterator_ops.h" #include "tensorflow/core/kernels/data/name_utils.h" #include "tensorflow/core/kernels/data/range_dataset_op.h" +#include "tensorflow/core/kernels/data/take_dataset_op.h" #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_outputbuffer.h" @@ -177,6 +178,19 @@ class DatasetOpsTestBase : public ::testing::Test { std::vector* const components, DatasetBase** tensor_slice_dataset); + // Creates a `RangeDataset` dataset as a variant tensor. + Status MakeRangeDataset(const Tensor& start, const Tensor& stop, + const Tensor& step, + const DataTypeVector& output_types, + const std::vector& output_shapes, + Tensor* range_dataset); + + // Creates a `TakeDataset` dataset as a variant tensor. + Status MakeTakeDataset(const Tensor& input_dataset, int64 count, + const DataTypeVector& output_types, + const std::vector& output_shapes, + Tensor* take_dataset); + // Fetches the dataset from the operation context. Status GetDatasetFromContext(OpKernelContext* context, int output_index, DatasetBase** const dataset); diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index d16f580d1c5..2ff370e92a6 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -3,6 +3,7 @@ load( "//tensorflow:tensorflow.bzl", + "tf_cc_test", "tf_kernel_library", ) @@ -16,9 +17,27 @@ exports_files(["LICENSE"]) tf_kernel_library( name = "assert_next_dataset_op", srcs = ["assert_next_dataset_op.cc"], + hdrs = ["assert_next_dataset_op.h"], deps = [ "//tensorflow/core:experimental_dataset_ops_op_lib", "//tensorflow/core:framework", + "//tensorflow/core/kernels/data:name_utils", + "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "assert_next_dataset_op_test", + size = "small", + srcs = ["assert_next_dataset_op_test.cc"], + deps = [ + ":assert_next_dataset_op", + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels/data:dataset_test_base", "//third_party/eigen3", ], ) diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc index b84d813c023..592d8db8281 100644 --- a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc @@ -12,149 +12,146 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h" + #include #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/name_utils.h" namespace tensorflow { namespace data { -namespace { -// See documentation in ../ops/dataset_ops.cc for a high-level -// description of the following op. -class AssertNextDatasetOp : public UnaryDatasetOpKernel { +/* static */ constexpr const char* const AssertNextDatasetOp::kInputDataset; +/* static */ constexpr const char* const AssertNextDatasetOp::kDatasetType; +/* static */ constexpr const char* const AssertNextDatasetOp::kTransformations; +/* static */ constexpr const char* const AssertNextDatasetOp::kOutputTypes; +/* static */ constexpr const char* const AssertNextDatasetOp::kOutputShapes; + +class AssertNextDatasetOp::Dataset : public DatasetBase { public: - explicit AssertNextDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const std::vector& transformations, + const DataTypeVector& output_types, + const std::vector& output_shapes) + : DatasetBase(DatasetContext(ctx)), + input_(input), + transformations_(transformations), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); } + ~Dataset() override { input_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return absl::make_unique(Iterator::Params{ + this, name_utils::IteratorPrefix(kDatasetType, prefix)}); + } + + const DataTypeVector& output_dtypes() const override { return output_types_; } + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return name_utils::DatasetDebugString(kDatasetType); + } + + int64 Cardinality() const override { return input_->Cardinality(); } + protected: - void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase** output) override { - std::vector transformations; - OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "transformations", - &transformations)); - *output = - new Dataset(ctx, input, transformations, output_types_, output_shapes_); + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* transformations_node = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {input_graph_node, transformations_node}, output)); + return Status::OK(); } private: - class Dataset : public DatasetBase { + class Iterator : public DatasetIterator { public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, - const std::vector& transformations, - const DataTypeVector& output_types, - const std::vector& output_shapes) - : DatasetBase(DatasetContext(ctx)), - input_(input), - transformations_(transformations), - output_types_(output_types), - output_shapes_(output_shapes) { - input_->Ref(); + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + std::vector tokens = + absl::StrSplit(prefix(), ':', absl::SkipEmpty()); + if (dataset()->transformations_.size() > tokens.size() - 2) { + return errors::InvalidArgument( + "Asserted next ", dataset()->transformations_.size(), + " transformations but encountered only ", tokens.size() - 2, "."); + } + int n = tokens.size(); + for (size_t i = 0; i < dataset()->transformations_.size(); ++i) { + if (dataset()->transformations_[i] != tokens[n - 2 - i]) { + return errors::InvalidArgument( + "Asserted ", dataset()->transformations_[i], + " transformation at offset ", i, " but encountered ", + tokens[n - 2 - i], " transformation instead."); + } + } + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); } - ~Dataset() override { input_->Unref(); } - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - return absl::make_unique( - Iterator::Params{this, strings::StrCat(prefix, "::AssertNext")}); + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); } - const DataTypeVector& output_dtypes() const override { - return output_types_; - } - const std::vector& output_shapes() const override { - return output_shapes_; - } - - string DebugString() const override { - return "AssertNextDatasetOp::Dataset"; - } - - int64 Cardinality() const override { return input_->Cardinality(); } - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); - Node* transformations_node = nullptr; - TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node)); - TF_RETURN_IF_ERROR(b->AddDataset( - this, {input_graph_node, transformations_node}, output)); + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); return Status::OK(); } private: - class Iterator : public DatasetIterator { - public: - explicit Iterator(const Params& params) - : DatasetIterator(params) {} - - Status Initialize(IteratorContext* ctx) override { - std::vector tokens = - absl::StrSplit(prefix(), ':', absl::SkipEmpty()); - if (dataset()->transformations_.size() > tokens.size() - 2) { - return errors::InvalidArgument( - "Asserted next ", dataset()->transformations_.size(), - " transformations but encountered only ", tokens.size() - 2, "."); - } - int n = tokens.size(); - for (size_t i = 0; i < dataset()->transformations_.size(); ++i) { - if (dataset()->transformations_[i] != tokens[n - 2 - i]) { - return errors::InvalidArgument( - "Asserted ", dataset()->transformations_[i], - " transformation at offset ", i, " but encountered ", - tokens[n - 2 - i], " transformation instead."); - } - } - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); - } - - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); - } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeKnownRatioNode(std::move(args), - /*ratio=*/1); - } - - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return Status::OK(); - } - - private: - std::unique_ptr input_impl_; - }; - - const DatasetBase* input_; - const std::vector transformations_; - const DataTypeVector output_types_; - const std::vector output_shapes_; + std::unique_ptr input_impl_; }; - DataTypeVector output_types_; - std::vector output_shapes_; + const DatasetBase* input_; + const std::vector transformations_; + const DataTypeVector output_types_; + const std::vector output_shapes_; }; +AssertNextDatasetOp::AssertNextDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); +} + +void AssertNextDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) { + std::vector transformations; + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, kTransformations, + &transformations)); + *output = + new Dataset(ctx, input, transformations, output_types_, output_shapes_); +} + +namespace { REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU), AssertNextDatasetOp); REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h new file mode 100644 index 00000000000..aae2e80323e --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h @@ -0,0 +1,49 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_NEXT_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_NEXT_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +// See documentation in ../../ops/experimental_dataset_ops.cc for a high-level +// description of the following op. + +class AssertNextDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "AssertNext"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kTransformations = "transformations"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit AssertNextDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_NEXT_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op_test.cc new file mode 100644 index 00000000000..e256d5ba008 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op_test.cc @@ -0,0 +1,667 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h" + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "assert_next_dataset"; + +struct RangeDatasetParams { + int start; + int stop; + int step; +}; + +struct TakeDatasetParams { + int count; +}; + +class AssertNextDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates a new `AssertNextDataset` op kernel. + Status CreateAssertNextDatasetOpKernel( + const DataTypeVector& output_types, + const std::vector& output_shapes, + std::unique_ptr* assert_next_dataset_op_kernel) { + NodeDef node_def = test::function::NDef( + kNodeName, name_utils::OpName(AssertNextDatasetOp::kDatasetType), + {AssertNextDatasetOp::kInputDataset, + AssertNextDatasetOp::kTransformations}, + {{AssertNextDatasetOp::kOutputTypes, output_types}, + {AssertNextDatasetOp::kOutputShapes, output_shapes}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, assert_next_dataset_op_kernel)); + return Status::OK(); + } + + // Creates a new `AssertNextDataset` op kernel context. + Status CreateAssertNextDatasetContext( + OpKernel* const op_kernel, + gtl::InlinedVector* const inputs, + std::unique_ptr* context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } + + // Creates a new `RangeAndTakeDataset` tensor. + Status MakeRangeAndTakeDatasetTensor( + const RangeDatasetParams& range_dataset_params, + const TakeDatasetParams& take_dataset_params, + Tensor* range_and_take_dataset_tensor) { + Tensor range_dataset_tensor; + Tensor start = + CreateTensor(TensorShape({}), {range_dataset_params.start}); + Tensor stop = + CreateTensor(TensorShape({}), {range_dataset_params.stop}); + Tensor step = + CreateTensor(TensorShape({}), {range_dataset_params.step}); + TF_RETURN_IF_ERROR(MakeRangeDataset(start, stop, step, {DT_INT64}, + {PartialTensorShape({})}, + &range_dataset_tensor)); + + TF_RETURN_IF_ERROR(MakeTakeDataset( + range_dataset_tensor, take_dataset_params.count, {DT_INT64}, + {PartialTensorShape({})}, range_and_take_dataset_tensor)); + return Status::OK(); + } +}; + +struct TestCase { + RangeDatasetParams range_dataset_params; + TakeDatasetParams take_dataset_params; + Tensor transformations; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +// Test case 1 : assert one transformation. +TestCase TestCase1() { + return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1}, + /*take_dataset_params*/ {/*count*/ 3}, + /*transformations*/ + DatasetOpsTestBase::CreateTensor( + TensorShape({1}), {TakeDatasetOp::kDatasetType}), + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 3, + /*breakpoints*/ {0, 2, 5}}; +} + +// Test case 2 : assert two transformations. +TestCase TestCase2() { + return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1}, + /*take_dataset_params*/ {/*count*/ 3}, + /*transformations*/ + DatasetOpsTestBase::CreateTensor( + TensorShape({2}), + {TakeDatasetOp::kDatasetType, RangeDatasetOp::kDatasetType}), + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 3, + /*breakpoints*/ {0, 2, 5}}; +} + +TestCase AssertNextInvalid() { + return { + /*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1}, + /*take_dataset_params*/ {/*count*/ 3}, + /*transformations*/ + DatasetOpsTestBase::CreateTensor(TensorShape({1}), {"Whoops"}), + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 3, + /*breakpoints*/ {0, 2, 5}}; +} + +TestCase AssertNextShort() { + return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1}, + /*take_dataset_params*/ {/*count*/ 3}, + /*transformations*/ + DatasetOpsTestBase::CreateTensor( + TensorShape({3}), {TakeDatasetOp::kDatasetType, + RangeDatasetOp::kDatasetType, "Whoops"}), + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 3, + /*breakpoints*/ {0, 2, 5}}; +} + +class ParameterizedAssertNextDatasetOpTest + : public AssertNextDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedAssertNextDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor range_and_take_dataset_tensor; + TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params, + test_case.take_dataset_params, + &range_and_take_dataset_tensor)); + + std::unique_ptr assert_next_dataset_kernel; + TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &assert_next_dataset_kernel)); + Tensor transformations = test_case.transformations; + gtl::InlinedVector inputs( + {TensorValue(&range_and_take_dataset_tensor), + TensorValue(&transformations)}); + std::unique_ptr assert_next_dataset_context; + TF_ASSERT_OK(CreateAssertNextDatasetContext( + assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context)); + + DatasetBase* assert_next_dataset; + TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(), + assert_next_dataset_context.get(), + &assert_next_dataset)); + core::ScopedUnref scoped_unref(assert_next_dataset); + + std::unique_ptr iterator_context; + TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(), + &iterator_context)); + std::unique_ptr iterator; + string iterator_prefix = name_utils::IteratorPrefix( + TakeDatasetOp::kDatasetType, + name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator")); + TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(), + iterator_prefix, &iterator)); + + bool end_of_sequence = false; + std::vector out_tensors; + while (!end_of_sequence) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_context.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + } + + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, + /*compare_order*/ true)); +} + +TEST_F(AssertNextDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor range_and_take_dataset_tensor; + TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params, + test_case.take_dataset_params, + &range_and_take_dataset_tensor)); + + std::unique_ptr assert_next_dataset_kernel; + TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &assert_next_dataset_kernel)); + Tensor transformations = test_case.transformations; + gtl::InlinedVector inputs( + {TensorValue(&range_and_take_dataset_tensor), + TensorValue(&transformations)}); + std::unique_ptr assert_next_dataset_context; + TF_ASSERT_OK(CreateAssertNextDatasetContext( + assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context)); + + DatasetBase* assert_next_dataset; + TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(), + assert_next_dataset_context.get(), + &assert_next_dataset)); + core::ScopedUnref scoped_unref(assert_next_dataset); + + EXPECT_EQ(assert_next_dataset->node_name(), kNodeName); +} + +TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor range_and_take_dataset_tensor; + TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params, + test_case.take_dataset_params, + &range_and_take_dataset_tensor)); + + std::unique_ptr assert_next_dataset_kernel; + TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &assert_next_dataset_kernel)); + Tensor transformations = test_case.transformations; + gtl::InlinedVector inputs( + {TensorValue(&range_and_take_dataset_tensor), + TensorValue(&transformations)}); + std::unique_ptr assert_next_dataset_context; + TF_ASSERT_OK(CreateAssertNextDatasetContext( + assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context)); + + DatasetBase* assert_next_dataset; + TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(), + assert_next_dataset_context.get(), + &assert_next_dataset)); + core::ScopedUnref scoped_unref(assert_next_dataset); + + EXPECT_EQ(assert_next_dataset->type_string(), + name_utils::OpName(AssertNextDatasetOp::kDatasetType)); +} + +TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor range_and_take_dataset_tensor; + TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params, + test_case.take_dataset_params, + &range_and_take_dataset_tensor)); + + std::unique_ptr assert_next_dataset_kernel; + TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &assert_next_dataset_kernel)); + Tensor transformations = test_case.transformations; + gtl::InlinedVector inputs( + {TensorValue(&range_and_take_dataset_tensor), + TensorValue(&transformations)}); + std::unique_ptr assert_next_dataset_context; + TF_ASSERT_OK(CreateAssertNextDatasetContext( + assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context)); + + DatasetBase* assert_next_dataset; + TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(), + assert_next_dataset_context.get(), + &assert_next_dataset)); + core::ScopedUnref scoped_unref(assert_next_dataset); + + TF_EXPECT_OK(VerifyTypesMatch(assert_next_dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor range_and_take_dataset_tensor; + TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params, + test_case.take_dataset_params, + &range_and_take_dataset_tensor)); + + std::unique_ptr assert_next_dataset_kernel; + TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &assert_next_dataset_kernel)); + Tensor transformations = test_case.transformations; + gtl::InlinedVector inputs( + {TensorValue(&range_and_take_dataset_tensor), + TensorValue(&transformations)}); + std::unique_ptr assert_next_dataset_context; + TF_ASSERT_OK(CreateAssertNextDatasetContext( + assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context)); + + DatasetBase* assert_next_dataset; + TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(), + assert_next_dataset_context.get(), + &assert_next_dataset)); + core::ScopedUnref scoped_unref(assert_next_dataset); + + TF_EXPECT_OK(VerifyShapesCompatible(assert_next_dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedAssertNextDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor range_and_take_dataset_tensor; + TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params, + test_case.take_dataset_params, + &range_and_take_dataset_tensor)); + + std::unique_ptr assert_next_dataset_kernel; + TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &assert_next_dataset_kernel)); + Tensor transformations = test_case.transformations; + gtl::InlinedVector inputs( + {TensorValue(&range_and_take_dataset_tensor), + TensorValue(&transformations)}); + std::unique_ptr assert_next_dataset_context; + TF_ASSERT_OK(CreateAssertNextDatasetContext( + assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context)); + + DatasetBase* assert_next_dataset; + TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(), + assert_next_dataset_context.get(), + &assert_next_dataset)); + core::ScopedUnref scoped_unref(assert_next_dataset); + + EXPECT_EQ(assert_next_dataset->Cardinality(), test_case.expected_cardinality); +} + +TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor range_and_take_dataset_tensor; + TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params, + test_case.take_dataset_params, + &range_and_take_dataset_tensor)); + + std::unique_ptr assert_next_dataset_kernel; + TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &assert_next_dataset_kernel)); + Tensor transformations = test_case.transformations; + gtl::InlinedVector inputs( + {TensorValue(&range_and_take_dataset_tensor), + TensorValue(&transformations)}); + std::unique_ptr assert_next_dataset_context; + TF_ASSERT_OK(CreateAssertNextDatasetContext( + assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context)); + + DatasetBase* assert_next_dataset; + TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(), + assert_next_dataset_context.get(), + &assert_next_dataset)); + core::ScopedUnref scoped_unref(assert_next_dataset); + + std::unique_ptr serialization_context; + TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(assert_next_dataset->Save(serialization_context.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParameterizedAssertNextDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor range_and_take_dataset_tensor; + TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params, + test_case.take_dataset_params, + &range_and_take_dataset_tensor)); + + std::unique_ptr assert_next_dataset_kernel; + TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &assert_next_dataset_kernel)); + Tensor transformations = test_case.transformations; + gtl::InlinedVector inputs( + {TensorValue(&range_and_take_dataset_tensor), + TensorValue(&transformations)}); + std::unique_ptr assert_next_dataset_context; + TF_ASSERT_OK(CreateAssertNextDatasetContext( + assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context)); + + DatasetBase* assert_next_dataset; + TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(), + assert_next_dataset_context.get(), + &assert_next_dataset)); + core::ScopedUnref scoped_unref(assert_next_dataset); + + std::unique_ptr iterator_context; + TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(), + &iterator_context)); + std::unique_ptr iterator; + string iterator_prefix = name_utils::IteratorPrefix( + TakeDatasetOp::kDatasetType, + name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator")); + TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(), + iterator_prefix, &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedAssertNextDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor range_and_take_dataset_tensor; + TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params, + test_case.take_dataset_params, + &range_and_take_dataset_tensor)); + + std::unique_ptr assert_next_dataset_kernel; + TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &assert_next_dataset_kernel)); + Tensor transformations = test_case.transformations; + gtl::InlinedVector inputs( + {TensorValue(&range_and_take_dataset_tensor), + TensorValue(&transformations)}); + std::unique_ptr assert_next_dataset_context; + TF_ASSERT_OK(CreateAssertNextDatasetContext( + assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context)); + + DatasetBase* assert_next_dataset; + TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(), + assert_next_dataset_context.get(), + &assert_next_dataset)); + core::ScopedUnref scoped_unref(assert_next_dataset); + + std::unique_ptr iterator_context; + TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(), + &iterator_context)); + std::unique_ptr iterator; + string iterator_prefix = name_utils::IteratorPrefix( + TakeDatasetOp::kDatasetType, + name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator")); + TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(), + iterator_prefix, &iterator)); + + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedAssertNextDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor range_and_take_dataset_tensor; + TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params, + test_case.take_dataset_params, + &range_and_take_dataset_tensor)); + + std::unique_ptr assert_next_dataset_kernel; + TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &assert_next_dataset_kernel)); + Tensor transformations = test_case.transformations; + gtl::InlinedVector inputs( + {TensorValue(&range_and_take_dataset_tensor), + TensorValue(&transformations)}); + std::unique_ptr assert_next_dataset_context; + TF_ASSERT_OK(CreateAssertNextDatasetContext( + assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context)); + + DatasetBase* assert_next_dataset; + TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(), + assert_next_dataset_context.get(), + &assert_next_dataset)); + core::ScopedUnref scoped_unref(assert_next_dataset); + + std::unique_ptr iterator_context; + TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(), + &iterator_context)); + std::unique_ptr iterator; + string iterator_prefix = name_utils::IteratorPrefix( + TakeDatasetOp::kDatasetType, + name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator")); + TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(), + iterator_prefix, &iterator)); + + EXPECT_EQ(iterator->prefix(), + name_utils::IteratorPrefix(AssertNextDatasetOp::kDatasetType, + iterator_prefix)); +} + +TEST_P(ParameterizedAssertNextDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor range_and_take_dataset_tensor; + TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params, + test_case.take_dataset_params, + &range_and_take_dataset_tensor)); + + std::unique_ptr assert_next_dataset_kernel; + TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &assert_next_dataset_kernel)); + Tensor transformations = test_case.transformations; + gtl::InlinedVector inputs( + {TensorValue(&range_and_take_dataset_tensor), + TensorValue(&transformations)}); + std::unique_ptr assert_next_dataset_context; + TF_ASSERT_OK(CreateAssertNextDatasetContext( + assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context)); + + DatasetBase* assert_next_dataset; + TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(), + assert_next_dataset_context.get(), + &assert_next_dataset)); + core::ScopedUnref scoped_unref(assert_next_dataset); + + std::unique_ptr iterator_context; + TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(), + &iterator_context)); + std::unique_ptr iterator; + string iterator_prefix = name_utils::IteratorPrefix( + TakeDatasetOp::kDatasetType, + name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator")); + TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(), + iterator_prefix, &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + bool end_of_sequence = false; + std::vector out_tensors; + int cur_iteration = 0; + const std::vector& breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_context.get(), &reader, + iterator_prefix, *assert_next_dataset, + &iterator)); + + while (cur_iteration <= breakpoint) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_context.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + ++cur_iteration; + } + } + + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, + /*compare_order*/ true)); +} + +INSTANTIATE_TEST_SUITE_P( + AssertNextDatasetOpTest, ParameterizedAssertNextDatasetOpTest, + ::testing::ValuesIn(std::vector({TestCase1(), TestCase2()}))); + +TEST_F(AssertNextDatasetOpTest, InvalidArguments) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::vector test_cases = {AssertNextInvalid(), AssertNextShort()}; + for (TestCase test_case : test_cases) { + Tensor range_and_take_dataset_tensor; + TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params, + test_case.take_dataset_params, + &range_and_take_dataset_tensor)); + + std::unique_ptr assert_next_dataset_kernel; + TF_ASSERT_OK(CreateAssertNextDatasetOpKernel( + test_case.expected_output_dtypes, test_case.expected_output_shapes, + &assert_next_dataset_kernel)); + Tensor transformations = test_case.transformations; + gtl::InlinedVector inputs( + {TensorValue(&range_and_take_dataset_tensor), + TensorValue(&transformations)}); + std::unique_ptr assert_next_dataset_context; + TF_ASSERT_OK( + CreateAssertNextDatasetContext(assert_next_dataset_kernel.get(), + &inputs, &assert_next_dataset_context)); + + DatasetBase* assert_next_dataset; + TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(), + assert_next_dataset_context.get(), + &assert_next_dataset)); + core::ScopedUnref scoped_unref(assert_next_dataset); + + std::unique_ptr iterator_context; + TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(), + &iterator_context)); + std::unique_ptr iterator; + string iterator_prefix = name_utils::IteratorPrefix( + TakeDatasetOp::kDatasetType, + name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator")); + EXPECT_EQ( + assert_next_dataset + ->MakeIterator(iterator_context.get(), iterator_prefix, &iterator) + .code(), + tensorflow::error::INVALID_ARGUMENT); + } +} + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/optimize_dataset_op_test.cc b/tensorflow/core/kernels/data/optimize_dataset_op_test.cc index 94dda91dbef..4469c6eebf7 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op_test.cc @@ -50,45 +50,6 @@ class OptimizeDatasetOpTest : public DatasetOpsTestBase { TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); return Status::OK(); } - - // Create a `RangeDataset` dataset as a variant tensor. - Status MakeRangeDataset(const Tensor& start, const Tensor& stop, - const Tensor& step, - const DataTypeVector& output_types, - const std::vector& output_shapes, - Tensor* range_dataset) { - GraphConstructorOptions graph_opts; - graph_opts.allow_internal_ops = true; - graph_opts.expect_device_spec = false; - TF_RETURN_IF_ERROR( - RunFunction(test::function::MakeRangeDataset(), - /*attrs*/ - {{RangeDatasetOp::kOutputTypes, output_types}, - {RangeDatasetOp::kOutputShapes, output_shapes}}, - /*inputs*/ {start, stop, step}, graph_opts, - /*rets*/ {range_dataset})); - return Status::OK(); - } - - // Create a `TakeDataset` dataset as a variant tensor. - Status MakeTakeDataset(const Tensor& input_dataset, int64 count, - const DataTypeVector& output_types, - const std::vector& output_shapes, - Tensor* take_dataset) { - GraphConstructorOptions graph_opts; - graph_opts.allow_internal_ops = true; - graph_opts.expect_device_spec = false; - - Tensor count_tensor = CreateTensor(TensorShape({}), {count}); - TF_RETURN_IF_ERROR( - RunFunction(test::function::MakeTakeDataset(), - /*attrs*/ - {{TakeDatasetOp::kOutputTypes, output_types}, - {TakeDatasetOp::kOutputShapes, output_shapes}}, - /*inputs*/ {input_dataset, count_tensor}, graph_opts, - /*rets*/ {take_dataset})); - return Status::OK(); - } }; TEST_F(OptimizeDatasetOpTest, NoopElimination) {