Merge pull request #5246 from gunan/master
Fix graph_io_test flakiness.
This commit is contained in:
commit
66f979714c
@ -237,7 +237,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
session.run(tf.initialize_local_variables())
|
||||
|
||||
coord = tf.train.Coordinator()
|
||||
tf.train.start_queue_runners(session, coord=coord)
|
||||
threads = tf.train.start_queue_runners(session, coord=coord)
|
||||
|
||||
self.assertAllEqual(session.run(inputs), [b"ABC"])
|
||||
self.assertAllEqual(session.run(inputs), [b"DEF"])
|
||||
@ -246,6 +246,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
session.run(inputs)
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
def test_read_keyed_batch_features_mutual_exclusive_args(self):
|
||||
filename = self._create_temp_file("abcde")
|
||||
@ -297,6 +298,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
coord.request_stop()
|
||||
|
||||
coord.join(threads)
|
||||
|
||||
parsed_records = [item for sublist in [d["sequence"] for d in data]
|
||||
for item in sublist]
|
||||
# Check that the number of records matches expected and all records
|
||||
@ -320,7 +322,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
session.run(tf.initialize_local_variables())
|
||||
|
||||
coord = tf.train.Coordinator()
|
||||
tf.train.start_queue_runners(session, coord=coord)
|
||||
threads = tf.train.start_queue_runners(session, coord=coord)
|
||||
|
||||
self.assertEqual("%s:1" % name, inputs.name)
|
||||
file_name_queue_name = "%s/file_name_queue" % name
|
||||
@ -341,6 +343,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
session.run(inputs)
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
def test_read_text_lines_multifile_with_shared_queue(self):
|
||||
gfile.Glob = self._orig_glob
|
||||
@ -362,7 +365,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
session.run(tf.initialize_local_variables())
|
||||
|
||||
coord = tf.train.Coordinator()
|
||||
tf.train.start_queue_runners(session, coord=coord)
|
||||
threads = tf.train.start_queue_runners(session, coord=coord)
|
||||
|
||||
self.assertEqual("%s:1" % name, inputs.name)
|
||||
shared_file_name_queue_name = "%s/file_name_queue" % name
|
||||
@ -385,6 +388,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
session.run(inputs)
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
def _get_qr(self, name):
|
||||
for qr in ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
|
||||
@ -472,7 +476,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
session.run(tf.initialize_local_variables())
|
||||
|
||||
coord = tf.train.Coordinator()
|
||||
tf.train.start_queue_runners(session, coord=coord)
|
||||
threads = tf.train.start_queue_runners(session, coord=coord)
|
||||
|
||||
self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
|
||||
self.assertAllEqual(session.run(inputs), [b"D", b"E"])
|
||||
@ -480,6 +484,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
session.run(inputs)
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
def test_keyed_read_text_lines(self):
|
||||
gfile.Glob = self._orig_glob
|
||||
@ -497,7 +502,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
session.run(tf.initialize_local_variables())
|
||||
|
||||
coord = tf.train.Coordinator()
|
||||
tf.train.start_queue_runners(session, coord=coord)
|
||||
threads = tf.train.start_queue_runners(session, coord=coord)
|
||||
|
||||
self.assertAllEqual(session.run([keys, inputs]),
|
||||
[[filename.encode("utf-8") + b":1"], [b"ABC"]])
|
||||
@ -509,6 +514,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
session.run(inputs)
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
def test_keyed_parse_json(self):
|
||||
gfile.Glob = self._orig_glob
|
||||
@ -534,7 +540,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
session.run(tf.initialize_local_variables())
|
||||
|
||||
coord = tf.train.Coordinator()
|
||||
tf.train.start_queue_runners(session, coord=coord)
|
||||
threads = tf.train.start_queue_runners(session, coord=coord)
|
||||
|
||||
key, age = session.run([keys, inputs["age"]])
|
||||
self.assertAllEqual(age, [[0]])
|
||||
@ -549,6 +555,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
session.run(inputs)
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user