Add a new lightweight queue-like object - RecordInput

Change: 144752664
This commit is contained in:
A. Unique TensorFlower 2017-01-17 14:03:09 -08:00 committed by TensorFlower Gardener
parent bf67d0a1c5
commit 66b5684133
8 changed files with 631 additions and 0 deletions

View File

@ -317,6 +317,19 @@ cc_library(
],
)
cc_library(
name = "record_input_op",
srcs = [
"record_input_op.cc",
"record_yielder.cc",
"record_yielder.h",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
cc_library(
name = "save_restore_tensor",
srcs = ["save_restore_tensor.cc"],
@ -1167,6 +1180,7 @@ cc_library(
":priority_queue_op",
":queue_ops",
":random_shuffle_queue_op",
":record_input_op",
":session_ops",
":sparse_conditional_accumulator_op",
":stack_ops",

View File

@ -0,0 +1,67 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/record_yielder.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
class RecordInputOp : public OpKernel {
public:
explicit RecordInputOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
#define GETATTR(TYPE, FIELD) \
TYPE FIELD; \
OP_REQUIRES_OK(ctx, ctx->GetAttr(#FIELD, &FIELD));
GETATTR(string, file_pattern);
GETATTR(int64, file_random_seed);
GETATTR(float, file_shuffle_shift_ratio);
GETATTR(int64, file_buffer_size);
GETATTR(int64, file_parallelism);
GETATTR(int64, batch_size);
#undef GETATTR
RecordYielder::Options yopts;
yopts.file_pattern = file_pattern;
yopts.seed = file_random_seed;
yopts.bufsize = file_buffer_size;
yopts.file_shuffle_shift_ratio = file_shuffle_shift_ratio;
yopts.parallelism = file_parallelism;
yielder_ = std::unique_ptr<RecordYielder>(new RecordYielder(ctx, yopts));
batch_size_ = batch_size;
}
void Compute(OpKernelContext* ctx) override {
Tensor out(DT_STRING, {batch_size_});
auto t_out = out.flat<string>();
for (int i = 0; i < batch_size_; ++i) {
OP_REQUIRES_OK(ctx, yielder_->Yield(&t_out(i)));
}
ctx->set_output(0, out);
}
private:
int64 batch_size_;
std::unique_ptr<RecordYielder> yielder_;
};
REGISTER_KERNEL_BUILDER(Name("RecordInput").Device(DEVICE_CPU), RecordInputOp);
} // namespace tensorflow

View File

@ -0,0 +1,216 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/record_yielder.h"
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
RecordYielder::RecordYielder(OpKernelConstruction* context,
const RecordYielder::Options& opts)
: opts_(opts),
thread_(new thread::ThreadPool(context->env(), "record_yielder",
1 + opts.parallelism)),
epoch_(0),
rnd_(opts.seed) {
thread_->Schedule([this]() { MainLoop(); });
}
RecordYielder::~RecordYielder() {
{
mutex_lock l(mu_);
stop_ = true;
buf_empty_.notify_all();
buf_enough_.notify_all();
buf_not_full_.notify_all();
}
main_loop_done_.WaitForNotification();
delete thread_;
}
Status RecordYielder::Yield(string* value) {
mutex_lock l(mu_);
while (!BufEnough()) {
buf_enough_.wait(l);
}
if (status_.ok()) {
bool notify_no_longer_full = !BufNotFull();
CHECK(!stop_ && !buf_.empty());
*value = std::move(buf_.back());
buf_.pop_back();
++num_records_yielded_in_epoch_;
// Assumption is that an epoch always has something in the buffer
// until it ends. If the input pipeline was slower than the consumers
// by a lot this might not be true. Not sure how to handle.
if (buf_.empty()) {
buf_empty_.notify_all();
}
if (notify_no_longer_full) {
buf_not_full_.notify_all();
}
}
return status_;
}
struct RecordYielder::Shard {
int index; // Shard index.
std::vector<string> filenames; // File names given to this shard.
Notification done; // Notified when this shard is done.
Status status; // Shard status.
};
bool RecordYielder::ShouldFinish(const Status& s) {
mutex_lock l(mu_);
status_.Update(s);
return stop_ || !status_.ok();
}
static Status MatchFiles(const string& patterns,
std::vector<string>* filenames) {
for (const auto& file_pattern : str_util::Split(patterns, ',')) {
std::vector<string> tmp_filenames;
TF_RETURN_IF_ERROR(
Env::Default()->GetMatchingPaths(file_pattern, &tmp_filenames));
filenames->insert(filenames->end(),
std::make_move_iterator(tmp_filenames.begin()),
std::make_move_iterator(tmp_filenames.end()));
}
return Status::OK();
}
void RecordYielder::MainLoop() {
while (true) {
++epoch_;
num_records_yielded_in_epoch_ = 0;
// Finds all files.
std::vector<string> filenames;
Status s = MatchFiles(opts_.file_pattern, &filenames);
if (ShouldFinish(s)) break;
if (filenames.empty()) {
s = errors::NotFound("Found no files at ", opts_.file_pattern);
if (ShouldFinish(s)) break;
}
// Shuffles these files according to the epoch # and random seed.
std::mt19937_64 shuffle_rnd(
Hash64(reinterpret_cast<char*>(&epoch_), sizeof(epoch_), opts_.seed));
std::shuffle(filenames.begin(), filenames.end(), shuffle_rnd);
// Left-shift the filename list.
const int64 num = filenames.size();
int64 shift;
if (0 <= opts_.file_shuffle_shift_ratio &&
opts_.file_shuffle_shift_ratio < 1) {
shift = opts_.file_shuffle_shift_ratio * num;
std::rotate(filenames.begin(), filenames.begin() + shift,
filenames.end());
}
// Shards files and use one thread to go through each shard.
const int N = opts_.parallelism;
std::vector<Shard> shards(N);
for (int i = 0; i < N; ++i) {
Shard* shard = &shards[i];
shard->index = i;
for (int j = i; j < filenames.size(); j += N) {
shard->filenames.push_back(filenames[j]);
}
thread_->Schedule([this, shard]() { ShardLoop(shard); });
}
for (int i = 0; i < N; ++i) {
shards[i].done.WaitForNotification();
s.Update(shards[i].status);
}
if (ShouldFinish(s)) break;
// Starts the next epoch once all buffered records are consumed.
{
mutex_lock l(mu_);
epoch_end_ = true;
while (!BufEmpty()) {
buf_empty_.wait(l);
}
epoch_end_ = false;
}
}
main_loop_done_.Notify();
}
bool RecordYielder::Add(std::vector<string>* values) {
mutex_lock l(mu_);
while (!BufNotFull()) {
buf_not_full_.wait(l);
}
while (BufNotFull() && !values->empty()) {
// Adds values->back(). Swaps its position with another random
// element.
auto index = rnd_() % (buf_.size() + 1);
if (index == buf_.size()) {
buf_.push_back(std::move(values->back()));
} else {
buf_.push_back(std::move(buf_[index]));
buf_[index] = std::move(values->back());
}
values->pop_back();
}
if (BufEnough()) {
buf_enough_.notify_all();
}
return stop_;
}
void RecordYielder::ShardLoop(Shard* shard) {
std::vector<string> values;
const int64 kRecords = 16;
for (const string& filename : shard->filenames) {
std::unique_ptr<RandomAccessFile> file;
if (ShouldFinish(Status::OK())) break;
Status s = Env::Default()->NewRandomAccessFile(filename, &file);
if (!s.ok()) {
shard->status = errors::InvalidArgument("Can't open ", filename);
break;
}
io::RecordReader rdr(file.get());
uint64 offset = 0;
string record;
while (true) {
Status s = rdr.ReadRecord(&offset, &record);
if (s.ok()) {
values.emplace_back(std::move(record));
if (values.size() >= kRecords && Add(&values)) {
shard->status = errors::Aborted("stopped");
break;
}
} else if (errors::IsOutOfRange(s)) {
break;
} else {
shard->status = s;
break;
}
}
}
// Adds the remaining values of this shard to buf_.
while (!values.empty()) {
Add(&values);
}
shard->done.Notify();
}
} // namespace tensorflow

View File

@ -0,0 +1,157 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_KERNELS_RECORD_YIELDER_H_
#define TENSORFLOW_KERNELS_RECORD_YIELDER_H_
#include <atomic>
#include <random>
#include <string>
#include <vector>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
// RecordYielder produces value records from a set of tfrecord files
// in a random order.
//
// It guarantees that:
// 1) all records in tfrecords are yielded within every epoch;
// 2) each record is yielded only once within every epoch;
// 3) the order in which records are yielded are highly randomized.
// 4) the peak memory usage is roughly avg record size *
// (opts.bufsize + opts.parellelism * 16).
//
// Usage example:
// RecordYielder::Options opts;
// opts.file_pattern = "input-*";
// opts.seed = 301;
// opts.bufsize = 1000000; // A randomized buffer with 1M records.
// opts.parallelism = 8; // Uses 8 tfrecord iterators to iterate
// // through all files.
// RecordYielder yielder(opts);
// string val;
// while (true) {
// yielder.Yield(&val);
// // process val
// }
//
// RecordYielder can be accessed by multiple threads concurrently.
class RecordYielder {
public:
struct Options {
// Glob pattern for tfrecords.
string file_pattern;
// Random seed. It determines how data files are shuffled and how
// records are shuffled.
int64 seed = 0;
// Each epoch, all files are first shuffled according to the
// random seed and the epoch number, and then all files are
// left-shifted by file_shuffle_shift_ratio * num_files slots. If
// file_shuffle_shift_ratio is not within [0, 1), the
// implementation clip it to [0, 1).
float file_shuffle_shift_ratio = 0;
// Randomization buffer keeps these many records.
uint64 bufsize = 1;
// Uses these many concurrent tfrecord iterators to iterate through
// tfrecords.
int32 parallelism = 1;
};
explicit RecordYielder(OpKernelConstruction* context,
const RecordYielder::Options& opts);
~RecordYielder();
RecordYielder(const RecordYielder&) = delete;
RecordYielder& operator=(const RecordYielder&) = delete;
// Yields one 'value'.
Status Yield(string* value);
// Returns the current epoch number.
int64 current_epoch() const { return epoch_; }
private:
typedef RecordYielder ME;
Options opts_;
// Backgrounds threads. Owned.
thread::ThreadPool* thread_;
// Epoch number.
std::atomic<int64> epoch_;
mutex mu_;
// Turned to true when this is deleted.
bool stop_ GUARDED_BY(mu_) = false;
Status status_ GUARDED_BY(mu_);
// PRG used for randomization.
std::mt19937_64 rnd_ GUARDED_BY(mu_);
// Randomization buffer.
std::vector<string> buf_ GUARDED_BY(mu_);
// True iff we are draining an epoch.
bool epoch_end_ = false;
int64 num_records_yielded_in_epoch_ = 0;
// Trigger when the main loop has exited.
Notification main_loop_done_;
// condition_variables.
condition_variable buf_empty_;
bool BufEmpty() const SHARED_LOCKS_REQUIRED(mu_) {
return stop_ || buf_.empty();
}
condition_variable buf_not_full_;
bool BufNotFull() const SHARED_LOCKS_REQUIRED(mu_) {
return stop_ || buf_.size() < opts_.bufsize;
}
condition_variable buf_enough_;
bool BufEnough() const SHARED_LOCKS_REQUIRED(mu_) {
// NOTE: Unless we are finishing an epoch, we want to make sure
// the buf_ contains enough randomized elements before yielding
// any.
return stop_ || !status_.ok() || (epoch_end_ && !buf_.empty()) ||
(!epoch_end_ &&
buf_.size() >= std::max<int64>(1, opts_.bufsize / 2));
}
void MainLoop();
struct Shard;
void ShardLoop(Shard* shard);
bool ShouldFinish(const Status& s);
bool Add(std::vector<string>* values);
};
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_RECORD_YIELDER_H_

View File

@ -2211,4 +2211,27 @@ dequeue with many fewer capabilities and options. This Op is optimized for
performance.
)doc");
REGISTER_OP("RecordInput")
.Output("records: string")
.Attr("file_pattern: string")
.Attr("file_random_seed: int = 301")
.Attr("file_shuffle_shift_ratio: float = 0")
.Attr("file_buffer_size: int = 10000")
.Attr("file_parallelism: int = 16")
.Attr("batch_size: int = 32")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Emits randomized records.
records: A tensor of shape [batch_size].
file_pattern: Glob pattern for the data files.
file_random_seed: Random seeds used to produce randomized records.
file_shuffle_shift_ratio: Shifts the list of files after the list is randomly
shuffled.
file_buffer_size: The randomization shuffling buffer.
file_parallelism: How many sstables are opened and concurrently iterated over.
batch_size: The batch size.
)doc");
} // namespace tensorflow

View File

@ -341,6 +341,18 @@ tf_py_test(
],
)
tf_py_test(
name = "record_input_test",
size = "small",
srcs = ["record_input_test.py"],
additional_deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:io_ops",
"//tensorflow/python:util",
],
)
tf_py_test(
name = "io_ops_test",
size = "small",

View File

@ -0,0 +1,80 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for record_input_op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.python.lib.io import tf_record
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.platform import test
class RecordInputOpTest(test.TestCase):
def generateTestData(self, prefix, n, m):
for i in range(n):
f = os.path.join(self.get_temp_dir(), prefix + "." + str(i))
w = tf_record.TFRecordWriter(f)
for j in range(m):
w.write("{0:0{width}}".format(i * m + j, width=10))
w.close()
def testRecordInputSimple(self):
with self.test_session() as sess:
self.generateTestData("basic", 1, 1)
yield_op = data_flow_ops.RecordInput(
file_pattern=os.path.join(self.get_temp_dir(), "basic.*"),
parallelism=1,
buffer_size=1,
batch_size=1,
name="record_input").get_yield_op()
self.assertEqual(sess.run(yield_op), "0000000000")
def testRecordInputEpochs(self):
files = 100
records_per_file = 100
with self.test_session() as sess:
self.generateTestData("basic", files, records_per_file)
records = data_flow_ops.RecordInput(
file_pattern=os.path.join(self.get_temp_dir(), "basic.*"),
parallelism=2,
buffer_size=2000,
batch_size=1,
shift_ratio=0.33,
seed=10,
name="record_input")
yield_op = records.get_yield_op()
# cycle over 3 epochs and make sure we never duplicate
for _ in range(3):
epoch_set = set()
for _ in range(files * records_per_file):
r = sess.run(yield_op)
self.assertTrue(r[0] not in epoch_set)
epoch_set.add(r[0])
if __name__ == "__main__":
test.main()

View File

@ -1613,3 +1613,65 @@ class StagingArea(object):
output.set_shape(shape)
return self._get_return_value(ret)
class RecordInput(object):
"""RecordInput asynchronously reads and randomly yields TFRecords.
A RecordInput Op will continuously read a batch of records asynchronously
into a buffer of some fixed capacity. It can also asynchronously yield
random records from this buffer.
It will not start yielding until at least `buffer_size / 2` elements have been
placed into the buffer so that sufficient randomization can take place.
The order the files are read will be shifted each epoch by `shift_amount` so
that the data is presented in a different order every epoch.
"""
def __init__(self,
file_pattern,
batch_size=1,
buffer_size=1,
parallelism=1,
shift_ratio=0,
seed=0,
name=None):
"""Constructs a RecordInput Op.
Args:
file_pattern: File path to the dataset, possibly containing wildcards.
All matching files will be iterated over each epoch.
batch_size: How many records to return at a time.
buffer_size: The maximum number of records the buffer will contain. This
_must_ be smaller than the total number of records in an epoch or
deadlock can occur.
parallelism: How many reader threads to use for reading from files.
shift_ratio: What percentage of the total number files to move the start
file forward by each epoch.
seed: Specify the random number seed used by generator that randomizes
records.
name: Optional name for the operation.
Raises:
ValueError: If one of the arguments is invalid.
"""
self._batch_size = batch_size
self._file_pattern = file_pattern
self._buffer_size = buffer_size
self._parallelism = parallelism
self._shift_ratio = shift_ratio
self._seed = seed
self._name = name
def get_yield_op(self):
"""Add a node that yields a minibatch every time it is executed."""
return gen_data_flow_ops.record_input(
file_pattern=self._file_pattern,
file_buffer_size=self._buffer_size,
file_parallelism=self._parallelism,
file_shuffle_shift_ratio=self._shift_ratio,
batch_size=self._batch_size,
file_random_seed=self._seed,
name=self._name)