Merge pull request #30886 from feihugis:Refactor_AssertNextDatasetOp
PiperOrigin-RevId: 259145082
This commit is contained in:
commit
488b385a7d
@ -30,6 +30,7 @@ cc_library(
|
|||||||
":iterator_ops",
|
":iterator_ops",
|
||||||
":name_utils",
|
":name_utils",
|
||||||
":range_dataset_op",
|
":range_dataset_op",
|
||||||
|
":take_dataset_op",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
@ -274,6 +274,46 @@ Status DatasetOpsTestBase::CreateTensorSliceDataset(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a `RangeDataset` dataset as a variant tensor.
|
||||||
|
Status DatasetOpsTestBase::MakeRangeDataset(
|
||||||
|
const Tensor& start, const Tensor& stop, const Tensor& step,
|
||||||
|
const DataTypeVector& output_types,
|
||||||
|
const std::vector<PartialTensorShape>& output_shapes,
|
||||||
|
Tensor* range_dataset) {
|
||||||
|
GraphConstructorOptions graph_opts;
|
||||||
|
graph_opts.allow_internal_ops = true;
|
||||||
|
graph_opts.expect_device_spec = false;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
RunFunction(test::function::MakeRangeDataset(),
|
||||||
|
/*attrs*/
|
||||||
|
{{RangeDatasetOp::kOutputTypes, output_types},
|
||||||
|
{RangeDatasetOp::kOutputShapes, output_shapes}},
|
||||||
|
/*inputs*/ {start, stop, step}, graph_opts,
|
||||||
|
/*rets*/ {range_dataset}));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a `TakeDataset` dataset as a variant tensor.
|
||||||
|
Status DatasetOpsTestBase::MakeTakeDataset(
|
||||||
|
const Tensor& input_dataset, int64 count,
|
||||||
|
const DataTypeVector& output_types,
|
||||||
|
const std::vector<PartialTensorShape>& output_shapes,
|
||||||
|
Tensor* take_dataset) {
|
||||||
|
GraphConstructorOptions graph_opts;
|
||||||
|
graph_opts.allow_internal_ops = true;
|
||||||
|
graph_opts.expect_device_spec = false;
|
||||||
|
|
||||||
|
Tensor count_tensor = CreateTensor<int64>(TensorShape({}), {count});
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
RunFunction(test::function::MakeTakeDataset(),
|
||||||
|
/*attrs*/
|
||||||
|
{{TakeDatasetOp::kOutputTypes, output_types},
|
||||||
|
{TakeDatasetOp::kOutputShapes, output_shapes}},
|
||||||
|
/*inputs*/ {input_dataset, count_tensor}, graph_opts,
|
||||||
|
/*rets*/ {take_dataset}));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status DatasetOpsTestBase::CreateOpKernel(
|
Status DatasetOpsTestBase::CreateOpKernel(
|
||||||
const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) {
|
const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) {
|
||||||
OpKernel* kernel;
|
OpKernel* kernel;
|
||||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/data/iterator_ops.h"
|
#include "tensorflow/core/kernels/data/iterator_ops.h"
|
||||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||||
#include "tensorflow/core/kernels/data/range_dataset_op.h"
|
#include "tensorflow/core/kernels/data/range_dataset_op.h"
|
||||||
|
#include "tensorflow/core/kernels/data/take_dataset_op.h"
|
||||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||||
#include "tensorflow/core/lib/io/zlib_compression_options.h"
|
#include "tensorflow/core/lib/io/zlib_compression_options.h"
|
||||||
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
|
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
|
||||||
@ -177,6 +178,19 @@ class DatasetOpsTestBase : public ::testing::Test {
|
|||||||
std::vector<Tensor>* const components,
|
std::vector<Tensor>* const components,
|
||||||
DatasetBase** tensor_slice_dataset);
|
DatasetBase** tensor_slice_dataset);
|
||||||
|
|
||||||
|
// Creates a `RangeDataset` dataset as a variant tensor.
|
||||||
|
Status MakeRangeDataset(const Tensor& start, const Tensor& stop,
|
||||||
|
const Tensor& step,
|
||||||
|
const DataTypeVector& output_types,
|
||||||
|
const std::vector<PartialTensorShape>& output_shapes,
|
||||||
|
Tensor* range_dataset);
|
||||||
|
|
||||||
|
// Creates a `TakeDataset` dataset as a variant tensor.
|
||||||
|
Status MakeTakeDataset(const Tensor& input_dataset, int64 count,
|
||||||
|
const DataTypeVector& output_types,
|
||||||
|
const std::vector<PartialTensorShape>& output_shapes,
|
||||||
|
Tensor* take_dataset);
|
||||||
|
|
||||||
// Fetches the dataset from the operation context.
|
// Fetches the dataset from the operation context.
|
||||||
Status GetDatasetFromContext(OpKernelContext* context, int output_index,
|
Status GetDatasetFromContext(OpKernelContext* context, int output_index,
|
||||||
DatasetBase** const dataset);
|
DatasetBase** const dataset);
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"tf_cc_test",
|
||||||
"tf_kernel_library",
|
"tf_kernel_library",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -16,9 +17,27 @@ exports_files(["LICENSE"])
|
|||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "assert_next_dataset_op",
|
name = "assert_next_dataset_op",
|
||||||
srcs = ["assert_next_dataset_op.cc"],
|
srcs = ["assert_next_dataset_op.cc"],
|
||||||
|
hdrs = ["assert_next_dataset_op.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core/kernels/data:name_utils",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "assert_next_dataset_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["assert_next_dataset_op_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":assert_next_dataset_op",
|
||||||
|
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/kernels/data:dataset_test_base",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -12,38 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h"
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
namespace {
|
|
||||||
|
|
||||||
// See documentation in ../ops/dataset_ops.cc for a high-level
|
/* static */ constexpr const char* const AssertNextDatasetOp::kInputDataset;
|
||||||
// description of the following op.
|
/* static */ constexpr const char* const AssertNextDatasetOp::kDatasetType;
|
||||||
class AssertNextDatasetOp : public UnaryDatasetOpKernel {
|
/* static */ constexpr const char* const AssertNextDatasetOp::kTransformations;
|
||||||
public:
|
/* static */ constexpr const char* const AssertNextDatasetOp::kOutputTypes;
|
||||||
explicit AssertNextDatasetOp(OpKernelConstruction* ctx)
|
/* static */ constexpr const char* const AssertNextDatasetOp::kOutputShapes;
|
||||||
: UnaryDatasetOpKernel(ctx) {
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
class AssertNextDatasetOp::Dataset : public DatasetBase {
|
||||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
|
||||||
DatasetBase** output) override {
|
|
||||||
std::vector<string> transformations;
|
|
||||||
OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "transformations",
|
|
||||||
&transformations));
|
|
||||||
*output =
|
|
||||||
new Dataset(ctx, input, transformations, output_types_, output_shapes_);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
class Dataset : public DatasetBase {
|
|
||||||
public:
|
public:
|
||||||
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||||
const std::vector<string>& transformations,
|
const std::vector<string>& transformations,
|
||||||
@ -61,19 +48,17 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||||
const string& prefix) const override {
|
const string& prefix) const override {
|
||||||
return absl::make_unique<Iterator>(
|
return absl::make_unique<Iterator>(Iterator::Params{
|
||||||
Iterator::Params{this, strings::StrCat(prefix, "::AssertNext")});
|
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
|
||||||
}
|
}
|
||||||
|
|
||||||
const DataTypeVector& output_dtypes() const override {
|
const DataTypeVector& output_dtypes() const override { return output_types_; }
|
||||||
return output_types_;
|
|
||||||
}
|
|
||||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||||
return output_shapes_;
|
return output_shapes_;
|
||||||
}
|
}
|
||||||
|
|
||||||
string DebugString() const override {
|
string DebugString() const override {
|
||||||
return "AssertNextDatasetOp::Dataset";
|
return name_utils::DatasetDebugString(kDatasetType);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||||
@ -86,8 +71,8 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||||
Node* transformations_node = nullptr;
|
Node* transformations_node = nullptr;
|
||||||
TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node));
|
TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node));
|
||||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
TF_RETURN_IF_ERROR(
|
||||||
this, {input_graph_node, transformations_node}, output));
|
b->AddDataset(this, {input_graph_node, transformations_node}, output));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -149,12 +134,24 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
const std::vector<string> transformations_;
|
const std::vector<string> transformations_;
|
||||||
const DataTypeVector output_types_;
|
const DataTypeVector output_types_;
|
||||||
const std::vector<PartialTensorShape> output_shapes_;
|
const std::vector<PartialTensorShape> output_shapes_;
|
||||||
};
|
|
||||||
|
|
||||||
DataTypeVector output_types_;
|
|
||||||
std::vector<PartialTensorShape> output_shapes_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
AssertNextDatasetOp::AssertNextDatasetOp(OpKernelConstruction* ctx)
|
||||||
|
: UnaryDatasetOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void AssertNextDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||||
|
DatasetBase** output) {
|
||||||
|
std::vector<string> transformations;
|
||||||
|
OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, kTransformations,
|
||||||
|
&transformations));
|
||||||
|
*output =
|
||||||
|
new Dataset(ctx, input, transformations, output_types_, output_shapes_);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU),
|
||||||
AssertNextDatasetOp);
|
AssertNextDatasetOp);
|
||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
@ -0,0 +1,49 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_NEXT_DATASET_OP_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_NEXT_DATASET_OP_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace data {
|
||||||
|
|
||||||
|
// See documentation in ../../ops/experimental_dataset_ops.cc for a high-level
|
||||||
|
// description of the following op.
|
||||||
|
|
||||||
|
class AssertNextDatasetOp : public UnaryDatasetOpKernel {
|
||||||
|
public:
|
||||||
|
static constexpr const char* const kDatasetType = "AssertNext";
|
||||||
|
static constexpr const char* const kInputDataset = "input_dataset";
|
||||||
|
static constexpr const char* const kTransformations = "transformations";
|
||||||
|
static constexpr const char* const kOutputTypes = "output_types";
|
||||||
|
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||||
|
|
||||||
|
explicit AssertNextDatasetOp(OpKernelConstruction* ctx);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||||
|
DatasetBase** output) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Dataset;
|
||||||
|
DataTypeVector output_types_;
|
||||||
|
std::vector<PartialTensorShape> output_shapes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace data
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_NEXT_DATASET_OP_H_
|
@ -0,0 +1,667 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace data {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kNodeName[] = "assert_next_dataset";
|
||||||
|
|
||||||
|
struct RangeDatasetParams {
|
||||||
|
int start;
|
||||||
|
int stop;
|
||||||
|
int step;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TakeDatasetParams {
|
||||||
|
int count;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AssertNextDatasetOpTest : public DatasetOpsTestBase {
|
||||||
|
protected:
|
||||||
|
// Creates a new `AssertNextDataset` op kernel.
|
||||||
|
Status CreateAssertNextDatasetOpKernel(
|
||||||
|
const DataTypeVector& output_types,
|
||||||
|
const std::vector<PartialTensorShape>& output_shapes,
|
||||||
|
std::unique_ptr<OpKernel>* assert_next_dataset_op_kernel) {
|
||||||
|
NodeDef node_def = test::function::NDef(
|
||||||
|
kNodeName, name_utils::OpName(AssertNextDatasetOp::kDatasetType),
|
||||||
|
{AssertNextDatasetOp::kInputDataset,
|
||||||
|
AssertNextDatasetOp::kTransformations},
|
||||||
|
{{AssertNextDatasetOp::kOutputTypes, output_types},
|
||||||
|
{AssertNextDatasetOp::kOutputShapes, output_shapes}});
|
||||||
|
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, assert_next_dataset_op_kernel));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates a new `AssertNextDataset` op kernel context.
|
||||||
|
Status CreateAssertNextDatasetContext(
|
||||||
|
OpKernel* const op_kernel,
|
||||||
|
gtl::InlinedVector<TensorValue, 4>* const inputs,
|
||||||
|
std::unique_ptr<OpKernelContext>* context) {
|
||||||
|
TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
|
||||||
|
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates a new `RangeAndTakeDataset` tensor.
|
||||||
|
Status MakeRangeAndTakeDatasetTensor(
|
||||||
|
const RangeDatasetParams& range_dataset_params,
|
||||||
|
const TakeDatasetParams& take_dataset_params,
|
||||||
|
Tensor* range_and_take_dataset_tensor) {
|
||||||
|
Tensor range_dataset_tensor;
|
||||||
|
Tensor start =
|
||||||
|
CreateTensor<int64>(TensorShape({}), {range_dataset_params.start});
|
||||||
|
Tensor stop =
|
||||||
|
CreateTensor<int64>(TensorShape({}), {range_dataset_params.stop});
|
||||||
|
Tensor step =
|
||||||
|
CreateTensor<int64>(TensorShape({}), {range_dataset_params.step});
|
||||||
|
TF_RETURN_IF_ERROR(MakeRangeDataset(start, stop, step, {DT_INT64},
|
||||||
|
{PartialTensorShape({})},
|
||||||
|
&range_dataset_tensor));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(MakeTakeDataset(
|
||||||
|
range_dataset_tensor, take_dataset_params.count, {DT_INT64},
|
||||||
|
{PartialTensorShape({})}, range_and_take_dataset_tensor));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TestCase {
|
||||||
|
RangeDatasetParams range_dataset_params;
|
||||||
|
TakeDatasetParams take_dataset_params;
|
||||||
|
Tensor transformations;
|
||||||
|
std::vector<Tensor> expected_outputs;
|
||||||
|
DataTypeVector expected_output_dtypes;
|
||||||
|
std::vector<PartialTensorShape> expected_output_shapes;
|
||||||
|
int64 expected_cardinality;
|
||||||
|
std::vector<int> breakpoints;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Test case 1 : assert one transformation.
|
||||||
|
TestCase TestCase1() {
|
||||||
|
return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
|
||||||
|
/*take_dataset_params*/ {/*count*/ 3},
|
||||||
|
/*transformations*/
|
||||||
|
DatasetOpsTestBase::CreateTensor<string>(
|
||||||
|
TensorShape({1}), {TakeDatasetOp::kDatasetType}),
|
||||||
|
/*expected_outputs*/
|
||||||
|
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
|
||||||
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
|
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||||
|
/*expected_cardinality*/ 3,
|
||||||
|
/*breakpoints*/ {0, 2, 5}};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 2 : assert two transformations.
|
||||||
|
TestCase TestCase2() {
|
||||||
|
return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
|
||||||
|
/*take_dataset_params*/ {/*count*/ 3},
|
||||||
|
/*transformations*/
|
||||||
|
DatasetOpsTestBase::CreateTensor<string>(
|
||||||
|
TensorShape({2}),
|
||||||
|
{TakeDatasetOp::kDatasetType, RangeDatasetOp::kDatasetType}),
|
||||||
|
/*expected_outputs*/
|
||||||
|
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
|
||||||
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
|
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||||
|
/*expected_cardinality*/ 3,
|
||||||
|
/*breakpoints*/ {0, 2, 5}};
|
||||||
|
}
|
||||||
|
|
||||||
|
TestCase AssertNextInvalid() {
|
||||||
|
return {
|
||||||
|
/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
|
||||||
|
/*take_dataset_params*/ {/*count*/ 3},
|
||||||
|
/*transformations*/
|
||||||
|
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"Whoops"}),
|
||||||
|
/*expected_outputs*/
|
||||||
|
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
|
||||||
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
|
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||||
|
/*expected_cardinality*/ 3,
|
||||||
|
/*breakpoints*/ {0, 2, 5}};
|
||||||
|
}
|
||||||
|
|
||||||
|
TestCase AssertNextShort() {
|
||||||
|
return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
|
||||||
|
/*take_dataset_params*/ {/*count*/ 3},
|
||||||
|
/*transformations*/
|
||||||
|
DatasetOpsTestBase::CreateTensor<string>(
|
||||||
|
TensorShape({3}), {TakeDatasetOp::kDatasetType,
|
||||||
|
RangeDatasetOp::kDatasetType, "Whoops"}),
|
||||||
|
/*expected_outputs*/
|
||||||
|
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
|
||||||
|
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
|
||||||
|
/*expected_output_dtypes*/ {DT_INT64},
|
||||||
|
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||||
|
/*expected_cardinality*/ 3,
|
||||||
|
/*breakpoints*/ {0, 2, 5}};
|
||||||
|
}
|
||||||
|
|
||||||
|
class ParameterizedAssertNextDatasetOpTest
|
||||||
|
: public AssertNextDatasetOpTest,
|
||||||
|
public ::testing::WithParamInterface<TestCase> {};
|
||||||
|
|
||||||
|
TEST_P(ParameterizedAssertNextDatasetOpTest, GetNext) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TestCase test_case = GetParam();
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
Tensor range_and_take_dataset_tensor;
|
||||||
|
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||||
|
test_case.take_dataset_params,
|
||||||
|
&range_and_take_dataset_tensor));
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||||
|
test_case.expected_output_shapes,
|
||||||
|
&assert_next_dataset_kernel));
|
||||||
|
Tensor transformations = test_case.transformations;
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||||
|
{TensorValue(&range_and_take_dataset_tensor),
|
||||||
|
TensorValue(&transformations)});
|
||||||
|
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||||
|
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||||
|
|
||||||
|
DatasetBase* assert_next_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||||
|
assert_next_dataset_context.get(),
|
||||||
|
&assert_next_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
|
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
|
||||||
|
&iterator_context));
|
||||||
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
|
string iterator_prefix = name_utils::IteratorPrefix(
|
||||||
|
TakeDatasetOp::kDatasetType,
|
||||||
|
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
|
||||||
|
TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(),
|
||||||
|
iterator_prefix, &iterator));
|
||||||
|
|
||||||
|
bool end_of_sequence = false;
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
while (!end_of_sequence) {
|
||||||
|
std::vector<Tensor> next;
|
||||||
|
TF_EXPECT_OK(
|
||||||
|
iterator->GetNext(iterator_context.get(), &next, &end_of_sequence));
|
||||||
|
out_tensors.insert(out_tensors.end(), next.begin(), next.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
|
||||||
|
/*compare_order*/ true));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AssertNextDatasetOpTest, DatasetNodeName) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TestCase test_case = TestCase1();
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
Tensor range_and_take_dataset_tensor;
|
||||||
|
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||||
|
test_case.take_dataset_params,
|
||||||
|
&range_and_take_dataset_tensor));
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||||
|
test_case.expected_output_shapes,
|
||||||
|
&assert_next_dataset_kernel));
|
||||||
|
Tensor transformations = test_case.transformations;
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||||
|
{TensorValue(&range_and_take_dataset_tensor),
|
||||||
|
TensorValue(&transformations)});
|
||||||
|
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||||
|
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||||
|
|
||||||
|
DatasetBase* assert_next_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||||
|
assert_next_dataset_context.get(),
|
||||||
|
&assert_next_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||||
|
|
||||||
|
EXPECT_EQ(assert_next_dataset->node_name(), kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetTypeString) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TestCase test_case = GetParam();
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
Tensor range_and_take_dataset_tensor;
|
||||||
|
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||||
|
test_case.take_dataset_params,
|
||||||
|
&range_and_take_dataset_tensor));
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||||
|
test_case.expected_output_shapes,
|
||||||
|
&assert_next_dataset_kernel));
|
||||||
|
Tensor transformations = test_case.transformations;
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||||
|
{TensorValue(&range_and_take_dataset_tensor),
|
||||||
|
TensorValue(&transformations)});
|
||||||
|
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||||
|
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||||
|
|
||||||
|
DatasetBase* assert_next_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||||
|
assert_next_dataset_context.get(),
|
||||||
|
&assert_next_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||||
|
|
||||||
|
EXPECT_EQ(assert_next_dataset->type_string(),
|
||||||
|
name_utils::OpName(AssertNextDatasetOp::kDatasetType));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetOutputDtypes) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TestCase test_case = GetParam();
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
Tensor range_and_take_dataset_tensor;
|
||||||
|
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||||
|
test_case.take_dataset_params,
|
||||||
|
&range_and_take_dataset_tensor));
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||||
|
test_case.expected_output_shapes,
|
||||||
|
&assert_next_dataset_kernel));
|
||||||
|
Tensor transformations = test_case.transformations;
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||||
|
{TensorValue(&range_and_take_dataset_tensor),
|
||||||
|
TensorValue(&transformations)});
|
||||||
|
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||||
|
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||||
|
|
||||||
|
DatasetBase* assert_next_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||||
|
assert_next_dataset_context.get(),
|
||||||
|
&assert_next_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||||
|
|
||||||
|
TF_EXPECT_OK(VerifyTypesMatch(assert_next_dataset->output_dtypes(),
|
||||||
|
test_case.expected_output_dtypes));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetOutputShapes) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TestCase test_case = GetParam();
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
Tensor range_and_take_dataset_tensor;
|
||||||
|
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||||
|
test_case.take_dataset_params,
|
||||||
|
&range_and_take_dataset_tensor));
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||||
|
test_case.expected_output_shapes,
|
||||||
|
&assert_next_dataset_kernel));
|
||||||
|
Tensor transformations = test_case.transformations;
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||||
|
{TensorValue(&range_and_take_dataset_tensor),
|
||||||
|
TensorValue(&transformations)});
|
||||||
|
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||||
|
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||||
|
|
||||||
|
DatasetBase* assert_next_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||||
|
assert_next_dataset_context.get(),
|
||||||
|
&assert_next_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||||
|
|
||||||
|
TF_EXPECT_OK(VerifyShapesCompatible(assert_next_dataset->output_shapes(),
|
||||||
|
test_case.expected_output_shapes));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(ParameterizedAssertNextDatasetOpTest, Cardinality) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TestCase test_case = GetParam();
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
Tensor range_and_take_dataset_tensor;
|
||||||
|
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||||
|
test_case.take_dataset_params,
|
||||||
|
&range_and_take_dataset_tensor));
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||||
|
test_case.expected_output_shapes,
|
||||||
|
&assert_next_dataset_kernel));
|
||||||
|
Tensor transformations = test_case.transformations;
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||||
|
{TensorValue(&range_and_take_dataset_tensor),
|
||||||
|
TensorValue(&transformations)});
|
||||||
|
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||||
|
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||||
|
|
||||||
|
DatasetBase* assert_next_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||||
|
assert_next_dataset_context.get(),
|
||||||
|
&assert_next_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||||
|
|
||||||
|
EXPECT_EQ(assert_next_dataset->Cardinality(), test_case.expected_cardinality);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetSave) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TestCase test_case = GetParam();
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
Tensor range_and_take_dataset_tensor;
|
||||||
|
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||||
|
test_case.take_dataset_params,
|
||||||
|
&range_and_take_dataset_tensor));
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||||
|
test_case.expected_output_shapes,
|
||||||
|
&assert_next_dataset_kernel));
|
||||||
|
Tensor transformations = test_case.transformations;
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||||
|
{TensorValue(&range_and_take_dataset_tensor),
|
||||||
|
TensorValue(&transformations)});
|
||||||
|
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||||
|
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||||
|
|
||||||
|
DatasetBase* assert_next_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||||
|
assert_next_dataset_context.get(),
|
||||||
|
&assert_next_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||||
|
|
||||||
|
std::unique_ptr<SerializationContext> serialization_context;
|
||||||
|
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||||
|
VariantTensorData data;
|
||||||
|
VariantTensorDataWriter writer(&data);
|
||||||
|
TF_ASSERT_OK(assert_next_dataset->Save(serialization_context.get(), &writer));
|
||||||
|
TF_ASSERT_OK(writer.Flush());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(ParameterizedAssertNextDatasetOpTest, IteratorOutputDtypes) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TestCase test_case = GetParam();
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
Tensor range_and_take_dataset_tensor;
|
||||||
|
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||||
|
test_case.take_dataset_params,
|
||||||
|
&range_and_take_dataset_tensor));
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||||
|
test_case.expected_output_shapes,
|
||||||
|
&assert_next_dataset_kernel));
|
||||||
|
Tensor transformations = test_case.transformations;
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||||
|
{TensorValue(&range_and_take_dataset_tensor),
|
||||||
|
TensorValue(&transformations)});
|
||||||
|
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||||
|
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||||
|
|
||||||
|
DatasetBase* assert_next_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||||
|
assert_next_dataset_context.get(),
|
||||||
|
&assert_next_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
|
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
|
||||||
|
&iterator_context));
|
||||||
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
|
string iterator_prefix = name_utils::IteratorPrefix(
|
||||||
|
TakeDatasetOp::kDatasetType,
|
||||||
|
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
|
||||||
|
TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(),
|
||||||
|
iterator_prefix, &iterator));
|
||||||
|
|
||||||
|
TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
|
||||||
|
test_case.expected_output_dtypes));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(ParameterizedAssertNextDatasetOpTest, IteratorOutputShapes) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TestCase test_case = GetParam();
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
Tensor range_and_take_dataset_tensor;
|
||||||
|
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||||
|
test_case.take_dataset_params,
|
||||||
|
&range_and_take_dataset_tensor));
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||||
|
test_case.expected_output_shapes,
|
||||||
|
&assert_next_dataset_kernel));
|
||||||
|
Tensor transformations = test_case.transformations;
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||||
|
{TensorValue(&range_and_take_dataset_tensor),
|
||||||
|
TensorValue(&transformations)});
|
||||||
|
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||||
|
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||||
|
|
||||||
|
DatasetBase* assert_next_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||||
|
assert_next_dataset_context.get(),
|
||||||
|
&assert_next_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
|
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
|
||||||
|
&iterator_context));
|
||||||
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
|
string iterator_prefix = name_utils::IteratorPrefix(
|
||||||
|
TakeDatasetOp::kDatasetType,
|
||||||
|
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
|
||||||
|
TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(),
|
||||||
|
iterator_prefix, &iterator));
|
||||||
|
|
||||||
|
TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
|
||||||
|
test_case.expected_output_shapes));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(ParameterizedAssertNextDatasetOpTest, IteratorOutputPrefix) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TestCase test_case = GetParam();
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
Tensor range_and_take_dataset_tensor;
|
||||||
|
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||||
|
test_case.take_dataset_params,
|
||||||
|
&range_and_take_dataset_tensor));
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||||
|
test_case.expected_output_shapes,
|
||||||
|
&assert_next_dataset_kernel));
|
||||||
|
Tensor transformations = test_case.transformations;
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||||
|
{TensorValue(&range_and_take_dataset_tensor),
|
||||||
|
TensorValue(&transformations)});
|
||||||
|
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||||
|
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||||
|
|
||||||
|
DatasetBase* assert_next_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||||
|
assert_next_dataset_context.get(),
|
||||||
|
&assert_next_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
|
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
|
||||||
|
&iterator_context));
|
||||||
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
|
string iterator_prefix = name_utils::IteratorPrefix(
|
||||||
|
TakeDatasetOp::kDatasetType,
|
||||||
|
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
|
||||||
|
TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(),
|
||||||
|
iterator_prefix, &iterator));
|
||||||
|
|
||||||
|
EXPECT_EQ(iterator->prefix(),
|
||||||
|
name_utils::IteratorPrefix(AssertNextDatasetOp::kDatasetType,
|
||||||
|
iterator_prefix));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(ParameterizedAssertNextDatasetOpTest, Roundtrip) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TestCase test_case = GetParam();
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
Tensor range_and_take_dataset_tensor;
|
||||||
|
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||||
|
test_case.take_dataset_params,
|
||||||
|
&range_and_take_dataset_tensor));
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||||
|
test_case.expected_output_shapes,
|
||||||
|
&assert_next_dataset_kernel));
|
||||||
|
Tensor transformations = test_case.transformations;
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||||
|
{TensorValue(&range_and_take_dataset_tensor),
|
||||||
|
TensorValue(&transformations)});
|
||||||
|
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||||
|
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||||
|
|
||||||
|
DatasetBase* assert_next_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||||
|
assert_next_dataset_context.get(),
|
||||||
|
&assert_next_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
|
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
|
||||||
|
&iterator_context));
|
||||||
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
|
string iterator_prefix = name_utils::IteratorPrefix(
|
||||||
|
TakeDatasetOp::kDatasetType,
|
||||||
|
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
|
||||||
|
TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(),
|
||||||
|
iterator_prefix, &iterator));
|
||||||
|
|
||||||
|
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||||
|
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||||
|
bool end_of_sequence = false;
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
int cur_iteration = 0;
|
||||||
|
const std::vector<int>& breakpoints = test_case.breakpoints;
|
||||||
|
for (int breakpoint : breakpoints) {
|
||||||
|
VariantTensorData data;
|
||||||
|
VariantTensorDataWriter writer(&data);
|
||||||
|
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
|
||||||
|
TF_EXPECT_OK(writer.Flush());
|
||||||
|
VariantTensorDataReader reader(&data);
|
||||||
|
TF_EXPECT_OK(RestoreIterator(iterator_context.get(), &reader,
|
||||||
|
iterator_prefix, *assert_next_dataset,
|
||||||
|
&iterator));
|
||||||
|
|
||||||
|
while (cur_iteration <= breakpoint) {
|
||||||
|
std::vector<Tensor> next;
|
||||||
|
TF_EXPECT_OK(
|
||||||
|
iterator->GetNext(iterator_context.get(), &next, &end_of_sequence));
|
||||||
|
out_tensors.insert(out_tensors.end(), next.begin(), next.end());
|
||||||
|
++cur_iteration;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
|
||||||
|
/*compare_order*/ true));
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
AssertNextDatasetOpTest, ParameterizedAssertNextDatasetOpTest,
|
||||||
|
::testing::ValuesIn(std::vector<TestCase>({TestCase1(), TestCase2()})));
|
||||||
|
|
||||||
|
TEST_F(AssertNextDatasetOpTest, InvalidArguments) {
|
||||||
|
int thread_num = 2, cpu_num = 2;
|
||||||
|
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||||
|
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||||
|
|
||||||
|
std::vector<TestCase> test_cases = {AssertNextInvalid(), AssertNextShort()};
|
||||||
|
for (TestCase test_case : test_cases) {
|
||||||
|
Tensor range_and_take_dataset_tensor;
|
||||||
|
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||||
|
test_case.take_dataset_params,
|
||||||
|
&range_and_take_dataset_tensor));
|
||||||
|
|
||||||
|
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||||
|
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(
|
||||||
|
test_case.expected_output_dtypes, test_case.expected_output_shapes,
|
||||||
|
&assert_next_dataset_kernel));
|
||||||
|
Tensor transformations = test_case.transformations;
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||||
|
{TensorValue(&range_and_take_dataset_tensor),
|
||||||
|
TensorValue(&transformations)});
|
||||||
|
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
CreateAssertNextDatasetContext(assert_next_dataset_kernel.get(),
|
||||||
|
&inputs, &assert_next_dataset_context));
|
||||||
|
|
||||||
|
DatasetBase* assert_next_dataset;
|
||||||
|
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||||
|
assert_next_dataset_context.get(),
|
||||||
|
&assert_next_dataset));
|
||||||
|
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorContext> iterator_context;
|
||||||
|
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
|
||||||
|
&iterator_context));
|
||||||
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
|
string iterator_prefix = name_utils::IteratorPrefix(
|
||||||
|
TakeDatasetOp::kDatasetType,
|
||||||
|
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
|
||||||
|
EXPECT_EQ(
|
||||||
|
assert_next_dataset
|
||||||
|
->MakeIterator(iterator_context.get(), iterator_prefix, &iterator)
|
||||||
|
.code(),
|
||||||
|
tensorflow::error::INVALID_ARGUMENT);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace data
|
||||||
|
} // namespace tensorflow
|
@ -50,45 +50,6 @@ class OptimizeDatasetOpTest : public DatasetOpsTestBase {
|
|||||||
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
|
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a `RangeDataset` dataset as a variant tensor.
|
|
||||||
Status MakeRangeDataset(const Tensor& start, const Tensor& stop,
|
|
||||||
const Tensor& step,
|
|
||||||
const DataTypeVector& output_types,
|
|
||||||
const std::vector<PartialTensorShape>& output_shapes,
|
|
||||||
Tensor* range_dataset) {
|
|
||||||
GraphConstructorOptions graph_opts;
|
|
||||||
graph_opts.allow_internal_ops = true;
|
|
||||||
graph_opts.expect_device_spec = false;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
RunFunction(test::function::MakeRangeDataset(),
|
|
||||||
/*attrs*/
|
|
||||||
{{RangeDatasetOp::kOutputTypes, output_types},
|
|
||||||
{RangeDatasetOp::kOutputShapes, output_shapes}},
|
|
||||||
/*inputs*/ {start, stop, step}, graph_opts,
|
|
||||||
/*rets*/ {range_dataset}));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a `TakeDataset` dataset as a variant tensor.
|
|
||||||
Status MakeTakeDataset(const Tensor& input_dataset, int64 count,
|
|
||||||
const DataTypeVector& output_types,
|
|
||||||
const std::vector<PartialTensorShape>& output_shapes,
|
|
||||||
Tensor* take_dataset) {
|
|
||||||
GraphConstructorOptions graph_opts;
|
|
||||||
graph_opts.allow_internal_ops = true;
|
|
||||||
graph_opts.expect_device_spec = false;
|
|
||||||
|
|
||||||
Tensor count_tensor = CreateTensor<int64>(TensorShape({}), {count});
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
RunFunction(test::function::MakeTakeDataset(),
|
|
||||||
/*attrs*/
|
|
||||||
{{TakeDatasetOp::kOutputTypes, output_types},
|
|
||||||
{TakeDatasetOp::kOutputShapes, output_shapes}},
|
|
||||||
/*inputs*/ {input_dataset, count_tensor}, graph_opts,
|
|
||||||
/*rets*/ {take_dataset}));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(OptimizeDatasetOpTest, NoopElimination) {
|
TEST_F(OptimizeDatasetOpTest, NoopElimination) {
|
||||||
|
Loading…
Reference in New Issue
Block a user