Simplify the Coordinator API by automatically tracking the set of coordinated

threads and computing the list of threads to join automatically.
Change: 129249818
This commit is contained in:
A. Unique TensorFlower 2016-08-03 12:18:10 -08:00 committed by TensorFlower Gardener
parent 2ac55447c6
commit a0812ee71d
5 changed files with 69 additions and 21 deletions

View File

@ -150,6 +150,11 @@ class Coordinator(object):
self._exc_info_to_raise = None
# True if we have called join() already.
self._joined = False
# Set of threads registered for joining when join() is called. These
# threads will be joined in addition to the threads passed to the join()
# call. It's ok if threads are both registered and passed to the join()
# call.
self._registered_threads = set()
def _filter_exception(self, ex):
"""Check if the exception indicated in 'ex' should be ignored.
@ -305,10 +310,22 @@ class Coordinator(object):
"""
return self._stop_event.wait(timeout)
def join(self, threads, stop_grace_period_secs=120):
def register_thread(self, thread):
"""Register a thread to join.
Args:
thread: A Python thread to join.
"""
with self._lock:
self._registered_threads.add(thread)
def join(self, threads=None, stop_grace_period_secs=120):
"""Wait for threads to terminate.
Blocks until all `threads` have terminated or `request_stop()` is called.
This call blocks until a set of threads have terminated. The set of thread
is the union of the threads passed in the `threads` argument and the list
of threads that registered with the coordinator by calling
`Coordinator.register_thread()`.
After the threads stop, if an `exc_info` was passed to `request_stop`, that
exception is re-raised.
@ -320,7 +337,8 @@ class Coordinator(object):
that `RuntimeError`.
Args:
threads: List of `threading.Threads`. The started threads to join.
threads: List of `threading.Threads`. The started threads to join in
addition to the registered threads.
stop_grace_period_secs: Number of seconds given to threads to stop after
`request_stop()` has been called.
@ -328,6 +346,13 @@ class Coordinator(object):
RuntimeError: If any thread is still alive after `request_stop()`
is called and the grace period expires.
"""
# Threads registered after this call will not be joined.
with self._lock:
if threads is None:
threads = self._registered_threads
else:
threads = self._registered_threads.union(set(threads))
# Wait for all threads to stop or for request_stop() to be called.
while any(t.is_alive() for t in threads) and not self.wait_for_stop(1.0):
pass
@ -353,6 +378,7 @@ class Coordinator(object):
# Terminate with an exception if appropriate.
with self._lock:
self._joined = True
self._registered_threads = set()
if self._exc_info_to_raise:
six.reraise(*self._exc_info_to_raise)
elif stragglers:
@ -411,6 +437,7 @@ class LooperThread(threading.Thread):
elif args or kwargs:
raise ValueError("'args' and 'kwargs' argument require that you also "
"pass 'target'")
self._coord.register_thread(self)
@staticmethod
def loop(coord, timer_interval_secs, target, args=None, kwargs=None):

View File

@ -47,7 +47,9 @@ def RaiseInNUsingContextHandler(coord, n_secs, ex):
raise ex
def SleepABit(n_secs):
def SleepABit(n_secs, coord=None):
if coord:
coord.register_thread(threading.current_thread())
time.sleep(n_secs)
@ -80,6 +82,33 @@ class CoordinatorTest(tf.test.TestCase):
for t in threads:
t.start()
coord.join(threads)
for t in threads:
self.assertFalse(t.is_alive())
def testJoinAllRegistered(self):
coord = tf.train.Coordinator()
threads = [
threading.Thread(target=SleepABit, args=(0.01, coord)),
threading.Thread(target=SleepABit, args=(0.02, coord)),
threading.Thread(target=SleepABit, args=(0.01, coord))]
for t in threads:
t.start()
coord.join()
for t in threads:
self.assertFalse(t.is_alive())
def testJoinSomeRegistered(self):
coord = tf.train.Coordinator()
threads = [
threading.Thread(target=SleepABit, args=(0.01, coord)),
threading.Thread(target=SleepABit, args=(0.02)),
threading.Thread(target=SleepABit, args=(0.01, coord))]
for t in threads:
t.start()
# threads[1] is not registred we must pass it in.
coord.join(threads[1:1])
for t in threads:
self.assertFalse(t.is_alive())
def testJoinGraceExpires(self):
def TestWithGracePeriod(stop_grace_period):

View File

@ -176,6 +176,8 @@ class QueueRunner(object):
coord: Optional Coordinator object for reporting errors and checking
for stop conditions.
"""
if coord:
coord.register_thread(threading.current_thread())
decremented = False
try:
while True:
@ -218,6 +220,7 @@ class QueueRunner(object):
cancel_op: The Operation to run.
coord: Coordinator.
"""
coord.register_thread(threading.current_thread())
coord.wait_for_stop()
try:
sess.run(cancel_op)

View File

@ -122,7 +122,7 @@ class QueueRunnerTest(tf.test.TestCase):
threads = qr.create_threads(sess, coord)
for t in threads:
t.start()
coord.join(threads)
coord.join()
self.assertEqual(0, len(qr.exceptions_raised))
# The variable should be 0.
self.assertEqual(0, var.eval())
@ -137,7 +137,7 @@ class QueueRunnerTest(tf.test.TestCase):
t.start()
# The exception should be re-raised when joining.
with self.assertRaisesRegexp(ValueError, "Operation not in the graph"):
coord.join(threads)
coord.join()
def testGracePeriod(self):
with self.test_session() as sess:
@ -147,14 +147,14 @@ class QueueRunnerTest(tf.test.TestCase):
dequeue = queue.dequeue()
qr = tf.train.QueueRunner(queue, [enqueue])
coord = tf.train.Coordinator()
threads = qr.create_threads(sess, coord, start=True)
qr.create_threads(sess, coord, start=True)
# Dequeue one element and then request stop.
dequeue.op.run()
time.sleep(0.02)
coord.request_stop()
# We should be able to join because the RequestStop() will cause
# the queue to be closed and the enqueue to terminate.
coord.join(threads, stop_grace_period_secs=0.05)
coord.join(stop_grace_period_secs=0.05)
def testIgnoreMultiStarts(self):
with self.test_session() as sess:
@ -171,7 +171,7 @@ class QueueRunnerTest(tf.test.TestCase):
new_threads = qr.create_threads(sess, coord=coord)
self.assertEqual([], new_threads)
coord.request_stop()
coord.join(threads, stop_grace_period_secs=0.5)
coord.join(stop_grace_period_secs=0.5)
def testThreads(self):
with self.test_session() as sess:

View File

@ -296,7 +296,6 @@ class Supervisor(object):
self._graph = graph
self._is_chief = is_chief
self._coord = coordinator.Coordinator()
self._started_threads = []
self._recovery_wait_secs = recovery_wait_secs
self._stop_grace_secs = stop_grace_secs
self._init_fn = init_fn
@ -636,8 +635,6 @@ class Supervisor(object):
threads.append(SVTimerCheckpointThread(self, sess))
for t in threads:
t.start()
self._started_threads.extend(threads)
return threads
def prepare_or_wait_for_session(self, master="", config=None,
@ -712,7 +709,6 @@ class Supervisor(object):
for qr in queue_runners:
threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True,
start=True))
self._started_threads.extend(threads)
return threads
def loop(self, timer_interval_secs, target, args=None, kwargs=None):
@ -737,7 +733,6 @@ class Supervisor(object):
looper = coordinator.LooperThread(self._coord, timer_interval_secs,
target=target, args=args, kwargs=kwargs)
looper.start()
self._started_threads.append(looper)
return looper
def stop(self, threads=None, close_summary_writer=True):
@ -755,16 +750,12 @@ class Supervisor(object):
`True` if the summary writer was created by the supervisor, `False`
otherwise.
"""
join_threads = []
join_threads.extend(self._started_threads)
if threads is not None:
join_threads.extend(threads)
self._coord.request_stop()
try:
# coord.join() re-raises the first reported exception; the "finally"
# block ensures that we clean up whether or not an exception was
# reported.
self._coord.join(join_threads,
self._coord.join(threads,
stop_grace_period_secs=self._stop_grace_secs)
finally:
# Close the writer last, in case one of the running threads was using it.
@ -775,8 +766,6 @@ class Supervisor(object):
self._summary_writer.close()
self._graph_added_to_summary = False
self._started_threads = []
def request_stop(self, ex=None):
"""Request that the coordinator stop the threads.