From 5d7c5237ab9d93c35694ea4e1435421a5efd6a77 Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Tue, 2 Jun 2020 08:43:16 -0700 Subject: [PATCH] Improve multi_process_runner This is to prepare enabling it for OSS. PiperOrigin-RevId: 314337449 Change-Id: I64d2498ceac4f78638f0cae429d073600eb9a1e1 --- tensorflow/python/distribute/BUILD | 9 +- .../python/distribute/multi_process_lib.py | 32 +- .../python/distribute/multi_process_runner.py | 532 +++++++++--------- .../distribute/multi_process_runner_test.py | 58 +- .../multi_worker_continuous_run_test.py | 2 +- tensorflow/python/keras/distribute/BUILD | 2 +- .../multi_worker_callback_tf2_test.py | 3 +- 7 files changed, 302 insertions(+), 336 deletions(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 140b3089e32..26027d46c98 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1695,14 +1695,11 @@ cuda_py_test( py_library( name = "multi_process_runner", srcs = ["multi_process_runner.py"], - srcs_version = "PY3", deps = [ ":multi_process_lib", "//tensorflow/python:client_testlib", "//tensorflow/python:tf2", "//tensorflow/python/compat:v2_compat", - "//tensorflow/python/eager:context", - "@absl_py//absl/logging", "@six_archive//:six", ], ) @@ -1710,19 +1707,18 @@ py_library( py_library( name = "multi_process_lib", srcs = ["multi_process_lib.py"], - deps = ["//tensorflow/python:client_testlib"], + deps = ["@six_archive//:six"], ) py_test( name = "multi_process_runner_test", srcs = ["multi_process_runner_test.py"], python_version = "PY3", + shard_count = 12, deps = [ ":multi_process_runner", ":multi_worker_test_base", "//tensorflow/python/eager:test", - "@absl_py//absl/logging", - "@six_archive//:six", ], ) @@ -1730,7 +1726,6 @@ py_test( name = "multi_process_runner_no_init_test", srcs = ["multi_process_runner_no_init_test.py"], python_version = "PY3", - tags = ["no_oss"], deps = [ ":multi_process_runner", ":multi_worker_test_base", diff --git a/tensorflow/python/distribute/multi_process_lib.py b/tensorflow/python/distribute/multi_process_lib.py index ae9aa494062..f3b03ca8bc4 100644 --- a/tensorflow/python/distribute/multi_process_lib.py +++ b/tensorflow/python/distribute/multi_process_lib.py @@ -18,18 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import multiprocessing as _multiprocessing +import contextlib import unittest -from tensorflow.python.platform import test - - -try: - multiprocessing = _multiprocessing.get_context('forkserver') -except ValueError: - # forkserver is not available on Windows. - multiprocessing = _multiprocessing.get_context('spawn') - class Process(object): """A process simulating a worker for testing multi-worker training.""" @@ -40,11 +31,20 @@ class Process(object): 'TODO(b/141874796): Implement OSS version of `multi_process_lib`') -def test_main(): - """Main function to be called within `__main__` of a test file.""" - test.main() +def get_user_data(): + """Returns the data commonly shared by parent process and subprocesses.""" + # TODO(b/141874796): Implement OSS version of `multi_process_lib`. + pass -def initialized(): - """Returns whether the module is initialized.""" - return True +@contextlib.contextmanager +def context_manager(max_subprocess_count=20, barrier_parties=0): + """No-op in OSS. This exists to maintain testing compatibility.""" + del max_subprocess_count, barrier_parties + yield + + +def using_context_manager(): + """Whether the context manager is being used.""" + raise unittest.SkipTest( + 'TODO(b/141874796): Implement OSS version of `multi_process_lib`') diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index 924e7b85b12..e258d98fb8c 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - import collections import contextlib import json @@ -27,19 +26,15 @@ import signal import sys import threading import time - from absl import logging -import dill import six from six.moves import queue as Queue -import tblib.pickling_support from tensorflow.python import tf2 from tensorflow.python.compat import v2_compat from tensorflow.python.distribute import multi_process_lib from tensorflow.python.eager import context - -multiprocessing = multi_process_lib.multiprocessing +from tensorflow.python.platform import test # pylint: disable=g-import-not-at-top try: @@ -48,45 +43,61 @@ try: except ImportError: faulthandler = None -# For pickling traceback objects. -tblib.pickling_support.install() - # _ProcessStatusInfo contains process status information. When is_successful # attribute is True, the subprocess has ended successfully, or if False, the # exception stack trace info is stored in exc_info to pass on to parent process # to be re-raised. _ProcessStatusInfo = collections.namedtuple( - '_ProcessStatusInfo', - ['task_type', 'is_successful', 'exc_info', 'return_value']) + '_ProcessStatusInfo', ['task_type', 'is_successful', 'exc_info']) + +# _SubprocessInfo collects basic information of a subprocess such as task type +# and process id. +# TODO(rchao): Include task_type and task_id in subprocess info. +_SubprocessInfo = collections.namedtuple('_SubprocessInfo', ['pid']) # Information returned from a successful MultiProcessRunner run. MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult', ['return_value', 'stdout']) -TestEnvironment = collections.namedtuple('TestEnvironment', [ - 'task_type', 'task_id', 'cluster_spec', 'rpc_layer', 'grpc_fail_fast', - 'v2_enabled', 'executing_eagerly' -]) +# Process status queue is used by `multi_process_runner` internally for +# communication from subprocesses to the parent process for whether it's been +# successful, and if not what the error stack trace is. +PROCESS_STATUS_QUEUE = 'process_status_queue' -# Resources for communication between worker processes and the main process. -# -# `process_status_queue` is used by `multi_process_runner` internally for -# communication from subprocesses to the parent process for whether it's been -# successful, and if not what the error stack trace is. -# `parent_to_sub_queue` is used for communications from parent to subprocess. -# Currently this is only used to terminate subprocesses. +# Return value queue is intended to be used by users of `multi_process_runner` +# for the process function to return information to the caller of +# `multi_process_runner.run()`. +RETURN_VALUE_QUEUE = 'return_value_queue' + +# Subprocess info queue stores `_SubprocessInfo` for later potential +# termination by the parent. +SUBPROCESS_INFO_QUEUE = 'subprocess_info_queue' + +# Parent-to-sub queue is used for communications from parent to subprocess. +# Currently this is only used to terminate subprocesses. # TODO(rchao): Remove this once subprocess is terminated by SIGKILL. -# `streaming_pipe_w` is to stream stdout and stderr from subprocesses to parent -# process. -# `barrier` is a barrier for the party of all subprocesses. -Resources = collections.namedtuple('Resources', [ - 'process_status_queue', 'parent_to_sub_queue', 'streaming_pipe_w', 'barrier' -]) +PARENT_TO_SUB_QUEUE = 'parent_to_sub_queue' + +# Streaming queue stores the logged and printed messages from subprocesses. +STREAMING_QUEUE = 'streaming_queue' + +# Pipes to stream stdout and stderr from subprocesses to parent process. +STREAMING_PIPE = 'streaming_pipe' + +# Barrier identifier. +BARRIER = 'barrier' + +_DEFAULT_MAX_SUBPROCESS_COUNT = 20 # Default time out sec is selected so that it's handled before the default # "medium" timeout of the test runs. _DEFAULT_TIMEOUT_SEC = 200 +# Next pipe index to be global so that pipes are not reused across multiple +# MultiProcessRunner usages. +# TODO(rchao): Investigate possibility to remove this variable. +_next_pipe_index = 0 + class MultiProcessRunner(object): """A utility class to start multiple processes to simulate a cluster. @@ -112,7 +123,6 @@ class MultiProcessRunner(object): grpc_fail_fast=None, stream_stdout=True, list_stdout=False, - use_dill_for_args=True, args=None, kwargs=None): """Creates a multi-process runner. @@ -143,9 +153,6 @@ class MultiProcessRunner(object): returned from `MultiProcessRunner.join()`. If True, the list of stdout can be retrieved via `MultiProcessRunnerResult.stdout` attribute. Defaults to False. - use_dill_for_args: Whether to use dill to pickle `args` and `kwargs`. dill - can pickle more objects, but doesn't work with types in - `multiprocessing` library like `Mutex`. args: Positional arguments to be sent to functions run on processes. kwargs: Keyword arguments to be sent to functions run on processes. @@ -159,15 +166,15 @@ class MultiProcessRunner(object): 'one chief. Current `cluster_spec` has {} chiefs.' .format(len(cluster_spec['chief']))) - if not multi_process_lib.initialized(): + assert callable(proc_func) + + if not multi_process_lib.using_context_manager(): raise RuntimeError('`multi_process_runner` is not initialized. ' 'Please call `multi_process_runner.test_main()` ' 'within `if __name__ == \'__main__\':` block ' 'in your python module to properly initialize ' '`multi_process_runner`.') - assert callable(proc_func) - self._proc_func = proc_func self._cluster_spec = cluster_spec self._rpc_layer = rpc_layer @@ -177,86 +184,62 @@ class MultiProcessRunner(object): # TODO(rchao): Revisit list_stdout argument to consider other solution. self._list_stdout = list_stdout self._dependence_on_chief = True - self._use_dill_for_args = use_dill_for_args self._args = args or () self._kwargs = kwargs or {} + self._outstanding_subprocess_count = 0 + # Child processes should have the same v2 and eager behavior. self._v2_enabled = tf2.enabled() self._executing_eagerly = context.executing_eagerly() - self._joined = False - self._processes = {} - self._outstanding_subprocess_count = 0 - self._reading_threads = [] - - self._manager = multiprocessing.Manager() - self._process_status_queue = self._manager.Queue() - self._parent_to_sub_queue = self._manager.Queue() - parties = sum(len(addresses) for addresses in self._cluster_spec.values()) - self._barrier = self._manager.Barrier(parties) - - # We use a queue to collect outputs from worker processes since it's thread - # safe. - self._streaming_queue = self._manager.Queue() - # This flag will be set to True once terminate_all() is called. self._all_forced_terminated = False def _continuously_readline_from_sub(self, pipe_r, task_type, task_id): """Function to continuously read lines from subprocesses.""" - with os.fdopen(pipe_r.fileno(), 'r', closefd=False) as reader: - for line in reader: - task_string = '[{}-{}]:'.format(task_type, task_id) - formatted_line = '{} {}'.format(task_string.ljust(14), line) - if self._stream_stdout: - # TODO(rchao): Use a lock here to ensure the printed lines are not - # broken. - print(formatted_line, end='', flush=True) - if self._list_stdout: - self._streaming_queue.put(formatted_line) + reader = os.fdopen(pipe_r.fileno(), 'r') + while True: + read_line = reader.readline() + if read_line == 'EOF': + reader.close() + # The thread that runs `_continuously_readline_from_sub` stops here. + # However the threads don't exit until the test exits, so we do not + # attempt to join the threads (which leads to timeout). + # TODO(rchao): Understand why and do thread joining. + break + task_string = '[{}-{}]:'.format(task_type, task_id) + formatted_line = '{} {}'.format(task_string.ljust(14), read_line) + if self._stream_stdout: + self._print_stdout_in_parent(formatted_line, task_type, task_id) + if self._list_stdout: + self._add_stdout_in_queue(formatted_line, task_type, task_id) - def _start_subprocess_and_reading_thread(self, - task_type, - task_id, - cluster_spec=None, - proc_func=None, - args=None, - kwargs=None): + def _print_stdout_in_parent(self, formatted_line, task_type, task_id): + del task_type, task_id + # Flush True so the logging order from subprocesses is respected. + # TODO(rchao): Use a lock here to ensure the printed lines are not broken. + print(formatted_line, end='', flush=True) + + def _add_stdout_in_queue(self, formatted_line, task_type, task_id): + del task_type, task_id + # A queue instead of a simple list is used here due to b/150652733. + _resource(STREAMING_QUEUE).put(formatted_line) + + def _start_subprocess_and_reading_thread(self, proc_func, task_type, task_id, + cluster_spec, args, kwargs): """Start a subprocess and a thread the reads lines from the subprocess.""" + global _next_pipe_index + pipe_r, pipe_w = _resource(STREAMING_PIPE)[_next_pipe_index] + _next_pipe_index += 1 - test_env = TestEnvironment( - task_type=task_type, - task_id=task_id, - cluster_spec=cluster_spec or self._cluster_spec, - rpc_layer=self._rpc_layer, - grpc_fail_fast=self._grpc_fail_fast, - v2_enabled=self._v2_enabled, - executing_eagerly=self._executing_eagerly, - ) - pipe_r, pipe_w = multiprocessing.Pipe(duplex=False) - resources = Resources( - process_status_queue=self._process_status_queue, - parent_to_sub_queue=self._parent_to_sub_queue, - streaming_pipe_w=pipe_w, - barrier=self._barrier, - ) - if proc_func is None: - proc_func, args, kwargs = self._proc_func, self._args, self._kwargs - # Always use dill to pickle proc_func so that we support more callable - # types, e.g. lambda. - proc_func = dill.dumps(proc_func, dill.HIGHEST_PROTOCOL) - if self._use_dill_for_args: - args = dill.dumps(args, dill.HIGHEST_PROTOCOL) - kwargs = dill.dumps(kwargs, dill.HIGHEST_PROTOCOL) - - p = _Process( - test_env=test_env, - target=_ProcFunc(), - args=(resources, test_env, proc_func, args, kwargs, - self._use_dill_for_args)) + p = multi_process_lib.Process( + target=_Subprocess(), + args=(proc_func, task_type, task_id, cluster_spec, self._rpc_layer, + self._grpc_fail_fast, self._v2_enabled, self._executing_eagerly, + pipe_w) + args, + kwargs=kwargs) p.start() - self._processes[(task_type, task_id)] = p self._outstanding_subprocess_count += 1 # For each subprocess, we dedicate a thread continuously reading lines @@ -265,15 +248,18 @@ class MultiProcessRunner(object): target=self._continuously_readline_from_sub, args=(pipe_r, task_type, task_id)) thread.start() - self._reading_threads.append(thread) def start(self): """Starts processes, one for each task in `cluster_spec`.""" - if self._processes: - raise ValueError('MultiProcessRunner already started.') + + global _next_pipe_index + self._starting_pipe_index = _next_pipe_index + for task_type, addresses in self._cluster_spec.items(): for task_id, _ in enumerate(addresses): - self._start_subprocess_and_reading_thread(task_type, task_id) + self._start_subprocess_and_reading_thread(self._proc_func, task_type, + task_id, self._cluster_spec, + self._args, self._kwargs) # TODO(rchao): Remove the need of using SIGALRM if possible. At this time, # without this the tests become very flaky. @@ -323,22 +309,33 @@ class MultiProcessRunner(object): as_task_type: The task type to be run in the main process. as_task_id: The task id to be run in the main process. """ - if self._processes: - raise ValueError('MultiProcessRunner already started.') + global _next_pipe_index + self._starting_pipe_index = _next_pipe_index + for task_type, addresses in self._cluster_spec.items(): for task_id, _ in enumerate(addresses): if not (task_type == as_task_type and task_id == as_task_id): - self._start_subprocess_and_reading_thread(task_type, task_id) + self._start_subprocess_and_reading_thread(self._proc_func, task_type, + task_id, self._cluster_spec, + self._args, self._kwargs) + tf_config_dict = { + 'cluster': self._cluster_spec, + 'task': { + 'type': as_task_type, + 'index': as_task_id, + }, + } + if self._rpc_layer is not None: + tf_config_dict['rpc_layer'] = self._rpc_layer + os.environ['TF_CONFIG'] = json.dumps(tf_config_dict) - _set_tf_config(as_task_type, as_task_id, self._cluster_spec, - self._rpc_layer) self._proc_func(*self._args, **self._kwargs) def start_single_process(self, task_type, task_id, - cluster_spec=None, proc_func=None, + cluster_spec=None, args=None, kwargs=None): """Starts a single process. @@ -355,22 +352,19 @@ class MultiProcessRunner(object): Args: task_type: The task type. task_id: The task id. + proc_func: The process function to be run on the newly started + process. If `None`, the function provided at `__init__` will be used. cluster_spec: The cluster spec to be used on the newly started process. If `None`, the cluster spec provided at `__init__` will be used. - proc_func: The process function to be run on the newly started - process. If specified, specify `args` and `kwargs` as well. If `None`, - the function provided at `__init__` will be used. args: Optional positional arguments to be supplied in `proc_func`. kwargs: Optional keyword arguments to be supplied in `proc_func`. """ - self._start_subprocess_and_reading_thread( - task_type, - task_id, - cluster_spec=cluster_spec, - proc_func=proc_func, - args=args or (), - kwargs=kwargs or {}) + cluster_spec = cluster_spec or self._cluster_spec + proc_func = proc_func or self._proc_func + self._start_subprocess_and_reading_thread(proc_func, task_type, task_id, + cluster_spec, args or (), + kwargs or {}) def _queue_to_list(self, queue_to_convert): """Convert `queue.Queue` to `list`.""" @@ -383,18 +377,6 @@ class MultiProcessRunner(object): break return list_to_return - def _join_or_terminate(self, task_type, task_id, process, timeout): - """Joins a process. If it times out, terminate all procsses.""" - logging.info('joining %s-%d', task_type, task_id) - process.join(timeout) - # If exitcode is None, the process aren't terminated and this is a - # timeout. - if process.exitcode is None: - # Force termination to dump worker processes stack trace. - self.terminate_all(sig=signal.SIGTERM) - raise RuntimeError('%s-%d and possibly more subprocesses timed out.' % - (task_type, task_id)) - def join(self, timeout=_DEFAULT_TIMEOUT_SEC): """Joins all the processes with timeout. @@ -413,97 +395,88 @@ class MultiProcessRunner(object): RuntimeError: if not all processes report status approximatelty within `timeout` seconds, or there's an exception propagated from any subprocess. """ - if self._joined: - raise ValueError("MultiProcessRunner can't be joined twice.") - self._joined = True - chief = self._processes.get(('chief', 0), None) - if self._dependence_on_chief and chief: - self._join_or_terminate('chief', 0, chief, timeout) - # Give other processes a chance to exit on their own. - for p in self._processes.values(): - p.join(timeout=3) - self.terminate_all() - else: - for (task_type, task_id), p in self._processes.items(): - self._join_or_terminate(task_type, task_id, p, timeout) + if not timeout: + timeout = float('inf') + start_time = time.time() + while self._outstanding_subprocess_count > 0: + while True: + try: + process_status = _resource(PROCESS_STATUS_QUEUE).get(timeout=10) + break + except Queue.Empty: + if self._all_forced_terminated: + break + if time.time() - start_time > timeout: + # Send SIGTERM signal to subprocesses to dump their current + # stack trace. + self.terminate_all(sig=signal.SIGTERM) + # If none of those did, report timeout to user. + raise RuntimeError('One or more subprocesses timed out. ' + 'Number of outstanding subprocesses ' + 'is %d.' % self._outstanding_subprocess_count) - for (task_type, task_id), p in self._processes.items(): - logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode) - - process_statuses = self._queue_to_list(self._process_status_queue) - if not self._all_forced_terminated and len( - process_statuses) != self._outstanding_subprocess_count: - raise RuntimeError( - 'missing statuses from %d subproceses.' % - (self._outstanding_subprocess_count - len(process_statuses))) - return_values = [] - for process_status in process_statuses: + if self._all_forced_terminated: + break + self._outstanding_subprocess_count -= 1 assert isinstance(process_status, _ProcessStatusInfo) if not process_status.is_successful: six.reraise(*process_status.exc_info) - if process_status.return_value is not None: - return_values.append(process_status.return_value) - logging.info('Joining log reading threads.') - for thread in self._reading_threads: - thread.join() - logging.info('Joined log reading threads.') + if self._dependence_on_chief and process_status.task_type == 'chief': + self.terminate_all() + break - # Clear the alarm. - signal.alarm(0) + # Giving threads some time to finish the message reading from subprocesses. + time.sleep(5) - stdout = self._queue_to_list(self._streaming_queue) + stdout = self._queue_to_list(_resource(STREAMING_QUEUE)) + return_value = self._queue_to_list(_resource(RETURN_VALUE_QUEUE)) - return MultiProcessRunnerResult(stdout=stdout, return_value=return_values) + # Notifying the threads that are reading lines that we should stop. + for pipe_index in range(self._starting_pipe_index, _next_pipe_index): # pylint: disable=protected-access + _, pipe_w = _resource(STREAMING_PIPE)[pipe_index] + writer = os.fdopen(pipe_w.fileno(), 'w') + # Writing end of file message so the threads that's actively reading lines + # know to stop. + writer.writelines(['EOF']) + writer.close() + + return MultiProcessRunnerResult(stdout=stdout, return_value=return_value) def terminate(self, task_type, task_id): """Terminates the process with `task_type` and `task_id`.""" - p = self._processes.get((task_type, task_id), None) - if p is None: - raise ValueError('{}-{} does not exist'.format(task_type, task_id)) - # TODO(crccw): change to use Process.terminate() as well. - self._parent_to_sub_queue.put('terminate {} {}'.format(task_type, task_id)) - p.join() + _resource(PARENT_TO_SUB_QUEUE).put('terminate {} {}'.format( + task_type, task_id)) def terminate_all(self, sig=None): """Terminates all subprocesses.""" # Use SIGKILL as default. In systems where that's unavailable such as # windows, use SIGTERM. sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM) - for (task_type, task_id), p in self._processes.items(): + subprocess_infos = [] + + while True: try: - os.kill(p.pid, sig) + subprocess_info = _resource(SUBPROCESS_INFO_QUEUE).get(block=False) + subprocess_infos.append(subprocess_info) + except Queue.Empty: + break + + for subprocess_info in subprocess_infos: + logging.info('Parent process is now killing PID: %d', subprocess_info.pid) + try: + os.kill(subprocess_info.pid, sig) except ProcessLookupError: - logging.info('Attempting to kill %s-%d but it does not exist.', - task_type, task_id) + # TODO(rchao): Remove subprocess info from the queue once a subprocess + # is terminated. + logging.info('PID %d does not exist.', subprocess_info.pid) + self._all_forced_terminated = True -class _Process(multi_process_lib.Process): - """A modified `multiprocessing.Process` that can set up environment variables.""" - - # TODO(crccw): consider moving other logics in _ProcFunc to _Process. - - def __init__(self, test_env, **kwargs): - super(_Process, self).__init__(**kwargs) - self._test_env = test_env - self._actual_run = getattr(self, 'run') - self.run = self._run_with_setenv - - def _run_with_setenv(self): - # We need to set environment variables before doing anything because - # setenv() is not thread-safe. - test_env = self._test_env - if test_env.grpc_fail_fast is not None: - os.environ['GRPC_FAIL_FAST'] = str(test_env.grpc_fail_fast) - _set_tf_config(test_env.task_type, test_env.task_id, test_env.cluster_spec, - test_env.rpc_layer) - return self._actual_run() - - -class _ProcFunc(object): - """Represents a callable to run in a subprocess.""" +class _Subprocess(object): + """Represents an internal subprocess used in MultiProcessRunner's context.""" @contextlib.contextmanager def _runtime_mode(self, executing_eagerly): @@ -514,12 +487,21 @@ class _ProcFunc(object): with context.graph_mode(): yield + def _finish_process(self, process_status_info, return_value): + """Adds data to queues before program exits.""" + # Clear the alarm. + signal.alarm(0) + + if return_value is not None: + self._add_return_data(return_value) + _resource(PROCESS_STATUS_QUEUE).put(process_status_info) + def _message_checking_func(self, task_type, task_id): """A function that regularly checks messages from parent process.""" # TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess. while True: try: - message = self._resources.parent_to_sub_queue.get(block=False) + message = _resource(PARENT_TO_SUB_QUEUE).get(block=False) # Currently the only possible message is termination. if not message.startswith('terminate'): @@ -530,75 +512,62 @@ class _ProcFunc(object): else: # If the message is not targeting this process, put it back to the # queue. - self._resources.parent_to_sub_queue.put(message) + _resource(PARENT_TO_SUB_QUEUE).put(message) time.sleep(1) except Queue.Empty: time.sleep(0.1) - self._resources.process_status_queue.put( + self._finish_process( _ProcessStatusInfo( - task_type=task_type, - is_successful=True, - exc_info=None, - return_value=None)) + task_type=task_type, is_successful=True, exc_info=None), None) # `os._exit(0)` is used to more reliably terminate a subprocess. os._exit(0) # pylint: disable=protected-access - def _close_streaming(self): - """Close stdout, stderr and streaming pipe. - - We need to explicitly close them since Tensorflow may take a while to exit, - so that the reading threads in the main process can exit more quickly. - """ - sys.stdout.flush() - sys.stderr.flush() - sys.stdout.close() - sys.stderr.close() - self._resources.streaming_pipe_w.close() - - def __call__(self, resources, test_env, proc_func, args, kwargs, - use_dill_for_args): + def __call__(self, proc_func, task_type, task_id, per_process_cluster_spec, + rpc_layer, grpc_fail_fast, v2_enabled, executing_eagerly, pipe_w, + *arg, **kwargs): """The wrapper function that actually gets run in child process(es).""" - global _barrier - - self._resources = resources - _barrier = self._resources.barrier - proc_func = dill.loads(proc_func) - if use_dill_for_args: - args = dill.loads(args) - kwargs = dill.loads(kwargs) - if faulthandler is not None: faulthandler.enable() faulthandler.register(signal.SIGTERM, chain=True) - # All logging should go to stderr to be streamed to the main process. - logging.set_stderrthreshold(logging.DEBUG) - - # Assign sys.stdout and sys.stderr as duplicates of `streaming_pipe_w` so - # print() and logging.*() write directly to `streaming_pipe_w`. - # Unfortunately since we cannot prepend task_type and task_id information to - # the streamed logs we will need a thread per subprocess to distinguish - # where the piece of message is from. - os.dup2(resources.streaming_pipe_w.fileno(), sys.stdout.fileno()) - os.dup2(resources.streaming_pipe_w.fileno(), sys.stderr.fileno()) - pid = os.getpid() logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid, - test_env.task_type, test_env.task_id) + task_type, task_id) + _resource(SUBPROCESS_INFO_QUEUE).put(_SubprocessInfo(pid=pid)) + # Assign sys.stdout and sys.stderr as duplicates of `pipe_w` so print() and + # logging.*() write directly to `pipe_w`. Unfortunately since we cannot + # prepend task_type and task_id information to the streamed logs we will + # need a thread per subprocess to distinguish where the piece of message is + # from. + os.dup2(pipe_w.fileno(), sys.stdout.fileno()) + os.dup2(pipe_w.fileno(), sys.stderr.fileno()) # The thread will be dedicated to checking messages from the parent process. threading.Thread( # pylint: disable=unexpected-keyword-arg target=self._message_checking_func, - args=(test_env.task_type, test_env.task_id), + args=(task_type, task_id), daemon=True).start() - if test_env.v2_enabled: + if grpc_fail_fast is not None: + os.environ['GRPC_FAIL_FAST'] = str(grpc_fail_fast) + tf_config_dict = { + 'cluster': per_process_cluster_spec, + 'task': { + 'type': task_type, + 'index': task_id, + }, + } + if rpc_layer is not None: + tf_config_dict['rpc_layer'] = rpc_layer + os.environ['TF_CONFIG'] = json.dumps(tf_config_dict) + + if v2_enabled: v2_compat.enable_v2_behavior() try: - with self._runtime_mode(test_env.executing_eagerly): - return_value = proc_func(*args, **kwargs) + with self._runtime_mode(executing_eagerly): + return_value = proc_func(*arg, **kwargs) is_successful = True exc_info = None @@ -618,27 +587,35 @@ class _ProcFunc(object): raise finally: - info = _ProcessStatusInfo( - task_type=test_env.task_type, - is_successful=is_successful, - exc_info=exc_info, - return_value=return_value) - self._resources.process_status_queue.put(info) - self._close_streaming() + self._finish_process( + _ProcessStatusInfo( + task_type=task_type, + is_successful=is_successful, + exc_info=exc_info), + return_value) + + def _add_return_data(self, data): + """Adds return data that will be returned by `join`. + + The function provides a way for child processes to communicate with the + parent process. Data passed to `_add_return_data` will be available in a + Python Queue.Queue that is eventually returned by `join`. + + Args: + data: data to be made available in the queue returned by `join`. + """ + # TODO(rchao): Incorporate the task type and id information in a data + # wrapper that becomes what is stored in the queue so we can tell where + # the data is from. + _resource(RETURN_VALUE_QUEUE).put(data) -def _set_tf_config(task_type, task_id, cluster_spec, rpc_layer=None): - """Set TF_CONFIG environment variable.""" - tf_config_dict = { - 'cluster': cluster_spec, - 'task': { - 'type': task_type, - 'index': task_id, - }, - } - if rpc_layer is not None: - tf_config_dict['rpc_layer'] = rpc_layer - os.environ['TF_CONFIG'] = json.dumps(tf_config_dict) +def barrier(): + return multi_process_lib.get_user_data()[BARRIER] + + +def _resource(resource_name): + return multi_process_lib.get_user_data()[resource_name] def run(proc_func, @@ -674,19 +651,16 @@ def run(proc_func, return runner.join(timeout) -# This is set by MultiProcessRunner in worker processes. -_barrier = None +def test_main(max_subprocess_count=_DEFAULT_MAX_SUBPROCESS_COUNT, + barrier_parties=0): + """Main function to be called within `__main__` of a test file. - -def barrier(): - if _barrier is None: - raise ValueError( - 'barrier is not defined. It is likely because you are calling barrier()' - 'in the main process. barrier() can only be called in the subprocesses.' - ) - return _barrier - - -def test_main(): - """Main function to be called within `__main__` of a test file.""" - multi_process_lib.test_main() + Args: + max_subprocess_count: Maximum number of subprocesses that will be used. User + of multi_process_runner needs to determine a number at calling this + method, and the subprocesses involved later should not exceed this number. + barrier_parties: Number of parties the barrier will be used toward. User of + multi_process_runner needs to determine a number at calling this method. + """ + with multi_process_lib.context_manager(max_subprocess_count, barrier_parties): + test.main() diff --git a/tensorflow/python/distribute/multi_process_runner_test.py b/tensorflow/python/distribute/multi_process_runner_test.py index cf68ffd50d7..1413777d0bc 100644 --- a/tensorflow/python/distribute/multi_process_runner_test.py +++ b/tensorflow/python/distribute/multi_process_runner_test.py @@ -23,13 +23,15 @@ import os import threading import time from absl import logging +from six.moves import queue as Queue from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.eager import test -def proc_func_that_adds_task_type_in_return_data(): +def proc_func_that_adds_task_type_in_return_data(test_obj, val): + test_obj.assertEqual(val, 3) return multi_worker_test_base.get_task_type() @@ -49,10 +51,6 @@ def proc_func_that_return_args_and_kwargs(*args, **kwargs): return list(args) + list(kwargs.items()) -def proc_func_with_barrier(): - return multi_process_runner.barrier() - - class MultiProcessRunnerTest(test.TestCase): def _worker_idx(self): @@ -63,7 +61,8 @@ class MultiProcessRunnerTest(test.TestCase): mpr_result = multi_process_runner.run( proc_func_that_adds_task_type_in_return_data, multi_worker_test_base.create_cluster_spec( - num_workers=2, num_ps=3, has_eval=1)) + num_workers=2, num_ps=3, has_eval=1), + args=(self, 3)) job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1} for data in mpr_result.return_value: @@ -125,22 +124,36 @@ class MultiProcessRunnerTest(test.TestCase): def test_process_that_exits(self): - def func_to_exit_in_5_sec(): - logging.error('foo') - time.sleep(10) - logging.error('bar') + def func_to_exit_in_15_sec(): + time.sleep(5) + print('foo', flush=True) + time.sleep(20) + print('bar', flush=True) mpr = multi_process_runner.MultiProcessRunner( - func_to_exit_in_5_sec, + func_to_exit_in_15_sec, multi_worker_test_base.create_cluster_spec(num_workers=1), list_stdout=True, - max_run_time=5) + max_run_time=15) mpr.start() stdout = mpr.join().stdout self.assertLen([msg for msg in stdout if 'foo' in msg], 1) self.assertLen([msg for msg in stdout if 'bar' in msg], 0) + def test_signal_doesnt_fire_after_process_exits(self): + mpr = multi_process_runner.MultiProcessRunner( + proc_func_that_does_nothing, + multi_worker_test_base.create_cluster_spec(num_workers=1), + max_run_time=10) + mpr.start() + mpr.join() + with self.assertRaisesRegexp(Queue.Empty, ''): + # If the signal was fired, another message would be added to internal + # queue, so verifying it's empty. + multi_process_runner._resource( + multi_process_runner.PROCESS_STATUS_QUEUE).get(block=False) + def test_termination(self): def proc_func(): @@ -179,7 +192,7 @@ class MultiProcessRunnerTest(test.TestCase): multi_worker_test_base.create_cluster_spec(num_workers=2), list_stdout=True) mpr.start() - time.sleep(3) + time.sleep(5) mpr.terminate('worker', 0) mpr.start_single_process('worker', 0) std_stream_results = mpr.join().stdout @@ -260,14 +273,11 @@ class MultiProcessRunnerTest(test.TestCase): has_chief=True, num_workers=1), list_stdout=True) - def eval_func(): - time.sleep(1) + def follow_ups(): mpr.start_single_process(task_type='evaluator', task_id=0) - eval_thread = threading.Thread(target=eval_func) - eval_thread.start() + threading.Thread(target=follow_ups).start() mpr.start_in_process_as(as_task_type='chief', as_task_id=0) - eval_thread.join() list_to_assert = mpr.join().stdout for job in ['worker', 'evaluator']: for iteration in range(5): @@ -275,17 +285,5 @@ class MultiProcessRunnerTest(test.TestCase): any('{}-0, i: {}'.format(job, iteration) in line for line in list_to_assert)) - def test_barrier(self): - multi_process_runner.run( - proc_func_with_barrier, - cluster_spec=multi_worker_test_base.create_cluster_spec( - has_chief=True, num_workers=1), - ) - - def test_barrier_called_in_main_process(self): - with self.assertRaises(ValueError): - multi_process_runner.barrier() - - if __name__ == '__main__': multi_process_runner.test_main() diff --git a/tensorflow/python/distribute/multi_worker_continuous_run_test.py b/tensorflow/python/distribute/multi_worker_continuous_run_test.py index 14e0564874b..437255c1015 100644 --- a/tensorflow/python/distribute/multi_worker_continuous_run_test.py +++ b/tensorflow/python/distribute/multi_worker_continuous_run_test.py @@ -127,4 +127,4 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase): if __name__ == '__main__': - multi_process_runner.test_main() + multi_process_runner.test_main(barrier_parties=NUM_WORKERS) diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index ddf274f299f..f989d93e82e 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -364,7 +364,7 @@ py_test( name = "multi_worker_callback_tf2_test", srcs = ["multi_worker_callback_tf2_test.py"], python_version = "PY3", - shard_count = 5, + shard_count = 10, deps = [ "//tensorflow/python/distribute:collective_all_reduce_strategy", "//tensorflow/python/distribute:combinations", diff --git a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py index 660a8e8cb6c..8daa46f6ea3 100644 --- a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py +++ b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py @@ -208,7 +208,6 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): callbacks.BackupAndRestore(backup_dir=bar_dir), AssertCallback() ]) - multi_process_runner.barrier() test_obj.assertFalse(file_io.file_exists(backup_filepath)) test_obj.assertTrue(file_io.file_exists(saving_filepath)) @@ -344,4 +343,4 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): if __name__ == '__main__': - multi_process_runner.test_main() + multi_process_runner.test_main(barrier_parties=2)