Add a test in multi_process_runner_test to assert that terminate_all does not suppress the errors that occur in the function.

PiperOrigin-RevId: 314442559
Change-Id: I5b8cd58bfbdb5e11011c61c90621bfe2ccf2d952
This commit is contained in:
Rick Chao 2020-06-02 18:13:16 -07:00 committed by TensorFlower Gardener
parent a95df90cb9
commit 78390fab9a
2 changed files with 31 additions and 25 deletions

View File

@ -422,32 +422,28 @@ class MultiProcessRunner(object):
timeout = float('inf')
start_time = time.time()
while self._outstanding_subprocess_count > 0:
while True:
try:
process_status = _resource(PROCESS_STATUS_QUEUE).get(timeout=10)
try:
process_status = _resource(PROCESS_STATUS_QUEUE).get(timeout=10)
self._outstanding_subprocess_count -= 1
assert isinstance(process_status, _ProcessStatusInfo)
if not process_status.is_successful:
six.reraise(*process_status.exc_info)
if self._dependence_on_chief and process_status.task_type == 'chief':
self.terminate_all()
break
except Queue.Empty:
if self._all_forced_terminated:
break
if time.time() - start_time > timeout:
# Send SIGTERM signal to subprocesses to dump their current
# stack trace.
self.terminate_all(sig=signal.SIGTERM)
# If none of those did, report timeout to user.
raise RuntimeError('One or more subprocesses timed out. '
'Number of outstanding subprocesses '
'is %d.' % self._outstanding_subprocess_count)
if self._all_forced_terminated:
break
self._outstanding_subprocess_count -= 1
assert isinstance(process_status, _ProcessStatusInfo)
if not process_status.is_successful:
six.reraise(*process_status.exc_info)
if self._dependence_on_chief and process_status.task_type == 'chief':
self.terminate_all()
break
except Queue.Empty:
if self._all_forced_terminated:
break
if time.time() - start_time > timeout:
# Send SIGTERM signal to subprocesses to dump their current
# stack trace.
self.terminate_all(sig=signal.SIGTERM)
# If none of those did, report timeout to user.
raise RuntimeError('One or more subprocesses timed out. '
'Number of outstanding subprocesses '
'is %d.' % self._outstanding_subprocess_count)
# Giving threads some time to finish the message reading from subprocesses.
time.sleep(5)

View File

@ -285,5 +285,15 @@ class MultiProcessRunnerTest(test.TestCase):
any('{}-0, i: {}'.format(job, iteration) in line
for line in list_to_assert))
def test_terminate_all_does_not_ignore_error(self):
mpr = multi_process_runner.MultiProcessRunner(
proc_func_that_errors,
multi_worker_test_base.create_cluster_spec(num_workers=2),
list_stdout=True)
mpr.start()
mpr.terminate_all()
with self.assertRaisesRegexp(ValueError, 'This is an error.'):
mpr.join()
if __name__ == '__main__':
multi_process_runner.test_main()