Simplify the Coordinator API by automatically tracking the set of coordinated
threads and computing the list of threads to join automatically. Change: 129249818
This commit is contained in:
parent
2ac55447c6
commit
a0812ee71d
@ -150,6 +150,11 @@ class Coordinator(object):
|
||||
self._exc_info_to_raise = None
|
||||
# True if we have called join() already.
|
||||
self._joined = False
|
||||
# Set of threads registered for joining when join() is called. These
|
||||
# threads will be joined in addition to the threads passed to the join()
|
||||
# call. It's ok if threads are both registered and passed to the join()
|
||||
# call.
|
||||
self._registered_threads = set()
|
||||
|
||||
def _filter_exception(self, ex):
|
||||
"""Check if the exception indicated in 'ex' should be ignored.
|
||||
@ -305,10 +310,22 @@ class Coordinator(object):
|
||||
"""
|
||||
return self._stop_event.wait(timeout)
|
||||
|
||||
def join(self, threads, stop_grace_period_secs=120):
|
||||
def register_thread(self, thread):
|
||||
"""Register a thread to join.
|
||||
|
||||
Args:
|
||||
thread: A Python thread to join.
|
||||
"""
|
||||
with self._lock:
|
||||
self._registered_threads.add(thread)
|
||||
|
||||
def join(self, threads=None, stop_grace_period_secs=120):
|
||||
"""Wait for threads to terminate.
|
||||
|
||||
Blocks until all `threads` have terminated or `request_stop()` is called.
|
||||
This call blocks until a set of threads have terminated. The set of thread
|
||||
is the union of the threads passed in the `threads` argument and the list
|
||||
of threads that registered with the coordinator by calling
|
||||
`Coordinator.register_thread()`.
|
||||
|
||||
After the threads stop, if an `exc_info` was passed to `request_stop`, that
|
||||
exception is re-raised.
|
||||
@ -320,7 +337,8 @@ class Coordinator(object):
|
||||
that `RuntimeError`.
|
||||
|
||||
Args:
|
||||
threads: List of `threading.Threads`. The started threads to join.
|
||||
threads: List of `threading.Threads`. The started threads to join in
|
||||
addition to the registered threads.
|
||||
stop_grace_period_secs: Number of seconds given to threads to stop after
|
||||
`request_stop()` has been called.
|
||||
|
||||
@ -328,6 +346,13 @@ class Coordinator(object):
|
||||
RuntimeError: If any thread is still alive after `request_stop()`
|
||||
is called and the grace period expires.
|
||||
"""
|
||||
# Threads registered after this call will not be joined.
|
||||
with self._lock:
|
||||
if threads is None:
|
||||
threads = self._registered_threads
|
||||
else:
|
||||
threads = self._registered_threads.union(set(threads))
|
||||
|
||||
# Wait for all threads to stop or for request_stop() to be called.
|
||||
while any(t.is_alive() for t in threads) and not self.wait_for_stop(1.0):
|
||||
pass
|
||||
@ -353,6 +378,7 @@ class Coordinator(object):
|
||||
# Terminate with an exception if appropriate.
|
||||
with self._lock:
|
||||
self._joined = True
|
||||
self._registered_threads = set()
|
||||
if self._exc_info_to_raise:
|
||||
six.reraise(*self._exc_info_to_raise)
|
||||
elif stragglers:
|
||||
@ -411,6 +437,7 @@ class LooperThread(threading.Thread):
|
||||
elif args or kwargs:
|
||||
raise ValueError("'args' and 'kwargs' argument require that you also "
|
||||
"pass 'target'")
|
||||
self._coord.register_thread(self)
|
||||
|
||||
@staticmethod
|
||||
def loop(coord, timer_interval_secs, target, args=None, kwargs=None):
|
||||
|
@ -47,7 +47,9 @@ def RaiseInNUsingContextHandler(coord, n_secs, ex):
|
||||
raise ex
|
||||
|
||||
|
||||
def SleepABit(n_secs):
|
||||
def SleepABit(n_secs, coord=None):
|
||||
if coord:
|
||||
coord.register_thread(threading.current_thread())
|
||||
time.sleep(n_secs)
|
||||
|
||||
|
||||
@ -80,6 +82,33 @@ class CoordinatorTest(tf.test.TestCase):
|
||||
for t in threads:
|
||||
t.start()
|
||||
coord.join(threads)
|
||||
for t in threads:
|
||||
self.assertFalse(t.is_alive())
|
||||
|
||||
def testJoinAllRegistered(self):
|
||||
coord = tf.train.Coordinator()
|
||||
threads = [
|
||||
threading.Thread(target=SleepABit, args=(0.01, coord)),
|
||||
threading.Thread(target=SleepABit, args=(0.02, coord)),
|
||||
threading.Thread(target=SleepABit, args=(0.01, coord))]
|
||||
for t in threads:
|
||||
t.start()
|
||||
coord.join()
|
||||
for t in threads:
|
||||
self.assertFalse(t.is_alive())
|
||||
|
||||
def testJoinSomeRegistered(self):
|
||||
coord = tf.train.Coordinator()
|
||||
threads = [
|
||||
threading.Thread(target=SleepABit, args=(0.01, coord)),
|
||||
threading.Thread(target=SleepABit, args=(0.02)),
|
||||
threading.Thread(target=SleepABit, args=(0.01, coord))]
|
||||
for t in threads:
|
||||
t.start()
|
||||
# threads[1] is not registred we must pass it in.
|
||||
coord.join(threads[1:1])
|
||||
for t in threads:
|
||||
self.assertFalse(t.is_alive())
|
||||
|
||||
def testJoinGraceExpires(self):
|
||||
def TestWithGracePeriod(stop_grace_period):
|
||||
|
@ -176,6 +176,8 @@ class QueueRunner(object):
|
||||
coord: Optional Coordinator object for reporting errors and checking
|
||||
for stop conditions.
|
||||
"""
|
||||
if coord:
|
||||
coord.register_thread(threading.current_thread())
|
||||
decremented = False
|
||||
try:
|
||||
while True:
|
||||
@ -218,6 +220,7 @@ class QueueRunner(object):
|
||||
cancel_op: The Operation to run.
|
||||
coord: Coordinator.
|
||||
"""
|
||||
coord.register_thread(threading.current_thread())
|
||||
coord.wait_for_stop()
|
||||
try:
|
||||
sess.run(cancel_op)
|
||||
|
@ -122,7 +122,7 @@ class QueueRunnerTest(tf.test.TestCase):
|
||||
threads = qr.create_threads(sess, coord)
|
||||
for t in threads:
|
||||
t.start()
|
||||
coord.join(threads)
|
||||
coord.join()
|
||||
self.assertEqual(0, len(qr.exceptions_raised))
|
||||
# The variable should be 0.
|
||||
self.assertEqual(0, var.eval())
|
||||
@ -137,7 +137,7 @@ class QueueRunnerTest(tf.test.TestCase):
|
||||
t.start()
|
||||
# The exception should be re-raised when joining.
|
||||
with self.assertRaisesRegexp(ValueError, "Operation not in the graph"):
|
||||
coord.join(threads)
|
||||
coord.join()
|
||||
|
||||
def testGracePeriod(self):
|
||||
with self.test_session() as sess:
|
||||
@ -147,14 +147,14 @@ class QueueRunnerTest(tf.test.TestCase):
|
||||
dequeue = queue.dequeue()
|
||||
qr = tf.train.QueueRunner(queue, [enqueue])
|
||||
coord = tf.train.Coordinator()
|
||||
threads = qr.create_threads(sess, coord, start=True)
|
||||
qr.create_threads(sess, coord, start=True)
|
||||
# Dequeue one element and then request stop.
|
||||
dequeue.op.run()
|
||||
time.sleep(0.02)
|
||||
coord.request_stop()
|
||||
# We should be able to join because the RequestStop() will cause
|
||||
# the queue to be closed and the enqueue to terminate.
|
||||
coord.join(threads, stop_grace_period_secs=0.05)
|
||||
coord.join(stop_grace_period_secs=0.05)
|
||||
|
||||
def testIgnoreMultiStarts(self):
|
||||
with self.test_session() as sess:
|
||||
@ -171,7 +171,7 @@ class QueueRunnerTest(tf.test.TestCase):
|
||||
new_threads = qr.create_threads(sess, coord=coord)
|
||||
self.assertEqual([], new_threads)
|
||||
coord.request_stop()
|
||||
coord.join(threads, stop_grace_period_secs=0.5)
|
||||
coord.join(stop_grace_period_secs=0.5)
|
||||
|
||||
def testThreads(self):
|
||||
with self.test_session() as sess:
|
||||
|
@ -296,7 +296,6 @@ class Supervisor(object):
|
||||
self._graph = graph
|
||||
self._is_chief = is_chief
|
||||
self._coord = coordinator.Coordinator()
|
||||
self._started_threads = []
|
||||
self._recovery_wait_secs = recovery_wait_secs
|
||||
self._stop_grace_secs = stop_grace_secs
|
||||
self._init_fn = init_fn
|
||||
@ -636,8 +635,6 @@ class Supervisor(object):
|
||||
threads.append(SVTimerCheckpointThread(self, sess))
|
||||
for t in threads:
|
||||
t.start()
|
||||
self._started_threads.extend(threads)
|
||||
|
||||
return threads
|
||||
|
||||
def prepare_or_wait_for_session(self, master="", config=None,
|
||||
@ -712,7 +709,6 @@ class Supervisor(object):
|
||||
for qr in queue_runners:
|
||||
threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True,
|
||||
start=True))
|
||||
self._started_threads.extend(threads)
|
||||
return threads
|
||||
|
||||
def loop(self, timer_interval_secs, target, args=None, kwargs=None):
|
||||
@ -737,7 +733,6 @@ class Supervisor(object):
|
||||
looper = coordinator.LooperThread(self._coord, timer_interval_secs,
|
||||
target=target, args=args, kwargs=kwargs)
|
||||
looper.start()
|
||||
self._started_threads.append(looper)
|
||||
return looper
|
||||
|
||||
def stop(self, threads=None, close_summary_writer=True):
|
||||
@ -755,16 +750,12 @@ class Supervisor(object):
|
||||
`True` if the summary writer was created by the supervisor, `False`
|
||||
otherwise.
|
||||
"""
|
||||
join_threads = []
|
||||
join_threads.extend(self._started_threads)
|
||||
if threads is not None:
|
||||
join_threads.extend(threads)
|
||||
self._coord.request_stop()
|
||||
try:
|
||||
# coord.join() re-raises the first reported exception; the "finally"
|
||||
# block ensures that we clean up whether or not an exception was
|
||||
# reported.
|
||||
self._coord.join(join_threads,
|
||||
self._coord.join(threads,
|
||||
stop_grace_period_secs=self._stop_grace_secs)
|
||||
finally:
|
||||
# Close the writer last, in case one of the running threads was using it.
|
||||
@ -775,8 +766,6 @@ class Supervisor(object):
|
||||
self._summary_writer.close()
|
||||
self._graph_added_to_summary = False
|
||||
|
||||
self._started_threads = []
|
||||
|
||||
def request_stop(self, ex=None):
|
||||
"""Request that the coordinator stop the threads.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user