Improve multi_process_runner

This is to prepare enabling it for OSS.

PiperOrigin-RevId: 314337449
Change-Id: I64d2498ceac4f78638f0cae429d073600eb9a1e1
This commit is contained in:
Zhenyu Tan 2020-06-02 08:43:16 -07:00 committed by TensorFlower Gardener
parent fa75abd1ad
commit 5d7c5237ab
7 changed files with 302 additions and 336 deletions

View File

@ -1695,14 +1695,11 @@ cuda_py_test(
py_library( py_library(
name = "multi_process_runner", name = "multi_process_runner",
srcs = ["multi_process_runner.py"], srcs = ["multi_process_runner.py"],
srcs_version = "PY3",
deps = [ deps = [
":multi_process_lib", ":multi_process_lib",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:tf2", "//tensorflow/python:tf2",
"//tensorflow/python/compat:v2_compat", "//tensorflow/python/compat:v2_compat",
"//tensorflow/python/eager:context",
"@absl_py//absl/logging",
"@six_archive//:six", "@six_archive//:six",
], ],
) )
@ -1710,19 +1707,18 @@ py_library(
py_library( py_library(
name = "multi_process_lib", name = "multi_process_lib",
srcs = ["multi_process_lib.py"], srcs = ["multi_process_lib.py"],
deps = ["//tensorflow/python:client_testlib"], deps = ["@six_archive//:six"],
) )
py_test( py_test(
name = "multi_process_runner_test", name = "multi_process_runner_test",
srcs = ["multi_process_runner_test.py"], srcs = ["multi_process_runner_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 12,
deps = [ deps = [
":multi_process_runner", ":multi_process_runner",
":multi_worker_test_base", ":multi_worker_test_base",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",
"@absl_py//absl/logging",
"@six_archive//:six",
], ],
) )
@ -1730,7 +1726,6 @@ py_test(
name = "multi_process_runner_no_init_test", name = "multi_process_runner_no_init_test",
srcs = ["multi_process_runner_no_init_test.py"], srcs = ["multi_process_runner_no_init_test.py"],
python_version = "PY3", python_version = "PY3",
tags = ["no_oss"],
deps = [ deps = [
":multi_process_runner", ":multi_process_runner",
":multi_worker_test_base", ":multi_worker_test_base",

View File

@ -18,18 +18,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import multiprocessing as _multiprocessing import contextlib
import unittest import unittest
from tensorflow.python.platform import test
try:
multiprocessing = _multiprocessing.get_context('forkserver')
except ValueError:
# forkserver is not available on Windows.
multiprocessing = _multiprocessing.get_context('spawn')
class Process(object): class Process(object):
"""A process simulating a worker for testing multi-worker training.""" """A process simulating a worker for testing multi-worker training."""
@ -40,11 +31,20 @@ class Process(object):
'TODO(b/141874796): Implement OSS version of `multi_process_lib`') 'TODO(b/141874796): Implement OSS version of `multi_process_lib`')
def test_main(): def get_user_data():
"""Main function to be called within `__main__` of a test file.""" """Returns the data commonly shared by parent process and subprocesses."""
test.main() # TODO(b/141874796): Implement OSS version of `multi_process_lib`.
pass
def initialized(): @contextlib.contextmanager
"""Returns whether the module is initialized.""" def context_manager(max_subprocess_count=20, barrier_parties=0):
return True """No-op in OSS. This exists to maintain testing compatibility."""
del max_subprocess_count, barrier_parties
yield
def using_context_manager():
"""Whether the context manager is being used."""
raise unittest.SkipTest(
'TODO(b/141874796): Implement OSS version of `multi_process_lib`')

View File

@ -18,7 +18,6 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import contextlib import contextlib
import json import json
@ -27,19 +26,15 @@ import signal
import sys import sys
import threading import threading
import time import time
from absl import logging from absl import logging
import dill
import six import six
from six.moves import queue as Queue from six.moves import queue as Queue
import tblib.pickling_support
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.compat import v2_compat from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import multi_process_lib from tensorflow.python.distribute import multi_process_lib
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.platform import test
multiprocessing = multi_process_lib.multiprocessing
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
try: try:
@ -48,45 +43,61 @@ try:
except ImportError: except ImportError:
faulthandler = None faulthandler = None
# For pickling traceback objects.
tblib.pickling_support.install()
# _ProcessStatusInfo contains process status information. When is_successful # _ProcessStatusInfo contains process status information. When is_successful
# attribute is True, the subprocess has ended successfully, or if False, the # attribute is True, the subprocess has ended successfully, or if False, the
# exception stack trace info is stored in exc_info to pass on to parent process # exception stack trace info is stored in exc_info to pass on to parent process
# to be re-raised. # to be re-raised.
_ProcessStatusInfo = collections.namedtuple( _ProcessStatusInfo = collections.namedtuple(
'_ProcessStatusInfo', '_ProcessStatusInfo', ['task_type', 'is_successful', 'exc_info'])
['task_type', 'is_successful', 'exc_info', 'return_value'])
# _SubprocessInfo collects basic information of a subprocess such as task type
# and process id.
# TODO(rchao): Include task_type and task_id in subprocess info.
_SubprocessInfo = collections.namedtuple('_SubprocessInfo', ['pid'])
# Information returned from a successful MultiProcessRunner run. # Information returned from a successful MultiProcessRunner run.
MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult', MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult',
['return_value', 'stdout']) ['return_value', 'stdout'])
TestEnvironment = collections.namedtuple('TestEnvironment', [ # Process status queue is used by `multi_process_runner` internally for
'task_type', 'task_id', 'cluster_spec', 'rpc_layer', 'grpc_fail_fast', # communication from subprocesses to the parent process for whether it's been
'v2_enabled', 'executing_eagerly' # successful, and if not what the error stack trace is.
]) PROCESS_STATUS_QUEUE = 'process_status_queue'
# Resources for communication between worker processes and the main process. # Return value queue is intended to be used by users of `multi_process_runner`
# # for the process function to return information to the caller of
# `process_status_queue` is used by `multi_process_runner` internally for # `multi_process_runner.run()`.
# communication from subprocesses to the parent process for whether it's been RETURN_VALUE_QUEUE = 'return_value_queue'
# successful, and if not what the error stack trace is.
# `parent_to_sub_queue` is used for communications from parent to subprocess. # Subprocess info queue stores `_SubprocessInfo` for later potential
# Currently this is only used to terminate subprocesses. # termination by the parent.
SUBPROCESS_INFO_QUEUE = 'subprocess_info_queue'
# Parent-to-sub queue is used for communications from parent to subprocess.
# Currently this is only used to terminate subprocesses.
# TODO(rchao): Remove this once subprocess is terminated by SIGKILL. # TODO(rchao): Remove this once subprocess is terminated by SIGKILL.
# `streaming_pipe_w` is to stream stdout and stderr from subprocesses to parent PARENT_TO_SUB_QUEUE = 'parent_to_sub_queue'
# process.
# `barrier` is a barrier for the party of all subprocesses. # Streaming queue stores the logged and printed messages from subprocesses.
Resources = collections.namedtuple('Resources', [ STREAMING_QUEUE = 'streaming_queue'
'process_status_queue', 'parent_to_sub_queue', 'streaming_pipe_w', 'barrier'
]) # Pipes to stream stdout and stderr from subprocesses to parent process.
STREAMING_PIPE = 'streaming_pipe'
# Barrier identifier.
BARRIER = 'barrier'
_DEFAULT_MAX_SUBPROCESS_COUNT = 20
# Default time out sec is selected so that it's handled before the default # Default time out sec is selected so that it's handled before the default
# "medium" timeout of the test runs. # "medium" timeout of the test runs.
_DEFAULT_TIMEOUT_SEC = 200 _DEFAULT_TIMEOUT_SEC = 200
# Next pipe index to be global so that pipes are not reused across multiple
# MultiProcessRunner usages.
# TODO(rchao): Investigate possibility to remove this variable.
_next_pipe_index = 0
class MultiProcessRunner(object): class MultiProcessRunner(object):
"""A utility class to start multiple processes to simulate a cluster. """A utility class to start multiple processes to simulate a cluster.
@ -112,7 +123,6 @@ class MultiProcessRunner(object):
grpc_fail_fast=None, grpc_fail_fast=None,
stream_stdout=True, stream_stdout=True,
list_stdout=False, list_stdout=False,
use_dill_for_args=True,
args=None, args=None,
kwargs=None): kwargs=None):
"""Creates a multi-process runner. """Creates a multi-process runner.
@ -143,9 +153,6 @@ class MultiProcessRunner(object):
returned from `MultiProcessRunner.join()`. If True, the list of stdout returned from `MultiProcessRunner.join()`. If True, the list of stdout
can be retrieved via `MultiProcessRunnerResult.stdout` attribute. can be retrieved via `MultiProcessRunnerResult.stdout` attribute.
Defaults to False. Defaults to False.
use_dill_for_args: Whether to use dill to pickle `args` and `kwargs`. dill
can pickle more objects, but doesn't work with types in
`multiprocessing` library like `Mutex`.
args: Positional arguments to be sent to functions run on processes. args: Positional arguments to be sent to functions run on processes.
kwargs: Keyword arguments to be sent to functions run on processes. kwargs: Keyword arguments to be sent to functions run on processes.
@ -159,15 +166,15 @@ class MultiProcessRunner(object):
'one chief. Current `cluster_spec` has {} chiefs.' 'one chief. Current `cluster_spec` has {} chiefs.'
.format(len(cluster_spec['chief']))) .format(len(cluster_spec['chief'])))
if not multi_process_lib.initialized(): assert callable(proc_func)
if not multi_process_lib.using_context_manager():
raise RuntimeError('`multi_process_runner` is not initialized. ' raise RuntimeError('`multi_process_runner` is not initialized. '
'Please call `multi_process_runner.test_main()` ' 'Please call `multi_process_runner.test_main()` '
'within `if __name__ == \'__main__\':` block ' 'within `if __name__ == \'__main__\':` block '
'in your python module to properly initialize ' 'in your python module to properly initialize '
'`multi_process_runner`.') '`multi_process_runner`.')
assert callable(proc_func)
self._proc_func = proc_func self._proc_func = proc_func
self._cluster_spec = cluster_spec self._cluster_spec = cluster_spec
self._rpc_layer = rpc_layer self._rpc_layer = rpc_layer
@ -177,86 +184,62 @@ class MultiProcessRunner(object):
# TODO(rchao): Revisit list_stdout argument to consider other solution. # TODO(rchao): Revisit list_stdout argument to consider other solution.
self._list_stdout = list_stdout self._list_stdout = list_stdout
self._dependence_on_chief = True self._dependence_on_chief = True
self._use_dill_for_args = use_dill_for_args
self._args = args or () self._args = args or ()
self._kwargs = kwargs or {} self._kwargs = kwargs or {}
self._outstanding_subprocess_count = 0
# Child processes should have the same v2 and eager behavior. # Child processes should have the same v2 and eager behavior.
self._v2_enabled = tf2.enabled() self._v2_enabled = tf2.enabled()
self._executing_eagerly = context.executing_eagerly() self._executing_eagerly = context.executing_eagerly()
self._joined = False
self._processes = {}
self._outstanding_subprocess_count = 0
self._reading_threads = []
self._manager = multiprocessing.Manager()
self._process_status_queue = self._manager.Queue()
self._parent_to_sub_queue = self._manager.Queue()
parties = sum(len(addresses) for addresses in self._cluster_spec.values())
self._barrier = self._manager.Barrier(parties)
# We use a queue to collect outputs from worker processes since it's thread
# safe.
self._streaming_queue = self._manager.Queue()
# This flag will be set to True once terminate_all() is called. # This flag will be set to True once terminate_all() is called.
self._all_forced_terminated = False self._all_forced_terminated = False
def _continuously_readline_from_sub(self, pipe_r, task_type, task_id): def _continuously_readline_from_sub(self, pipe_r, task_type, task_id):
"""Function to continuously read lines from subprocesses.""" """Function to continuously read lines from subprocesses."""
with os.fdopen(pipe_r.fileno(), 'r', closefd=False) as reader: reader = os.fdopen(pipe_r.fileno(), 'r')
for line in reader: while True:
task_string = '[{}-{}]:'.format(task_type, task_id) read_line = reader.readline()
formatted_line = '{} {}'.format(task_string.ljust(14), line) if read_line == 'EOF':
if self._stream_stdout: reader.close()
# TODO(rchao): Use a lock here to ensure the printed lines are not # The thread that runs `_continuously_readline_from_sub` stops here.
# broken. # However the threads don't exit until the test exits, so we do not
print(formatted_line, end='', flush=True) # attempt to join the threads (which leads to timeout).
if self._list_stdout: # TODO(rchao): Understand why and do thread joining.
self._streaming_queue.put(formatted_line) break
task_string = '[{}-{}]:'.format(task_type, task_id)
formatted_line = '{} {}'.format(task_string.ljust(14), read_line)
if self._stream_stdout:
self._print_stdout_in_parent(formatted_line, task_type, task_id)
if self._list_stdout:
self._add_stdout_in_queue(formatted_line, task_type, task_id)
def _start_subprocess_and_reading_thread(self, def _print_stdout_in_parent(self, formatted_line, task_type, task_id):
task_type, del task_type, task_id
task_id, # Flush True so the logging order from subprocesses is respected.
cluster_spec=None, # TODO(rchao): Use a lock here to ensure the printed lines are not broken.
proc_func=None, print(formatted_line, end='', flush=True)
args=None,
kwargs=None): def _add_stdout_in_queue(self, formatted_line, task_type, task_id):
del task_type, task_id
# A queue instead of a simple list is used here due to b/150652733.
_resource(STREAMING_QUEUE).put(formatted_line)
def _start_subprocess_and_reading_thread(self, proc_func, task_type, task_id,
cluster_spec, args, kwargs):
"""Start a subprocess and a thread the reads lines from the subprocess.""" """Start a subprocess and a thread the reads lines from the subprocess."""
global _next_pipe_index
pipe_r, pipe_w = _resource(STREAMING_PIPE)[_next_pipe_index]
_next_pipe_index += 1
test_env = TestEnvironment( p = multi_process_lib.Process(
task_type=task_type, target=_Subprocess(),
task_id=task_id, args=(proc_func, task_type, task_id, cluster_spec, self._rpc_layer,
cluster_spec=cluster_spec or self._cluster_spec, self._grpc_fail_fast, self._v2_enabled, self._executing_eagerly,
rpc_layer=self._rpc_layer, pipe_w) + args,
grpc_fail_fast=self._grpc_fail_fast, kwargs=kwargs)
v2_enabled=self._v2_enabled,
executing_eagerly=self._executing_eagerly,
)
pipe_r, pipe_w = multiprocessing.Pipe(duplex=False)
resources = Resources(
process_status_queue=self._process_status_queue,
parent_to_sub_queue=self._parent_to_sub_queue,
streaming_pipe_w=pipe_w,
barrier=self._barrier,
)
if proc_func is None:
proc_func, args, kwargs = self._proc_func, self._args, self._kwargs
# Always use dill to pickle proc_func so that we support more callable
# types, e.g. lambda.
proc_func = dill.dumps(proc_func, dill.HIGHEST_PROTOCOL)
if self._use_dill_for_args:
args = dill.dumps(args, dill.HIGHEST_PROTOCOL)
kwargs = dill.dumps(kwargs, dill.HIGHEST_PROTOCOL)
p = _Process(
test_env=test_env,
target=_ProcFunc(),
args=(resources, test_env, proc_func, args, kwargs,
self._use_dill_for_args))
p.start() p.start()
self._processes[(task_type, task_id)] = p
self._outstanding_subprocess_count += 1 self._outstanding_subprocess_count += 1
# For each subprocess, we dedicate a thread continuously reading lines # For each subprocess, we dedicate a thread continuously reading lines
@ -265,15 +248,18 @@ class MultiProcessRunner(object):
target=self._continuously_readline_from_sub, target=self._continuously_readline_from_sub,
args=(pipe_r, task_type, task_id)) args=(pipe_r, task_type, task_id))
thread.start() thread.start()
self._reading_threads.append(thread)
def start(self): def start(self):
"""Starts processes, one for each task in `cluster_spec`.""" """Starts processes, one for each task in `cluster_spec`."""
if self._processes:
raise ValueError('MultiProcessRunner already started.') global _next_pipe_index
self._starting_pipe_index = _next_pipe_index
for task_type, addresses in self._cluster_spec.items(): for task_type, addresses in self._cluster_spec.items():
for task_id, _ in enumerate(addresses): for task_id, _ in enumerate(addresses):
self._start_subprocess_and_reading_thread(task_type, task_id) self._start_subprocess_and_reading_thread(self._proc_func, task_type,
task_id, self._cluster_spec,
self._args, self._kwargs)
# TODO(rchao): Remove the need of using SIGALRM if possible. At this time, # TODO(rchao): Remove the need of using SIGALRM if possible. At this time,
# without this the tests become very flaky. # without this the tests become very flaky.
@ -323,22 +309,33 @@ class MultiProcessRunner(object):
as_task_type: The task type to be run in the main process. as_task_type: The task type to be run in the main process.
as_task_id: The task id to be run in the main process. as_task_id: The task id to be run in the main process.
""" """
if self._processes: global _next_pipe_index
raise ValueError('MultiProcessRunner already started.') self._starting_pipe_index = _next_pipe_index
for task_type, addresses in self._cluster_spec.items(): for task_type, addresses in self._cluster_spec.items():
for task_id, _ in enumerate(addresses): for task_id, _ in enumerate(addresses):
if not (task_type == as_task_type and task_id == as_task_id): if not (task_type == as_task_type and task_id == as_task_id):
self._start_subprocess_and_reading_thread(task_type, task_id) self._start_subprocess_and_reading_thread(self._proc_func, task_type,
task_id, self._cluster_spec,
self._args, self._kwargs)
tf_config_dict = {
'cluster': self._cluster_spec,
'task': {
'type': as_task_type,
'index': as_task_id,
},
}
if self._rpc_layer is not None:
tf_config_dict['rpc_layer'] = self._rpc_layer
os.environ['TF_CONFIG'] = json.dumps(tf_config_dict)
_set_tf_config(as_task_type, as_task_id, self._cluster_spec,
self._rpc_layer)
self._proc_func(*self._args, **self._kwargs) self._proc_func(*self._args, **self._kwargs)
def start_single_process(self, def start_single_process(self,
task_type, task_type,
task_id, task_id,
cluster_spec=None,
proc_func=None, proc_func=None,
cluster_spec=None,
args=None, args=None,
kwargs=None): kwargs=None):
"""Starts a single process. """Starts a single process.
@ -355,22 +352,19 @@ class MultiProcessRunner(object):
Args: Args:
task_type: The task type. task_type: The task type.
task_id: The task id. task_id: The task id.
proc_func: The process function to be run on the newly started
process. If `None`, the function provided at `__init__` will be used.
cluster_spec: The cluster spec to be used on the newly started cluster_spec: The cluster spec to be used on the newly started
process. If `None`, the cluster spec provided at `__init__` will be process. If `None`, the cluster spec provided at `__init__` will be
used. used.
proc_func: The process function to be run on the newly started
process. If specified, specify `args` and `kwargs` as well. If `None`,
the function provided at `__init__` will be used.
args: Optional positional arguments to be supplied in `proc_func`. args: Optional positional arguments to be supplied in `proc_func`.
kwargs: Optional keyword arguments to be supplied in `proc_func`. kwargs: Optional keyword arguments to be supplied in `proc_func`.
""" """
self._start_subprocess_and_reading_thread( cluster_spec = cluster_spec or self._cluster_spec
task_type, proc_func = proc_func or self._proc_func
task_id, self._start_subprocess_and_reading_thread(proc_func, task_type, task_id,
cluster_spec=cluster_spec, cluster_spec, args or (),
proc_func=proc_func, kwargs or {})
args=args or (),
kwargs=kwargs or {})
def _queue_to_list(self, queue_to_convert): def _queue_to_list(self, queue_to_convert):
"""Convert `queue.Queue` to `list`.""" """Convert `queue.Queue` to `list`."""
@ -383,18 +377,6 @@ class MultiProcessRunner(object):
break break
return list_to_return return list_to_return
def _join_or_terminate(self, task_type, task_id, process, timeout):
"""Joins a process. If it times out, terminate all procsses."""
logging.info('joining %s-%d', task_type, task_id)
process.join(timeout)
# If exitcode is None, the process aren't terminated and this is a
# timeout.
if process.exitcode is None:
# Force termination to dump worker processes stack trace.
self.terminate_all(sig=signal.SIGTERM)
raise RuntimeError('%s-%d and possibly more subprocesses timed out.' %
(task_type, task_id))
def join(self, timeout=_DEFAULT_TIMEOUT_SEC): def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
"""Joins all the processes with timeout. """Joins all the processes with timeout.
@ -413,97 +395,88 @@ class MultiProcessRunner(object):
RuntimeError: if not all processes report status approximatelty within RuntimeError: if not all processes report status approximatelty within
`timeout` seconds, or there's an exception propagated from any subprocess. `timeout` seconds, or there's an exception propagated from any subprocess.
""" """
if self._joined:
raise ValueError("MultiProcessRunner can't be joined twice.")
self._joined = True
chief = self._processes.get(('chief', 0), None) if not timeout:
if self._dependence_on_chief and chief: timeout = float('inf')
self._join_or_terminate('chief', 0, chief, timeout) start_time = time.time()
# Give other processes a chance to exit on their own. while self._outstanding_subprocess_count > 0:
for p in self._processes.values(): while True:
p.join(timeout=3) try:
self.terminate_all() process_status = _resource(PROCESS_STATUS_QUEUE).get(timeout=10)
else: break
for (task_type, task_id), p in self._processes.items(): except Queue.Empty:
self._join_or_terminate(task_type, task_id, p, timeout) if self._all_forced_terminated:
break
if time.time() - start_time > timeout:
# Send SIGTERM signal to subprocesses to dump their current
# stack trace.
self.terminate_all(sig=signal.SIGTERM)
# If none of those did, report timeout to user.
raise RuntimeError('One or more subprocesses timed out. '
'Number of outstanding subprocesses '
'is %d.' % self._outstanding_subprocess_count)
for (task_type, task_id), p in self._processes.items(): if self._all_forced_terminated:
logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode) break
self._outstanding_subprocess_count -= 1
process_statuses = self._queue_to_list(self._process_status_queue)
if not self._all_forced_terminated and len(
process_statuses) != self._outstanding_subprocess_count:
raise RuntimeError(
'missing statuses from %d subproceses.' %
(self._outstanding_subprocess_count - len(process_statuses)))
return_values = []
for process_status in process_statuses:
assert isinstance(process_status, _ProcessStatusInfo) assert isinstance(process_status, _ProcessStatusInfo)
if not process_status.is_successful: if not process_status.is_successful:
six.reraise(*process_status.exc_info) six.reraise(*process_status.exc_info)
if process_status.return_value is not None:
return_values.append(process_status.return_value)
logging.info('Joining log reading threads.') if self._dependence_on_chief and process_status.task_type == 'chief':
for thread in self._reading_threads: self.terminate_all()
thread.join() break
logging.info('Joined log reading threads.')
# Clear the alarm. # Giving threads some time to finish the message reading from subprocesses.
signal.alarm(0) time.sleep(5)
stdout = self._queue_to_list(self._streaming_queue) stdout = self._queue_to_list(_resource(STREAMING_QUEUE))
return_value = self._queue_to_list(_resource(RETURN_VALUE_QUEUE))
return MultiProcessRunnerResult(stdout=stdout, return_value=return_values) # Notifying the threads that are reading lines that we should stop.
for pipe_index in range(self._starting_pipe_index, _next_pipe_index): # pylint: disable=protected-access
_, pipe_w = _resource(STREAMING_PIPE)[pipe_index]
writer = os.fdopen(pipe_w.fileno(), 'w')
# Writing end of file message so the threads that's actively reading lines
# know to stop.
writer.writelines(['EOF'])
writer.close()
return MultiProcessRunnerResult(stdout=stdout, return_value=return_value)
def terminate(self, task_type, task_id): def terminate(self, task_type, task_id):
"""Terminates the process with `task_type` and `task_id`.""" """Terminates the process with `task_type` and `task_id`."""
p = self._processes.get((task_type, task_id), None) _resource(PARENT_TO_SUB_QUEUE).put('terminate {} {}'.format(
if p is None: task_type, task_id))
raise ValueError('{}-{} does not exist'.format(task_type, task_id))
# TODO(crccw): change to use Process.terminate() as well.
self._parent_to_sub_queue.put('terminate {} {}'.format(task_type, task_id))
p.join()
def terminate_all(self, sig=None): def terminate_all(self, sig=None):
"""Terminates all subprocesses.""" """Terminates all subprocesses."""
# Use SIGKILL as default. In systems where that's unavailable such as # Use SIGKILL as default. In systems where that's unavailable such as
# windows, use SIGTERM. # windows, use SIGTERM.
sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM) sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM)
for (task_type, task_id), p in self._processes.items(): subprocess_infos = []
while True:
try: try:
os.kill(p.pid, sig) subprocess_info = _resource(SUBPROCESS_INFO_QUEUE).get(block=False)
subprocess_infos.append(subprocess_info)
except Queue.Empty:
break
for subprocess_info in subprocess_infos:
logging.info('Parent process is now killing PID: %d', subprocess_info.pid)
try:
os.kill(subprocess_info.pid, sig)
except ProcessLookupError: except ProcessLookupError:
logging.info('Attempting to kill %s-%d but it does not exist.', # TODO(rchao): Remove subprocess info from the queue once a subprocess
task_type, task_id) # is terminated.
logging.info('PID %d does not exist.', subprocess_info.pid)
self._all_forced_terminated = True self._all_forced_terminated = True
class _Process(multi_process_lib.Process): class _Subprocess(object):
"""A modified `multiprocessing.Process` that can set up environment variables.""" """Represents an internal subprocess used in MultiProcessRunner's context."""
# TODO(crccw): consider moving other logics in _ProcFunc to _Process.
def __init__(self, test_env, **kwargs):
super(_Process, self).__init__(**kwargs)
self._test_env = test_env
self._actual_run = getattr(self, 'run')
self.run = self._run_with_setenv
def _run_with_setenv(self):
# We need to set environment variables before doing anything because
# setenv() is not thread-safe.
test_env = self._test_env
if test_env.grpc_fail_fast is not None:
os.environ['GRPC_FAIL_FAST'] = str(test_env.grpc_fail_fast)
_set_tf_config(test_env.task_type, test_env.task_id, test_env.cluster_spec,
test_env.rpc_layer)
return self._actual_run()
class _ProcFunc(object):
"""Represents a callable to run in a subprocess."""
@contextlib.contextmanager @contextlib.contextmanager
def _runtime_mode(self, executing_eagerly): def _runtime_mode(self, executing_eagerly):
@ -514,12 +487,21 @@ class _ProcFunc(object):
with context.graph_mode(): with context.graph_mode():
yield yield
def _finish_process(self, process_status_info, return_value):
"""Adds data to queues before program exits."""
# Clear the alarm.
signal.alarm(0)
if return_value is not None:
self._add_return_data(return_value)
_resource(PROCESS_STATUS_QUEUE).put(process_status_info)
def _message_checking_func(self, task_type, task_id): def _message_checking_func(self, task_type, task_id):
"""A function that regularly checks messages from parent process.""" """A function that regularly checks messages from parent process."""
# TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess. # TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess.
while True: while True:
try: try:
message = self._resources.parent_to_sub_queue.get(block=False) message = _resource(PARENT_TO_SUB_QUEUE).get(block=False)
# Currently the only possible message is termination. # Currently the only possible message is termination.
if not message.startswith('terminate'): if not message.startswith('terminate'):
@ -530,75 +512,62 @@ class _ProcFunc(object):
else: else:
# If the message is not targeting this process, put it back to the # If the message is not targeting this process, put it back to the
# queue. # queue.
self._resources.parent_to_sub_queue.put(message) _resource(PARENT_TO_SUB_QUEUE).put(message)
time.sleep(1) time.sleep(1)
except Queue.Empty: except Queue.Empty:
time.sleep(0.1) time.sleep(0.1)
self._resources.process_status_queue.put( self._finish_process(
_ProcessStatusInfo( _ProcessStatusInfo(
task_type=task_type, task_type=task_type, is_successful=True, exc_info=None), None)
is_successful=True,
exc_info=None,
return_value=None))
# `os._exit(0)` is used to more reliably terminate a subprocess. # `os._exit(0)` is used to more reliably terminate a subprocess.
os._exit(0) # pylint: disable=protected-access os._exit(0) # pylint: disable=protected-access
def _close_streaming(self): def __call__(self, proc_func, task_type, task_id, per_process_cluster_spec,
"""Close stdout, stderr and streaming pipe. rpc_layer, grpc_fail_fast, v2_enabled, executing_eagerly, pipe_w,
*arg, **kwargs):
We need to explicitly close them since Tensorflow may take a while to exit,
so that the reading threads in the main process can exit more quickly.
"""
sys.stdout.flush()
sys.stderr.flush()
sys.stdout.close()
sys.stderr.close()
self._resources.streaming_pipe_w.close()
def __call__(self, resources, test_env, proc_func, args, kwargs,
use_dill_for_args):
"""The wrapper function that actually gets run in child process(es).""" """The wrapper function that actually gets run in child process(es)."""
global _barrier
self._resources = resources
_barrier = self._resources.barrier
proc_func = dill.loads(proc_func)
if use_dill_for_args:
args = dill.loads(args)
kwargs = dill.loads(kwargs)
if faulthandler is not None: if faulthandler is not None:
faulthandler.enable() faulthandler.enable()
faulthandler.register(signal.SIGTERM, chain=True) faulthandler.register(signal.SIGTERM, chain=True)
# All logging should go to stderr to be streamed to the main process.
logging.set_stderrthreshold(logging.DEBUG)
# Assign sys.stdout and sys.stderr as duplicates of `streaming_pipe_w` so
# print() and logging.*() write directly to `streaming_pipe_w`.
# Unfortunately since we cannot prepend task_type and task_id information to
# the streamed logs we will need a thread per subprocess to distinguish
# where the piece of message is from.
os.dup2(resources.streaming_pipe_w.fileno(), sys.stdout.fileno())
os.dup2(resources.streaming_pipe_w.fileno(), sys.stderr.fileno())
pid = os.getpid() pid = os.getpid()
logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid, logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid,
test_env.task_type, test_env.task_id) task_type, task_id)
_resource(SUBPROCESS_INFO_QUEUE).put(_SubprocessInfo(pid=pid))
# Assign sys.stdout and sys.stderr as duplicates of `pipe_w` so print() and
# logging.*() write directly to `pipe_w`. Unfortunately since we cannot
# prepend task_type and task_id information to the streamed logs we will
# need a thread per subprocess to distinguish where the piece of message is
# from.
os.dup2(pipe_w.fileno(), sys.stdout.fileno())
os.dup2(pipe_w.fileno(), sys.stderr.fileno())
# The thread will be dedicated to checking messages from the parent process. # The thread will be dedicated to checking messages from the parent process.
threading.Thread( # pylint: disable=unexpected-keyword-arg threading.Thread( # pylint: disable=unexpected-keyword-arg
target=self._message_checking_func, target=self._message_checking_func,
args=(test_env.task_type, test_env.task_id), args=(task_type, task_id),
daemon=True).start() daemon=True).start()
if test_env.v2_enabled: if grpc_fail_fast is not None:
os.environ['GRPC_FAIL_FAST'] = str(grpc_fail_fast)
tf_config_dict = {
'cluster': per_process_cluster_spec,
'task': {
'type': task_type,
'index': task_id,
},
}
if rpc_layer is not None:
tf_config_dict['rpc_layer'] = rpc_layer
os.environ['TF_CONFIG'] = json.dumps(tf_config_dict)
if v2_enabled:
v2_compat.enable_v2_behavior() v2_compat.enable_v2_behavior()
try: try:
with self._runtime_mode(test_env.executing_eagerly): with self._runtime_mode(executing_eagerly):
return_value = proc_func(*args, **kwargs) return_value = proc_func(*arg, **kwargs)
is_successful = True is_successful = True
exc_info = None exc_info = None
@ -618,27 +587,35 @@ class _ProcFunc(object):
raise raise
finally: finally:
info = _ProcessStatusInfo( self._finish_process(
task_type=test_env.task_type, _ProcessStatusInfo(
is_successful=is_successful, task_type=task_type,
exc_info=exc_info, is_successful=is_successful,
return_value=return_value) exc_info=exc_info),
self._resources.process_status_queue.put(info) return_value)
self._close_streaming()
def _add_return_data(self, data):
"""Adds return data that will be returned by `join`.
The function provides a way for child processes to communicate with the
parent process. Data passed to `_add_return_data` will be available in a
Python Queue.Queue that is eventually returned by `join`.
Args:
data: data to be made available in the queue returned by `join`.
"""
# TODO(rchao): Incorporate the task type and id information in a data
# wrapper that becomes what is stored in the queue so we can tell where
# the data is from.
_resource(RETURN_VALUE_QUEUE).put(data)
def _set_tf_config(task_type, task_id, cluster_spec, rpc_layer=None): def barrier():
"""Set TF_CONFIG environment variable.""" return multi_process_lib.get_user_data()[BARRIER]
tf_config_dict = {
'cluster': cluster_spec,
'task': { def _resource(resource_name):
'type': task_type, return multi_process_lib.get_user_data()[resource_name]
'index': task_id,
},
}
if rpc_layer is not None:
tf_config_dict['rpc_layer'] = rpc_layer
os.environ['TF_CONFIG'] = json.dumps(tf_config_dict)
def run(proc_func, def run(proc_func,
@ -674,19 +651,16 @@ def run(proc_func,
return runner.join(timeout) return runner.join(timeout)
# This is set by MultiProcessRunner in worker processes. def test_main(max_subprocess_count=_DEFAULT_MAX_SUBPROCESS_COUNT,
_barrier = None barrier_parties=0):
"""Main function to be called within `__main__` of a test file.
Args:
def barrier(): max_subprocess_count: Maximum number of subprocesses that will be used. User
if _barrier is None: of multi_process_runner needs to determine a number at calling this
raise ValueError( method, and the subprocesses involved later should not exceed this number.
'barrier is not defined. It is likely because you are calling barrier()' barrier_parties: Number of parties the barrier will be used toward. User of
'in the main process. barrier() can only be called in the subprocesses.' multi_process_runner needs to determine a number at calling this method.
) """
return _barrier with multi_process_lib.context_manager(max_subprocess_count, barrier_parties):
test.main()
def test_main():
"""Main function to be called within `__main__` of a test file."""
multi_process_lib.test_main()

View File

@ -23,13 +23,15 @@ import os
import threading import threading
import time import time
from absl import logging from absl import logging
from six.moves import queue as Queue
from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.eager import test from tensorflow.python.eager import test
def proc_func_that_adds_task_type_in_return_data(): def proc_func_that_adds_task_type_in_return_data(test_obj, val):
test_obj.assertEqual(val, 3)
return multi_worker_test_base.get_task_type() return multi_worker_test_base.get_task_type()
@ -49,10 +51,6 @@ def proc_func_that_return_args_and_kwargs(*args, **kwargs):
return list(args) + list(kwargs.items()) return list(args) + list(kwargs.items())
def proc_func_with_barrier():
return multi_process_runner.barrier()
class MultiProcessRunnerTest(test.TestCase): class MultiProcessRunnerTest(test.TestCase):
def _worker_idx(self): def _worker_idx(self):
@ -63,7 +61,8 @@ class MultiProcessRunnerTest(test.TestCase):
mpr_result = multi_process_runner.run( mpr_result = multi_process_runner.run(
proc_func_that_adds_task_type_in_return_data, proc_func_that_adds_task_type_in_return_data,
multi_worker_test_base.create_cluster_spec( multi_worker_test_base.create_cluster_spec(
num_workers=2, num_ps=3, has_eval=1)) num_workers=2, num_ps=3, has_eval=1),
args=(self, 3))
job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1} job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1}
for data in mpr_result.return_value: for data in mpr_result.return_value:
@ -125,22 +124,36 @@ class MultiProcessRunnerTest(test.TestCase):
def test_process_that_exits(self): def test_process_that_exits(self):
def func_to_exit_in_5_sec(): def func_to_exit_in_15_sec():
logging.error('foo') time.sleep(5)
time.sleep(10) print('foo', flush=True)
logging.error('bar') time.sleep(20)
print('bar', flush=True)
mpr = multi_process_runner.MultiProcessRunner( mpr = multi_process_runner.MultiProcessRunner(
func_to_exit_in_5_sec, func_to_exit_in_15_sec,
multi_worker_test_base.create_cluster_spec(num_workers=1), multi_worker_test_base.create_cluster_spec(num_workers=1),
list_stdout=True, list_stdout=True,
max_run_time=5) max_run_time=15)
mpr.start() mpr.start()
stdout = mpr.join().stdout stdout = mpr.join().stdout
self.assertLen([msg for msg in stdout if 'foo' in msg], 1) self.assertLen([msg for msg in stdout if 'foo' in msg], 1)
self.assertLen([msg for msg in stdout if 'bar' in msg], 0) self.assertLen([msg for msg in stdout if 'bar' in msg], 0)
def test_signal_doesnt_fire_after_process_exits(self):
mpr = multi_process_runner.MultiProcessRunner(
proc_func_that_does_nothing,
multi_worker_test_base.create_cluster_spec(num_workers=1),
max_run_time=10)
mpr.start()
mpr.join()
with self.assertRaisesRegexp(Queue.Empty, ''):
# If the signal was fired, another message would be added to internal
# queue, so verifying it's empty.
multi_process_runner._resource(
multi_process_runner.PROCESS_STATUS_QUEUE).get(block=False)
def test_termination(self): def test_termination(self):
def proc_func(): def proc_func():
@ -179,7 +192,7 @@ class MultiProcessRunnerTest(test.TestCase):
multi_worker_test_base.create_cluster_spec(num_workers=2), multi_worker_test_base.create_cluster_spec(num_workers=2),
list_stdout=True) list_stdout=True)
mpr.start() mpr.start()
time.sleep(3) time.sleep(5)
mpr.terminate('worker', 0) mpr.terminate('worker', 0)
mpr.start_single_process('worker', 0) mpr.start_single_process('worker', 0)
std_stream_results = mpr.join().stdout std_stream_results = mpr.join().stdout
@ -260,14 +273,11 @@ class MultiProcessRunnerTest(test.TestCase):
has_chief=True, num_workers=1), has_chief=True, num_workers=1),
list_stdout=True) list_stdout=True)
def eval_func(): def follow_ups():
time.sleep(1)
mpr.start_single_process(task_type='evaluator', task_id=0) mpr.start_single_process(task_type='evaluator', task_id=0)
eval_thread = threading.Thread(target=eval_func) threading.Thread(target=follow_ups).start()
eval_thread.start()
mpr.start_in_process_as(as_task_type='chief', as_task_id=0) mpr.start_in_process_as(as_task_type='chief', as_task_id=0)
eval_thread.join()
list_to_assert = mpr.join().stdout list_to_assert = mpr.join().stdout
for job in ['worker', 'evaluator']: for job in ['worker', 'evaluator']:
for iteration in range(5): for iteration in range(5):
@ -275,17 +285,5 @@ class MultiProcessRunnerTest(test.TestCase):
any('{}-0, i: {}'.format(job, iteration) in line any('{}-0, i: {}'.format(job, iteration) in line
for line in list_to_assert)) for line in list_to_assert))
def test_barrier(self):
multi_process_runner.run(
proc_func_with_barrier,
cluster_spec=multi_worker_test_base.create_cluster_spec(
has_chief=True, num_workers=1),
)
def test_barrier_called_in_main_process(self):
with self.assertRaises(ValueError):
multi_process_runner.barrier()
if __name__ == '__main__': if __name__ == '__main__':
multi_process_runner.test_main() multi_process_runner.test_main()

View File

@ -127,4 +127,4 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
multi_process_runner.test_main() multi_process_runner.test_main(barrier_parties=NUM_WORKERS)

View File

@ -364,7 +364,7 @@ py_test(
name = "multi_worker_callback_tf2_test", name = "multi_worker_callback_tf2_test",
srcs = ["multi_worker_callback_tf2_test.py"], srcs = ["multi_worker_callback_tf2_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 5, shard_count = 10,
deps = [ deps = [
"//tensorflow/python/distribute:collective_all_reduce_strategy", "//tensorflow/python/distribute:collective_all_reduce_strategy",
"//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:combinations",

View File

@ -208,7 +208,6 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
callbacks.BackupAndRestore(backup_dir=bar_dir), callbacks.BackupAndRestore(backup_dir=bar_dir),
AssertCallback() AssertCallback()
]) ])
multi_process_runner.barrier()
test_obj.assertFalse(file_io.file_exists(backup_filepath)) test_obj.assertFalse(file_io.file_exists(backup_filepath))
test_obj.assertTrue(file_io.file_exists(saving_filepath)) test_obj.assertTrue(file_io.file_exists(saving_filepath))
@ -344,4 +343,4 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
multi_process_runner.test_main() multi_process_runner.test_main(barrier_parties=2)