From 838312242534f725d45ecfbc9f01e36598440485 Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy <gunan@google.com>
Date: Thu, 27 Oct 2016 15:32:05 -0700
Subject: [PATCH] Fix graph_io_test flakiness.

---
 .../python/learn/learn_io/graph_io_test.py    | 19 +++++++++++++------
 1 file changed, 13 insertions(+), 6 deletions(-)

diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
index a0c143e9bb5..f9f42bbfad2 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
@@ -237,7 +237,7 @@ class GraphIOTest(tf.test.TestCase):
       session.run(tf.initialize_local_variables())
 
       coord = tf.train.Coordinator()
-      tf.train.start_queue_runners(session, coord=coord)
+      threads = tf.train.start_queue_runners(session, coord=coord)
 
       self.assertAllEqual(session.run(inputs), [b"ABC"])
       self.assertAllEqual(session.run(inputs), [b"DEF"])
@@ -246,6 +246,7 @@ class GraphIOTest(tf.test.TestCase):
         session.run(inputs)
 
       coord.request_stop()
+      coord.join(threads)
 
   def test_read_keyed_batch_features_mutual_exclusive_args(self):
     filename = self._create_temp_file("abcde")
@@ -297,6 +298,7 @@ class GraphIOTest(tf.test.TestCase):
         coord.request_stop()
 
       coord.join(threads)
+
     parsed_records = [item for sublist in [d["sequence"] for d in data]
                       for item in sublist]
     # Check that the number of records matches expected and all records
@@ -320,7 +322,7 @@ class GraphIOTest(tf.test.TestCase):
       session.run(tf.initialize_local_variables())
 
       coord = tf.train.Coordinator()
-      tf.train.start_queue_runners(session, coord=coord)
+      threads = tf.train.start_queue_runners(session, coord=coord)
 
       self.assertEqual("%s:1" % name, inputs.name)
       file_name_queue_name = "%s/file_name_queue" % name
@@ -341,6 +343,7 @@ class GraphIOTest(tf.test.TestCase):
         session.run(inputs)
 
       coord.request_stop()
+      coord.join(threads)
 
   def test_read_text_lines_multifile_with_shared_queue(self):
     gfile.Glob = self._orig_glob
@@ -362,7 +365,7 @@ class GraphIOTest(tf.test.TestCase):
       session.run(tf.initialize_local_variables())
 
       coord = tf.train.Coordinator()
-      tf.train.start_queue_runners(session, coord=coord)
+      threads = tf.train.start_queue_runners(session, coord=coord)
 
       self.assertEqual("%s:1" % name, inputs.name)
       shared_file_name_queue_name = "%s/file_name_queue" % name
@@ -385,6 +388,7 @@ class GraphIOTest(tf.test.TestCase):
         session.run(inputs)
 
       coord.request_stop()
+      coord.join(threads)
 
   def _get_qr(self, name):
     for qr in ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
@@ -472,7 +476,7 @@ class GraphIOTest(tf.test.TestCase):
       session.run(tf.initialize_local_variables())
 
       coord = tf.train.Coordinator()
-      tf.train.start_queue_runners(session, coord=coord)
+      threads = tf.train.start_queue_runners(session, coord=coord)
 
       self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
       self.assertAllEqual(session.run(inputs), [b"D", b"E"])
@@ -480,6 +484,7 @@ class GraphIOTest(tf.test.TestCase):
         session.run(inputs)
 
       coord.request_stop()
+      coord.join(threads)
 
   def test_keyed_read_text_lines(self):
     gfile.Glob = self._orig_glob
@@ -497,7 +502,7 @@ class GraphIOTest(tf.test.TestCase):
       session.run(tf.initialize_local_variables())
 
       coord = tf.train.Coordinator()
-      tf.train.start_queue_runners(session, coord=coord)
+      threads = tf.train.start_queue_runners(session, coord=coord)
 
       self.assertAllEqual(session.run([keys, inputs]),
                           [[filename.encode("utf-8") + b":1"], [b"ABC"]])
@@ -509,6 +514,7 @@ class GraphIOTest(tf.test.TestCase):
         session.run(inputs)
 
       coord.request_stop()
+      coord.join(threads)
 
   def test_keyed_parse_json(self):
     gfile.Glob = self._orig_glob
@@ -534,7 +540,7 @@ class GraphIOTest(tf.test.TestCase):
       session.run(tf.initialize_local_variables())
 
       coord = tf.train.Coordinator()
-      tf.train.start_queue_runners(session, coord=coord)
+      threads = tf.train.start_queue_runners(session, coord=coord)
 
       key, age = session.run([keys, inputs["age"]])
       self.assertAllEqual(age, [[0]])
@@ -549,6 +555,7 @@ class GraphIOTest(tf.test.TestCase):
         session.run(inputs)
 
       coord.request_stop()
+      coord.join(threads)
 
 
 if __name__ == "__main__":