diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py index d4868e8701a..55f730b4e2a 100644 --- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py @@ -321,6 +321,41 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, dataset3 = dataset3.apply(snapshot.snapshot(tmpdir, shuffle_on_read=True)) self.assertDatasetProduces(dataset3, expected, assert_items_equal=True) + @combinations.generate(test_base.default_test_combinations()) + def testReadShuffledSnapshotWithSeedAfterWrite(self): + self.setUpTFRecord(num_files=10, num_records=50) + filenames = self.test_filenames + + expected = [ + b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension + for f in range(0, 10) + for r in range(0, 50) + ] + + tmpdir = self.makeSnapshotDirectory() + dataset = core_readers._TFRecordDataset(filenames) + dataset = dataset.apply(snapshot.snapshot(tmpdir, shard_size_bytes=10)) + self.assertDatasetProduces(dataset, expected) + + # remove the original files and try to read the data back only from snapshot + self.removeTFRecords() + + dataset2 = core_readers._TFRecordDataset(filenames) + dataset2 = dataset2.apply( + snapshot.snapshot(tmpdir, shuffle_on_read=True, shuffle_seed=123456)) + next2 = self.getNext(dataset2) + + dataset3 = core_readers._TFRecordDataset(filenames) + dataset3 = dataset3.apply( + snapshot.snapshot(tmpdir, shuffle_on_read=True, shuffle_seed=123456)) + next3 = self.getNext(dataset3) + + # make sure that the items are read back in the same order for both datasets + for _ in range(500): + res2 = self.evaluate(next2()) + res3 = self.evaluate(next3()) + self.assertEqual(res2, res3) + @combinations.generate(test_base.default_test_combinations()) def testReadSnapshotParallelAfterWrite(self): self.setUpTFRecord(10, 4000) diff --git a/tensorflow/python/data/experimental/ops/snapshot.py b/tensorflow/python/data/experimental/ops/snapshot.py index eaf4225acfd..9bba2757dd7 100644 --- a/tensorflow/python/data/experimental/ops/snapshot.py +++ b/tensorflow/python/data/experimental/ops/snapshot.py @@ -46,7 +46,7 @@ class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset): num_writer_threads=None, writer_buffer_size=None, shuffle_on_read=None, - seed=None, + shuffle_seed=None, mode=None, snapshot_name=None): @@ -73,7 +73,7 @@ class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset): self._mode = (mode if mode is not None else "auto") self._snapshot_name = (snapshot_name if snapshot_name is not None else "") - self._seed, self._seed2 = random_seed.get_seed(seed) + self._seed, self._seed2 = random_seed.get_seed(shuffle_seed) self._input_dataset = input_dataset self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path") @@ -129,7 +129,7 @@ def snapshot(path, num_writer_threads=None, writer_buffer_size=None, shuffle_on_read=None, - seed=None, + shuffle_seed=None, mode=None, snapshot_name=None): """Writes to/reads from a snapshot of a dataset. @@ -170,8 +170,10 @@ def snapshot(path, buffer before writing them out using `num_writer_threads`. shuffle_on_read: If this is True, then the order in which examples are produced when reading from a snapshot will be random. Defaults to False. - seed: If seed is set, the random number generator is seeded by the given - seed. Otherwise, it is seeded by a random seed. + shuffle_seed: Optional. If shuffle_seed is set, the random number generator + used for shuffling (when shuffle_on_read is turned on) is seeded by the + given seed. Otherwise, it is seeded by a random seed that differs for + every run. mode: The mode at which snapshot should operate. Valid options are "auto", "read", "write", and "passthrough". The default mode is "auto", where the snapshot op will automatically determine what mode to operate in. @@ -198,7 +200,7 @@ def snapshot(path, num_writer_threads=num_writer_threads, writer_buffer_size=writer_buffer_size, shuffle_on_read=shuffle_on_read, - seed=seed, + shuffle_seed=shuffle_seed, mode=mode, snapshot_name=snapshot_name)