From 838312242534f725d45ecfbc9f01e36598440485 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy <gunan@google.com> Date: Thu, 27 Oct 2016 15:32:05 -0700 Subject: [PATCH] Fix graph_io_test flakiness. --- .../python/learn/learn_io/graph_io_test.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py index a0c143e9bb5..f9f42bbfad2 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py @@ -237,7 +237,7 @@ class GraphIOTest(tf.test.TestCase): session.run(tf.initialize_local_variables()) coord = tf.train.Coordinator() - tf.train.start_queue_runners(session, coord=coord) + threads = tf.train.start_queue_runners(session, coord=coord) self.assertAllEqual(session.run(inputs), [b"ABC"]) self.assertAllEqual(session.run(inputs), [b"DEF"]) @@ -246,6 +246,7 @@ class GraphIOTest(tf.test.TestCase): session.run(inputs) coord.request_stop() + coord.join(threads) def test_read_keyed_batch_features_mutual_exclusive_args(self): filename = self._create_temp_file("abcde") @@ -297,6 +298,7 @@ class GraphIOTest(tf.test.TestCase): coord.request_stop() coord.join(threads) + parsed_records = [item for sublist in [d["sequence"] for d in data] for item in sublist] # Check that the number of records matches expected and all records @@ -320,7 +322,7 @@ class GraphIOTest(tf.test.TestCase): session.run(tf.initialize_local_variables()) coord = tf.train.Coordinator() - tf.train.start_queue_runners(session, coord=coord) + threads = tf.train.start_queue_runners(session, coord=coord) self.assertEqual("%s:1" % name, inputs.name) file_name_queue_name = "%s/file_name_queue" % name @@ -341,6 +343,7 @@ class GraphIOTest(tf.test.TestCase): session.run(inputs) coord.request_stop() + coord.join(threads) def test_read_text_lines_multifile_with_shared_queue(self): gfile.Glob = self._orig_glob @@ -362,7 +365,7 @@ class GraphIOTest(tf.test.TestCase): session.run(tf.initialize_local_variables()) coord = tf.train.Coordinator() - tf.train.start_queue_runners(session, coord=coord) + threads = tf.train.start_queue_runners(session, coord=coord) self.assertEqual("%s:1" % name, inputs.name) shared_file_name_queue_name = "%s/file_name_queue" % name @@ -385,6 +388,7 @@ class GraphIOTest(tf.test.TestCase): session.run(inputs) coord.request_stop() + coord.join(threads) def _get_qr(self, name): for qr in ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS): @@ -472,7 +476,7 @@ class GraphIOTest(tf.test.TestCase): session.run(tf.initialize_local_variables()) coord = tf.train.Coordinator() - tf.train.start_queue_runners(session, coord=coord) + threads = tf.train.start_queue_runners(session, coord=coord) self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"]) self.assertAllEqual(session.run(inputs), [b"D", b"E"]) @@ -480,6 +484,7 @@ class GraphIOTest(tf.test.TestCase): session.run(inputs) coord.request_stop() + coord.join(threads) def test_keyed_read_text_lines(self): gfile.Glob = self._orig_glob @@ -497,7 +502,7 @@ class GraphIOTest(tf.test.TestCase): session.run(tf.initialize_local_variables()) coord = tf.train.Coordinator() - tf.train.start_queue_runners(session, coord=coord) + threads = tf.train.start_queue_runners(session, coord=coord) self.assertAllEqual(session.run([keys, inputs]), [[filename.encode("utf-8") + b":1"], [b"ABC"]]) @@ -509,6 +514,7 @@ class GraphIOTest(tf.test.TestCase): session.run(inputs) coord.request_stop() + coord.join(threads) def test_keyed_parse_json(self): gfile.Glob = self._orig_glob @@ -534,7 +540,7 @@ class GraphIOTest(tf.test.TestCase): session.run(tf.initialize_local_variables()) coord = tf.train.Coordinator() - tf.train.start_queue_runners(session, coord=coord) + threads = tf.train.start_queue_runners(session, coord=coord) key, age = session.run([keys, inputs["age"]]) self.assertAllEqual(age, [[0]]) @@ -549,6 +555,7 @@ class GraphIOTest(tf.test.TestCase): session.run(inputs) coord.request_stop() + coord.join(threads) if __name__ == "__main__":