diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py index 33bb8cf51fa..a7fc169d18a 100644 --- a/tensorflow/python/training/coordinator.py +++ b/tensorflow/python/training/coordinator.py @@ -148,6 +148,8 @@ class Coordinator(object): # If not None, it should hold the returned value of sys.exc_info(), which is # a tuple containing exception (type, value, traceback). self._exc_info_to_raise = None + # True if we have called join() already. + self._joined = False def _filter_exception(self, ex): """Check if the exception indicated in 'ex' should be ignored. @@ -189,8 +191,19 @@ class Coordinator(object): `sys.exc_info()`. If this is the first call to `request_stop()` the corresponding exception is recorded and re-raised from `join()`. """ - ex = self._filter_exception(ex) with self._lock: + ex = self._filter_exception(ex) + # If we have already joined the coordinator the exception will not have a + # chance to be reported, so just raise it normally. This can happen if + # you continue to use a session have having stopped and joined the + # coordinator threads. + if self._joined: + if isinstance(ex, tuple): + six.reraise(*ex) + elif ex is not None: + # NOTE(touts): This is bogus if request_stop() is not called + # from the exception handler that raised ex. + six.reraise(*sys.exc_info()) if not self._stop_event.is_set(): if ex and self._exc_info_to_raise is None: if isinstance(ex, tuple): @@ -217,6 +230,8 @@ class Coordinator(object): % self._exc_info_to_raise) except ValueError: # Record this error so it kills the coordinator properly. + # NOTE(touts): As above, this is bogus if request_stop() is not + # called from the exception handler that raised ex. self._exc_info_to_raise = sys.exc_info() self._stop_event.set() @@ -227,6 +242,7 @@ class Coordinator(object): After this is called, calls to `should_stop()` will return `False`. """ with self._lock: + self._joined = False self._exc_info_to_raise = None if self._stop_event.is_set(): self._stop_event.clear() @@ -336,6 +352,7 @@ class Coordinator(object): # Terminate with an exception if appropriate. with self._lock: + self._joined = True if self._exc_info_to_raise: six.reraise(*self._exc_info_to_raise) elif stragglers: @@ -343,6 +360,10 @@ class Coordinator(object): "Coordinator stopped with threads still running: %s" % " ".join(stragglers)) + @property + def joined(self): + return self._joined + # Threads for the standard services. class LooperThread(threading.Thread): diff --git a/tensorflow/python/training/coordinator_test.py b/tensorflow/python/training/coordinator_test.py index 95c7cb7160c..764307fd7d2 100644 --- a/tensorflow/python/training/coordinator_test.py +++ b/tensorflow/python/training/coordinator_test.py @@ -174,6 +174,49 @@ class CoordinatorTest(tf.test.TestCase): with self.assertRaisesRegexp(RuntimeError, "Second"): coord.join(threads) + def testRequestStopRaisesIfJoined(self): + coord = tf.train.Coordinator() + # Join the coordinator right away. + coord.join([]) + reported = False + with self.assertRaisesRegexp(RuntimeError, "Too late"): + try: + raise RuntimeError("Too late") + except RuntimeError as e: + reported = True + coord.request_stop(e) + self.assertTrue(reported) + # If we clear_stop the exceptions are handled normally. + coord.clear_stop() + try: + raise RuntimeError("After clear") + except RuntimeError as e: + coord.request_stop(e) + with self.assertRaisesRegexp(RuntimeError, "After clear"): + coord.join([]) + + def testRequestStopRaisesIfJoined_ExcInfo(self): + # Same as testRequestStopRaisesIfJoined but using syc.exc_info(). + coord = tf.train.Coordinator() + # Join the coordinator right away. + coord.join([]) + reported = False + with self.assertRaisesRegexp(RuntimeError, "Too late"): + try: + raise RuntimeError("Too late") + except RuntimeError: + reported = True + coord.request_stop(sys.exc_info()) + self.assertTrue(reported) + # If we clear_stop the exceptions are handled normally. + coord.clear_stop() + try: + raise RuntimeError("After clear") + except RuntimeError: + coord.request_stop(sys.exc_info()) + with self.assertRaisesRegexp(RuntimeError, "After clear"): + coord.join([]) + def _StopAt0(coord, n): if n[0] == 0: