Rename `seed` to `shuffle_seed` to better indicate use, and add a test for it.

PiperOrigin-RevId: 287907717
Change-Id: I4f1649ab669058d036a3a076a87024803640eb32
This commit is contained in:
Frank Chen 2020-01-02 15:45:29 -08:00 committed by TensorFlower Gardener
parent 043abbdf86
commit 4b7f4c1f09
2 changed files with 43 additions and 6 deletions

View File

@ -321,6 +321,41 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
dataset3 = dataset3.apply(snapshot.snapshot(tmpdir, shuffle_on_read=True)) dataset3 = dataset3.apply(snapshot.snapshot(tmpdir, shuffle_on_read=True))
self.assertDatasetProduces(dataset3, expected, assert_items_equal=True) self.assertDatasetProduces(dataset3, expected, assert_items_equal=True)
@combinations.generate(test_base.default_test_combinations())
def testReadShuffledSnapshotWithSeedAfterWrite(self):
self.setUpTFRecord(num_files=10, num_records=50)
filenames = self.test_filenames
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 10)
for r in range(0, 50)
]
tmpdir = self.makeSnapshotDirectory()
dataset = core_readers._TFRecordDataset(filenames)
dataset = dataset.apply(snapshot.snapshot(tmpdir, shard_size_bytes=10))
self.assertDatasetProduces(dataset, expected)
# remove the original files and try to read the data back only from snapshot
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
dataset2 = dataset2.apply(
snapshot.snapshot(tmpdir, shuffle_on_read=True, shuffle_seed=123456))
next2 = self.getNext(dataset2)
dataset3 = core_readers._TFRecordDataset(filenames)
dataset3 = dataset3.apply(
snapshot.snapshot(tmpdir, shuffle_on_read=True, shuffle_seed=123456))
next3 = self.getNext(dataset3)
# make sure that the items are read back in the same order for both datasets
for _ in range(500):
res2 = self.evaluate(next2())
res3 = self.evaluate(next3())
self.assertEqual(res2, res3)
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testReadSnapshotParallelAfterWrite(self): def testReadSnapshotParallelAfterWrite(self):
self.setUpTFRecord(10, 4000) self.setUpTFRecord(10, 4000)

View File

@ -46,7 +46,7 @@ class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
num_writer_threads=None, num_writer_threads=None,
writer_buffer_size=None, writer_buffer_size=None,
shuffle_on_read=None, shuffle_on_read=None,
seed=None, shuffle_seed=None,
mode=None, mode=None,
snapshot_name=None): snapshot_name=None):
@ -73,7 +73,7 @@ class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
self._mode = (mode if mode is not None else "auto") self._mode = (mode if mode is not None else "auto")
self._snapshot_name = (snapshot_name if snapshot_name is not None else "") self._snapshot_name = (snapshot_name if snapshot_name is not None else "")
self._seed, self._seed2 = random_seed.get_seed(seed) self._seed, self._seed2 = random_seed.get_seed(shuffle_seed)
self._input_dataset = input_dataset self._input_dataset = input_dataset
self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path") self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
@ -129,7 +129,7 @@ def snapshot(path,
num_writer_threads=None, num_writer_threads=None,
writer_buffer_size=None, writer_buffer_size=None,
shuffle_on_read=None, shuffle_on_read=None,
seed=None, shuffle_seed=None,
mode=None, mode=None,
snapshot_name=None): snapshot_name=None):
"""Writes to/reads from a snapshot of a dataset. """Writes to/reads from a snapshot of a dataset.
@ -170,8 +170,10 @@ def snapshot(path,
buffer before writing them out using `num_writer_threads`. buffer before writing them out using `num_writer_threads`.
shuffle_on_read: If this is True, then the order in which examples are shuffle_on_read: If this is True, then the order in which examples are
produced when reading from a snapshot will be random. Defaults to False. produced when reading from a snapshot will be random. Defaults to False.
seed: If seed is set, the random number generator is seeded by the given shuffle_seed: Optional. If shuffle_seed is set, the random number generator
seed. Otherwise, it is seeded by a random seed. used for shuffling (when shuffle_on_read is turned on) is seeded by the
given seed. Otherwise, it is seeded by a random seed that differs for
every run.
mode: The mode at which snapshot should operate. Valid options are "auto", mode: The mode at which snapshot should operate. Valid options are "auto",
"read", "write", and "passthrough". The default mode is "auto", where the "read", "write", and "passthrough". The default mode is "auto", where the
snapshot op will automatically determine what mode to operate in. snapshot op will automatically determine what mode to operate in.
@ -198,7 +200,7 @@ def snapshot(path,
num_writer_threads=num_writer_threads, num_writer_threads=num_writer_threads,
writer_buffer_size=writer_buffer_size, writer_buffer_size=writer_buffer_size,
shuffle_on_read=shuffle_on_read, shuffle_on_read=shuffle_on_read,
seed=seed, shuffle_seed=shuffle_seed,
mode=mode, mode=mode,
snapshot_name=snapshot_name) snapshot_name=snapshot_name)