Raise the exception raised by subprocess in multi-process runner.
Before this change, exceptions raised by processes spawned by the runner process would be converted to strings, and eventually raised as `RuntimeError` by runner. After this change we raise the same error thrown by the subprocess. PiperOrigin-RevId: 273325020
This commit is contained in:
parent
0a2af00d23
commit
785de35b28
@ -29,7 +29,6 @@ import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from absl import flags
|
||||
from six.moves import queue as Queue
|
||||
@ -167,7 +166,7 @@ def run(proc_func,
|
||||
stderr_collector = _LogCollector(
|
||||
sys.__stderr__) if return_std_stream else None
|
||||
|
||||
def finish_wrapper_func_properly(finish_message):
|
||||
def finish_wrapper_func_properly(func_result):
|
||||
"""Call to finish `wrapper_func` properly."""
|
||||
# Clear the alarm.
|
||||
signal.alarm(0)
|
||||
@ -180,7 +179,7 @@ def run(proc_func,
|
||||
# Un-redirect stdout and stderr.
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
_get_internal_queue().put(finish_message)
|
||||
_get_internal_queue().put(func_result)
|
||||
|
||||
if time_to_exit is not None:
|
||||
|
||||
@ -204,9 +203,7 @@ def run(proc_func,
|
||||
# pylint: disable=broad-except
|
||||
except Exception as e:
|
||||
# Capture all exceptions to be reported to parent process.
|
||||
finish_wrapper_func_properly(
|
||||
'Exception raised by subprocess: {}: {}. {}'.format(
|
||||
e.__class__.__name__, str(e), traceback.format_exc()))
|
||||
finish_wrapper_func_properly(e)
|
||||
return
|
||||
|
||||
finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)
|
||||
@ -236,17 +233,16 @@ def run(proc_func,
|
||||
except Queue.Empty:
|
||||
# First check if any of the subprocesses raised exception.
|
||||
for internal_queue_result in internal_queue_results:
|
||||
if internal_queue_result.startswith('Exception raised by subprocess'):
|
||||
# TODO(b/142073790): Recover the original exception type.
|
||||
raise RuntimeError(internal_queue_result)
|
||||
if isinstance(internal_queue_result, Exception):
|
||||
raise internal_queue_result
|
||||
# If none of those did, report time out to user.
|
||||
raise RuntimeError(
|
||||
'One or more subprocesses timed out. Please inspect logs for '
|
||||
'subprocess debugging info. Timeout = {} sec.'.format(timeout))
|
||||
|
||||
for internal_queue_result in internal_queue_results:
|
||||
if internal_queue_result.startswith('Exception raised by subprocess'):
|
||||
raise RuntimeError(internal_queue_result)
|
||||
if isinstance(internal_queue_result, Exception):
|
||||
raise internal_queue_result
|
||||
assert internal_queue_result == _FINISH_PROPERLY_MESSAGE
|
||||
|
||||
def queue_to_list(queue_to_convert):
|
||||
|
@ -76,9 +76,7 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||
|
||||
def test_multi_process_runner_error_propagates_from_subprocesses(self):
|
||||
job_count_dict = {'worker': 1, 'ps': 1}
|
||||
with self.assertRaisesRegexp(
|
||||
RuntimeError, 'Exception raised by subprocess: RuntimeError: '
|
||||
'This is an error.'):
|
||||
with self.assertRaisesRegexp(RuntimeError, 'This is an error.'):
|
||||
multi_process_runner.run(
|
||||
proc_func_that_errors,
|
||||
multi_process_runner.job_count_to_cluster_spec(job_count_dict),
|
||||
|
Loading…
Reference in New Issue
Block a user