Recover MonitoredSession when the Coordinator is requested to stop with one of the _PREEMPTION_ERRORS.
When SyncReplicasOptimizer is used, a preemption in the Coordinator may result in two cases: Case 1) the session gets silently marked as complete Case 2) the session gets stuck This CL aims to solve and verify solutions for both of these problems. Fix 1 changes the should_stop logic. Fix 2 changes the CoordinatedSession.run() logic. SyncReplicasOptimizer runs a separate set of threads using a Coordinator instance. Those threads do FIFOQueue.enqueue; the main thread does a blocking FIFOQueue.dequeue. `sync_token_q` FIFOQueue is on parameter-servers. When one of the PS instances gets preempted, an AbortedError causes the Coordinator to stop via request_stop(ex). That by itself changes the state of MonitoredSession.should_stop() to True (Fix 1). Results of the blocking Dequeue operation are sent to the chief worker via Recv. What happens next depends on the amount of tokens in `sync_token_q`. If there are enough for the next call to Dequeue to return, then the low-level "tf session run() call" returns. The next iteration of the `while not MonitoredSession.should_stop()` loop decides that the training is complete (Case 1). If there are not enough tokens in `sync_token_q`, then the blocking Dequeue is going to keep waiting for them. This results in the graph execution getting stuck and the whole session getting garbage collected after 10 minutes (Case 2). We decided to fix that by re-creating a session after it gets garbage collected (Fix 2). An alternative was to try to cancel the pending Dequeue operation, but it's not clear that it is the right thing to do and it is also not easy. PiperOrigin-RevId: 164888390
This commit is contained in:
parent
46e4de6e5f
commit
8f9b1af8ae
@ -855,6 +855,27 @@ class _RecoverableSession(_WrappedSession):
|
|||||||
'or parameter server. A new session will be created. '
|
'or parameter server. A new session will be created. '
|
||||||
'Error: %s', e)
|
'Error: %s', e)
|
||||||
|
|
||||||
|
def _check_stop(self):
|
||||||
|
try:
|
||||||
|
if self._sess:
|
||||||
|
return self._sess._check_stop() # pylint: disable=protected-access
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
except _PREEMPTION_ERRORS as e:
|
||||||
|
logging.info('An error was raised while considering whether the '
|
||||||
|
'session is complete. This may be due to a preemption in '
|
||||||
|
'a connected worker or parameter server. The current '
|
||||||
|
'session will be closed and a new session will be '
|
||||||
|
'created. Error: %s', e)
|
||||||
|
self.close()
|
||||||
|
self._sess = self._create_session()
|
||||||
|
# Since we have just recreated the session, the overall computation should
|
||||||
|
# not stop:
|
||||||
|
return False
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
# `should_stop` should return True instead of raising an exception.
|
||||||
|
return True
|
||||||
|
|
||||||
def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
|
def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -901,7 +922,10 @@ class _CoordinatedSession(_WrappedSession):
|
|||||||
self._stop_grace_period_secs = stop_grace_period_secs
|
self._stop_grace_period_secs = stop_grace_period_secs
|
||||||
|
|
||||||
def _check_stop(self):
|
def _check_stop(self):
|
||||||
# Check with the coordinator if we should stop.
|
# If the coordinator was asked to stop due to an exception, then it needs
|
||||||
|
# to be propagated to this stack.
|
||||||
|
self._coord.raise_requested_exception()
|
||||||
|
# At this point, no exceptions are recorded in the coordinator.
|
||||||
return self._coord.should_stop()
|
return self._coord.should_stop()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
@ -918,6 +942,24 @@ class _CoordinatedSession(_WrappedSession):
|
|||||||
# useful exceptions are already reported by join().
|
# useful exceptions are already reported by join().
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def run(self, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
return self._sess.run(*args, **kwargs)
|
||||||
|
except _PREEMPTION_ERRORS as original_exception:
|
||||||
|
raise original_exception
|
||||||
|
except Exception as original_exception: # pylint: disable=broad-except
|
||||||
|
# A non-preemption error could have been caused by a preemption error
|
||||||
|
# in the coordinator. If this is the case, raise that exception instead,
|
||||||
|
# since it's the root cause. Otherwise, stick to the `original_exception`.
|
||||||
|
try:
|
||||||
|
self._coord.raise_requested_exception()
|
||||||
|
except _PREEMPTION_ERRORS as preemption_in_coordinator:
|
||||||
|
raise preemption_in_coordinator
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
raise original_exception
|
||||||
|
else:
|
||||||
|
raise original_exception
|
||||||
|
|
||||||
|
|
||||||
class _HookedSession(_WrappedSession):
|
class _HookedSession(_WrappedSession):
|
||||||
"""A _WrappedSession that calls hooks during calls to run().
|
"""A _WrappedSession that calls hooks during calls to run().
|
||||||
|
@ -508,7 +508,7 @@ class CoordinatedSessionTest(test.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class AbortAtNSession(object):
|
class AbortAtNSession(object):
|
||||||
"""A mock sessionthat aborts at the N-th run call."""
|
"""A mock session that aborts at the N-th run call."""
|
||||||
|
|
||||||
def __init__(self, sess, n):
|
def __init__(self, sess, n):
|
||||||
self._sess = sess
|
self._sess = sess
|
||||||
@ -524,6 +524,99 @@ class AbortAtNSession(object):
|
|||||||
return self._sess.run(*args, **kwargs)
|
return self._sess.run(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class StopCoordinatorWithException(session_run_hook.SessionRunHook):
|
||||||
|
"""With this hook Coordinator throws an exception after N-runs."""
|
||||||
|
|
||||||
|
def __init__(self, calls_before_stopping, exception_to_raise=None):
|
||||||
|
self._started_the_side_thread_already = False
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._stored_exception_event = threading.Event()
|
||||||
|
self._calls_before_stopping = calls_before_stopping
|
||||||
|
self._exception_to_raise = (exception_to_raise or errors_impl.AbortedError(
|
||||||
|
None, None, 'Aborted at N'))
|
||||||
|
|
||||||
|
def _maybe_stop_with_exception(self, coord):
|
||||||
|
while True:
|
||||||
|
with self._lock:
|
||||||
|
if self._calls_before_stopping == 0:
|
||||||
|
try:
|
||||||
|
raise self._exception_to_raise
|
||||||
|
except Exception as e: # pylint: disable=broad-except
|
||||||
|
coord.request_stop(e)
|
||||||
|
self._stored_exception_event.set()
|
||||||
|
break
|
||||||
|
|
||||||
|
def after_create_session(self, session, coord):
|
||||||
|
if self._started_the_side_thread_already:
|
||||||
|
return
|
||||||
|
|
||||||
|
separate_thread = threading.Thread(
|
||||||
|
target=self._maybe_stop_with_exception, args=(coord,))
|
||||||
|
|
||||||
|
coord.register_thread(separate_thread)
|
||||||
|
separate_thread.start()
|
||||||
|
self._started_the_side_thread_already = True
|
||||||
|
# Coordinator will take care of joining `separate_thread`.
|
||||||
|
|
||||||
|
def after_run(self, run_context, run_values):
|
||||||
|
stopping_now = False
|
||||||
|
with self._lock:
|
||||||
|
self._calls_before_stopping -= 1
|
||||||
|
if self._calls_before_stopping == 0:
|
||||||
|
stopping_now = True
|
||||||
|
|
||||||
|
if stopping_now:
|
||||||
|
self._stored_exception_event.wait()
|
||||||
|
|
||||||
|
|
||||||
|
class FailTrainingAfterCoordinatorStopped(StopCoordinatorWithException):
|
||||||
|
"""With this hook training encounters an exception after N-runs."""
|
||||||
|
|
||||||
|
def __init__(self, calls_before_stopping):
|
||||||
|
StopCoordinatorWithException.__init__(self, calls_before_stopping)
|
||||||
|
self._coord = None
|
||||||
|
|
||||||
|
def after_create_session(self, session, coord):
|
||||||
|
self._coord = coord
|
||||||
|
return StopCoordinatorWithException.after_create_session(
|
||||||
|
self, session, coord)
|
||||||
|
|
||||||
|
def after_run(self, run_context, run_values):
|
||||||
|
StopCoordinatorWithException.after_run(self, run_context, run_values)
|
||||||
|
try:
|
||||||
|
# After a `run`, an exception could have been stored inside the
|
||||||
|
# coordinator.
|
||||||
|
self._coord.raise_requested_exception()
|
||||||
|
except errors_impl.AbortedError:
|
||||||
|
# In real world, the main thread may or may not know about the exception
|
||||||
|
# that stopped the coordinator. Because the coordinator has stopped, the
|
||||||
|
# main thread could have gotten stuck as well (for example, the
|
||||||
|
# coordinator was supposed to execute `FIFOQueue.enqueue` while the main
|
||||||
|
# thread is executing a blocking `FIFOQueue.dequeue`). After it got stuck,
|
||||||
|
# the session is going to get garbage collected after some time with:
|
||||||
|
raise errors_impl.CancelledError(None, None,
|
||||||
|
'Session got garbage-collected.')
|
||||||
|
|
||||||
|
|
||||||
|
class CountingSessionCreator(object):
|
||||||
|
"""A creator that counts the number of created sessions."""
|
||||||
|
|
||||||
|
def __init__(self, session):
|
||||||
|
self._initial_session = session
|
||||||
|
# We only have one session per test case. We can't re-create it, thus
|
||||||
|
# it shouldn't be closed.
|
||||||
|
self._initial_session.close = lambda *args: None
|
||||||
|
self._create_session_calls = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def number_of_sessions_created(self):
|
||||||
|
return self._create_session_calls
|
||||||
|
|
||||||
|
def create_session(self):
|
||||||
|
self._create_session_calls += 1
|
||||||
|
return self._initial_session
|
||||||
|
|
||||||
|
|
||||||
class RecoverableSessionTest(test.TestCase):
|
class RecoverableSessionTest(test.TestCase):
|
||||||
"""_RecoverableSession tests."""
|
"""_RecoverableSession tests."""
|
||||||
|
|
||||||
@ -597,6 +690,84 @@ class RecoverableSessionTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(IndexError, 'pop from empty list'):
|
with self.assertRaisesRegexp(IndexError, 'pop from empty list'):
|
||||||
recoverable_sess.run(v, feed_dict={c: -12})
|
recoverable_sess.run(v, feed_dict={c: -12})
|
||||||
|
|
||||||
|
def test_recovery_from_coordinator_exception(self):
|
||||||
|
with self.test_session() as test_session:
|
||||||
|
session_creator = CountingSessionCreator(test_session)
|
||||||
|
session = monitored_session.MonitoredSession(
|
||||||
|
session_creator,
|
||||||
|
[StopCoordinatorWithException(calls_before_stopping=2)])
|
||||||
|
|
||||||
|
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
|
||||||
|
c = constant_op.constant(0)
|
||||||
|
v = array_ops.identity(c)
|
||||||
|
|
||||||
|
# The coordinator will not abort during this call, since it's the call
|
||||||
|
# number 0.
|
||||||
|
self.assertEqual(51, session.run(v, feed_dict={c: 51}))
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
# The coordinator will abort during the next call, since it's the call
|
||||||
|
# number 1.
|
||||||
|
self.assertEqual(42, session.run(v, feed_dict={c: 42}))
|
||||||
|
# Even though the coordinator was asked to stop, the underlying session is
|
||||||
|
# recreated and is to be continued.
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
self.assertEqual(2, session_creator.number_of_sessions_created)
|
||||||
|
|
||||||
|
def test_recovery_from_non_preemption_in_coordinator(self):
|
||||||
|
with self.test_session() as test_session:
|
||||||
|
session_creator = CountingSessionCreator(test_session)
|
||||||
|
hook = StopCoordinatorWithException(
|
||||||
|
calls_before_stopping=2,
|
||||||
|
exception_to_raise=errors_impl.UnknownError(
|
||||||
|
None, None, 'Some fatal exception inside the coordinator.'))
|
||||||
|
session = monitored_session.MonitoredSession(session_creator, [hook])
|
||||||
|
|
||||||
|
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
|
||||||
|
c = constant_op.constant(0)
|
||||||
|
v = array_ops.identity(c)
|
||||||
|
|
||||||
|
# The coordinator will not abort during this call, since it's the call
|
||||||
|
# number 0.
|
||||||
|
self.assertEqual(51, session.run(v, feed_dict={c: 51}))
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
# The coordinator will abort during the next call, since it's the call
|
||||||
|
# number 1.
|
||||||
|
self.assertEqual(42, session.run(v, feed_dict={c: 42}))
|
||||||
|
# The coordinator was asked to stop due to non-redeemable error. Training
|
||||||
|
# should stop and the session should not be recreated.
|
||||||
|
self.assertTrue(session.should_stop())
|
||||||
|
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||||
|
with self.assertRaises(errors_impl.UnknownError):
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def test_recovery_from_session_getting_stuck(self):
|
||||||
|
with self.test_session() as test_session:
|
||||||
|
session_creator = CountingSessionCreator(test_session)
|
||||||
|
session = monitored_session.MonitoredSession(
|
||||||
|
session_creator,
|
||||||
|
[FailTrainingAfterCoordinatorStopped(calls_before_stopping=2)])
|
||||||
|
|
||||||
|
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
|
||||||
|
c = constant_op.constant(0)
|
||||||
|
v = array_ops.identity(c)
|
||||||
|
|
||||||
|
# Training will not fail, since it's the call number 0.
|
||||||
|
self.assertEqual(51, session.run(v, feed_dict={c: 51}))
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
# Training will fail during the next call, since it's the call
|
||||||
|
# number 1.
|
||||||
|
self.assertEqual(42, session.run(v, feed_dict={c: 42}))
|
||||||
|
# Even though the coordinator stopped which and training failed, the
|
||||||
|
# underlying session is recreated and training is to be continued.
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
self.assertEqual(2, session_creator.number_of_sessions_created)
|
||||||
|
|
||||||
|
|
||||||
class FakeSession(monitored_session._WrappedSession):
|
class FakeSession(monitored_session._WrappedSession):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user