From 77b19fb814d53607fc7c66f3731559cdf71dc210 Mon Sep 17 00:00:00 2001 From: Yuefeng Zhou Date: Fri, 8 Nov 2019 18:47:13 -0800 Subject: [PATCH] Refactor MultiProcessRunner: * break run function into smaller functions * create a python thread/process-like API * remove some seldom used features PiperOrigin-RevId: 279433471 Change-Id: Ibb7febc2347f45872d1455e000ddcfb9541eee23 --- tensorflow/python/distribute/BUILD | 2 + .../python/distribute/multi_process_runner.py | 486 +++++++++--------- .../multi_process_runner_no_init_test.py | 9 +- .../distribute/multi_process_runner_test.py | 80 ++- .../multi_worker_continuous_run_test.py | 2 +- 5 files changed, 295 insertions(+), 284 deletions(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 69353e0fcb2..d273d7176b3 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1321,6 +1321,8 @@ py_library( ":multi_process_runner_util", ":multi_worker_test_base", "//tensorflow/python:client_testlib", + "//tensorflow/python:tf2", + "//tensorflow/python/compat:v2_compat", "@six_archive//:six", ], ) diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index 253c17d1c25..c57741eadd1 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -24,12 +24,13 @@ import json import os import signal import sys +import time -from absl import flags import six from six.moves import queue as Queue -from tensorflow.python.distribute import multi_worker_test_base +from tensorflow.python import tf2 +from tensorflow.python.compat import v2_compat from tensorflow.python.distribute import multi_process_lib from tensorflow.python.eager import context from tensorflow.python.platform import test @@ -37,23 +38,20 @@ from tensorflow.python.platform import test _FINISH_PROPERLY_MESSAGE = 'OK' _ExcInfoWrapper = collections.namedtuple('_ExcInfoWrapper', ['exc_info']) - -class _AvailableQueues(object): - """Names of the available queues used by `multi_process_runner`.""" - # Internal queue is used by `multi_process_runner` internally for - # communication from subprocesses to the parent process. The message - # can be _FINISH_PROPERLY_MESSAGE in which case the subprocess has ended successfully, or - # the detailed message of an exception if the subprocess has raised - # one so it can be re-raised by the parent process. - INTERNAL_QUEUE = 'internal_queue' - # Public queue is intended to be used by users of `multi_process_runner` - # for the process function to return information to the caller of - # `multi_process_runner.run()`. - PUBLIC_QUEUE = 'public_queue' - # Standard stream queue is used by `multi_process_runner` to collect - # information streamed to stdout and stderr to be reported back to the - # parent process. - STD_STREAM_QUEUE = 'std_stream_queue' +# Process status queue is used by `multi_process_runner` internally for +# communication from subprocesses to the parent process. The message can be +# _FINISH_PROPERLY_MESSAGE in which case the subprocess has ended +# successfully, or the detailed message of an exception if the subprocess has +# raised one so it can be re-raised by the parent process. +PROCESS_STATUS_QUEUE = 'process_status_queue' +# Return value queue is intended to be used by users of `multi_process_runner` +# for the process function to return information to the caller of +# `multi_process_runner.run()`. +RETURN_VALUE_QUEUE = 'return_value_queue' +# Standard stream queue is used by `multi_process_runner` to collect +# information streamed to stdout and stderr to be reported back to the +# parent process. +STD_STREAM_QUEUE = 'std_stream_queue' class _LogCollector(object): @@ -72,29 +70,32 @@ class _LogCollector(object): class MultiProcessRunner(object): - """A utility to start multiple subprocesses to simulate multiple workers. + """A utility class to start multiple processes to simulate a cluster. - Training with multiple workers with eager runtime can be tested by simulating - using multiple processes. See `run()` for more information about the usage - of this class. + We need to use multiple processes to simulate a cluster in TF 2.0 tests + because TF 2.0 has some process-global data structures that have to be + separated by processes. We also need child processes to test out our fault + tolerance because shutting down a standard TensorFlow server within its + process is not supported. + + Note: the main test program that uses this runner class must run main program + via `test_main` defined in this file. Using this runner in non-test binaries + is not supported yet. + + This class is not thread-safe. Child processes will inherit TF2 behavior flag. """ - def run(self, - proc_func, - cluster_spec, - proc_flags=None, - timeout=200, - time_to_exit=None, - return_std_stream=False, - args=None, - kwargs=None): - """Run functions on local sub-processes. - - Experimental. API subject to change. To fully inspect logging from - subprocesses, use `--test_arg=--logtostderr` flag with bazel test. + def __init__(self, + proc_func, + cluster_spec, + max_run_time=None, + capture_std_stream=False, + args=None, + kwargs=None): + """Creates a multi-process runner. Args: - proc_func: Function to be run on the processes. This will be run on + proc_func: Function to be run on child processes. This will be run on processes for all task types. cluster_spec: Dict for cluster spec. The following is an example of cluster with three workers and two ps's. @@ -103,39 +104,19 @@ class MultiProcessRunner(object): "worker2.example.com:2222"], "ps": ["ps0.example.com:2222", "ps1.example.com:2222"]} - proc_flags: Dict that contains the key/values of the flags used on the - processes. - timeout: Time out in seconds. If the sub-process takes more than this time - to complete, raise an error. - time_to_exit: If set, sub-processes is forced to exit at approximately - this many seconds after `run()` is called, through `signal.alarm()` api. - This is for simulation of interruption on a process so in such cases no - error is raised. Note that this is best effort at Python level since - Python signal handler does not get executed inside the low-level (C) - signal handler, so it can be delayed. - return_std_stream: Boolean, whether the messages streamed to stdout and - stderr in subprocesses are captured. If True, the messages are stored in - a list returned as the second element. + max_run_time: If set, child processes is forced to exit at approximately + this many seconds after `start` is called. We achieve this through + `signal.alarm()` api. Note that this is best effort at Python level + since Python signal handler does not get executed when it runs lower + level C/C++ code. So it can be delayed for arbitrarily long time. + capture_std_stream: Boolean, whether the messages streamed to stdout and + stderr in subprocesses are captured. args: Positional arguments to be sent to functions run on processes. kwargs: Keyword arguments to be sent to functions run on processes. - Returns: - If `return_std_stream` is False, a list that stores the return data added - by subprocesses through `multi_process_runner._add_return_data(data)` - call, - or through normal function return; if `return_std_stream` is True, a - two-element tuple of `(return_data_list, std_stream_data_list)`, where - `return_data_list` stores the return data added by processes through - `multi_process_runner._add_return_data(data)` call or through normal - function - return, and `std_stream_data_list` stores the messages streamed to stdout - and stderr in the subprocesses. - Raises: - RuntimeError: If any of the subprocesses raise an error, or if any of the - subprocesses does not return or error out within `timeout` seconds. + RuntimeError: if `multi_process_runner.test_main()` is not called. """ - assert cluster_spec is not None assert callable(proc_func) @@ -146,192 +127,233 @@ class MultiProcessRunner(object): 'in your python module to properly initialize ' '`multi_process_runner`.') - processes = [] - args = args or () - kwargs = kwargs or {} + self._proc_func = proc_func + self._cluster_spec = cluster_spec + self._max_run_time = max_run_time + self._capture_std_stream = capture_std_stream + self._args = args or () + self._kwargs = kwargs or {} + self._processes = [] - def wrapper_func(tf_config_as_json, proc_func, proc_flags, time_to_exit, - executing_eagerly, *arg, **kwargs): - """The wrapper function that actually gets run on the process(es).""" + # Child processes should have the same v2 and eager behavior. + self._v2_enabled = tf2.enabled() + self._executing_eagerly = context.executing_eagerly() - @contextlib.contextmanager - def runtime_mode(executing_eagerly): - if executing_eagerly: - with context.eager_mode(): - yield - else: - with context.graph_mode(): - yield - - with runtime_mode(executing_eagerly): - os.environ['TF_CONFIG'] = tf_config_as_json - if proc_flags is not None: - for flag_key, flag_value in proc_flags.items(): - setattr(flags.FLAGS, flag_key, flag_value) - - stdout_collector = _LogCollector( - sys.__stdout__) if return_std_stream else None - stderr_collector = _LogCollector( - sys.__stderr__) if return_std_stream else None - - def finish_wrapper_func_properly(func_result): - """Call to finish `wrapper_func` properly.""" - # Clear the alarm. - signal.alarm(0) - if (return_std_stream and stdout_collector is not None and - stderr_collector is not None): - # If stdout and stderr are to be collected, add them to std stream - # queue. - self._add_std_stream_data_flattened(stdout_collector.log) - self._add_std_stream_data_flattened(stderr_collector.log) - # Un-redirect stdout and stderr. - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - self._get_internal_queue().put(func_result) - - if time_to_exit is not None: - - def handler(signum, frame): - del signum, frame - finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE) - # pylint: disable=protected-access - os._exit(0) - - signal.signal(signal.SIGALRM, handler) - signal.alarm(time_to_exit) - - if return_std_stream: - sys.stdout = stdout_collector - sys.stderr = stderr_collector - - try: - return_data = proc_func(*arg, **kwargs) - if return_data is not None: - self._add_return_data(return_data) - # pylint: disable=broad-except - except Exception: - # Capture all exceptions to be reported to parent process. - finish_wrapper_func_properly(_ExcInfoWrapper(sys.exc_info())) - - # Re-raise the exception in addition to reporting it to the parent - # process, so that even if `--test_timeout` flag is set and the - # error doesn't make it to be shown in parent process before bazel's - # timeout, the log would still show what happens in this subprocess, - # instead of silently suppressing the error due to early bazel - # timeout. Raising an error in the subprocess produces stack trace in - # the log, but the program continues running. - raise - - finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE) - - # Start number of processes according to `count_dict`. - for job_type, addresses in cluster_spec.items(): - for task_id, _ in enumerate(addresses): - tf_config_as_json = json.dumps({ - 'cluster': cluster_spec, - 'task': { - 'type': job_type, - 'index': task_id - } - }) - p = multi_process_lib.Process( - target=wrapper_func, - args=(tf_config_as_json, proc_func, proc_flags, time_to_exit, - context.executing_eagerly()) + args, - kwargs=kwargs) - p.start() - processes.append(p) - - internal_queue_results = [] - for _ in range(len(processes)): - try: - internal_queue_results.append( - self._get_internal_queue().get(timeout=timeout)) - except Queue.Empty: - # First check if any of the subprocesses raised exception. - for internal_queue_result in internal_queue_results: - if isinstance(internal_queue_result, _ExcInfoWrapper): - six.reraise(*internal_queue_result.exc_info) - # If none of those did, report time out to user. - raise RuntimeError( - 'One or more subprocesses timed out. Please use ' - '`--test_arg=--logtostderr` bazel flag to inspect logs for ' - 'subprocess debugging info. Timeout = {} sec.'.format(timeout)) - - for internal_queue_result in internal_queue_results: - if isinstance(internal_queue_result, _ExcInfoWrapper): - six.reraise(*internal_queue_result.exc_info) - assert internal_queue_result == _FINISH_PROPERLY_MESSAGE - - def queue_to_list(queue_to_convert): - """Convert `queue.Queue` to `list`.""" - list_to_return = [] - while True: - try: - list_to_return.append(queue_to_convert.get(block=False)) - except Queue.Empty: - break - return list_to_return - - if return_std_stream: - return tuple( - queue_to_list(multi_process_lib.get_user_data()[queue_name]) - for queue_name in - [_AvailableQueues.PUBLIC_QUEUE, _AvailableQueues.STD_STREAM_QUEUE]) + @contextlib.contextmanager + def _runtime_mode(self): + if self._executing_eagerly: + with context.eager_mode(): + yield else: - return queue_to_list( - multi_process_lib.get_user_data()[_AvailableQueues.PUBLIC_QUEUE]) + with context.graph_mode(): + yield - def _add_return_data(self, data): - """Add return data that will be returned by `multi_process_runner.run()`. + def _finish_process(self, func_status, return_value, stdout_collector, + stderr_collector): + """Adds data to queues before program exits.""" + # Clear the alarm. + signal.alarm(0) + if return_value is not None: + self._add_return_data(return_value) + if self._capture_std_stream: + # If stdout and stderr are to be collected, add them to std stream + # queue. + self._add_std_stream_data_flattened(stdout_collector.log) + self._add_std_stream_data_flattened(stderr_collector.log) + self._get_process_status_queue().put(func_status) - The function provides a way for processes started by - `multi_process_runner.run()` to communicate with the original process - that started the sub-processes. Data passed to `_add_return_data` will - be available in a python Queue.Queue that is eventually returned by - `multi_process_runner.run()`. + def _proc_func_wrapper(self, task_type, task_id, *arg, **kwargs): + """The wrapper function that actually gets run in child process(es).""" + os.environ['TF_CONFIG'] = json.dumps({ + 'cluster': self._cluster_spec, + 'task': { + 'type': task_type, + 'index': task_id, + } + }) + + if self._capture_std_stream: + # TODO(yuefengz): consider a lighter way of capturing std streams. + stdout_collector = _LogCollector(sys.__stdout__) + stderr_collector = _LogCollector(sys.__stderr__) + sys.stdout = stdout_collector + sys.stderr = stderr_collector + else: + stdout_collector = None + stderr_collector = None + + if self._v2_enabled: + v2_compat.enable_v2_behavior() + + return_value = None + + if self._max_run_time is not None: + # Register an sigalarm handler to exit the process when it reaches + # `timeout` seconds. A program reaching `timeout` doesn't necessarily + # indicate an issue. + def handler(signum, frame): + del signum, frame + self._finish_process(_FINISH_PROPERLY_MESSAGE, None, stdout_collector, + stderr_collector) + os._exit(0) # pylint: disable=protected-access + + signal.signal(signal.SIGALRM, handler) + signal.alarm(self._max_run_time) + + try: + with self._runtime_mode(): + return_value = self._proc_func(*arg, **kwargs) + except Exception: # pylint: disable=broad-except + # Capture all exceptions to be reported to parent process. + self._finish_process( + _ExcInfoWrapper(sys.exc_info()), return_value, stdout_collector, + stderr_collector) + + # Re-raise the exception in addition to reporting it to the parent + # process, so that even if `--test_timeout` flag is set and the + # error doesn't make it to be shown in parent process before bazel's + # timeout, the log would still show what happens in this subprocess, + # instead of silently suppressing the error due to early bazel + # timeout. Raising an error in the subprocess produces stack trace in + # the log, but the program continues running. + raise + + self._finish_process(_FINISH_PROPERLY_MESSAGE, return_value, + stdout_collector, stderr_collector) + + def start(self): + """Starts processes, one for each task in `cluster_spec`.""" + for task_type, addresses in self._cluster_spec.items(): + for task_id, _ in enumerate(addresses): + p = multi_process_lib.Process( + target=self._proc_func_wrapper, + args=(task_type, task_id) + self._args, + kwargs=self._kwargs) + p.start() + self._processes.append(p) + + def _queue_to_list(self, queue_to_convert): + """Convert `queue.Queue` to `list`.""" + list_to_return = [] + # Calling `queue.empty()` is not reliable. + while True: + try: + list_to_return.append(queue_to_convert.get(block=False)) + except Queue.Empty: + break + return list_to_return + + def join(self, timeout=None): + """Joins all the processes with timeout. Args: - data: data to be made available in the queue returned by - `multi_process_runner.run()`. + timeout: if set and not all processes report status within roughly + `timeout` seconds, a `RuntimeError` exception will be thrown. + + Returns: + It returns a tuple. The first element is a list that stores the return + data added by subprocesses through `_add_return_data` or through normal + function return; The second element is a list of the messages streamed to + stdout and stderr in the subprocesses if `capture_std_stream` is True or + `None` otherwise. + + Raises: + RuntimeError: if not all processes report status within `timeout` seconds. + Or the exception propagated from any child process. + """ + if not timeout: + if self._max_run_time: + timeout = self._max_run_time + 10 # add 10 seconds grace period + else: + timeout = float('inf') + num_returned = 0 + start_time = time.time() + while num_returned < len(self._processes): + while True: + try: + process_status = self._get_process_status_queue().get(timeout=10) + break + except Queue.Empty: + if time.time() - start_time > timeout: + # If none of those did, report timeout to user. + raise RuntimeError( + 'One or more subprocesses timed out. Please use ' + '`--test_arg=--logtostderr` bazel flag to inspect logs for ' + 'subprocess debugging info. Number of returned processes is ' + '%d.' % num_returned) + + num_returned += 1 + if isinstance(process_status, _ExcInfoWrapper): + six.reraise(*process_status.exc_info) + assert process_status == _FINISH_PROPERLY_MESSAGE + + self._processes = [] + + if self._capture_std_stream: + # TODO(yuefengz): we need to make sure elements match the same process in + # the two returned lists so as to not surprise users. Consider creating a + # `ReturnData` class. + return tuple( + self._queue_to_list(multi_process_lib.get_user_data()[queue_name]) + for queue_name in [RETURN_VALUE_QUEUE, STD_STREAM_QUEUE]) + else: + return (self._queue_to_list( + multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE]), None) + + def _add_return_data(self, data): + """Adds return data that will be returned by `join`. + + The function provides a way for child processes to communicate with the + parent process. Data passed to `_add_return_data` will be available in a + Python Queue.Queue that is eventually returned by `join`. + + Args: + data: data to be made available in the queue returned by `join`. """ # TODO(rchao): Incorporate the task type and id information in a data # wrapper that becomes what is stored in the queue so we can tell where # the data is from. - multi_process_lib.get_user_data()[_AvailableQueues.PUBLIC_QUEUE].put(data) + multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE].put(data) def _add_std_stream_data_flattened(self, data): - std_stream_queue = multi_process_lib.get_user_data()[ - _AvailableQueues.STD_STREAM_QUEUE] + # TODO(yuefengz): currently the same queue is used by multiple processes. It + # is difficult for users to distinguish between logs from different + # processes. + std_stream_queue = multi_process_lib.get_user_data()[STD_STREAM_QUEUE] for d in list(data): std_stream_queue.put(d) - def _get_internal_queue(self): - return multi_process_lib.get_user_data()[_AvailableQueues.INTERNAL_QUEUE] + def _get_process_status_queue(self): + return multi_process_lib.get_user_data()[PROCESS_STATUS_QUEUE] + + +def run(proc_func, + cluster_spec, + max_run_time=None, + capture_std_stream=False, + args=None, + kwargs=None): # pylint: disable=g-doc-args + """Runs functions in local child processes. + + It is a convenience method that creates a `MultiProcessRunner` object and + invokes `start` and `join` method. Please see these methods for detailed + documentations. + + Returns: + A tuple returned from `MultiProcessRunner.join()`. + """ + runner = MultiProcessRunner( + proc_func, + cluster_spec, + max_run_time=max_run_time, + capture_std_stream=capture_std_stream, + args=args, + kwargs=kwargs) + runner.start() + return runner.join() def test_main(): """Main function to be called within `__main__` of a test file.""" with multi_process_lib.context_manager(): test.main() - - -def job_count_to_cluster_spec(job_count_dict): - """Convert a job count dict to cluster spec. - - Args: - job_count_dict: Dict for task_type/count of such task type. - {'worker': 1, 'ps': 1} is an example of a cluster with a worker and a - ps. - - Returns: - The converted cluster spec dict. - """ - - cluster_spec = {} - for task_type, count in job_count_dict.items(): - cluster_spec[task_type] = [ - 'localhost:{}'.format(multi_worker_test_base.pick_unused_port()) - for _ in range(count) - ] - return cluster_spec diff --git a/tensorflow/python/distribute/multi_process_runner_no_init_test.py b/tensorflow/python/distribute/multi_process_runner_no_init_test.py index c5820271d1a..475255d5e0a 100644 --- a/tensorflow/python/distribute/multi_process_runner_no_init_test.py +++ b/tensorflow/python/distribute/multi_process_runner_no_init_test.py @@ -19,23 +19,22 @@ from __future__ import division from __future__ import print_function from tensorflow.python.distribute import multi_process_runner -from tensorflow.python.distribute.multi_process_runner import MultiProcessRunner +from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.eager import test class MultiProcessRunnerNoInitTest(test.TestCase): - def test_stdout_captured(self): + def test_not_calling_correct_main(self): def simple_func(): return 'foobar' - job_count_dict = {'worker': 1} with self.assertRaisesRegexp(RuntimeError, '`multi_process_runner` is not initialized.'): - MultiProcessRunner().run( + multi_process_runner.run( simple_func, - multi_process_runner.job_count_to_cluster_spec(job_count_dict)) + multi_worker_test_base.create_cluster_spec(num_workers=1)) if __name__ == '__main__': diff --git a/tensorflow/python/distribute/multi_process_runner_test.py b/tensorflow/python/distribute/multi_process_runner_test.py index 98ca282b7b3..4144eb6f040 100644 --- a/tensorflow/python/distribute/multi_process_runner_test.py +++ b/tensorflow/python/distribute/multi_process_runner_test.py @@ -20,19 +20,15 @@ from __future__ import print_function import time -from absl import flags from six.moves import queue as Queue from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base -from tensorflow.python.distribute.multi_process_runner import MultiProcessRunner from tensorflow.python.eager import test -flags.DEFINE_boolean(name='test_flag', default=0, help='Test flag') - -def proc_func_that_adds_task_type_in_return_data(test_obj): - test_obj.assertTrue(flags.FLAGS.test_flag == 3) +def proc_func_that_adds_task_type_in_return_data(test_obj, val): + test_obj.assertEqual(val, 3) return multi_worker_test_base.get_task_type() @@ -55,16 +51,13 @@ def proc_func_that_return_args_and_kwargs(*args, **kwargs): class MultiProcessRunnerTest(test.TestCase): def test_multi_process_runner(self): - job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 2} - proc_flags = { - 'test_flag': 3, - } - returned_data = MultiProcessRunner().run( + returned_data, _ = multi_process_runner.run( proc_func_that_adds_task_type_in_return_data, - multi_process_runner.job_count_to_cluster_spec(job_count_dict), - proc_flags=proc_flags, - args=(self,)) + multi_worker_test_base.create_cluster_spec( + num_workers=2, num_ps=3, has_eval=1), + args=(self, 3)) + job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1} for data in returned_data: job_count_dict[data] -= 1 @@ -73,31 +66,29 @@ class MultiProcessRunnerTest(test.TestCase): self.assertEqual(job_count_dict['evaluator'], 0) def test_multi_process_runner_error_propagates_from_subprocesses(self): - job_count_dict = {'worker': 1, 'ps': 1} + runner = multi_process_runner.MultiProcessRunner( + proc_func_that_errors, + multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1), + max_run_time=20) + runner.start() with self.assertRaisesRegexp(ValueError, 'This is an error.'): - MultiProcessRunner().run( - proc_func_that_errors, - multi_process_runner.job_count_to_cluster_spec(job_count_dict), - timeout=20) + runner.join() def test_multi_process_runner_queue_emptied_between_runs(self): - job_count_dict = {'worker': 2} - cluster_spec = multi_process_runner.job_count_to_cluster_spec( - job_count_dict) - returned_data = MultiProcessRunner().run( + cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) + returned_data, _ = multi_process_runner.run( proc_func_that_adds_simple_return_data, cluster_spec) self.assertTrue(returned_data) self.assertEqual(returned_data[0], 'dummy_data') self.assertEqual(returned_data[1], 'dummy_data') - returned_data = MultiProcessRunner().run(proc_func_that_does_nothing, - cluster_spec) + returned_data, _ = multi_process_runner.run(proc_func_that_does_nothing, + cluster_spec) self.assertFalse(returned_data) def test_multi_process_runner_args_passed_correctly(self): - job_count_dict = {'worker': 1} - returned_data = MultiProcessRunner().run( + returned_data, _ = multi_process_runner.run( proc_func_that_return_args_and_kwargs, - multi_process_runner.job_count_to_cluster_spec(job_count_dict), + multi_worker_test_base.create_cluster_spec(num_workers=1), args=('a', 'b'), kwargs={'c_k': 'c_v'}) self.assertEqual(returned_data[0][0], 'a') @@ -110,11 +101,10 @@ class MultiProcessRunnerTest(test.TestCase): print('This is something printed.') return 'This is returned data.' - job_count_dict = {'worker': 2} - returned_data, std_stream_data = MultiProcessRunner().run( + returned_data, std_stream_data = multi_process_runner.run( simple_print_func, - multi_process_runner.job_count_to_cluster_spec(job_count_dict), - return_std_stream=True) + multi_worker_test_base.create_cluster_spec(num_workers=2), + capture_std_stream=True) num_string_std_stream = len( [d for d in std_stream_data if d == 'This is something printed.']) num_string_returned_data = len( @@ -123,34 +113,32 @@ class MultiProcessRunnerTest(test.TestCase): self.assertEqual(num_string_returned_data, 2) def test_process_that_exits(self): - - mpr = MultiProcessRunner() - def func_to_exit_in_10_sec(): time.sleep(5) mpr._add_return_data('foo') time.sleep(20) mpr._add_return_data('bar') - job_count_dict = {'worker': 1} - returned_data = mpr.run( + mpr = multi_process_runner.MultiProcessRunner( func_to_exit_in_10_sec, - multi_process_runner.job_count_to_cluster_spec(job_count_dict), - time_to_exit=10) + multi_worker_test_base.create_cluster_spec(num_workers=1), + max_run_time=10) + + mpr.start() + returned_data, _ = mpr.join() self.assertLen(returned_data, 1) def test_signal_doesnt_fire_after_process_exits(self): - job_count_dict = {'worker': 1} - mpr = MultiProcessRunner() - mpr.run( + mpr = multi_process_runner.MultiProcessRunner( proc_func_that_does_nothing, - multi_process_runner.job_count_to_cluster_spec(job_count_dict), - time_to_exit=10) - time.sleep(15) + multi_worker_test_base.create_cluster_spec(num_workers=1), + max_run_time=10) + mpr.start() + mpr.join() with self.assertRaisesRegexp(Queue.Empty, ''): # If the signal was fired, another message would be added to internal # queue, so verifying it's empty. - mpr._get_internal_queue().get(block=False) + mpr._get_process_status_queue().get(block=False) if __name__ == '__main__': diff --git a/tensorflow/python/distribute/multi_worker_continuous_run_test.py b/tensorflow/python/distribute/multi_worker_continuous_run_test.py index 19790a0d69f..9668bc23351 100644 --- a/tensorflow/python/distribute/multi_worker_continuous_run_test.py +++ b/tensorflow/python/distribute/multi_worker_continuous_run_test.py @@ -78,7 +78,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase): # TODO(b/141948186): Remove this `with` block once b/141948186 is resolved. with multi_process_runner_util.try_run_and_except_connection_error(self): - multi_process_runner.MultiProcessRunner().run( + multi_process_runner.run( worker_fn, cluster_spec=test_base.create_cluster_spec(num_workers=num_workers))