Recover MonitoredSession when the Coordinator is requested to stop with one of the _PREEMPTION_ERRORS.

When SyncReplicasOptimizer is used, a preemption in the Coordinator may result in two cases:
Case 1) the session gets silently marked as complete
Case 2) the session gets stuck

This CL aims to solve and verify solutions for both of these problems. Fix 1 changes the should_stop logic. Fix 2 changes the CoordinatedSession.run() logic.

SyncReplicasOptimizer runs a separate set of threads using a Coordinator instance. Those threads do FIFOQueue.enqueue; the main thread does a blocking FIFOQueue.dequeue.

`sync_token_q` FIFOQueue is on parameter-servers. When one of the PS instances gets preempted, an AbortedError causes the Coordinator to stop via request_stop(ex). That by itself changes the state of MonitoredSession.should_stop() to True (Fix 1).

Results of the blocking Dequeue operation are sent to the chief worker via Recv. What happens next depends on the amount of tokens in `sync_token_q`. If there are enough for the next call to Dequeue to return, then the low-level "tf session run() call" returns. The next iteration of the `while not MonitoredSession.should_stop()` loop decides that the training is complete (Case 1).

If there are not enough tokens in `sync_token_q`, then the blocking Dequeue is going to keep waiting for them. This results in the graph execution getting stuck and the whole session getting garbage collected after 10 minutes (Case 2).

We decided to fix that by re-creating a session after it gets garbage collected (Fix 2). An alternative was to try to cancel the pending Dequeue operation, but it's not clear that it is the right thing to do and it is also not easy.

PiperOrigin-RevId: 164888390
This commit is contained in:
Igor Saprykin 2017-08-10 12:36:13 -07:00 committed by TensorFlower Gardener
parent 46e4de6e5f
commit 8f9b1af8ae
2 changed files with 215 additions and 2 deletions

View File

@ -855,6 +855,27 @@ class _RecoverableSession(_WrappedSession):
'or parameter server. A new session will be created. '
'Error: %s', e)
def _check_stop(self):
try:
if self._sess:
return self._sess._check_stop() # pylint: disable=protected-access
else:
return True
except _PREEMPTION_ERRORS as e:
logging.info('An error was raised while considering whether the '
'session is complete. This may be due to a preemption in '
'a connected worker or parameter server. The current '
'session will be closed and a new session will be '
'created. Error: %s', e)
self.close()
self._sess = self._create_session()
# Since we have just recreated the session, the overall computation should
# not stop:
return False
except Exception: # pylint: disable=broad-except
# `should_stop` should return True instead of raising an exception.
return True
def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
while True:
try:
@ -901,7 +922,10 @@ class _CoordinatedSession(_WrappedSession):
self._stop_grace_period_secs = stop_grace_period_secs
def _check_stop(self):
# Check with the coordinator if we should stop.
# If the coordinator was asked to stop due to an exception, then it needs
# to be propagated to this stack.
self._coord.raise_requested_exception()
# At this point, no exceptions are recorded in the coordinator.
return self._coord.should_stop()
def close(self):
@ -918,6 +942,24 @@ class _CoordinatedSession(_WrappedSession):
# useful exceptions are already reported by join().
pass
def run(self, *args, **kwargs):
try:
return self._sess.run(*args, **kwargs)
except _PREEMPTION_ERRORS as original_exception:
raise original_exception
except Exception as original_exception: # pylint: disable=broad-except
# A non-preemption error could have been caused by a preemption error
# in the coordinator. If this is the case, raise that exception instead,
# since it's the root cause. Otherwise, stick to the `original_exception`.
try:
self._coord.raise_requested_exception()
except _PREEMPTION_ERRORS as preemption_in_coordinator:
raise preemption_in_coordinator
except Exception: # pylint: disable=broad-except
raise original_exception
else:
raise original_exception
class _HookedSession(_WrappedSession):
"""A _WrappedSession that calls hooks during calls to run().

View File

@ -508,7 +508,7 @@ class CoordinatedSessionTest(test.TestCase):
class AbortAtNSession(object):
"""A mock sessionthat aborts at the N-th run call."""
"""A mock session that aborts at the N-th run call."""
def __init__(self, sess, n):
self._sess = sess
@ -524,6 +524,99 @@ class AbortAtNSession(object):
return self._sess.run(*args, **kwargs)
class StopCoordinatorWithException(session_run_hook.SessionRunHook):
"""With this hook Coordinator throws an exception after N-runs."""
def __init__(self, calls_before_stopping, exception_to_raise=None):
self._started_the_side_thread_already = False
self._lock = threading.Lock()
self._stored_exception_event = threading.Event()
self._calls_before_stopping = calls_before_stopping
self._exception_to_raise = (exception_to_raise or errors_impl.AbortedError(
None, None, 'Aborted at N'))
def _maybe_stop_with_exception(self, coord):
while True:
with self._lock:
if self._calls_before_stopping == 0:
try:
raise self._exception_to_raise
except Exception as e: # pylint: disable=broad-except
coord.request_stop(e)
self._stored_exception_event.set()
break
def after_create_session(self, session, coord):
if self._started_the_side_thread_already:
return
separate_thread = threading.Thread(
target=self._maybe_stop_with_exception, args=(coord,))
coord.register_thread(separate_thread)
separate_thread.start()
self._started_the_side_thread_already = True
# Coordinator will take care of joining `separate_thread`.
def after_run(self, run_context, run_values):
stopping_now = False
with self._lock:
self._calls_before_stopping -= 1
if self._calls_before_stopping == 0:
stopping_now = True
if stopping_now:
self._stored_exception_event.wait()
class FailTrainingAfterCoordinatorStopped(StopCoordinatorWithException):
"""With this hook training encounters an exception after N-runs."""
def __init__(self, calls_before_stopping):
StopCoordinatorWithException.__init__(self, calls_before_stopping)
self._coord = None
def after_create_session(self, session, coord):
self._coord = coord
return StopCoordinatorWithException.after_create_session(
self, session, coord)
def after_run(self, run_context, run_values):
StopCoordinatorWithException.after_run(self, run_context, run_values)
try:
# After a `run`, an exception could have been stored inside the
# coordinator.
self._coord.raise_requested_exception()
except errors_impl.AbortedError:
# In real world, the main thread may or may not know about the exception
# that stopped the coordinator. Because the coordinator has stopped, the
# main thread could have gotten stuck as well (for example, the
# coordinator was supposed to execute `FIFOQueue.enqueue` while the main
# thread is executing a blocking `FIFOQueue.dequeue`). After it got stuck,
# the session is going to get garbage collected after some time with:
raise errors_impl.CancelledError(None, None,
'Session got garbage-collected.')
class CountingSessionCreator(object):
"""A creator that counts the number of created sessions."""
def __init__(self, session):
self._initial_session = session
# We only have one session per test case. We can't re-create it, thus
# it shouldn't be closed.
self._initial_session.close = lambda *args: None
self._create_session_calls = 0
@property
def number_of_sessions_created(self):
return self._create_session_calls
def create_session(self):
self._create_session_calls += 1
return self._initial_session
class RecoverableSessionTest(test.TestCase):
"""_RecoverableSession tests."""
@ -597,6 +690,84 @@ class RecoverableSessionTest(test.TestCase):
with self.assertRaisesRegexp(IndexError, 'pop from empty list'):
recoverable_sess.run(v, feed_dict={c: -12})
def test_recovery_from_coordinator_exception(self):
with self.test_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
[StopCoordinatorWithException(calls_before_stopping=2)])
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
# The coordinator will not abort during this call, since it's the call
# number 0.
self.assertEqual(51, session.run(v, feed_dict={c: 51}))
self.assertFalse(session.should_stop())
# The coordinator will abort during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run(v, feed_dict={c: 42}))
# Even though the coordinator was asked to stop, the underlying session is
# recreated and is to be continued.
self.assertFalse(session.should_stop())
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_recovery_from_non_preemption_in_coordinator(self):
with self.test_session() as test_session:
session_creator = CountingSessionCreator(test_session)
hook = StopCoordinatorWithException(
calls_before_stopping=2,
exception_to_raise=errors_impl.UnknownError(
None, None, 'Some fatal exception inside the coordinator.'))
session = monitored_session.MonitoredSession(session_creator, [hook])
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
# The coordinator will not abort during this call, since it's the call
# number 0.
self.assertEqual(51, session.run(v, feed_dict={c: 51}))
self.assertFalse(session.should_stop())
# The coordinator will abort during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run(v, feed_dict={c: 42}))
# The coordinator was asked to stop due to non-redeemable error. Training
# should stop and the session should not be recreated.
self.assertTrue(session.should_stop())
self.assertEqual(1, session_creator.number_of_sessions_created)
with self.assertRaises(errors_impl.UnknownError):
session.close()
def test_recovery_from_session_getting_stuck(self):
with self.test_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
[FailTrainingAfterCoordinatorStopped(calls_before_stopping=2)])
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
# Training will not fail, since it's the call number 0.
self.assertEqual(51, session.run(v, feed_dict={c: 51}))
self.assertFalse(session.should_stop())
# Training will fail during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run(v, feed_dict={c: 42}))
# Even though the coordinator stopped which and training failed, the
# underlying session is recreated and training is to be continued.
self.assertFalse(session.should_stop())
self.assertEqual(2, session_creator.number_of_sessions_created)
class FakeSession(monitored_session._WrappedSession):