From 785de35b28ef695b8b4c654e6f077775b2122a86 Mon Sep 17 00:00:00 2001 From: Ayush Dubey Date: Mon, 7 Oct 2019 10:39:27 -0700 Subject: [PATCH] Raise the exception raised by subprocess in multi-process runner. Before this change, exceptions raised by processes spawned by the runner process would be converted to strings, and eventually raised as `RuntimeError` by runner. After this change we raise the same error thrown by the subprocess. PiperOrigin-RevId: 273325020 --- .../python/distribute/multi_process_runner.py | 18 +++++++----------- .../distribute/multi_process_runner_test.py | 4 +--- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index 5d421e51aae..f048200cda9 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -29,7 +29,6 @@ import json import os import signal import sys -import traceback from absl import flags from six.moves import queue as Queue @@ -167,7 +166,7 @@ def run(proc_func, stderr_collector = _LogCollector( sys.__stderr__) if return_std_stream else None - def finish_wrapper_func_properly(finish_message): + def finish_wrapper_func_properly(func_result): """Call to finish `wrapper_func` properly.""" # Clear the alarm. signal.alarm(0) @@ -180,7 +179,7 @@ def run(proc_func, # Un-redirect stdout and stderr. sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ - _get_internal_queue().put(finish_message) + _get_internal_queue().put(func_result) if time_to_exit is not None: @@ -204,9 +203,7 @@ def run(proc_func, # pylint: disable=broad-except except Exception as e: # Capture all exceptions to be reported to parent process. - finish_wrapper_func_properly( - 'Exception raised by subprocess: {}: {}. {}'.format( - e.__class__.__name__, str(e), traceback.format_exc())) + finish_wrapper_func_properly(e) return finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE) @@ -236,17 +233,16 @@ def run(proc_func, except Queue.Empty: # First check if any of the subprocesses raised exception. for internal_queue_result in internal_queue_results: - if internal_queue_result.startswith('Exception raised by subprocess'): - # TODO(b/142073790): Recover the original exception type. - raise RuntimeError(internal_queue_result) + if isinstance(internal_queue_result, Exception): + raise internal_queue_result # If none of those did, report time out to user. raise RuntimeError( 'One or more subprocesses timed out. Please inspect logs for ' 'subprocess debugging info. Timeout = {} sec.'.format(timeout)) for internal_queue_result in internal_queue_results: - if internal_queue_result.startswith('Exception raised by subprocess'): - raise RuntimeError(internal_queue_result) + if isinstance(internal_queue_result, Exception): + raise internal_queue_result assert internal_queue_result == _FINISH_PROPERLY_MESSAGE def queue_to_list(queue_to_convert): diff --git a/tensorflow/python/distribute/multi_process_runner_test.py b/tensorflow/python/distribute/multi_process_runner_test.py index a9d0be0f893..fd99b3a419e 100644 --- a/tensorflow/python/distribute/multi_process_runner_test.py +++ b/tensorflow/python/distribute/multi_process_runner_test.py @@ -76,9 +76,7 @@ class MultiProcessRunnerTest(test.TestCase): def test_multi_process_runner_error_propagates_from_subprocesses(self): job_count_dict = {'worker': 1, 'ps': 1} - with self.assertRaisesRegexp( - RuntimeError, 'Exception raised by subprocess: RuntimeError: ' - 'This is an error.'): + with self.assertRaisesRegexp(RuntimeError, 'This is an error.'): multi_process_runner.run( proc_func_that_errors, multi_process_runner.job_count_to_cluster_spec(job_count_dict),