Refactor MultiProcessRunner:

* break run function into smaller functions
  * create a python thread/process-like API
  * remove some seldom used features

PiperOrigin-RevId: 279433471
Change-Id: Ibb7febc2347f45872d1455e000ddcfb9541eee23
This commit is contained in:
Yuefeng Zhou 2019-11-08 18:47:13 -08:00 committed by TensorFlower Gardener
parent 7e48fd7fcf
commit 77b19fb814
5 changed files with 295 additions and 284 deletions

View File

@ -1321,6 +1321,8 @@ py_library(
":multi_process_runner_util",
":multi_worker_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:tf2",
"//tensorflow/python/compat:v2_compat",
"@six_archive//:six",
],
)

View File

@ -24,12 +24,13 @@ import json
import os
import signal
import sys
import time
from absl import flags
import six
from six.moves import queue as Queue
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python import tf2
from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import multi_process_lib
from tensorflow.python.eager import context
from tensorflow.python.platform import test
@ -37,23 +38,20 @@ from tensorflow.python.platform import test
_FINISH_PROPERLY_MESSAGE = 'OK'
_ExcInfoWrapper = collections.namedtuple('_ExcInfoWrapper', ['exc_info'])
class _AvailableQueues(object):
"""Names of the available queues used by `multi_process_runner`."""
# Internal queue is used by `multi_process_runner` internally for
# communication from subprocesses to the parent process. The message
# can be _FINISH_PROPERLY_MESSAGE in which case the subprocess has ended successfully, or
# the detailed message of an exception if the subprocess has raised
# one so it can be re-raised by the parent process.
INTERNAL_QUEUE = 'internal_queue'
# Public queue is intended to be used by users of `multi_process_runner`
# for the process function to return information to the caller of
# `multi_process_runner.run()`.
PUBLIC_QUEUE = 'public_queue'
# Standard stream queue is used by `multi_process_runner` to collect
# information streamed to stdout and stderr to be reported back to the
# parent process.
STD_STREAM_QUEUE = 'std_stream_queue'
# Process status queue is used by `multi_process_runner` internally for
# communication from subprocesses to the parent process. The message can be
# _FINISH_PROPERLY_MESSAGE in which case the subprocess has ended
# successfully, or the detailed message of an exception if the subprocess has
# raised one so it can be re-raised by the parent process.
PROCESS_STATUS_QUEUE = 'process_status_queue'
# Return value queue is intended to be used by users of `multi_process_runner`
# for the process function to return information to the caller of
# `multi_process_runner.run()`.
RETURN_VALUE_QUEUE = 'return_value_queue'
# Standard stream queue is used by `multi_process_runner` to collect
# information streamed to stdout and stderr to be reported back to the
# parent process.
STD_STREAM_QUEUE = 'std_stream_queue'
class _LogCollector(object):
@ -72,29 +70,32 @@ class _LogCollector(object):
class MultiProcessRunner(object):
"""A utility to start multiple subprocesses to simulate multiple workers.
"""A utility class to start multiple processes to simulate a cluster.
Training with multiple workers with eager runtime can be tested by simulating
using multiple processes. See `run()` for more information about the usage
of this class.
We need to use multiple processes to simulate a cluster in TF 2.0 tests
because TF 2.0 has some process-global data structures that have to be
separated by processes. We also need child processes to test out our fault
tolerance because shutting down a standard TensorFlow server within its
process is not supported.
Note: the main test program that uses this runner class must run main program
via `test_main` defined in this file. Using this runner in non-test binaries
is not supported yet.
This class is not thread-safe. Child processes will inherit TF2 behavior flag.
"""
def run(self,
proc_func,
cluster_spec,
proc_flags=None,
timeout=200,
time_to_exit=None,
return_std_stream=False,
args=None,
kwargs=None):
"""Run functions on local sub-processes.
Experimental. API subject to change. To fully inspect logging from
subprocesses, use `--test_arg=--logtostderr` flag with bazel test.
def __init__(self,
proc_func,
cluster_spec,
max_run_time=None,
capture_std_stream=False,
args=None,
kwargs=None):
"""Creates a multi-process runner.
Args:
proc_func: Function to be run on the processes. This will be run on
proc_func: Function to be run on child processes. This will be run on
processes for all task types.
cluster_spec: Dict for cluster spec. The following is an example of
cluster with three workers and two ps's.
@ -103,39 +104,19 @@ class MultiProcessRunner(object):
"worker2.example.com:2222"],
"ps": ["ps0.example.com:2222",
"ps1.example.com:2222"]}
proc_flags: Dict that contains the key/values of the flags used on the
processes.
timeout: Time out in seconds. If the sub-process takes more than this time
to complete, raise an error.
time_to_exit: If set, sub-processes is forced to exit at approximately
this many seconds after `run()` is called, through `signal.alarm()` api.
This is for simulation of interruption on a process so in such cases no
error is raised. Note that this is best effort at Python level since
Python signal handler does not get executed inside the low-level (C)
signal handler, so it can be delayed.
return_std_stream: Boolean, whether the messages streamed to stdout and
stderr in subprocesses are captured. If True, the messages are stored in
a list returned as the second element.
max_run_time: If set, child processes is forced to exit at approximately
this many seconds after `start` is called. We achieve this through
`signal.alarm()` api. Note that this is best effort at Python level
since Python signal handler does not get executed when it runs lower
level C/C++ code. So it can be delayed for arbitrarily long time.
capture_std_stream: Boolean, whether the messages streamed to stdout and
stderr in subprocesses are captured.
args: Positional arguments to be sent to functions run on processes.
kwargs: Keyword arguments to be sent to functions run on processes.
Returns:
If `return_std_stream` is False, a list that stores the return data added
by subprocesses through `multi_process_runner._add_return_data(data)`
call,
or through normal function return; if `return_std_stream` is True, a
two-element tuple of `(return_data_list, std_stream_data_list)`, where
`return_data_list` stores the return data added by processes through
`multi_process_runner._add_return_data(data)` call or through normal
function
return, and `std_stream_data_list` stores the messages streamed to stdout
and stderr in the subprocesses.
Raises:
RuntimeError: If any of the subprocesses raise an error, or if any of the
subprocesses does not return or error out within `timeout` seconds.
RuntimeError: if `multi_process_runner.test_main()` is not called.
"""
assert cluster_spec is not None
assert callable(proc_func)
@ -146,192 +127,233 @@ class MultiProcessRunner(object):
'in your python module to properly initialize '
'`multi_process_runner`.')
processes = []
args = args or ()
kwargs = kwargs or {}
self._proc_func = proc_func
self._cluster_spec = cluster_spec
self._max_run_time = max_run_time
self._capture_std_stream = capture_std_stream
self._args = args or ()
self._kwargs = kwargs or {}
self._processes = []
def wrapper_func(tf_config_as_json, proc_func, proc_flags, time_to_exit,
executing_eagerly, *arg, **kwargs):
"""The wrapper function that actually gets run on the process(es)."""
# Child processes should have the same v2 and eager behavior.
self._v2_enabled = tf2.enabled()
self._executing_eagerly = context.executing_eagerly()
@contextlib.contextmanager
def runtime_mode(executing_eagerly):
if executing_eagerly:
with context.eager_mode():
yield
else:
with context.graph_mode():
yield
with runtime_mode(executing_eagerly):
os.environ['TF_CONFIG'] = tf_config_as_json
if proc_flags is not None:
for flag_key, flag_value in proc_flags.items():
setattr(flags.FLAGS, flag_key, flag_value)
stdout_collector = _LogCollector(
sys.__stdout__) if return_std_stream else None
stderr_collector = _LogCollector(
sys.__stderr__) if return_std_stream else None
def finish_wrapper_func_properly(func_result):
"""Call to finish `wrapper_func` properly."""
# Clear the alarm.
signal.alarm(0)
if (return_std_stream and stdout_collector is not None and
stderr_collector is not None):
# If stdout and stderr are to be collected, add them to std stream
# queue.
self._add_std_stream_data_flattened(stdout_collector.log)
self._add_std_stream_data_flattened(stderr_collector.log)
# Un-redirect stdout and stderr.
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
self._get_internal_queue().put(func_result)
if time_to_exit is not None:
def handler(signum, frame):
del signum, frame
finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)
# pylint: disable=protected-access
os._exit(0)
signal.signal(signal.SIGALRM, handler)
signal.alarm(time_to_exit)
if return_std_stream:
sys.stdout = stdout_collector
sys.stderr = stderr_collector
try:
return_data = proc_func(*arg, **kwargs)
if return_data is not None:
self._add_return_data(return_data)
# pylint: disable=broad-except
except Exception:
# Capture all exceptions to be reported to parent process.
finish_wrapper_func_properly(_ExcInfoWrapper(sys.exc_info()))
# Re-raise the exception in addition to reporting it to the parent
# process, so that even if `--test_timeout` flag is set and the
# error doesn't make it to be shown in parent process before bazel's
# timeout, the log would still show what happens in this subprocess,
# instead of silently suppressing the error due to early bazel
# timeout. Raising an error in the subprocess produces stack trace in
# the log, but the program continues running.
raise
finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)
# Start number of processes according to `count_dict`.
for job_type, addresses in cluster_spec.items():
for task_id, _ in enumerate(addresses):
tf_config_as_json = json.dumps({
'cluster': cluster_spec,
'task': {
'type': job_type,
'index': task_id
}
})
p = multi_process_lib.Process(
target=wrapper_func,
args=(tf_config_as_json, proc_func, proc_flags, time_to_exit,
context.executing_eagerly()) + args,
kwargs=kwargs)
p.start()
processes.append(p)
internal_queue_results = []
for _ in range(len(processes)):
try:
internal_queue_results.append(
self._get_internal_queue().get(timeout=timeout))
except Queue.Empty:
# First check if any of the subprocesses raised exception.
for internal_queue_result in internal_queue_results:
if isinstance(internal_queue_result, _ExcInfoWrapper):
six.reraise(*internal_queue_result.exc_info)
# If none of those did, report time out to user.
raise RuntimeError(
'One or more subprocesses timed out. Please use '
'`--test_arg=--logtostderr` bazel flag to inspect logs for '
'subprocess debugging info. Timeout = {} sec.'.format(timeout))
for internal_queue_result in internal_queue_results:
if isinstance(internal_queue_result, _ExcInfoWrapper):
six.reraise(*internal_queue_result.exc_info)
assert internal_queue_result == _FINISH_PROPERLY_MESSAGE
def queue_to_list(queue_to_convert):
"""Convert `queue.Queue` to `list`."""
list_to_return = []
while True:
try:
list_to_return.append(queue_to_convert.get(block=False))
except Queue.Empty:
break
return list_to_return
if return_std_stream:
return tuple(
queue_to_list(multi_process_lib.get_user_data()[queue_name])
for queue_name in
[_AvailableQueues.PUBLIC_QUEUE, _AvailableQueues.STD_STREAM_QUEUE])
@contextlib.contextmanager
def _runtime_mode(self):
if self._executing_eagerly:
with context.eager_mode():
yield
else:
return queue_to_list(
multi_process_lib.get_user_data()[_AvailableQueues.PUBLIC_QUEUE])
with context.graph_mode():
yield
def _add_return_data(self, data):
"""Add return data that will be returned by `multi_process_runner.run()`.
def _finish_process(self, func_status, return_value, stdout_collector,
stderr_collector):
"""Adds data to queues before program exits."""
# Clear the alarm.
signal.alarm(0)
if return_value is not None:
self._add_return_data(return_value)
if self._capture_std_stream:
# If stdout and stderr are to be collected, add them to std stream
# queue.
self._add_std_stream_data_flattened(stdout_collector.log)
self._add_std_stream_data_flattened(stderr_collector.log)
self._get_process_status_queue().put(func_status)
The function provides a way for processes started by
`multi_process_runner.run()` to communicate with the original process
that started the sub-processes. Data passed to `_add_return_data` will
be available in a python Queue.Queue that is eventually returned by
`multi_process_runner.run()`.
def _proc_func_wrapper(self, task_type, task_id, *arg, **kwargs):
"""The wrapper function that actually gets run in child process(es)."""
os.environ['TF_CONFIG'] = json.dumps({
'cluster': self._cluster_spec,
'task': {
'type': task_type,
'index': task_id,
}
})
if self._capture_std_stream:
# TODO(yuefengz): consider a lighter way of capturing std streams.
stdout_collector = _LogCollector(sys.__stdout__)
stderr_collector = _LogCollector(sys.__stderr__)
sys.stdout = stdout_collector
sys.stderr = stderr_collector
else:
stdout_collector = None
stderr_collector = None
if self._v2_enabled:
v2_compat.enable_v2_behavior()
return_value = None
if self._max_run_time is not None:
# Register an sigalarm handler to exit the process when it reaches
# `timeout` seconds. A program reaching `timeout` doesn't necessarily
# indicate an issue.
def handler(signum, frame):
del signum, frame
self._finish_process(_FINISH_PROPERLY_MESSAGE, None, stdout_collector,
stderr_collector)
os._exit(0) # pylint: disable=protected-access
signal.signal(signal.SIGALRM, handler)
signal.alarm(self._max_run_time)
try:
with self._runtime_mode():
return_value = self._proc_func(*arg, **kwargs)
except Exception: # pylint: disable=broad-except
# Capture all exceptions to be reported to parent process.
self._finish_process(
_ExcInfoWrapper(sys.exc_info()), return_value, stdout_collector,
stderr_collector)
# Re-raise the exception in addition to reporting it to the parent
# process, so that even if `--test_timeout` flag is set and the
# error doesn't make it to be shown in parent process before bazel's
# timeout, the log would still show what happens in this subprocess,
# instead of silently suppressing the error due to early bazel
# timeout. Raising an error in the subprocess produces stack trace in
# the log, but the program continues running.
raise
self._finish_process(_FINISH_PROPERLY_MESSAGE, return_value,
stdout_collector, stderr_collector)
def start(self):
"""Starts processes, one for each task in `cluster_spec`."""
for task_type, addresses in self._cluster_spec.items():
for task_id, _ in enumerate(addresses):
p = multi_process_lib.Process(
target=self._proc_func_wrapper,
args=(task_type, task_id) + self._args,
kwargs=self._kwargs)
p.start()
self._processes.append(p)
def _queue_to_list(self, queue_to_convert):
"""Convert `queue.Queue` to `list`."""
list_to_return = []
# Calling `queue.empty()` is not reliable.
while True:
try:
list_to_return.append(queue_to_convert.get(block=False))
except Queue.Empty:
break
return list_to_return
def join(self, timeout=None):
"""Joins all the processes with timeout.
Args:
data: data to be made available in the queue returned by
`multi_process_runner.run()`.
timeout: if set and not all processes report status within roughly
`timeout` seconds, a `RuntimeError` exception will be thrown.
Returns:
It returns a tuple. The first element is a list that stores the return
data added by subprocesses through `_add_return_data` or through normal
function return; The second element is a list of the messages streamed to
stdout and stderr in the subprocesses if `capture_std_stream` is True or
`None` otherwise.
Raises:
RuntimeError: if not all processes report status within `timeout` seconds.
Or the exception propagated from any child process.
"""
if not timeout:
if self._max_run_time:
timeout = self._max_run_time + 10 # add 10 seconds grace period
else:
timeout = float('inf')
num_returned = 0
start_time = time.time()
while num_returned < len(self._processes):
while True:
try:
process_status = self._get_process_status_queue().get(timeout=10)
break
except Queue.Empty:
if time.time() - start_time > timeout:
# If none of those did, report timeout to user.
raise RuntimeError(
'One or more subprocesses timed out. Please use '
'`--test_arg=--logtostderr` bazel flag to inspect logs for '
'subprocess debugging info. Number of returned processes is '
'%d.' % num_returned)
num_returned += 1
if isinstance(process_status, _ExcInfoWrapper):
six.reraise(*process_status.exc_info)
assert process_status == _FINISH_PROPERLY_MESSAGE
self._processes = []
if self._capture_std_stream:
# TODO(yuefengz): we need to make sure elements match the same process in
# the two returned lists so as to not surprise users. Consider creating a
# `ReturnData` class.
return tuple(
self._queue_to_list(multi_process_lib.get_user_data()[queue_name])
for queue_name in [RETURN_VALUE_QUEUE, STD_STREAM_QUEUE])
else:
return (self._queue_to_list(
multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE]), None)
def _add_return_data(self, data):
"""Adds return data that will be returned by `join`.
The function provides a way for child processes to communicate with the
parent process. Data passed to `_add_return_data` will be available in a
Python Queue.Queue that is eventually returned by `join`.
Args:
data: data to be made available in the queue returned by `join`.
"""
# TODO(rchao): Incorporate the task type and id information in a data
# wrapper that becomes what is stored in the queue so we can tell where
# the data is from.
multi_process_lib.get_user_data()[_AvailableQueues.PUBLIC_QUEUE].put(data)
multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE].put(data)
def _add_std_stream_data_flattened(self, data):
std_stream_queue = multi_process_lib.get_user_data()[
_AvailableQueues.STD_STREAM_QUEUE]
# TODO(yuefengz): currently the same queue is used by multiple processes. It
# is difficult for users to distinguish between logs from different
# processes.
std_stream_queue = multi_process_lib.get_user_data()[STD_STREAM_QUEUE]
for d in list(data):
std_stream_queue.put(d)
def _get_internal_queue(self):
return multi_process_lib.get_user_data()[_AvailableQueues.INTERNAL_QUEUE]
def _get_process_status_queue(self):
return multi_process_lib.get_user_data()[PROCESS_STATUS_QUEUE]
def run(proc_func,
cluster_spec,
max_run_time=None,
capture_std_stream=False,
args=None,
kwargs=None): # pylint: disable=g-doc-args
"""Runs functions in local child processes.
It is a convenience method that creates a `MultiProcessRunner` object and
invokes `start` and `join` method. Please see these methods for detailed
documentations.
Returns:
A tuple returned from `MultiProcessRunner.join()`.
"""
runner = MultiProcessRunner(
proc_func,
cluster_spec,
max_run_time=max_run_time,
capture_std_stream=capture_std_stream,
args=args,
kwargs=kwargs)
runner.start()
return runner.join()
def test_main():
"""Main function to be called within `__main__` of a test file."""
with multi_process_lib.context_manager():
test.main()
def job_count_to_cluster_spec(job_count_dict):
"""Convert a job count dict to cluster spec.
Args:
job_count_dict: Dict for task_type/count of such task type.
{'worker': 1, 'ps': 1} is an example of a cluster with a worker and a
ps.
Returns:
The converted cluster spec dict.
"""
cluster_spec = {}
for task_type, count in job_count_dict.items():
cluster_spec[task_type] = [
'localhost:{}'.format(multi_worker_test_base.pick_unused_port())
for _ in range(count)
]
return cluster_spec

View File

@ -19,23 +19,22 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute.multi_process_runner import MultiProcessRunner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.eager import test
class MultiProcessRunnerNoInitTest(test.TestCase):
def test_stdout_captured(self):
def test_not_calling_correct_main(self):
def simple_func():
return 'foobar'
job_count_dict = {'worker': 1}
with self.assertRaisesRegexp(RuntimeError,
'`multi_process_runner` is not initialized.'):
MultiProcessRunner().run(
multi_process_runner.run(
simple_func,
multi_process_runner.job_count_to_cluster_spec(job_count_dict))
multi_worker_test_base.create_cluster_spec(num_workers=1))
if __name__ == '__main__':

View File

@ -20,19 +20,15 @@ from __future__ import print_function
import time
from absl import flags
from six.moves import queue as Queue
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute.multi_process_runner import MultiProcessRunner
from tensorflow.python.eager import test
flags.DEFINE_boolean(name='test_flag', default=0, help='Test flag')
def proc_func_that_adds_task_type_in_return_data(test_obj):
test_obj.assertTrue(flags.FLAGS.test_flag == 3)
def proc_func_that_adds_task_type_in_return_data(test_obj, val):
test_obj.assertEqual(val, 3)
return multi_worker_test_base.get_task_type()
@ -55,16 +51,13 @@ def proc_func_that_return_args_and_kwargs(*args, **kwargs):
class MultiProcessRunnerTest(test.TestCase):
def test_multi_process_runner(self):
job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 2}
proc_flags = {
'test_flag': 3,
}
returned_data = MultiProcessRunner().run(
returned_data, _ = multi_process_runner.run(
proc_func_that_adds_task_type_in_return_data,
multi_process_runner.job_count_to_cluster_spec(job_count_dict),
proc_flags=proc_flags,
args=(self,))
multi_worker_test_base.create_cluster_spec(
num_workers=2, num_ps=3, has_eval=1),
args=(self, 3))
job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1}
for data in returned_data:
job_count_dict[data] -= 1
@ -73,31 +66,29 @@ class MultiProcessRunnerTest(test.TestCase):
self.assertEqual(job_count_dict['evaluator'], 0)
def test_multi_process_runner_error_propagates_from_subprocesses(self):
job_count_dict = {'worker': 1, 'ps': 1}
runner = multi_process_runner.MultiProcessRunner(
proc_func_that_errors,
multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1),
max_run_time=20)
runner.start()
with self.assertRaisesRegexp(ValueError, 'This is an error.'):
MultiProcessRunner().run(
proc_func_that_errors,
multi_process_runner.job_count_to_cluster_spec(job_count_dict),
timeout=20)
runner.join()
def test_multi_process_runner_queue_emptied_between_runs(self):
job_count_dict = {'worker': 2}
cluster_spec = multi_process_runner.job_count_to_cluster_spec(
job_count_dict)
returned_data = MultiProcessRunner().run(
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
returned_data, _ = multi_process_runner.run(
proc_func_that_adds_simple_return_data, cluster_spec)
self.assertTrue(returned_data)
self.assertEqual(returned_data[0], 'dummy_data')
self.assertEqual(returned_data[1], 'dummy_data')
returned_data = MultiProcessRunner().run(proc_func_that_does_nothing,
cluster_spec)
returned_data, _ = multi_process_runner.run(proc_func_that_does_nothing,
cluster_spec)
self.assertFalse(returned_data)
def test_multi_process_runner_args_passed_correctly(self):
job_count_dict = {'worker': 1}
returned_data = MultiProcessRunner().run(
returned_data, _ = multi_process_runner.run(
proc_func_that_return_args_and_kwargs,
multi_process_runner.job_count_to_cluster_spec(job_count_dict),
multi_worker_test_base.create_cluster_spec(num_workers=1),
args=('a', 'b'),
kwargs={'c_k': 'c_v'})
self.assertEqual(returned_data[0][0], 'a')
@ -110,11 +101,10 @@ class MultiProcessRunnerTest(test.TestCase):
print('This is something printed.')
return 'This is returned data.'
job_count_dict = {'worker': 2}
returned_data, std_stream_data = MultiProcessRunner().run(
returned_data, std_stream_data = multi_process_runner.run(
simple_print_func,
multi_process_runner.job_count_to_cluster_spec(job_count_dict),
return_std_stream=True)
multi_worker_test_base.create_cluster_spec(num_workers=2),
capture_std_stream=True)
num_string_std_stream = len(
[d for d in std_stream_data if d == 'This is something printed.'])
num_string_returned_data = len(
@ -123,34 +113,32 @@ class MultiProcessRunnerTest(test.TestCase):
self.assertEqual(num_string_returned_data, 2)
def test_process_that_exits(self):
mpr = MultiProcessRunner()
def func_to_exit_in_10_sec():
time.sleep(5)
mpr._add_return_data('foo')
time.sleep(20)
mpr._add_return_data('bar')
job_count_dict = {'worker': 1}
returned_data = mpr.run(
mpr = multi_process_runner.MultiProcessRunner(
func_to_exit_in_10_sec,
multi_process_runner.job_count_to_cluster_spec(job_count_dict),
time_to_exit=10)
multi_worker_test_base.create_cluster_spec(num_workers=1),
max_run_time=10)
mpr.start()
returned_data, _ = mpr.join()
self.assertLen(returned_data, 1)
def test_signal_doesnt_fire_after_process_exits(self):
job_count_dict = {'worker': 1}
mpr = MultiProcessRunner()
mpr.run(
mpr = multi_process_runner.MultiProcessRunner(
proc_func_that_does_nothing,
multi_process_runner.job_count_to_cluster_spec(job_count_dict),
time_to_exit=10)
time.sleep(15)
multi_worker_test_base.create_cluster_spec(num_workers=1),
max_run_time=10)
mpr.start()
mpr.join()
with self.assertRaisesRegexp(Queue.Empty, ''):
# If the signal was fired, another message would be added to internal
# queue, so verifying it's empty.
mpr._get_internal_queue().get(block=False)
mpr._get_process_status_queue().get(block=False)
if __name__ == '__main__':

View File

@ -78,7 +78,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
with multi_process_runner_util.try_run_and_except_connection_error(self):
multi_process_runner.MultiProcessRunner().run(
multi_process_runner.run(
worker_fn,
cluster_spec=test_base.create_cluster_spec(num_workers=num_workers))