Marking end_of_sequence as false explicitly in SnapshotReaderIterator. This fixes issues when we have a repeat afterwards and the end_of_sequence bit propagates from the SnapshotWriterIterator reaching the end of sequence.
PiperOrigin-RevId: 247761255
This commit is contained in:
parent
91725f5881
commit
7bbd0940e8
@ -302,6 +302,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
*end_of_sequence = false;
|
||||||
experimental::SnapshotRecord record;
|
experimental::SnapshotRecord record;
|
||||||
record.ParseFromString(record_bytes);
|
record.ParseFromString(record_bytes);
|
||||||
|
|
||||||
|
@ -114,6 +114,16 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase):
|
|||||||
|
|
||||||
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
|
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
|
||||||
|
|
||||||
|
def testWriteSnapshotRepeatAfterwards(self):
|
||||||
|
tmpdir = self.makeSnapshotDirectory()
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
dataset = dataset.apply(snapshot.snapshot(tmpdir))
|
||||||
|
dataset = dataset.repeat(10)
|
||||||
|
self.assertDatasetProduces(dataset, list(range(10)) * 10)
|
||||||
|
|
||||||
|
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
|
||||||
|
|
||||||
def testWriteSnapshotMultiFileSuccessful(self):
|
def testWriteSnapshotMultiFileSuccessful(self):
|
||||||
tmpdir = self.makeSnapshotDirectory()
|
tmpdir = self.makeSnapshotDirectory()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user