From 66b5684133bda0a3050e1573f747d86c645dfd67 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Jan 2017 14:03:09 -0800 Subject: [PATCH] Add a new lightweight queue-like object - RecordInput Change: 144752664 --- tensorflow/core/kernels/BUILD | 14 ++ tensorflow/core/kernels/record_input_op.cc | 67 ++++++ tensorflow/core/kernels/record_yielder.cc | 216 ++++++++++++++++++ tensorflow/core/kernels/record_yielder.h | 157 +++++++++++++ tensorflow/core/ops/data_flow_ops.cc | 23 ++ tensorflow/python/kernel_tests/BUILD | 12 + .../python/kernel_tests/record_input_test.py | 80 +++++++ tensorflow/python/ops/data_flow_ops.py | 62 +++++ 8 files changed, 631 insertions(+) create mode 100644 tensorflow/core/kernels/record_input_op.cc create mode 100644 tensorflow/core/kernels/record_yielder.cc create mode 100644 tensorflow/core/kernels/record_yielder.h create mode 100644 tensorflow/python/kernel_tests/record_input_test.py diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 5137156107f..ac9af1c7932 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -317,6 +317,19 @@ cc_library( ], ) +cc_library( + name = "record_input_op", + srcs = [ + "record_input_op.cc", + "record_yielder.cc", + "record_yielder.h", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + cc_library( name = "save_restore_tensor", srcs = ["save_restore_tensor.cc"], @@ -1167,6 +1180,7 @@ cc_library( ":priority_queue_op", ":queue_ops", ":random_shuffle_queue_op", + ":record_input_op", ":session_ops", ":sparse_conditional_accumulator_op", ":stack_ops", diff --git a/tensorflow/core/kernels/record_input_op.cc b/tensorflow/core/kernels/record_input_op.cc new file mode 100644 index 00000000000..60c0a7d2d8c --- /dev/null +++ b/tensorflow/core/kernels/record_input_op.cc @@ -0,0 +1,67 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/record_yielder.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +class RecordInputOp : public OpKernel { + public: + explicit RecordInputOp(OpKernelConstruction* ctx) : OpKernel(ctx) { +#define GETATTR(TYPE, FIELD) \ + TYPE FIELD; \ + OP_REQUIRES_OK(ctx, ctx->GetAttr(#FIELD, &FIELD)); + + GETATTR(string, file_pattern); + GETATTR(int64, file_random_seed); + GETATTR(float, file_shuffle_shift_ratio); + GETATTR(int64, file_buffer_size); + GETATTR(int64, file_parallelism); + GETATTR(int64, batch_size); +#undef GETATTR + + RecordYielder::Options yopts; + yopts.file_pattern = file_pattern; + yopts.seed = file_random_seed; + yopts.bufsize = file_buffer_size; + yopts.file_shuffle_shift_ratio = file_shuffle_shift_ratio; + yopts.parallelism = file_parallelism; + yielder_ = std::unique_ptr(new RecordYielder(ctx, yopts)); + + batch_size_ = batch_size; + } + + void Compute(OpKernelContext* ctx) override { + Tensor out(DT_STRING, {batch_size_}); + auto t_out = out.flat(); + for (int i = 0; i < batch_size_; ++i) { + OP_REQUIRES_OK(ctx, yielder_->Yield(&t_out(i))); + } + ctx->set_output(0, out); + } + + private: + int64 batch_size_; + std::unique_ptr yielder_; +}; + +REGISTER_KERNEL_BUILDER(Name("RecordInput").Device(DEVICE_CPU), RecordInputOp); +} // namespace tensorflow diff --git a/tensorflow/core/kernels/record_yielder.cc b/tensorflow/core/kernels/record_yielder.cc new file mode 100644 index 00000000000..d65ba2c0260 --- /dev/null +++ b/tensorflow/core/kernels/record_yielder.cc @@ -0,0 +1,216 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/record_yielder.h" + +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +RecordYielder::RecordYielder(OpKernelConstruction* context, + const RecordYielder::Options& opts) + : opts_(opts), + thread_(new thread::ThreadPool(context->env(), "record_yielder", + 1 + opts.parallelism)), + epoch_(0), + rnd_(opts.seed) { + thread_->Schedule([this]() { MainLoop(); }); +} + +RecordYielder::~RecordYielder() { + { + mutex_lock l(mu_); + stop_ = true; + buf_empty_.notify_all(); + buf_enough_.notify_all(); + buf_not_full_.notify_all(); + } + main_loop_done_.WaitForNotification(); + delete thread_; +} + +Status RecordYielder::Yield(string* value) { + mutex_lock l(mu_); + while (!BufEnough()) { + buf_enough_.wait(l); + } + if (status_.ok()) { + bool notify_no_longer_full = !BufNotFull(); + CHECK(!stop_ && !buf_.empty()); + *value = std::move(buf_.back()); + buf_.pop_back(); + ++num_records_yielded_in_epoch_; + // Assumption is that an epoch always has something in the buffer + // until it ends. If the input pipeline was slower than the consumers + // by a lot this might not be true. Not sure how to handle. + if (buf_.empty()) { + buf_empty_.notify_all(); + } + if (notify_no_longer_full) { + buf_not_full_.notify_all(); + } + } + return status_; +} + +struct RecordYielder::Shard { + int index; // Shard index. + std::vector filenames; // File names given to this shard. + Notification done; // Notified when this shard is done. + Status status; // Shard status. +}; + +bool RecordYielder::ShouldFinish(const Status& s) { + mutex_lock l(mu_); + status_.Update(s); + return stop_ || !status_.ok(); +} + +static Status MatchFiles(const string& patterns, + std::vector* filenames) { + for (const auto& file_pattern : str_util::Split(patterns, ',')) { + std::vector tmp_filenames; + TF_RETURN_IF_ERROR( + Env::Default()->GetMatchingPaths(file_pattern, &tmp_filenames)); + filenames->insert(filenames->end(), + std::make_move_iterator(tmp_filenames.begin()), + std::make_move_iterator(tmp_filenames.end())); + } + return Status::OK(); +} + +void RecordYielder::MainLoop() { + while (true) { + ++epoch_; + num_records_yielded_in_epoch_ = 0; + + // Finds all files. + std::vector filenames; + Status s = MatchFiles(opts_.file_pattern, &filenames); + if (ShouldFinish(s)) break; + + if (filenames.empty()) { + s = errors::NotFound("Found no files at ", opts_.file_pattern); + if (ShouldFinish(s)) break; + } + + // Shuffles these files according to the epoch # and random seed. + std::mt19937_64 shuffle_rnd( + Hash64(reinterpret_cast(&epoch_), sizeof(epoch_), opts_.seed)); + std::shuffle(filenames.begin(), filenames.end(), shuffle_rnd); + + // Left-shift the filename list. + const int64 num = filenames.size(); + int64 shift; + if (0 <= opts_.file_shuffle_shift_ratio && + opts_.file_shuffle_shift_ratio < 1) { + shift = opts_.file_shuffle_shift_ratio * num; + std::rotate(filenames.begin(), filenames.begin() + shift, + filenames.end()); + } + + // Shards files and use one thread to go through each shard. + const int N = opts_.parallelism; + std::vector shards(N); + for (int i = 0; i < N; ++i) { + Shard* shard = &shards[i]; + shard->index = i; + for (int j = i; j < filenames.size(); j += N) { + shard->filenames.push_back(filenames[j]); + } + thread_->Schedule([this, shard]() { ShardLoop(shard); }); + } + for (int i = 0; i < N; ++i) { + shards[i].done.WaitForNotification(); + s.Update(shards[i].status); + } + if (ShouldFinish(s)) break; + + // Starts the next epoch once all buffered records are consumed. + { + mutex_lock l(mu_); + epoch_end_ = true; + while (!BufEmpty()) { + buf_empty_.wait(l); + } + epoch_end_ = false; + } + } + main_loop_done_.Notify(); +} + +bool RecordYielder::Add(std::vector* values) { + mutex_lock l(mu_); + while (!BufNotFull()) { + buf_not_full_.wait(l); + } + while (BufNotFull() && !values->empty()) { + // Adds values->back(). Swaps its position with another random + // element. + auto index = rnd_() % (buf_.size() + 1); + if (index == buf_.size()) { + buf_.push_back(std::move(values->back())); + } else { + buf_.push_back(std::move(buf_[index])); + buf_[index] = std::move(values->back()); + } + values->pop_back(); + } + if (BufEnough()) { + buf_enough_.notify_all(); + } + return stop_; +} + +void RecordYielder::ShardLoop(Shard* shard) { + std::vector values; + const int64 kRecords = 16; + for (const string& filename : shard->filenames) { + std::unique_ptr file; + if (ShouldFinish(Status::OK())) break; + Status s = Env::Default()->NewRandomAccessFile(filename, &file); + if (!s.ok()) { + shard->status = errors::InvalidArgument("Can't open ", filename); + break; + } + io::RecordReader rdr(file.get()); + uint64 offset = 0; + string record; + while (true) { + Status s = rdr.ReadRecord(&offset, &record); + if (s.ok()) { + values.emplace_back(std::move(record)); + if (values.size() >= kRecords && Add(&values)) { + shard->status = errors::Aborted("stopped"); + break; + } + } else if (errors::IsOutOfRange(s)) { + break; + } else { + shard->status = s; + break; + } + } + } + // Adds the remaining values of this shard to buf_. + while (!values.empty()) { + Add(&values); + } + shard->done.Notify(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/record_yielder.h b/tensorflow/core/kernels/record_yielder.h new file mode 100644 index 00000000000..d331995e47f --- /dev/null +++ b/tensorflow/core/kernels/record_yielder.h @@ -0,0 +1,157 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_KERNELS_RECORD_YIELDER_H_ +#define TENSORFLOW_KERNELS_RECORD_YIELDER_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// RecordYielder produces value records from a set of tfrecord files +// in a random order. +// +// It guarantees that: +// 1) all records in tfrecords are yielded within every epoch; +// 2) each record is yielded only once within every epoch; +// 3) the order in which records are yielded are highly randomized. +// 4) the peak memory usage is roughly avg record size * +// (opts.bufsize + opts.parellelism * 16). +// +// Usage example: +// RecordYielder::Options opts; +// opts.file_pattern = "input-*"; +// opts.seed = 301; +// opts.bufsize = 1000000; // A randomized buffer with 1M records. +// opts.parallelism = 8; // Uses 8 tfrecord iterators to iterate +// // through all files. +// RecordYielder yielder(opts); +// string val; +// while (true) { +// yielder.Yield(&val); +// // process val +// } +// +// RecordYielder can be accessed by multiple threads concurrently. +class RecordYielder { + public: + struct Options { + // Glob pattern for tfrecords. + string file_pattern; + + // Random seed. It determines how data files are shuffled and how + // records are shuffled. + int64 seed = 0; + + // Each epoch, all files are first shuffled according to the + // random seed and the epoch number, and then all files are + // left-shifted by file_shuffle_shift_ratio * num_files slots. If + // file_shuffle_shift_ratio is not within [0, 1), the + // implementation clip it to [0, 1). + float file_shuffle_shift_ratio = 0; + + // Randomization buffer keeps these many records. + uint64 bufsize = 1; + + // Uses these many concurrent tfrecord iterators to iterate through + // tfrecords. + int32 parallelism = 1; + }; + + explicit RecordYielder(OpKernelConstruction* context, + const RecordYielder::Options& opts); + ~RecordYielder(); + + RecordYielder(const RecordYielder&) = delete; + RecordYielder& operator=(const RecordYielder&) = delete; + + // Yields one 'value'. + Status Yield(string* value); + + // Returns the current epoch number. + int64 current_epoch() const { return epoch_; } + + private: + typedef RecordYielder ME; + + Options opts_; + + // Backgrounds threads. Owned. + thread::ThreadPool* thread_; + + // Epoch number. + std::atomic epoch_; + + mutex mu_; + + // Turned to true when this is deleted. + bool stop_ GUARDED_BY(mu_) = false; + Status status_ GUARDED_BY(mu_); + + // PRG used for randomization. + std::mt19937_64 rnd_ GUARDED_BY(mu_); + + // Randomization buffer. + std::vector buf_ GUARDED_BY(mu_); + + // True iff we are draining an epoch. + bool epoch_end_ = false; + + int64 num_records_yielded_in_epoch_ = 0; + + // Trigger when the main loop has exited. + Notification main_loop_done_; + + // condition_variables. + condition_variable buf_empty_; + bool BufEmpty() const SHARED_LOCKS_REQUIRED(mu_) { + return stop_ || buf_.empty(); + } + + condition_variable buf_not_full_; + bool BufNotFull() const SHARED_LOCKS_REQUIRED(mu_) { + return stop_ || buf_.size() < opts_.bufsize; + } + + condition_variable buf_enough_; + bool BufEnough() const SHARED_LOCKS_REQUIRED(mu_) { + // NOTE: Unless we are finishing an epoch, we want to make sure + // the buf_ contains enough randomized elements before yielding + // any. + return stop_ || !status_.ok() || (epoch_end_ && !buf_.empty()) || + (!epoch_end_ && + buf_.size() >= std::max(1, opts_.bufsize / 2)); + } + + void MainLoop(); + struct Shard; + void ShardLoop(Shard* shard); + bool ShouldFinish(const Status& s); + bool Add(std::vector* values); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_RECORD_YIELDER_H_ diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index 54e766e8e9c..a19d9483a1f 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -2211,4 +2211,27 @@ dequeue with many fewer capabilities and options. This Op is optimized for performance. )doc"); +REGISTER_OP("RecordInput") + .Output("records: string") + .Attr("file_pattern: string") + .Attr("file_random_seed: int = 301") + .Attr("file_shuffle_shift_ratio: float = 0") + .Attr("file_buffer_size: int = 10000") + .Attr("file_parallelism: int = 16") + .Attr("batch_size: int = 32") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Emits randomized records. + +records: A tensor of shape [batch_size]. +file_pattern: Glob pattern for the data files. +file_random_seed: Random seeds used to produce randomized records. +file_shuffle_shift_ratio: Shifts the list of files after the list is randomly + shuffled. +file_buffer_size: The randomization shuffling buffer. +file_parallelism: How many sstables are opened and concurrently iterated over. +batch_size: The batch size. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index f4c3dcf99fb..13b6923c3c8 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -341,6 +341,18 @@ tf_py_test( ], ) +tf_py_test( + name = "record_input_test", + size = "small", + srcs = ["record_input_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:util", + ], +) + tf_py_test( name = "io_ops_test", size = "small", diff --git a/tensorflow/python/kernel_tests/record_input_test.py b/tensorflow/python/kernel_tests/record_input_test.py new file mode 100644 index 00000000000..4aa40595760 --- /dev/null +++ b/tensorflow/python/kernel_tests/record_input_test.py @@ -0,0 +1,80 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for record_input_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.lib.io import tf_record +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.platform import test + + +class RecordInputOpTest(test.TestCase): + + def generateTestData(self, prefix, n, m): + for i in range(n): + f = os.path.join(self.get_temp_dir(), prefix + "." + str(i)) + w = tf_record.TFRecordWriter(f) + + for j in range(m): + w.write("{0:0{width}}".format(i * m + j, width=10)) + + w.close() + + def testRecordInputSimple(self): + with self.test_session() as sess: + self.generateTestData("basic", 1, 1) + + yield_op = data_flow_ops.RecordInput( + file_pattern=os.path.join(self.get_temp_dir(), "basic.*"), + parallelism=1, + buffer_size=1, + batch_size=1, + name="record_input").get_yield_op() + + self.assertEqual(sess.run(yield_op), "0000000000") + + def testRecordInputEpochs(self): + files = 100 + records_per_file = 100 + with self.test_session() as sess: + self.generateTestData("basic", files, records_per_file) + + records = data_flow_ops.RecordInput( + file_pattern=os.path.join(self.get_temp_dir(), "basic.*"), + parallelism=2, + buffer_size=2000, + batch_size=1, + shift_ratio=0.33, + seed=10, + name="record_input") + + yield_op = records.get_yield_op() + + # cycle over 3 epochs and make sure we never duplicate + for _ in range(3): + epoch_set = set() + for _ in range(files * records_per_file): + r = sess.run(yield_op) + self.assertTrue(r[0] not in epoch_set) + epoch_set.add(r[0]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 72f0454e30c..037c3a8187f 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -1613,3 +1613,65 @@ class StagingArea(object): output.set_shape(shape) return self._get_return_value(ret) + + +class RecordInput(object): + """RecordInput asynchronously reads and randomly yields TFRecords. + + A RecordInput Op will continuously read a batch of records asynchronously + into a buffer of some fixed capacity. It can also asynchronously yield + random records from this buffer. + + It will not start yielding until at least `buffer_size / 2` elements have been + placed into the buffer so that sufficient randomization can take place. + + The order the files are read will be shifted each epoch by `shift_amount` so + that the data is presented in a different order every epoch. + """ + + def __init__(self, + file_pattern, + batch_size=1, + buffer_size=1, + parallelism=1, + shift_ratio=0, + seed=0, + name=None): + """Constructs a RecordInput Op. + + Args: + file_pattern: File path to the dataset, possibly containing wildcards. + All matching files will be iterated over each epoch. + batch_size: How many records to return at a time. + buffer_size: The maximum number of records the buffer will contain. This + _must_ be smaller than the total number of records in an epoch or + deadlock can occur. + parallelism: How many reader threads to use for reading from files. + shift_ratio: What percentage of the total number files to move the start + file forward by each epoch. + seed: Specify the random number seed used by generator that randomizes + records. + name: Optional name for the operation. + + Raises: + ValueError: If one of the arguments is invalid. + """ + + self._batch_size = batch_size + self._file_pattern = file_pattern + self._buffer_size = buffer_size + self._parallelism = parallelism + self._shift_ratio = shift_ratio + self._seed = seed + self._name = name + + def get_yield_op(self): + """Add a node that yields a minibatch every time it is executed.""" + return gen_data_flow_ops.record_input( + file_pattern=self._file_pattern, + file_buffer_size=self._buffer_size, + file_parallelism=self._parallelism, + file_shuffle_shift_ratio=self._shift_ratio, + batch_size=self._batch_size, + file_random_seed=self._seed, + name=self._name)