Improve multi_process_runner

This is to prepare enabling it for OSS.

PiperOrigin-RevId: 315878874
Change-Id: Ib29bccf3c964462a7643df4b1cd011ddda79372b
This commit is contained in:
Ran Chen 2020-06-11 05:15:16 -07:00 committed by TensorFlower Gardener
parent 51f2a966cc
commit c674577870
12 changed files with 403 additions and 321 deletions

View File

@ -36,6 +36,7 @@ tensorflow/third_party/coremltools.BUILD
tensorflow/third_party/cub.BUILD tensorflow/third_party/cub.BUILD
tensorflow/third_party/curl.BUILD tensorflow/third_party/curl.BUILD
tensorflow/third_party/cython.BUILD tensorflow/third_party/cython.BUILD
tensorflow/third_party/dill.BUILD
tensorflow/third_party/double_conversion.BUILD tensorflow/third_party/double_conversion.BUILD
tensorflow/third_party/eigen.BUILD tensorflow/third_party/eigen.BUILD
tensorflow/third_party/eigen3/BUILD tensorflow/third_party/eigen3/BUILD
@ -196,6 +197,7 @@ tensorflow/third_party/systemlibs/swig.BUILD
tensorflow/third_party/systemlibs/syslibs_configure.bzl tensorflow/third_party/systemlibs/syslibs_configure.bzl
tensorflow/third_party/systemlibs/termcolor.BUILD tensorflow/third_party/systemlibs/termcolor.BUILD
tensorflow/third_party/systemlibs/zlib.BUILD tensorflow/third_party/systemlibs/zlib.BUILD
tensorflow/third_party/tblib.BUILD
tensorflow/third_party/tensorrt/BUILD tensorflow/third_party/tensorrt/BUILD
tensorflow/third_party/tensorrt/BUILD.tpl tensorflow/third_party/tensorrt/BUILD.tpl
tensorflow/third_party/tensorrt/LICENSE tensorflow/third_party/tensorrt/LICENSE

View File

@ -1708,19 +1708,24 @@ cuda_py_test(
py_library( py_library(
name = "multi_process_runner", name = "multi_process_runner",
srcs = ["multi_process_runner.py"], srcs = ["multi_process_runner.py"],
srcs_version = "PY3",
deps = [ deps = [
":multi_process_lib", ":multi_process_lib",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:tf2", "//tensorflow/python:tf2",
"//tensorflow/python/compat:v2_compat", "//tensorflow/python/compat:v2_compat",
"//tensorflow/python/eager:context",
"@absl_py//absl/logging",
"@dill_archive//:dill",
"@six_archive//:six", "@six_archive//:six",
"@tblib_archive//:tblib",
], ],
) )
py_library( py_library(
name = "multi_process_lib", name = "multi_process_lib",
srcs = ["multi_process_lib.py"], srcs = ["multi_process_lib.py"],
deps = ["@six_archive//:six"], deps = ["//tensorflow/python:client_testlib"],
) )
py_test( py_test(
@ -1745,11 +1750,12 @@ py_test(
name = "multi_process_runner_test", name = "multi_process_runner_test",
srcs = ["multi_process_runner_test.py"], srcs = ["multi_process_runner_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 12,
deps = [ deps = [
":multi_process_runner", ":multi_process_runner",
":multi_worker_test_base", ":multi_worker_test_base",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",
"@absl_py//absl/logging",
"@six_archive//:six",
], ],
) )
@ -1757,6 +1763,7 @@ py_test(
name = "multi_process_runner_no_init_test", name = "multi_process_runner_no_init_test",
srcs = ["multi_process_runner_no_init_test.py"], srcs = ["multi_process_runner_no_init_test.py"],
python_version = "PY3", python_version = "PY3",
tags = ["no_oss"],
deps = [ deps = [
":multi_process_runner", ":multi_process_runner",
":multi_worker_test_base", ":multi_worker_test_base",

View File

@ -18,9 +18,18 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import contextlib import multiprocessing as _multiprocessing
import unittest import unittest
from tensorflow.python.platform import test
try:
multiprocessing = _multiprocessing.get_context('forkserver')
except ValueError:
# forkserver is not available on Windows.
multiprocessing = _multiprocessing.get_context('spawn')
class Process(object): class Process(object):
"""A process simulating a worker for testing multi-worker training.""" """A process simulating a worker for testing multi-worker training."""
@ -28,23 +37,14 @@ class Process(object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
del args, kwargs del args, kwargs
raise unittest.SkipTest( raise unittest.SkipTest(
'TODO(b/141874796): Implement OSS version of `multi_process_lib`') 'TODO(b/150264776): Implement OSS version of `multi_process_lib`')
def get_user_data(): def test_main():
"""Returns the data commonly shared by parent process and subprocesses.""" """Main function to be called within `__main__` of a test file."""
# TODO(b/141874796): Implement OSS version of `multi_process_lib`. test.main()
pass
@contextlib.contextmanager def initialized():
def context_manager(max_subprocess_count=20, barrier_parties=0): """Returns whether the module is initialized."""
"""No-op in OSS. This exists to maintain testing compatibility.""" return True
del max_subprocess_count, barrier_parties
yield
def using_context_manager():
"""Whether the context manager is being used."""
raise unittest.SkipTest(
'TODO(b/141874796): Implement OSS version of `multi_process_lib`')

View File

@ -18,6 +18,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import contextlib import contextlib
import json import json
@ -26,6 +27,8 @@ import signal
import sys import sys
import threading import threading
import time import time
import unittest
from absl import logging from absl import logging
import six import six
from six.moves import queue as Queue from six.moves import queue as Queue
@ -34,7 +37,8 @@ from tensorflow.python import tf2
from tensorflow.python.compat import v2_compat from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import multi_process_lib from tensorflow.python.distribute import multi_process_lib
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.platform import test
multiprocessing = multi_process_lib.multiprocessing
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
try: try:
@ -43,61 +47,57 @@ try:
except ImportError: except ImportError:
faulthandler = None faulthandler = None
# TODO(b/150264776): Remove after resolving CI issue.
try:
import dill
except ImportError:
dill = None
# TODO(b/150264776): Remove after resolving CI issue.
try:
import tblib.pickling_support
# For pickling traceback objects.
tblib.pickling_support.install()
except ImportError:
pass
# _ProcessStatusInfo contains process status information. When is_successful # _ProcessStatusInfo contains process status information. When is_successful
# attribute is True, the subprocess has ended successfully, or if False, the # attribute is True, the subprocess has ended successfully, or if False, the
# exception stack trace info is stored in exc_info to pass on to parent process # exception stack trace info is stored in exc_info to pass on to parent process
# to be re-raised. # to be re-raised.
_ProcessStatusInfo = collections.namedtuple( _ProcessStatusInfo = collections.namedtuple(
'_ProcessStatusInfo', ['task_type', 'is_successful', 'exc_info']) '_ProcessStatusInfo',
['task_type', 'is_successful', 'exc_info', 'return_value'])
# _SubprocessInfo collects basic information of a subprocess such as task type
# and process id.
_SubprocessInfo = collections.namedtuple('_SubprocessInfo',
['pid', 'task_type', 'task_id'])
# Information returned from a successful MultiProcessRunner run. # Information returned from a successful MultiProcessRunner run.
MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult', MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult',
['return_value', 'stdout']) ['return_value', 'stdout'])
# Process status queue is used by `multi_process_runner` internally for TestEnvironment = collections.namedtuple('TestEnvironment', [
'task_type', 'task_id', 'cluster_spec', 'rpc_layer', 'grpc_fail_fast',
'v2_enabled', 'executing_eagerly'
])
# Resources for communication between worker processes and the main process.
#
# `process_status_queue` is used by `multi_process_runner` internally for
# communication from subprocesses to the parent process for whether it's been # communication from subprocesses to the parent process for whether it's been
# successful, and if not what the error stack trace is. # successful, and if not what the error stack trace is.
PROCESS_STATUS_QUEUE = 'process_status_queue' # `parent_to_sub_queue` is used for communications from parent to subprocess.
# Return value queue is intended to be used by users of `multi_process_runner`
# for the process function to return information to the caller of
# `multi_process_runner.run()`.
RETURN_VALUE_QUEUE = 'return_value_queue'
# Subprocess info queue stores `_SubprocessInfo` for later potential
# termination by the parent.
SUBPROCESS_INFO_QUEUE = 'subprocess_info_queue'
# Parent-to-sub queue is used for communications from parent to subprocess.
# Currently this is only used to terminate subprocesses. # Currently this is only used to terminate subprocesses.
# TODO(rchao): Remove this once subprocess is terminated by SIGKILL. # TODO(rchao): Remove this once subprocess is terminated by SIGKILL.
PARENT_TO_SUB_QUEUE = 'parent_to_sub_queue' # `streaming_pipe_w` is to stream stdout and stderr from subprocesses to parent
# process.
# Streaming queue stores the logged and printed messages from subprocesses. # `barrier` is a barrier for the party of all subprocesses.
STREAMING_QUEUE = 'streaming_queue' Resources = collections.namedtuple('Resources', [
'process_status_queue', 'parent_to_sub_queue', 'streaming_pipe_w', 'barrier'
# Pipes to stream stdout and stderr from subprocesses to parent process. ])
STREAMING_PIPE = 'streaming_pipe'
# Barrier identifier.
BARRIER = 'barrier'
_DEFAULT_MAX_SUBPROCESS_COUNT = 20
# Default time out sec is selected so that it's handled before the default # Default time out sec is selected so that it's handled before the default
# "medium" timeout of the test runs. # "medium" timeout of the test runs.
_DEFAULT_TIMEOUT_SEC = 200 _DEFAULT_TIMEOUT_SEC = 200
# Next pipe index to be global so that pipes are not reused across multiple
# MultiProcessRunner usages.
# TODO(rchao): Investigate possibility to remove this variable.
_next_pipe_index = 0
class MultiProcessRunner(object): class MultiProcessRunner(object):
"""A utility class to start multiple processes to simulate a cluster. """A utility class to start multiple processes to simulate a cluster.
@ -123,6 +123,7 @@ class MultiProcessRunner(object):
grpc_fail_fast=None, grpc_fail_fast=None,
stream_stdout=True, stream_stdout=True,
list_stdout=False, list_stdout=False,
use_dill_for_args=True,
args=None, args=None,
kwargs=None): kwargs=None):
"""Creates a multi-process runner. """Creates a multi-process runner.
@ -153,6 +154,9 @@ class MultiProcessRunner(object):
returned from `MultiProcessRunner.join()`. If True, the list of stdout returned from `MultiProcessRunner.join()`. If True, the list of stdout
can be retrieved via `MultiProcessRunnerResult.stdout` attribute. can be retrieved via `MultiProcessRunnerResult.stdout` attribute.
Defaults to False. Defaults to False.
use_dill_for_args: Whether to use dill to pickle `args` and `kwargs`. dill
can pickle more objects, but doesn't work with types in
`multiprocessing` library like `Mutex`.
args: Positional arguments to be sent to functions run on processes. args: Positional arguments to be sent to functions run on processes.
kwargs: Keyword arguments to be sent to functions run on processes. kwargs: Keyword arguments to be sent to functions run on processes.
@ -165,15 +169,14 @@ class MultiProcessRunner(object):
raise ValueError('If chief exists in the cluster, there must be at most ' raise ValueError('If chief exists in the cluster, there must be at most '
'one chief. Current `cluster_spec` has {} chiefs.' 'one chief. Current `cluster_spec` has {} chiefs.'
.format(len(cluster_spec['chief']))) .format(len(cluster_spec['chief'])))
if not multi_process_lib.initialized():
assert callable(proc_func)
if not multi_process_lib.using_context_manager():
raise RuntimeError('`multi_process_runner` is not initialized. ' raise RuntimeError('`multi_process_runner` is not initialized. '
'Please call `multi_process_runner.test_main()` ' 'Please call `multi_process_runner.test_main()` '
'within `if __name__ == \'__main__\':` block ' 'within `if __name__ == \'__main__\':` block '
'in your python module to properly initialize ' 'in your python module to properly initialize '
'`multi_process_runner`.') '`multi_process_runner`.')
if not callable(proc_func):
raise ValueError('proc_func is not a callable')
self._proc_func = proc_func self._proc_func = proc_func
self._cluster_spec = cluster_spec self._cluster_spec = cluster_spec
@ -184,62 +187,90 @@ class MultiProcessRunner(object):
# TODO(rchao): Revisit list_stdout argument to consider other solution. # TODO(rchao): Revisit list_stdout argument to consider other solution.
self._list_stdout = list_stdout self._list_stdout = list_stdout
self._dependence_on_chief = True self._dependence_on_chief = True
self._use_dill_for_args = use_dill_for_args
self._args = args or () self._args = args or ()
self._kwargs = kwargs or {} self._kwargs = kwargs or {}
self._outstanding_subprocess_count = 0
# Child processes should have the same v2 and eager behavior. # Child processes should have the same v2 and eager behavior.
self._v2_enabled = tf2.enabled() self._v2_enabled = tf2.enabled()
self._executing_eagerly = context.executing_eagerly() self._executing_eagerly = context.executing_eagerly()
self._joined = False
self._processes = {}
self._outstanding_subprocess_count = 0
self._reading_threads = []
self._manager = multiprocessing.Manager()
self._process_status_queue = self._manager.Queue()
self._parent_to_sub_queue = self._manager.Queue()
parties = sum(len(addresses) for addresses in self._cluster_spec.values())
self._barrier = self._manager.Barrier(parties)
# We use a queue to collect outputs from worker processes since it's thread
# safe.
self._streaming_queue = self._manager.Queue()
# This flag will be set to True once terminate_all() is called. # This flag will be set to True once terminate_all() is called.
self._all_forced_terminated = False self._all_forced_terminated = False
def _continuously_readline_from_sub(self, pipe_r, task_type, task_id): def _continuously_readline_from_sub(self, pipe_r, task_type, task_id):
"""Function to continuously read lines from subprocesses.""" """Function to continuously read lines from subprocesses."""
reader = os.fdopen(pipe_r.fileno(), 'r') with os.fdopen(pipe_r.fileno(), 'r', closefd=False) as reader:
while True: for line in reader:
read_line = reader.readline()
if read_line == 'EOF':
reader.close()
# The thread that runs `_continuously_readline_from_sub` stops here.
# However the threads don't exit until the test exits, so we do not
# attempt to join the threads (which leads to timeout).
# TODO(rchao): Understand why and do thread joining.
break
task_string = '[{}-{}]:'.format(task_type, task_id) task_string = '[{}-{}]:'.format(task_type, task_id)
formatted_line = '{} {}'.format(task_string.ljust(14), read_line) formatted_line = '{} {}'.format(task_string.ljust(14), line)
if self._stream_stdout: if self._stream_stdout:
self._print_stdout_in_parent(formatted_line, task_type, task_id) # TODO(rchao): Use a lock here to ensure the printed lines are not
if self._list_stdout: # broken.
self._add_stdout_in_queue(formatted_line, task_type, task_id)
def _print_stdout_in_parent(self, formatted_line, task_type, task_id):
del task_type, task_id
# Flush True so the logging order from subprocesses is respected.
# TODO(rchao): Use a lock here to ensure the printed lines are not broken.
print(formatted_line, end='', flush=True) print(formatted_line, end='', flush=True)
if self._list_stdout:
self._streaming_queue.put(formatted_line)
def _add_stdout_in_queue(self, formatted_line, task_type, task_id): def _start_subprocess_and_reading_thread(self,
del task_type, task_id task_type,
# A queue instead of a simple list is used here due to b/150652733. task_id,
_resource(STREAMING_QUEUE).put(formatted_line) cluster_spec=None,
proc_func=None,
def _start_subprocess_and_reading_thread(self, proc_func, task_type, task_id, args=None,
cluster_spec, args, kwargs): kwargs=None):
"""Start a subprocess and a thread the reads lines from the subprocess.""" """Start a subprocess and a thread the reads lines from the subprocess."""
global _next_pipe_index
pipe_r, pipe_w = _resource(STREAMING_PIPE)[_next_pipe_index]
_next_pipe_index += 1
p = multi_process_lib.Process( if dill is None:
target=_Subprocess(), raise unittest.SkipTest(
args=(proc_func, task_type, task_id, cluster_spec, self._rpc_layer, 'TODO(b/150264776): Resolve dependency issue in CI')
self._grpc_fail_fast, self._v2_enabled, self._executing_eagerly,
pipe_w) + args, test_env = TestEnvironment(
kwargs=kwargs) task_type=task_type,
task_id=task_id,
cluster_spec=cluster_spec or self._cluster_spec,
rpc_layer=self._rpc_layer,
grpc_fail_fast=self._grpc_fail_fast,
v2_enabled=self._v2_enabled,
executing_eagerly=self._executing_eagerly,
)
pipe_r, pipe_w = multiprocessing.Pipe(duplex=False)
resources = Resources(
process_status_queue=self._process_status_queue,
parent_to_sub_queue=self._parent_to_sub_queue,
streaming_pipe_w=pipe_w,
barrier=self._barrier,
)
if proc_func is None:
proc_func, args, kwargs = self._proc_func, self._args, self._kwargs
# Always use dill to pickle proc_func so that we support more callable
# types, e.g. lambda.
proc_func = dill.dumps(proc_func, dill.HIGHEST_PROTOCOL)
if self._use_dill_for_args:
args = dill.dumps(args, dill.HIGHEST_PROTOCOL)
kwargs = dill.dumps(kwargs, dill.HIGHEST_PROTOCOL)
p = _Process(
test_env=test_env,
target=_ProcFunc(),
args=(resources, test_env, proc_func, args, kwargs,
self._use_dill_for_args))
p.start() p.start()
self._processes[(task_type, task_id)] = p
self._outstanding_subprocess_count += 1 self._outstanding_subprocess_count += 1
# For each subprocess, we dedicate a thread continuously reading lines # For each subprocess, we dedicate a thread continuously reading lines
@ -248,18 +279,15 @@ class MultiProcessRunner(object):
target=self._continuously_readline_from_sub, target=self._continuously_readline_from_sub,
args=(pipe_r, task_type, task_id)) args=(pipe_r, task_type, task_id))
thread.start() thread.start()
self._reading_threads.append(thread)
def start(self): def start(self):
"""Starts processes, one for each task in `cluster_spec`.""" """Starts processes, one for each task in `cluster_spec`."""
if self._processes:
global _next_pipe_index raise ValueError('MultiProcessRunner already started.')
self._starting_pipe_index = _next_pipe_index
for task_type, addresses in self._cluster_spec.items(): for task_type, addresses in self._cluster_spec.items():
for task_id, _ in enumerate(addresses): for task_id, _ in enumerate(addresses):
self._start_subprocess_and_reading_thread(self._proc_func, task_type, self._start_subprocess_and_reading_thread(task_type, task_id)
task_id, self._cluster_spec,
self._args, self._kwargs)
# TODO(rchao): Remove the need of using SIGALRM if possible. At this time, # TODO(rchao): Remove the need of using SIGALRM if possible. At this time,
# without this the tests become very flaky. # without this the tests become very flaky.
@ -309,33 +337,22 @@ class MultiProcessRunner(object):
as_task_type: The task type to be run in the main process. as_task_type: The task type to be run in the main process.
as_task_id: The task id to be run in the main process. as_task_id: The task id to be run in the main process.
""" """
global _next_pipe_index if self._processes:
self._starting_pipe_index = _next_pipe_index raise ValueError('MultiProcessRunner already started.')
for task_type, addresses in self._cluster_spec.items(): for task_type, addresses in self._cluster_spec.items():
for task_id, _ in enumerate(addresses): for task_id, _ in enumerate(addresses):
if not (task_type == as_task_type and task_id == as_task_id): if not (task_type == as_task_type and task_id == as_task_id):
self._start_subprocess_and_reading_thread(self._proc_func, task_type, self._start_subprocess_and_reading_thread(task_type, task_id)
task_id, self._cluster_spec,
self._args, self._kwargs)
tf_config_dict = {
'cluster': self._cluster_spec,
'task': {
'type': as_task_type,
'index': as_task_id,
},
}
if self._rpc_layer is not None:
tf_config_dict['rpc_layer'] = self._rpc_layer
os.environ['TF_CONFIG'] = json.dumps(tf_config_dict)
_set_tf_config(as_task_type, as_task_id, self._cluster_spec,
self._rpc_layer)
self._proc_func(*self._args, **self._kwargs) self._proc_func(*self._args, **self._kwargs)
def start_single_process(self, def start_single_process(self,
task_type, task_type,
task_id, task_id,
proc_func=None,
cluster_spec=None, cluster_spec=None,
proc_func=None,
args=None, args=None,
kwargs=None): kwargs=None):
"""Starts a single process. """Starts a single process.
@ -352,19 +369,22 @@ class MultiProcessRunner(object):
Args: Args:
task_type: The task type. task_type: The task type.
task_id: The task id. task_id: The task id.
proc_func: The process function to be run on the newly started
process. If `None`, the function provided at `__init__` will be used.
cluster_spec: The cluster spec to be used on the newly started cluster_spec: The cluster spec to be used on the newly started
process. If `None`, the cluster spec provided at `__init__` will be process. If `None`, the cluster spec provided at `__init__` will be
used. used.
proc_func: The process function to be run on the newly started
process. If specified, specify `args` and `kwargs` as well. If `None`,
the function provided at `__init__` will be used.
args: Optional positional arguments to be supplied in `proc_func`. args: Optional positional arguments to be supplied in `proc_func`.
kwargs: Optional keyword arguments to be supplied in `proc_func`. kwargs: Optional keyword arguments to be supplied in `proc_func`.
""" """
cluster_spec = cluster_spec or self._cluster_spec self._start_subprocess_and_reading_thread(
proc_func = proc_func or self._proc_func task_type,
self._start_subprocess_and_reading_thread(proc_func, task_type, task_id, task_id,
cluster_spec, args or (), cluster_spec=cluster_spec,
kwargs or {}) proc_func=proc_func,
args=args or (),
kwargs=kwargs or {})
def _queue_to_list(self, queue_to_convert): def _queue_to_list(self, queue_to_convert):
"""Convert `queue.Queue` to `list`.""" """Convert `queue.Queue` to `list`."""
@ -379,25 +399,20 @@ class MultiProcessRunner(object):
def get_process_id(self, task_type, task_id): def get_process_id(self, task_type, task_id):
"""Returns the subprocess id given the task type and task id.""" """Returns the subprocess id given the task type and task id."""
if not hasattr(self, '_pid_dict'): p = self._processes.get((task_type, task_id), None)
self._pid_dict = {} return p.pid if p else None
subprocess_infos = []
while True: def _join_or_terminate(self, task_type, task_id, process, timeout):
try: """Joins a process. If it times out, terminate all procsses."""
subprocess_info = _resource(SUBPROCESS_INFO_QUEUE).get(block=False) logging.info('joining %s-%d', task_type, task_id)
subprocess_infos.append(subprocess_info) process.join(timeout)
except Queue.Empty: # If exitcode is None, the process aren't terminated and this is a
break # timeout.
if process.exitcode is None:
for subprocess_info in subprocess_infos: # Force termination to dump worker processes stack trace.
self._pid_dict[(subprocess_info.task_type, self.terminate_all(sig=signal.SIGTERM)
subprocess_info.task_id)] = subprocess_info.pid raise RuntimeError('%s-%d and possibly more subprocesses timed out.' %
(task_type, task_id))
for subprocess_info in subprocess_infos:
_resource(SUBPROCESS_INFO_QUEUE).put(subprocess_info)
return self._pid_dict.get((task_type, task_id), None)
def join(self, timeout=_DEFAULT_TIMEOUT_SEC): def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
"""Joins all the processes with timeout. """Joins all the processes with timeout.
@ -417,84 +432,97 @@ class MultiProcessRunner(object):
RuntimeError: if not all processes report status approximatelty within RuntimeError: if not all processes report status approximatelty within
`timeout` seconds, or there's an exception propagated from any subprocess. `timeout` seconds, or there's an exception propagated from any subprocess.
""" """
if self._joined:
raise ValueError("MultiProcessRunner can't be joined twice.")
self._joined = True
if not timeout: chief = self._processes.get(('chief', 0), None)
timeout = float('inf') if self._dependence_on_chief and chief:
start_time = time.time() self._join_or_terminate('chief', 0, chief, timeout)
while self._outstanding_subprocess_count > 0: # Give other processes a chance to exit on their own.
try: for p in self._processes.values():
process_status = _resource(PROCESS_STATUS_QUEUE).get(timeout=10) p.join(timeout=3)
self.terminate_all()
else:
for (task_type, task_id), p in self._processes.items():
self._join_or_terminate(task_type, task_id, p, timeout)
self._outstanding_subprocess_count -= 1 for (task_type, task_id), p in self._processes.items():
logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode)
process_statuses = self._queue_to_list(self._process_status_queue)
if not self._all_forced_terminated and len(
process_statuses) != self._outstanding_subprocess_count:
raise RuntimeError(
'missing statuses from %d subproceses.' %
(self._outstanding_subprocess_count - len(process_statuses)))
return_values = []
for process_status in process_statuses:
assert isinstance(process_status, _ProcessStatusInfo) assert isinstance(process_status, _ProcessStatusInfo)
if not process_status.is_successful: if not process_status.is_successful:
six.reraise(*process_status.exc_info) six.reraise(*process_status.exc_info)
if process_status.return_value is not None:
return_values.append(process_status.return_value)
if self._dependence_on_chief and process_status.task_type == 'chief': logging.info('Joining log reading threads.')
self.terminate_all() for thread in self._reading_threads:
break thread.join()
except Queue.Empty: logging.info('Joined log reading threads.')
if self._all_forced_terminated:
break
if time.time() - start_time > timeout:
# Send SIGTERM signal to subprocesses to dump their current
# stack trace.
self.terminate_all(sig=signal.SIGTERM)
# If none of those did, report timeout to user.
raise RuntimeError('One or more subprocesses timed out. '
'Number of outstanding subprocesses '
'is %d.' % self._outstanding_subprocess_count)
# Giving threads some time to finish the message reading from subprocesses. # Clear the alarm.
time.sleep(5) signal.alarm(0)
stdout = self._queue_to_list(_resource(STREAMING_QUEUE)) stdout = self._queue_to_list(self._streaming_queue)
return_value = self._queue_to_list(_resource(RETURN_VALUE_QUEUE))
# Notifying the threads that are reading lines that we should stop. return MultiProcessRunnerResult(stdout=stdout, return_value=return_values)
for pipe_index in range(self._starting_pipe_index, _next_pipe_index): # pylint: disable=protected-access
_, pipe_w = _resource(STREAMING_PIPE)[pipe_index]
writer = os.fdopen(pipe_w.fileno(), 'w')
# Writing end of file message so the threads that's actively reading lines
# know to stop.
writer.writelines(['EOF'])
writer.close()
return MultiProcessRunnerResult(stdout=stdout, return_value=return_value)
def terminate(self, task_type, task_id): def terminate(self, task_type, task_id):
"""Terminates the process with `task_type` and `task_id`.""" """Terminates the process with `task_type` and `task_id`."""
_resource(PARENT_TO_SUB_QUEUE).put('terminate {} {}'.format( p = self._processes.get((task_type, task_id), None)
task_type, task_id)) if p is None:
raise ValueError('{}-{} does not exist'.format(task_type, task_id))
# TODO(crccw): change to use Process.terminate() as well.
self._parent_to_sub_queue.put('terminate {} {}'.format(task_type, task_id))
p.join()
def terminate_all(self, sig=None): def terminate_all(self, sig=None):
"""Terminates all subprocesses.""" """Terminates all subprocesses."""
# Use SIGKILL as default. In systems where that's unavailable such as # Use SIGKILL as default. In systems where that's unavailable such as
# windows, use SIGTERM. # windows, use SIGTERM.
sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM) sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM)
subprocess_infos = [] for (task_type, task_id), p in self._processes.items():
while True:
try: try:
subprocess_info = _resource(SUBPROCESS_INFO_QUEUE).get(block=False) os.kill(p.pid, sig)
subprocess_infos.append(subprocess_info)
except Queue.Empty:
break
for subprocess_info in subprocess_infos:
logging.info('Parent process is now killing PID: %d', subprocess_info.pid)
try:
os.kill(subprocess_info.pid, sig)
except ProcessLookupError: except ProcessLookupError:
# TODO(rchao): Remove subprocess info from the queue once a subprocess logging.info('Attempting to kill %s-%d but it does not exist.',
# is terminated. task_type, task_id)
logging.info('PID %d does not exist.', subprocess_info.pid)
self._all_forced_terminated = True self._all_forced_terminated = True
class _Subprocess(object): class _Process(multi_process_lib.Process):
"""Represents an internal subprocess used in MultiProcessRunner's context.""" """A modified `multiprocessing.Process` that can set up environment variables."""
# TODO(crccw): consider moving other logics in _ProcFunc to _Process.
def __init__(self, test_env, **kwargs):
super(_Process, self).__init__(**kwargs)
self._test_env = test_env
self._actual_run = getattr(self, 'run')
self.run = self._run_with_setenv
def _run_with_setenv(self):
# We need to set environment variables before doing anything because
# setenv() is not thread-safe.
test_env = self._test_env
if test_env.grpc_fail_fast is not None:
os.environ['GRPC_FAIL_FAST'] = str(test_env.grpc_fail_fast)
_set_tf_config(test_env.task_type, test_env.task_id, test_env.cluster_spec,
test_env.rpc_layer)
return self._actual_run()
class _ProcFunc(object):
"""Represents a callable to run in a subprocess."""
@contextlib.contextmanager @contextlib.contextmanager
def _runtime_mode(self, executing_eagerly): def _runtime_mode(self, executing_eagerly):
@ -505,21 +533,12 @@ class _Subprocess(object):
with context.graph_mode(): with context.graph_mode():
yield yield
def _finish_process(self, process_status_info, return_value):
"""Adds data to queues before program exits."""
# Clear the alarm.
signal.alarm(0)
if return_value is not None:
self._add_return_data(return_value)
_resource(PROCESS_STATUS_QUEUE).put(process_status_info)
def _message_checking_func(self, task_type, task_id): def _message_checking_func(self, task_type, task_id):
"""A function that regularly checks messages from parent process.""" """A function that regularly checks messages from parent process."""
# TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess. # TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess.
while True: while True:
try: try:
message = _resource(PARENT_TO_SUB_QUEUE).get(block=False) message = self._resources.parent_to_sub_queue.get(block=False)
# Currently the only possible message is termination. # Currently the only possible message is termination.
if not message.startswith('terminate'): if not message.startswith('terminate'):
@ -530,63 +549,75 @@ class _Subprocess(object):
else: else:
# If the message is not targeting this process, put it back to the # If the message is not targeting this process, put it back to the
# queue. # queue.
_resource(PARENT_TO_SUB_QUEUE).put(message) self._resources.parent_to_sub_queue.put(message)
time.sleep(1) time.sleep(1)
except Queue.Empty: except Queue.Empty:
time.sleep(0.1) time.sleep(0.1)
self._finish_process( self._resources.process_status_queue.put(
_ProcessStatusInfo( _ProcessStatusInfo(
task_type=task_type, is_successful=True, exc_info=None), None) task_type=task_type,
is_successful=True,
exc_info=None,
return_value=None))
# `os._exit(0)` is used to more reliably terminate a subprocess. # `os._exit(0)` is used to more reliably terminate a subprocess.
os._exit(0) # pylint: disable=protected-access os._exit(0) # pylint: disable=protected-access
def __call__(self, proc_func, task_type, task_id, per_process_cluster_spec, def _close_streaming(self):
rpc_layer, grpc_fail_fast, v2_enabled, executing_eagerly, pipe_w, """Close stdout, stderr and streaming pipe.
*arg, **kwargs):
We need to explicitly close them since Tensorflow may take a while to exit,
so that the reading threads in the main process can exit more quickly.
"""
sys.stdout.flush()
sys.stderr.flush()
sys.stdout.close()
sys.stderr.close()
self._resources.streaming_pipe_w.close()
def __call__(self, resources, test_env, proc_func, args, kwargs,
use_dill_for_args):
"""The wrapper function that actually gets run in child process(es).""" """The wrapper function that actually gets run in child process(es)."""
global _barrier
self._resources = resources
_barrier = self._resources.barrier
proc_func = dill.loads(proc_func)
if use_dill_for_args:
args = dill.loads(args)
kwargs = dill.loads(kwargs)
if faulthandler is not None: if faulthandler is not None:
faulthandler.enable() faulthandler.enable()
faulthandler.register(signal.SIGTERM, chain=True) faulthandler.register(signal.SIGTERM, chain=True)
# All logging should go to stderr to be streamed to the main process.
logging.set_stderrthreshold(logging.DEBUG)
# Assign sys.stdout and sys.stderr as duplicates of `streaming_pipe_w` so
# print() and logging.*() write directly to `streaming_pipe_w`.
# Unfortunately since we cannot prepend task_type and task_id information to
# the streamed logs we will need a thread per subprocess to distinguish
# where the piece of message is from.
os.dup2(resources.streaming_pipe_w.fileno(), sys.stdout.fileno())
os.dup2(resources.streaming_pipe_w.fileno(), sys.stderr.fileno())
pid = os.getpid() pid = os.getpid()
logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid, logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid,
task_type, task_id) test_env.task_type, test_env.task_id)
_resource(SUBPROCESS_INFO_QUEUE).put(
_SubprocessInfo(pid=pid, task_type=task_type, task_id=task_id))
# Assign sys.stdout and sys.stderr as duplicates of `pipe_w` so print() and
# logging.*() write directly to `pipe_w`. Unfortunately since we cannot
# prepend task_type and task_id information to the streamed logs we will
# need a thread per subprocess to distinguish where the piece of message is
# from.
os.dup2(pipe_w.fileno(), sys.stdout.fileno())
os.dup2(pipe_w.fileno(), sys.stderr.fileno())
# The thread will be dedicated to checking messages from the parent process. # The thread will be dedicated to checking messages from the parent process.
threading.Thread( # pylint: disable=unexpected-keyword-arg threading.Thread( # pylint: disable=unexpected-keyword-arg
target=self._message_checking_func, target=self._message_checking_func,
args=(task_type, task_id), args=(test_env.task_type, test_env.task_id),
daemon=True).start() daemon=True).start()
if grpc_fail_fast is not None: if test_env.v2_enabled:
os.environ['GRPC_FAIL_FAST'] = str(grpc_fail_fast)
tf_config_dict = {
'cluster': per_process_cluster_spec,
'task': {
'type': task_type,
'index': task_id,
},
}
if rpc_layer is not None:
tf_config_dict['rpc_layer'] = rpc_layer
os.environ['TF_CONFIG'] = json.dumps(tf_config_dict)
if v2_enabled:
v2_compat.enable_v2_behavior() v2_compat.enable_v2_behavior()
try: try:
with self._runtime_mode(executing_eagerly): with self._runtime_mode(test_env.executing_eagerly):
return_value = proc_func(*arg, **kwargs) return_value = proc_func(*args, **kwargs)
is_successful = True is_successful = True
exc_info = None exc_info = None
@ -606,35 +637,27 @@ class _Subprocess(object):
raise raise
finally: finally:
self._finish_process( info = _ProcessStatusInfo(
_ProcessStatusInfo( task_type=test_env.task_type,
task_type=task_type,
is_successful=is_successful, is_successful=is_successful,
exc_info=exc_info), exc_info=exc_info,
return_value) return_value=return_value)
self._resources.process_status_queue.put(info)
def _add_return_data(self, data): self._close_streaming()
"""Adds return data that will be returned by `join`.
The function provides a way for child processes to communicate with the
parent process. Data passed to `_add_return_data` will be available in a
Python Queue.Queue that is eventually returned by `join`.
Args:
data: data to be made available in the queue returned by `join`.
"""
# TODO(rchao): Incorporate the task type and id information in a data
# wrapper that becomes what is stored in the queue so we can tell where
# the data is from.
_resource(RETURN_VALUE_QUEUE).put(data)
def barrier(): def _set_tf_config(task_type, task_id, cluster_spec, rpc_layer=None):
return multi_process_lib.get_user_data()[BARRIER] """Set TF_CONFIG environment variable."""
tf_config_dict = {
'cluster': cluster_spec,
def _resource(resource_name): 'task': {
return multi_process_lib.get_user_data()[resource_name] 'type': task_type,
'index': task_id,
},
}
if rpc_layer is not None:
tf_config_dict['rpc_layer'] = rpc_layer
os.environ['TF_CONFIG'] = json.dumps(tf_config_dict)
def run(proc_func, def run(proc_func,
@ -670,16 +693,19 @@ def run(proc_func,
return runner.join(timeout) return runner.join(timeout)
def test_main(max_subprocess_count=_DEFAULT_MAX_SUBPROCESS_COUNT, # This is set by MultiProcessRunner in worker processes.
barrier_parties=0): _barrier = None
"""Main function to be called within `__main__` of a test file.
Args:
max_subprocess_count: Maximum number of subprocesses that will be used. User def barrier():
of multi_process_runner needs to determine a number at calling this if _barrier is None:
method, and the subprocesses involved later should not exceed this number. raise ValueError(
barrier_parties: Number of parties the barrier will be used toward. User of 'barrier is not defined. It is likely because you are calling barrier()'
multi_process_runner needs to determine a number at calling this method. 'in the main process. barrier() can only be called in the subprocesses.'
""" )
with multi_process_lib.context_manager(max_subprocess_count, barrier_parties): return _barrier
test.main()
def test_main():
"""Main function to be called within `__main__` of a test file."""
multi_process_lib.test_main()

View File

@ -23,15 +23,13 @@ import os
import threading import threading
import time import time
from absl import logging from absl import logging
from six.moves import queue as Queue
from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.eager import test from tensorflow.python.eager import test
def proc_func_that_adds_task_type_in_return_data(test_obj, val): def proc_func_that_adds_task_type_in_return_data():
test_obj.assertEqual(val, 3)
return multi_worker_test_base.get_task_type() return multi_worker_test_base.get_task_type()
@ -51,6 +49,10 @@ def proc_func_that_return_args_and_kwargs(*args, **kwargs):
return list(args) + list(kwargs.items()) return list(args) + list(kwargs.items())
def proc_func_with_barrier():
return multi_process_runner.barrier()
class MultiProcessRunnerTest(test.TestCase): class MultiProcessRunnerTest(test.TestCase):
def _worker_idx(self): def _worker_idx(self):
@ -61,8 +63,7 @@ class MultiProcessRunnerTest(test.TestCase):
mpr_result = multi_process_runner.run( mpr_result = multi_process_runner.run(
proc_func_that_adds_task_type_in_return_data, proc_func_that_adds_task_type_in_return_data,
multi_worker_test_base.create_cluster_spec( multi_worker_test_base.create_cluster_spec(
num_workers=2, num_ps=3, has_eval=1), num_workers=2, num_ps=3, has_eval=1))
args=(self, 3))
job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1} job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1}
for data in mpr_result.return_value: for data in mpr_result.return_value:
@ -124,36 +125,22 @@ class MultiProcessRunnerTest(test.TestCase):
def test_process_that_exits(self): def test_process_that_exits(self):
def func_to_exit_in_15_sec(): def func_to_exit_in_5_sec():
time.sleep(5) logging.error('foo')
print('foo', flush=True) time.sleep(10)
time.sleep(20) logging.error('bar')
print('bar', flush=True)
mpr = multi_process_runner.MultiProcessRunner( mpr = multi_process_runner.MultiProcessRunner(
func_to_exit_in_15_sec, func_to_exit_in_5_sec,
multi_worker_test_base.create_cluster_spec(num_workers=1), multi_worker_test_base.create_cluster_spec(num_workers=1),
list_stdout=True, list_stdout=True,
max_run_time=15) max_run_time=5)
mpr.start() mpr.start()
stdout = mpr.join().stdout stdout = mpr.join().stdout
self.assertLen([msg for msg in stdout if 'foo' in msg], 1) self.assertLen([msg for msg in stdout if 'foo' in msg], 1)
self.assertLen([msg for msg in stdout if 'bar' in msg], 0) self.assertLen([msg for msg in stdout if 'bar' in msg], 0)
def test_signal_doesnt_fire_after_process_exits(self):
mpr = multi_process_runner.MultiProcessRunner(
proc_func_that_does_nothing,
multi_worker_test_base.create_cluster_spec(num_workers=1),
max_run_time=10)
mpr.start()
mpr.join()
with self.assertRaisesRegexp(Queue.Empty, ''):
# If the signal was fired, another message would be added to internal
# queue, so verifying it's empty.
multi_process_runner._resource(
multi_process_runner.PROCESS_STATUS_QUEUE).get(block=False)
def test_termination(self): def test_termination(self):
def proc_func(): def proc_func():
@ -192,7 +179,7 @@ class MultiProcessRunnerTest(test.TestCase):
multi_worker_test_base.create_cluster_spec(num_workers=2), multi_worker_test_base.create_cluster_spec(num_workers=2),
list_stdout=True) list_stdout=True)
mpr.start() mpr.start()
time.sleep(5) time.sleep(3)
mpr.terminate('worker', 0) mpr.terminate('worker', 0)
mpr.start_single_process('worker', 0) mpr.start_single_process('worker', 0)
std_stream_results = mpr.join().stdout std_stream_results = mpr.join().stdout
@ -273,11 +260,14 @@ class MultiProcessRunnerTest(test.TestCase):
has_chief=True, num_workers=1), has_chief=True, num_workers=1),
list_stdout=True) list_stdout=True)
def follow_ups(): def eval_func():
time.sleep(1)
mpr.start_single_process(task_type='evaluator', task_id=0) mpr.start_single_process(task_type='evaluator', task_id=0)
threading.Thread(target=follow_ups).start() eval_thread = threading.Thread(target=eval_func)
eval_thread.start()
mpr.start_in_process_as(as_task_type='chief', as_task_id=0) mpr.start_in_process_as(as_task_type='chief', as_task_id=0)
eval_thread.join()
list_to_assert = mpr.join().stdout list_to_assert = mpr.join().stdout
for job in ['worker', 'evaluator']: for job in ['worker', 'evaluator']:
for iteration in range(5): for iteration in range(5):
@ -291,9 +281,21 @@ class MultiProcessRunnerTest(test.TestCase):
multi_worker_test_base.create_cluster_spec(num_workers=2), multi_worker_test_base.create_cluster_spec(num_workers=2),
list_stdout=True) list_stdout=True)
mpr.start() mpr.start()
time.sleep(3)
mpr.terminate_all() mpr.terminate_all()
with self.assertRaisesRegexp(ValueError, 'This is an error.'): with self.assertRaisesRegexp(ValueError, 'This is an error.'):
mpr.join() mpr.join()
def test_barrier(self):
multi_process_runner.run(
proc_func_with_barrier,
cluster_spec=multi_worker_test_base.create_cluster_spec(
has_chief=True, num_workers=1),
)
def test_barrier_called_in_main_process(self):
with self.assertRaises(ValueError):
multi_process_runner.barrier()
if __name__ == '__main__': if __name__ == '__main__':
multi_process_runner.test_main() multi_process_runner.test_main()

View File

@ -127,4 +127,4 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
multi_process_runner.test_main(barrier_parties=NUM_WORKERS) multi_process_runner.test_main()

View File

@ -364,7 +364,7 @@ py_test(
name = "multi_worker_callback_tf2_test", name = "multi_worker_callback_tf2_test",
srcs = ["multi_worker_callback_tf2_test.py"], srcs = ["multi_worker_callback_tf2_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 10, shard_count = 5,
deps = [ deps = [
"//tensorflow/python/distribute:collective_all_reduce_strategy", "//tensorflow/python/distribute:collective_all_reduce_strategy",
"//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:combinations",

View File

@ -345,4 +345,4 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
multi_process_runner.test_main(barrier_parties=2) multi_process_runner.test_main()

View File

@ -183,6 +183,7 @@ filegroup(
"@com_google_protobuf//:LICENSE", "@com_google_protobuf//:LICENSE",
"@com_googlesource_code_re2//:LICENSE", "@com_googlesource_code_re2//:LICENSE",
"@curl//:COPYING", "@curl//:COPYING",
"@dill_archive//:LICENSE",
"@dlpack//:LICENSE", "@dlpack//:LICENSE",
"@double_conversion//:LICENSE", "@double_conversion//:LICENSE",
"@eigen_archive//:COPYING.MPL2", "@eigen_archive//:COPYING.MPL2",
@ -212,6 +213,7 @@ filegroup(
"@six_archive//:LICENSE", "@six_archive//:LICENSE",
"@snappy//:COPYING", "@snappy//:COPYING",
"@sobol_data//:LICENSE", "@sobol_data//:LICENSE",
"@tblib_archive//:LICENSE",
"@termcolor_archive//:COPYING.txt", "@termcolor_archive//:COPYING.txt",
"@zlib//:zlib.h", "@zlib//:zlib.h",
"@clog//:LICENSE", "@clog//:LICENSE",

View File

@ -532,6 +532,28 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
], ],
) )
tf_http_archive(
name = "dill_archive",
build_file = clean_dep("//third_party:dill.BUILD"),
urls = [
"http://mirror.tensorflow.org/files.pythonhosted.org/packages/c7/11/345f3173809cea7f1a193bfbf02403fff250a3360e0e118a1630985e547d/dill-0.3.1.1.tar.gz",
"https://files.pythonhosted.org/packages/c7/11/345f3173809cea7f1a193bfbf02403fff250a3360e0e118a1630985e547d/dill-0.3.1.1.tar.gz",
],
sha256 = "42d8ef819367516592a825746a18073ced42ca169ab1f5f4044134703e7a049c",
strip_prefix = "dill-0.3.1.1",
)
tf_http_archive(
name = "tblib_archive",
build_file = clean_dep("//third_party:tblib.BUILD"),
urls = [
"http://mirror.tensorflow.org/files.pythonhosted.org/packages/ec/c4/8c651f3240a73c28a218194f3d527eb2be5a173d08501060cdee84ade33f/tblib-1.3.2.tar.gz",
"https://files.pythonhosted.org/packages/ec/c4/8c651f3240a73c28a218194f3d527eb2be5a173d08501060cdee84ade33f/tblib-1.3.2.tar.gz",
],
sha256 = "436e4200e63d92316551179dc540906652878df4ff39b43db30fcf6400444fe7",
strip_prefix = "tblib-1.3.2",
)
filegroup_external( filegroup_external(
name = "org_python_license", name = "org_python_license",
licenses = ["notice"], # Python 2.0 licenses = ["notice"], # Python 2.0

10
third_party/dill.BUILD vendored Normal file
View File

@ -0,0 +1,10 @@
licenses(["notice"]) # BSD 3-clause
exports_files(["LICENSE"])
py_library(
name = "dill",
srcs = glob(["dill/*.py"]),
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)

11
third_party/tblib.BUILD vendored Normal file
View File

@ -0,0 +1,11 @@
licenses(["notice"]) # BSD
exports_files(["LICENSE"])
py_library(
name = "tblib",
srcs = glob(["src/tblib/*.py"]),
imports = ["src"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)