In Coordinator.request_stop(ex) just re-raise the exception immediately if the

coordinator had already been stopped.
This occur in an anti-pattern where you continue to use a session after the
coordinator threads have been stopped and joined.  In that case the reported
exception would be ignored and never reported anywhere.
Change: 128392626
This commit is contained in:
A. Unique TensorFlower 2016-07-25 12:24:58 -08:00 committed by TensorFlower Gardener
parent 6e38ff720b
commit 5bfa4aa142
2 changed files with 65 additions and 1 deletions

View File

@ -148,6 +148,8 @@ class Coordinator(object):
# If not None, it should hold the returned value of sys.exc_info(), which is
# a tuple containing exception (type, value, traceback).
self._exc_info_to_raise = None
# True if we have called join() already.
self._joined = False
def _filter_exception(self, ex):
"""Check if the exception indicated in 'ex' should be ignored.
@ -189,8 +191,19 @@ class Coordinator(object):
`sys.exc_info()`. If this is the first call to `request_stop()` the
corresponding exception is recorded and re-raised from `join()`.
"""
ex = self._filter_exception(ex)
with self._lock:
ex = self._filter_exception(ex)
# If we have already joined the coordinator the exception will not have a
# chance to be reported, so just raise it normally. This can happen if
# you continue to use a session have having stopped and joined the
# coordinator threads.
if self._joined:
if isinstance(ex, tuple):
six.reraise(*ex)
elif ex is not None:
# NOTE(touts): This is bogus if request_stop() is not called
# from the exception handler that raised ex.
six.reraise(*sys.exc_info())
if not self._stop_event.is_set():
if ex and self._exc_info_to_raise is None:
if isinstance(ex, tuple):
@ -217,6 +230,8 @@ class Coordinator(object):
% self._exc_info_to_raise)
except ValueError:
# Record this error so it kills the coordinator properly.
# NOTE(touts): As above, this is bogus if request_stop() is not
# called from the exception handler that raised ex.
self._exc_info_to_raise = sys.exc_info()
self._stop_event.set()
@ -227,6 +242,7 @@ class Coordinator(object):
After this is called, calls to `should_stop()` will return `False`.
"""
with self._lock:
self._joined = False
self._exc_info_to_raise = None
if self._stop_event.is_set():
self._stop_event.clear()
@ -336,6 +352,7 @@ class Coordinator(object):
# Terminate with an exception if appropriate.
with self._lock:
self._joined = True
if self._exc_info_to_raise:
six.reraise(*self._exc_info_to_raise)
elif stragglers:
@ -343,6 +360,10 @@ class Coordinator(object):
"Coordinator stopped with threads still running: %s" %
" ".join(stragglers))
@property
def joined(self):
return self._joined
# Threads for the standard services.
class LooperThread(threading.Thread):

View File

@ -174,6 +174,49 @@ class CoordinatorTest(tf.test.TestCase):
with self.assertRaisesRegexp(RuntimeError, "Second"):
coord.join(threads)
def testRequestStopRaisesIfJoined(self):
coord = tf.train.Coordinator()
# Join the coordinator right away.
coord.join([])
reported = False
with self.assertRaisesRegexp(RuntimeError, "Too late"):
try:
raise RuntimeError("Too late")
except RuntimeError as e:
reported = True
coord.request_stop(e)
self.assertTrue(reported)
# If we clear_stop the exceptions are handled normally.
coord.clear_stop()
try:
raise RuntimeError("After clear")
except RuntimeError as e:
coord.request_stop(e)
with self.assertRaisesRegexp(RuntimeError, "After clear"):
coord.join([])
def testRequestStopRaisesIfJoined_ExcInfo(self):
# Same as testRequestStopRaisesIfJoined but using syc.exc_info().
coord = tf.train.Coordinator()
# Join the coordinator right away.
coord.join([])
reported = False
with self.assertRaisesRegexp(RuntimeError, "Too late"):
try:
raise RuntimeError("Too late")
except RuntimeError:
reported = True
coord.request_stop(sys.exc_info())
self.assertTrue(reported)
# If we clear_stop the exceptions are handled normally.
coord.clear_stop()
try:
raise RuntimeError("After clear")
except RuntimeError:
coord.request_stop(sys.exc_info())
with self.assertRaisesRegexp(RuntimeError, "After clear"):
coord.join([])
def _StopAt0(coord, n):
if n[0] == 0: