From 4b7f4c1f09a9e0b221a25049af676183a43e1f2b Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Thu, 2 Jan 2020 15:45:29 -0800 Subject: [PATCH] Rename `seed` to `shuffle_seed` to better indicate use, and add a test for it. PiperOrigin-RevId: 287907717 Change-Id: I4f1649ab669058d036a3a076a87024803640eb32 --- .../kernel_tests/snapshot_test.py | 35 +++++++++++++++++++ .../python/data/experimental/ops/snapshot.py | 14 ++++---- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py index d4868e8701a..55f730b4e2a 100644 --- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py @@ -321,6 +321,41 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, dataset3 = dataset3.apply(snapshot.snapshot(tmpdir, shuffle_on_read=True)) self.assertDatasetProduces(dataset3, expected, assert_items_equal=True) + @combinations.generate(test_base.default_test_combinations()) + def testReadShuffledSnapshotWithSeedAfterWrite(self): + self.setUpTFRecord(num_files=10, num_records=50) + filenames = self.test_filenames + + expected = [ + b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension + for f in range(0, 10) + for r in range(0, 50) + ] + + tmpdir = self.makeSnapshotDirectory() + dataset = core_readers._TFRecordDataset(filenames) + dataset = dataset.apply(snapshot.snapshot(tmpdir, shard_size_bytes=10)) + self.assertDatasetProduces(dataset, expected) + + # remove the original files and try to read the data back only from snapshot + self.removeTFRecords() + + dataset2 = core_readers._TFRecordDataset(filenames) + dataset2 = dataset2.apply( + snapshot.snapshot(tmpdir, shuffle_on_read=True, shuffle_seed=123456)) + next2 = self.getNext(dataset2) + + dataset3 = core_readers._TFRecordDataset(filenames) + dataset3 = dataset3.apply( + snapshot.snapshot(tmpdir, shuffle_on_read=True, shuffle_seed=123456)) + next3 = self.getNext(dataset3) + + # make sure that the items are read back in the same order for both datasets + for _ in range(500): + res2 = self.evaluate(next2()) + res3 = self.evaluate(next3()) + self.assertEqual(res2, res3) + @combinations.generate(test_base.default_test_combinations()) def testReadSnapshotParallelAfterWrite(self): self.setUpTFRecord(10, 4000) diff --git a/tensorflow/python/data/experimental/ops/snapshot.py b/tensorflow/python/data/experimental/ops/snapshot.py index eaf4225acfd..9bba2757dd7 100644 --- a/tensorflow/python/data/experimental/ops/snapshot.py +++ b/tensorflow/python/data/experimental/ops/snapshot.py @@ -46,7 +46,7 @@ class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset): num_writer_threads=None, writer_buffer_size=None, shuffle_on_read=None, - seed=None, + shuffle_seed=None, mode=None, snapshot_name=None): @@ -73,7 +73,7 @@ class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset): self._mode = (mode if mode is not None else "auto") self._snapshot_name = (snapshot_name if snapshot_name is not None else "") - self._seed, self._seed2 = random_seed.get_seed(seed) + self._seed, self._seed2 = random_seed.get_seed(shuffle_seed) self._input_dataset = input_dataset self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path") @@ -129,7 +129,7 @@ def snapshot(path, num_writer_threads=None, writer_buffer_size=None, shuffle_on_read=None, - seed=None, + shuffle_seed=None, mode=None, snapshot_name=None): """Writes to/reads from a snapshot of a dataset. @@ -170,8 +170,10 @@ def snapshot(path, buffer before writing them out using `num_writer_threads`. shuffle_on_read: If this is True, then the order in which examples are produced when reading from a snapshot will be random. Defaults to False. - seed: If seed is set, the random number generator is seeded by the given - seed. Otherwise, it is seeded by a random seed. + shuffle_seed: Optional. If shuffle_seed is set, the random number generator + used for shuffling (when shuffle_on_read is turned on) is seeded by the + given seed. Otherwise, it is seeded by a random seed that differs for + every run. mode: The mode at which snapshot should operate. Valid options are "auto", "read", "write", and "passthrough". The default mode is "auto", where the snapshot op will automatically determine what mode to operate in. @@ -198,7 +200,7 @@ def snapshot(path, num_writer_threads=num_writer_threads, writer_buffer_size=writer_buffer_size, shuffle_on_read=shuffle_on_read, - seed=seed, + shuffle_seed=shuffle_seed, mode=mode, snapshot_name=snapshot_name)