Marking end_of_sequence as false explicitly in SnapshotReaderIterator. This fixes issues when we have a repeat afterwards and the end_of_sequence bit propagates from the SnapshotWriterIterator reaching the end of sequence.

PiperOrigin-RevId: 247761255
This commit is contained in:
Rohan Jain 2019-05-11 09:24:21 -07:00 committed by TensorFlower Gardener
parent 91725f5881
commit 7bbd0940e8
2 changed files with 11 additions and 0 deletions

View File

@ -302,6 +302,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
return s;
}
*end_of_sequence = false;
experimental::SnapshotRecord record;
record.ParseFromString(record_bytes);

View File

@ -114,6 +114,16 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase):
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
def testWriteSnapshotRepeatAfterwards(self):
tmpdir = self.makeSnapshotDirectory()
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(snapshot.snapshot(tmpdir))
dataset = dataset.repeat(10)
self.assertDatasetProduces(dataset, list(range(10)) * 10)
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
def testWriteSnapshotMultiFileSuccessful(self):
tmpdir = self.makeSnapshotDirectory()