Raise the exception raised by subprocess in multi-process runner.

Before this change, exceptions raised by processes spawned by the runner process would be converted to strings, and eventually raised as `RuntimeError` by runner.

After this change we raise the same error thrown by the subprocess.

PiperOrigin-RevId: 273325020
This commit is contained in:
Ayush Dubey 2019-10-07 10:39:27 -07:00 committed by TensorFlower Gardener
parent 0a2af00d23
commit 785de35b28
2 changed files with 8 additions and 14 deletions

View File

@ -29,7 +29,6 @@ import json
import os import os
import signal import signal
import sys import sys
import traceback
from absl import flags from absl import flags
from six.moves import queue as Queue from six.moves import queue as Queue
@ -167,7 +166,7 @@ def run(proc_func,
stderr_collector = _LogCollector( stderr_collector = _LogCollector(
sys.__stderr__) if return_std_stream else None sys.__stderr__) if return_std_stream else None
def finish_wrapper_func_properly(finish_message): def finish_wrapper_func_properly(func_result):
"""Call to finish `wrapper_func` properly.""" """Call to finish `wrapper_func` properly."""
# Clear the alarm. # Clear the alarm.
signal.alarm(0) signal.alarm(0)
@ -180,7 +179,7 @@ def run(proc_func,
# Un-redirect stdout and stderr. # Un-redirect stdout and stderr.
sys.stdout = sys.__stdout__ sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__ sys.stderr = sys.__stderr__
_get_internal_queue().put(finish_message) _get_internal_queue().put(func_result)
if time_to_exit is not None: if time_to_exit is not None:
@ -204,9 +203,7 @@ def run(proc_func,
# pylint: disable=broad-except # pylint: disable=broad-except
except Exception as e: except Exception as e:
# Capture all exceptions to be reported to parent process. # Capture all exceptions to be reported to parent process.
finish_wrapper_func_properly( finish_wrapper_func_properly(e)
'Exception raised by subprocess: {}: {}. {}'.format(
e.__class__.__name__, str(e), traceback.format_exc()))
return return
finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE) finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)
@ -236,17 +233,16 @@ def run(proc_func,
except Queue.Empty: except Queue.Empty:
# First check if any of the subprocesses raised exception. # First check if any of the subprocesses raised exception.
for internal_queue_result in internal_queue_results: for internal_queue_result in internal_queue_results:
if internal_queue_result.startswith('Exception raised by subprocess'): if isinstance(internal_queue_result, Exception):
# TODO(b/142073790): Recover the original exception type. raise internal_queue_result
raise RuntimeError(internal_queue_result)
# If none of those did, report time out to user. # If none of those did, report time out to user.
raise RuntimeError( raise RuntimeError(
'One or more subprocesses timed out. Please inspect logs for ' 'One or more subprocesses timed out. Please inspect logs for '
'subprocess debugging info. Timeout = {} sec.'.format(timeout)) 'subprocess debugging info. Timeout = {} sec.'.format(timeout))
for internal_queue_result in internal_queue_results: for internal_queue_result in internal_queue_results:
if internal_queue_result.startswith('Exception raised by subprocess'): if isinstance(internal_queue_result, Exception):
raise RuntimeError(internal_queue_result) raise internal_queue_result
assert internal_queue_result == _FINISH_PROPERLY_MESSAGE assert internal_queue_result == _FINISH_PROPERLY_MESSAGE
def queue_to_list(queue_to_convert): def queue_to_list(queue_to_convert):

View File

@ -76,9 +76,7 @@ class MultiProcessRunnerTest(test.TestCase):
def test_multi_process_runner_error_propagates_from_subprocesses(self): def test_multi_process_runner_error_propagates_from_subprocesses(self):
job_count_dict = {'worker': 1, 'ps': 1} job_count_dict = {'worker': 1, 'ps': 1}
with self.assertRaisesRegexp( with self.assertRaisesRegexp(RuntimeError, 'This is an error.'):
RuntimeError, 'Exception raised by subprocess: RuntimeError: '
'This is an error.'):
multi_process_runner.run( multi_process_runner.run(
proc_func_that_errors, proc_func_that_errors,
multi_process_runner.job_count_to_cluster_spec(job_count_dict), multi_process_runner.job_count_to_cluster_spec(job_count_dict),