Add a new lightweight queue-like object - RecordInput
Change: 144752664
This commit is contained in:
parent
bf67d0a1c5
commit
66b5684133
@ -317,6 +317,19 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "record_input_op",
|
||||||
|
srcs = [
|
||||||
|
"record_input_op.cc",
|
||||||
|
"record_yielder.cc",
|
||||||
|
"record_yielder.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "save_restore_tensor",
|
name = "save_restore_tensor",
|
||||||
srcs = ["save_restore_tensor.cc"],
|
srcs = ["save_restore_tensor.cc"],
|
||||||
@ -1167,6 +1180,7 @@ cc_library(
|
|||||||
":priority_queue_op",
|
":priority_queue_op",
|
||||||
":queue_ops",
|
":queue_ops",
|
||||||
":random_shuffle_queue_op",
|
":random_shuffle_queue_op",
|
||||||
|
":record_input_op",
|
||||||
":session_ops",
|
":session_ops",
|
||||||
":sparse_conditional_accumulator_op",
|
":sparse_conditional_accumulator_op",
|
||||||
":stack_ops",
|
":stack_ops",
|
||||||
|
|||||||
67
tensorflow/core/kernels/record_input_op.cc
Normal file
67
tensorflow/core/kernels/record_input_op.cc
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/kernels/record_yielder.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class RecordInputOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit RecordInputOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
|
#define GETATTR(TYPE, FIELD) \
|
||||||
|
TYPE FIELD; \
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr(#FIELD, &FIELD));
|
||||||
|
|
||||||
|
GETATTR(string, file_pattern);
|
||||||
|
GETATTR(int64, file_random_seed);
|
||||||
|
GETATTR(float, file_shuffle_shift_ratio);
|
||||||
|
GETATTR(int64, file_buffer_size);
|
||||||
|
GETATTR(int64, file_parallelism);
|
||||||
|
GETATTR(int64, batch_size);
|
||||||
|
#undef GETATTR
|
||||||
|
|
||||||
|
RecordYielder::Options yopts;
|
||||||
|
yopts.file_pattern = file_pattern;
|
||||||
|
yopts.seed = file_random_seed;
|
||||||
|
yopts.bufsize = file_buffer_size;
|
||||||
|
yopts.file_shuffle_shift_ratio = file_shuffle_shift_ratio;
|
||||||
|
yopts.parallelism = file_parallelism;
|
||||||
|
yielder_ = std::unique_ptr<RecordYielder>(new RecordYielder(ctx, yopts));
|
||||||
|
|
||||||
|
batch_size_ = batch_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
Tensor out(DT_STRING, {batch_size_});
|
||||||
|
auto t_out = out.flat<string>();
|
||||||
|
for (int i = 0; i < batch_size_; ++i) {
|
||||||
|
OP_REQUIRES_OK(ctx, yielder_->Yield(&t_out(i)));
|
||||||
|
}
|
||||||
|
ctx->set_output(0, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64 batch_size_;
|
||||||
|
std::unique_ptr<RecordYielder> yielder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("RecordInput").Device(DEVICE_CPU), RecordInputOp);
|
||||||
|
} // namespace tensorflow
|
||||||
216
tensorflow/core/kernels/record_yielder.cc
Normal file
216
tensorflow/core/kernels/record_yielder.cc
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/record_yielder.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/io/record_reader.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
RecordYielder::RecordYielder(OpKernelConstruction* context,
|
||||||
|
const RecordYielder::Options& opts)
|
||||||
|
: opts_(opts),
|
||||||
|
thread_(new thread::ThreadPool(context->env(), "record_yielder",
|
||||||
|
1 + opts.parallelism)),
|
||||||
|
epoch_(0),
|
||||||
|
rnd_(opts.seed) {
|
||||||
|
thread_->Schedule([this]() { MainLoop(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
RecordYielder::~RecordYielder() {
|
||||||
|
{
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
stop_ = true;
|
||||||
|
buf_empty_.notify_all();
|
||||||
|
buf_enough_.notify_all();
|
||||||
|
buf_not_full_.notify_all();
|
||||||
|
}
|
||||||
|
main_loop_done_.WaitForNotification();
|
||||||
|
delete thread_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RecordYielder::Yield(string* value) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
while (!BufEnough()) {
|
||||||
|
buf_enough_.wait(l);
|
||||||
|
}
|
||||||
|
if (status_.ok()) {
|
||||||
|
bool notify_no_longer_full = !BufNotFull();
|
||||||
|
CHECK(!stop_ && !buf_.empty());
|
||||||
|
*value = std::move(buf_.back());
|
||||||
|
buf_.pop_back();
|
||||||
|
++num_records_yielded_in_epoch_;
|
||||||
|
// Assumption is that an epoch always has something in the buffer
|
||||||
|
// until it ends. If the input pipeline was slower than the consumers
|
||||||
|
// by a lot this might not be true. Not sure how to handle.
|
||||||
|
if (buf_.empty()) {
|
||||||
|
buf_empty_.notify_all();
|
||||||
|
}
|
||||||
|
if (notify_no_longer_full) {
|
||||||
|
buf_not_full_.notify_all();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return status_;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RecordYielder::Shard {
|
||||||
|
int index; // Shard index.
|
||||||
|
std::vector<string> filenames; // File names given to this shard.
|
||||||
|
Notification done; // Notified when this shard is done.
|
||||||
|
Status status; // Shard status.
|
||||||
|
};
|
||||||
|
|
||||||
|
bool RecordYielder::ShouldFinish(const Status& s) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
status_.Update(s);
|
||||||
|
return stop_ || !status_.ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Status MatchFiles(const string& patterns,
|
||||||
|
std::vector<string>* filenames) {
|
||||||
|
for (const auto& file_pattern : str_util::Split(patterns, ',')) {
|
||||||
|
std::vector<string> tmp_filenames;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
Env::Default()->GetMatchingPaths(file_pattern, &tmp_filenames));
|
||||||
|
filenames->insert(filenames->end(),
|
||||||
|
std::make_move_iterator(tmp_filenames.begin()),
|
||||||
|
std::make_move_iterator(tmp_filenames.end()));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecordYielder::MainLoop() {
|
||||||
|
while (true) {
|
||||||
|
++epoch_;
|
||||||
|
num_records_yielded_in_epoch_ = 0;
|
||||||
|
|
||||||
|
// Finds all files.
|
||||||
|
std::vector<string> filenames;
|
||||||
|
Status s = MatchFiles(opts_.file_pattern, &filenames);
|
||||||
|
if (ShouldFinish(s)) break;
|
||||||
|
|
||||||
|
if (filenames.empty()) {
|
||||||
|
s = errors::NotFound("Found no files at ", opts_.file_pattern);
|
||||||
|
if (ShouldFinish(s)) break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shuffles these files according to the epoch # and random seed.
|
||||||
|
std::mt19937_64 shuffle_rnd(
|
||||||
|
Hash64(reinterpret_cast<char*>(&epoch_), sizeof(epoch_), opts_.seed));
|
||||||
|
std::shuffle(filenames.begin(), filenames.end(), shuffle_rnd);
|
||||||
|
|
||||||
|
// Left-shift the filename list.
|
||||||
|
const int64 num = filenames.size();
|
||||||
|
int64 shift;
|
||||||
|
if (0 <= opts_.file_shuffle_shift_ratio &&
|
||||||
|
opts_.file_shuffle_shift_ratio < 1) {
|
||||||
|
shift = opts_.file_shuffle_shift_ratio * num;
|
||||||
|
std::rotate(filenames.begin(), filenames.begin() + shift,
|
||||||
|
filenames.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shards files and use one thread to go through each shard.
|
||||||
|
const int N = opts_.parallelism;
|
||||||
|
std::vector<Shard> shards(N);
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
Shard* shard = &shards[i];
|
||||||
|
shard->index = i;
|
||||||
|
for (int j = i; j < filenames.size(); j += N) {
|
||||||
|
shard->filenames.push_back(filenames[j]);
|
||||||
|
}
|
||||||
|
thread_->Schedule([this, shard]() { ShardLoop(shard); });
|
||||||
|
}
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
shards[i].done.WaitForNotification();
|
||||||
|
s.Update(shards[i].status);
|
||||||
|
}
|
||||||
|
if (ShouldFinish(s)) break;
|
||||||
|
|
||||||
|
// Starts the next epoch once all buffered records are consumed.
|
||||||
|
{
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
epoch_end_ = true;
|
||||||
|
while (!BufEmpty()) {
|
||||||
|
buf_empty_.wait(l);
|
||||||
|
}
|
||||||
|
epoch_end_ = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
main_loop_done_.Notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RecordYielder::Add(std::vector<string>* values) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
while (!BufNotFull()) {
|
||||||
|
buf_not_full_.wait(l);
|
||||||
|
}
|
||||||
|
while (BufNotFull() && !values->empty()) {
|
||||||
|
// Adds values->back(). Swaps its position with another random
|
||||||
|
// element.
|
||||||
|
auto index = rnd_() % (buf_.size() + 1);
|
||||||
|
if (index == buf_.size()) {
|
||||||
|
buf_.push_back(std::move(values->back()));
|
||||||
|
} else {
|
||||||
|
buf_.push_back(std::move(buf_[index]));
|
||||||
|
buf_[index] = std::move(values->back());
|
||||||
|
}
|
||||||
|
values->pop_back();
|
||||||
|
}
|
||||||
|
if (BufEnough()) {
|
||||||
|
buf_enough_.notify_all();
|
||||||
|
}
|
||||||
|
return stop_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecordYielder::ShardLoop(Shard* shard) {
|
||||||
|
std::vector<string> values;
|
||||||
|
const int64 kRecords = 16;
|
||||||
|
for (const string& filename : shard->filenames) {
|
||||||
|
std::unique_ptr<RandomAccessFile> file;
|
||||||
|
if (ShouldFinish(Status::OK())) break;
|
||||||
|
Status s = Env::Default()->NewRandomAccessFile(filename, &file);
|
||||||
|
if (!s.ok()) {
|
||||||
|
shard->status = errors::InvalidArgument("Can't open ", filename);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
io::RecordReader rdr(file.get());
|
||||||
|
uint64 offset = 0;
|
||||||
|
string record;
|
||||||
|
while (true) {
|
||||||
|
Status s = rdr.ReadRecord(&offset, &record);
|
||||||
|
if (s.ok()) {
|
||||||
|
values.emplace_back(std::move(record));
|
||||||
|
if (values.size() >= kRecords && Add(&values)) {
|
||||||
|
shard->status = errors::Aborted("stopped");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else if (errors::IsOutOfRange(s)) {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
shard->status = s;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Adds the remaining values of this shard to buf_.
|
||||||
|
while (!values.empty()) {
|
||||||
|
Add(&values);
|
||||||
|
}
|
||||||
|
shard->done.Notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
157
tensorflow/core/kernels/record_yielder.h
Normal file
157
tensorflow/core/kernels/record_yielder.h
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_KERNELS_RECORD_YIELDER_H_
|
||||||
|
#define TENSORFLOW_KERNELS_RECORD_YIELDER_H_
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/notification.h"
|
||||||
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// RecordYielder produces value records from a set of tfrecord files
|
||||||
|
// in a random order.
|
||||||
|
//
|
||||||
|
// It guarantees that:
|
||||||
|
// 1) all records in tfrecords are yielded within every epoch;
|
||||||
|
// 2) each record is yielded only once within every epoch;
|
||||||
|
// 3) the order in which records are yielded are highly randomized.
|
||||||
|
// 4) the peak memory usage is roughly avg record size *
|
||||||
|
// (opts.bufsize + opts.parellelism * 16).
|
||||||
|
//
|
||||||
|
// Usage example:
|
||||||
|
// RecordYielder::Options opts;
|
||||||
|
// opts.file_pattern = "input-*";
|
||||||
|
// opts.seed = 301;
|
||||||
|
// opts.bufsize = 1000000; // A randomized buffer with 1M records.
|
||||||
|
// opts.parallelism = 8; // Uses 8 tfrecord iterators to iterate
|
||||||
|
// // through all files.
|
||||||
|
// RecordYielder yielder(opts);
|
||||||
|
// string val;
|
||||||
|
// while (true) {
|
||||||
|
// yielder.Yield(&val);
|
||||||
|
// // process val
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RecordYielder can be accessed by multiple threads concurrently.
|
||||||
|
class RecordYielder {
|
||||||
|
public:
|
||||||
|
struct Options {
|
||||||
|
// Glob pattern for tfrecords.
|
||||||
|
string file_pattern;
|
||||||
|
|
||||||
|
// Random seed. It determines how data files are shuffled and how
|
||||||
|
// records are shuffled.
|
||||||
|
int64 seed = 0;
|
||||||
|
|
||||||
|
// Each epoch, all files are first shuffled according to the
|
||||||
|
// random seed and the epoch number, and then all files are
|
||||||
|
// left-shifted by file_shuffle_shift_ratio * num_files slots. If
|
||||||
|
// file_shuffle_shift_ratio is not within [0, 1), the
|
||||||
|
// implementation clip it to [0, 1).
|
||||||
|
float file_shuffle_shift_ratio = 0;
|
||||||
|
|
||||||
|
// Randomization buffer keeps these many records.
|
||||||
|
uint64 bufsize = 1;
|
||||||
|
|
||||||
|
// Uses these many concurrent tfrecord iterators to iterate through
|
||||||
|
// tfrecords.
|
||||||
|
int32 parallelism = 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
explicit RecordYielder(OpKernelConstruction* context,
|
||||||
|
const RecordYielder::Options& opts);
|
||||||
|
~RecordYielder();
|
||||||
|
|
||||||
|
RecordYielder(const RecordYielder&) = delete;
|
||||||
|
RecordYielder& operator=(const RecordYielder&) = delete;
|
||||||
|
|
||||||
|
// Yields one 'value'.
|
||||||
|
Status Yield(string* value);
|
||||||
|
|
||||||
|
// Returns the current epoch number.
|
||||||
|
int64 current_epoch() const { return epoch_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
typedef RecordYielder ME;
|
||||||
|
|
||||||
|
Options opts_;
|
||||||
|
|
||||||
|
// Backgrounds threads. Owned.
|
||||||
|
thread::ThreadPool* thread_;
|
||||||
|
|
||||||
|
// Epoch number.
|
||||||
|
std::atomic<int64> epoch_;
|
||||||
|
|
||||||
|
mutex mu_;
|
||||||
|
|
||||||
|
// Turned to true when this is deleted.
|
||||||
|
bool stop_ GUARDED_BY(mu_) = false;
|
||||||
|
Status status_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
// PRG used for randomization.
|
||||||
|
std::mt19937_64 rnd_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
// Randomization buffer.
|
||||||
|
std::vector<string> buf_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
// True iff we are draining an epoch.
|
||||||
|
bool epoch_end_ = false;
|
||||||
|
|
||||||
|
int64 num_records_yielded_in_epoch_ = 0;
|
||||||
|
|
||||||
|
// Trigger when the main loop has exited.
|
||||||
|
Notification main_loop_done_;
|
||||||
|
|
||||||
|
// condition_variables.
|
||||||
|
condition_variable buf_empty_;
|
||||||
|
bool BufEmpty() const SHARED_LOCKS_REQUIRED(mu_) {
|
||||||
|
return stop_ || buf_.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
condition_variable buf_not_full_;
|
||||||
|
bool BufNotFull() const SHARED_LOCKS_REQUIRED(mu_) {
|
||||||
|
return stop_ || buf_.size() < opts_.bufsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
condition_variable buf_enough_;
|
||||||
|
bool BufEnough() const SHARED_LOCKS_REQUIRED(mu_) {
|
||||||
|
// NOTE: Unless we are finishing an epoch, we want to make sure
|
||||||
|
// the buf_ contains enough randomized elements before yielding
|
||||||
|
// any.
|
||||||
|
return stop_ || !status_.ok() || (epoch_end_ && !buf_.empty()) ||
|
||||||
|
(!epoch_end_ &&
|
||||||
|
buf_.size() >= std::max<int64>(1, opts_.bufsize / 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
void MainLoop();
|
||||||
|
struct Shard;
|
||||||
|
void ShardLoop(Shard* shard);
|
||||||
|
bool ShouldFinish(const Status& s);
|
||||||
|
bool Add(std::vector<string>* values);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_KERNELS_RECORD_YIELDER_H_
|
||||||
@ -2211,4 +2211,27 @@ dequeue with many fewer capabilities and options. This Op is optimized for
|
|||||||
performance.
|
performance.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("RecordInput")
|
||||||
|
.Output("records: string")
|
||||||
|
.Attr("file_pattern: string")
|
||||||
|
.Attr("file_random_seed: int = 301")
|
||||||
|
.Attr("file_shuffle_shift_ratio: float = 0")
|
||||||
|
.Attr("file_buffer_size: int = 10000")
|
||||||
|
.Attr("file_parallelism: int = 16")
|
||||||
|
.Attr("batch_size: int = 32")
|
||||||
|
.SetIsStateful()
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
|
.Doc(R"doc(
|
||||||
|
Emits randomized records.
|
||||||
|
|
||||||
|
records: A tensor of shape [batch_size].
|
||||||
|
file_pattern: Glob pattern for the data files.
|
||||||
|
file_random_seed: Random seeds used to produce randomized records.
|
||||||
|
file_shuffle_shift_ratio: Shifts the list of files after the list is randomly
|
||||||
|
shuffled.
|
||||||
|
file_buffer_size: The randomization shuffling buffer.
|
||||||
|
file_parallelism: How many sstables are opened and concurrently iterated over.
|
||||||
|
batch_size: The batch size.
|
||||||
|
)doc");
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
@ -341,6 +341,18 @@ tf_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "record_input_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["record_input_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:data_flow_ops",
|
||||||
|
"//tensorflow/python:io_ops",
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "io_ops_test",
|
name = "io_ops_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
|||||||
80
tensorflow/python/kernel_tests/record_input_test.py
Normal file
80
tensorflow/python/kernel_tests/record_input_test.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for record_input_op."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from tensorflow.python.lib.io import tf_record
|
||||||
|
from tensorflow.python.ops import data_flow_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class RecordInputOpTest(test.TestCase):
|
||||||
|
|
||||||
|
def generateTestData(self, prefix, n, m):
|
||||||
|
for i in range(n):
|
||||||
|
f = os.path.join(self.get_temp_dir(), prefix + "." + str(i))
|
||||||
|
w = tf_record.TFRecordWriter(f)
|
||||||
|
|
||||||
|
for j in range(m):
|
||||||
|
w.write("{0:0{width}}".format(i * m + j, width=10))
|
||||||
|
|
||||||
|
w.close()
|
||||||
|
|
||||||
|
def testRecordInputSimple(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
self.generateTestData("basic", 1, 1)
|
||||||
|
|
||||||
|
yield_op = data_flow_ops.RecordInput(
|
||||||
|
file_pattern=os.path.join(self.get_temp_dir(), "basic.*"),
|
||||||
|
parallelism=1,
|
||||||
|
buffer_size=1,
|
||||||
|
batch_size=1,
|
||||||
|
name="record_input").get_yield_op()
|
||||||
|
|
||||||
|
self.assertEqual(sess.run(yield_op), "0000000000")
|
||||||
|
|
||||||
|
def testRecordInputEpochs(self):
|
||||||
|
files = 100
|
||||||
|
records_per_file = 100
|
||||||
|
with self.test_session() as sess:
|
||||||
|
self.generateTestData("basic", files, records_per_file)
|
||||||
|
|
||||||
|
records = data_flow_ops.RecordInput(
|
||||||
|
file_pattern=os.path.join(self.get_temp_dir(), "basic.*"),
|
||||||
|
parallelism=2,
|
||||||
|
buffer_size=2000,
|
||||||
|
batch_size=1,
|
||||||
|
shift_ratio=0.33,
|
||||||
|
seed=10,
|
||||||
|
name="record_input")
|
||||||
|
|
||||||
|
yield_op = records.get_yield_op()
|
||||||
|
|
||||||
|
# cycle over 3 epochs and make sure we never duplicate
|
||||||
|
for _ in range(3):
|
||||||
|
epoch_set = set()
|
||||||
|
for _ in range(files * records_per_file):
|
||||||
|
r = sess.run(yield_op)
|
||||||
|
self.assertTrue(r[0] not in epoch_set)
|
||||||
|
epoch_set.add(r[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
||||||
@ -1613,3 +1613,65 @@ class StagingArea(object):
|
|||||||
output.set_shape(shape)
|
output.set_shape(shape)
|
||||||
|
|
||||||
return self._get_return_value(ret)
|
return self._get_return_value(ret)
|
||||||
|
|
||||||
|
|
||||||
|
class RecordInput(object):
|
||||||
|
"""RecordInput asynchronously reads and randomly yields TFRecords.
|
||||||
|
|
||||||
|
A RecordInput Op will continuously read a batch of records asynchronously
|
||||||
|
into a buffer of some fixed capacity. It can also asynchronously yield
|
||||||
|
random records from this buffer.
|
||||||
|
|
||||||
|
It will not start yielding until at least `buffer_size / 2` elements have been
|
||||||
|
placed into the buffer so that sufficient randomization can take place.
|
||||||
|
|
||||||
|
The order the files are read will be shifted each epoch by `shift_amount` so
|
||||||
|
that the data is presented in a different order every epoch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
file_pattern,
|
||||||
|
batch_size=1,
|
||||||
|
buffer_size=1,
|
||||||
|
parallelism=1,
|
||||||
|
shift_ratio=0,
|
||||||
|
seed=0,
|
||||||
|
name=None):
|
||||||
|
"""Constructs a RecordInput Op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_pattern: File path to the dataset, possibly containing wildcards.
|
||||||
|
All matching files will be iterated over each epoch.
|
||||||
|
batch_size: How many records to return at a time.
|
||||||
|
buffer_size: The maximum number of records the buffer will contain. This
|
||||||
|
_must_ be smaller than the total number of records in an epoch or
|
||||||
|
deadlock can occur.
|
||||||
|
parallelism: How many reader threads to use for reading from files.
|
||||||
|
shift_ratio: What percentage of the total number files to move the start
|
||||||
|
file forward by each epoch.
|
||||||
|
seed: Specify the random number seed used by generator that randomizes
|
||||||
|
records.
|
||||||
|
name: Optional name for the operation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If one of the arguments is invalid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._batch_size = batch_size
|
||||||
|
self._file_pattern = file_pattern
|
||||||
|
self._buffer_size = buffer_size
|
||||||
|
self._parallelism = parallelism
|
||||||
|
self._shift_ratio = shift_ratio
|
||||||
|
self._seed = seed
|
||||||
|
self._name = name
|
||||||
|
|
||||||
|
def get_yield_op(self):
|
||||||
|
"""Add a node that yields a minibatch every time it is executed."""
|
||||||
|
return gen_data_flow_ops.record_input(
|
||||||
|
file_pattern=self._file_pattern,
|
||||||
|
file_buffer_size=self._buffer_size,
|
||||||
|
file_parallelism=self._parallelism,
|
||||||
|
file_shuffle_shift_ratio=self._shift_ratio,
|
||||||
|
batch_size=self._batch_size,
|
||||||
|
file_random_seed=self._seed,
|
||||||
|
name=self._name)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user