Raise the exception raised by subprocess in multi-process runner.
Before this change, exceptions raised by processes spawned by the runner process would be converted to strings, and eventually raised as `RuntimeError` by runner. After this change we raise the same error thrown by the subprocess. PiperOrigin-RevId: 273325020
This commit is contained in:
parent
0a2af00d23
commit
785de35b28
@ -29,7 +29,6 @@ import json
|
|||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
|
||||||
|
|
||||||
from absl import flags
|
from absl import flags
|
||||||
from six.moves import queue as Queue
|
from six.moves import queue as Queue
|
||||||
@ -167,7 +166,7 @@ def run(proc_func,
|
|||||||
stderr_collector = _LogCollector(
|
stderr_collector = _LogCollector(
|
||||||
sys.__stderr__) if return_std_stream else None
|
sys.__stderr__) if return_std_stream else None
|
||||||
|
|
||||||
def finish_wrapper_func_properly(finish_message):
|
def finish_wrapper_func_properly(func_result):
|
||||||
"""Call to finish `wrapper_func` properly."""
|
"""Call to finish `wrapper_func` properly."""
|
||||||
# Clear the alarm.
|
# Clear the alarm.
|
||||||
signal.alarm(0)
|
signal.alarm(0)
|
||||||
@ -180,7 +179,7 @@ def run(proc_func,
|
|||||||
# Un-redirect stdout and stderr.
|
# Un-redirect stdout and stderr.
|
||||||
sys.stdout = sys.__stdout__
|
sys.stdout = sys.__stdout__
|
||||||
sys.stderr = sys.__stderr__
|
sys.stderr = sys.__stderr__
|
||||||
_get_internal_queue().put(finish_message)
|
_get_internal_queue().put(func_result)
|
||||||
|
|
||||||
if time_to_exit is not None:
|
if time_to_exit is not None:
|
||||||
|
|
||||||
@ -204,9 +203,7 @@ def run(proc_func,
|
|||||||
# pylint: disable=broad-except
|
# pylint: disable=broad-except
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Capture all exceptions to be reported to parent process.
|
# Capture all exceptions to be reported to parent process.
|
||||||
finish_wrapper_func_properly(
|
finish_wrapper_func_properly(e)
|
||||||
'Exception raised by subprocess: {}: {}. {}'.format(
|
|
||||||
e.__class__.__name__, str(e), traceback.format_exc()))
|
|
||||||
return
|
return
|
||||||
|
|
||||||
finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)
|
finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)
|
||||||
@ -236,17 +233,16 @@ def run(proc_func,
|
|||||||
except Queue.Empty:
|
except Queue.Empty:
|
||||||
# First check if any of the subprocesses raised exception.
|
# First check if any of the subprocesses raised exception.
|
||||||
for internal_queue_result in internal_queue_results:
|
for internal_queue_result in internal_queue_results:
|
||||||
if internal_queue_result.startswith('Exception raised by subprocess'):
|
if isinstance(internal_queue_result, Exception):
|
||||||
# TODO(b/142073790): Recover the original exception type.
|
raise internal_queue_result
|
||||||
raise RuntimeError(internal_queue_result)
|
|
||||||
# If none of those did, report time out to user.
|
# If none of those did, report time out to user.
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'One or more subprocesses timed out. Please inspect logs for '
|
'One or more subprocesses timed out. Please inspect logs for '
|
||||||
'subprocess debugging info. Timeout = {} sec.'.format(timeout))
|
'subprocess debugging info. Timeout = {} sec.'.format(timeout))
|
||||||
|
|
||||||
for internal_queue_result in internal_queue_results:
|
for internal_queue_result in internal_queue_results:
|
||||||
if internal_queue_result.startswith('Exception raised by subprocess'):
|
if isinstance(internal_queue_result, Exception):
|
||||||
raise RuntimeError(internal_queue_result)
|
raise internal_queue_result
|
||||||
assert internal_queue_result == _FINISH_PROPERLY_MESSAGE
|
assert internal_queue_result == _FINISH_PROPERLY_MESSAGE
|
||||||
|
|
||||||
def queue_to_list(queue_to_convert):
|
def queue_to_list(queue_to_convert):
|
||||||
|
|||||||
@ -76,9 +76,7 @@ class MultiProcessRunnerTest(test.TestCase):
|
|||||||
|
|
||||||
def test_multi_process_runner_error_propagates_from_subprocesses(self):
|
def test_multi_process_runner_error_propagates_from_subprocesses(self):
|
||||||
job_count_dict = {'worker': 1, 'ps': 1}
|
job_count_dict = {'worker': 1, 'ps': 1}
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(RuntimeError, 'This is an error.'):
|
||||||
RuntimeError, 'Exception raised by subprocess: RuntimeError: '
|
|
||||||
'This is an error.'):
|
|
||||||
multi_process_runner.run(
|
multi_process_runner.run(
|
||||||
proc_func_that_errors,
|
proc_func_that_errors,
|
||||||
multi_process_runner.job_count_to_cluster_spec(job_count_dict),
|
multi_process_runner.job_count_to_cluster_spec(job_count_dict),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user