Merge pull request #30886 from feihugis:Refactor_AssertNextDatasetOp
PiperOrigin-RevId: 259145082
This commit is contained in:
commit
488b385a7d
@ -30,6 +30,7 @@ cc_library(
|
||||
":iterator_ops",
|
||||
":name_utils",
|
||||
":range_dataset_op",
|
||||
":take_dataset_op",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -274,6 +274,46 @@ Status DatasetOpsTestBase::CreateTensorSliceDataset(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Create a `RangeDataset` dataset as a variant tensor.
|
||||
Status DatasetOpsTestBase::MakeRangeDataset(
|
||||
const Tensor& start, const Tensor& stop, const Tensor& step,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
Tensor* range_dataset) {
|
||||
GraphConstructorOptions graph_opts;
|
||||
graph_opts.allow_internal_ops = true;
|
||||
graph_opts.expect_device_spec = false;
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunFunction(test::function::MakeRangeDataset(),
|
||||
/*attrs*/
|
||||
{{RangeDatasetOp::kOutputTypes, output_types},
|
||||
{RangeDatasetOp::kOutputShapes, output_shapes}},
|
||||
/*inputs*/ {start, stop, step}, graph_opts,
|
||||
/*rets*/ {range_dataset}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Create a `TakeDataset` dataset as a variant tensor.
|
||||
Status DatasetOpsTestBase::MakeTakeDataset(
|
||||
const Tensor& input_dataset, int64 count,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
Tensor* take_dataset) {
|
||||
GraphConstructorOptions graph_opts;
|
||||
graph_opts.allow_internal_ops = true;
|
||||
graph_opts.expect_device_spec = false;
|
||||
|
||||
Tensor count_tensor = CreateTensor<int64>(TensorShape({}), {count});
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunFunction(test::function::MakeTakeDataset(),
|
||||
/*attrs*/
|
||||
{{TakeDatasetOp::kOutputTypes, output_types},
|
||||
{TakeDatasetOp::kOutputShapes, output_shapes}},
|
||||
/*inputs*/ {input_dataset, count_tensor}, graph_opts,
|
||||
/*rets*/ {take_dataset}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetOpsTestBase::CreateOpKernel(
|
||||
const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) {
|
||||
OpKernel* kernel;
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/data/iterator_ops.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/kernels/data/range_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/take_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/io/zlib_compression_options.h"
|
||||
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
|
||||
@ -177,6 +178,19 @@ class DatasetOpsTestBase : public ::testing::Test {
|
||||
std::vector<Tensor>* const components,
|
||||
DatasetBase** tensor_slice_dataset);
|
||||
|
||||
// Creates a `RangeDataset` dataset as a variant tensor.
|
||||
Status MakeRangeDataset(const Tensor& start, const Tensor& stop,
|
||||
const Tensor& step,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
Tensor* range_dataset);
|
||||
|
||||
// Creates a `TakeDataset` dataset as a variant tensor.
|
||||
Status MakeTakeDataset(const Tensor& input_dataset, int64 count,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
Tensor* take_dataset);
|
||||
|
||||
// Fetches the dataset from the operation context.
|
||||
Status GetDatasetFromContext(OpKernelContext* context, int output_index,
|
||||
DatasetBase** const dataset);
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
"tf_kernel_library",
|
||||
)
|
||||
|
||||
@ -16,9 +17,27 @@ exports_files(["LICENSE"])
|
||||
tf_kernel_library(
|
||||
name = "assert_next_dataset_op",
|
||||
srcs = ["assert_next_dataset_op.cc"],
|
||||
hdrs = ["assert_next_dataset_op.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/kernels/data:name_utils",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "assert_next_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["assert_next_dataset_op_test.cc"],
|
||||
deps = [
|
||||
":assert_next_dataset_op",
|
||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels/data:dataset_test_base",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
@ -12,149 +12,146 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
// See documentation in ../ops/dataset_ops.cc for a high-level
|
||||
// description of the following op.
|
||||
class AssertNextDatasetOp : public UnaryDatasetOpKernel {
|
||||
/* static */ constexpr const char* const AssertNextDatasetOp::kInputDataset;
|
||||
/* static */ constexpr const char* const AssertNextDatasetOp::kDatasetType;
|
||||
/* static */ constexpr const char* const AssertNextDatasetOp::kTransformations;
|
||||
/* static */ constexpr const char* const AssertNextDatasetOp::kOutputTypes;
|
||||
/* static */ constexpr const char* const AssertNextDatasetOp::kOutputShapes;
|
||||
|
||||
class AssertNextDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
explicit AssertNextDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
const std::vector<string>& transformations,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
transformations_(transformations),
|
||||
output_types_(output_types),
|
||||
output_shapes_(output_shapes) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
~Dataset() override { input_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return absl::make_unique<Iterator>(Iterator::Params{
|
||||
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override { return output_types_; }
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return output_shapes_;
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
protected:
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override {
|
||||
std::vector<string> transformations;
|
||||
OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "transformations",
|
||||
&transformations));
|
||||
*output =
|
||||
new Dataset(ctx, input, transformations, output_types_, output_shapes_);
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
Node* transformations_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node));
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this, {input_graph_node, transformations_node}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
const std::vector<string>& transformations,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
transformations_(transformations),
|
||||
output_types_(output_types),
|
||||
output_shapes_(output_shapes) {
|
||||
input_->Ref();
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
std::vector<string> tokens =
|
||||
absl::StrSplit(prefix(), ':', absl::SkipEmpty());
|
||||
if (dataset()->transformations_.size() > tokens.size() - 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Asserted next ", dataset()->transformations_.size(),
|
||||
" transformations but encountered only ", tokens.size() - 2, ".");
|
||||
}
|
||||
int n = tokens.size();
|
||||
for (size_t i = 0; i < dataset()->transformations_.size(); ++i) {
|
||||
if (dataset()->transformations_[i] != tokens[n - 2 - i]) {
|
||||
return errors::InvalidArgument(
|
||||
"Asserted ", dataset()->transformations_[i],
|
||||
" transformation at offset ", i, " but encountered ",
|
||||
tokens[n - 2 - i], " transformation instead.");
|
||||
}
|
||||
}
|
||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
~Dataset() override { input_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return absl::make_unique<Iterator>(
|
||||
Iterator::Params{this, strings::StrCat(prefix, "::AssertNext")});
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
return output_types_;
|
||||
}
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return output_shapes_;
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return "AssertNextDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
Node* transformations_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, {input_graph_node, transformations_node}, output));
|
||||
std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const override {
|
||||
return model::MakeKnownRatioNode(std::move(args),
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
std::vector<string> tokens =
|
||||
absl::StrSplit(prefix(), ':', absl::SkipEmpty());
|
||||
if (dataset()->transformations_.size() > tokens.size() - 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Asserted next ", dataset()->transformations_.size(),
|
||||
" transformations but encountered only ", tokens.size() - 2, ".");
|
||||
}
|
||||
int n = tokens.size();
|
||||
for (size_t i = 0; i < dataset()->transformations_.size(); ++i) {
|
||||
if (dataset()->transformations_[i] != tokens[n - 2 - i]) {
|
||||
return errors::InvalidArgument(
|
||||
"Asserted ", dataset()->transformations_[i],
|
||||
" transformation at offset ", i, " but encountered ",
|
||||
tokens[n - 2 - i], " transformation instead.");
|
||||
}
|
||||
}
|
||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const override {
|
||||
return model::MakeKnownRatioNode(std::move(args),
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
};
|
||||
|
||||
const DatasetBase* input_;
|
||||
const std::vector<string> transformations_;
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
};
|
||||
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
const DatasetBase* input_;
|
||||
const std::vector<string> transformations_;
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
AssertNextDatasetOp::AssertNextDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
}
|
||||
|
||||
void AssertNextDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) {
|
||||
std::vector<string> transformations;
|
||||
OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, kTransformations,
|
||||
&transformations));
|
||||
*output =
|
||||
new Dataset(ctx, input, transformations, output_types_, output_shapes_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU),
|
||||
AssertNextDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_NEXT_DATASET_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_NEXT_DATASET_OP_H_
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// See documentation in ../../ops/experimental_dataset_ops.cc for a high-level
|
||||
// description of the following op.
|
||||
|
||||
class AssertNextDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
static constexpr const char* const kDatasetType = "AssertNext";
|
||||
static constexpr const char* const kInputDataset = "input_dataset";
|
||||
static constexpr const char* const kTransformations = "transformations";
|
||||
static constexpr const char* const kOutputTypes = "output_types";
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
|
||||
explicit AssertNextDatasetOp(OpKernelConstruction* ctx);
|
||||
|
||||
protected:
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override;
|
||||
|
||||
private:
|
||||
class Dataset;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_NEXT_DATASET_OP_H_
|
@ -0,0 +1,667 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h"
|
||||
|
||||
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr char kNodeName[] = "assert_next_dataset";
|
||||
|
||||
struct RangeDatasetParams {
|
||||
int start;
|
||||
int stop;
|
||||
int step;
|
||||
};
|
||||
|
||||
struct TakeDatasetParams {
|
||||
int count;
|
||||
};
|
||||
|
||||
class AssertNextDatasetOpTest : public DatasetOpsTestBase {
|
||||
protected:
|
||||
// Creates a new `AssertNextDataset` op kernel.
|
||||
Status CreateAssertNextDatasetOpKernel(
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
std::unique_ptr<OpKernel>* assert_next_dataset_op_kernel) {
|
||||
NodeDef node_def = test::function::NDef(
|
||||
kNodeName, name_utils::OpName(AssertNextDatasetOp::kDatasetType),
|
||||
{AssertNextDatasetOp::kInputDataset,
|
||||
AssertNextDatasetOp::kTransformations},
|
||||
{{AssertNextDatasetOp::kOutputTypes, output_types},
|
||||
{AssertNextDatasetOp::kOutputShapes, output_shapes}});
|
||||
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, assert_next_dataset_op_kernel));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Creates a new `AssertNextDataset` op kernel context.
|
||||
Status CreateAssertNextDatasetContext(
|
||||
OpKernel* const op_kernel,
|
||||
gtl::InlinedVector<TensorValue, 4>* const inputs,
|
||||
std::unique_ptr<OpKernelContext>* context) {
|
||||
TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
|
||||
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Creates a new `RangeAndTakeDataset` tensor.
|
||||
Status MakeRangeAndTakeDatasetTensor(
|
||||
const RangeDatasetParams& range_dataset_params,
|
||||
const TakeDatasetParams& take_dataset_params,
|
||||
Tensor* range_and_take_dataset_tensor) {
|
||||
Tensor range_dataset_tensor;
|
||||
Tensor start =
|
||||
CreateTensor<int64>(TensorShape({}), {range_dataset_params.start});
|
||||
Tensor stop =
|
||||
CreateTensor<int64>(TensorShape({}), {range_dataset_params.stop});
|
||||
Tensor step =
|
||||
CreateTensor<int64>(TensorShape({}), {range_dataset_params.step});
|
||||
TF_RETURN_IF_ERROR(MakeRangeDataset(start, stop, step, {DT_INT64},
|
||||
{PartialTensorShape({})},
|
||||
&range_dataset_tensor));
|
||||
|
||||
TF_RETURN_IF_ERROR(MakeTakeDataset(
|
||||
range_dataset_tensor, take_dataset_params.count, {DT_INT64},
|
||||
{PartialTensorShape({})}, range_and_take_dataset_tensor));
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
struct TestCase {
|
||||
RangeDatasetParams range_dataset_params;
|
||||
TakeDatasetParams take_dataset_params;
|
||||
Tensor transformations;
|
||||
std::vector<Tensor> expected_outputs;
|
||||
DataTypeVector expected_output_dtypes;
|
||||
std::vector<PartialTensorShape> expected_output_shapes;
|
||||
int64 expected_cardinality;
|
||||
std::vector<int> breakpoints;
|
||||
};
|
||||
|
||||
// Test case 1 : assert one transformation.
|
||||
TestCase TestCase1() {
|
||||
return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
|
||||
/*take_dataset_params*/ {/*count*/ 3},
|
||||
/*transformations*/
|
||||
DatasetOpsTestBase::CreateTensor<string>(
|
||||
TensorShape({1}), {TakeDatasetOp::kDatasetType}),
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
|
||||
/*expected_output_dtypes*/ {DT_INT64},
|
||||
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||
/*expected_cardinality*/ 3,
|
||||
/*breakpoints*/ {0, 2, 5}};
|
||||
}
|
||||
|
||||
// Test case 2 : assert two transformations.
|
||||
TestCase TestCase2() {
|
||||
return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
|
||||
/*take_dataset_params*/ {/*count*/ 3},
|
||||
/*transformations*/
|
||||
DatasetOpsTestBase::CreateTensor<string>(
|
||||
TensorShape({2}),
|
||||
{TakeDatasetOp::kDatasetType, RangeDatasetOp::kDatasetType}),
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
|
||||
/*expected_output_dtypes*/ {DT_INT64},
|
||||
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||
/*expected_cardinality*/ 3,
|
||||
/*breakpoints*/ {0, 2, 5}};
|
||||
}
|
||||
|
||||
TestCase AssertNextInvalid() {
|
||||
return {
|
||||
/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
|
||||
/*take_dataset_params*/ {/*count*/ 3},
|
||||
/*transformations*/
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"Whoops"}),
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
|
||||
/*expected_output_dtypes*/ {DT_INT64},
|
||||
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||
/*expected_cardinality*/ 3,
|
||||
/*breakpoints*/ {0, 2, 5}};
|
||||
}
|
||||
|
||||
TestCase AssertNextShort() {
|
||||
return {/*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
|
||||
/*take_dataset_params*/ {/*count*/ 3},
|
||||
/*transformations*/
|
||||
DatasetOpsTestBase::CreateTensor<string>(
|
||||
TensorShape({3}), {TakeDatasetOp::kDatasetType,
|
||||
RangeDatasetOp::kDatasetType, "Whoops"}),
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
|
||||
/*expected_output_dtypes*/ {DT_INT64},
|
||||
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||
/*expected_cardinality*/ 3,
|
||||
/*breakpoints*/ {0, 2, 5}};
|
||||
}
|
||||
|
||||
class ParameterizedAssertNextDatasetOpTest
|
||||
: public AssertNextDatasetOpTest,
|
||||
public ::testing::WithParamInterface<TestCase> {};
|
||||
|
||||
TEST_P(ParameterizedAssertNextDatasetOpTest, GetNext) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor range_and_take_dataset_tensor;
|
||||
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||
test_case.take_dataset_params,
|
||||
&range_and_take_dataset_tensor));
|
||||
|
||||
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&assert_next_dataset_kernel));
|
||||
Tensor transformations = test_case.transformations;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&range_and_take_dataset_tensor),
|
||||
TensorValue(&transformations)});
|
||||
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||
|
||||
DatasetBase* assert_next_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||
assert_next_dataset_context.get(),
|
||||
&assert_next_dataset));
|
||||
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
|
||||
&iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
string iterator_prefix = name_utils::IteratorPrefix(
|
||||
TakeDatasetOp::kDatasetType,
|
||||
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
|
||||
TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(),
|
||||
iterator_prefix, &iterator));
|
||||
|
||||
bool end_of_sequence = false;
|
||||
std::vector<Tensor> out_tensors;
|
||||
while (!end_of_sequence) {
|
||||
std::vector<Tensor> next;
|
||||
TF_EXPECT_OK(
|
||||
iterator->GetNext(iterator_context.get(), &next, &end_of_sequence));
|
||||
out_tensors.insert(out_tensors.end(), next.begin(), next.end());
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
|
||||
/*compare_order*/ true));
|
||||
}
|
||||
|
||||
TEST_F(AssertNextDatasetOpTest, DatasetNodeName) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = TestCase1();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor range_and_take_dataset_tensor;
|
||||
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||
test_case.take_dataset_params,
|
||||
&range_and_take_dataset_tensor));
|
||||
|
||||
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&assert_next_dataset_kernel));
|
||||
Tensor transformations = test_case.transformations;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&range_and_take_dataset_tensor),
|
||||
TensorValue(&transformations)});
|
||||
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||
|
||||
DatasetBase* assert_next_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||
assert_next_dataset_context.get(),
|
||||
&assert_next_dataset));
|
||||
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||
|
||||
EXPECT_EQ(assert_next_dataset->node_name(), kNodeName);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetTypeString) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor range_and_take_dataset_tensor;
|
||||
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||
test_case.take_dataset_params,
|
||||
&range_and_take_dataset_tensor));
|
||||
|
||||
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&assert_next_dataset_kernel));
|
||||
Tensor transformations = test_case.transformations;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&range_and_take_dataset_tensor),
|
||||
TensorValue(&transformations)});
|
||||
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||
|
||||
DatasetBase* assert_next_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||
assert_next_dataset_context.get(),
|
||||
&assert_next_dataset));
|
||||
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||
|
||||
EXPECT_EQ(assert_next_dataset->type_string(),
|
||||
name_utils::OpName(AssertNextDatasetOp::kDatasetType));
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor range_and_take_dataset_tensor;
|
||||
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||
test_case.take_dataset_params,
|
||||
&range_and_take_dataset_tensor));
|
||||
|
||||
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&assert_next_dataset_kernel));
|
||||
Tensor transformations = test_case.transformations;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&range_and_take_dataset_tensor),
|
||||
TensorValue(&transformations)});
|
||||
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||
|
||||
DatasetBase* assert_next_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||
assert_next_dataset_context.get(),
|
||||
&assert_next_dataset));
|
||||
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||
|
||||
TF_EXPECT_OK(VerifyTypesMatch(assert_next_dataset->output_dtypes(),
|
||||
test_case.expected_output_dtypes));
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetOutputShapes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor range_and_take_dataset_tensor;
|
||||
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||
test_case.take_dataset_params,
|
||||
&range_and_take_dataset_tensor));
|
||||
|
||||
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&assert_next_dataset_kernel));
|
||||
Tensor transformations = test_case.transformations;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&range_and_take_dataset_tensor),
|
||||
TensorValue(&transformations)});
|
||||
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||
|
||||
DatasetBase* assert_next_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||
assert_next_dataset_context.get(),
|
||||
&assert_next_dataset));
|
||||
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||
|
||||
TF_EXPECT_OK(VerifyShapesCompatible(assert_next_dataset->output_shapes(),
|
||||
test_case.expected_output_shapes));
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedAssertNextDatasetOpTest, Cardinality) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor range_and_take_dataset_tensor;
|
||||
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||
test_case.take_dataset_params,
|
||||
&range_and_take_dataset_tensor));
|
||||
|
||||
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&assert_next_dataset_kernel));
|
||||
Tensor transformations = test_case.transformations;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&range_and_take_dataset_tensor),
|
||||
TensorValue(&transformations)});
|
||||
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||
|
||||
DatasetBase* assert_next_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||
assert_next_dataset_context.get(),
|
||||
&assert_next_dataset));
|
||||
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||
|
||||
EXPECT_EQ(assert_next_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor range_and_take_dataset_tensor;
|
||||
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||
test_case.take_dataset_params,
|
||||
&range_and_take_dataset_tensor));
|
||||
|
||||
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&assert_next_dataset_kernel));
|
||||
Tensor transformations = test_case.transformations;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&range_and_take_dataset_tensor),
|
||||
TensorValue(&transformations)});
|
||||
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||
|
||||
DatasetBase* assert_next_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||
assert_next_dataset_context.get(),
|
||||
&assert_next_dataset));
|
||||
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(assert_next_dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedAssertNextDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor range_and_take_dataset_tensor;
|
||||
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||
test_case.take_dataset_params,
|
||||
&range_and_take_dataset_tensor));
|
||||
|
||||
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&assert_next_dataset_kernel));
|
||||
Tensor transformations = test_case.transformations;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&range_and_take_dataset_tensor),
|
||||
TensorValue(&transformations)});
|
||||
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||
|
||||
DatasetBase* assert_next_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||
assert_next_dataset_context.get(),
|
||||
&assert_next_dataset));
|
||||
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
|
||||
&iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
string iterator_prefix = name_utils::IteratorPrefix(
|
||||
TakeDatasetOp::kDatasetType,
|
||||
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
|
||||
TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(),
|
||||
iterator_prefix, &iterator));
|
||||
|
||||
TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
|
||||
test_case.expected_output_dtypes));
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedAssertNextDatasetOpTest, IteratorOutputShapes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor range_and_take_dataset_tensor;
|
||||
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||
test_case.take_dataset_params,
|
||||
&range_and_take_dataset_tensor));
|
||||
|
||||
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&assert_next_dataset_kernel));
|
||||
Tensor transformations = test_case.transformations;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&range_and_take_dataset_tensor),
|
||||
TensorValue(&transformations)});
|
||||
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||
|
||||
DatasetBase* assert_next_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||
assert_next_dataset_context.get(),
|
||||
&assert_next_dataset));
|
||||
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
|
||||
&iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
string iterator_prefix = name_utils::IteratorPrefix(
|
||||
TakeDatasetOp::kDatasetType,
|
||||
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
|
||||
TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(),
|
||||
iterator_prefix, &iterator));
|
||||
|
||||
TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
|
||||
test_case.expected_output_shapes));
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedAssertNextDatasetOpTest, IteratorOutputPrefix) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor range_and_take_dataset_tensor;
|
||||
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||
test_case.take_dataset_params,
|
||||
&range_and_take_dataset_tensor));
|
||||
|
||||
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&assert_next_dataset_kernel));
|
||||
Tensor transformations = test_case.transformations;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&range_and_take_dataset_tensor),
|
||||
TensorValue(&transformations)});
|
||||
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||
|
||||
DatasetBase* assert_next_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||
assert_next_dataset_context.get(),
|
||||
&assert_next_dataset));
|
||||
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
|
||||
&iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
string iterator_prefix = name_utils::IteratorPrefix(
|
||||
TakeDatasetOp::kDatasetType,
|
||||
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
|
||||
TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(),
|
||||
iterator_prefix, &iterator));
|
||||
|
||||
EXPECT_EQ(iterator->prefix(),
|
||||
name_utils::IteratorPrefix(AssertNextDatasetOp::kDatasetType,
|
||||
iterator_prefix));
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedAssertNextDatasetOpTest, Roundtrip) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor range_and_take_dataset_tensor;
|
||||
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||
test_case.take_dataset_params,
|
||||
&range_and_take_dataset_tensor));
|
||||
|
||||
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&assert_next_dataset_kernel));
|
||||
Tensor transformations = test_case.transformations;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&range_and_take_dataset_tensor),
|
||||
TensorValue(&transformations)});
|
||||
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||
|
||||
DatasetBase* assert_next_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||
assert_next_dataset_context.get(),
|
||||
&assert_next_dataset));
|
||||
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
|
||||
&iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
string iterator_prefix = name_utils::IteratorPrefix(
|
||||
TakeDatasetOp::kDatasetType,
|
||||
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
|
||||
TF_ASSERT_OK(assert_next_dataset->MakeIterator(iterator_context.get(),
|
||||
iterator_prefix, &iterator));
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
bool end_of_sequence = false;
|
||||
std::vector<Tensor> out_tensors;
|
||||
int cur_iteration = 0;
|
||||
const std::vector<int>& breakpoints = test_case.breakpoints;
|
||||
for (int breakpoint : breakpoints) {
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
|
||||
TF_EXPECT_OK(writer.Flush());
|
||||
VariantTensorDataReader reader(&data);
|
||||
TF_EXPECT_OK(RestoreIterator(iterator_context.get(), &reader,
|
||||
iterator_prefix, *assert_next_dataset,
|
||||
&iterator));
|
||||
|
||||
while (cur_iteration <= breakpoint) {
|
||||
std::vector<Tensor> next;
|
||||
TF_EXPECT_OK(
|
||||
iterator->GetNext(iterator_context.get(), &next, &end_of_sequence));
|
||||
out_tensors.insert(out_tensors.end(), next.begin(), next.end());
|
||||
++cur_iteration;
|
||||
}
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
|
||||
/*compare_order*/ true));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
AssertNextDatasetOpTest, ParameterizedAssertNextDatasetOpTest,
|
||||
::testing::ValuesIn(std::vector<TestCase>({TestCase1(), TestCase2()})));
|
||||
|
||||
TEST_F(AssertNextDatasetOpTest, InvalidArguments) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
std::vector<TestCase> test_cases = {AssertNextInvalid(), AssertNextShort()};
|
||||
for (TestCase test_case : test_cases) {
|
||||
Tensor range_and_take_dataset_tensor;
|
||||
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||
test_case.take_dataset_params,
|
||||
&range_and_take_dataset_tensor));
|
||||
|
||||
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(
|
||||
test_case.expected_output_dtypes, test_case.expected_output_shapes,
|
||||
&assert_next_dataset_kernel));
|
||||
Tensor transformations = test_case.transformations;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&range_and_take_dataset_tensor),
|
||||
TensorValue(&transformations)});
|
||||
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||
TF_ASSERT_OK(
|
||||
CreateAssertNextDatasetContext(assert_next_dataset_kernel.get(),
|
||||
&inputs, &assert_next_dataset_context));
|
||||
|
||||
DatasetBase* assert_next_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||
assert_next_dataset_context.get(),
|
||||
&assert_next_dataset));
|
||||
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(CreateIteratorContext(assert_next_dataset_context.get(),
|
||||
&iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
string iterator_prefix = name_utils::IteratorPrefix(
|
||||
TakeDatasetOp::kDatasetType,
|
||||
name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
|
||||
EXPECT_EQ(
|
||||
assert_next_dataset
|
||||
->MakeIterator(iterator_context.get(), iterator_prefix, &iterator)
|
||||
.code(),
|
||||
tensorflow::error::INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -50,45 +50,6 @@ class OptimizeDatasetOpTest : public DatasetOpsTestBase {
|
||||
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Create a `RangeDataset` dataset as a variant tensor.
|
||||
Status MakeRangeDataset(const Tensor& start, const Tensor& stop,
|
||||
const Tensor& step,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
Tensor* range_dataset) {
|
||||
GraphConstructorOptions graph_opts;
|
||||
graph_opts.allow_internal_ops = true;
|
||||
graph_opts.expect_device_spec = false;
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunFunction(test::function::MakeRangeDataset(),
|
||||
/*attrs*/
|
||||
{{RangeDatasetOp::kOutputTypes, output_types},
|
||||
{RangeDatasetOp::kOutputShapes, output_shapes}},
|
||||
/*inputs*/ {start, stop, step}, graph_opts,
|
||||
/*rets*/ {range_dataset}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Create a `TakeDataset` dataset as a variant tensor.
|
||||
Status MakeTakeDataset(const Tensor& input_dataset, int64 count,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
Tensor* take_dataset) {
|
||||
GraphConstructorOptions graph_opts;
|
||||
graph_opts.allow_internal_ops = true;
|
||||
graph_opts.expect_device_spec = false;
|
||||
|
||||
Tensor count_tensor = CreateTensor<int64>(TensorShape({}), {count});
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunFunction(test::function::MakeTakeDataset(),
|
||||
/*attrs*/
|
||||
{{TakeDatasetOp::kOutputTypes, output_types},
|
||||
{TakeDatasetOp::kOutputShapes, output_shapes}},
|
||||
/*inputs*/ {input_dataset, count_tensor}, graph_opts,
|
||||
/*rets*/ {take_dataset}));
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(OptimizeDatasetOpTest, NoopElimination) {
|
||||
|
Loading…
Reference in New Issue
Block a user