Rename `seed` to `shuffle_seed` to better indicate use, and add a test for it.
PiperOrigin-RevId: 287907717 Change-Id: I4f1649ab669058d036a3a076a87024803640eb32
This commit is contained in:
parent
043abbdf86
commit
4b7f4c1f09
|
@ -321,6 +321,41 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
|||
dataset3 = dataset3.apply(snapshot.snapshot(tmpdir, shuffle_on_read=True))
|
||||
self.assertDatasetProduces(dataset3, expected, assert_items_equal=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadShuffledSnapshotWithSeedAfterWrite(self):
|
||||
self.setUpTFRecord(num_files=10, num_records=50)
|
||||
filenames = self.test_filenames
|
||||
|
||||
expected = [
|
||||
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||
for f in range(0, 10)
|
||||
for r in range(0, 50)
|
||||
]
|
||||
|
||||
tmpdir = self.makeSnapshotDirectory()
|
||||
dataset = core_readers._TFRecordDataset(filenames)
|
||||
dataset = dataset.apply(snapshot.snapshot(tmpdir, shard_size_bytes=10))
|
||||
self.assertDatasetProduces(dataset, expected)
|
||||
|
||||
# remove the original files and try to read the data back only from snapshot
|
||||
self.removeTFRecords()
|
||||
|
||||
dataset2 = core_readers._TFRecordDataset(filenames)
|
||||
dataset2 = dataset2.apply(
|
||||
snapshot.snapshot(tmpdir, shuffle_on_read=True, shuffle_seed=123456))
|
||||
next2 = self.getNext(dataset2)
|
||||
|
||||
dataset3 = core_readers._TFRecordDataset(filenames)
|
||||
dataset3 = dataset3.apply(
|
||||
snapshot.snapshot(tmpdir, shuffle_on_read=True, shuffle_seed=123456))
|
||||
next3 = self.getNext(dataset3)
|
||||
|
||||
# make sure that the items are read back in the same order for both datasets
|
||||
for _ in range(500):
|
||||
res2 = self.evaluate(next2())
|
||||
res3 = self.evaluate(next3())
|
||||
self.assertEqual(res2, res3)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadSnapshotParallelAfterWrite(self):
|
||||
self.setUpTFRecord(10, 4000)
|
||||
|
|
|
@ -46,7 +46,7 @@ class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
|||
num_writer_threads=None,
|
||||
writer_buffer_size=None,
|
||||
shuffle_on_read=None,
|
||||
seed=None,
|
||||
shuffle_seed=None,
|
||||
mode=None,
|
||||
snapshot_name=None):
|
||||
|
||||
|
@ -73,7 +73,7 @@ class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
|||
self._mode = (mode if mode is not None else "auto")
|
||||
self._snapshot_name = (snapshot_name if snapshot_name is not None else "")
|
||||
|
||||
self._seed, self._seed2 = random_seed.get_seed(seed)
|
||||
self._seed, self._seed2 = random_seed.get_seed(shuffle_seed)
|
||||
|
||||
self._input_dataset = input_dataset
|
||||
self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
|
||||
|
@ -129,7 +129,7 @@ def snapshot(path,
|
|||
num_writer_threads=None,
|
||||
writer_buffer_size=None,
|
||||
shuffle_on_read=None,
|
||||
seed=None,
|
||||
shuffle_seed=None,
|
||||
mode=None,
|
||||
snapshot_name=None):
|
||||
"""Writes to/reads from a snapshot of a dataset.
|
||||
|
@ -170,8 +170,10 @@ def snapshot(path,
|
|||
buffer before writing them out using `num_writer_threads`.
|
||||
shuffle_on_read: If this is True, then the order in which examples are
|
||||
produced when reading from a snapshot will be random. Defaults to False.
|
||||
seed: If seed is set, the random number generator is seeded by the given
|
||||
seed. Otherwise, it is seeded by a random seed.
|
||||
shuffle_seed: Optional. If shuffle_seed is set, the random number generator
|
||||
used for shuffling (when shuffle_on_read is turned on) is seeded by the
|
||||
given seed. Otherwise, it is seeded by a random seed that differs for
|
||||
every run.
|
||||
mode: The mode at which snapshot should operate. Valid options are "auto",
|
||||
"read", "write", and "passthrough". The default mode is "auto", where the
|
||||
snapshot op will automatically determine what mode to operate in.
|
||||
|
@ -198,7 +200,7 @@ def snapshot(path,
|
|||
num_writer_threads=num_writer_threads,
|
||||
writer_buffer_size=writer_buffer_size,
|
||||
shuffle_on_read=shuffle_on_read,
|
||||
seed=seed,
|
||||
shuffle_seed=shuffle_seed,
|
||||
mode=mode,
|
||||
snapshot_name=snapshot_name)
|
||||
|
||||
|
|
Loading…
Reference in New Issue