Implement save/restore for ShuffleDataset[reshuffle_each_iteration=False].

Also added SingleSampleAdapter::Skip for restoring rng state.

PiperOrigin-RevId: 174424108
This commit is contained in:
Saurabh Saxena 2017-11-02 22:15:54 -07:00 committed by TensorFlower Gardener
parent 7bb2d57b0b
commit 902c91342a
8 changed files with 635 additions and 21 deletions

View File

@ -393,11 +393,16 @@ py_test(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:training",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//third_party/py/numpy",
],

View File

@ -18,16 +18,22 @@ from __future__ import division
from __future__ import print_function
import collections
import os
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops
from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.training import saver as saver_lib
class ShuffleDatasetTest(test.TestCase):
@ -42,8 +48,9 @@ class ShuffleDatasetTest(test.TestCase):
buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components)
.repeat(count_placeholder))
repeat_dataset = (
contrib_dataset_ops.Dataset.from_tensor_slices(components)
.repeat(count_placeholder))
shuffle_dataset = repeat_dataset.shuffle(buffer_size_placeholder,
seed_placeholder)
@ -134,8 +141,9 @@ class ShuffleDatasetTest(test.TestCase):
def testDefaultArguments(self):
components = [0, 1, 2, 3, 4]
iterator = (dataset_ops.Dataset.from_tensor_slices(components).shuffle(5)
.repeat().make_one_shot_iterator())
iterator = (
contrib_dataset_ops.Dataset.from_tensor_slices(components).shuffle(5)
.repeat().make_one_shot_iterator())
get_next = iterator.get_next()
@ -149,5 +157,322 @@ class ShuffleDatasetTest(test.TestCase):
self.assertEqual(10, counts[i])
class ShuffleDatasetSerializationTest(test.TestCase):
def tearDown(self):
# Remove all checkpoint files.
prefix = self._ckpt_path()
pattern = prefix + "*"
files = gfile.Glob(pattern)
map(gfile.Remove, files)
def _build_graph(self,
range_limit=10,
num_repeats=5,
buffer_size=5,
seed=None,
reshuffle_each_iteration=None,
build_saveable=True):
iterator = dataset_ops.Dataset.range(range_limit).shuffle(
buffer_size,
seed=seed,
reshuffle_each_iteration=reshuffle_each_iteration).repeat(
num_repeats).make_initializable_iterator()
if build_saveable:
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
init_op = iterator.initializer
get_next = iterator.get_next()
ops.add_to_collection("iterator_ops", init_op)
ops.add_to_collection("iterator_ops", get_next)
saver = saver_lib.Saver(allow_empty=True)
return init_op, get_next, saver
def _ckpt_path(self):
return os.path.join(self.get_temp_dir(), "iterator")
def _latest_ckpt(self):
return saver_lib.latest_checkpoint(self.get_temp_dir())
def _save(self, sess, saver):
saver.save(sess, self._ckpt_path())
def _restore(self, saver, sess):
saver.restore(sess, self._latest_ckpt())
def _import_meta_graph(self):
meta_file_path = self._ckpt_path() + ".meta"
return saver_lib.import_meta_graph(meta_file_path)
def _testReadWithBreaks(self, break_points, init_before_restore=False):
seed = 55
range_limit = 10
num_repeats = 5
num_outputs = range_limit * num_repeats
buffer_sizes = [1, 3, 8, 10, 25, 50]
reshuffle_each_iteration = False
for buffer_size in buffer_sizes:
expected = []
actual = []
# Generate the ground truth.
with ops.Graph().as_default() as g:
g.seed = 10
init_op, get_next_op, _ = self._build_graph(
range_limit=range_limit,
num_repeats=num_repeats,
buffer_size=buffer_size,
seed=seed,
reshuffle_each_iteration=reshuffle_each_iteration)
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(num_outputs):
expected.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
# Run and checkpoint after first break_point.
with ops.Graph().as_default() as g:
g.seed = 10
init_op, get_next_op, saver = self._build_graph(
range_limit=range_limit,
num_repeats=num_repeats,
buffer_size=buffer_size,
seed=seed,
reshuffle_each_iteration=reshuffle_each_iteration)
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(break_points[0]):
actual.append(sess.run(get_next_op))
self._save(sess, saver)
# Load from checkpoint and continue running while stopping at each
# subsequent checkpoint.
for i in range(len(break_points)):
with ops.Graph().as_default() as g:
saver = self._import_meta_graph()
init_op, get_next_op = ops.get_collection("iterator_ops")
with self.test_session(graph=g) as sess:
if init_before_restore:
sess.run(init_op)
self._restore(saver, sess)
start = break_points[i]
end = break_points[
i + 1] if i < len(break_points) - 1 else num_outputs
for _ in range(end - start):
actual.append(sess.run(get_next_op))
self._save(sess, saver)
if end == num_outputs:
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
self.assertEqual(expected, actual)
def testSaveRestore(self):
self._testReadWithBreaks([8]) # rng buffer_size: 0
self._testReadWithBreaks([13]) # rng buffer_size: 1
self._testReadWithBreaks([18]) # rng buffer_size: 2
self._testReadWithBreaks([23]) # rng buffer_size: 3
def testSaveUnusedIterator(self):
self._testReadWithBreaks([0])
def testSaveFullyUsedIterator(self):
self._testReadWithBreaks([50])
def testMultipleBreaks(self):
self._testReadWithBreaks([0, 5, 9, 15, 25, 32])
def testIdempotence(self):
# Attempt to save iterator immediately after restoring.
self._testReadWithBreaks([1, 1, 5, 5, 5, 25, 32])
def testInitThenRestore(self):
self._testReadWithBreaks([0, 5, 9, 15, 25, 32], init_before_restore=True)
def testRestoreExhaustedIterator(self):
seed = 55
range_limit = 10
num_repeats = 5
num_outputs = range_limit * num_repeats
buffer_sizes = [1, 3, 8, 10, 25, 50]
reshuffle_each_iteration = False
for buffer_size in buffer_sizes:
with ops.Graph().as_default() as g:
g.seed = 10
init_op, get_next_op, saver = self._build_graph(
range_limit=range_limit,
num_repeats=num_repeats,
buffer_size=buffer_size,
seed=seed,
reshuffle_each_iteration=reshuffle_each_iteration)
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(num_outputs):
sess.run(get_next_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
self._save(sess, saver)
with ops.Graph().as_default() as g:
saver = self._import_meta_graph()
init_op, get_next_op = ops.get_collection("iterator_ops")
with self.test_session(graph=g) as sess:
self._restore(saver, sess)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
def testResetRestoredIterator(self):
seed = 55
range_limit = 10
num_repeats = 5
num_outputs = range_limit * num_repeats
buffer_sizes = [1, 3, 8, 10, 25, 50]
reshuffle_each_iteration = False
for buffer_size in buffer_sizes:
with ops.Graph().as_default() as g:
g.seed = 10
init_op, get_next_op, saver = self._build_graph(
range_limit=range_limit,
num_repeats=num_repeats,
buffer_size=buffer_size,
seed=seed,
reshuffle_each_iteration=reshuffle_each_iteration)
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(num_outputs // 2):
sess.run(get_next_op)
self._save(sess, saver)
outputs = []
with ops.Graph().as_default() as g:
saver = self._import_meta_graph()
init_op, get_next_op = ops.get_collection("iterator_ops")
with self.test_session(graph=g) as sess:
self._restore(saver, sess)
sess.run(init_op)
for _ in range(num_outputs):
outputs.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
expected_outputs_sorted = sorted(
np.array([range(range_limit)
for _ in range(num_repeats)]).flatten())
self.assertEqual(expected_outputs_sorted, sorted(outputs))
def testRestoreInModifiedGraph(self):
seed = 55
break_point = 25
range_limit = 10
num_repeats = 5
num_outputs = range_limit * num_repeats
buffer_sizes = [3, 8, 10, 25, 50]
reshuffle_each_iteration = False
for buffer_size in buffer_sizes:
expected = []
actual_without_restore = []
actual = []
with ops.Graph().as_default() as g:
g.seed = 10
init_op, get_next_op, saver = self._build_graph(
range_limit=range_limit,
num_repeats=num_repeats,
buffer_size=buffer_size,
seed=seed,
reshuffle_each_iteration=reshuffle_each_iteration)
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(break_point):
expected.append(sess.run(get_next_op))
actual.extend(expected)
self._save(sess, saver)
for _ in range(num_outputs - break_point):
expected.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
with ops.Graph().as_default() as g:
g.seed = 20 # Different seed than previous graph for shuffle rngs.
init_op, get_next_op, saver = self._build_graph(
range_limit=range_limit,
num_repeats=num_repeats,
buffer_size=buffer_size,
seed=seed,
reshuffle_each_iteration=reshuffle_each_iteration)
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(num_outputs):
actual_without_restore.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
with ops.Graph().as_default() as g:
g.seed = 20 # Different seed than previous graph for shuffle rngs.
init_op, get_next_op, saver = self._build_graph(
range_limit=range_limit,
num_repeats=num_repeats,
buffer_size=buffer_size,
seed=seed,
reshuffle_each_iteration=reshuffle_each_iteration)
with self.test_session(graph=g) as sess:
self._restore(saver, sess)
for _ in range(num_outputs - break_point):
actual.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
# Since the modified graph has a different random seed it produces a
# different order of examples.
self.assertNotEqual(expected, actual_without_restore)
self.assertEqual(sorted(expected), sorted(actual_without_restore))
self.assertEqual(expected, actual)
def testDoNotBuildSaveable(self):
seed = 55
break_point = 25
range_limit = 10
num_repeats = 5
num_outputs = range_limit * num_repeats
buffer_sizes = [3, 8, 10, 25, 50]
reshuffle_each_iteration = False
for buffer_size in buffer_sizes:
actual = []
with ops.Graph().as_default() as g:
g.seed = 10
init_op, get_next_op, saver = self._build_graph(
range_limit=range_limit,
num_repeats=num_repeats,
buffer_size=buffer_size,
seed=seed,
reshuffle_each_iteration=reshuffle_each_iteration)
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(break_point):
sess.run(get_next_op)
self._save(sess, saver)
with ops.Graph().as_default() as g:
g.seed = 20 # Different seed than previous graph for shuffle rngs.
init_op, get_next_op, saver = self._build_graph(
range_limit=range_limit,
num_repeats=num_repeats,
buffer_size=buffer_size,
seed=seed,
reshuffle_each_iteration=reshuffle_each_iteration,
build_saveable=False)
with self.test_session(graph=g) as sess:
# Since the SaveableObject was not added to Saver's list
# of saveables, iterator state is not restored by saver.restore().
self._restore(saver, sess)
with self.assertRaises(errors.FailedPreconditionError):
sess.run(get_next_op)
sess.run(init_op)
for _ in range(num_outputs):
actual.append(sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
expected_outputs_sorted = sorted(
np.array([range(range_limit) for _ in range(num_repeats)]).flatten())
self.assertEqual(expected_outputs_sorted, sorted(actual))
if __name__ == "__main__":
test.main()

View File

@ -18,6 +18,8 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
@ -47,6 +49,7 @@ class IteratorStateReader {
public:
virtual Status ReadScalar(StringPiece key, int64* val) = 0;
virtual Status ReadScalar(StringPiece key, string* val) = 0;
virtual Status ReadTensor(StringPiece key, Tensor* val) = 0;
virtual bool Contains(StringPiece key) = 0;
virtual ~IteratorStateReader() {}
@ -58,6 +61,7 @@ class IteratorStateWriter {
public:
virtual Status WriteScalar(StringPiece key, const int64 val) = 0;
virtual Status WriteScalar(StringPiece key, const string& val) = 0;
virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
virtual ~IteratorStateWriter() {}
};
@ -112,6 +116,13 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
template <class DatasetType>
Status AddDataset(const DatasetType* dataset,
const std::vector<NodeBuilder::NodeOut>& inputs,
Node** output) {
return AddDataset(dataset, inputs, {}, output);
}
// Adds a node corresponding to the `DatasetType` to the Graph.
// Return value of `DatasetType::op_name()` is used as the op type for the
// node.
@ -122,7 +133,9 @@ class GraphDefBuilderWrapper {
// The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
template <class DatasetType>
Status AddDataset(const DatasetType* dataset,
std::vector<NodeBuilder::NodeOut> inputs, Node** output) {
const std::vector<NodeBuilder::NodeOut>& inputs,
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
Node** output) {
const string& op_type_name = dataset->op_name();
std::unique_ptr<const GraphDefBuilder::Options> opts(
new GraphDefBuilder::Options(b_->opts()));
@ -138,6 +151,10 @@ class GraphDefBuilderWrapper {
opts.reset(new GraphDefBuilder::Options(
opts->WithAttr("output_types", dataset->output_dtypes())));
}
for (auto attr : attrs) {
opts.reset(new GraphDefBuilder::Options(
opts->WithAttr(attr.first, attr.second)));
}
if (opts->HaveError()) {
return errors::Internal("AddDataset: Error building Options.");
}
@ -187,6 +204,11 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
template <typename T>
void BuildAttrValue(const T& value, AttrValue* attr) {
SetAttrValue(value, attr);
}
private:
void AddTensorInternal(const Tensor& val, Node** output) {
*output = ops::SourceOp(

View File

@ -188,6 +188,10 @@ class VariantTensorDataReader : public IteratorStateReader {
return ReadScalarInternal(key, val);
}
Status ReadTensor(StringPiece key, Tensor* val) override {
return ReadTensorInternal(key, val);
}
bool Contains(StringPiece key) override {
return map_.find(key.ToString()) != map_.end();
}
@ -217,6 +221,14 @@ class VariantTensorDataReader : public IteratorStateReader {
return Status::OK();
}
Status ReadTensorInternal(StringPiece key, Tensor* val) {
if (map_.find(key.ToString()) == map_.end()) {
return errors::NotFound(key);
}
*val = data_->tensors(map_[key.ToString()]);
return Status::OK();
}
std::map<string, size_t> map_;
const VariantTensorData* data_; // Not owned.
Status status_;
@ -236,6 +248,10 @@ class VariantTensorDataWriter : public IteratorStateWriter {
return WriteScalarInternal(key, val);
}
Status WriteTensor(StringPiece key, const Tensor& val) override {
return WriteTensorInternal(key, val);
}
// Writes the metadata to `data_`.
Status Flush() {
string metadata;
@ -249,15 +265,19 @@ class VariantTensorDataWriter : public IteratorStateWriter {
private:
template <typename T>
Status WriteScalarInternal(StringPiece key, const T& val) {
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
val_t.scalar<T>()() = val;
return WriteTensorInternal(key, val_t);
}
Status WriteTensorInternal(StringPiece key, const Tensor& val) {
// Write key to the metadata proto. This gets written to `data_`
// when `Flush()` is called. We do this lazily to avoid multiple
// serialization calls.
metadata_proto_.add_keys(key.ToString());
// Update tensors.
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
val_t.scalar<T>()() = val;
*(data_->add_tensors()) = std::move(val_t);
*(data_->add_tensors()) = val;
return Status::OK();
}

View File

@ -60,18 +60,19 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
}
if (reshuffle_each_iteration_) {
*output = new ReshufflingDataset(input, buffer_size, seed, seed2);
*output = new ReshufflingDataset(ctx, input, buffer_size, seed, seed2);
} else {
*output = new FixedSeedDataset(input, buffer_size, seed, seed2);
*output = new FixedSeedDataset(ctx, input, buffer_size, seed, seed2);
}
}
private:
// Abstract base dataset that implements a shuffling iterator.
class ShuffleDatasetBase : public DatasetBase {
class ShuffleDatasetBase : public GraphDatasetBase {
public:
ShuffleDatasetBase(const DatasetBase* input, int64 buffer_size)
: input_(input), buffer_size_(buffer_size) {
ShuffleDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
int64 buffer_size)
: GraphDatasetBase(ctx), input_(input), buffer_size_(buffer_size) {
input_->Ref();
}
@ -91,6 +92,8 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
explicit Iterator(const Params& params, int64 seed, int64 seed2)
: DatasetIterator<ShuffleDatasetBase>(params),
input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
seed_(seed),
seed2_(seed2),
parent_generator_(seed, seed2),
generator_(&parent_generator_) {
buffer_.reserve(params.dataset->buffer_size_);
@ -115,6 +118,8 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
&end_of_input_sequence_));
if (!end_of_input_sequence_) {
buffer_.emplace_back(std::move(input_element));
} else {
input_impl_.reset();
}
}
if (num_log_entries > 0) {
@ -125,7 +130,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
*end_of_sequence = false;
// Choose an element to produce uniformly at random, and
// swap the last element into its place in the buffer.
int64 index = generator_() % buffer_.size();
int64 index = Random() % buffer_.size();
*out_tensors = std::move(buffer_[index]);
std::swap(buffer_[index], buffer_.back());
buffer_.pop_back();
@ -136,14 +141,102 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
// Save the tensors in the buffer.
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("buffer_size"), buffer_.size()));
for (int i = 0; i < buffer_.size(); i++) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("buffer_", i, "_size")),
buffer_[i].size()));
for (int j = 0; j < buffer_[i].size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat("buffer_", i, "_", j)),
buffer_[i][j]));
}
}
// Save state needed to restore the random number generators.
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"),
num_random_samples_));
// Save input iterator if it hasn't been exhausted else write
// "end_of_input_sequence".
if (end_of_input_sequence_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("end_of_input_sequence"), ""));
} else {
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
}
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
buffer_.clear();
// Restore the buffer.
int64 buffer_size;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("buffer_size"), &buffer_size));
for (int i = 0; i < buffer_size; i++) {
int64 list_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("buffer_", i, "_size")), &list_size));
buffer_.emplace_back(std::vector<Tensor>(list_size));
for (int j = 0; j < list_size; j++) {
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat("buffer_", i, "_", j)),
&buffer_[i][j]));
}
}
// Restore the random number generators.
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"),
&num_random_samples_));
ResetRngs();
// Restore the input iterator if it wasn't already exhausted.
if (!reader->Contains(full_name("end_of_input_sequence"))) {
input_impl_ = dataset()->input_->MakeIterator(prefix());
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
end_of_input_sequence_ = true;
input_impl_.reset();
}
return Status::OK();
}
private:
random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random()
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
num_random_samples_++;
auto out = generator_();
return out;
}
void ResetRngs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
// Reset the generators based on the current iterator seeds.
parent_generator_ = random::PhiloxRandom(seed_, seed2_);
generator_ = random::SingleSampleAdapter<random::PhiloxRandom>(
&parent_generator_);
generator_.Skip(num_random_samples_);
}
mutex mu_;
std::vector<std::vector<Tensor>> buffer_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
bool end_of_input_sequence_ GUARDED_BY(mu_) = false;
const int64 seed_ GUARDED_BY(mu_);
const int64 seed2_ GUARDED_BY(mu_);
random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
random::SingleSampleAdapter<random::PhiloxRandom> generator_
GUARDED_BY(mu_);
int64 num_random_samples_ GUARDED_BY(mu_) = 0;
};
const DatasetBase* const input_;
@ -154,9 +247,9 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
// created from it. Used when `reshuffle_each_iteration` is true.
class ReshufflingDataset : public ShuffleDatasetBase {
public:
ReshufflingDataset(const DatasetBase* input, int64 buffer_size, int64 seed,
int64 seed2)
: ShuffleDatasetBase(input, buffer_size),
ReshufflingDataset(OpKernelContext* ctx, const DatasetBase* input,
int64 buffer_size, int64 seed, int64 seed2)
: ShuffleDatasetBase(ctx, input, buffer_size),
seed_(seed),
seed2_(seed2),
parent_generator_(seed, seed2),
@ -181,6 +274,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
iterator_seed2));
}
private:
const int64 seed_;
const int64 seed2_;
mutable mutex mu_;
@ -193,9 +287,11 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
// Used when `reshuffle_each_iteration` is false.
class FixedSeedDataset : public ShuffleDatasetBase {
public:
FixedSeedDataset(const DatasetBase* input, int64 buffer_size, int64 seed,
int64 seed2)
: ShuffleDatasetBase(input, buffer_size), seed_(seed), seed2_(seed) {}
FixedSeedDataset(OpKernelContext* ctx, const DatasetBase* input,
int64 buffer_size, int64 seed, int64 seed2)
: ShuffleDatasetBase(ctx, input, buffer_size),
seed_(seed),
seed2_(seed) {}
string DebugString() override {
return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_,
@ -208,6 +304,29 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
{this, strings::StrCat(prefix, "::Shuffle")}, seed_, seed2_));
}
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node));
Node* buffer_size = nullptr;
Node* seed = nullptr;
Node* seed2 = nullptr;
AttrValue reshuffle_each_iteration;
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
b->BuildAttrValue(false, &reshuffle_each_iteration);
TF_RETURN_IF_ERROR(b->AddDataset(
this, {input_graph_node, buffer_size, seed, seed2}, // Inputs
{std::make_pair("reshuffle_each_iteration",
reshuffle_each_iteration)}, // Attrs
output));
return Status::OK();
}
private:
const int64 seed_;
const int64 seed2_;
};

View File

@ -0,0 +1,27 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/lib/random/distribution_sampler.h"
#include "tensorflow/core/lib/random/philox_random.h"
namespace tensorflow {
namespace random {
template <>
void SingleSampleAdapter<PhiloxRandom>::SkipFromGenerator(uint64 num_skips) {
// Use the O(1) PhiloxRandom::Skip instead of the default O(N) impl.
generator_->Skip(num_skips);
}
} // namespace random
} // namespace tensorflow

View File

@ -219,7 +219,37 @@ class SingleSampleAdapter {
return unused_results_[used_result_index_++];
}
PHILOX_DEVICE_INLINE
void Skip(uint64 num_skips) {
if (!num_skips) {
return;
}
int num_unused_results = kNativeElementCount - used_result_index_;
if (num_skips <= num_unused_results) {
used_result_index_ += num_skips;
return;
}
num_skips -= num_unused_results;
used_result_index_ = kNativeElementCount;
SkipFromGenerator(num_skips / kNativeElementCount);
num_skips = num_skips % kNativeElementCount;
if (num_skips) {
unused_results_ = (*generator_)();
used_result_index_ = num_skips;
}
}
private:
// This implementation iteratively skips over `num_skips` samples
// from `generator_`. There is an O(1) implementation for PhiloxRandom
// in random_distributions.cc.
PHILOX_DEVICE_INLINE
void SkipFromGenerator(uint64 num_skips) {
while (num_skips--) {
(*generator_)();
}
}
Generator* generator_;
typename Generator::ResultType unused_results_;
int used_result_index_;

View File

@ -280,6 +280,72 @@ TEST(PhiloxRandomTest, RandomParametersDoubleMomentsTest) {
RandomParametersMomentsTest<double>(1 << 20, 40, strides, kZLimit);
}
class MockGenerator {
public:
explicit MockGenerator(uint64 seed) : counter_(seed) {}
using ResultType = std::vector<uint32>;
using ResultElementType = uint32;
static const int kResultElementCount = 1;
ResultType operator()() {
ResultType result;
result.push_back(counter_++);
return result;
}
private:
uint32 counter_;
};
template <typename T>
void SingleSampleAdapterSkipTest() {
std::vector<uint64> skips(10);
std::vector<uint64> skip_afters(10);
std::iota(skips.begin(), skips.end(), 0);
std::iota(skip_afters.begin(), skip_afters.end(), 0);
uint64 total_samples = 100;
uint64 seed = GetTestSeed();
for (uint64 skip : skips) {
for (uint64 skip_after : skip_afters) {
// Baseline rngs.
T parent_gen(seed);
SingleSampleAdapter<T> gen(&parent_gen);
// Rng on which Skip() is performed.
T parent_gen_to_skip(seed);
SingleSampleAdapter<T> gen_to_skip(&parent_gen_to_skip);
// Skip over `skip_after` samples from both `gen` and `gen_to_skip`.
int cur = 0;
for (; cur < skip_after; cur++) {
gen();
gen_to_skip();
}
// Skip over `skip_` samples from `gen` iteratively.
for (; cur < skip_after + skip; cur++) {
gen();
}
// Skip over `skip_` samples from `gen_to_skip` by calling `Skip()`.
gen_to_skip.Skip(skip);
// Assert that they produce same outputs afterwards.
for (; cur < total_samples; cur++) {
ASSERT_EQ(gen(), gen_to_skip());
}
}
}
}
TEST(SingleSampleAdapterTest, PhiloxRandomSkip) {
SingleSampleAdapterSkipTest<PhiloxRandom>();
}
TEST(SingleSampleAdapterTest, MockGeneratorSkip) {
SingleSampleAdapterSkipTest<MockGenerator>();
}
} // namespace
} // namespace random
} // namespace tensorflow