Merge pull request #30723 from feihugis:Refactor_DatasetOps_8
PiperOrigin-RevId: 258254881
This commit is contained in:
commit
2beeba1eb6
@ -972,18 +972,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "reader_dataset_ops",
|
||||
srcs = ["reader_dataset_ops.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "text_line_dataset_op",
|
||||
srcs = ["text_line_dataset_op.cc"],
|
||||
@ -1050,6 +1038,39 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "tf_record_dataset_op",
|
||||
srcs = ["tf_record_dataset_op.cc"],
|
||||
hdrs = ["tf_record_dataset_op.h"],
|
||||
deps = [
|
||||
":name_utils",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tf_record_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["tf_record_dataset_op_test.cc"],
|
||||
deps = [
|
||||
":dataset_test_base",
|
||||
":dataset_utils",
|
||||
":iterator_ops",
|
||||
":tf_record_dataset_op",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "iterator_ops",
|
||||
srcs = ["iterator_ops.cc"],
|
||||
@ -1242,7 +1263,6 @@ tf_kernel_library(
|
||||
":parallel_map_dataset_op",
|
||||
":prefetch_dataset_op",
|
||||
":range_dataset_op",
|
||||
":reader_dataset_ops",
|
||||
":repeat_dataset_op",
|
||||
":shard_dataset_op",
|
||||
":shuffle_dataset_op",
|
||||
@ -1252,6 +1272,7 @@ tf_kernel_library(
|
||||
":tensor_dataset_op",
|
||||
":tensor_slice_dataset_op",
|
||||
":text_line_dataset_op",
|
||||
":tf_record_dataset_op",
|
||||
":window_dataset_op",
|
||||
":zip_dataset_op",
|
||||
"//tensorflow/core:array_ops_op_lib",
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/lib/io/record_writer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -72,6 +73,7 @@ Status WriteDataToFile(const string& filename, const char* data,
|
||||
zlib_compression_options);
|
||||
TF_RETURN_IF_ERROR(out.Init());
|
||||
TF_RETURN_IF_ERROR(out.Append(data));
|
||||
TF_RETURN_IF_ERROR(out.Flush());
|
||||
TF_RETURN_IF_ERROR(out.Close());
|
||||
} else {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
@ -84,6 +86,26 @@ Status WriteDataToFile(const string& filename, const char* data,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status WriteDataToTFRecordFile(const string& filename,
|
||||
const std::vector<absl::string_view>& records,
|
||||
const CompressionParams& params) {
|
||||
Env* env = Env::Default();
|
||||
std::unique_ptr<WritableFile> file_writer;
|
||||
TF_RETURN_IF_ERROR(env->NewWritableFile(filename, &file_writer));
|
||||
auto options = io::RecordWriterOptions::CreateRecordWriterOptions(
|
||||
ToString(params.compression_type));
|
||||
options.zlib_options.input_buffer_size = params.input_buffer_size;
|
||||
io::RecordWriter record_writer(file_writer.get(), options);
|
||||
for (const auto& record : records) {
|
||||
TF_RETURN_IF_ERROR(record_writer.WriteRecord(record));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(record_writer.Flush());
|
||||
TF_RETURN_IF_ERROR(record_writer.Close());
|
||||
TF_RETURN_IF_ERROR(file_writer->Flush());
|
||||
TF_RETURN_IF_ERROR(file_writer->Close());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status IsEqual(const Tensor& t1, const Tensor& t2) {
|
||||
if (t1.dtype() != t2.dtype()) {
|
||||
|
@ -71,6 +71,11 @@ Status WriteDataToFile(const string& filename, const char* data);
|
||||
Status WriteDataToFile(const string& filename, const char* data,
|
||||
const CompressionParams& params);
|
||||
|
||||
// Writes the input data into the TFRecord file with the specified compression.
|
||||
Status WriteDataToTFRecordFile(const string& filename,
|
||||
const std::vector<absl::string_view>& records,
|
||||
const CompressionParams& params);
|
||||
|
||||
// Helpful functions to test Dataset op kernels.
|
||||
class DatasetOpsTestBase : public ::testing::Test {
|
||||
public:
|
||||
|
@ -1,246 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/common_runtime/metrics.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/io/buffered_inputstream.h"
|
||||
#include "tensorflow/core/lib/io/inputbuffer.h"
|
||||
#include "tensorflow/core/lib/io/random_inputstream.h"
|
||||
#include "tensorflow/core/lib/io/record_reader.h"
|
||||
#include "tensorflow/core/lib/io/zlib_compression_options.h"
|
||||
#include "tensorflow/core/lib/io/zlib_inputstream.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
// See documentation in ../../ops/dataset_ops.cc for a high-level
|
||||
// description of the following ops.
|
||||
|
||||
constexpr char kTFRecordDatasetName[] = "TFRecord";
|
||||
|
||||
class TFRecordDatasetOp : public DatasetOpKernel {
|
||||
public:
|
||||
using DatasetOpKernel::DatasetOpKernel;
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
|
||||
const Tensor* filenames_tensor;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
|
||||
OP_REQUIRES(
|
||||
ctx, filenames_tensor->dims() <= 1,
|
||||
errors::InvalidArgument("`filenames` must be a scalar or a vector."));
|
||||
|
||||
std::vector<string> filenames;
|
||||
filenames.reserve(filenames_tensor->NumElements());
|
||||
for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
|
||||
VLOG(2) << "Reading file: " << filenames_tensor->flat<string>()(i);
|
||||
filenames.push_back(filenames_tensor->flat<string>()(i));
|
||||
}
|
||||
|
||||
string compression_type;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type",
|
||||
&compression_type));
|
||||
|
||||
int64 buffer_size = -1;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
|
||||
OP_REQUIRES(ctx, buffer_size >= 0,
|
||||
errors::InvalidArgument(
|
||||
"`buffer_size` must be >= 0 (0 == no buffering)"));
|
||||
|
||||
*output =
|
||||
new Dataset(ctx, std::move(filenames), compression_type, buffer_size);
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
public:
|
||||
explicit Dataset(OpKernelContext* ctx, std::vector<string> filenames,
|
||||
const string& compression_type, int64 buffer_size)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
filenames_(std::move(filenames)),
|
||||
compression_type_(compression_type),
|
||||
options_(io::RecordReaderOptions::CreateRecordReaderOptions(
|
||||
compression_type)) {
|
||||
if (buffer_size > 0) {
|
||||
options_.buffer_size = buffer_size;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return absl::make_unique<Iterator>(Iterator::Params{
|
||||
this, strings::StrCat(prefix, "::", kTFRecordDatasetName)});
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
|
||||
return *dtypes;
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
static std::vector<PartialTensorShape>* shapes =
|
||||
new std::vector<PartialTensorShape>({{}});
|
||||
return *shapes;
|
||||
}
|
||||
|
||||
string DebugString() const override { return "TFRecordDatasetOp::Dataset"; }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* filenames = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
|
||||
Node* compression_type = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
|
||||
Node* buffer_size = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(options_.buffer_size, &buffer_size));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, {filenames, compression_type, buffer_size}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
do {
|
||||
// We are currently processing a file, so try to read the next record.
|
||||
if (reader_) {
|
||||
out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
|
||||
TensorShape({}));
|
||||
Status s =
|
||||
reader_->ReadRecord(&out_tensors->back().scalar<string>()());
|
||||
if (s.ok()) {
|
||||
metrics::RecordTFDataBytesRead(
|
||||
kTFRecordDatasetName,
|
||||
out_tensors->back().scalar<string>()().size());
|
||||
*end_of_sequence = false;
|
||||
return Status::OK();
|
||||
}
|
||||
out_tensors->pop_back();
|
||||
if (!errors::IsOutOfRange(s)) {
|
||||
// In case of other errors e.g., DataLoss, we still move forward
|
||||
// the file index so that it works with ignore_errors.
|
||||
// Otherwise the same file will repeat.
|
||||
ResetStreamsLocked();
|
||||
++current_file_index_;
|
||||
return s;
|
||||
}
|
||||
|
||||
// We have reached the end of the current file, so maybe
|
||||
// move on to next file.
|
||||
ResetStreamsLocked();
|
||||
++current_file_index_;
|
||||
}
|
||||
|
||||
// Iteration ends when there are no more files to process.
|
||||
if (current_file_index_ == dataset()->filenames_.size()) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
|
||||
} while (true);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const override {
|
||||
return model::MakeSourceNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
|
||||
current_file_index_));
|
||||
|
||||
if (reader_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("offset"), reader_->TellOffset()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
ResetStreamsLocked();
|
||||
int64 current_file_index;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
|
||||
¤t_file_index));
|
||||
current_file_index_ = size_t(current_file_index);
|
||||
if (reader->Contains(full_name("offset"))) {
|
||||
int64 offset;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("offset"), &offset));
|
||||
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
|
||||
TF_RETURN_IF_ERROR(reader_->SeekOffset(offset));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
// Sets up reader streams to read from the file at `current_file_index_`.
|
||||
Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (current_file_index_ >= dataset()->filenames_.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"current_file_index_:", current_file_index_,
|
||||
" >= filenames_.size():", dataset()->filenames_.size());
|
||||
}
|
||||
|
||||
// Actually move on to next file.
|
||||
const string& next_filename =
|
||||
dataset()->filenames_[current_file_index_];
|
||||
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(next_filename, &file_));
|
||||
reader_ = absl::make_unique<io::SequentialRecordReader>(
|
||||
file_.get(), dataset()->options_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Resets all reader streams.
|
||||
void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
reader_.reset();
|
||||
file_.reset();
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
size_t current_file_index_ GUARDED_BY(mu_) = 0;
|
||||
|
||||
// `reader_` will borrow the object that `file_` points to, so
|
||||
// we must destroy `reader_` before `file_`.
|
||||
std::unique_ptr<RandomAccessFile> file_ GUARDED_BY(mu_);
|
||||
std::unique_ptr<io::SequentialRecordReader> reader_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
const std::vector<string> filenames_;
|
||||
const string compression_type_;
|
||||
io::RecordReaderOptions options_;
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TFRecordDataset").Device(DEVICE_CPU),
|
||||
TFRecordDatasetOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
251
tensorflow/core/kernels/data/tf_record_dataset_op.cc
Normal file
251
tensorflow/core/kernels/data/tf_record_dataset_op.cc
Normal file
@ -0,0 +1,251 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/tf_record_dataset_op.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/metrics.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/lib/io/buffered_inputstream.h"
|
||||
#include "tensorflow/core/lib/io/inputbuffer.h"
|
||||
#include "tensorflow/core/lib/io/random_inputstream.h"
|
||||
#include "tensorflow/core/lib/io/record_reader.h"
|
||||
#include "tensorflow/core/lib/io/zlib_compression_options.h"
|
||||
#include "tensorflow/core/lib/io/zlib_inputstream.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// See documentation in ../../ops/dataset_ops.cc for a high-level
|
||||
// description of the following ops.
|
||||
|
||||
/* static */ constexpr const char* const TFRecordDatasetOp::kDatasetType;
|
||||
/* static */ constexpr const char* const TFRecordDatasetOp::kFileNames;
|
||||
/* static */ constexpr const char* const TFRecordDatasetOp::kCompressionType;
|
||||
/* static */ constexpr const char* const TFRecordDatasetOp::kBufferSize;
|
||||
|
||||
constexpr char kCurrentFileIndex[] = "current_file_index";
|
||||
constexpr char kOffset[] = "offset";
|
||||
|
||||
class TFRecordDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
explicit Dataset(OpKernelContext* ctx, std::vector<string> filenames,
|
||||
const string& compression_type, int64 buffer_size)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
filenames_(std::move(filenames)),
|
||||
compression_type_(compression_type),
|
||||
options_(io::RecordReaderOptions::CreateRecordReaderOptions(
|
||||
compression_type)) {
|
||||
if (buffer_size > 0) {
|
||||
options_.buffer_size = buffer_size;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return absl::make_unique<Iterator>(Iterator::Params{
|
||||
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
|
||||
return *dtypes;
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
static std::vector<PartialTensorShape>* shapes =
|
||||
new std::vector<PartialTensorShape>({{}});
|
||||
return *shapes;
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* filenames = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
|
||||
Node* compression_type = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
|
||||
Node* buffer_size = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(options_.buffer_size, &buffer_size));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, {filenames, compression_type, buffer_size}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
do {
|
||||
// We are currently processing a file, so try to read the next record.
|
||||
if (reader_) {
|
||||
out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
|
||||
TensorShape({}));
|
||||
Status s =
|
||||
reader_->ReadRecord(&out_tensors->back().scalar<string>()());
|
||||
if (s.ok()) {
|
||||
metrics::RecordTFDataBytesRead(
|
||||
kDatasetType, out_tensors->back().scalar<string>()().size());
|
||||
*end_of_sequence = false;
|
||||
return Status::OK();
|
||||
}
|
||||
out_tensors->pop_back();
|
||||
if (!errors::IsOutOfRange(s)) {
|
||||
// In case of other errors e.g., DataLoss, we still move forward
|
||||
// the file index so that it works with ignore_errors.
|
||||
// Otherwise the same file will repeat.
|
||||
ResetStreamsLocked();
|
||||
++current_file_index_;
|
||||
return s;
|
||||
}
|
||||
|
||||
// We have reached the end of the current file, so maybe move on to
|
||||
// next file.
|
||||
ResetStreamsLocked();
|
||||
++current_file_index_;
|
||||
}
|
||||
|
||||
// Iteration ends when there are no more files to process.
|
||||
if (current_file_index_ == dataset()->filenames_.size()) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
|
||||
} while (true);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const override {
|
||||
return model::MakeSourceNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurrentFileIndex),
|
||||
current_file_index_));
|
||||
|
||||
if (reader_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kOffset), reader_->TellOffset()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
ResetStreamsLocked();
|
||||
int64 current_file_index;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurrentFileIndex),
|
||||
¤t_file_index));
|
||||
current_file_index_ = size_t(current_file_index);
|
||||
if (reader->Contains(full_name(kOffset))) {
|
||||
int64 offset;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kOffset), &offset));
|
||||
TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
|
||||
TF_RETURN_IF_ERROR(reader_->SeekOffset(offset));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
// Sets up reader streams to read from the file at `current_file_index_`.
|
||||
Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (current_file_index_ >= dataset()->filenames_.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"current_file_index_:", current_file_index_,
|
||||
" >= filenames_.size():", dataset()->filenames_.size());
|
||||
}
|
||||
|
||||
// Actually move on to next file.
|
||||
const string& next_filename = dataset()->filenames_[current_file_index_];
|
||||
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(next_filename, &file_));
|
||||
reader_ = absl::make_unique<io::SequentialRecordReader>(
|
||||
file_.get(), dataset()->options_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Resets all reader streams.
|
||||
void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
reader_.reset();
|
||||
file_.reset();
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
size_t current_file_index_ GUARDED_BY(mu_) = 0;
|
||||
|
||||
// `reader_` will borrow the object that `file_` points to, so
|
||||
// we must destroy `reader_` before `file_`.
|
||||
std::unique_ptr<RandomAccessFile> file_ GUARDED_BY(mu_);
|
||||
std::unique_ptr<io::SequentialRecordReader> reader_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
const std::vector<string> filenames_;
|
||||
const string compression_type_;
|
||||
io::RecordReaderOptions options_;
|
||||
};
|
||||
|
||||
TFRecordDatasetOp::TFRecordDatasetOp(OpKernelConstruction* ctx)
|
||||
: DatasetOpKernel(ctx) {}
|
||||
|
||||
void TFRecordDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
DatasetBase** output) {
|
||||
const Tensor* filenames_tensor;
|
||||
OP_REQUIRES_OK(ctx, ctx->input(kFileNames, &filenames_tensor));
|
||||
OP_REQUIRES(
|
||||
ctx, filenames_tensor->dims() <= 1,
|
||||
errors::InvalidArgument("`filenames` must be a scalar or a vector."));
|
||||
|
||||
std::vector<string> filenames;
|
||||
filenames.reserve(filenames_tensor->NumElements());
|
||||
for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
|
||||
VLOG(2) << "Reading file: " << filenames_tensor->flat<string>()(i);
|
||||
filenames.push_back(filenames_tensor->flat<string>()(i));
|
||||
}
|
||||
|
||||
string compression_type;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, kCompressionType,
|
||||
&compression_type));
|
||||
|
||||
int64 buffer_size = -1;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ParseScalarArgument<int64>(ctx, kBufferSize, &buffer_size));
|
||||
OP_REQUIRES(ctx, buffer_size >= 0,
|
||||
errors::InvalidArgument(
|
||||
"`buffer_size` must be >= 0 (0 == no buffering)"));
|
||||
|
||||
*output =
|
||||
new Dataset(ctx, std::move(filenames), compression_type, buffer_size);
|
||||
}
|
||||
|
||||
namespace {
|
||||
REGISTER_KERNEL_BUILDER(Name("TFRecordDataset").Device(DEVICE_CPU),
|
||||
TFRecordDatasetOp);
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
42
tensorflow/core/kernels/data/tf_record_dataset_op.h
Normal file
42
tensorflow/core/kernels/data/tf_record_dataset_op.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_TF_RECORD_DATASET_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_TF_RECORD_DATASET_OP_H_
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
class TFRecordDatasetOp : public DatasetOpKernel {
|
||||
public:
|
||||
static constexpr const char* const kDatasetType = "TFRecord";
|
||||
static constexpr const char* const kFileNames = "filenames";
|
||||
static constexpr const char* const kCompressionType = "compression_type";
|
||||
static constexpr const char* const kBufferSize = "buffer_size";
|
||||
|
||||
explicit TFRecordDatasetOp(OpKernelConstruction* ctx);
|
||||
|
||||
protected:
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
|
||||
|
||||
private:
|
||||
class Dataset;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_TF_RECORD_DATASET_OP_H_
|
604
tensorflow/core/kernels/data/tf_record_dataset_op_test.cc
Normal file
604
tensorflow/core/kernels/data/tf_record_dataset_op_test.cc
Normal file
@ -0,0 +1,604 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/tf_record_dataset_op.h"
|
||||
|
||||
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr char kNodeName[] = "tf_record_dataset";
|
||||
constexpr char kIteratorPrefix[] = "Iterator";
|
||||
|
||||
class TFRecordDatasetOpTest : public DatasetOpsTestBase {
|
||||
protected:
|
||||
// Create a new `TFRecordDataset` op kernel.
|
||||
Status CreateTFRecordDatasetOpKernel(
|
||||
std::unique_ptr<OpKernel>* tf_record_dataset_op_kernel) {
|
||||
NodeDef node_def = test::function::NDef(
|
||||
kNodeName, name_utils::OpName(TFRecordDatasetOp::kDatasetType),
|
||||
{TFRecordDatasetOp::kFileNames, TFRecordDatasetOp::kCompressionType,
|
||||
TFRecordDatasetOp::kBufferSize},
|
||||
{});
|
||||
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, tf_record_dataset_op_kernel));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Create a new `TFRecordDataset` op kernel context
|
||||
Status CreateTFRecordDatasetContext(
|
||||
OpKernel* const op_kernel,
|
||||
gtl::InlinedVector<TensorValue, 4>* const inputs,
|
||||
std::unique_ptr<OpKernelContext>* context) {
|
||||
TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
|
||||
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
struct TestCase {
|
||||
std::vector<string> filenames;
|
||||
std::vector<std::vector<string>> contents;
|
||||
CompressionType compression_type;
|
||||
int64 buffer_size;
|
||||
std::vector<Tensor> expected_outputs;
|
||||
DataTypeVector expected_output_dtypes;
|
||||
std::vector<PartialTensorShape> expected_output_shapes;
|
||||
int64 expected_cardinality;
|
||||
std::vector<int> breakpoints;
|
||||
};
|
||||
|
||||
Status CreateTestFiles(const TestCase& test_case) {
|
||||
if (test_case.filenames.size() != test_case.contents.size()) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"The number of files does not match with the contents");
|
||||
}
|
||||
|
||||
CompressionParams params;
|
||||
params.compression_type = test_case.compression_type;
|
||||
params.input_buffer_size = test_case.buffer_size;
|
||||
for (int i = 0; i < test_case.filenames.size(); ++i) {
|
||||
std::vector<absl::string_view> records(test_case.contents[i].begin(),
|
||||
test_case.contents[i].end());
|
||||
TF_RETURN_IF_ERROR(
|
||||
WriteDataToTFRecordFile(test_case.filenames[i], records, params));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Test case 1: multiple text files with ZLIB compression.
|
||||
TestCase TestCase1() {
|
||||
return {/*filenames*/ {absl::StrCat(testing::TmpDir(), "/tf_record_ZLIB_1"),
|
||||
absl::StrCat(testing::TmpDir(), "/tf_record_ZLIB_2")},
|
||||
/*contents*/
|
||||
{{"1", "22", "333"}, {"a", "bb", "ccc"}},
|
||||
/*compression_type*/ CompressionType::ZLIB,
|
||||
/*buffer_size*/ 10,
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"1"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"22"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"333"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"a"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"bb"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"ccc"})},
|
||||
/*expected_output_dtypes*/ {DT_STRING},
|
||||
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||
/*expected_cardinality*/ kUnknownCardinality,
|
||||
/*breakpoints*/ {0, 2, 7}};
|
||||
}
|
||||
|
||||
// Test case 2: multiple text files with GZIP compression.
|
||||
TestCase TestCase2() {
|
||||
return {/*filenames*/ {absl::StrCat(testing::TmpDir(), "/tf_record_GZIP_1"),
|
||||
absl::StrCat(testing::TmpDir(), "/tf_record_GZIP_2")},
|
||||
/*contents*/
|
||||
{{"1", "22", "333"}, {"a", "bb", "ccc"}},
|
||||
/*compression_type*/ CompressionType::GZIP,
|
||||
/*buffer_size*/ 10,
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"1"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"22"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"333"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"a"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"bb"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"ccc"})},
|
||||
/*expected_output_dtypes*/ {DT_STRING},
|
||||
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||
/*expected_cardinality*/ kUnknownCardinality,
|
||||
/*breakpoints*/ {0, 2, 7}};
|
||||
}
|
||||
|
||||
// Test case 3: multiple text files without compression.
|
||||
TestCase TestCase3() {
|
||||
return {/*filenames*/ {
|
||||
absl::StrCat(testing::TmpDir(), "/tf_record_UNCOMPRESSED_1"),
|
||||
absl::StrCat(testing::TmpDir(), "/tf_record_UNCOMPRESSED_2")},
|
||||
/*contents*/
|
||||
{{"1", "22", "333"}, {"a", "bb", "ccc"}},
|
||||
/*compression_type*/ CompressionType::UNCOMPRESSED,
|
||||
/*buffer_size*/ 10,
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"1"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"22"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"333"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"a"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"bb"}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({}), {"ccc"})},
|
||||
/*expected_output_dtypes*/ {DT_STRING},
|
||||
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||
/*expected_cardinality*/ kUnknownCardinality,
|
||||
/*breakpoints*/ {0, 2, 7}};
|
||||
}
|
||||
|
||||
class ParameterizedTFRecordDatasetOpTest
|
||||
: public TFRecordDatasetOpTest,
|
||||
public ::testing::WithParamInterface<TestCase> {};
|
||||
|
||||
TEST_P(ParameterizedTFRecordDatasetOpTest, GetNext) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(CreateTestFiles(test_case));
|
||||
|
||||
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
|
||||
TensorValue(&compression_type),
|
||||
TensorValue(&buffer_size)};
|
||||
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetContext(
|
||||
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
|
||||
|
||||
DatasetBase* tf_record_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
|
||||
tf_record_dataset_context.get(),
|
||||
&tf_record_dataset));
|
||||
core::ScopedUnref scoped_unref(tf_record_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_ctx;
|
||||
TF_ASSERT_OK(
|
||||
CreateIteratorContext(tf_record_dataset_context.get(), &iterator_ctx));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tf_record_dataset->MakeIterator(iterator_ctx.get(),
|
||||
kIteratorPrefix, &iterator));
|
||||
bool end_of_sequence = false;
|
||||
std::vector<Tensor> out_tensors;
|
||||
while (!end_of_sequence) {
|
||||
std::vector<Tensor> next;
|
||||
TF_EXPECT_OK(
|
||||
iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence));
|
||||
out_tensors.insert(out_tensors.end(), next.begin(), next.end());
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
|
||||
/*compare_order*/ true));
|
||||
}
|
||||
|
||||
TEST_F(TFRecordDatasetOpTest, DatasetNodeName) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = TestCase1();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(CreateTestFiles(test_case));
|
||||
|
||||
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
|
||||
TensorValue(&compression_type),
|
||||
TensorValue(&buffer_size)};
|
||||
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetContext(
|
||||
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
|
||||
|
||||
DatasetBase* tf_record_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
|
||||
tf_record_dataset_context.get(),
|
||||
&tf_record_dataset));
|
||||
core::ScopedUnref scoped_unref(tf_record_dataset);
|
||||
EXPECT_EQ(tf_record_dataset->node_name(), kNodeName);
|
||||
}
|
||||
|
||||
TEST_F(TFRecordDatasetOpTest, DatasetTypeString) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = TestCase1();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(CreateTestFiles(test_case));
|
||||
|
||||
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
|
||||
TensorValue(&compression_type),
|
||||
TensorValue(&buffer_size)};
|
||||
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetContext(
|
||||
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
|
||||
|
||||
DatasetBase* tf_record_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
|
||||
tf_record_dataset_context.get(),
|
||||
&tf_record_dataset));
|
||||
core::ScopedUnref scoped_unref(tf_record_dataset);
|
||||
EXPECT_EQ(tf_record_dataset->type_string(),
|
||||
name_utils::OpName(TFRecordDatasetOp::kDatasetType));
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedTFRecordDatasetOpTest, DatasetOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(CreateTestFiles(test_case));
|
||||
|
||||
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
|
||||
TensorValue(&compression_type),
|
||||
TensorValue(&buffer_size)};
|
||||
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetContext(
|
||||
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
|
||||
|
||||
DatasetBase* tf_record_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
|
||||
tf_record_dataset_context.get(),
|
||||
&tf_record_dataset));
|
||||
core::ScopedUnref scoped_unref(tf_record_dataset);
|
||||
TF_EXPECT_OK(VerifyTypesMatch(tf_record_dataset->output_dtypes(),
|
||||
test_case.expected_output_dtypes));
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedTFRecordDatasetOpTest, DatasetOutputShapes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(CreateTestFiles(test_case));
|
||||
|
||||
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
|
||||
TensorValue(&compression_type),
|
||||
TensorValue(&buffer_size)};
|
||||
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetContext(
|
||||
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
|
||||
|
||||
DatasetBase* tf_record_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
|
||||
tf_record_dataset_context.get(),
|
||||
&tf_record_dataset));
|
||||
core::ScopedUnref scoped_unref(tf_record_dataset);
|
||||
TF_EXPECT_OK(VerifyShapesCompatible(tf_record_dataset->output_shapes(),
|
||||
test_case.expected_output_shapes));
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedTFRecordDatasetOpTest, Cardinality) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(CreateTestFiles(test_case));
|
||||
|
||||
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
|
||||
TensorValue(&compression_type),
|
||||
TensorValue(&buffer_size)};
|
||||
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetContext(
|
||||
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
|
||||
|
||||
DatasetBase* tf_record_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
|
||||
tf_record_dataset_context.get(),
|
||||
&tf_record_dataset));
|
||||
core::ScopedUnref scoped_unref(tf_record_dataset);
|
||||
EXPECT_EQ(tf_record_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedTFRecordDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(CreateTestFiles(test_case));
|
||||
|
||||
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
|
||||
TensorValue(&compression_type),
|
||||
TensorValue(&buffer_size)};
|
||||
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetContext(
|
||||
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
|
||||
|
||||
DatasetBase* tf_record_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
|
||||
tf_record_dataset_context.get(),
|
||||
&tf_record_dataset));
|
||||
core::ScopedUnref scoped_unref(tf_record_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(tf_record_dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(CreateTestFiles(test_case));
|
||||
|
||||
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
|
||||
TensorValue(&compression_type),
|
||||
TensorValue(&buffer_size)};
|
||||
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetContext(
|
||||
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
|
||||
|
||||
DatasetBase* tf_record_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
|
||||
tf_record_dataset_context.get(),
|
||||
&tf_record_dataset));
|
||||
core::ScopedUnref scoped_unref(tf_record_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_ctx;
|
||||
TF_ASSERT_OK(
|
||||
CreateIteratorContext(tf_record_dataset_context.get(), &iterator_ctx));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tf_record_dataset->MakeIterator(iterator_ctx.get(),
|
||||
kIteratorPrefix, &iterator));
|
||||
|
||||
TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
|
||||
test_case.expected_output_dtypes));
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputShapes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(CreateTestFiles(test_case));
|
||||
|
||||
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
|
||||
TensorValue(&compression_type),
|
||||
TensorValue(&buffer_size)};
|
||||
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetContext(
|
||||
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
|
||||
|
||||
DatasetBase* tf_record_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
|
||||
tf_record_dataset_context.get(),
|
||||
&tf_record_dataset));
|
||||
core::ScopedUnref scoped_unref(tf_record_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_ctx;
|
||||
TF_ASSERT_OK(
|
||||
CreateIteratorContext(tf_record_dataset_context.get(), &iterator_ctx));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tf_record_dataset->MakeIterator(iterator_ctx.get(),
|
||||
kIteratorPrefix, &iterator));
|
||||
|
||||
TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
|
||||
test_case.expected_output_shapes));
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputPrefix) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(CreateTestFiles(test_case));
|
||||
|
||||
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
|
||||
TensorValue(&compression_type),
|
||||
TensorValue(&buffer_size)};
|
||||
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetContext(
|
||||
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
|
||||
|
||||
DatasetBase* tf_record_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
|
||||
tf_record_dataset_context.get(),
|
||||
&tf_record_dataset));
|
||||
core::ScopedUnref scoped_unref(tf_record_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_ctx;
|
||||
TF_ASSERT_OK(
|
||||
CreateIteratorContext(tf_record_dataset_context.get(), &iterator_ctx));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tf_record_dataset->MakeIterator(iterator_ctx.get(),
|
||||
kIteratorPrefix, &iterator));
|
||||
|
||||
EXPECT_EQ(iterator->prefix(),
|
||||
name_utils::IteratorPrefix(TFRecordDatasetOp::kDatasetType,
|
||||
kIteratorPrefix));
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedTFRecordDatasetOpTest, Roundtrip) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(CreateTestFiles(test_case));
|
||||
|
||||
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
|
||||
TensorValue(&compression_type),
|
||||
TensorValue(&buffer_size)};
|
||||
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetContext(
|
||||
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
|
||||
|
||||
DatasetBase* tf_record_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
|
||||
tf_record_dataset_context.get(),
|
||||
&tf_record_dataset));
|
||||
core::ScopedUnref scoped_unref(tf_record_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_ctx;
|
||||
TF_ASSERT_OK(
|
||||
CreateIteratorContext(tf_record_dataset_context.get(), &iterator_ctx));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tf_record_dataset->MakeIterator(iterator_ctx.get(),
|
||||
kIteratorPrefix, &iterator));
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
|
||||
bool end_of_sequence = false;
|
||||
std::vector<Tensor> out_tensors;
|
||||
int cur_iteration = 0;
|
||||
const std::vector<int>& breakpoints = test_case.breakpoints;
|
||||
for (int breakpoint : breakpoints) {
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
|
||||
TF_EXPECT_OK(writer.Flush());
|
||||
VariantTensorDataReader reader(&data);
|
||||
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, kIteratorPrefix,
|
||||
*tf_record_dataset, &iterator));
|
||||
|
||||
while (cur_iteration <= breakpoint) {
|
||||
std::vector<Tensor> next;
|
||||
TF_EXPECT_OK(
|
||||
iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence));
|
||||
out_tensors.insert(out_tensors.end(), next.begin(), next.end());
|
||||
cur_iteration++;
|
||||
}
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
|
||||
/*compare_order*/ true));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TFRecordDatasetOpTest,
|
||||
ParameterizedTFRecordDatasetOpTest,
|
||||
::testing::ValuesIn(std::vector<TestCase>(
|
||||
{TestCase1(), TestCase2(), TestCase3()})));
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user