Merge pull request from gunan/master

Fix graph_io_test flakiness.
This commit is contained in:
gunan 2016-10-27 16:11:03 -07:00 committed by GitHub
commit 66f979714c

View File

@ -237,7 +237,7 @@ class GraphIOTest(tf.test.TestCase):
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
tf.train.start_queue_runners(session, coord=coord)
threads = tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run(inputs), [b"ABC"])
self.assertAllEqual(session.run(inputs), [b"DEF"])
@ -246,6 +246,7 @@ class GraphIOTest(tf.test.TestCase):
session.run(inputs)
coord.request_stop()
coord.join(threads)
def test_read_keyed_batch_features_mutual_exclusive_args(self):
filename = self._create_temp_file("abcde")
@ -297,6 +298,7 @@ class GraphIOTest(tf.test.TestCase):
coord.request_stop()
coord.join(threads)
parsed_records = [item for sublist in [d["sequence"] for d in data]
for item in sublist]
# Check that the number of records matches expected and all records
@ -320,7 +322,7 @@ class GraphIOTest(tf.test.TestCase):
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
tf.train.start_queue_runners(session, coord=coord)
threads = tf.train.start_queue_runners(session, coord=coord)
self.assertEqual("%s:1" % name, inputs.name)
file_name_queue_name = "%s/file_name_queue" % name
@ -341,6 +343,7 @@ class GraphIOTest(tf.test.TestCase):
session.run(inputs)
coord.request_stop()
coord.join(threads)
def test_read_text_lines_multifile_with_shared_queue(self):
gfile.Glob = self._orig_glob
@ -362,7 +365,7 @@ class GraphIOTest(tf.test.TestCase):
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
tf.train.start_queue_runners(session, coord=coord)
threads = tf.train.start_queue_runners(session, coord=coord)
self.assertEqual("%s:1" % name, inputs.name)
shared_file_name_queue_name = "%s/file_name_queue" % name
@ -385,6 +388,7 @@ class GraphIOTest(tf.test.TestCase):
session.run(inputs)
coord.request_stop()
coord.join(threads)
def _get_qr(self, name):
for qr in ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
@ -472,7 +476,7 @@ class GraphIOTest(tf.test.TestCase):
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
tf.train.start_queue_runners(session, coord=coord)
threads = tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
self.assertAllEqual(session.run(inputs), [b"D", b"E"])
@ -480,6 +484,7 @@ class GraphIOTest(tf.test.TestCase):
session.run(inputs)
coord.request_stop()
coord.join(threads)
def test_keyed_read_text_lines(self):
gfile.Glob = self._orig_glob
@ -497,7 +502,7 @@ class GraphIOTest(tf.test.TestCase):
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
tf.train.start_queue_runners(session, coord=coord)
threads = tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run([keys, inputs]),
[[filename.encode("utf-8") + b":1"], [b"ABC"]])
@ -509,6 +514,7 @@ class GraphIOTest(tf.test.TestCase):
session.run(inputs)
coord.request_stop()
coord.join(threads)
def test_keyed_parse_json(self):
gfile.Glob = self._orig_glob
@ -534,7 +540,7 @@ class GraphIOTest(tf.test.TestCase):
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
tf.train.start_queue_runners(session, coord=coord)
threads = tf.train.start_queue_runners(session, coord=coord)
key, age = session.run([keys, inputs["age"]])
self.assertAllEqual(age, [[0]])
@ -549,6 +555,7 @@ class GraphIOTest(tf.test.TestCase):
session.run(inputs)
coord.request_stop()
coord.join(threads)
if __name__ == "__main__":