Add a test in multi_process_runner_test to assert that terminate_all does not suppress the errors that occur in the function.
PiperOrigin-RevId: 314442559 Change-Id: I5b8cd58bfbdb5e11011c61c90621bfe2ccf2d952
This commit is contained in:
parent
a95df90cb9
commit
78390fab9a
tensorflow/python/distribute
@ -422,32 +422,28 @@ class MultiProcessRunner(object):
|
||||
timeout = float('inf')
|
||||
start_time = time.time()
|
||||
while self._outstanding_subprocess_count > 0:
|
||||
while True:
|
||||
try:
|
||||
process_status = _resource(PROCESS_STATUS_QUEUE).get(timeout=10)
|
||||
try:
|
||||
process_status = _resource(PROCESS_STATUS_QUEUE).get(timeout=10)
|
||||
|
||||
self._outstanding_subprocess_count -= 1
|
||||
assert isinstance(process_status, _ProcessStatusInfo)
|
||||
if not process_status.is_successful:
|
||||
six.reraise(*process_status.exc_info)
|
||||
|
||||
if self._dependence_on_chief and process_status.task_type == 'chief':
|
||||
self.terminate_all()
|
||||
break
|
||||
except Queue.Empty:
|
||||
if self._all_forced_terminated:
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
# Send SIGTERM signal to subprocesses to dump their current
|
||||
# stack trace.
|
||||
self.terminate_all(sig=signal.SIGTERM)
|
||||
# If none of those did, report timeout to user.
|
||||
raise RuntimeError('One or more subprocesses timed out. '
|
||||
'Number of outstanding subprocesses '
|
||||
'is %d.' % self._outstanding_subprocess_count)
|
||||
|
||||
if self._all_forced_terminated:
|
||||
break
|
||||
self._outstanding_subprocess_count -= 1
|
||||
assert isinstance(process_status, _ProcessStatusInfo)
|
||||
if not process_status.is_successful:
|
||||
six.reraise(*process_status.exc_info)
|
||||
|
||||
if self._dependence_on_chief and process_status.task_type == 'chief':
|
||||
self.terminate_all()
|
||||
break
|
||||
except Queue.Empty:
|
||||
if self._all_forced_terminated:
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
# Send SIGTERM signal to subprocesses to dump their current
|
||||
# stack trace.
|
||||
self.terminate_all(sig=signal.SIGTERM)
|
||||
# If none of those did, report timeout to user.
|
||||
raise RuntimeError('One or more subprocesses timed out. '
|
||||
'Number of outstanding subprocesses '
|
||||
'is %d.' % self._outstanding_subprocess_count)
|
||||
|
||||
# Giving threads some time to finish the message reading from subprocesses.
|
||||
time.sleep(5)
|
||||
|
@ -285,5 +285,15 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||
any('{}-0, i: {}'.format(job, iteration) in line
|
||||
for line in list_to_assert))
|
||||
|
||||
def test_terminate_all_does_not_ignore_error(self):
|
||||
mpr = multi_process_runner.MultiProcessRunner(
|
||||
proc_func_that_errors,
|
||||
multi_worker_test_base.create_cluster_spec(num_workers=2),
|
||||
list_stdout=True)
|
||||
mpr.start()
|
||||
mpr.terminate_all()
|
||||
with self.assertRaisesRegexp(ValueError, 'This is an error.'):
|
||||
mpr.join()
|
||||
|
||||
if __name__ == '__main__':
|
||||
multi_process_runner.test_main()
|
||||
|
Loading…
Reference in New Issue
Block a user