Rename `seed` to `shuffle_seed` to better indicate use, and add a test for it.
PiperOrigin-RevId: 287907717 Change-Id: I4f1649ab669058d036a3a076a87024803640eb32
This commit is contained in:
parent
043abbdf86
commit
4b7f4c1f09
|
@ -321,6 +321,41 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||||
dataset3 = dataset3.apply(snapshot.snapshot(tmpdir, shuffle_on_read=True))
|
dataset3 = dataset3.apply(snapshot.snapshot(tmpdir, shuffle_on_read=True))
|
||||||
self.assertDatasetProduces(dataset3, expected, assert_items_equal=True)
|
self.assertDatasetProduces(dataset3, expected, assert_items_equal=True)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testReadShuffledSnapshotWithSeedAfterWrite(self):
|
||||||
|
self.setUpTFRecord(num_files=10, num_records=50)
|
||||||
|
filenames = self.test_filenames
|
||||||
|
|
||||||
|
expected = [
|
||||||
|
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||||
|
for f in range(0, 10)
|
||||||
|
for r in range(0, 50)
|
||||||
|
]
|
||||||
|
|
||||||
|
tmpdir = self.makeSnapshotDirectory()
|
||||||
|
dataset = core_readers._TFRecordDataset(filenames)
|
||||||
|
dataset = dataset.apply(snapshot.snapshot(tmpdir, shard_size_bytes=10))
|
||||||
|
self.assertDatasetProduces(dataset, expected)
|
||||||
|
|
||||||
|
# remove the original files and try to read the data back only from snapshot
|
||||||
|
self.removeTFRecords()
|
||||||
|
|
||||||
|
dataset2 = core_readers._TFRecordDataset(filenames)
|
||||||
|
dataset2 = dataset2.apply(
|
||||||
|
snapshot.snapshot(tmpdir, shuffle_on_read=True, shuffle_seed=123456))
|
||||||
|
next2 = self.getNext(dataset2)
|
||||||
|
|
||||||
|
dataset3 = core_readers._TFRecordDataset(filenames)
|
||||||
|
dataset3 = dataset3.apply(
|
||||||
|
snapshot.snapshot(tmpdir, shuffle_on_read=True, shuffle_seed=123456))
|
||||||
|
next3 = self.getNext(dataset3)
|
||||||
|
|
||||||
|
# make sure that the items are read back in the same order for both datasets
|
||||||
|
for _ in range(500):
|
||||||
|
res2 = self.evaluate(next2())
|
||||||
|
res3 = self.evaluate(next3())
|
||||||
|
self.assertEqual(res2, res3)
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testReadSnapshotParallelAfterWrite(self):
|
def testReadSnapshotParallelAfterWrite(self):
|
||||||
self.setUpTFRecord(10, 4000)
|
self.setUpTFRecord(10, 4000)
|
||||||
|
|
|
@ -46,7 +46,7 @@ class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||||
num_writer_threads=None,
|
num_writer_threads=None,
|
||||||
writer_buffer_size=None,
|
writer_buffer_size=None,
|
||||||
shuffle_on_read=None,
|
shuffle_on_read=None,
|
||||||
seed=None,
|
shuffle_seed=None,
|
||||||
mode=None,
|
mode=None,
|
||||||
snapshot_name=None):
|
snapshot_name=None):
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||||
self._mode = (mode if mode is not None else "auto")
|
self._mode = (mode if mode is not None else "auto")
|
||||||
self._snapshot_name = (snapshot_name if snapshot_name is not None else "")
|
self._snapshot_name = (snapshot_name if snapshot_name is not None else "")
|
||||||
|
|
||||||
self._seed, self._seed2 = random_seed.get_seed(seed)
|
self._seed, self._seed2 = random_seed.get_seed(shuffle_seed)
|
||||||
|
|
||||||
self._input_dataset = input_dataset
|
self._input_dataset = input_dataset
|
||||||
self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
|
self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
|
||||||
|
@ -129,7 +129,7 @@ def snapshot(path,
|
||||||
num_writer_threads=None,
|
num_writer_threads=None,
|
||||||
writer_buffer_size=None,
|
writer_buffer_size=None,
|
||||||
shuffle_on_read=None,
|
shuffle_on_read=None,
|
||||||
seed=None,
|
shuffle_seed=None,
|
||||||
mode=None,
|
mode=None,
|
||||||
snapshot_name=None):
|
snapshot_name=None):
|
||||||
"""Writes to/reads from a snapshot of a dataset.
|
"""Writes to/reads from a snapshot of a dataset.
|
||||||
|
@ -170,8 +170,10 @@ def snapshot(path,
|
||||||
buffer before writing them out using `num_writer_threads`.
|
buffer before writing them out using `num_writer_threads`.
|
||||||
shuffle_on_read: If this is True, then the order in which examples are
|
shuffle_on_read: If this is True, then the order in which examples are
|
||||||
produced when reading from a snapshot will be random. Defaults to False.
|
produced when reading from a snapshot will be random. Defaults to False.
|
||||||
seed: If seed is set, the random number generator is seeded by the given
|
shuffle_seed: Optional. If shuffle_seed is set, the random number generator
|
||||||
seed. Otherwise, it is seeded by a random seed.
|
used for shuffling (when shuffle_on_read is turned on) is seeded by the
|
||||||
|
given seed. Otherwise, it is seeded by a random seed that differs for
|
||||||
|
every run.
|
||||||
mode: The mode at which snapshot should operate. Valid options are "auto",
|
mode: The mode at which snapshot should operate. Valid options are "auto",
|
||||||
"read", "write", and "passthrough". The default mode is "auto", where the
|
"read", "write", and "passthrough". The default mode is "auto", where the
|
||||||
snapshot op will automatically determine what mode to operate in.
|
snapshot op will automatically determine what mode to operate in.
|
||||||
|
@ -198,7 +200,7 @@ def snapshot(path,
|
||||||
num_writer_threads=num_writer_threads,
|
num_writer_threads=num_writer_threads,
|
||||||
writer_buffer_size=writer_buffer_size,
|
writer_buffer_size=writer_buffer_size,
|
||||||
shuffle_on_read=shuffle_on_read,
|
shuffle_on_read=shuffle_on_read,
|
||||||
seed=seed,
|
shuffle_seed=shuffle_seed,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
snapshot_name=snapshot_name)
|
snapshot_name=snapshot_name)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue