From c6745778706f434014ab4bcb3fe5926773f6eed1 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Thu, 11 Jun 2020 05:15:16 -0700 Subject: [PATCH] Improve multi_process_runner This is to prepare enabling it for OSS. PiperOrigin-RevId: 315878874 Change-Id: Ib29bccf3c964462a7643df4b1cd011ddda79372b --- tensorflow/opensource_only.files | 2 + tensorflow/python/distribute/BUILD | 11 +- .../python/distribute/multi_process_lib.py | 34 +- .../python/distribute/multi_process_runner.py | 568 +++++++++--------- .../distribute/multi_process_runner_test.py | 58 +- .../multi_worker_continuous_run_test.py | 2 +- tensorflow/python/keras/distribute/BUILD | 2 +- .../multi_worker_callback_tf2_test.py | 2 +- tensorflow/tools/pip_package/BUILD | 2 + tensorflow/workspace.bzl | 22 + third_party/dill.BUILD | 10 + third_party/tblib.BUILD | 11 + 12 files changed, 403 insertions(+), 321 deletions(-) create mode 100644 third_party/dill.BUILD create mode 100644 third_party/tblib.BUILD diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 5f1f2832cc8..3d57e5f2089 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -36,6 +36,7 @@ tensorflow/third_party/coremltools.BUILD tensorflow/third_party/cub.BUILD tensorflow/third_party/curl.BUILD tensorflow/third_party/cython.BUILD +tensorflow/third_party/dill.BUILD tensorflow/third_party/double_conversion.BUILD tensorflow/third_party/eigen.BUILD tensorflow/third_party/eigen3/BUILD @@ -196,6 +197,7 @@ tensorflow/third_party/systemlibs/swig.BUILD tensorflow/third_party/systemlibs/syslibs_configure.bzl tensorflow/third_party/systemlibs/termcolor.BUILD tensorflow/third_party/systemlibs/zlib.BUILD +tensorflow/third_party/tblib.BUILD tensorflow/third_party/tensorrt/BUILD tensorflow/third_party/tensorrt/BUILD.tpl tensorflow/third_party/tensorrt/LICENSE diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 7451d5b0408..fd655aa75d3 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1708,19 +1708,24 @@ cuda_py_test( py_library( name = "multi_process_runner", srcs = ["multi_process_runner.py"], + srcs_version = "PY3", deps = [ ":multi_process_lib", "//tensorflow/python:client_testlib", "//tensorflow/python:tf2", "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/eager:context", + "@absl_py//absl/logging", + "@dill_archive//:dill", "@six_archive//:six", + "@tblib_archive//:tblib", ], ) py_library( name = "multi_process_lib", srcs = ["multi_process_lib.py"], - deps = ["@six_archive//:six"], + deps = ["//tensorflow/python:client_testlib"], ) py_test( @@ -1745,11 +1750,12 @@ py_test( name = "multi_process_runner_test", srcs = ["multi_process_runner_test.py"], python_version = "PY3", - shard_count = 12, deps = [ ":multi_process_runner", ":multi_worker_test_base", "//tensorflow/python/eager:test", + "@absl_py//absl/logging", + "@six_archive//:six", ], ) @@ -1757,6 +1763,7 @@ py_test( name = "multi_process_runner_no_init_test", srcs = ["multi_process_runner_no_init_test.py"], python_version = "PY3", + tags = ["no_oss"], deps = [ ":multi_process_runner", ":multi_worker_test_base", diff --git a/tensorflow/python/distribute/multi_process_lib.py b/tensorflow/python/distribute/multi_process_lib.py index f3b03ca8bc4..12b81db7189 100644 --- a/tensorflow/python/distribute/multi_process_lib.py +++ b/tensorflow/python/distribute/multi_process_lib.py @@ -18,9 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib +import multiprocessing as _multiprocessing import unittest +from tensorflow.python.platform import test + + +try: + multiprocessing = _multiprocessing.get_context('forkserver') +except ValueError: + # forkserver is not available on Windows. + multiprocessing = _multiprocessing.get_context('spawn') + class Process(object): """A process simulating a worker for testing multi-worker training.""" @@ -28,23 +37,14 @@ class Process(object): def __init__(self, *args, **kwargs): del args, kwargs raise unittest.SkipTest( - 'TODO(b/141874796): Implement OSS version of `multi_process_lib`') + 'TODO(b/150264776): Implement OSS version of `multi_process_lib`') -def get_user_data(): - """Returns the data commonly shared by parent process and subprocesses.""" - # TODO(b/141874796): Implement OSS version of `multi_process_lib`. - pass +def test_main(): + """Main function to be called within `__main__` of a test file.""" + test.main() -@contextlib.contextmanager -def context_manager(max_subprocess_count=20, barrier_parties=0): - """No-op in OSS. This exists to maintain testing compatibility.""" - del max_subprocess_count, barrier_parties - yield - - -def using_context_manager(): - """Whether the context manager is being used.""" - raise unittest.SkipTest( - 'TODO(b/141874796): Implement OSS version of `multi_process_lib`') +def initialized(): + """Returns whether the module is initialized.""" + return True diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index 1938805ad97..55992bdae2a 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + import collections import contextlib import json @@ -26,6 +27,8 @@ import signal import sys import threading import time +import unittest + from absl import logging import six from six.moves import queue as Queue @@ -34,7 +37,8 @@ from tensorflow.python import tf2 from tensorflow.python.compat import v2_compat from tensorflow.python.distribute import multi_process_lib from tensorflow.python.eager import context -from tensorflow.python.platform import test + +multiprocessing = multi_process_lib.multiprocessing # pylint: disable=g-import-not-at-top try: @@ -43,61 +47,57 @@ try: except ImportError: faulthandler = None +# TODO(b/150264776): Remove after resolving CI issue. +try: + import dill +except ImportError: + dill = None + +# TODO(b/150264776): Remove after resolving CI issue. +try: + import tblib.pickling_support + # For pickling traceback objects. + tblib.pickling_support.install() +except ImportError: + pass + + # _ProcessStatusInfo contains process status information. When is_successful # attribute is True, the subprocess has ended successfully, or if False, the # exception stack trace info is stored in exc_info to pass on to parent process # to be re-raised. _ProcessStatusInfo = collections.namedtuple( - '_ProcessStatusInfo', ['task_type', 'is_successful', 'exc_info']) - -# _SubprocessInfo collects basic information of a subprocess such as task type -# and process id. -_SubprocessInfo = collections.namedtuple('_SubprocessInfo', - ['pid', 'task_type', 'task_id']) + '_ProcessStatusInfo', + ['task_type', 'is_successful', 'exc_info', 'return_value']) # Information returned from a successful MultiProcessRunner run. MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult', ['return_value', 'stdout']) -# Process status queue is used by `multi_process_runner` internally for -# communication from subprocesses to the parent process for whether it's been -# successful, and if not what the error stack trace is. -PROCESS_STATUS_QUEUE = 'process_status_queue' +TestEnvironment = collections.namedtuple('TestEnvironment', [ + 'task_type', 'task_id', 'cluster_spec', 'rpc_layer', 'grpc_fail_fast', + 'v2_enabled', 'executing_eagerly' +]) -# Return value queue is intended to be used by users of `multi_process_runner` -# for the process function to return information to the caller of -# `multi_process_runner.run()`. -RETURN_VALUE_QUEUE = 'return_value_queue' - -# Subprocess info queue stores `_SubprocessInfo` for later potential -# termination by the parent. -SUBPROCESS_INFO_QUEUE = 'subprocess_info_queue' - -# Parent-to-sub queue is used for communications from parent to subprocess. -# Currently this is only used to terminate subprocesses. +# Resources for communication between worker processes and the main process. +# +# `process_status_queue` is used by `multi_process_runner` internally for +# communication from subprocesses to the parent process for whether it's been +# successful, and if not what the error stack trace is. +# `parent_to_sub_queue` is used for communications from parent to subprocess. +# Currently this is only used to terminate subprocesses. # TODO(rchao): Remove this once subprocess is terminated by SIGKILL. -PARENT_TO_SUB_QUEUE = 'parent_to_sub_queue' - -# Streaming queue stores the logged and printed messages from subprocesses. -STREAMING_QUEUE = 'streaming_queue' - -# Pipes to stream stdout and stderr from subprocesses to parent process. -STREAMING_PIPE = 'streaming_pipe' - -# Barrier identifier. -BARRIER = 'barrier' - -_DEFAULT_MAX_SUBPROCESS_COUNT = 20 +# `streaming_pipe_w` is to stream stdout and stderr from subprocesses to parent +# process. +# `barrier` is a barrier for the party of all subprocesses. +Resources = collections.namedtuple('Resources', [ + 'process_status_queue', 'parent_to_sub_queue', 'streaming_pipe_w', 'barrier' +]) # Default time out sec is selected so that it's handled before the default # "medium" timeout of the test runs. _DEFAULT_TIMEOUT_SEC = 200 -# Next pipe index to be global so that pipes are not reused across multiple -# MultiProcessRunner usages. -# TODO(rchao): Investigate possibility to remove this variable. -_next_pipe_index = 0 - class MultiProcessRunner(object): """A utility class to start multiple processes to simulate a cluster. @@ -123,6 +123,7 @@ class MultiProcessRunner(object): grpc_fail_fast=None, stream_stdout=True, list_stdout=False, + use_dill_for_args=True, args=None, kwargs=None): """Creates a multi-process runner. @@ -153,6 +154,9 @@ class MultiProcessRunner(object): returned from `MultiProcessRunner.join()`. If True, the list of stdout can be retrieved via `MultiProcessRunnerResult.stdout` attribute. Defaults to False. + use_dill_for_args: Whether to use dill to pickle `args` and `kwargs`. dill + can pickle more objects, but doesn't work with types in + `multiprocessing` library like `Mutex`. args: Positional arguments to be sent to functions run on processes. kwargs: Keyword arguments to be sent to functions run on processes. @@ -165,15 +169,14 @@ class MultiProcessRunner(object): raise ValueError('If chief exists in the cluster, there must be at most ' 'one chief. Current `cluster_spec` has {} chiefs.' .format(len(cluster_spec['chief']))) - - assert callable(proc_func) - - if not multi_process_lib.using_context_manager(): + if not multi_process_lib.initialized(): raise RuntimeError('`multi_process_runner` is not initialized. ' 'Please call `multi_process_runner.test_main()` ' 'within `if __name__ == \'__main__\':` block ' 'in your python module to properly initialize ' '`multi_process_runner`.') + if not callable(proc_func): + raise ValueError('proc_func is not a callable') self._proc_func = proc_func self._cluster_spec = cluster_spec @@ -184,62 +187,90 @@ class MultiProcessRunner(object): # TODO(rchao): Revisit list_stdout argument to consider other solution. self._list_stdout = list_stdout self._dependence_on_chief = True + self._use_dill_for_args = use_dill_for_args self._args = args or () self._kwargs = kwargs or {} - self._outstanding_subprocess_count = 0 - # Child processes should have the same v2 and eager behavior. self._v2_enabled = tf2.enabled() self._executing_eagerly = context.executing_eagerly() + self._joined = False + self._processes = {} + self._outstanding_subprocess_count = 0 + self._reading_threads = [] + + self._manager = multiprocessing.Manager() + self._process_status_queue = self._manager.Queue() + self._parent_to_sub_queue = self._manager.Queue() + parties = sum(len(addresses) for addresses in self._cluster_spec.values()) + self._barrier = self._manager.Barrier(parties) + + # We use a queue to collect outputs from worker processes since it's thread + # safe. + self._streaming_queue = self._manager.Queue() + # This flag will be set to True once terminate_all() is called. self._all_forced_terminated = False def _continuously_readline_from_sub(self, pipe_r, task_type, task_id): """Function to continuously read lines from subprocesses.""" - reader = os.fdopen(pipe_r.fileno(), 'r') - while True: - read_line = reader.readline() - if read_line == 'EOF': - reader.close() - # The thread that runs `_continuously_readline_from_sub` stops here. - # However the threads don't exit until the test exits, so we do not - # attempt to join the threads (which leads to timeout). - # TODO(rchao): Understand why and do thread joining. - break - task_string = '[{}-{}]:'.format(task_type, task_id) - formatted_line = '{} {}'.format(task_string.ljust(14), read_line) - if self._stream_stdout: - self._print_stdout_in_parent(formatted_line, task_type, task_id) - if self._list_stdout: - self._add_stdout_in_queue(formatted_line, task_type, task_id) + with os.fdopen(pipe_r.fileno(), 'r', closefd=False) as reader: + for line in reader: + task_string = '[{}-{}]:'.format(task_type, task_id) + formatted_line = '{} {}'.format(task_string.ljust(14), line) + if self._stream_stdout: + # TODO(rchao): Use a lock here to ensure the printed lines are not + # broken. + print(formatted_line, end='', flush=True) + if self._list_stdout: + self._streaming_queue.put(formatted_line) - def _print_stdout_in_parent(self, formatted_line, task_type, task_id): - del task_type, task_id - # Flush True so the logging order from subprocesses is respected. - # TODO(rchao): Use a lock here to ensure the printed lines are not broken. - print(formatted_line, end='', flush=True) - - def _add_stdout_in_queue(self, formatted_line, task_type, task_id): - del task_type, task_id - # A queue instead of a simple list is used here due to b/150652733. - _resource(STREAMING_QUEUE).put(formatted_line) - - def _start_subprocess_and_reading_thread(self, proc_func, task_type, task_id, - cluster_spec, args, kwargs): + def _start_subprocess_and_reading_thread(self, + task_type, + task_id, + cluster_spec=None, + proc_func=None, + args=None, + kwargs=None): """Start a subprocess and a thread the reads lines from the subprocess.""" - global _next_pipe_index - pipe_r, pipe_w = _resource(STREAMING_PIPE)[_next_pipe_index] - _next_pipe_index += 1 - p = multi_process_lib.Process( - target=_Subprocess(), - args=(proc_func, task_type, task_id, cluster_spec, self._rpc_layer, - self._grpc_fail_fast, self._v2_enabled, self._executing_eagerly, - pipe_w) + args, - kwargs=kwargs) + if dill is None: + raise unittest.SkipTest( + 'TODO(b/150264776): Resolve dependency issue in CI') + + test_env = TestEnvironment( + task_type=task_type, + task_id=task_id, + cluster_spec=cluster_spec or self._cluster_spec, + rpc_layer=self._rpc_layer, + grpc_fail_fast=self._grpc_fail_fast, + v2_enabled=self._v2_enabled, + executing_eagerly=self._executing_eagerly, + ) + pipe_r, pipe_w = multiprocessing.Pipe(duplex=False) + resources = Resources( + process_status_queue=self._process_status_queue, + parent_to_sub_queue=self._parent_to_sub_queue, + streaming_pipe_w=pipe_w, + barrier=self._barrier, + ) + if proc_func is None: + proc_func, args, kwargs = self._proc_func, self._args, self._kwargs + # Always use dill to pickle proc_func so that we support more callable + # types, e.g. lambda. + proc_func = dill.dumps(proc_func, dill.HIGHEST_PROTOCOL) + if self._use_dill_for_args: + args = dill.dumps(args, dill.HIGHEST_PROTOCOL) + kwargs = dill.dumps(kwargs, dill.HIGHEST_PROTOCOL) + + p = _Process( + test_env=test_env, + target=_ProcFunc(), + args=(resources, test_env, proc_func, args, kwargs, + self._use_dill_for_args)) p.start() + self._processes[(task_type, task_id)] = p self._outstanding_subprocess_count += 1 # For each subprocess, we dedicate a thread continuously reading lines @@ -248,18 +279,15 @@ class MultiProcessRunner(object): target=self._continuously_readline_from_sub, args=(pipe_r, task_type, task_id)) thread.start() + self._reading_threads.append(thread) def start(self): """Starts processes, one for each task in `cluster_spec`.""" - - global _next_pipe_index - self._starting_pipe_index = _next_pipe_index - + if self._processes: + raise ValueError('MultiProcessRunner already started.') for task_type, addresses in self._cluster_spec.items(): for task_id, _ in enumerate(addresses): - self._start_subprocess_and_reading_thread(self._proc_func, task_type, - task_id, self._cluster_spec, - self._args, self._kwargs) + self._start_subprocess_and_reading_thread(task_type, task_id) # TODO(rchao): Remove the need of using SIGALRM if possible. At this time, # without this the tests become very flaky. @@ -309,33 +337,22 @@ class MultiProcessRunner(object): as_task_type: The task type to be run in the main process. as_task_id: The task id to be run in the main process. """ - global _next_pipe_index - self._starting_pipe_index = _next_pipe_index - + if self._processes: + raise ValueError('MultiProcessRunner already started.') for task_type, addresses in self._cluster_spec.items(): for task_id, _ in enumerate(addresses): if not (task_type == as_task_type and task_id == as_task_id): - self._start_subprocess_and_reading_thread(self._proc_func, task_type, - task_id, self._cluster_spec, - self._args, self._kwargs) - tf_config_dict = { - 'cluster': self._cluster_spec, - 'task': { - 'type': as_task_type, - 'index': as_task_id, - }, - } - if self._rpc_layer is not None: - tf_config_dict['rpc_layer'] = self._rpc_layer - os.environ['TF_CONFIG'] = json.dumps(tf_config_dict) + self._start_subprocess_and_reading_thread(task_type, task_id) + _set_tf_config(as_task_type, as_task_id, self._cluster_spec, + self._rpc_layer) self._proc_func(*self._args, **self._kwargs) def start_single_process(self, task_type, task_id, - proc_func=None, cluster_spec=None, + proc_func=None, args=None, kwargs=None): """Starts a single process. @@ -352,19 +369,22 @@ class MultiProcessRunner(object): Args: task_type: The task type. task_id: The task id. - proc_func: The process function to be run on the newly started - process. If `None`, the function provided at `__init__` will be used. cluster_spec: The cluster spec to be used on the newly started process. If `None`, the cluster spec provided at `__init__` will be used. + proc_func: The process function to be run on the newly started + process. If specified, specify `args` and `kwargs` as well. If `None`, + the function provided at `__init__` will be used. args: Optional positional arguments to be supplied in `proc_func`. kwargs: Optional keyword arguments to be supplied in `proc_func`. """ - cluster_spec = cluster_spec or self._cluster_spec - proc_func = proc_func or self._proc_func - self._start_subprocess_and_reading_thread(proc_func, task_type, task_id, - cluster_spec, args or (), - kwargs or {}) + self._start_subprocess_and_reading_thread( + task_type, + task_id, + cluster_spec=cluster_spec, + proc_func=proc_func, + args=args or (), + kwargs=kwargs or {}) def _queue_to_list(self, queue_to_convert): """Convert `queue.Queue` to `list`.""" @@ -379,25 +399,20 @@ class MultiProcessRunner(object): def get_process_id(self, task_type, task_id): """Returns the subprocess id given the task type and task id.""" - if not hasattr(self, '_pid_dict'): - self._pid_dict = {} - subprocess_infos = [] + p = self._processes.get((task_type, task_id), None) + return p.pid if p else None - while True: - try: - subprocess_info = _resource(SUBPROCESS_INFO_QUEUE).get(block=False) - subprocess_infos.append(subprocess_info) - except Queue.Empty: - break - - for subprocess_info in subprocess_infos: - self._pid_dict[(subprocess_info.task_type, - subprocess_info.task_id)] = subprocess_info.pid - - for subprocess_info in subprocess_infos: - _resource(SUBPROCESS_INFO_QUEUE).put(subprocess_info) - - return self._pid_dict.get((task_type, task_id), None) + def _join_or_terminate(self, task_type, task_id, process, timeout): + """Joins a process. If it times out, terminate all procsses.""" + logging.info('joining %s-%d', task_type, task_id) + process.join(timeout) + # If exitcode is None, the process aren't terminated and this is a + # timeout. + if process.exitcode is None: + # Force termination to dump worker processes stack trace. + self.terminate_all(sig=signal.SIGTERM) + raise RuntimeError('%s-%d and possibly more subprocesses timed out.' % + (task_type, task_id)) def join(self, timeout=_DEFAULT_TIMEOUT_SEC): """Joins all the processes with timeout. @@ -417,84 +432,97 @@ class MultiProcessRunner(object): RuntimeError: if not all processes report status approximatelty within `timeout` seconds, or there's an exception propagated from any subprocess. """ + if self._joined: + raise ValueError("MultiProcessRunner can't be joined twice.") + self._joined = True - if not timeout: - timeout = float('inf') - start_time = time.time() - while self._outstanding_subprocess_count > 0: - try: - process_status = _resource(PROCESS_STATUS_QUEUE).get(timeout=10) + chief = self._processes.get(('chief', 0), None) + if self._dependence_on_chief and chief: + self._join_or_terminate('chief', 0, chief, timeout) + # Give other processes a chance to exit on their own. + for p in self._processes.values(): + p.join(timeout=3) + self.terminate_all() + else: + for (task_type, task_id), p in self._processes.items(): + self._join_or_terminate(task_type, task_id, p, timeout) - self._outstanding_subprocess_count -= 1 - assert isinstance(process_status, _ProcessStatusInfo) - if not process_status.is_successful: - six.reraise(*process_status.exc_info) + for (task_type, task_id), p in self._processes.items(): + logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode) - if self._dependence_on_chief and process_status.task_type == 'chief': - self.terminate_all() - break - except Queue.Empty: - if self._all_forced_terminated: - break - if time.time() - start_time > timeout: - # Send SIGTERM signal to subprocesses to dump their current - # stack trace. - self.terminate_all(sig=signal.SIGTERM) - # If none of those did, report timeout to user. - raise RuntimeError('One or more subprocesses timed out. ' - 'Number of outstanding subprocesses ' - 'is %d.' % self._outstanding_subprocess_count) + process_statuses = self._queue_to_list(self._process_status_queue) + if not self._all_forced_terminated and len( + process_statuses) != self._outstanding_subprocess_count: + raise RuntimeError( + 'missing statuses from %d subproceses.' % + (self._outstanding_subprocess_count - len(process_statuses))) + return_values = [] + for process_status in process_statuses: + assert isinstance(process_status, _ProcessStatusInfo) + if not process_status.is_successful: + six.reraise(*process_status.exc_info) + if process_status.return_value is not None: + return_values.append(process_status.return_value) - # Giving threads some time to finish the message reading from subprocesses. - time.sleep(5) + logging.info('Joining log reading threads.') + for thread in self._reading_threads: + thread.join() + logging.info('Joined log reading threads.') - stdout = self._queue_to_list(_resource(STREAMING_QUEUE)) - return_value = self._queue_to_list(_resource(RETURN_VALUE_QUEUE)) + # Clear the alarm. + signal.alarm(0) - # Notifying the threads that are reading lines that we should stop. - for pipe_index in range(self._starting_pipe_index, _next_pipe_index): # pylint: disable=protected-access - _, pipe_w = _resource(STREAMING_PIPE)[pipe_index] - writer = os.fdopen(pipe_w.fileno(), 'w') - # Writing end of file message so the threads that's actively reading lines - # know to stop. - writer.writelines(['EOF']) - writer.close() + stdout = self._queue_to_list(self._streaming_queue) - return MultiProcessRunnerResult(stdout=stdout, return_value=return_value) + return MultiProcessRunnerResult(stdout=stdout, return_value=return_values) def terminate(self, task_type, task_id): """Terminates the process with `task_type` and `task_id`.""" - _resource(PARENT_TO_SUB_QUEUE).put('terminate {} {}'.format( - task_type, task_id)) + p = self._processes.get((task_type, task_id), None) + if p is None: + raise ValueError('{}-{} does not exist'.format(task_type, task_id)) + # TODO(crccw): change to use Process.terminate() as well. + self._parent_to_sub_queue.put('terminate {} {}'.format(task_type, task_id)) + p.join() def terminate_all(self, sig=None): """Terminates all subprocesses.""" # Use SIGKILL as default. In systems where that's unavailable such as # windows, use SIGTERM. sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM) - subprocess_infos = [] - - while True: + for (task_type, task_id), p in self._processes.items(): try: - subprocess_info = _resource(SUBPROCESS_INFO_QUEUE).get(block=False) - subprocess_infos.append(subprocess_info) - except Queue.Empty: - break - - for subprocess_info in subprocess_infos: - logging.info('Parent process is now killing PID: %d', subprocess_info.pid) - try: - os.kill(subprocess_info.pid, sig) + os.kill(p.pid, sig) except ProcessLookupError: - # TODO(rchao): Remove subprocess info from the queue once a subprocess - # is terminated. - logging.info('PID %d does not exist.', subprocess_info.pid) - + logging.info('Attempting to kill %s-%d but it does not exist.', + task_type, task_id) self._all_forced_terminated = True -class _Subprocess(object): - """Represents an internal subprocess used in MultiProcessRunner's context.""" +class _Process(multi_process_lib.Process): + """A modified `multiprocessing.Process` that can set up environment variables.""" + + # TODO(crccw): consider moving other logics in _ProcFunc to _Process. + + def __init__(self, test_env, **kwargs): + super(_Process, self).__init__(**kwargs) + self._test_env = test_env + self._actual_run = getattr(self, 'run') + self.run = self._run_with_setenv + + def _run_with_setenv(self): + # We need to set environment variables before doing anything because + # setenv() is not thread-safe. + test_env = self._test_env + if test_env.grpc_fail_fast is not None: + os.environ['GRPC_FAIL_FAST'] = str(test_env.grpc_fail_fast) + _set_tf_config(test_env.task_type, test_env.task_id, test_env.cluster_spec, + test_env.rpc_layer) + return self._actual_run() + + +class _ProcFunc(object): + """Represents a callable to run in a subprocess.""" @contextlib.contextmanager def _runtime_mode(self, executing_eagerly): @@ -505,21 +533,12 @@ class _Subprocess(object): with context.graph_mode(): yield - def _finish_process(self, process_status_info, return_value): - """Adds data to queues before program exits.""" - # Clear the alarm. - signal.alarm(0) - - if return_value is not None: - self._add_return_data(return_value) - _resource(PROCESS_STATUS_QUEUE).put(process_status_info) - def _message_checking_func(self, task_type, task_id): """A function that regularly checks messages from parent process.""" # TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess. while True: try: - message = _resource(PARENT_TO_SUB_QUEUE).get(block=False) + message = self._resources.parent_to_sub_queue.get(block=False) # Currently the only possible message is termination. if not message.startswith('terminate'): @@ -530,63 +549,75 @@ class _Subprocess(object): else: # If the message is not targeting this process, put it back to the # queue. - _resource(PARENT_TO_SUB_QUEUE).put(message) + self._resources.parent_to_sub_queue.put(message) time.sleep(1) except Queue.Empty: time.sleep(0.1) - self._finish_process( + self._resources.process_status_queue.put( _ProcessStatusInfo( - task_type=task_type, is_successful=True, exc_info=None), None) + task_type=task_type, + is_successful=True, + exc_info=None, + return_value=None)) # `os._exit(0)` is used to more reliably terminate a subprocess. os._exit(0) # pylint: disable=protected-access - def __call__(self, proc_func, task_type, task_id, per_process_cluster_spec, - rpc_layer, grpc_fail_fast, v2_enabled, executing_eagerly, pipe_w, - *arg, **kwargs): + def _close_streaming(self): + """Close stdout, stderr and streaming pipe. + + We need to explicitly close them since Tensorflow may take a while to exit, + so that the reading threads in the main process can exit more quickly. + """ + sys.stdout.flush() + sys.stderr.flush() + sys.stdout.close() + sys.stderr.close() + self._resources.streaming_pipe_w.close() + + def __call__(self, resources, test_env, proc_func, args, kwargs, + use_dill_for_args): """The wrapper function that actually gets run in child process(es).""" + global _barrier + + self._resources = resources + _barrier = self._resources.barrier + proc_func = dill.loads(proc_func) + if use_dill_for_args: + args = dill.loads(args) + kwargs = dill.loads(kwargs) + if faulthandler is not None: faulthandler.enable() faulthandler.register(signal.SIGTERM, chain=True) + # All logging should go to stderr to be streamed to the main process. + logging.set_stderrthreshold(logging.DEBUG) + + # Assign sys.stdout and sys.stderr as duplicates of `streaming_pipe_w` so + # print() and logging.*() write directly to `streaming_pipe_w`. + # Unfortunately since we cannot prepend task_type and task_id information to + # the streamed logs we will need a thread per subprocess to distinguish + # where the piece of message is from. + os.dup2(resources.streaming_pipe_w.fileno(), sys.stdout.fileno()) + os.dup2(resources.streaming_pipe_w.fileno(), sys.stderr.fileno()) + pid = os.getpid() logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid, - task_type, task_id) - _resource(SUBPROCESS_INFO_QUEUE).put( - _SubprocessInfo(pid=pid, task_type=task_type, task_id=task_id)) - # Assign sys.stdout and sys.stderr as duplicates of `pipe_w` so print() and - # logging.*() write directly to `pipe_w`. Unfortunately since we cannot - # prepend task_type and task_id information to the streamed logs we will - # need a thread per subprocess to distinguish where the piece of message is - # from. - os.dup2(pipe_w.fileno(), sys.stdout.fileno()) - os.dup2(pipe_w.fileno(), sys.stderr.fileno()) + test_env.task_type, test_env.task_id) # The thread will be dedicated to checking messages from the parent process. threading.Thread( # pylint: disable=unexpected-keyword-arg target=self._message_checking_func, - args=(task_type, task_id), + args=(test_env.task_type, test_env.task_id), daemon=True).start() - if grpc_fail_fast is not None: - os.environ['GRPC_FAIL_FAST'] = str(grpc_fail_fast) - tf_config_dict = { - 'cluster': per_process_cluster_spec, - 'task': { - 'type': task_type, - 'index': task_id, - }, - } - if rpc_layer is not None: - tf_config_dict['rpc_layer'] = rpc_layer - os.environ['TF_CONFIG'] = json.dumps(tf_config_dict) - - if v2_enabled: + if test_env.v2_enabled: v2_compat.enable_v2_behavior() try: - with self._runtime_mode(executing_eagerly): - return_value = proc_func(*arg, **kwargs) + with self._runtime_mode(test_env.executing_eagerly): + return_value = proc_func(*args, **kwargs) is_successful = True exc_info = None @@ -606,35 +637,27 @@ class _Subprocess(object): raise finally: - self._finish_process( - _ProcessStatusInfo( - task_type=task_type, - is_successful=is_successful, - exc_info=exc_info), - return_value) - - def _add_return_data(self, data): - """Adds return data that will be returned by `join`. - - The function provides a way for child processes to communicate with the - parent process. Data passed to `_add_return_data` will be available in a - Python Queue.Queue that is eventually returned by `join`. - - Args: - data: data to be made available in the queue returned by `join`. - """ - # TODO(rchao): Incorporate the task type and id information in a data - # wrapper that becomes what is stored in the queue so we can tell where - # the data is from. - _resource(RETURN_VALUE_QUEUE).put(data) + info = _ProcessStatusInfo( + task_type=test_env.task_type, + is_successful=is_successful, + exc_info=exc_info, + return_value=return_value) + self._resources.process_status_queue.put(info) + self._close_streaming() -def barrier(): - return multi_process_lib.get_user_data()[BARRIER] - - -def _resource(resource_name): - return multi_process_lib.get_user_data()[resource_name] +def _set_tf_config(task_type, task_id, cluster_spec, rpc_layer=None): + """Set TF_CONFIG environment variable.""" + tf_config_dict = { + 'cluster': cluster_spec, + 'task': { + 'type': task_type, + 'index': task_id, + }, + } + if rpc_layer is not None: + tf_config_dict['rpc_layer'] = rpc_layer + os.environ['TF_CONFIG'] = json.dumps(tf_config_dict) def run(proc_func, @@ -670,16 +693,19 @@ def run(proc_func, return runner.join(timeout) -def test_main(max_subprocess_count=_DEFAULT_MAX_SUBPROCESS_COUNT, - barrier_parties=0): - """Main function to be called within `__main__` of a test file. +# This is set by MultiProcessRunner in worker processes. +_barrier = None - Args: - max_subprocess_count: Maximum number of subprocesses that will be used. User - of multi_process_runner needs to determine a number at calling this - method, and the subprocesses involved later should not exceed this number. - barrier_parties: Number of parties the barrier will be used toward. User of - multi_process_runner needs to determine a number at calling this method. - """ - with multi_process_lib.context_manager(max_subprocess_count, barrier_parties): - test.main() + +def barrier(): + if _barrier is None: + raise ValueError( + 'barrier is not defined. It is likely because you are calling barrier()' + 'in the main process. barrier() can only be called in the subprocesses.' + ) + return _barrier + + +def test_main(): + """Main function to be called within `__main__` of a test file.""" + multi_process_lib.test_main() diff --git a/tensorflow/python/distribute/multi_process_runner_test.py b/tensorflow/python/distribute/multi_process_runner_test.py index 69e84581af3..924f0cb5ffe 100644 --- a/tensorflow/python/distribute/multi_process_runner_test.py +++ b/tensorflow/python/distribute/multi_process_runner_test.py @@ -23,15 +23,13 @@ import os import threading import time from absl import logging -from six.moves import queue as Queue from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.eager import test -def proc_func_that_adds_task_type_in_return_data(test_obj, val): - test_obj.assertEqual(val, 3) +def proc_func_that_adds_task_type_in_return_data(): return multi_worker_test_base.get_task_type() @@ -51,6 +49,10 @@ def proc_func_that_return_args_and_kwargs(*args, **kwargs): return list(args) + list(kwargs.items()) +def proc_func_with_barrier(): + return multi_process_runner.barrier() + + class MultiProcessRunnerTest(test.TestCase): def _worker_idx(self): @@ -61,8 +63,7 @@ class MultiProcessRunnerTest(test.TestCase): mpr_result = multi_process_runner.run( proc_func_that_adds_task_type_in_return_data, multi_worker_test_base.create_cluster_spec( - num_workers=2, num_ps=3, has_eval=1), - args=(self, 3)) + num_workers=2, num_ps=3, has_eval=1)) job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1} for data in mpr_result.return_value: @@ -124,36 +125,22 @@ class MultiProcessRunnerTest(test.TestCase): def test_process_that_exits(self): - def func_to_exit_in_15_sec(): - time.sleep(5) - print('foo', flush=True) - time.sleep(20) - print('bar', flush=True) + def func_to_exit_in_5_sec(): + logging.error('foo') + time.sleep(10) + logging.error('bar') mpr = multi_process_runner.MultiProcessRunner( - func_to_exit_in_15_sec, + func_to_exit_in_5_sec, multi_worker_test_base.create_cluster_spec(num_workers=1), list_stdout=True, - max_run_time=15) + max_run_time=5) mpr.start() stdout = mpr.join().stdout self.assertLen([msg for msg in stdout if 'foo' in msg], 1) self.assertLen([msg for msg in stdout if 'bar' in msg], 0) - def test_signal_doesnt_fire_after_process_exits(self): - mpr = multi_process_runner.MultiProcessRunner( - proc_func_that_does_nothing, - multi_worker_test_base.create_cluster_spec(num_workers=1), - max_run_time=10) - mpr.start() - mpr.join() - with self.assertRaisesRegexp(Queue.Empty, ''): - # If the signal was fired, another message would be added to internal - # queue, so verifying it's empty. - multi_process_runner._resource( - multi_process_runner.PROCESS_STATUS_QUEUE).get(block=False) - def test_termination(self): def proc_func(): @@ -192,7 +179,7 @@ class MultiProcessRunnerTest(test.TestCase): multi_worker_test_base.create_cluster_spec(num_workers=2), list_stdout=True) mpr.start() - time.sleep(5) + time.sleep(3) mpr.terminate('worker', 0) mpr.start_single_process('worker', 0) std_stream_results = mpr.join().stdout @@ -273,11 +260,14 @@ class MultiProcessRunnerTest(test.TestCase): has_chief=True, num_workers=1), list_stdout=True) - def follow_ups(): + def eval_func(): + time.sleep(1) mpr.start_single_process(task_type='evaluator', task_id=0) - threading.Thread(target=follow_ups).start() + eval_thread = threading.Thread(target=eval_func) + eval_thread.start() mpr.start_in_process_as(as_task_type='chief', as_task_id=0) + eval_thread.join() list_to_assert = mpr.join().stdout for job in ['worker', 'evaluator']: for iteration in range(5): @@ -291,9 +281,21 @@ class MultiProcessRunnerTest(test.TestCase): multi_worker_test_base.create_cluster_spec(num_workers=2), list_stdout=True) mpr.start() + time.sleep(3) mpr.terminate_all() with self.assertRaisesRegexp(ValueError, 'This is an error.'): mpr.join() + def test_barrier(self): + multi_process_runner.run( + proc_func_with_barrier, + cluster_spec=multi_worker_test_base.create_cluster_spec( + has_chief=True, num_workers=1), + ) + + def test_barrier_called_in_main_process(self): + with self.assertRaises(ValueError): + multi_process_runner.barrier() + if __name__ == '__main__': multi_process_runner.test_main() diff --git a/tensorflow/python/distribute/multi_worker_continuous_run_test.py b/tensorflow/python/distribute/multi_worker_continuous_run_test.py index 437255c1015..14e0564874b 100644 --- a/tensorflow/python/distribute/multi_worker_continuous_run_test.py +++ b/tensorflow/python/distribute/multi_worker_continuous_run_test.py @@ -127,4 +127,4 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase): if __name__ == '__main__': - multi_process_runner.test_main(barrier_parties=NUM_WORKERS) + multi_process_runner.test_main() diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index f989d93e82e..ddf274f299f 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -364,7 +364,7 @@ py_test( name = "multi_worker_callback_tf2_test", srcs = ["multi_worker_callback_tf2_test.py"], python_version = "PY3", - shard_count = 10, + shard_count = 5, deps = [ "//tensorflow/python/distribute:collective_all_reduce_strategy", "//tensorflow/python/distribute:combinations", diff --git a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py index 9c7756e3f5c..5de3ff0f613 100644 --- a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py +++ b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py @@ -345,4 +345,4 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): if __name__ == '__main__': - multi_process_runner.test_main(barrier_parties=2) + multi_process_runner.test_main() diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 36e20408c53..5e7cf143bba 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -183,6 +183,7 @@ filegroup( "@com_google_protobuf//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@curl//:COPYING", + "@dill_archive//:LICENSE", "@dlpack//:LICENSE", "@double_conversion//:LICENSE", "@eigen_archive//:COPYING.MPL2", @@ -212,6 +213,7 @@ filegroup( "@six_archive//:LICENSE", "@snappy//:COPYING", "@sobol_data//:LICENSE", + "@tblib_archive//:LICENSE", "@termcolor_archive//:COPYING.txt", "@zlib//:zlib.h", "@clog//:LICENSE", diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 1cd802f80af..cc6cb5033f7 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -532,6 +532,28 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ], ) + tf_http_archive( + name = "dill_archive", + build_file = clean_dep("//third_party:dill.BUILD"), + urls = [ + "http://mirror.tensorflow.org/files.pythonhosted.org/packages/c7/11/345f3173809cea7f1a193bfbf02403fff250a3360e0e118a1630985e547d/dill-0.3.1.1.tar.gz", + "https://files.pythonhosted.org/packages/c7/11/345f3173809cea7f1a193bfbf02403fff250a3360e0e118a1630985e547d/dill-0.3.1.1.tar.gz", + ], + sha256 = "42d8ef819367516592a825746a18073ced42ca169ab1f5f4044134703e7a049c", + strip_prefix = "dill-0.3.1.1", + ) + + tf_http_archive( + name = "tblib_archive", + build_file = clean_dep("//third_party:tblib.BUILD"), + urls = [ + "http://mirror.tensorflow.org/files.pythonhosted.org/packages/ec/c4/8c651f3240a73c28a218194f3d527eb2be5a173d08501060cdee84ade33f/tblib-1.3.2.tar.gz", + "https://files.pythonhosted.org/packages/ec/c4/8c651f3240a73c28a218194f3d527eb2be5a173d08501060cdee84ade33f/tblib-1.3.2.tar.gz", + ], + sha256 = "436e4200e63d92316551179dc540906652878df4ff39b43db30fcf6400444fe7", + strip_prefix = "tblib-1.3.2", + ) + filegroup_external( name = "org_python_license", licenses = ["notice"], # Python 2.0 diff --git a/third_party/dill.BUILD b/third_party/dill.BUILD new file mode 100644 index 00000000000..61eb841c64f --- /dev/null +++ b/third_party/dill.BUILD @@ -0,0 +1,10 @@ +licenses(["notice"]) # BSD 3-clause + +exports_files(["LICENSE"]) + +py_library( + name = "dill", + srcs = glob(["dill/*.py"]), + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +) diff --git a/third_party/tblib.BUILD b/third_party/tblib.BUILD new file mode 100644 index 00000000000..baed5c5ea66 --- /dev/null +++ b/third_party/tblib.BUILD @@ -0,0 +1,11 @@ +licenses(["notice"]) # BSD + +exports_files(["LICENSE"]) + +py_library( + name = "tblib", + srcs = glob(["src/tblib/*.py"]), + imports = ["src"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +)