Refactor MultiProcessRunner:

* break run function into smaller functions
  * create a python thread/process-like API
  * remove some seldom used features

PiperOrigin-RevId: 279433471
Change-Id: Ibb7febc2347f45872d1455e000ddcfb9541eee23
This commit is contained in:
Yuefeng Zhou 2019-11-08 18:47:13 -08:00 committed by TensorFlower Gardener
parent 7e48fd7fcf
commit 77b19fb814
5 changed files with 295 additions and 284 deletions

View File

@ -1321,6 +1321,8 @@ py_library(
":multi_process_runner_util", ":multi_process_runner_util",
":multi_worker_test_base", ":multi_worker_test_base",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:tf2",
"//tensorflow/python/compat:v2_compat",
"@six_archive//:six", "@six_archive//:six",
], ],
) )

View File

@ -24,12 +24,13 @@ import json
import os import os
import signal import signal
import sys import sys
import time
from absl import flags
import six import six
from six.moves import queue as Queue from six.moves import queue as Queue
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python import tf2
from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import multi_process_lib from tensorflow.python.distribute import multi_process_lib
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -37,23 +38,20 @@ from tensorflow.python.platform import test
_FINISH_PROPERLY_MESSAGE = 'OK' _FINISH_PROPERLY_MESSAGE = 'OK'
_ExcInfoWrapper = collections.namedtuple('_ExcInfoWrapper', ['exc_info']) _ExcInfoWrapper = collections.namedtuple('_ExcInfoWrapper', ['exc_info'])
# Process status queue is used by `multi_process_runner` internally for
class _AvailableQueues(object): # communication from subprocesses to the parent process. The message can be
"""Names of the available queues used by `multi_process_runner`.""" # _FINISH_PROPERLY_MESSAGE in which case the subprocess has ended
# Internal queue is used by `multi_process_runner` internally for # successfully, or the detailed message of an exception if the subprocess has
# communication from subprocesses to the parent process. The message # raised one so it can be re-raised by the parent process.
# can be _FINISH_PROPERLY_MESSAGE in which case the subprocess has ended successfully, or PROCESS_STATUS_QUEUE = 'process_status_queue'
# the detailed message of an exception if the subprocess has raised # Return value queue is intended to be used by users of `multi_process_runner`
# one so it can be re-raised by the parent process. # for the process function to return information to the caller of
INTERNAL_QUEUE = 'internal_queue' # `multi_process_runner.run()`.
# Public queue is intended to be used by users of `multi_process_runner` RETURN_VALUE_QUEUE = 'return_value_queue'
# for the process function to return information to the caller of # Standard stream queue is used by `multi_process_runner` to collect
# `multi_process_runner.run()`. # information streamed to stdout and stderr to be reported back to the
PUBLIC_QUEUE = 'public_queue' # parent process.
# Standard stream queue is used by `multi_process_runner` to collect STD_STREAM_QUEUE = 'std_stream_queue'
# information streamed to stdout and stderr to be reported back to the
# parent process.
STD_STREAM_QUEUE = 'std_stream_queue'
class _LogCollector(object): class _LogCollector(object):
@ -72,29 +70,32 @@ class _LogCollector(object):
class MultiProcessRunner(object): class MultiProcessRunner(object):
"""A utility to start multiple subprocesses to simulate multiple workers. """A utility class to start multiple processes to simulate a cluster.
Training with multiple workers with eager runtime can be tested by simulating We need to use multiple processes to simulate a cluster in TF 2.0 tests
using multiple processes. See `run()` for more information about the usage because TF 2.0 has some process-global data structures that have to be
of this class. separated by processes. We also need child processes to test out our fault
tolerance because shutting down a standard TensorFlow server within its
process is not supported.
Note: the main test program that uses this runner class must run main program
via `test_main` defined in this file. Using this runner in non-test binaries
is not supported yet.
This class is not thread-safe. Child processes will inherit TF2 behavior flag.
""" """
def run(self, def __init__(self,
proc_func, proc_func,
cluster_spec, cluster_spec,
proc_flags=None, max_run_time=None,
timeout=200, capture_std_stream=False,
time_to_exit=None,
return_std_stream=False,
args=None, args=None,
kwargs=None): kwargs=None):
"""Run functions on local sub-processes. """Creates a multi-process runner.
Experimental. API subject to change. To fully inspect logging from
subprocesses, use `--test_arg=--logtostderr` flag with bazel test.
Args: Args:
proc_func: Function to be run on the processes. This will be run on proc_func: Function to be run on child processes. This will be run on
processes for all task types. processes for all task types.
cluster_spec: Dict for cluster spec. The following is an example of cluster_spec: Dict for cluster spec. The following is an example of
cluster with three workers and two ps's. cluster with three workers and two ps's.
@ -103,39 +104,19 @@ class MultiProcessRunner(object):
"worker2.example.com:2222"], "worker2.example.com:2222"],
"ps": ["ps0.example.com:2222", "ps": ["ps0.example.com:2222",
"ps1.example.com:2222"]} "ps1.example.com:2222"]}
proc_flags: Dict that contains the key/values of the flags used on the max_run_time: If set, child processes is forced to exit at approximately
processes. this many seconds after `start` is called. We achieve this through
timeout: Time out in seconds. If the sub-process takes more than this time `signal.alarm()` api. Note that this is best effort at Python level
to complete, raise an error. since Python signal handler does not get executed when it runs lower
time_to_exit: If set, sub-processes is forced to exit at approximately level C/C++ code. So it can be delayed for arbitrarily long time.
this many seconds after `run()` is called, through `signal.alarm()` api. capture_std_stream: Boolean, whether the messages streamed to stdout and
This is for simulation of interruption on a process so in such cases no stderr in subprocesses are captured.
error is raised. Note that this is best effort at Python level since
Python signal handler does not get executed inside the low-level (C)
signal handler, so it can be delayed.
return_std_stream: Boolean, whether the messages streamed to stdout and
stderr in subprocesses are captured. If True, the messages are stored in
a list returned as the second element.
args: Positional arguments to be sent to functions run on processes. args: Positional arguments to be sent to functions run on processes.
kwargs: Keyword arguments to be sent to functions run on processes. kwargs: Keyword arguments to be sent to functions run on processes.
Returns:
If `return_std_stream` is False, a list that stores the return data added
by subprocesses through `multi_process_runner._add_return_data(data)`
call,
or through normal function return; if `return_std_stream` is True, a
two-element tuple of `(return_data_list, std_stream_data_list)`, where
`return_data_list` stores the return data added by processes through
`multi_process_runner._add_return_data(data)` call or through normal
function
return, and `std_stream_data_list` stores the messages streamed to stdout
and stderr in the subprocesses.
Raises: Raises:
RuntimeError: If any of the subprocesses raise an error, or if any of the RuntimeError: if `multi_process_runner.test_main()` is not called.
subprocesses does not return or error out within `timeout` seconds.
""" """
assert cluster_spec is not None assert cluster_spec is not None
assert callable(proc_func) assert callable(proc_func)
@ -146,72 +127,87 @@ class MultiProcessRunner(object):
'in your python module to properly initialize ' 'in your python module to properly initialize '
'`multi_process_runner`.') '`multi_process_runner`.')
processes = [] self._proc_func = proc_func
args = args or () self._cluster_spec = cluster_spec
kwargs = kwargs or {} self._max_run_time = max_run_time
self._capture_std_stream = capture_std_stream
self._args = args or ()
self._kwargs = kwargs or {}
self._processes = []
def wrapper_func(tf_config_as_json, proc_func, proc_flags, time_to_exit, # Child processes should have the same v2 and eager behavior.
executing_eagerly, *arg, **kwargs): self._v2_enabled = tf2.enabled()
"""The wrapper function that actually gets run on the process(es).""" self._executing_eagerly = context.executing_eagerly()
@contextlib.contextmanager @contextlib.contextmanager
def runtime_mode(executing_eagerly): def _runtime_mode(self):
if executing_eagerly: if self._executing_eagerly:
with context.eager_mode(): with context.eager_mode():
yield yield
else: else:
with context.graph_mode(): with context.graph_mode():
yield yield
with runtime_mode(executing_eagerly): def _finish_process(self, func_status, return_value, stdout_collector,
os.environ['TF_CONFIG'] = tf_config_as_json stderr_collector):
if proc_flags is not None: """Adds data to queues before program exits."""
for flag_key, flag_value in proc_flags.items():
setattr(flags.FLAGS, flag_key, flag_value)
stdout_collector = _LogCollector(
sys.__stdout__) if return_std_stream else None
stderr_collector = _LogCollector(
sys.__stderr__) if return_std_stream else None
def finish_wrapper_func_properly(func_result):
"""Call to finish `wrapper_func` properly."""
# Clear the alarm. # Clear the alarm.
signal.alarm(0) signal.alarm(0)
if (return_std_stream and stdout_collector is not None and if return_value is not None:
stderr_collector is not None): self._add_return_data(return_value)
if self._capture_std_stream:
# If stdout and stderr are to be collected, add them to std stream # If stdout and stderr are to be collected, add them to std stream
# queue. # queue.
self._add_std_stream_data_flattened(stdout_collector.log) self._add_std_stream_data_flattened(stdout_collector.log)
self._add_std_stream_data_flattened(stderr_collector.log) self._add_std_stream_data_flattened(stderr_collector.log)
# Un-redirect stdout and stderr. self._get_process_status_queue().put(func_status)
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
self._get_internal_queue().put(func_result)
if time_to_exit is not None: def _proc_func_wrapper(self, task_type, task_id, *arg, **kwargs):
"""The wrapper function that actually gets run in child process(es)."""
os.environ['TF_CONFIG'] = json.dumps({
'cluster': self._cluster_spec,
'task': {
'type': task_type,
'index': task_id,
}
})
def handler(signum, frame): if self._capture_std_stream:
del signum, frame # TODO(yuefengz): consider a lighter way of capturing std streams.
finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE) stdout_collector = _LogCollector(sys.__stdout__)
# pylint: disable=protected-access stderr_collector = _LogCollector(sys.__stderr__)
os._exit(0)
signal.signal(signal.SIGALRM, handler)
signal.alarm(time_to_exit)
if return_std_stream:
sys.stdout = stdout_collector sys.stdout = stdout_collector
sys.stderr = stderr_collector sys.stderr = stderr_collector
else:
stdout_collector = None
stderr_collector = None
if self._v2_enabled:
v2_compat.enable_v2_behavior()
return_value = None
if self._max_run_time is not None:
# Register an sigalarm handler to exit the process when it reaches
# `timeout` seconds. A program reaching `timeout` doesn't necessarily
# indicate an issue.
def handler(signum, frame):
del signum, frame
self._finish_process(_FINISH_PROPERLY_MESSAGE, None, stdout_collector,
stderr_collector)
os._exit(0) # pylint: disable=protected-access
signal.signal(signal.SIGALRM, handler)
signal.alarm(self._max_run_time)
try: try:
return_data = proc_func(*arg, **kwargs) with self._runtime_mode():
if return_data is not None: return_value = self._proc_func(*arg, **kwargs)
self._add_return_data(return_data) except Exception: # pylint: disable=broad-except
# pylint: disable=broad-except
except Exception:
# Capture all exceptions to be reported to parent process. # Capture all exceptions to be reported to parent process.
finish_wrapper_func_properly(_ExcInfoWrapper(sys.exc_info())) self._finish_process(
_ExcInfoWrapper(sys.exc_info()), return_value, stdout_collector,
stderr_collector)
# Re-raise the exception in addition to reporting it to the parent # Re-raise the exception in addition to reporting it to the parent
# process, so that even if `--test_timeout` flag is set and the # process, so that even if `--test_timeout` flag is set and the
@ -222,50 +218,24 @@ class MultiProcessRunner(object):
# the log, but the program continues running. # the log, but the program continues running.
raise raise
finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE) self._finish_process(_FINISH_PROPERLY_MESSAGE, return_value,
stdout_collector, stderr_collector)
# Start number of processes according to `count_dict`. def start(self):
for job_type, addresses in cluster_spec.items(): """Starts processes, one for each task in `cluster_spec`."""
for task_type, addresses in self._cluster_spec.items():
for task_id, _ in enumerate(addresses): for task_id, _ in enumerate(addresses):
tf_config_as_json = json.dumps({
'cluster': cluster_spec,
'task': {
'type': job_type,
'index': task_id
}
})
p = multi_process_lib.Process( p = multi_process_lib.Process(
target=wrapper_func, target=self._proc_func_wrapper,
args=(tf_config_as_json, proc_func, proc_flags, time_to_exit, args=(task_type, task_id) + self._args,
context.executing_eagerly()) + args, kwargs=self._kwargs)
kwargs=kwargs)
p.start() p.start()
processes.append(p) self._processes.append(p)
internal_queue_results = [] def _queue_to_list(self, queue_to_convert):
for _ in range(len(processes)):
try:
internal_queue_results.append(
self._get_internal_queue().get(timeout=timeout))
except Queue.Empty:
# First check if any of the subprocesses raised exception.
for internal_queue_result in internal_queue_results:
if isinstance(internal_queue_result, _ExcInfoWrapper):
six.reraise(*internal_queue_result.exc_info)
# If none of those did, report time out to user.
raise RuntimeError(
'One or more subprocesses timed out. Please use '
'`--test_arg=--logtostderr` bazel flag to inspect logs for '
'subprocess debugging info. Timeout = {} sec.'.format(timeout))
for internal_queue_result in internal_queue_results:
if isinstance(internal_queue_result, _ExcInfoWrapper):
six.reraise(*internal_queue_result.exc_info)
assert internal_queue_result == _FINISH_PROPERLY_MESSAGE
def queue_to_list(queue_to_convert):
"""Convert `queue.Queue` to `list`.""" """Convert `queue.Queue` to `list`."""
list_to_return = [] list_to_return = []
# Calling `queue.empty()` is not reliable.
while True: while True:
try: try:
list_to_return.append(queue_to_convert.get(block=False)) list_to_return.append(queue_to_convert.get(block=False))
@ -273,65 +243,117 @@ class MultiProcessRunner(object):
break break
return list_to_return return list_to_return
if return_std_stream: def join(self, timeout=None):
return tuple( """Joins all the processes with timeout.
queue_to_list(multi_process_lib.get_user_data()[queue_name])
for queue_name in
[_AvailableQueues.PUBLIC_QUEUE, _AvailableQueues.STD_STREAM_QUEUE])
else:
return queue_to_list(
multi_process_lib.get_user_data()[_AvailableQueues.PUBLIC_QUEUE])
def _add_return_data(self, data):
"""Add return data that will be returned by `multi_process_runner.run()`.
The function provides a way for processes started by
`multi_process_runner.run()` to communicate with the original process
that started the sub-processes. Data passed to `_add_return_data` will
be available in a python Queue.Queue that is eventually returned by
`multi_process_runner.run()`.
Args: Args:
data: data to be made available in the queue returned by timeout: if set and not all processes report status within roughly
`multi_process_runner.run()`. `timeout` seconds, a `RuntimeError` exception will be thrown.
Returns:
It returns a tuple. The first element is a list that stores the return
data added by subprocesses through `_add_return_data` or through normal
function return; The second element is a list of the messages streamed to
stdout and stderr in the subprocesses if `capture_std_stream` is True or
`None` otherwise.
Raises:
RuntimeError: if not all processes report status within `timeout` seconds.
Or the exception propagated from any child process.
"""
if not timeout:
if self._max_run_time:
timeout = self._max_run_time + 10 # add 10 seconds grace period
else:
timeout = float('inf')
num_returned = 0
start_time = time.time()
while num_returned < len(self._processes):
while True:
try:
process_status = self._get_process_status_queue().get(timeout=10)
break
except Queue.Empty:
if time.time() - start_time > timeout:
# If none of those did, report timeout to user.
raise RuntimeError(
'One or more subprocesses timed out. Please use '
'`--test_arg=--logtostderr` bazel flag to inspect logs for '
'subprocess debugging info. Number of returned processes is '
'%d.' % num_returned)
num_returned += 1
if isinstance(process_status, _ExcInfoWrapper):
six.reraise(*process_status.exc_info)
assert process_status == _FINISH_PROPERLY_MESSAGE
self._processes = []
if self._capture_std_stream:
# TODO(yuefengz): we need to make sure elements match the same process in
# the two returned lists so as to not surprise users. Consider creating a
# `ReturnData` class.
return tuple(
self._queue_to_list(multi_process_lib.get_user_data()[queue_name])
for queue_name in [RETURN_VALUE_QUEUE, STD_STREAM_QUEUE])
else:
return (self._queue_to_list(
multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE]), None)
def _add_return_data(self, data):
"""Adds return data that will be returned by `join`.
The function provides a way for child processes to communicate with the
parent process. Data passed to `_add_return_data` will be available in a
Python Queue.Queue that is eventually returned by `join`.
Args:
data: data to be made available in the queue returned by `join`.
""" """
# TODO(rchao): Incorporate the task type and id information in a data # TODO(rchao): Incorporate the task type and id information in a data
# wrapper that becomes what is stored in the queue so we can tell where # wrapper that becomes what is stored in the queue so we can tell where
# the data is from. # the data is from.
multi_process_lib.get_user_data()[_AvailableQueues.PUBLIC_QUEUE].put(data) multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE].put(data)
def _add_std_stream_data_flattened(self, data): def _add_std_stream_data_flattened(self, data):
std_stream_queue = multi_process_lib.get_user_data()[ # TODO(yuefengz): currently the same queue is used by multiple processes. It
_AvailableQueues.STD_STREAM_QUEUE] # is difficult for users to distinguish between logs from different
# processes.
std_stream_queue = multi_process_lib.get_user_data()[STD_STREAM_QUEUE]
for d in list(data): for d in list(data):
std_stream_queue.put(d) std_stream_queue.put(d)
def _get_internal_queue(self): def _get_process_status_queue(self):
return multi_process_lib.get_user_data()[_AvailableQueues.INTERNAL_QUEUE] return multi_process_lib.get_user_data()[PROCESS_STATUS_QUEUE]
def run(proc_func,
cluster_spec,
max_run_time=None,
capture_std_stream=False,
args=None,
kwargs=None): # pylint: disable=g-doc-args
"""Runs functions in local child processes.
It is a convenience method that creates a `MultiProcessRunner` object and
invokes `start` and `join` method. Please see these methods for detailed
documentations.
Returns:
A tuple returned from `MultiProcessRunner.join()`.
"""
runner = MultiProcessRunner(
proc_func,
cluster_spec,
max_run_time=max_run_time,
capture_std_stream=capture_std_stream,
args=args,
kwargs=kwargs)
runner.start()
return runner.join()
def test_main(): def test_main():
"""Main function to be called within `__main__` of a test file.""" """Main function to be called within `__main__` of a test file."""
with multi_process_lib.context_manager(): with multi_process_lib.context_manager():
test.main() test.main()
def job_count_to_cluster_spec(job_count_dict):
"""Convert a job count dict to cluster spec.
Args:
job_count_dict: Dict for task_type/count of such task type.
{'worker': 1, 'ps': 1} is an example of a cluster with a worker and a
ps.
Returns:
The converted cluster spec dict.
"""
cluster_spec = {}
for task_type, count in job_count_dict.items():
cluster_spec[task_type] = [
'localhost:{}'.format(multi_worker_test_base.pick_unused_port())
for _ in range(count)
]
return cluster_spec

View File

@ -19,23 +19,22 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute.multi_process_runner import MultiProcessRunner from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.eager import test from tensorflow.python.eager import test
class MultiProcessRunnerNoInitTest(test.TestCase): class MultiProcessRunnerNoInitTest(test.TestCase):
def test_stdout_captured(self): def test_not_calling_correct_main(self):
def simple_func(): def simple_func():
return 'foobar' return 'foobar'
job_count_dict = {'worker': 1}
with self.assertRaisesRegexp(RuntimeError, with self.assertRaisesRegexp(RuntimeError,
'`multi_process_runner` is not initialized.'): '`multi_process_runner` is not initialized.'):
MultiProcessRunner().run( multi_process_runner.run(
simple_func, simple_func,
multi_process_runner.job_count_to_cluster_spec(job_count_dict)) multi_worker_test_base.create_cluster_spec(num_workers=1))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -20,19 +20,15 @@ from __future__ import print_function
import time import time
from absl import flags
from six.moves import queue as Queue from six.moves import queue as Queue
from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute.multi_process_runner import MultiProcessRunner
from tensorflow.python.eager import test from tensorflow.python.eager import test
flags.DEFINE_boolean(name='test_flag', default=0, help='Test flag')
def proc_func_that_adds_task_type_in_return_data(test_obj, val):
def proc_func_that_adds_task_type_in_return_data(test_obj): test_obj.assertEqual(val, 3)
test_obj.assertTrue(flags.FLAGS.test_flag == 3)
return multi_worker_test_base.get_task_type() return multi_worker_test_base.get_task_type()
@ -55,16 +51,13 @@ def proc_func_that_return_args_and_kwargs(*args, **kwargs):
class MultiProcessRunnerTest(test.TestCase): class MultiProcessRunnerTest(test.TestCase):
def test_multi_process_runner(self): def test_multi_process_runner(self):
job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 2} returned_data, _ = multi_process_runner.run(
proc_flags = {
'test_flag': 3,
}
returned_data = MultiProcessRunner().run(
proc_func_that_adds_task_type_in_return_data, proc_func_that_adds_task_type_in_return_data,
multi_process_runner.job_count_to_cluster_spec(job_count_dict), multi_worker_test_base.create_cluster_spec(
proc_flags=proc_flags, num_workers=2, num_ps=3, has_eval=1),
args=(self,)) args=(self, 3))
job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1}
for data in returned_data: for data in returned_data:
job_count_dict[data] -= 1 job_count_dict[data] -= 1
@ -73,31 +66,29 @@ class MultiProcessRunnerTest(test.TestCase):
self.assertEqual(job_count_dict['evaluator'], 0) self.assertEqual(job_count_dict['evaluator'], 0)
def test_multi_process_runner_error_propagates_from_subprocesses(self): def test_multi_process_runner_error_propagates_from_subprocesses(self):
job_count_dict = {'worker': 1, 'ps': 1} runner = multi_process_runner.MultiProcessRunner(
with self.assertRaisesRegexp(ValueError, 'This is an error.'):
MultiProcessRunner().run(
proc_func_that_errors, proc_func_that_errors,
multi_process_runner.job_count_to_cluster_spec(job_count_dict), multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1),
timeout=20) max_run_time=20)
runner.start()
with self.assertRaisesRegexp(ValueError, 'This is an error.'):
runner.join()
def test_multi_process_runner_queue_emptied_between_runs(self): def test_multi_process_runner_queue_emptied_between_runs(self):
job_count_dict = {'worker': 2} cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
cluster_spec = multi_process_runner.job_count_to_cluster_spec( returned_data, _ = multi_process_runner.run(
job_count_dict)
returned_data = MultiProcessRunner().run(
proc_func_that_adds_simple_return_data, cluster_spec) proc_func_that_adds_simple_return_data, cluster_spec)
self.assertTrue(returned_data) self.assertTrue(returned_data)
self.assertEqual(returned_data[0], 'dummy_data') self.assertEqual(returned_data[0], 'dummy_data')
self.assertEqual(returned_data[1], 'dummy_data') self.assertEqual(returned_data[1], 'dummy_data')
returned_data = MultiProcessRunner().run(proc_func_that_does_nothing, returned_data, _ = multi_process_runner.run(proc_func_that_does_nothing,
cluster_spec) cluster_spec)
self.assertFalse(returned_data) self.assertFalse(returned_data)
def test_multi_process_runner_args_passed_correctly(self): def test_multi_process_runner_args_passed_correctly(self):
job_count_dict = {'worker': 1} returned_data, _ = multi_process_runner.run(
returned_data = MultiProcessRunner().run(
proc_func_that_return_args_and_kwargs, proc_func_that_return_args_and_kwargs,
multi_process_runner.job_count_to_cluster_spec(job_count_dict), multi_worker_test_base.create_cluster_spec(num_workers=1),
args=('a', 'b'), args=('a', 'b'),
kwargs={'c_k': 'c_v'}) kwargs={'c_k': 'c_v'})
self.assertEqual(returned_data[0][0], 'a') self.assertEqual(returned_data[0][0], 'a')
@ -110,11 +101,10 @@ class MultiProcessRunnerTest(test.TestCase):
print('This is something printed.') print('This is something printed.')
return 'This is returned data.' return 'This is returned data.'
job_count_dict = {'worker': 2} returned_data, std_stream_data = multi_process_runner.run(
returned_data, std_stream_data = MultiProcessRunner().run(
simple_print_func, simple_print_func,
multi_process_runner.job_count_to_cluster_spec(job_count_dict), multi_worker_test_base.create_cluster_spec(num_workers=2),
return_std_stream=True) capture_std_stream=True)
num_string_std_stream = len( num_string_std_stream = len(
[d for d in std_stream_data if d == 'This is something printed.']) [d for d in std_stream_data if d == 'This is something printed.'])
num_string_returned_data = len( num_string_returned_data = len(
@ -123,34 +113,32 @@ class MultiProcessRunnerTest(test.TestCase):
self.assertEqual(num_string_returned_data, 2) self.assertEqual(num_string_returned_data, 2)
def test_process_that_exits(self): def test_process_that_exits(self):
mpr = MultiProcessRunner()
def func_to_exit_in_10_sec(): def func_to_exit_in_10_sec():
time.sleep(5) time.sleep(5)
mpr._add_return_data('foo') mpr._add_return_data('foo')
time.sleep(20) time.sleep(20)
mpr._add_return_data('bar') mpr._add_return_data('bar')
job_count_dict = {'worker': 1} mpr = multi_process_runner.MultiProcessRunner(
returned_data = mpr.run(
func_to_exit_in_10_sec, func_to_exit_in_10_sec,
multi_process_runner.job_count_to_cluster_spec(job_count_dict), multi_worker_test_base.create_cluster_spec(num_workers=1),
time_to_exit=10) max_run_time=10)
mpr.start()
returned_data, _ = mpr.join()
self.assertLen(returned_data, 1) self.assertLen(returned_data, 1)
def test_signal_doesnt_fire_after_process_exits(self): def test_signal_doesnt_fire_after_process_exits(self):
job_count_dict = {'worker': 1} mpr = multi_process_runner.MultiProcessRunner(
mpr = MultiProcessRunner()
mpr.run(
proc_func_that_does_nothing, proc_func_that_does_nothing,
multi_process_runner.job_count_to_cluster_spec(job_count_dict), multi_worker_test_base.create_cluster_spec(num_workers=1),
time_to_exit=10) max_run_time=10)
time.sleep(15) mpr.start()
mpr.join()
with self.assertRaisesRegexp(Queue.Empty, ''): with self.assertRaisesRegexp(Queue.Empty, ''):
# If the signal was fired, another message would be added to internal # If the signal was fired, another message would be added to internal
# queue, so verifying it's empty. # queue, so verifying it's empty.
mpr._get_internal_queue().get(block=False) mpr._get_process_status_queue().get(block=False)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -78,7 +78,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved. # TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
with multi_process_runner_util.try_run_and_except_connection_error(self): with multi_process_runner_util.try_run_and_except_connection_error(self):
multi_process_runner.MultiProcessRunner().run( multi_process_runner.run(
worker_fn, worker_fn,
cluster_spec=test_base.create_cluster_spec(num_workers=num_workers)) cluster_spec=test_base.create_cluster_spec(num_workers=num_workers))