Add a new lightweight queue-like object - RecordInput
Change: 144752664
This commit is contained in:
parent
bf67d0a1c5
commit
66b5684133
@ -317,6 +317,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "record_input_op",
|
||||
srcs = [
|
||||
"record_input_op.cc",
|
||||
"record_yielder.cc",
|
||||
"record_yielder.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "save_restore_tensor",
|
||||
srcs = ["save_restore_tensor.cc"],
|
||||
@ -1167,6 +1180,7 @@ cc_library(
|
||||
":priority_queue_op",
|
||||
":queue_ops",
|
||||
":random_shuffle_queue_op",
|
||||
":record_input_op",
|
||||
":session_ops",
|
||||
":sparse_conditional_accumulator_op",
|
||||
":stack_ops",
|
||||
|
67
tensorflow/core/kernels/record_input_op.cc
Normal file
67
tensorflow/core/kernels/record_input_op.cc
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/record_yielder.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class RecordInputOp : public OpKernel {
|
||||
public:
|
||||
explicit RecordInputOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
#define GETATTR(TYPE, FIELD) \
|
||||
TYPE FIELD; \
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(#FIELD, &FIELD));
|
||||
|
||||
GETATTR(string, file_pattern);
|
||||
GETATTR(int64, file_random_seed);
|
||||
GETATTR(float, file_shuffle_shift_ratio);
|
||||
GETATTR(int64, file_buffer_size);
|
||||
GETATTR(int64, file_parallelism);
|
||||
GETATTR(int64, batch_size);
|
||||
#undef GETATTR
|
||||
|
||||
RecordYielder::Options yopts;
|
||||
yopts.file_pattern = file_pattern;
|
||||
yopts.seed = file_random_seed;
|
||||
yopts.bufsize = file_buffer_size;
|
||||
yopts.file_shuffle_shift_ratio = file_shuffle_shift_ratio;
|
||||
yopts.parallelism = file_parallelism;
|
||||
yielder_ = std::unique_ptr<RecordYielder>(new RecordYielder(ctx, yopts));
|
||||
|
||||
batch_size_ = batch_size;
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Tensor out(DT_STRING, {batch_size_});
|
||||
auto t_out = out.flat<string>();
|
||||
for (int i = 0; i < batch_size_; ++i) {
|
||||
OP_REQUIRES_OK(ctx, yielder_->Yield(&t_out(i)));
|
||||
}
|
||||
ctx->set_output(0, out);
|
||||
}
|
||||
|
||||
private:
|
||||
int64 batch_size_;
|
||||
std::unique_ptr<RecordYielder> yielder_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RecordInput").Device(DEVICE_CPU), RecordInputOp);
|
||||
} // namespace tensorflow
|
216
tensorflow/core/kernels/record_yielder.cc
Normal file
216
tensorflow/core/kernels/record_yielder.cc
Normal file
@ -0,0 +1,216 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/record_yielder.h"
|
||||
|
||||
#include "tensorflow/core/lib/io/record_reader.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
RecordYielder::RecordYielder(OpKernelConstruction* context,
|
||||
const RecordYielder::Options& opts)
|
||||
: opts_(opts),
|
||||
thread_(new thread::ThreadPool(context->env(), "record_yielder",
|
||||
1 + opts.parallelism)),
|
||||
epoch_(0),
|
||||
rnd_(opts.seed) {
|
||||
thread_->Schedule([this]() { MainLoop(); });
|
||||
}
|
||||
|
||||
RecordYielder::~RecordYielder() {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
stop_ = true;
|
||||
buf_empty_.notify_all();
|
||||
buf_enough_.notify_all();
|
||||
buf_not_full_.notify_all();
|
||||
}
|
||||
main_loop_done_.WaitForNotification();
|
||||
delete thread_;
|
||||
}
|
||||
|
||||
Status RecordYielder::Yield(string* value) {
|
||||
mutex_lock l(mu_);
|
||||
while (!BufEnough()) {
|
||||
buf_enough_.wait(l);
|
||||
}
|
||||
if (status_.ok()) {
|
||||
bool notify_no_longer_full = !BufNotFull();
|
||||
CHECK(!stop_ && !buf_.empty());
|
||||
*value = std::move(buf_.back());
|
||||
buf_.pop_back();
|
||||
++num_records_yielded_in_epoch_;
|
||||
// Assumption is that an epoch always has something in the buffer
|
||||
// until it ends. If the input pipeline was slower than the consumers
|
||||
// by a lot this might not be true. Not sure how to handle.
|
||||
if (buf_.empty()) {
|
||||
buf_empty_.notify_all();
|
||||
}
|
||||
if (notify_no_longer_full) {
|
||||
buf_not_full_.notify_all();
|
||||
}
|
||||
}
|
||||
return status_;
|
||||
}
|
||||
|
||||
struct RecordYielder::Shard {
|
||||
int index; // Shard index.
|
||||
std::vector<string> filenames; // File names given to this shard.
|
||||
Notification done; // Notified when this shard is done.
|
||||
Status status; // Shard status.
|
||||
};
|
||||
|
||||
bool RecordYielder::ShouldFinish(const Status& s) {
|
||||
mutex_lock l(mu_);
|
||||
status_.Update(s);
|
||||
return stop_ || !status_.ok();
|
||||
}
|
||||
|
||||
static Status MatchFiles(const string& patterns,
|
||||
std::vector<string>* filenames) {
|
||||
for (const auto& file_pattern : str_util::Split(patterns, ',')) {
|
||||
std::vector<string> tmp_filenames;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Env::Default()->GetMatchingPaths(file_pattern, &tmp_filenames));
|
||||
filenames->insert(filenames->end(),
|
||||
std::make_move_iterator(tmp_filenames.begin()),
|
||||
std::make_move_iterator(tmp_filenames.end()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RecordYielder::MainLoop() {
|
||||
while (true) {
|
||||
++epoch_;
|
||||
num_records_yielded_in_epoch_ = 0;
|
||||
|
||||
// Finds all files.
|
||||
std::vector<string> filenames;
|
||||
Status s = MatchFiles(opts_.file_pattern, &filenames);
|
||||
if (ShouldFinish(s)) break;
|
||||
|
||||
if (filenames.empty()) {
|
||||
s = errors::NotFound("Found no files at ", opts_.file_pattern);
|
||||
if (ShouldFinish(s)) break;
|
||||
}
|
||||
|
||||
// Shuffles these files according to the epoch # and random seed.
|
||||
std::mt19937_64 shuffle_rnd(
|
||||
Hash64(reinterpret_cast<char*>(&epoch_), sizeof(epoch_), opts_.seed));
|
||||
std::shuffle(filenames.begin(), filenames.end(), shuffle_rnd);
|
||||
|
||||
// Left-shift the filename list.
|
||||
const int64 num = filenames.size();
|
||||
int64 shift;
|
||||
if (0 <= opts_.file_shuffle_shift_ratio &&
|
||||
opts_.file_shuffle_shift_ratio < 1) {
|
||||
shift = opts_.file_shuffle_shift_ratio * num;
|
||||
std::rotate(filenames.begin(), filenames.begin() + shift,
|
||||
filenames.end());
|
||||
}
|
||||
|
||||
// Shards files and use one thread to go through each shard.
|
||||
const int N = opts_.parallelism;
|
||||
std::vector<Shard> shards(N);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
Shard* shard = &shards[i];
|
||||
shard->index = i;
|
||||
for (int j = i; j < filenames.size(); j += N) {
|
||||
shard->filenames.push_back(filenames[j]);
|
||||
}
|
||||
thread_->Schedule([this, shard]() { ShardLoop(shard); });
|
||||
}
|
||||
for (int i = 0; i < N; ++i) {
|
||||
shards[i].done.WaitForNotification();
|
||||
s.Update(shards[i].status);
|
||||
}
|
||||
if (ShouldFinish(s)) break;
|
||||
|
||||
// Starts the next epoch once all buffered records are consumed.
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
epoch_end_ = true;
|
||||
while (!BufEmpty()) {
|
||||
buf_empty_.wait(l);
|
||||
}
|
||||
epoch_end_ = false;
|
||||
}
|
||||
}
|
||||
main_loop_done_.Notify();
|
||||
}
|
||||
|
||||
bool RecordYielder::Add(std::vector<string>* values) {
|
||||
mutex_lock l(mu_);
|
||||
while (!BufNotFull()) {
|
||||
buf_not_full_.wait(l);
|
||||
}
|
||||
while (BufNotFull() && !values->empty()) {
|
||||
// Adds values->back(). Swaps its position with another random
|
||||
// element.
|
||||
auto index = rnd_() % (buf_.size() + 1);
|
||||
if (index == buf_.size()) {
|
||||
buf_.push_back(std::move(values->back()));
|
||||
} else {
|
||||
buf_.push_back(std::move(buf_[index]));
|
||||
buf_[index] = std::move(values->back());
|
||||
}
|
||||
values->pop_back();
|
||||
}
|
||||
if (BufEnough()) {
|
||||
buf_enough_.notify_all();
|
||||
}
|
||||
return stop_;
|
||||
}
|
||||
|
||||
void RecordYielder::ShardLoop(Shard* shard) {
|
||||
std::vector<string> values;
|
||||
const int64 kRecords = 16;
|
||||
for (const string& filename : shard->filenames) {
|
||||
std::unique_ptr<RandomAccessFile> file;
|
||||
if (ShouldFinish(Status::OK())) break;
|
||||
Status s = Env::Default()->NewRandomAccessFile(filename, &file);
|
||||
if (!s.ok()) {
|
||||
shard->status = errors::InvalidArgument("Can't open ", filename);
|
||||
break;
|
||||
}
|
||||
io::RecordReader rdr(file.get());
|
||||
uint64 offset = 0;
|
||||
string record;
|
||||
while (true) {
|
||||
Status s = rdr.ReadRecord(&offset, &record);
|
||||
if (s.ok()) {
|
||||
values.emplace_back(std::move(record));
|
||||
if (values.size() >= kRecords && Add(&values)) {
|
||||
shard->status = errors::Aborted("stopped");
|
||||
break;
|
||||
}
|
||||
} else if (errors::IsOutOfRange(s)) {
|
||||
break;
|
||||
} else {
|
||||
shard->status = s;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Adds the remaining values of this shard to buf_.
|
||||
while (!values.empty()) {
|
||||
Add(&values);
|
||||
}
|
||||
shard->done.Notify();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
157
tensorflow/core/kernels/record_yielder.h
Normal file
157
tensorflow/core/kernels/record_yielder.h
Normal file
@ -0,0 +1,157 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_KERNELS_RECORD_YIELDER_H_
|
||||
#define TENSORFLOW_KERNELS_RECORD_YIELDER_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// RecordYielder produces value records from a set of tfrecord files
|
||||
// in a random order.
|
||||
//
|
||||
// It guarantees that:
|
||||
// 1) all records in tfrecords are yielded within every epoch;
|
||||
// 2) each record is yielded only once within every epoch;
|
||||
// 3) the order in which records are yielded are highly randomized.
|
||||
// 4) the peak memory usage is roughly avg record size *
|
||||
// (opts.bufsize + opts.parellelism * 16).
|
||||
//
|
||||
// Usage example:
|
||||
// RecordYielder::Options opts;
|
||||
// opts.file_pattern = "input-*";
|
||||
// opts.seed = 301;
|
||||
// opts.bufsize = 1000000; // A randomized buffer with 1M records.
|
||||
// opts.parallelism = 8; // Uses 8 tfrecord iterators to iterate
|
||||
// // through all files.
|
||||
// RecordYielder yielder(opts);
|
||||
// string val;
|
||||
// while (true) {
|
||||
// yielder.Yield(&val);
|
||||
// // process val
|
||||
// }
|
||||
//
|
||||
// RecordYielder can be accessed by multiple threads concurrently.
|
||||
class RecordYielder {
|
||||
public:
|
||||
struct Options {
|
||||
// Glob pattern for tfrecords.
|
||||
string file_pattern;
|
||||
|
||||
// Random seed. It determines how data files are shuffled and how
|
||||
// records are shuffled.
|
||||
int64 seed = 0;
|
||||
|
||||
// Each epoch, all files are first shuffled according to the
|
||||
// random seed and the epoch number, and then all files are
|
||||
// left-shifted by file_shuffle_shift_ratio * num_files slots. If
|
||||
// file_shuffle_shift_ratio is not within [0, 1), the
|
||||
// implementation clip it to [0, 1).
|
||||
float file_shuffle_shift_ratio = 0;
|
||||
|
||||
// Randomization buffer keeps these many records.
|
||||
uint64 bufsize = 1;
|
||||
|
||||
// Uses these many concurrent tfrecord iterators to iterate through
|
||||
// tfrecords.
|
||||
int32 parallelism = 1;
|
||||
};
|
||||
|
||||
explicit RecordYielder(OpKernelConstruction* context,
|
||||
const RecordYielder::Options& opts);
|
||||
~RecordYielder();
|
||||
|
||||
RecordYielder(const RecordYielder&) = delete;
|
||||
RecordYielder& operator=(const RecordYielder&) = delete;
|
||||
|
||||
// Yields one 'value'.
|
||||
Status Yield(string* value);
|
||||
|
||||
// Returns the current epoch number.
|
||||
int64 current_epoch() const { return epoch_; }
|
||||
|
||||
private:
|
||||
typedef RecordYielder ME;
|
||||
|
||||
Options opts_;
|
||||
|
||||
// Backgrounds threads. Owned.
|
||||
thread::ThreadPool* thread_;
|
||||
|
||||
// Epoch number.
|
||||
std::atomic<int64> epoch_;
|
||||
|
||||
mutex mu_;
|
||||
|
||||
// Turned to true when this is deleted.
|
||||
bool stop_ GUARDED_BY(mu_) = false;
|
||||
Status status_ GUARDED_BY(mu_);
|
||||
|
||||
// PRG used for randomization.
|
||||
std::mt19937_64 rnd_ GUARDED_BY(mu_);
|
||||
|
||||
// Randomization buffer.
|
||||
std::vector<string> buf_ GUARDED_BY(mu_);
|
||||
|
||||
// True iff we are draining an epoch.
|
||||
bool epoch_end_ = false;
|
||||
|
||||
int64 num_records_yielded_in_epoch_ = 0;
|
||||
|
||||
// Trigger when the main loop has exited.
|
||||
Notification main_loop_done_;
|
||||
|
||||
// condition_variables.
|
||||
condition_variable buf_empty_;
|
||||
bool BufEmpty() const SHARED_LOCKS_REQUIRED(mu_) {
|
||||
return stop_ || buf_.empty();
|
||||
}
|
||||
|
||||
condition_variable buf_not_full_;
|
||||
bool BufNotFull() const SHARED_LOCKS_REQUIRED(mu_) {
|
||||
return stop_ || buf_.size() < opts_.bufsize;
|
||||
}
|
||||
|
||||
condition_variable buf_enough_;
|
||||
bool BufEnough() const SHARED_LOCKS_REQUIRED(mu_) {
|
||||
// NOTE: Unless we are finishing an epoch, we want to make sure
|
||||
// the buf_ contains enough randomized elements before yielding
|
||||
// any.
|
||||
return stop_ || !status_.ok() || (epoch_end_ && !buf_.empty()) ||
|
||||
(!epoch_end_ &&
|
||||
buf_.size() >= std::max<int64>(1, opts_.bufsize / 2));
|
||||
}
|
||||
|
||||
void MainLoop();
|
||||
struct Shard;
|
||||
void ShardLoop(Shard* shard);
|
||||
bool ShouldFinish(const Status& s);
|
||||
bool Add(std::vector<string>* values);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_KERNELS_RECORD_YIELDER_H_
|
@ -2211,4 +2211,27 @@ dequeue with many fewer capabilities and options. This Op is optimized for
|
||||
performance.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("RecordInput")
|
||||
.Output("records: string")
|
||||
.Attr("file_pattern: string")
|
||||
.Attr("file_random_seed: int = 301")
|
||||
.Attr("file_shuffle_shift_ratio: float = 0")
|
||||
.Attr("file_buffer_size: int = 10000")
|
||||
.Attr("file_parallelism: int = 16")
|
||||
.Attr("batch_size: int = 32")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
Emits randomized records.
|
||||
|
||||
records: A tensor of shape [batch_size].
|
||||
file_pattern: Glob pattern for the data files.
|
||||
file_random_seed: Random seeds used to produce randomized records.
|
||||
file_shuffle_shift_ratio: Shifts the list of files after the list is randomly
|
||||
shuffled.
|
||||
file_buffer_size: The randomization shuffling buffer.
|
||||
file_parallelism: How many sstables are opened and concurrently iterated over.
|
||||
batch_size: The batch size.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -341,6 +341,18 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "record_input_test",
|
||||
size = "small",
|
||||
srcs = ["record_input_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "io_ops_test",
|
||||
size = "small",
|
||||
|
80
tensorflow/python/kernel_tests/record_input_test.py
Normal file
80
tensorflow/python/kernel_tests/record_input_test.py
Normal file
@ -0,0 +1,80 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for record_input_op."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.lib.io import tf_record
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class RecordInputOpTest(test.TestCase):
|
||||
|
||||
def generateTestData(self, prefix, n, m):
|
||||
for i in range(n):
|
||||
f = os.path.join(self.get_temp_dir(), prefix + "." + str(i))
|
||||
w = tf_record.TFRecordWriter(f)
|
||||
|
||||
for j in range(m):
|
||||
w.write("{0:0{width}}".format(i * m + j, width=10))
|
||||
|
||||
w.close()
|
||||
|
||||
def testRecordInputSimple(self):
|
||||
with self.test_session() as sess:
|
||||
self.generateTestData("basic", 1, 1)
|
||||
|
||||
yield_op = data_flow_ops.RecordInput(
|
||||
file_pattern=os.path.join(self.get_temp_dir(), "basic.*"),
|
||||
parallelism=1,
|
||||
buffer_size=1,
|
||||
batch_size=1,
|
||||
name="record_input").get_yield_op()
|
||||
|
||||
self.assertEqual(sess.run(yield_op), "0000000000")
|
||||
|
||||
def testRecordInputEpochs(self):
|
||||
files = 100
|
||||
records_per_file = 100
|
||||
with self.test_session() as sess:
|
||||
self.generateTestData("basic", files, records_per_file)
|
||||
|
||||
records = data_flow_ops.RecordInput(
|
||||
file_pattern=os.path.join(self.get_temp_dir(), "basic.*"),
|
||||
parallelism=2,
|
||||
buffer_size=2000,
|
||||
batch_size=1,
|
||||
shift_ratio=0.33,
|
||||
seed=10,
|
||||
name="record_input")
|
||||
|
||||
yield_op = records.get_yield_op()
|
||||
|
||||
# cycle over 3 epochs and make sure we never duplicate
|
||||
for _ in range(3):
|
||||
epoch_set = set()
|
||||
for _ in range(files * records_per_file):
|
||||
r = sess.run(yield_op)
|
||||
self.assertTrue(r[0] not in epoch_set)
|
||||
epoch_set.add(r[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1613,3 +1613,65 @@ class StagingArea(object):
|
||||
output.set_shape(shape)
|
||||
|
||||
return self._get_return_value(ret)
|
||||
|
||||
|
||||
class RecordInput(object):
|
||||
"""RecordInput asynchronously reads and randomly yields TFRecords.
|
||||
|
||||
A RecordInput Op will continuously read a batch of records asynchronously
|
||||
into a buffer of some fixed capacity. It can also asynchronously yield
|
||||
random records from this buffer.
|
||||
|
||||
It will not start yielding until at least `buffer_size / 2` elements have been
|
||||
placed into the buffer so that sufficient randomization can take place.
|
||||
|
||||
The order the files are read will be shifted each epoch by `shift_amount` so
|
||||
that the data is presented in a different order every epoch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
file_pattern,
|
||||
batch_size=1,
|
||||
buffer_size=1,
|
||||
parallelism=1,
|
||||
shift_ratio=0,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Constructs a RecordInput Op.
|
||||
|
||||
Args:
|
||||
file_pattern: File path to the dataset, possibly containing wildcards.
|
||||
All matching files will be iterated over each epoch.
|
||||
batch_size: How many records to return at a time.
|
||||
buffer_size: The maximum number of records the buffer will contain. This
|
||||
_must_ be smaller than the total number of records in an epoch or
|
||||
deadlock can occur.
|
||||
parallelism: How many reader threads to use for reading from files.
|
||||
shift_ratio: What percentage of the total number files to move the start
|
||||
file forward by each epoch.
|
||||
seed: Specify the random number seed used by generator that randomizes
|
||||
records.
|
||||
name: Optional name for the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If one of the arguments is invalid.
|
||||
"""
|
||||
|
||||
self._batch_size = batch_size
|
||||
self._file_pattern = file_pattern
|
||||
self._buffer_size = buffer_size
|
||||
self._parallelism = parallelism
|
||||
self._shift_ratio = shift_ratio
|
||||
self._seed = seed
|
||||
self._name = name
|
||||
|
||||
def get_yield_op(self):
|
||||
"""Add a node that yields a minibatch every time it is executed."""
|
||||
return gen_data_flow_ops.record_input(
|
||||
file_pattern=self._file_pattern,
|
||||
file_buffer_size=self._buffer_size,
|
||||
file_parallelism=self._parallelism,
|
||||
file_shuffle_shift_ratio=self._shift_ratio,
|
||||
batch_size=self._batch_size,
|
||||
file_random_seed=self._seed,
|
||||
name=self._name)
|
||||
|
Loading…
x
Reference in New Issue
Block a user