diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index 6263351a53e..e5484e02b5e 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -855,6 +855,27 @@ class _RecoverableSession(_WrappedSession): 'or parameter server. A new session will be created. ' 'Error: %s', e) + def _check_stop(self): + try: + if self._sess: + return self._sess._check_stop() # pylint: disable=protected-access + else: + return True + except _PREEMPTION_ERRORS as e: + logging.info('An error was raised while considering whether the ' + 'session is complete. This may be due to a preemption in ' + 'a connected worker or parameter server. The current ' + 'session will be closed and a new session will be ' + 'created. Error: %s', e) + self.close() + self._sess = self._create_session() + # Since we have just recreated the session, the overall computation should + # not stop: + return False + except Exception: # pylint: disable=broad-except + # `should_stop` should return True instead of raising an exception. + return True + def run(self, fetches, feed_dict=None, options=None, run_metadata=None): while True: try: @@ -901,7 +922,10 @@ class _CoordinatedSession(_WrappedSession): self._stop_grace_period_secs = stop_grace_period_secs def _check_stop(self): - # Check with the coordinator if we should stop. + # If the coordinator was asked to stop due to an exception, then it needs + # to be propagated to this stack. + self._coord.raise_requested_exception() + # At this point, no exceptions are recorded in the coordinator. return self._coord.should_stop() def close(self): @@ -918,6 +942,24 @@ class _CoordinatedSession(_WrappedSession): # useful exceptions are already reported by join(). pass + def run(self, *args, **kwargs): + try: + return self._sess.run(*args, **kwargs) + except _PREEMPTION_ERRORS as original_exception: + raise original_exception + except Exception as original_exception: # pylint: disable=broad-except + # A non-preemption error could have been caused by a preemption error + # in the coordinator. If this is the case, raise that exception instead, + # since it's the root cause. Otherwise, stick to the `original_exception`. + try: + self._coord.raise_requested_exception() + except _PREEMPTION_ERRORS as preemption_in_coordinator: + raise preemption_in_coordinator + except Exception: # pylint: disable=broad-except + raise original_exception + else: + raise original_exception + class _HookedSession(_WrappedSession): """A _WrappedSession that calls hooks during calls to run(). diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index 85a5ceeb08f..ffd9ed311fa 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -508,7 +508,7 @@ class CoordinatedSessionTest(test.TestCase): class AbortAtNSession(object): - """A mock sessionthat aborts at the N-th run call.""" + """A mock session that aborts at the N-th run call.""" def __init__(self, sess, n): self._sess = sess @@ -524,6 +524,99 @@ class AbortAtNSession(object): return self._sess.run(*args, **kwargs) +class StopCoordinatorWithException(session_run_hook.SessionRunHook): + """With this hook Coordinator throws an exception after N-runs.""" + + def __init__(self, calls_before_stopping, exception_to_raise=None): + self._started_the_side_thread_already = False + self._lock = threading.Lock() + self._stored_exception_event = threading.Event() + self._calls_before_stopping = calls_before_stopping + self._exception_to_raise = (exception_to_raise or errors_impl.AbortedError( + None, None, 'Aborted at N')) + + def _maybe_stop_with_exception(self, coord): + while True: + with self._lock: + if self._calls_before_stopping == 0: + try: + raise self._exception_to_raise + except Exception as e: # pylint: disable=broad-except + coord.request_stop(e) + self._stored_exception_event.set() + break + + def after_create_session(self, session, coord): + if self._started_the_side_thread_already: + return + + separate_thread = threading.Thread( + target=self._maybe_stop_with_exception, args=(coord,)) + + coord.register_thread(separate_thread) + separate_thread.start() + self._started_the_side_thread_already = True + # Coordinator will take care of joining `separate_thread`. + + def after_run(self, run_context, run_values): + stopping_now = False + with self._lock: + self._calls_before_stopping -= 1 + if self._calls_before_stopping == 0: + stopping_now = True + + if stopping_now: + self._stored_exception_event.wait() + + +class FailTrainingAfterCoordinatorStopped(StopCoordinatorWithException): + """With this hook training encounters an exception after N-runs.""" + + def __init__(self, calls_before_stopping): + StopCoordinatorWithException.__init__(self, calls_before_stopping) + self._coord = None + + def after_create_session(self, session, coord): + self._coord = coord + return StopCoordinatorWithException.after_create_session( + self, session, coord) + + def after_run(self, run_context, run_values): + StopCoordinatorWithException.after_run(self, run_context, run_values) + try: + # After a `run`, an exception could have been stored inside the + # coordinator. + self._coord.raise_requested_exception() + except errors_impl.AbortedError: + # In real world, the main thread may or may not know about the exception + # that stopped the coordinator. Because the coordinator has stopped, the + # main thread could have gotten stuck as well (for example, the + # coordinator was supposed to execute `FIFOQueue.enqueue` while the main + # thread is executing a blocking `FIFOQueue.dequeue`). After it got stuck, + # the session is going to get garbage collected after some time with: + raise errors_impl.CancelledError(None, None, + 'Session got garbage-collected.') + + +class CountingSessionCreator(object): + """A creator that counts the number of created sessions.""" + + def __init__(self, session): + self._initial_session = session + # We only have one session per test case. We can't re-create it, thus + # it shouldn't be closed. + self._initial_session.close = lambda *args: None + self._create_session_calls = 0 + + @property + def number_of_sessions_created(self): + return self._create_session_calls + + def create_session(self): + self._create_session_calls += 1 + return self._initial_session + + class RecoverableSessionTest(test.TestCase): """_RecoverableSession tests.""" @@ -597,6 +690,84 @@ class RecoverableSessionTest(test.TestCase): with self.assertRaisesRegexp(IndexError, 'pop from empty list'): recoverable_sess.run(v, feed_dict={c: -12}) + def test_recovery_from_coordinator_exception(self): + with self.test_session() as test_session: + session_creator = CountingSessionCreator(test_session) + session = monitored_session.MonitoredSession( + session_creator, + [StopCoordinatorWithException(calls_before_stopping=2)]) + + self.assertEqual(1, session_creator.number_of_sessions_created) + self.assertFalse(session.should_stop()) + + c = constant_op.constant(0) + v = array_ops.identity(c) + + # The coordinator will not abort during this call, since it's the call + # number 0. + self.assertEqual(51, session.run(v, feed_dict={c: 51})) + self.assertFalse(session.should_stop()) + # The coordinator will abort during the next call, since it's the call + # number 1. + self.assertEqual(42, session.run(v, feed_dict={c: 42})) + # Even though the coordinator was asked to stop, the underlying session is + # recreated and is to be continued. + self.assertFalse(session.should_stop()) + self.assertEqual(2, session_creator.number_of_sessions_created) + + def test_recovery_from_non_preemption_in_coordinator(self): + with self.test_session() as test_session: + session_creator = CountingSessionCreator(test_session) + hook = StopCoordinatorWithException( + calls_before_stopping=2, + exception_to_raise=errors_impl.UnknownError( + None, None, 'Some fatal exception inside the coordinator.')) + session = monitored_session.MonitoredSession(session_creator, [hook]) + + self.assertEqual(1, session_creator.number_of_sessions_created) + self.assertFalse(session.should_stop()) + + c = constant_op.constant(0) + v = array_ops.identity(c) + + # The coordinator will not abort during this call, since it's the call + # number 0. + self.assertEqual(51, session.run(v, feed_dict={c: 51})) + self.assertFalse(session.should_stop()) + # The coordinator will abort during the next call, since it's the call + # number 1. + self.assertEqual(42, session.run(v, feed_dict={c: 42})) + # The coordinator was asked to stop due to non-redeemable error. Training + # should stop and the session should not be recreated. + self.assertTrue(session.should_stop()) + self.assertEqual(1, session_creator.number_of_sessions_created) + with self.assertRaises(errors_impl.UnknownError): + session.close() + + def test_recovery_from_session_getting_stuck(self): + with self.test_session() as test_session: + session_creator = CountingSessionCreator(test_session) + session = monitored_session.MonitoredSession( + session_creator, + [FailTrainingAfterCoordinatorStopped(calls_before_stopping=2)]) + + self.assertEqual(1, session_creator.number_of_sessions_created) + self.assertFalse(session.should_stop()) + + c = constant_op.constant(0) + v = array_ops.identity(c) + + # Training will not fail, since it's the call number 0. + self.assertEqual(51, session.run(v, feed_dict={c: 51})) + self.assertFalse(session.should_stop()) + # Training will fail during the next call, since it's the call + # number 1. + self.assertEqual(42, session.run(v, feed_dict={c: 42})) + # Even though the coordinator stopped which and training failed, the + # underlying session is recreated and training is to be continued. + self.assertFalse(session.should_stop()) + self.assertEqual(2, session_creator.number_of_sessions_created) + class FakeSession(monitored_session._WrappedSession):