Raise the exception raised by subprocess in multi-process runner.

Before this change, exceptions raised by processes spawned by the runner process would be converted to strings, and eventually raised as `RuntimeError` by runner.

After this change we raise the same error thrown by the subprocess.

PiperOrigin-RevId: 273325020
This commit is contained in:
Ayush Dubey 2019-10-07 10:39:27 -07:00 committed by TensorFlower Gardener
parent 0a2af00d23
commit 785de35b28
2 changed files with 8 additions and 14 deletions

View File

@ -29,7 +29,6 @@ import json
import os
import signal
import sys
import traceback
from absl import flags
from six.moves import queue as Queue
@ -167,7 +166,7 @@ def run(proc_func,
stderr_collector = _LogCollector(
sys.__stderr__) if return_std_stream else None
def finish_wrapper_func_properly(finish_message):
def finish_wrapper_func_properly(func_result):
"""Call to finish `wrapper_func` properly."""
# Clear the alarm.
signal.alarm(0)
@ -180,7 +179,7 @@ def run(proc_func,
# Un-redirect stdout and stderr.
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
_get_internal_queue().put(finish_message)
_get_internal_queue().put(func_result)
if time_to_exit is not None:
@ -204,9 +203,7 @@ def run(proc_func,
# pylint: disable=broad-except
except Exception as e:
# Capture all exceptions to be reported to parent process.
finish_wrapper_func_properly(
'Exception raised by subprocess: {}: {}. {}'.format(
e.__class__.__name__, str(e), traceback.format_exc()))
finish_wrapper_func_properly(e)
return
finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)
@ -236,17 +233,16 @@ def run(proc_func,
except Queue.Empty:
# First check if any of the subprocesses raised exception.
for internal_queue_result in internal_queue_results:
if internal_queue_result.startswith('Exception raised by subprocess'):
# TODO(b/142073790): Recover the original exception type.
raise RuntimeError(internal_queue_result)
if isinstance(internal_queue_result, Exception):
raise internal_queue_result
# If none of those did, report time out to user.
raise RuntimeError(
'One or more subprocesses timed out. Please inspect logs for '
'subprocess debugging info. Timeout = {} sec.'.format(timeout))
for internal_queue_result in internal_queue_results:
if internal_queue_result.startswith('Exception raised by subprocess'):
raise RuntimeError(internal_queue_result)
if isinstance(internal_queue_result, Exception):
raise internal_queue_result
assert internal_queue_result == _FINISH_PROPERLY_MESSAGE
def queue_to_list(queue_to_convert):

View File

@ -76,9 +76,7 @@ class MultiProcessRunnerTest(test.TestCase):
def test_multi_process_runner_error_propagates_from_subprocesses(self):
job_count_dict = {'worker': 1, 'ps': 1}
with self.assertRaisesRegexp(
RuntimeError, 'Exception raised by subprocess: RuntimeError: '
'This is an error.'):
with self.assertRaisesRegexp(RuntimeError, 'This is an error.'):
multi_process_runner.run(
proc_func_that_errors,
multi_process_runner.job_count_to_cluster_spec(job_count_dict),