From 78390fab9a0c7dff3e99de766ff25c05b3801cdb Mon Sep 17 00:00:00 2001 From: Rick Chao Date: Tue, 2 Jun 2020 18:13:16 -0700 Subject: [PATCH] Add a test in multi_process_runner_test to assert that terminate_all does not suppress the errors that occur in the function. PiperOrigin-RevId: 314442559 Change-Id: I5b8cd58bfbdb5e11011c61c90621bfe2ccf2d952 --- .../python/distribute/multi_process_runner.py | 46 +++++++++---------- .../distribute/multi_process_runner_test.py | 10 ++++ 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index 3e28cf44072..aa52ee684ec 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -422,32 +422,28 @@ class MultiProcessRunner(object): timeout = float('inf') start_time = time.time() while self._outstanding_subprocess_count > 0: - while True: - try: - process_status = _resource(PROCESS_STATUS_QUEUE).get(timeout=10) + try: + process_status = _resource(PROCESS_STATUS_QUEUE).get(timeout=10) + + self._outstanding_subprocess_count -= 1 + assert isinstance(process_status, _ProcessStatusInfo) + if not process_status.is_successful: + six.reraise(*process_status.exc_info) + + if self._dependence_on_chief and process_status.task_type == 'chief': + self.terminate_all() break - except Queue.Empty: - if self._all_forced_terminated: - break - if time.time() - start_time > timeout: - # Send SIGTERM signal to subprocesses to dump their current - # stack trace. - self.terminate_all(sig=signal.SIGTERM) - # If none of those did, report timeout to user. - raise RuntimeError('One or more subprocesses timed out. ' - 'Number of outstanding subprocesses ' - 'is %d.' % self._outstanding_subprocess_count) - - if self._all_forced_terminated: - break - self._outstanding_subprocess_count -= 1 - assert isinstance(process_status, _ProcessStatusInfo) - if not process_status.is_successful: - six.reraise(*process_status.exc_info) - - if self._dependence_on_chief and process_status.task_type == 'chief': - self.terminate_all() - break + except Queue.Empty: + if self._all_forced_terminated: + break + if time.time() - start_time > timeout: + # Send SIGTERM signal to subprocesses to dump their current + # stack trace. + self.terminate_all(sig=signal.SIGTERM) + # If none of those did, report timeout to user. + raise RuntimeError('One or more subprocesses timed out. ' + 'Number of outstanding subprocesses ' + 'is %d.' % self._outstanding_subprocess_count) # Giving threads some time to finish the message reading from subprocesses. time.sleep(5) diff --git a/tensorflow/python/distribute/multi_process_runner_test.py b/tensorflow/python/distribute/multi_process_runner_test.py index 1413777d0bc..69e84581af3 100644 --- a/tensorflow/python/distribute/multi_process_runner_test.py +++ b/tensorflow/python/distribute/multi_process_runner_test.py @@ -285,5 +285,15 @@ class MultiProcessRunnerTest(test.TestCase): any('{}-0, i: {}'.format(job, iteration) in line for line in list_to_assert)) + def test_terminate_all_does_not_ignore_error(self): + mpr = multi_process_runner.MultiProcessRunner( + proc_func_that_errors, + multi_worker_test_base.create_cluster_spec(num_workers=2), + list_stdout=True) + mpr.start() + mpr.terminate_all() + with self.assertRaisesRegexp(ValueError, 'This is an error.'): + mpr.join() + if __name__ == '__main__': multi_process_runner.test_main()