Implement save/restore for ShuffleDataset[reshuffle_each_iteration=False].
Also added SingleSampleAdapter::Skip for restoring rng state. PiperOrigin-RevId: 174424108
This commit is contained in:
parent
7bb2d57b0b
commit
902c91342a
@ -393,11 +393,16 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/contrib/data/python/ops:iterator_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
|
@ -18,16 +18,22 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops
|
||||
from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
|
||||
class ShuffleDatasetTest(test.TestCase):
|
||||
@ -42,8 +48,9 @@ class ShuffleDatasetTest(test.TestCase):
|
||||
buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.repeat(count_placeholder))
|
||||
repeat_dataset = (
|
||||
contrib_dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.repeat(count_placeholder))
|
||||
|
||||
shuffle_dataset = repeat_dataset.shuffle(buffer_size_placeholder,
|
||||
seed_placeholder)
|
||||
@ -134,8 +141,9 @@ class ShuffleDatasetTest(test.TestCase):
|
||||
|
||||
def testDefaultArguments(self):
|
||||
components = [0, 1, 2, 3, 4]
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components).shuffle(5)
|
||||
.repeat().make_one_shot_iterator())
|
||||
iterator = (
|
||||
contrib_dataset_ops.Dataset.from_tensor_slices(components).shuffle(5)
|
||||
.repeat().make_one_shot_iterator())
|
||||
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -149,5 +157,322 @@ class ShuffleDatasetTest(test.TestCase):
|
||||
self.assertEqual(10, counts[i])
|
||||
|
||||
|
||||
class ShuffleDatasetSerializationTest(test.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
# Remove all checkpoint files.
|
||||
prefix = self._ckpt_path()
|
||||
pattern = prefix + "*"
|
||||
files = gfile.Glob(pattern)
|
||||
map(gfile.Remove, files)
|
||||
|
||||
def _build_graph(self,
|
||||
range_limit=10,
|
||||
num_repeats=5,
|
||||
buffer_size=5,
|
||||
seed=None,
|
||||
reshuffle_each_iteration=None,
|
||||
build_saveable=True):
|
||||
iterator = dataset_ops.Dataset.range(range_limit).shuffle(
|
||||
buffer_size,
|
||||
seed=seed,
|
||||
reshuffle_each_iteration=reshuffle_each_iteration).repeat(
|
||||
num_repeats).make_initializable_iterator()
|
||||
if build_saveable:
|
||||
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
ops.add_to_collection("iterator_ops", init_op)
|
||||
ops.add_to_collection("iterator_ops", get_next)
|
||||
saver = saver_lib.Saver(allow_empty=True)
|
||||
return init_op, get_next, saver
|
||||
|
||||
def _ckpt_path(self):
|
||||
return os.path.join(self.get_temp_dir(), "iterator")
|
||||
|
||||
def _latest_ckpt(self):
|
||||
return saver_lib.latest_checkpoint(self.get_temp_dir())
|
||||
|
||||
def _save(self, sess, saver):
|
||||
saver.save(sess, self._ckpt_path())
|
||||
|
||||
def _restore(self, saver, sess):
|
||||
saver.restore(sess, self._latest_ckpt())
|
||||
|
||||
def _import_meta_graph(self):
|
||||
meta_file_path = self._ckpt_path() + ".meta"
|
||||
return saver_lib.import_meta_graph(meta_file_path)
|
||||
|
||||
def _testReadWithBreaks(self, break_points, init_before_restore=False):
|
||||
seed = 55
|
||||
range_limit = 10
|
||||
num_repeats = 5
|
||||
num_outputs = range_limit * num_repeats
|
||||
buffer_sizes = [1, 3, 8, 10, 25, 50]
|
||||
reshuffle_each_iteration = False
|
||||
for buffer_size in buffer_sizes:
|
||||
expected = []
|
||||
actual = []
|
||||
# Generate the ground truth.
|
||||
with ops.Graph().as_default() as g:
|
||||
g.seed = 10
|
||||
init_op, get_next_op, _ = self._build_graph(
|
||||
range_limit=range_limit,
|
||||
num_repeats=num_repeats,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
reshuffle_each_iteration=reshuffle_each_iteration)
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(num_outputs):
|
||||
expected.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
# Run and checkpoint after first break_point.
|
||||
with ops.Graph().as_default() as g:
|
||||
g.seed = 10
|
||||
init_op, get_next_op, saver = self._build_graph(
|
||||
range_limit=range_limit,
|
||||
num_repeats=num_repeats,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
reshuffle_each_iteration=reshuffle_each_iteration)
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(break_points[0]):
|
||||
actual.append(sess.run(get_next_op))
|
||||
self._save(sess, saver)
|
||||
|
||||
# Load from checkpoint and continue running while stopping at each
|
||||
# subsequent checkpoint.
|
||||
for i in range(len(break_points)):
|
||||
with ops.Graph().as_default() as g:
|
||||
saver = self._import_meta_graph()
|
||||
init_op, get_next_op = ops.get_collection("iterator_ops")
|
||||
with self.test_session(graph=g) as sess:
|
||||
if init_before_restore:
|
||||
sess.run(init_op)
|
||||
self._restore(saver, sess)
|
||||
start = break_points[i]
|
||||
end = break_points[
|
||||
i + 1] if i < len(break_points) - 1 else num_outputs
|
||||
for _ in range(end - start):
|
||||
actual.append(sess.run(get_next_op))
|
||||
self._save(sess, saver)
|
||||
if end == num_outputs:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def testSaveRestore(self):
|
||||
self._testReadWithBreaks([8]) # rng buffer_size: 0
|
||||
self._testReadWithBreaks([13]) # rng buffer_size: 1
|
||||
self._testReadWithBreaks([18]) # rng buffer_size: 2
|
||||
self._testReadWithBreaks([23]) # rng buffer_size: 3
|
||||
|
||||
def testSaveUnusedIterator(self):
|
||||
self._testReadWithBreaks([0])
|
||||
|
||||
def testSaveFullyUsedIterator(self):
|
||||
self._testReadWithBreaks([50])
|
||||
|
||||
def testMultipleBreaks(self):
|
||||
self._testReadWithBreaks([0, 5, 9, 15, 25, 32])
|
||||
|
||||
def testIdempotence(self):
|
||||
# Attempt to save iterator immediately after restoring.
|
||||
self._testReadWithBreaks([1, 1, 5, 5, 5, 25, 32])
|
||||
|
||||
def testInitThenRestore(self):
|
||||
self._testReadWithBreaks([0, 5, 9, 15, 25, 32], init_before_restore=True)
|
||||
|
||||
def testRestoreExhaustedIterator(self):
|
||||
seed = 55
|
||||
range_limit = 10
|
||||
num_repeats = 5
|
||||
num_outputs = range_limit * num_repeats
|
||||
buffer_sizes = [1, 3, 8, 10, 25, 50]
|
||||
reshuffle_each_iteration = False
|
||||
for buffer_size in buffer_sizes:
|
||||
with ops.Graph().as_default() as g:
|
||||
g.seed = 10
|
||||
init_op, get_next_op, saver = self._build_graph(
|
||||
range_limit=range_limit,
|
||||
num_repeats=num_repeats,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
reshuffle_each_iteration=reshuffle_each_iteration)
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(num_outputs):
|
||||
sess.run(get_next_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self._save(sess, saver)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
saver = self._import_meta_graph()
|
||||
init_op, get_next_op = ops.get_collection("iterator_ops")
|
||||
with self.test_session(graph=g) as sess:
|
||||
self._restore(saver, sess)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
def testResetRestoredIterator(self):
|
||||
seed = 55
|
||||
range_limit = 10
|
||||
num_repeats = 5
|
||||
num_outputs = range_limit * num_repeats
|
||||
buffer_sizes = [1, 3, 8, 10, 25, 50]
|
||||
reshuffle_each_iteration = False
|
||||
for buffer_size in buffer_sizes:
|
||||
with ops.Graph().as_default() as g:
|
||||
g.seed = 10
|
||||
init_op, get_next_op, saver = self._build_graph(
|
||||
range_limit=range_limit,
|
||||
num_repeats=num_repeats,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
reshuffle_each_iteration=reshuffle_each_iteration)
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(num_outputs // 2):
|
||||
sess.run(get_next_op)
|
||||
self._save(sess, saver)
|
||||
|
||||
outputs = []
|
||||
with ops.Graph().as_default() as g:
|
||||
saver = self._import_meta_graph()
|
||||
init_op, get_next_op = ops.get_collection("iterator_ops")
|
||||
with self.test_session(graph=g) as sess:
|
||||
self._restore(saver, sess)
|
||||
sess.run(init_op)
|
||||
for _ in range(num_outputs):
|
||||
outputs.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
expected_outputs_sorted = sorted(
|
||||
np.array([range(range_limit)
|
||||
for _ in range(num_repeats)]).flatten())
|
||||
self.assertEqual(expected_outputs_sorted, sorted(outputs))
|
||||
|
||||
def testRestoreInModifiedGraph(self):
|
||||
seed = 55
|
||||
break_point = 25
|
||||
range_limit = 10
|
||||
num_repeats = 5
|
||||
num_outputs = range_limit * num_repeats
|
||||
buffer_sizes = [3, 8, 10, 25, 50]
|
||||
reshuffle_each_iteration = False
|
||||
for buffer_size in buffer_sizes:
|
||||
expected = []
|
||||
actual_without_restore = []
|
||||
actual = []
|
||||
with ops.Graph().as_default() as g:
|
||||
g.seed = 10
|
||||
init_op, get_next_op, saver = self._build_graph(
|
||||
range_limit=range_limit,
|
||||
num_repeats=num_repeats,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
reshuffle_each_iteration=reshuffle_each_iteration)
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(break_point):
|
||||
expected.append(sess.run(get_next_op))
|
||||
actual.extend(expected)
|
||||
self._save(sess, saver)
|
||||
for _ in range(num_outputs - break_point):
|
||||
expected.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
g.seed = 20 # Different seed than previous graph for shuffle rngs.
|
||||
init_op, get_next_op, saver = self._build_graph(
|
||||
range_limit=range_limit,
|
||||
num_repeats=num_repeats,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
reshuffle_each_iteration=reshuffle_each_iteration)
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(num_outputs):
|
||||
actual_without_restore.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
g.seed = 20 # Different seed than previous graph for shuffle rngs.
|
||||
init_op, get_next_op, saver = self._build_graph(
|
||||
range_limit=range_limit,
|
||||
num_repeats=num_repeats,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
reshuffle_each_iteration=reshuffle_each_iteration)
|
||||
with self.test_session(graph=g) as sess:
|
||||
self._restore(saver, sess)
|
||||
for _ in range(num_outputs - break_point):
|
||||
actual.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
# Since the modified graph has a different random seed it produces a
|
||||
# different order of examples.
|
||||
self.assertNotEqual(expected, actual_without_restore)
|
||||
self.assertEqual(sorted(expected), sorted(actual_without_restore))
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def testDoNotBuildSaveable(self):
|
||||
seed = 55
|
||||
break_point = 25
|
||||
range_limit = 10
|
||||
num_repeats = 5
|
||||
num_outputs = range_limit * num_repeats
|
||||
buffer_sizes = [3, 8, 10, 25, 50]
|
||||
reshuffle_each_iteration = False
|
||||
for buffer_size in buffer_sizes:
|
||||
actual = []
|
||||
with ops.Graph().as_default() as g:
|
||||
g.seed = 10
|
||||
init_op, get_next_op, saver = self._build_graph(
|
||||
range_limit=range_limit,
|
||||
num_repeats=num_repeats,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
reshuffle_each_iteration=reshuffle_each_iteration)
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(break_point):
|
||||
sess.run(get_next_op)
|
||||
self._save(sess, saver)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
g.seed = 20 # Different seed than previous graph for shuffle rngs.
|
||||
init_op, get_next_op, saver = self._build_graph(
|
||||
range_limit=range_limit,
|
||||
num_repeats=num_repeats,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
reshuffle_each_iteration=reshuffle_each_iteration,
|
||||
build_saveable=False)
|
||||
with self.test_session(graph=g) as sess:
|
||||
# Since the SaveableObject was not added to Saver's list
|
||||
# of saveables, iterator state is not restored by saver.restore().
|
||||
self._restore(saver, sess)
|
||||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
sess.run(get_next_op)
|
||||
sess.run(init_op)
|
||||
for _ in range(num_outputs):
|
||||
actual.append(sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
expected_outputs_sorted = sorted(
|
||||
np.array([range(range_limit) for _ in range(num_repeats)]).flatten())
|
||||
self.assertEqual(expected_outputs_sorted, sorted(actual))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -18,6 +18,8 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
@ -47,6 +49,7 @@ class IteratorStateReader {
|
||||
public:
|
||||
virtual Status ReadScalar(StringPiece key, int64* val) = 0;
|
||||
virtual Status ReadScalar(StringPiece key, string* val) = 0;
|
||||
virtual Status ReadTensor(StringPiece key, Tensor* val) = 0;
|
||||
virtual bool Contains(StringPiece key) = 0;
|
||||
|
||||
virtual ~IteratorStateReader() {}
|
||||
@ -58,6 +61,7 @@ class IteratorStateWriter {
|
||||
public:
|
||||
virtual Status WriteScalar(StringPiece key, const int64 val) = 0;
|
||||
virtual Status WriteScalar(StringPiece key, const string& val) = 0;
|
||||
virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
|
||||
|
||||
virtual ~IteratorStateWriter() {}
|
||||
};
|
||||
@ -112,6 +116,13 @@ class GraphDefBuilderWrapper {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <class DatasetType>
|
||||
Status AddDataset(const DatasetType* dataset,
|
||||
const std::vector<NodeBuilder::NodeOut>& inputs,
|
||||
Node** output) {
|
||||
return AddDataset(dataset, inputs, {}, output);
|
||||
}
|
||||
|
||||
// Adds a node corresponding to the `DatasetType` to the Graph.
|
||||
// Return value of `DatasetType::op_name()` is used as the op type for the
|
||||
// node.
|
||||
@ -122,7 +133,9 @@ class GraphDefBuilderWrapper {
|
||||
// The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
|
||||
template <class DatasetType>
|
||||
Status AddDataset(const DatasetType* dataset,
|
||||
std::vector<NodeBuilder::NodeOut> inputs, Node** output) {
|
||||
const std::vector<NodeBuilder::NodeOut>& inputs,
|
||||
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
|
||||
Node** output) {
|
||||
const string& op_type_name = dataset->op_name();
|
||||
std::unique_ptr<const GraphDefBuilder::Options> opts(
|
||||
new GraphDefBuilder::Options(b_->opts()));
|
||||
@ -138,6 +151,10 @@ class GraphDefBuilderWrapper {
|
||||
opts.reset(new GraphDefBuilder::Options(
|
||||
opts->WithAttr("output_types", dataset->output_dtypes())));
|
||||
}
|
||||
for (auto attr : attrs) {
|
||||
opts.reset(new GraphDefBuilder::Options(
|
||||
opts->WithAttr(attr.first, attr.second)));
|
||||
}
|
||||
if (opts->HaveError()) {
|
||||
return errors::Internal("AddDataset: Error building Options.");
|
||||
}
|
||||
@ -187,6 +204,11 @@ class GraphDefBuilderWrapper {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BuildAttrValue(const T& value, AttrValue* attr) {
|
||||
SetAttrValue(value, attr);
|
||||
}
|
||||
|
||||
private:
|
||||
void AddTensorInternal(const Tensor& val, Node** output) {
|
||||
*output = ops::SourceOp(
|
||||
|
@ -188,6 +188,10 @@ class VariantTensorDataReader : public IteratorStateReader {
|
||||
return ReadScalarInternal(key, val);
|
||||
}
|
||||
|
||||
Status ReadTensor(StringPiece key, Tensor* val) override {
|
||||
return ReadTensorInternal(key, val);
|
||||
}
|
||||
|
||||
bool Contains(StringPiece key) override {
|
||||
return map_.find(key.ToString()) != map_.end();
|
||||
}
|
||||
@ -217,6 +221,14 @@ class VariantTensorDataReader : public IteratorStateReader {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadTensorInternal(StringPiece key, Tensor* val) {
|
||||
if (map_.find(key.ToString()) == map_.end()) {
|
||||
return errors::NotFound(key);
|
||||
}
|
||||
*val = data_->tensors(map_[key.ToString()]);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::map<string, size_t> map_;
|
||||
const VariantTensorData* data_; // Not owned.
|
||||
Status status_;
|
||||
@ -236,6 +248,10 @@ class VariantTensorDataWriter : public IteratorStateWriter {
|
||||
return WriteScalarInternal(key, val);
|
||||
}
|
||||
|
||||
Status WriteTensor(StringPiece key, const Tensor& val) override {
|
||||
return WriteTensorInternal(key, val);
|
||||
}
|
||||
|
||||
// Writes the metadata to `data_`.
|
||||
Status Flush() {
|
||||
string metadata;
|
||||
@ -249,15 +265,19 @@ class VariantTensorDataWriter : public IteratorStateWriter {
|
||||
private:
|
||||
template <typename T>
|
||||
Status WriteScalarInternal(StringPiece key, const T& val) {
|
||||
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
|
||||
val_t.scalar<T>()() = val;
|
||||
return WriteTensorInternal(key, val_t);
|
||||
}
|
||||
|
||||
Status WriteTensorInternal(StringPiece key, const Tensor& val) {
|
||||
// Write key to the metadata proto. This gets written to `data_`
|
||||
// when `Flush()` is called. We do this lazily to avoid multiple
|
||||
// serialization calls.
|
||||
metadata_proto_.add_keys(key.ToString());
|
||||
|
||||
// Update tensors.
|
||||
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
|
||||
val_t.scalar<T>()() = val;
|
||||
*(data_->add_tensors()) = std::move(val_t);
|
||||
*(data_->add_tensors()) = val;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -60,18 +60,19 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
if (reshuffle_each_iteration_) {
|
||||
*output = new ReshufflingDataset(input, buffer_size, seed, seed2);
|
||||
*output = new ReshufflingDataset(ctx, input, buffer_size, seed, seed2);
|
||||
} else {
|
||||
*output = new FixedSeedDataset(input, buffer_size, seed, seed2);
|
||||
*output = new FixedSeedDataset(ctx, input, buffer_size, seed, seed2);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Abstract base dataset that implements a shuffling iterator.
|
||||
class ShuffleDatasetBase : public DatasetBase {
|
||||
class ShuffleDatasetBase : public GraphDatasetBase {
|
||||
public:
|
||||
ShuffleDatasetBase(const DatasetBase* input, int64 buffer_size)
|
||||
: input_(input), buffer_size_(buffer_size) {
|
||||
ShuffleDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
|
||||
int64 buffer_size)
|
||||
: GraphDatasetBase(ctx), input_(input), buffer_size_(buffer_size) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
@ -91,6 +92,8 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
explicit Iterator(const Params& params, int64 seed, int64 seed2)
|
||||
: DatasetIterator<ShuffleDatasetBase>(params),
|
||||
input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
|
||||
seed_(seed),
|
||||
seed2_(seed2),
|
||||
parent_generator_(seed, seed2),
|
||||
generator_(&parent_generator_) {
|
||||
buffer_.reserve(params.dataset->buffer_size_);
|
||||
@ -115,6 +118,8 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
&end_of_input_sequence_));
|
||||
if (!end_of_input_sequence_) {
|
||||
buffer_.emplace_back(std::move(input_element));
|
||||
} else {
|
||||
input_impl_.reset();
|
||||
}
|
||||
}
|
||||
if (num_log_entries > 0) {
|
||||
@ -125,7 +130,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
*end_of_sequence = false;
|
||||
// Choose an element to produce uniformly at random, and
|
||||
// swap the last element into its place in the buffer.
|
||||
int64 index = generator_() % buffer_.size();
|
||||
int64 index = Random() % buffer_.size();
|
||||
*out_tensors = std::move(buffer_[index]);
|
||||
std::swap(buffer_[index], buffer_.back());
|
||||
buffer_.pop_back();
|
||||
@ -136,14 +141,102 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
|
||||
// Save the tensors in the buffer.
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("buffer_size"), buffer_.size()));
|
||||
for (int i = 0; i < buffer_.size(); i++) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat("buffer_", i, "_size")),
|
||||
buffer_[i].size()));
|
||||
for (int j = 0; j < buffer_[i].size(); j++) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||
full_name(strings::StrCat("buffer_", i, "_", j)),
|
||||
buffer_[i][j]));
|
||||
}
|
||||
}
|
||||
|
||||
// Save state needed to restore the random number generators.
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"),
|
||||
num_random_samples_));
|
||||
|
||||
// Save input iterator if it hasn't been exhausted else write
|
||||
// "end_of_input_sequence".
|
||||
if (end_of_input_sequence_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("end_of_input_sequence"), ""));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(OpKernelContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
buffer_.clear();
|
||||
|
||||
// Restore the buffer.
|
||||
int64 buffer_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("buffer_size"), &buffer_size));
|
||||
for (int i = 0; i < buffer_size; i++) {
|
||||
int64 list_size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat("buffer_", i, "_size")), &list_size));
|
||||
buffer_.emplace_back(std::vector<Tensor>(list_size));
|
||||
for (int j = 0; j < list_size; j++) {
|
||||
TF_RETURN_IF_ERROR(reader->ReadTensor(
|
||||
full_name(strings::StrCat("buffer_", i, "_", j)),
|
||||
&buffer_[i][j]));
|
||||
}
|
||||
}
|
||||
|
||||
// Restore the random number generators.
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"),
|
||||
&num_random_samples_));
|
||||
ResetRngs();
|
||||
|
||||
// Restore the input iterator if it wasn't already exhausted.
|
||||
if (!reader->Contains(full_name("end_of_input_sequence"))) {
|
||||
input_impl_ = dataset()->input_->MakeIterator(prefix());
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
} else {
|
||||
end_of_input_sequence_ = true;
|
||||
input_impl_.reset();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random()
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
num_random_samples_++;
|
||||
auto out = generator_();
|
||||
return out;
|
||||
}
|
||||
|
||||
void ResetRngs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
// Reset the generators based on the current iterator seeds.
|
||||
parent_generator_ = random::PhiloxRandom(seed_, seed2_);
|
||||
generator_ = random::SingleSampleAdapter<random::PhiloxRandom>(
|
||||
&parent_generator_);
|
||||
generator_.Skip(num_random_samples_);
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
std::vector<std::vector<Tensor>> buffer_ GUARDED_BY(mu_);
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
bool end_of_input_sequence_ GUARDED_BY(mu_) = false;
|
||||
const int64 seed_ GUARDED_BY(mu_);
|
||||
const int64 seed2_ GUARDED_BY(mu_);
|
||||
random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
|
||||
random::SingleSampleAdapter<random::PhiloxRandom> generator_
|
||||
GUARDED_BY(mu_);
|
||||
int64 num_random_samples_ GUARDED_BY(mu_) = 0;
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
@ -154,9 +247,9 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
// created from it. Used when `reshuffle_each_iteration` is true.
|
||||
class ReshufflingDataset : public ShuffleDatasetBase {
|
||||
public:
|
||||
ReshufflingDataset(const DatasetBase* input, int64 buffer_size, int64 seed,
|
||||
int64 seed2)
|
||||
: ShuffleDatasetBase(input, buffer_size),
|
||||
ReshufflingDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
int64 buffer_size, int64 seed, int64 seed2)
|
||||
: ShuffleDatasetBase(ctx, input, buffer_size),
|
||||
seed_(seed),
|
||||
seed2_(seed2),
|
||||
parent_generator_(seed, seed2),
|
||||
@ -181,6 +274,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
iterator_seed2));
|
||||
}
|
||||
|
||||
private:
|
||||
const int64 seed_;
|
||||
const int64 seed2_;
|
||||
mutable mutex mu_;
|
||||
@ -193,9 +287,11 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Used when `reshuffle_each_iteration` is false.
|
||||
class FixedSeedDataset : public ShuffleDatasetBase {
|
||||
public:
|
||||
FixedSeedDataset(const DatasetBase* input, int64 buffer_size, int64 seed,
|
||||
int64 seed2)
|
||||
: ShuffleDatasetBase(input, buffer_size), seed_(seed), seed2_(seed) {}
|
||||
FixedSeedDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
int64 buffer_size, int64 seed, int64 seed2)
|
||||
: ShuffleDatasetBase(ctx, input, buffer_size),
|
||||
seed_(seed),
|
||||
seed2_(seed) {}
|
||||
|
||||
string DebugString() override {
|
||||
return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_,
|
||||
@ -208,6 +304,29 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
|
||||
{this, strings::StrCat(prefix, "::Shuffle")}, seed_, seed2_));
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node));
|
||||
Node* buffer_size = nullptr;
|
||||
Node* seed = nullptr;
|
||||
Node* seed2 = nullptr;
|
||||
AttrValue reshuffle_each_iteration;
|
||||
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
|
||||
b->BuildAttrValue(false, &reshuffle_each_iteration);
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, {input_graph_node, buffer_size, seed, seed2}, // Inputs
|
||||
{std::make_pair("reshuffle_each_iteration",
|
||||
reshuffle_each_iteration)}, // Attrs
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
const int64 seed_;
|
||||
const int64 seed2_;
|
||||
};
|
||||
|
27
tensorflow/core/lib/random/random_distributions.cc
Normal file
27
tensorflow/core/lib/random/random_distributions.cc
Normal file
@ -0,0 +1,27 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/random/distribution_sampler.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace random {
|
||||
template <>
|
||||
void SingleSampleAdapter<PhiloxRandom>::SkipFromGenerator(uint64 num_skips) {
|
||||
// Use the O(1) PhiloxRandom::Skip instead of the default O(N) impl.
|
||||
generator_->Skip(num_skips);
|
||||
}
|
||||
} // namespace random
|
||||
} // namespace tensorflow
|
@ -219,7 +219,37 @@ class SingleSampleAdapter {
|
||||
return unused_results_[used_result_index_++];
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE
|
||||
void Skip(uint64 num_skips) {
|
||||
if (!num_skips) {
|
||||
return;
|
||||
}
|
||||
int num_unused_results = kNativeElementCount - used_result_index_;
|
||||
if (num_skips <= num_unused_results) {
|
||||
used_result_index_ += num_skips;
|
||||
return;
|
||||
}
|
||||
num_skips -= num_unused_results;
|
||||
used_result_index_ = kNativeElementCount;
|
||||
SkipFromGenerator(num_skips / kNativeElementCount);
|
||||
num_skips = num_skips % kNativeElementCount;
|
||||
if (num_skips) {
|
||||
unused_results_ = (*generator_)();
|
||||
used_result_index_ = num_skips;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// This implementation iteratively skips over `num_skips` samples
|
||||
// from `generator_`. There is an O(1) implementation for PhiloxRandom
|
||||
// in random_distributions.cc.
|
||||
PHILOX_DEVICE_INLINE
|
||||
void SkipFromGenerator(uint64 num_skips) {
|
||||
while (num_skips--) {
|
||||
(*generator_)();
|
||||
}
|
||||
}
|
||||
|
||||
Generator* generator_;
|
||||
typename Generator::ResultType unused_results_;
|
||||
int used_result_index_;
|
||||
|
@ -280,6 +280,72 @@ TEST(PhiloxRandomTest, RandomParametersDoubleMomentsTest) {
|
||||
RandomParametersMomentsTest<double>(1 << 20, 40, strides, kZLimit);
|
||||
}
|
||||
|
||||
class MockGenerator {
|
||||
public:
|
||||
explicit MockGenerator(uint64 seed) : counter_(seed) {}
|
||||
using ResultType = std::vector<uint32>;
|
||||
using ResultElementType = uint32;
|
||||
static const int kResultElementCount = 1;
|
||||
ResultType operator()() {
|
||||
ResultType result;
|
||||
result.push_back(counter_++);
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
uint32 counter_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void SingleSampleAdapterSkipTest() {
|
||||
std::vector<uint64> skips(10);
|
||||
std::vector<uint64> skip_afters(10);
|
||||
std::iota(skips.begin(), skips.end(), 0);
|
||||
std::iota(skip_afters.begin(), skip_afters.end(), 0);
|
||||
uint64 total_samples = 100;
|
||||
uint64 seed = GetTestSeed();
|
||||
|
||||
for (uint64 skip : skips) {
|
||||
for (uint64 skip_after : skip_afters) {
|
||||
// Baseline rngs.
|
||||
T parent_gen(seed);
|
||||
SingleSampleAdapter<T> gen(&parent_gen);
|
||||
|
||||
// Rng on which Skip() is performed.
|
||||
T parent_gen_to_skip(seed);
|
||||
SingleSampleAdapter<T> gen_to_skip(&parent_gen_to_skip);
|
||||
|
||||
// Skip over `skip_after` samples from both `gen` and `gen_to_skip`.
|
||||
int cur = 0;
|
||||
for (; cur < skip_after; cur++) {
|
||||
gen();
|
||||
gen_to_skip();
|
||||
}
|
||||
|
||||
// Skip over `skip_` samples from `gen` iteratively.
|
||||
for (; cur < skip_after + skip; cur++) {
|
||||
gen();
|
||||
}
|
||||
|
||||
// Skip over `skip_` samples from `gen_to_skip` by calling `Skip()`.
|
||||
gen_to_skip.Skip(skip);
|
||||
|
||||
// Assert that they produce same outputs afterwards.
|
||||
for (; cur < total_samples; cur++) {
|
||||
ASSERT_EQ(gen(), gen_to_skip());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SingleSampleAdapterTest, PhiloxRandomSkip) {
|
||||
SingleSampleAdapterSkipTest<PhiloxRandom>();
|
||||
}
|
||||
|
||||
TEST(SingleSampleAdapterTest, MockGeneratorSkip) {
|
||||
SingleSampleAdapterSkipTest<MockGenerator>();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace random
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user