STT-tensorflow/tensorflow/python/kernel_tests/reader_ops_test.py
Brian Atkinson f3058a5cd2 Use resource_loader to access in-tree data.
PiperOrigin-RevId: 292260504
Change-Id: Ifcc8e37fd17658ac802f539ef91b119130c4c839
2020-01-29 18:34:00 -08:00

845 lines
30 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Reader ops from io_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import gzip
import os
import shutil
import sys
import threading
import zlib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import tf_record
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import coordinator
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.util import compat
prefix_path = "tensorflow/core/lib"
# pylint: disable=invalid-name
TFRecordCompressionType = tf_record.TFRecordCompressionType
# pylint: enable=invalid-name
# Edgar Allan Poe's 'Eldorado'
_TEXT = b"""Gaily bedight,
A gallant knight,
In sunshine and in shadow,
Had journeyed long,
Singing a song,
In search of Eldorado.
But he grew old
This knight so bold
And o'er his heart a shadow
Fell as he found
No spot of ground
That looked like Eldorado.
And, as his strength
Failed him at length,
He met a pilgrim shadow
'Shadow,' said he,
'Where can it be
This land of Eldorado?'
'Over the Mountains
Of the Moon'
Down the Valley of the Shadow,
Ride, boldly ride,'
The shade replied,
'If you seek for Eldorado!'
"""
class TFCompressionTestCase(test.TestCase):
def setUp(self):
super(TFCompressionTestCase, self).setUp()
self._num_files = 2
self._num_records = 7
def _Record(self, f, r):
return compat.as_bytes("Record %d of file %d" % (r, f))
def _CreateFiles(self, options=None, prefix=""):
filenames = []
for i in range(self._num_files):
name = prefix + "tfrecord.%d.txt" % i
records = [self._Record(i, j) for j in range(self._num_records)]
fn = self._WriteRecordsToFile(records, name, options)
filenames.append(fn)
return filenames
def _WriteRecordsToFile(self, records, name="tfrecord", options=None):
fn = os.path.join(self.get_temp_dir(), name)
with tf_record.TFRecordWriter(fn, options=options) as writer:
for r in records:
writer.write(r)
return fn
def _ZlibCompressFile(self, infile, name="tfrecord.z"):
# zlib compress the file and write compressed contents to file.
with open(infile, "rb") as f:
cdata = zlib.compress(f.read())
zfn = os.path.join(self.get_temp_dir(), name)
with open(zfn, "wb") as f:
f.write(cdata)
return zfn
def _GzipCompressFile(self, infile, name="tfrecord.gz"):
# gzip compress the file and write compressed contents to file.
with open(infile, "rb") as f:
cdata = f.read()
gzfn = os.path.join(self.get_temp_dir(), name)
with gzip.GzipFile(gzfn, "wb") as f:
f.write(cdata)
return gzfn
def _ZlibDecompressFile(self, infile, name="tfrecord"):
with open(infile, "rb") as f:
cdata = zlib.decompress(f.read())
fn = os.path.join(self.get_temp_dir(), name)
with open(fn, "wb") as f:
f.write(cdata)
return fn
def _GzipDecompressFile(self, infile, name="tfrecord"):
with gzip.GzipFile(infile, "rb") as f:
cdata = f.read()
fn = os.path.join(self.get_temp_dir(), name)
with open(fn, "wb") as f:
f.write(cdata)
return fn
class IdentityReaderTest(test.TestCase):
def _ExpectRead(self, key, value, expected):
k, v = self.evaluate([key, value])
self.assertAllEqual(expected, k)
self.assertAllEqual(expected, v)
@test_util.run_deprecated_v1
def testOneEpoch(self):
reader = io_ops.IdentityReader("test_reader")
work_completed = reader.num_work_units_completed()
produced = reader.num_records_produced()
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
queued_length = queue.size()
key, value = reader.read(queue)
self.assertAllEqual(0, self.evaluate(work_completed))
self.assertAllEqual(0, self.evaluate(produced))
self.assertAllEqual(0, self.evaluate(queued_length))
self.evaluate(queue.enqueue_many([["A", "B", "C"]]))
self.evaluate(queue.close())
self.assertAllEqual(3, self.evaluate(queued_length))
self._ExpectRead(key, value, b"A")
self.assertAllEqual(1, self.evaluate(produced))
self._ExpectRead(key, value, b"B")
self._ExpectRead(key, value, b"C")
self.assertAllEqual(3, self.evaluate(produced))
self.assertAllEqual(0, self.evaluate(queued_length))
with self.assertRaisesOpError("is closed and has insufficient elements "
"\\(requested 1, current size 0\\)"):
self.evaluate([key, value])
self.assertAllEqual(3, self.evaluate(work_completed))
self.assertAllEqual(3, self.evaluate(produced))
self.assertAllEqual(0, self.evaluate(queued_length))
@test_util.run_deprecated_v1
def testMultipleEpochs(self):
reader = io_ops.IdentityReader("test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
enqueue = queue.enqueue_many([["DD", "EE"]])
key, value = reader.read(queue)
self.evaluate(enqueue)
self._ExpectRead(key, value, b"DD")
self._ExpectRead(key, value, b"EE")
self.evaluate(enqueue)
self._ExpectRead(key, value, b"DD")
self._ExpectRead(key, value, b"EE")
self.evaluate(enqueue)
self._ExpectRead(key, value, b"DD")
self._ExpectRead(key, value, b"EE")
self.evaluate(queue.close())
with self.assertRaisesOpError("is closed and has insufficient elements "
"\\(requested 1, current size 0\\)"):
self.evaluate([key, value])
@test_util.run_deprecated_v1
def testSerializeRestore(self):
reader = io_ops.IdentityReader("test_reader")
produced = reader.num_records_produced()
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
self.evaluate(queue.enqueue_many([["X", "Y", "Z"]]))
key, value = reader.read(queue)
self._ExpectRead(key, value, b"X")
self.assertAllEqual(1, self.evaluate(produced))
state = self.evaluate(reader.serialize_state())
self._ExpectRead(key, value, b"Y")
self._ExpectRead(key, value, b"Z")
self.assertAllEqual(3, self.evaluate(produced))
self.evaluate(queue.enqueue_many([["Y", "Z"]]))
self.evaluate(queue.close())
self.evaluate(reader.restore_state(state))
self.assertAllEqual(1, self.evaluate(produced))
self._ExpectRead(key, value, b"Y")
self._ExpectRead(key, value, b"Z")
with self.assertRaisesOpError("is closed and has insufficient elements "
"\\(requested 1, current size 0\\)"):
self.evaluate([key, value])
self.assertAllEqual(3, self.evaluate(produced))
self.assertEqual(bytes, type(state))
with self.assertRaises(ValueError):
reader.restore_state([])
with self.assertRaises(ValueError):
reader.restore_state([state, state])
with self.assertRaisesOpError(
"Could not parse state for IdentityReader 'test_reader'"):
self.evaluate(reader.restore_state(state[1:]))
with self.assertRaisesOpError(
"Could not parse state for IdentityReader 'test_reader'"):
self.evaluate(reader.restore_state(state[:-1]))
with self.assertRaisesOpError(
"Could not parse state for IdentityReader 'test_reader'"):
self.evaluate(reader.restore_state(state + b"ExtraJunk"))
with self.assertRaisesOpError(
"Could not parse state for IdentityReader 'test_reader'"):
self.evaluate(reader.restore_state(b"PREFIX" + state))
with self.assertRaisesOpError(
"Could not parse state for IdentityReader 'test_reader'"):
self.evaluate(reader.restore_state(b"BOGUS" + state[5:]))
@test_util.run_deprecated_v1
def testReset(self):
reader = io_ops.IdentityReader("test_reader")
work_completed = reader.num_work_units_completed()
produced = reader.num_records_produced()
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
queued_length = queue.size()
key, value = reader.read(queue)
self.evaluate(queue.enqueue_many([["X", "Y", "Z"]]))
self._ExpectRead(key, value, b"X")
self.assertLess(0, self.evaluate(queued_length))
self.assertAllEqual(1, self.evaluate(produced))
self._ExpectRead(key, value, b"Y")
self.assertLess(0, self.evaluate(work_completed))
self.assertAllEqual(2, self.evaluate(produced))
self.evaluate(reader.reset())
self.assertAllEqual(0, self.evaluate(work_completed))
self.assertAllEqual(0, self.evaluate(produced))
self.assertAllEqual(1, self.evaluate(queued_length))
self._ExpectRead(key, value, b"Z")
self.evaluate(queue.enqueue_many([["K", "L"]]))
self._ExpectRead(key, value, b"K")
class WholeFileReaderTest(test.TestCase):
def setUp(self):
super(WholeFileReaderTest, self).setUp()
self._filenames = [
os.path.join(self.get_temp_dir(), "whole_file.%d.txt" % i)
for i in range(3)
]
self._content = [b"One\na\nb\n", b"Two\nC\nD", b"Three x, y, z"]
for fn, c in zip(self._filenames, self._content):
with open(fn, "wb") as h:
h.write(c)
def tearDown(self):
for fn in self._filenames:
os.remove(fn)
super(WholeFileReaderTest, self).tearDown()
def _ExpectRead(self, key, value, index):
k, v = self.evaluate([key, value])
self.assertAllEqual(compat.as_bytes(self._filenames[index]), k)
self.assertAllEqual(self._content[index], v)
@test_util.run_deprecated_v1
def testOneEpoch(self):
reader = io_ops.WholeFileReader("test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
self.evaluate(queue.enqueue_many([self._filenames]))
self.evaluate(queue.close())
key, value = reader.read(queue)
self._ExpectRead(key, value, 0)
self._ExpectRead(key, value, 1)
self._ExpectRead(key, value, 2)
with self.assertRaisesOpError("is closed and has insufficient elements "
"\\(requested 1, current size 0\\)"):
self.evaluate([key, value])
@test_util.run_deprecated_v1
def testInfiniteEpochs(self):
reader = io_ops.WholeFileReader("test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
enqueue = queue.enqueue_many([self._filenames])
key, value = reader.read(queue)
self.evaluate(enqueue)
self._ExpectRead(key, value, 0)
self._ExpectRead(key, value, 1)
self.evaluate(enqueue)
self._ExpectRead(key, value, 2)
self._ExpectRead(key, value, 0)
self._ExpectRead(key, value, 1)
self.evaluate(enqueue)
self._ExpectRead(key, value, 2)
self._ExpectRead(key, value, 0)
class TextLineReaderTest(test.TestCase):
def setUp(self):
super(TextLineReaderTest, self).setUp()
self._num_files = 2
self._num_lines = 5
def _LineText(self, f, l):
return compat.as_bytes("%d: %d" % (f, l))
def _CreateFiles(self, crlf=False):
filenames = []
for i in range(self._num_files):
fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
filenames.append(fn)
with open(fn, "wb") as f:
for j in range(self._num_lines):
f.write(self._LineText(i, j))
# Always include a newline after the record unless it is
# at the end of the file, in which case we include it sometimes.
if j + 1 != self._num_lines or i == 0:
f.write(b"\r\n" if crlf else b"\n")
return filenames
def _testOneEpoch(self, files):
reader = io_ops.TextLineReader(name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
self.evaluate(queue.enqueue_many([files]))
self.evaluate(queue.close())
for i in range(self._num_files):
for j in range(self._num_lines):
k, v = self.evaluate([key, value])
self.assertAllEqual("%s:%d" % (files[i], j + 1), compat.as_text(k))
self.assertAllEqual(self._LineText(i, j), v)
with self.assertRaisesOpError("is closed and has insufficient elements "
"\\(requested 1, current size 0\\)"):
k, v = self.evaluate([key, value])
@test_util.run_deprecated_v1
def testOneEpochLF(self):
self._testOneEpoch(self._CreateFiles(crlf=False))
@test_util.run_deprecated_v1
def testOneEpochCRLF(self):
self._testOneEpoch(self._CreateFiles(crlf=True))
@test_util.run_deprecated_v1
def testSkipHeaderLines(self):
files = self._CreateFiles()
reader = io_ops.TextLineReader(skip_header_lines=1, name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
self.evaluate(queue.enqueue_many([files]))
self.evaluate(queue.close())
for i in range(self._num_files):
for j in range(self._num_lines - 1):
k, v = self.evaluate([key, value])
self.assertAllEqual("%s:%d" % (files[i], j + 2), compat.as_text(k))
self.assertAllEqual(self._LineText(i, j + 1), v)
with self.assertRaisesOpError("is closed and has insufficient elements "
"\\(requested 1, current size 0\\)"):
k, v = self.evaluate([key, value])
class FixedLengthRecordReaderTest(TFCompressionTestCase):
def setUp(self):
super(FixedLengthRecordReaderTest, self).setUp()
self._num_files = 2
self._header_bytes = 5
self._record_bytes = 3
self._footer_bytes = 2
self._hop_bytes = 2
def _Record(self, f, r):
return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
def _OverlappedRecord(self, f, r):
record_str = "".join([
str(i)[0]
for i in range(r * self._hop_bytes,
r * self._hop_bytes + self._record_bytes)
])
return compat.as_bytes(record_str)
# gap_bytes=hop_bytes-record_bytes
def _CreateFiles(self, num_records, gap_bytes):
filenames = []
for i in range(self._num_files):
fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
filenames.append(fn)
with open(fn, "wb") as f:
f.write(b"H" * self._header_bytes)
if num_records > 0:
f.write(self._Record(i, 0))
for j in range(1, num_records):
if gap_bytes > 0:
f.write(b"G" * gap_bytes)
f.write(self._Record(i, j))
f.write(b"F" * self._footer_bytes)
return filenames
def _CreateOverlappedRecordFiles(self, num_overlapped_records):
filenames = []
for i in range(self._num_files):
fn = os.path.join(self.get_temp_dir(),
"fixed_length_overlapped_record.%d.txt" % i)
filenames.append(fn)
with open(fn, "wb") as f:
f.write(b"H" * self._header_bytes)
if num_overlapped_records > 0:
all_records_str = "".join([
str(i)[0]
for i in range(self._record_bytes + self._hop_bytes *
(num_overlapped_records - 1))
])
f.write(compat.as_bytes(all_records_str))
f.write(b"F" * self._footer_bytes)
return filenames
# gap_bytes=hop_bytes-record_bytes
def _CreateGzipFiles(self, num_records, gap_bytes):
filenames = self._CreateFiles(num_records, gap_bytes)
for fn in filenames:
# compress inplace.
self._GzipCompressFile(fn, fn)
return filenames
# gap_bytes=hop_bytes-record_bytes
def _CreateZlibFiles(self, num_records, gap_bytes):
filenames = self._CreateFiles(num_records, gap_bytes)
for fn in filenames:
# compress inplace.
self._ZlibCompressFile(fn, fn)
return filenames
def _CreateGzipOverlappedRecordFiles(self, num_overlapped_records):
filenames = []
for i in range(self._num_files):
fn = os.path.join(self.get_temp_dir(),
"fixed_length_overlapped_record.%d.txt" % i)
filenames.append(fn)
with gzip.GzipFile(fn, "wb") as f:
f.write(b"H" * self._header_bytes)
if num_overlapped_records > 0:
all_records_str = "".join([
str(i)[0]
for i in range(self._record_bytes + self._hop_bytes *
(num_overlapped_records - 1))
])
f.write(compat.as_bytes(all_records_str))
f.write(b"F" * self._footer_bytes)
return filenames
def _CreateZlibOverlappedRecordFiles(self, num_overlapped_records):
filenames = []
for i in range(self._num_files):
fn = os.path.join(self.get_temp_dir(),
"fixed_length_overlapped_record.%d.txt" % i)
filenames.append(fn)
with open(fn + ".tmp", "wb") as f:
f.write(b"H" * self._header_bytes)
if num_overlapped_records > 0:
all_records_str = "".join([
str(i)[0]
for i in range(self._record_bytes + self._hop_bytes *
(num_overlapped_records - 1))
])
f.write(compat.as_bytes(all_records_str))
f.write(b"F" * self._footer_bytes)
self._ZlibCompressFile(fn + ".tmp", fn)
return filenames
# gap_bytes=hop_bytes-record_bytes
def _TestOneEpoch(self, files, num_records, gap_bytes, encoding=None):
hop_bytes = 0 if gap_bytes == 0 else self._record_bytes + gap_bytes
reader = io_ops.FixedLengthRecordReader(
header_bytes=self._header_bytes,
record_bytes=self._record_bytes,
footer_bytes=self._footer_bytes,
hop_bytes=hop_bytes,
encoding=encoding,
name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
self.evaluate(queue.enqueue_many([files]))
self.evaluate(queue.close())
for i in range(self._num_files):
for j in range(num_records):
k, v = self.evaluate([key, value])
self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
self.assertAllEqual(self._Record(i, j), v)
with self.assertRaisesOpError("is closed and has insufficient elements "
"\\(requested 1, current size 0\\)"):
k, v = self.evaluate([key, value])
def _TestOneEpochWithHopBytes(self,
files,
num_overlapped_records,
encoding=None):
reader = io_ops.FixedLengthRecordReader(
header_bytes=self._header_bytes,
record_bytes=self._record_bytes,
footer_bytes=self._footer_bytes,
hop_bytes=self._hop_bytes,
encoding=encoding,
name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
self.evaluate(queue.enqueue_many([files]))
self.evaluate(queue.close())
for i in range(self._num_files):
for j in range(num_overlapped_records):
k, v = self.evaluate([key, value])
self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
self.assertAllEqual(self._OverlappedRecord(i, j), v)
with self.assertRaisesOpError("is closed and has insufficient elements "
"\\(requested 1, current size 0\\)"):
k, v = self.evaluate([key, value])
@test_util.run_deprecated_v1
def testOneEpoch(self):
for num_records in [0, 7]:
# gap_bytes=0: hop_bytes=0
# gap_bytes=1: hop_bytes=record_bytes+1
for gap_bytes in [0, 1]:
files = self._CreateFiles(num_records, gap_bytes)
self._TestOneEpoch(files, num_records, gap_bytes)
@test_util.run_deprecated_v1
def testGzipOneEpoch(self):
for num_records in [0, 7]:
# gap_bytes=0: hop_bytes=0
# gap_bytes=1: hop_bytes=record_bytes+1
for gap_bytes in [0, 1]:
files = self._CreateGzipFiles(num_records, gap_bytes)
self._TestOneEpoch(files, num_records, gap_bytes, encoding="GZIP")
@test_util.run_deprecated_v1
def testZlibOneEpoch(self):
for num_records in [0, 7]:
# gap_bytes=0: hop_bytes=0
# gap_bytes=1: hop_bytes=record_bytes+1
for gap_bytes in [0, 1]:
files = self._CreateZlibFiles(num_records, gap_bytes)
self._TestOneEpoch(files, num_records, gap_bytes, encoding="ZLIB")
@test_util.run_deprecated_v1
def testOneEpochWithHopBytes(self):
for num_overlapped_records in [0, 2]:
files = self._CreateOverlappedRecordFiles(num_overlapped_records)
self._TestOneEpochWithHopBytes(files, num_overlapped_records)
@test_util.run_deprecated_v1
def testGzipOneEpochWithHopBytes(self):
for num_overlapped_records in [0, 2]:
files = self._CreateGzipOverlappedRecordFiles(num_overlapped_records,)
self._TestOneEpochWithHopBytes(
files, num_overlapped_records, encoding="GZIP")
@test_util.run_deprecated_v1
def testZlibOneEpochWithHopBytes(self):
for num_overlapped_records in [0, 2]:
files = self._CreateZlibOverlappedRecordFiles(num_overlapped_records)
self._TestOneEpochWithHopBytes(
files, num_overlapped_records, encoding="ZLIB")
class TFRecordReaderTest(TFCompressionTestCase):
def setUp(self):
super(TFRecordReaderTest, self).setUp()
@test_util.run_deprecated_v1
def testOneEpoch(self):
files = self._CreateFiles()
reader = io_ops.TFRecordReader(name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
self.evaluate(queue.enqueue_many([files]))
self.evaluate(queue.close())
for i in range(self._num_files):
for j in range(self._num_records):
k, v = self.evaluate([key, value])
self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
self.assertAllEqual(self._Record(i, j), v)
with self.assertRaisesOpError("is closed and has insufficient elements "
"\\(requested 1, current size 0\\)"):
k, v = self.evaluate([key, value])
@test_util.run_deprecated_v1
def testReadUpTo(self):
files = self._CreateFiles()
reader = io_ops.TFRecordReader(name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
batch_size = 3
key, value = reader.read_up_to(queue, batch_size)
self.evaluate(queue.enqueue_many([files]))
self.evaluate(queue.close())
num_k = 0
num_v = 0
while True:
try:
k, v = self.evaluate([key, value])
# Test reading *up to* batch_size records
self.assertLessEqual(len(k), batch_size)
self.assertLessEqual(len(v), batch_size)
num_k += len(k)
num_v += len(v)
except errors_impl.OutOfRangeError:
break
# Test that we have read everything
self.assertEqual(self._num_files * self._num_records, num_k)
self.assertEqual(self._num_files * self._num_records, num_v)
@test_util.run_deprecated_v1
def testReadZlibFiles(self):
options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
files = self._CreateFiles(options)
reader = io_ops.TFRecordReader(name="test_reader", options=options)
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
self.evaluate(queue.enqueue_many([files]))
self.evaluate(queue.close())
for i in range(self._num_files):
for j in range(self._num_records):
k, v = self.evaluate([key, value])
self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
self.assertAllEqual(self._Record(i, j), v)
@test_util.run_deprecated_v1
def testReadGzipFiles(self):
options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
files = self._CreateFiles(options)
reader = io_ops.TFRecordReader(name="test_reader", options=options)
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
self.evaluate(queue.enqueue_many([files]))
self.evaluate(queue.close())
for i in range(self._num_files):
for j in range(self._num_records):
k, v = self.evaluate([key, value])
self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
self.assertAllEqual(self._Record(i, j), v)
class AsyncReaderTest(test.TestCase):
@test_util.run_deprecated_v1
def testNoDeadlockFromQueue(self):
"""Tests that reading does not block main execution threads."""
config = config_pb2.ConfigProto(
inter_op_parallelism_threads=1, intra_op_parallelism_threads=1)
with self.session(config=config) as sess:
thread_data_t = collections.namedtuple("thread_data_t",
["thread", "queue", "output"])
thread_data = []
# Create different readers, each with its own queue.
for i in range(3):
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
reader = io_ops.TextLineReader()
_, line = reader.read(queue)
output = []
t = threading.Thread(
target=AsyncReaderTest._RunSessionAndSave,
args=(sess, [line], output))
thread_data.append(thread_data_t(t, queue, output))
# Start all readers. They are all blocked waiting for queue entries.
self.evaluate(variables.global_variables_initializer())
for d in thread_data:
d.thread.start()
# Unblock the readers.
for i, d in enumerate(reversed(thread_data)):
fname = os.path.join(self.get_temp_dir(), "deadlock.%s.txt" % i)
with open(fname, "wb") as f:
f.write(("file-%s" % i).encode())
self.evaluate(d.queue.enqueue_many([[fname]]))
d.thread.join()
self.assertEqual([[("file-%s" % i).encode()]], d.output)
@staticmethod
def _RunSessionAndSave(sess, args, output):
output.append(sess.run(args))
class LMDBReaderTest(test.TestCase):
def setUp(self):
super(LMDBReaderTest, self).setUp()
# Copy database out because we need the path to be writable to use locks.
# The on-disk format of an LMDB file is different on big-endian machines,
# because LMDB is a memory-mapped database.
db_file = "data.mdb" if sys.byteorder == "little" else "data_bigendian.mdb"
path = os.path.join(prefix_path, "lmdb", "testdata", db_file)
self.db_path = os.path.join(self.get_temp_dir(), "data.mdb")
shutil.copy(path, self.db_path)
@test_util.run_deprecated_v1
def testReadFromFile(self):
reader = io_ops.LMDBReader(name="test_read_from_file")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
self.evaluate(queue.enqueue([self.db_path]))
self.evaluate(queue.close())
for i in range(10):
k, v = self.evaluate([key, value])
self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i)))
self.assertAllEqual(
compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + i))))
with self.assertRaisesOpError("is closed and has insufficient elements "
"\\(requested 1, current size 0\\)"):
k, v = self.evaluate([key, value])
@test_util.run_deprecated_v1
def testReadFromSameFile(self):
with self.cached_session() as sess:
reader1 = io_ops.LMDBReader(name="test_read_from_same_file1")
reader2 = io_ops.LMDBReader(name="test_read_from_same_file2")
filename_queue = input_lib.string_input_producer(
[self.db_path], num_epochs=None)
key1, value1 = reader1.read(filename_queue)
key2, value2 = reader2.read(filename_queue)
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
for _ in range(3):
for _ in range(10):
k1, v1, k2, v2 = self.evaluate([key1, value1, key2, value2])
self.assertAllEqual(compat.as_bytes(k1), compat.as_bytes(k2))
self.assertAllEqual(compat.as_bytes(v1), compat.as_bytes(v2))
coord.request_stop()
coord.join(threads)
@test_util.run_deprecated_v1
def testReadFromFolder(self):
reader = io_ops.LMDBReader(name="test_read_from_folder")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
self.evaluate(queue.enqueue([self.db_path]))
self.evaluate(queue.close())
for i in range(10):
k, v = self.evaluate([key, value])
self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i)))
self.assertAllEqual(
compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + i))))
with self.assertRaisesOpError("is closed and has insufficient elements "
"\\(requested 1, current size 0\\)"):
k, v = self.evaluate([key, value])
@test_util.run_deprecated_v1
def testReadFromFileRepeatedly(self):
with self.cached_session() as sess:
reader = io_ops.LMDBReader(name="test_read_from_file_repeated")
filename_queue = input_lib.string_input_producer(
[self.db_path], num_epochs=None)
key, value = reader.read(filename_queue)
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
# Iterate over the lmdb 3 times.
for _ in range(3):
# Go over all 10 records each time.
for j in range(10):
k, v = self.evaluate([key, value])
self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(j)))
self.assertAllEqual(
compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + j))))
coord.request_stop()
coord.join(threads)
if __name__ == "__main__":
test.main()