Merge pull request #30723 from feihugis:Refactor_DatasetOps_8

PiperOrigin-RevId: 258254881
This commit is contained in:
TensorFlower Gardener 2019-07-15 16:17:01 -07:00
commit 2beeba1eb6
7 changed files with 958 additions and 259 deletions

View File

@ -972,18 +972,6 @@ tf_cc_test(
],
)
tf_kernel_library(
name = "reader_dataset_ops",
srcs = ["reader_dataset_ops.cc"],
deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "text_line_dataset_op",
srcs = ["text_line_dataset_op.cc"],
@ -1050,6 +1038,39 @@ tf_cc_test(
],
)
tf_kernel_library(
name = "tf_record_dataset_op",
srcs = ["tf_record_dataset_op.cc"],
hdrs = ["tf_record_dataset_op.h"],
deps = [
":name_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_cc_test(
name = "tf_record_dataset_op_test",
size = "small",
srcs = ["tf_record_dataset_op_test.cc"],
deps = [
":dataset_test_base",
":dataset_utils",
":iterator_ops",
":tf_record_dataset_op",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_kernel_library(
name = "iterator_ops",
srcs = ["iterator_ops.cc"],
@ -1242,7 +1263,6 @@ tf_kernel_library(
":parallel_map_dataset_op",
":prefetch_dataset_op",
":range_dataset_op",
":reader_dataset_ops",
":repeat_dataset_op",
":shard_dataset_op",
":shuffle_dataset_op",
@ -1252,6 +1272,7 @@ tf_kernel_library(
":tensor_dataset_op",
":tensor_slice_dataset_op",
":text_line_dataset_op",
":tf_record_dataset_op",
":window_dataset_op",
":zip_dataset_op",
"//tensorflow/core:array_ops_op_lib",

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/lib/io/record_writer.h"
namespace tensorflow {
namespace data {
@ -72,6 +73,7 @@ Status WriteDataToFile(const string& filename, const char* data,
zlib_compression_options);
TF_RETURN_IF_ERROR(out.Init());
TF_RETURN_IF_ERROR(out.Append(data));
TF_RETURN_IF_ERROR(out.Flush());
TF_RETURN_IF_ERROR(out.Close());
} else {
return tensorflow::errors::InvalidArgument(
@ -84,6 +86,26 @@ Status WriteDataToFile(const string& filename, const char* data,
return Status::OK();
}
Status WriteDataToTFRecordFile(const string& filename,
const std::vector<absl::string_view>& records,
const CompressionParams& params) {
Env* env = Env::Default();
std::unique_ptr<WritableFile> file_writer;
TF_RETURN_IF_ERROR(env->NewWritableFile(filename, &file_writer));
auto options = io::RecordWriterOptions::CreateRecordWriterOptions(
ToString(params.compression_type));
options.zlib_options.input_buffer_size = params.input_buffer_size;
io::RecordWriter record_writer(file_writer.get(), options);
for (const auto& record : records) {
TF_RETURN_IF_ERROR(record_writer.WriteRecord(record));
}
TF_RETURN_IF_ERROR(record_writer.Flush());
TF_RETURN_IF_ERROR(record_writer.Close());
TF_RETURN_IF_ERROR(file_writer->Flush());
TF_RETURN_IF_ERROR(file_writer->Close());
return Status::OK();
}
template <typename T>
Status IsEqual(const Tensor& t1, const Tensor& t2) {
if (t1.dtype() != t2.dtype()) {

View File

@ -71,6 +71,11 @@ Status WriteDataToFile(const string& filename, const char* data);
Status WriteDataToFile(const string& filename, const char* data,
const CompressionParams& params);
// Writes the input data into the TFRecord file with the specified compression.
Status WriteDataToTFRecordFile(const string& filename,
const std::vector<absl::string_view>& records,
const CompressionParams& params);
// Helpful functions to test Dataset op kernels.
class DatasetOpsTestBase : public ::testing::Test {
public:

View File

@ -1,246 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/inputbuffer.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/lib/io/zlib_inputstream.h"
namespace tensorflow {
namespace data {
namespace {
// See documentation in ../../ops/dataset_ops.cc for a high-level
// description of the following ops.
constexpr char kTFRecordDatasetName[] = "TFRecord";
class TFRecordDatasetOp : public DatasetOpKernel {
public:
using DatasetOpKernel::DatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
const Tensor* filenames_tensor;
OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
OP_REQUIRES(
ctx, filenames_tensor->dims() <= 1,
errors::InvalidArgument("`filenames` must be a scalar or a vector."));
std::vector<string> filenames;
filenames.reserve(filenames_tensor->NumElements());
for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
VLOG(2) << "Reading file: " << filenames_tensor->flat<string>()(i);
filenames.push_back(filenames_tensor->flat<string>()(i));
}
string compression_type;
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type",
&compression_type));
int64 buffer_size = -1;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
OP_REQUIRES(ctx, buffer_size >= 0,
errors::InvalidArgument(
"`buffer_size` must be >= 0 (0 == no buffering)"));
*output =
new Dataset(ctx, std::move(filenames), compression_type, buffer_size);
}
private:
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, std::vector<string> filenames,
const string& compression_type, int64 buffer_size)
: DatasetBase(DatasetContext(ctx)),
filenames_(std::move(filenames)),
compression_type_(compression_type),
options_(io::RecordReaderOptions::CreateRecordReaderOptions(
compression_type)) {
if (buffer_size > 0) {
options_.buffer_size = buffer_size;
}
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, strings::StrCat(prefix, "::", kTFRecordDatasetName)});
}
const DataTypeVector& output_dtypes() const override {
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
return *dtypes;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{}});
return *shapes;
}
string DebugString() const override { return "TFRecordDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* filenames = nullptr;
TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
Node* compression_type = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
Node* buffer_size = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(options_.buffer_size, &buffer_size));
TF_RETURN_IF_ERROR(b->AddDataset(
this, {filenames, compression_type, buffer_size}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
do {
// We are currently processing a file, so try to read the next record.
if (reader_) {
out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
TensorShape({}));
Status s =
reader_->ReadRecord(&out_tensors->back().scalar<string>()());
if (s.ok()) {
metrics::RecordTFDataBytesRead(
kTFRecordDatasetName,
out_tensors->back().scalar<string>()().size());
*end_of_sequence = false;
return Status::OK();
}
out_tensors->pop_back();
if (!errors::IsOutOfRange(s)) {
// In case of other errors e.g., DataLoss, we still move forward
// the file index so that it works with ignore_errors.
// Otherwise the same file will repeat.
ResetStreamsLocked();
++current_file_index_;
return s;
}
// We have reached the end of the current file, so maybe
// move on to next file.
ResetStreamsLocked();
++current_file_index_;
}
// Iteration ends when there are no more files to process.
if (current_file_index_ == dataset()->filenames_.size()) {
*end_of_sequence = true;
return Status::OK();
}
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
} while (true);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeSourceNode(std::move(args));
}
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
current_file_index_));
if (reader_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("offset"), reader_->TellOffset()));
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
ResetStreamsLocked();
int64 current_file_index;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
&current_file_index));
current_file_index_ = size_t(current_file_index);
if (reader->Contains(full_name("offset"))) {
int64 offset;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("offset"), &offset));
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
TF_RETURN_IF_ERROR(reader_->SeekOffset(offset));
}
return Status::OK();
}
private:
// Sets up reader streams to read from the file at `current_file_index_`.
Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (current_file_index_ >= dataset()->filenames_.size()) {
return errors::InvalidArgument(
"current_file_index_:", current_file_index_,
" >= filenames_.size():", dataset()->filenames_.size());
}
// Actually move on to next file.
const string& next_filename =
dataset()->filenames_[current_file_index_];
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(next_filename, &file_));
reader_ = absl::make_unique<io::SequentialRecordReader>(
file_.get(), dataset()->options_);
return Status::OK();
}
// Resets all reader streams.
void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
reader_.reset();
file_.reset();
}
mutex mu_;
size_t current_file_index_ GUARDED_BY(mu_) = 0;
// `reader_` will borrow the object that `file_` points to, so
// we must destroy `reader_` before `file_`.
std::unique_ptr<RandomAccessFile> file_ GUARDED_BY(mu_);
std::unique_ptr<io::SequentialRecordReader> reader_ GUARDED_BY(mu_);
};
const std::vector<string> filenames_;
const string compression_type_;
io::RecordReaderOptions options_;
};
};
REGISTER_KERNEL_BUILDER(Name("TFRecordDataset").Device(DEVICE_CPU),
TFRecordDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,251 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/tf_record_dataset_op.h"
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/inputbuffer.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/lib/io/zlib_inputstream.h"
namespace tensorflow {
namespace data {
// See documentation in ../../ops/dataset_ops.cc for a high-level
// description of the following ops.
/* static */ constexpr const char* const TFRecordDatasetOp::kDatasetType;
/* static */ constexpr const char* const TFRecordDatasetOp::kFileNames;
/* static */ constexpr const char* const TFRecordDatasetOp::kCompressionType;
/* static */ constexpr const char* const TFRecordDatasetOp::kBufferSize;
constexpr char kCurrentFileIndex[] = "current_file_index";
constexpr char kOffset[] = "offset";
class TFRecordDatasetOp::Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, std::vector<string> filenames,
const string& compression_type, int64 buffer_size)
: DatasetBase(DatasetContext(ctx)),
filenames_(std::move(filenames)),
compression_type_(compression_type),
options_(io::RecordReaderOptions::CreateRecordReaderOptions(
compression_type)) {
if (buffer_size > 0) {
options_.buffer_size = buffer_size;
}
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
}
const DataTypeVector& output_dtypes() const override {
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
return *dtypes;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{}});
return *shapes;
}
string DebugString() const override {
return name_utils::DatasetDebugString(kDatasetType);
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* filenames = nullptr;
TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
Node* compression_type = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
Node* buffer_size = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(options_.buffer_size, &buffer_size));
TF_RETURN_IF_ERROR(b->AddDataset(
this, {filenames, compression_type, buffer_size}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
do {
// We are currently processing a file, so try to read the next record.
if (reader_) {
out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
TensorShape({}));
Status s =
reader_->ReadRecord(&out_tensors->back().scalar<string>()());
if (s.ok()) {
metrics::RecordTFDataBytesRead(
kDatasetType, out_tensors->back().scalar<string>()().size());
*end_of_sequence = false;
return Status::OK();
}
out_tensors->pop_back();
if (!errors::IsOutOfRange(s)) {
// In case of other errors e.g., DataLoss, we still move forward
// the file index so that it works with ignore_errors.
// Otherwise the same file will repeat.
ResetStreamsLocked();
++current_file_index_;
return s;
}
// We have reached the end of the current file, so maybe move on to
// next file.
ResetStreamsLocked();
++current_file_index_;
}
// Iteration ends when there are no more files to process.
if (current_file_index_ == dataset()->filenames_.size()) {
*end_of_sequence = true;
return Status::OK();
}
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
} while (true);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeSourceNode(std::move(args));
}
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurrentFileIndex),
current_file_index_));
if (reader_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(kOffset), reader_->TellOffset()));
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
ResetStreamsLocked();
int64 current_file_index;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurrentFileIndex),
&current_file_index));
current_file_index_ = size_t(current_file_index);
if (reader->Contains(full_name(kOffset))) {
int64 offset;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kOffset), &offset));
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
TF_RETURN_IF_ERROR(reader_->SeekOffset(offset));
}
return Status::OK();
}
private:
// Sets up reader streams to read from the file at `current_file_index_`.
Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (current_file_index_ >= dataset()->filenames_.size()) {
return errors::InvalidArgument(
"current_file_index_:", current_file_index_,
" >= filenames_.size():", dataset()->filenames_.size());
}
// Actually move on to next file.
const string& next_filename = dataset()->filenames_[current_file_index_];
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(next_filename, &file_));
reader_ = absl::make_unique<io::SequentialRecordReader>(
file_.get(), dataset()->options_);
return Status::OK();
}
// Resets all reader streams.
void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
reader_.reset();
file_.reset();
}
mutex mu_;
size_t current_file_index_ GUARDED_BY(mu_) = 0;
// `reader_` will borrow the object that `file_` points to, so
// we must destroy `reader_` before `file_`.
std::unique_ptr<RandomAccessFile> file_ GUARDED_BY(mu_);
std::unique_ptr<io::SequentialRecordReader> reader_ GUARDED_BY(mu_);
};
const std::vector<string> filenames_;
const string compression_type_;
io::RecordReaderOptions options_;
};
TFRecordDatasetOp::TFRecordDatasetOp(OpKernelConstruction* ctx)
: DatasetOpKernel(ctx) {}
void TFRecordDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase** output) {
const Tensor* filenames_tensor;
OP_REQUIRES_OK(ctx, ctx->input(kFileNames, &filenames_tensor));
OP_REQUIRES(
ctx, filenames_tensor->dims() <= 1,
errors::InvalidArgument("`filenames` must be a scalar or a vector."));
std::vector<string> filenames;
filenames.reserve(filenames_tensor->NumElements());
for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
VLOG(2) << "Reading file: " << filenames_tensor->flat<string>()(i);
filenames.push_back(filenames_tensor->flat<string>()(i));
}
string compression_type;
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, kCompressionType,
&compression_type));
int64 buffer_size = -1;
OP_REQUIRES_OK(ctx,
ParseScalarArgument<int64>(ctx, kBufferSize, &buffer_size));
OP_REQUIRES(ctx, buffer_size >= 0,
errors::InvalidArgument(
"`buffer_size` must be >= 0 (0 == no buffering)"));
*output =
new Dataset(ctx, std::move(filenames), compression_type, buffer_size);
}
namespace {
REGISTER_KERNEL_BUILDER(Name("TFRecordDataset").Device(DEVICE_CPU),
TFRecordDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,42 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_TF_RECORD_DATASET_OP_H_
#define TENSORFLOW_CORE_KERNELS_DATA_TF_RECORD_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
namespace data {
class TFRecordDatasetOp : public DatasetOpKernel {
public:
static constexpr const char* const kDatasetType = "TFRecord";
static constexpr const char* const kFileNames = "filenames";
static constexpr const char* const kCompressionType = "compression_type";
static constexpr const char* const kBufferSize = "buffer_size";
explicit TFRecordDatasetOp(OpKernelConstruction* ctx);
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
private:
class Dataset;
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_TF_RECORD_DATASET_OP_H_

View File

@ -0,0 +1,604 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/tf_record_dataset_op.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kNodeName[] = "tf_record_dataset";
constexpr char kIteratorPrefix[] = "Iterator";
class TFRecordDatasetOpTest : public DatasetOpsTestBase {
protected:
// Create a new `TFRecordDataset` op kernel.
Status CreateTFRecordDatasetOpKernel(
std::unique_ptr<OpKernel>* tf_record_dataset_op_kernel) {
NodeDef node_def = test::function::NDef(
kNodeName, name_utils::OpName(TFRecordDatasetOp::kDatasetType),
{TFRecordDatasetOp::kFileNames, TFRecordDatasetOp::kCompressionType,
TFRecordDatasetOp::kBufferSize},
{});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, tf_record_dataset_op_kernel));
return Status::OK();
}
// Create a new `TFRecordDataset` op kernel context
Status CreateTFRecordDatasetContext(
OpKernel* const op_kernel,
gtl::InlinedVector<TensorValue, 4>* const inputs,
std::unique_ptr<OpKernelContext>* context) {
TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
return Status::OK();
}
};
struct TestCase {
std::vector<string> filenames;
std::vector<std::vector<string>> contents;
CompressionType compression_type;
int64 buffer_size;
std::vector<Tensor> expected_outputs;
DataTypeVector expected_output_dtypes;
std::vector<PartialTensorShape> expected_output_shapes;
int64 expected_cardinality;
std::vector<int> breakpoints;
};
Status CreateTestFiles(const TestCase& test_case) {
if (test_case.filenames.size() != test_case.contents.size()) {
return tensorflow::errors::InvalidArgument(
"The number of files does not match with the contents");
}
CompressionParams params;
params.compression_type = test_case.compression_type;
params.input_buffer_size = test_case.buffer_size;
for (int i = 0; i < test_case.filenames.size(); ++i) {
std::vector<absl::string_view> records(test_case.contents[i].begin(),
test_case.contents[i].end());
TF_RETURN_IF_ERROR(
WriteDataToTFRecordFile(test_case.filenames[i], records, params));
}
return Status::OK();
}
// Test case 1: multiple text files with ZLIB compression.
TestCase TestCase1() {
return {/*filenames*/ {absl::StrCat(testing::TmpDir(), "/tf_record_ZLIB_1"),
absl::StrCat(testing::TmpDir(), "/tf_record_ZLIB_2")},
/*contents*/
{{"1", "22", "333"}, {"a", "bb", "ccc"}},
/*compression_type*/ CompressionType::ZLIB,
/*buffer_size*/ 10,
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"1"}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"22"}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"333"}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"a"}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"bb"}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"ccc"})},
/*expected_output_dtypes*/ {DT_STRING},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ kUnknownCardinality,
/*breakpoints*/ {0, 2, 7}};
}
// Test case 2: multiple text files with GZIP compression.
TestCase TestCase2() {
return {/*filenames*/ {absl::StrCat(testing::TmpDir(), "/tf_record_GZIP_1"),
absl::StrCat(testing::TmpDir(), "/tf_record_GZIP_2")},
/*contents*/
{{"1", "22", "333"}, {"a", "bb", "ccc"}},
/*compression_type*/ CompressionType::GZIP,
/*buffer_size*/ 10,
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"1"}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"22"}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"333"}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"a"}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"bb"}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"ccc"})},
/*expected_output_dtypes*/ {DT_STRING},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ kUnknownCardinality,
/*breakpoints*/ {0, 2, 7}};
}
// Test case 3: multiple text files without compression.
TestCase TestCase3() {
return {/*filenames*/ {
absl::StrCat(testing::TmpDir(), "/tf_record_UNCOMPRESSED_1"),
absl::StrCat(testing::TmpDir(), "/tf_record_UNCOMPRESSED_2")},
/*contents*/
{{"1", "22", "333"}, {"a", "bb", "ccc"}},
/*compression_type*/ CompressionType::UNCOMPRESSED,
/*buffer_size*/ 10,
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"1"}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"22"}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"333"}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"a"}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"bb"}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"ccc"})},
/*expected_output_dtypes*/ {DT_STRING},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ kUnknownCardinality,
/*breakpoints*/ {0, 2, 7}};
}
class ParameterizedTFRecordDatasetOpTest
: public TFRecordDatasetOpTest,
public ::testing::WithParamInterface<TestCase> {};
TEST_P(ParameterizedTFRecordDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(
CreateIteratorContext(tf_record_dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tf_record_dataset->MakeIterator(iterator_ctx.get(),
kIteratorPrefix, &iterator));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
while (!end_of_sequence) {
std::vector<Tensor> next;
TF_EXPECT_OK(
iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence));
out_tensors.insert(out_tensors.end(), next.begin(), next.end());
}
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
/*compare_order*/ true));
}
TEST_F(TFRecordDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = TestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
EXPECT_EQ(tf_record_dataset->node_name(), kNodeName);
}
TEST_F(TFRecordDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = TestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
EXPECT_EQ(tf_record_dataset->type_string(),
name_utils::OpName(TFRecordDatasetOp::kDatasetType));
}
TEST_P(ParameterizedTFRecordDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
TF_EXPECT_OK(VerifyTypesMatch(tf_record_dataset->output_dtypes(),
test_case.expected_output_dtypes));
}
TEST_P(ParameterizedTFRecordDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
TF_EXPECT_OK(VerifyShapesCompatible(tf_record_dataset->output_shapes(),
test_case.expected_output_shapes));
}
TEST_P(ParameterizedTFRecordDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
EXPECT_EQ(tf_record_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParameterizedTFRecordDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(tf_record_dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(
CreateIteratorContext(tf_record_dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tf_record_dataset->MakeIterator(iterator_ctx.get(),
kIteratorPrefix, &iterator));
TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
test_case.expected_output_dtypes));
}
TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(
CreateIteratorContext(tf_record_dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tf_record_dataset->MakeIterator(iterator_ctx.get(),
kIteratorPrefix, &iterator));
TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
test_case.expected_output_shapes));
}
TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputPrefix) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(
CreateIteratorContext(tf_record_dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tf_record_dataset->MakeIterator(iterator_ctx.get(),
kIteratorPrefix, &iterator));
EXPECT_EQ(iterator->prefix(),
name_utils::IteratorPrefix(TFRecordDatasetOp::kDatasetType,
kIteratorPrefix));
}
TEST_P(ParameterizedTFRecordDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(
CreateIteratorContext(tf_record_dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tf_record_dataset->MakeIterator(iterator_ctx.get(),
kIteratorPrefix, &iterator));
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
int cur_iteration = 0;
const std::vector<int>& breakpoints = test_case.breakpoints;
for (int breakpoint : breakpoints) {
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, kIteratorPrefix,
*tf_record_dataset, &iterator));
while (cur_iteration <= breakpoint) {
std::vector<Tensor> next;
TF_EXPECT_OK(
iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence));
out_tensors.insert(out_tensors.end(), next.begin(), next.end());
cur_iteration++;
}
}
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
/*compare_order*/ true));
}
INSTANTIATE_TEST_SUITE_P(TFRecordDatasetOpTest,
ParameterizedTFRecordDatasetOpTest,
::testing::ValuesIn(std::vector<TestCase>(
{TestCase1(), TestCase2(), TestCase3()})));
} // namespace
} // namespace data
} // namespace tensorflow