diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index f96f75f57af..1ff5878bb65 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -302,6 +302,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { return s; } + *end_of_sequence = false; experimental::SnapshotRecord record; record.ParseFromString(record_bytes); diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py index 50090f2971e..d21a4814017 100644 --- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py @@ -114,6 +114,16 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase): self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) + def testWriteSnapshotRepeatAfterwards(self): + tmpdir = self.makeSnapshotDirectory() + + dataset = dataset_ops.Dataset.range(10) + dataset = dataset.apply(snapshot.snapshot(tmpdir)) + dataset = dataset.repeat(10) + self.assertDatasetProduces(dataset, list(range(10)) * 10) + + self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) + def testWriteSnapshotMultiFileSuccessful(self): tmpdir = self.makeSnapshotDirectory()