Refactor MultiProcessRunner:
* break run function into smaller functions * create a python thread/process-like API * remove some seldom used features PiperOrigin-RevId: 279433471 Change-Id: Ibb7febc2347f45872d1455e000ddcfb9541eee23
This commit is contained in:
parent
7e48fd7fcf
commit
77b19fb814
@ -1321,6 +1321,8 @@ py_library(
|
||||
":multi_process_runner_util",
|
||||
":multi_worker_test_base",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -24,12 +24,13 @@ import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
import six
|
||||
from six.moves import queue as Queue
|
||||
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.distribute import multi_process_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.platform import test
|
||||
@ -37,23 +38,20 @@ from tensorflow.python.platform import test
|
||||
_FINISH_PROPERLY_MESSAGE = 'OK'
|
||||
_ExcInfoWrapper = collections.namedtuple('_ExcInfoWrapper', ['exc_info'])
|
||||
|
||||
|
||||
class _AvailableQueues(object):
|
||||
"""Names of the available queues used by `multi_process_runner`."""
|
||||
# Internal queue is used by `multi_process_runner` internally for
|
||||
# communication from subprocesses to the parent process. The message
|
||||
# can be _FINISH_PROPERLY_MESSAGE in which case the subprocess has ended successfully, or
|
||||
# the detailed message of an exception if the subprocess has raised
|
||||
# one so it can be re-raised by the parent process.
|
||||
INTERNAL_QUEUE = 'internal_queue'
|
||||
# Public queue is intended to be used by users of `multi_process_runner`
|
||||
# for the process function to return information to the caller of
|
||||
# `multi_process_runner.run()`.
|
||||
PUBLIC_QUEUE = 'public_queue'
|
||||
# Standard stream queue is used by `multi_process_runner` to collect
|
||||
# information streamed to stdout and stderr to be reported back to the
|
||||
# parent process.
|
||||
STD_STREAM_QUEUE = 'std_stream_queue'
|
||||
# Process status queue is used by `multi_process_runner` internally for
|
||||
# communication from subprocesses to the parent process. The message can be
|
||||
# _FINISH_PROPERLY_MESSAGE in which case the subprocess has ended
|
||||
# successfully, or the detailed message of an exception if the subprocess has
|
||||
# raised one so it can be re-raised by the parent process.
|
||||
PROCESS_STATUS_QUEUE = 'process_status_queue'
|
||||
# Return value queue is intended to be used by users of `multi_process_runner`
|
||||
# for the process function to return information to the caller of
|
||||
# `multi_process_runner.run()`.
|
||||
RETURN_VALUE_QUEUE = 'return_value_queue'
|
||||
# Standard stream queue is used by `multi_process_runner` to collect
|
||||
# information streamed to stdout and stderr to be reported back to the
|
||||
# parent process.
|
||||
STD_STREAM_QUEUE = 'std_stream_queue'
|
||||
|
||||
|
||||
class _LogCollector(object):
|
||||
@ -72,29 +70,32 @@ class _LogCollector(object):
|
||||
|
||||
|
||||
class MultiProcessRunner(object):
|
||||
"""A utility to start multiple subprocesses to simulate multiple workers.
|
||||
"""A utility class to start multiple processes to simulate a cluster.
|
||||
|
||||
Training with multiple workers with eager runtime can be tested by simulating
|
||||
using multiple processes. See `run()` for more information about the usage
|
||||
of this class.
|
||||
We need to use multiple processes to simulate a cluster in TF 2.0 tests
|
||||
because TF 2.0 has some process-global data structures that have to be
|
||||
separated by processes. We also need child processes to test out our fault
|
||||
tolerance because shutting down a standard TensorFlow server within its
|
||||
process is not supported.
|
||||
|
||||
Note: the main test program that uses this runner class must run main program
|
||||
via `test_main` defined in this file. Using this runner in non-test binaries
|
||||
is not supported yet.
|
||||
|
||||
This class is not thread-safe. Child processes will inherit TF2 behavior flag.
|
||||
"""
|
||||
|
||||
def run(self,
|
||||
proc_func,
|
||||
cluster_spec,
|
||||
proc_flags=None,
|
||||
timeout=200,
|
||||
time_to_exit=None,
|
||||
return_std_stream=False,
|
||||
args=None,
|
||||
kwargs=None):
|
||||
"""Run functions on local sub-processes.
|
||||
|
||||
Experimental. API subject to change. To fully inspect logging from
|
||||
subprocesses, use `--test_arg=--logtostderr` flag with bazel test.
|
||||
def __init__(self,
|
||||
proc_func,
|
||||
cluster_spec,
|
||||
max_run_time=None,
|
||||
capture_std_stream=False,
|
||||
args=None,
|
||||
kwargs=None):
|
||||
"""Creates a multi-process runner.
|
||||
|
||||
Args:
|
||||
proc_func: Function to be run on the processes. This will be run on
|
||||
proc_func: Function to be run on child processes. This will be run on
|
||||
processes for all task types.
|
||||
cluster_spec: Dict for cluster spec. The following is an example of
|
||||
cluster with three workers and two ps's.
|
||||
@ -103,39 +104,19 @@ class MultiProcessRunner(object):
|
||||
"worker2.example.com:2222"],
|
||||
"ps": ["ps0.example.com:2222",
|
||||
"ps1.example.com:2222"]}
|
||||
proc_flags: Dict that contains the key/values of the flags used on the
|
||||
processes.
|
||||
timeout: Time out in seconds. If the sub-process takes more than this time
|
||||
to complete, raise an error.
|
||||
time_to_exit: If set, sub-processes is forced to exit at approximately
|
||||
this many seconds after `run()` is called, through `signal.alarm()` api.
|
||||
This is for simulation of interruption on a process so in such cases no
|
||||
error is raised. Note that this is best effort at Python level since
|
||||
Python signal handler does not get executed inside the low-level (C)
|
||||
signal handler, so it can be delayed.
|
||||
return_std_stream: Boolean, whether the messages streamed to stdout and
|
||||
stderr in subprocesses are captured. If True, the messages are stored in
|
||||
a list returned as the second element.
|
||||
max_run_time: If set, child processes is forced to exit at approximately
|
||||
this many seconds after `start` is called. We achieve this through
|
||||
`signal.alarm()` api. Note that this is best effort at Python level
|
||||
since Python signal handler does not get executed when it runs lower
|
||||
level C/C++ code. So it can be delayed for arbitrarily long time.
|
||||
capture_std_stream: Boolean, whether the messages streamed to stdout and
|
||||
stderr in subprocesses are captured.
|
||||
args: Positional arguments to be sent to functions run on processes.
|
||||
kwargs: Keyword arguments to be sent to functions run on processes.
|
||||
|
||||
Returns:
|
||||
If `return_std_stream` is False, a list that stores the return data added
|
||||
by subprocesses through `multi_process_runner._add_return_data(data)`
|
||||
call,
|
||||
or through normal function return; if `return_std_stream` is True, a
|
||||
two-element tuple of `(return_data_list, std_stream_data_list)`, where
|
||||
`return_data_list` stores the return data added by processes through
|
||||
`multi_process_runner._add_return_data(data)` call or through normal
|
||||
function
|
||||
return, and `std_stream_data_list` stores the messages streamed to stdout
|
||||
and stderr in the subprocesses.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If any of the subprocesses raise an error, or if any of the
|
||||
subprocesses does not return or error out within `timeout` seconds.
|
||||
RuntimeError: if `multi_process_runner.test_main()` is not called.
|
||||
"""
|
||||
|
||||
assert cluster_spec is not None
|
||||
assert callable(proc_func)
|
||||
|
||||
@ -146,192 +127,233 @@ class MultiProcessRunner(object):
|
||||
'in your python module to properly initialize '
|
||||
'`multi_process_runner`.')
|
||||
|
||||
processes = []
|
||||
args = args or ()
|
||||
kwargs = kwargs or {}
|
||||
self._proc_func = proc_func
|
||||
self._cluster_spec = cluster_spec
|
||||
self._max_run_time = max_run_time
|
||||
self._capture_std_stream = capture_std_stream
|
||||
self._args = args or ()
|
||||
self._kwargs = kwargs or {}
|
||||
self._processes = []
|
||||
|
||||
def wrapper_func(tf_config_as_json, proc_func, proc_flags, time_to_exit,
|
||||
executing_eagerly, *arg, **kwargs):
|
||||
"""The wrapper function that actually gets run on the process(es)."""
|
||||
# Child processes should have the same v2 and eager behavior.
|
||||
self._v2_enabled = tf2.enabled()
|
||||
self._executing_eagerly = context.executing_eagerly()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def runtime_mode(executing_eagerly):
|
||||
if executing_eagerly:
|
||||
with context.eager_mode():
|
||||
yield
|
||||
else:
|
||||
with context.graph_mode():
|
||||
yield
|
||||
|
||||
with runtime_mode(executing_eagerly):
|
||||
os.environ['TF_CONFIG'] = tf_config_as_json
|
||||
if proc_flags is not None:
|
||||
for flag_key, flag_value in proc_flags.items():
|
||||
setattr(flags.FLAGS, flag_key, flag_value)
|
||||
|
||||
stdout_collector = _LogCollector(
|
||||
sys.__stdout__) if return_std_stream else None
|
||||
stderr_collector = _LogCollector(
|
||||
sys.__stderr__) if return_std_stream else None
|
||||
|
||||
def finish_wrapper_func_properly(func_result):
|
||||
"""Call to finish `wrapper_func` properly."""
|
||||
# Clear the alarm.
|
||||
signal.alarm(0)
|
||||
if (return_std_stream and stdout_collector is not None and
|
||||
stderr_collector is not None):
|
||||
# If stdout and stderr are to be collected, add them to std stream
|
||||
# queue.
|
||||
self._add_std_stream_data_flattened(stdout_collector.log)
|
||||
self._add_std_stream_data_flattened(stderr_collector.log)
|
||||
# Un-redirect stdout and stderr.
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
self._get_internal_queue().put(func_result)
|
||||
|
||||
if time_to_exit is not None:
|
||||
|
||||
def handler(signum, frame):
|
||||
del signum, frame
|
||||
finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)
|
||||
# pylint: disable=protected-access
|
||||
os._exit(0)
|
||||
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(time_to_exit)
|
||||
|
||||
if return_std_stream:
|
||||
sys.stdout = stdout_collector
|
||||
sys.stderr = stderr_collector
|
||||
|
||||
try:
|
||||
return_data = proc_func(*arg, **kwargs)
|
||||
if return_data is not None:
|
||||
self._add_return_data(return_data)
|
||||
# pylint: disable=broad-except
|
||||
except Exception:
|
||||
# Capture all exceptions to be reported to parent process.
|
||||
finish_wrapper_func_properly(_ExcInfoWrapper(sys.exc_info()))
|
||||
|
||||
# Re-raise the exception in addition to reporting it to the parent
|
||||
# process, so that even if `--test_timeout` flag is set and the
|
||||
# error doesn't make it to be shown in parent process before bazel's
|
||||
# timeout, the log would still show what happens in this subprocess,
|
||||
# instead of silently suppressing the error due to early bazel
|
||||
# timeout. Raising an error in the subprocess produces stack trace in
|
||||
# the log, but the program continues running.
|
||||
raise
|
||||
|
||||
finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)
|
||||
|
||||
# Start number of processes according to `count_dict`.
|
||||
for job_type, addresses in cluster_spec.items():
|
||||
for task_id, _ in enumerate(addresses):
|
||||
tf_config_as_json = json.dumps({
|
||||
'cluster': cluster_spec,
|
||||
'task': {
|
||||
'type': job_type,
|
||||
'index': task_id
|
||||
}
|
||||
})
|
||||
p = multi_process_lib.Process(
|
||||
target=wrapper_func,
|
||||
args=(tf_config_as_json, proc_func, proc_flags, time_to_exit,
|
||||
context.executing_eagerly()) + args,
|
||||
kwargs=kwargs)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
internal_queue_results = []
|
||||
for _ in range(len(processes)):
|
||||
try:
|
||||
internal_queue_results.append(
|
||||
self._get_internal_queue().get(timeout=timeout))
|
||||
except Queue.Empty:
|
||||
# First check if any of the subprocesses raised exception.
|
||||
for internal_queue_result in internal_queue_results:
|
||||
if isinstance(internal_queue_result, _ExcInfoWrapper):
|
||||
six.reraise(*internal_queue_result.exc_info)
|
||||
# If none of those did, report time out to user.
|
||||
raise RuntimeError(
|
||||
'One or more subprocesses timed out. Please use '
|
||||
'`--test_arg=--logtostderr` bazel flag to inspect logs for '
|
||||
'subprocess debugging info. Timeout = {} sec.'.format(timeout))
|
||||
|
||||
for internal_queue_result in internal_queue_results:
|
||||
if isinstance(internal_queue_result, _ExcInfoWrapper):
|
||||
six.reraise(*internal_queue_result.exc_info)
|
||||
assert internal_queue_result == _FINISH_PROPERLY_MESSAGE
|
||||
|
||||
def queue_to_list(queue_to_convert):
|
||||
"""Convert `queue.Queue` to `list`."""
|
||||
list_to_return = []
|
||||
while True:
|
||||
try:
|
||||
list_to_return.append(queue_to_convert.get(block=False))
|
||||
except Queue.Empty:
|
||||
break
|
||||
return list_to_return
|
||||
|
||||
if return_std_stream:
|
||||
return tuple(
|
||||
queue_to_list(multi_process_lib.get_user_data()[queue_name])
|
||||
for queue_name in
|
||||
[_AvailableQueues.PUBLIC_QUEUE, _AvailableQueues.STD_STREAM_QUEUE])
|
||||
@contextlib.contextmanager
|
||||
def _runtime_mode(self):
|
||||
if self._executing_eagerly:
|
||||
with context.eager_mode():
|
||||
yield
|
||||
else:
|
||||
return queue_to_list(
|
||||
multi_process_lib.get_user_data()[_AvailableQueues.PUBLIC_QUEUE])
|
||||
with context.graph_mode():
|
||||
yield
|
||||
|
||||
def _add_return_data(self, data):
|
||||
"""Add return data that will be returned by `multi_process_runner.run()`.
|
||||
def _finish_process(self, func_status, return_value, stdout_collector,
|
||||
stderr_collector):
|
||||
"""Adds data to queues before program exits."""
|
||||
# Clear the alarm.
|
||||
signal.alarm(0)
|
||||
if return_value is not None:
|
||||
self._add_return_data(return_value)
|
||||
if self._capture_std_stream:
|
||||
# If stdout and stderr are to be collected, add them to std stream
|
||||
# queue.
|
||||
self._add_std_stream_data_flattened(stdout_collector.log)
|
||||
self._add_std_stream_data_flattened(stderr_collector.log)
|
||||
self._get_process_status_queue().put(func_status)
|
||||
|
||||
The function provides a way for processes started by
|
||||
`multi_process_runner.run()` to communicate with the original process
|
||||
that started the sub-processes. Data passed to `_add_return_data` will
|
||||
be available in a python Queue.Queue that is eventually returned by
|
||||
`multi_process_runner.run()`.
|
||||
def _proc_func_wrapper(self, task_type, task_id, *arg, **kwargs):
|
||||
"""The wrapper function that actually gets run in child process(es)."""
|
||||
os.environ['TF_CONFIG'] = json.dumps({
|
||||
'cluster': self._cluster_spec,
|
||||
'task': {
|
||||
'type': task_type,
|
||||
'index': task_id,
|
||||
}
|
||||
})
|
||||
|
||||
if self._capture_std_stream:
|
||||
# TODO(yuefengz): consider a lighter way of capturing std streams.
|
||||
stdout_collector = _LogCollector(sys.__stdout__)
|
||||
stderr_collector = _LogCollector(sys.__stderr__)
|
||||
sys.stdout = stdout_collector
|
||||
sys.stderr = stderr_collector
|
||||
else:
|
||||
stdout_collector = None
|
||||
stderr_collector = None
|
||||
|
||||
if self._v2_enabled:
|
||||
v2_compat.enable_v2_behavior()
|
||||
|
||||
return_value = None
|
||||
|
||||
if self._max_run_time is not None:
|
||||
# Register an sigalarm handler to exit the process when it reaches
|
||||
# `timeout` seconds. A program reaching `timeout` doesn't necessarily
|
||||
# indicate an issue.
|
||||
def handler(signum, frame):
|
||||
del signum, frame
|
||||
self._finish_process(_FINISH_PROPERLY_MESSAGE, None, stdout_collector,
|
||||
stderr_collector)
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(self._max_run_time)
|
||||
|
||||
try:
|
||||
with self._runtime_mode():
|
||||
return_value = self._proc_func(*arg, **kwargs)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# Capture all exceptions to be reported to parent process.
|
||||
self._finish_process(
|
||||
_ExcInfoWrapper(sys.exc_info()), return_value, stdout_collector,
|
||||
stderr_collector)
|
||||
|
||||
# Re-raise the exception in addition to reporting it to the parent
|
||||
# process, so that even if `--test_timeout` flag is set and the
|
||||
# error doesn't make it to be shown in parent process before bazel's
|
||||
# timeout, the log would still show what happens in this subprocess,
|
||||
# instead of silently suppressing the error due to early bazel
|
||||
# timeout. Raising an error in the subprocess produces stack trace in
|
||||
# the log, but the program continues running.
|
||||
raise
|
||||
|
||||
self._finish_process(_FINISH_PROPERLY_MESSAGE, return_value,
|
||||
stdout_collector, stderr_collector)
|
||||
|
||||
def start(self):
|
||||
"""Starts processes, one for each task in `cluster_spec`."""
|
||||
for task_type, addresses in self._cluster_spec.items():
|
||||
for task_id, _ in enumerate(addresses):
|
||||
p = multi_process_lib.Process(
|
||||
target=self._proc_func_wrapper,
|
||||
args=(task_type, task_id) + self._args,
|
||||
kwargs=self._kwargs)
|
||||
p.start()
|
||||
self._processes.append(p)
|
||||
|
||||
def _queue_to_list(self, queue_to_convert):
|
||||
"""Convert `queue.Queue` to `list`."""
|
||||
list_to_return = []
|
||||
# Calling `queue.empty()` is not reliable.
|
||||
while True:
|
||||
try:
|
||||
list_to_return.append(queue_to_convert.get(block=False))
|
||||
except Queue.Empty:
|
||||
break
|
||||
return list_to_return
|
||||
|
||||
def join(self, timeout=None):
|
||||
"""Joins all the processes with timeout.
|
||||
|
||||
Args:
|
||||
data: data to be made available in the queue returned by
|
||||
`multi_process_runner.run()`.
|
||||
timeout: if set and not all processes report status within roughly
|
||||
`timeout` seconds, a `RuntimeError` exception will be thrown.
|
||||
|
||||
Returns:
|
||||
It returns a tuple. The first element is a list that stores the return
|
||||
data added by subprocesses through `_add_return_data` or through normal
|
||||
function return; The second element is a list of the messages streamed to
|
||||
stdout and stderr in the subprocesses if `capture_std_stream` is True or
|
||||
`None` otherwise.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if not all processes report status within `timeout` seconds.
|
||||
Or the exception propagated from any child process.
|
||||
"""
|
||||
if not timeout:
|
||||
if self._max_run_time:
|
||||
timeout = self._max_run_time + 10 # add 10 seconds grace period
|
||||
else:
|
||||
timeout = float('inf')
|
||||
num_returned = 0
|
||||
start_time = time.time()
|
||||
while num_returned < len(self._processes):
|
||||
while True:
|
||||
try:
|
||||
process_status = self._get_process_status_queue().get(timeout=10)
|
||||
break
|
||||
except Queue.Empty:
|
||||
if time.time() - start_time > timeout:
|
||||
# If none of those did, report timeout to user.
|
||||
raise RuntimeError(
|
||||
'One or more subprocesses timed out. Please use '
|
||||
'`--test_arg=--logtostderr` bazel flag to inspect logs for '
|
||||
'subprocess debugging info. Number of returned processes is '
|
||||
'%d.' % num_returned)
|
||||
|
||||
num_returned += 1
|
||||
if isinstance(process_status, _ExcInfoWrapper):
|
||||
six.reraise(*process_status.exc_info)
|
||||
assert process_status == _FINISH_PROPERLY_MESSAGE
|
||||
|
||||
self._processes = []
|
||||
|
||||
if self._capture_std_stream:
|
||||
# TODO(yuefengz): we need to make sure elements match the same process in
|
||||
# the two returned lists so as to not surprise users. Consider creating a
|
||||
# `ReturnData` class.
|
||||
return tuple(
|
||||
self._queue_to_list(multi_process_lib.get_user_data()[queue_name])
|
||||
for queue_name in [RETURN_VALUE_QUEUE, STD_STREAM_QUEUE])
|
||||
else:
|
||||
return (self._queue_to_list(
|
||||
multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE]), None)
|
||||
|
||||
def _add_return_data(self, data):
|
||||
"""Adds return data that will be returned by `join`.
|
||||
|
||||
The function provides a way for child processes to communicate with the
|
||||
parent process. Data passed to `_add_return_data` will be available in a
|
||||
Python Queue.Queue that is eventually returned by `join`.
|
||||
|
||||
Args:
|
||||
data: data to be made available in the queue returned by `join`.
|
||||
"""
|
||||
# TODO(rchao): Incorporate the task type and id information in a data
|
||||
# wrapper that becomes what is stored in the queue so we can tell where
|
||||
# the data is from.
|
||||
multi_process_lib.get_user_data()[_AvailableQueues.PUBLIC_QUEUE].put(data)
|
||||
multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE].put(data)
|
||||
|
||||
def _add_std_stream_data_flattened(self, data):
|
||||
std_stream_queue = multi_process_lib.get_user_data()[
|
||||
_AvailableQueues.STD_STREAM_QUEUE]
|
||||
# TODO(yuefengz): currently the same queue is used by multiple processes. It
|
||||
# is difficult for users to distinguish between logs from different
|
||||
# processes.
|
||||
std_stream_queue = multi_process_lib.get_user_data()[STD_STREAM_QUEUE]
|
||||
for d in list(data):
|
||||
std_stream_queue.put(d)
|
||||
|
||||
def _get_internal_queue(self):
|
||||
return multi_process_lib.get_user_data()[_AvailableQueues.INTERNAL_QUEUE]
|
||||
def _get_process_status_queue(self):
|
||||
return multi_process_lib.get_user_data()[PROCESS_STATUS_QUEUE]
|
||||
|
||||
|
||||
def run(proc_func,
|
||||
cluster_spec,
|
||||
max_run_time=None,
|
||||
capture_std_stream=False,
|
||||
args=None,
|
||||
kwargs=None): # pylint: disable=g-doc-args
|
||||
"""Runs functions in local child processes.
|
||||
|
||||
It is a convenience method that creates a `MultiProcessRunner` object and
|
||||
invokes `start` and `join` method. Please see these methods for detailed
|
||||
documentations.
|
||||
|
||||
Returns:
|
||||
A tuple returned from `MultiProcessRunner.join()`.
|
||||
"""
|
||||
runner = MultiProcessRunner(
|
||||
proc_func,
|
||||
cluster_spec,
|
||||
max_run_time=max_run_time,
|
||||
capture_std_stream=capture_std_stream,
|
||||
args=args,
|
||||
kwargs=kwargs)
|
||||
runner.start()
|
||||
return runner.join()
|
||||
|
||||
|
||||
def test_main():
|
||||
"""Main function to be called within `__main__` of a test file."""
|
||||
with multi_process_lib.context_manager():
|
||||
test.main()
|
||||
|
||||
|
||||
def job_count_to_cluster_spec(job_count_dict):
|
||||
"""Convert a job count dict to cluster spec.
|
||||
|
||||
Args:
|
||||
job_count_dict: Dict for task_type/count of such task type.
|
||||
{'worker': 1, 'ps': 1} is an example of a cluster with a worker and a
|
||||
ps.
|
||||
|
||||
Returns:
|
||||
The converted cluster spec dict.
|
||||
"""
|
||||
|
||||
cluster_spec = {}
|
||||
for task_type, count in job_count_dict.items():
|
||||
cluster_spec[task_type] = [
|
||||
'localhost:{}'.format(multi_worker_test_base.pick_unused_port())
|
||||
for _ in range(count)
|
||||
]
|
||||
return cluster_spec
|
||||
|
@ -19,23 +19,22 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute.multi_process_runner import MultiProcessRunner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.eager import test
|
||||
|
||||
|
||||
class MultiProcessRunnerNoInitTest(test.TestCase):
|
||||
|
||||
def test_stdout_captured(self):
|
||||
def test_not_calling_correct_main(self):
|
||||
|
||||
def simple_func():
|
||||
return 'foobar'
|
||||
|
||||
job_count_dict = {'worker': 1}
|
||||
with self.assertRaisesRegexp(RuntimeError,
|
||||
'`multi_process_runner` is not initialized.'):
|
||||
MultiProcessRunner().run(
|
||||
multi_process_runner.run(
|
||||
simple_func,
|
||||
multi_process_runner.job_count_to_cluster_spec(job_count_dict))
|
||||
multi_worker_test_base.create_cluster_spec(num_workers=1))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -20,19 +20,15 @@ from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
from six.moves import queue as Queue
|
||||
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute.multi_process_runner import MultiProcessRunner
|
||||
from tensorflow.python.eager import test
|
||||
|
||||
flags.DEFINE_boolean(name='test_flag', default=0, help='Test flag')
|
||||
|
||||
|
||||
def proc_func_that_adds_task_type_in_return_data(test_obj):
|
||||
test_obj.assertTrue(flags.FLAGS.test_flag == 3)
|
||||
def proc_func_that_adds_task_type_in_return_data(test_obj, val):
|
||||
test_obj.assertEqual(val, 3)
|
||||
return multi_worker_test_base.get_task_type()
|
||||
|
||||
|
||||
@ -55,16 +51,13 @@ def proc_func_that_return_args_and_kwargs(*args, **kwargs):
|
||||
class MultiProcessRunnerTest(test.TestCase):
|
||||
|
||||
def test_multi_process_runner(self):
|
||||
job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 2}
|
||||
proc_flags = {
|
||||
'test_flag': 3,
|
||||
}
|
||||
returned_data = MultiProcessRunner().run(
|
||||
returned_data, _ = multi_process_runner.run(
|
||||
proc_func_that_adds_task_type_in_return_data,
|
||||
multi_process_runner.job_count_to_cluster_spec(job_count_dict),
|
||||
proc_flags=proc_flags,
|
||||
args=(self,))
|
||||
multi_worker_test_base.create_cluster_spec(
|
||||
num_workers=2, num_ps=3, has_eval=1),
|
||||
args=(self, 3))
|
||||
|
||||
job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1}
|
||||
for data in returned_data:
|
||||
job_count_dict[data] -= 1
|
||||
|
||||
@ -73,31 +66,29 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||
self.assertEqual(job_count_dict['evaluator'], 0)
|
||||
|
||||
def test_multi_process_runner_error_propagates_from_subprocesses(self):
|
||||
job_count_dict = {'worker': 1, 'ps': 1}
|
||||
runner = multi_process_runner.MultiProcessRunner(
|
||||
proc_func_that_errors,
|
||||
multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1),
|
||||
max_run_time=20)
|
||||
runner.start()
|
||||
with self.assertRaisesRegexp(ValueError, 'This is an error.'):
|
||||
MultiProcessRunner().run(
|
||||
proc_func_that_errors,
|
||||
multi_process_runner.job_count_to_cluster_spec(job_count_dict),
|
||||
timeout=20)
|
||||
runner.join()
|
||||
|
||||
def test_multi_process_runner_queue_emptied_between_runs(self):
|
||||
job_count_dict = {'worker': 2}
|
||||
cluster_spec = multi_process_runner.job_count_to_cluster_spec(
|
||||
job_count_dict)
|
||||
returned_data = MultiProcessRunner().run(
|
||||
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
|
||||
returned_data, _ = multi_process_runner.run(
|
||||
proc_func_that_adds_simple_return_data, cluster_spec)
|
||||
self.assertTrue(returned_data)
|
||||
self.assertEqual(returned_data[0], 'dummy_data')
|
||||
self.assertEqual(returned_data[1], 'dummy_data')
|
||||
returned_data = MultiProcessRunner().run(proc_func_that_does_nothing,
|
||||
cluster_spec)
|
||||
returned_data, _ = multi_process_runner.run(proc_func_that_does_nothing,
|
||||
cluster_spec)
|
||||
self.assertFalse(returned_data)
|
||||
|
||||
def test_multi_process_runner_args_passed_correctly(self):
|
||||
job_count_dict = {'worker': 1}
|
||||
returned_data = MultiProcessRunner().run(
|
||||
returned_data, _ = multi_process_runner.run(
|
||||
proc_func_that_return_args_and_kwargs,
|
||||
multi_process_runner.job_count_to_cluster_spec(job_count_dict),
|
||||
multi_worker_test_base.create_cluster_spec(num_workers=1),
|
||||
args=('a', 'b'),
|
||||
kwargs={'c_k': 'c_v'})
|
||||
self.assertEqual(returned_data[0][0], 'a')
|
||||
@ -110,11 +101,10 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||
print('This is something printed.')
|
||||
return 'This is returned data.'
|
||||
|
||||
job_count_dict = {'worker': 2}
|
||||
returned_data, std_stream_data = MultiProcessRunner().run(
|
||||
returned_data, std_stream_data = multi_process_runner.run(
|
||||
simple_print_func,
|
||||
multi_process_runner.job_count_to_cluster_spec(job_count_dict),
|
||||
return_std_stream=True)
|
||||
multi_worker_test_base.create_cluster_spec(num_workers=2),
|
||||
capture_std_stream=True)
|
||||
num_string_std_stream = len(
|
||||
[d for d in std_stream_data if d == 'This is something printed.'])
|
||||
num_string_returned_data = len(
|
||||
@ -123,34 +113,32 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||
self.assertEqual(num_string_returned_data, 2)
|
||||
|
||||
def test_process_that_exits(self):
|
||||
|
||||
mpr = MultiProcessRunner()
|
||||
|
||||
def func_to_exit_in_10_sec():
|
||||
time.sleep(5)
|
||||
mpr._add_return_data('foo')
|
||||
time.sleep(20)
|
||||
mpr._add_return_data('bar')
|
||||
|
||||
job_count_dict = {'worker': 1}
|
||||
returned_data = mpr.run(
|
||||
mpr = multi_process_runner.MultiProcessRunner(
|
||||
func_to_exit_in_10_sec,
|
||||
multi_process_runner.job_count_to_cluster_spec(job_count_dict),
|
||||
time_to_exit=10)
|
||||
multi_worker_test_base.create_cluster_spec(num_workers=1),
|
||||
max_run_time=10)
|
||||
|
||||
mpr.start()
|
||||
returned_data, _ = mpr.join()
|
||||
self.assertLen(returned_data, 1)
|
||||
|
||||
def test_signal_doesnt_fire_after_process_exits(self):
|
||||
job_count_dict = {'worker': 1}
|
||||
mpr = MultiProcessRunner()
|
||||
mpr.run(
|
||||
mpr = multi_process_runner.MultiProcessRunner(
|
||||
proc_func_that_does_nothing,
|
||||
multi_process_runner.job_count_to_cluster_spec(job_count_dict),
|
||||
time_to_exit=10)
|
||||
time.sleep(15)
|
||||
multi_worker_test_base.create_cluster_spec(num_workers=1),
|
||||
max_run_time=10)
|
||||
mpr.start()
|
||||
mpr.join()
|
||||
with self.assertRaisesRegexp(Queue.Empty, ''):
|
||||
# If the signal was fired, another message would be added to internal
|
||||
# queue, so verifying it's empty.
|
||||
mpr._get_internal_queue().get(block=False)
|
||||
mpr._get_process_status_queue().get(block=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -78,7 +78,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
|
||||
with multi_process_runner_util.try_run_and_except_connection_error(self):
|
||||
multi_process_runner.MultiProcessRunner().run(
|
||||
multi_process_runner.run(
|
||||
worker_fn,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=num_workers))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user