From 7bbd0940e8d52fcae03021a4b66e3d6b3e495f4c Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Sat, 11 May 2019 09:24:21 -0700 Subject: [PATCH] Marking end_of_sequence as false explicitly in SnapshotReaderIterator. This fixes issues when we have a repeat afterwards and the end_of_sequence bit propagates from the SnapshotWriterIterator reaching the end of sequence. PiperOrigin-RevId: 247761255 --- .../kernels/data/experimental/snapshot_dataset_op.cc | 1 + .../data/experimental/kernel_tests/snapshot_test.py | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index f96f75f57af..1ff5878bb65 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -302,6 +302,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { return s; } + *end_of_sequence = false; experimental::SnapshotRecord record; record.ParseFromString(record_bytes); diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py index 50090f2971e..d21a4814017 100644 --- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py @@ -114,6 +114,16 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase): self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) + def testWriteSnapshotRepeatAfterwards(self): + tmpdir = self.makeSnapshotDirectory() + + dataset = dataset_ops.Dataset.range(10) + dataset = dataset.apply(snapshot.snapshot(tmpdir)) + dataset = dataset.repeat(10) + self.assertDatasetProduces(dataset, list(range(10)) * 10) + + self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) + def testWriteSnapshotMultiFileSuccessful(self): tmpdir = self.makeSnapshotDirectory()