diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index e5be4fa4a14..f3ec0d44486 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -67,7 +67,8 @@ except ImportError: # exception stack trace info is stored in exc_info to pass on to parent process # to be re-raised. _ProcessStatusInfo = collections.namedtuple( - '_ProcessStatusInfo', ['is_successful', 'exc_info', 'return_value']) + '_ProcessStatusInfo', + ['task_type', 'task_id', 'is_successful', 'exc_info', 'return_value']) # Information returned from a successful MultiProcessRunner run. MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult', @@ -124,6 +125,8 @@ class MultiProcessRunner(object): list_stdout=False, use_dill_for_args=True, daemon=False, + dependence_on_chief=True, + auto_restart=False, args=None, kwargs=None): """Creates a multi-process runner. @@ -161,6 +164,11 @@ class MultiProcessRunner(object): can pickle more objects, but doesn't work with types in `multiprocessing` library like `Mutex`. daemon: Whether to start processes as daemons. + dependence_on_chief: Whether to terminates the cluster if the chief exits. + If auto_restart is True, it only terminates the cluster if the chief + exits with a zero exit code. + auto_restart: Whether to automatically restart processes that exit with + non-zero exit code. args: Positional arguments to be sent to functions run on processes. kwargs: Keyword arguments to be sent to functions run on processes. @@ -190,9 +198,10 @@ class MultiProcessRunner(object): self._stream_stdout = stream_stdout # TODO(rchao): Revisit list_stdout argument to consider other solution. self._list_stdout = list_stdout - self._dependence_on_chief = True + self._dependence_on_chief = dependence_on_chief self._use_dill_for_args = use_dill_for_args self._daemon = daemon + self._auto_restart = auto_restart self._args = args or () self._kwargs = kwargs or {} @@ -201,8 +210,15 @@ class MultiProcessRunner(object): self._executing_eagerly = context.executing_eagerly() self._joined = False + self._process_lock = threading.Lock() + # Guarded by self._process_lock. self._processes = {} - self._outstanding_subprocess_count = 0 + # Record which processes are terminated. Due to a bug in Python<3.7, + # terminated processes return 255 exit code, which should cause an exception + # in join(). + # https://bugs.python.org/issue30589 + # Guarded by self._process_lock. + self._terminated = set() self._reading_threads = [] self._manager = manager() @@ -215,8 +231,7 @@ class MultiProcessRunner(object): # safe. self._streaming_queue = self._manager.Queue() - # This flag will be set to True once terminate_all() is called. - self._all_forced_terminated = False + self._watchdog_thread = None def set_args(self, args=None, kwargs=None): self._args = args or self._args @@ -281,7 +296,7 @@ class MultiProcessRunner(object): daemon=self._daemon) p.start() self._processes[(task_type, task_id)] = p - self._outstanding_subprocess_count += 1 + self._terminated.discard((task_type, task_id)) # For each subprocess, we dedicate a thread continuously reading lines # from them. @@ -291,17 +306,26 @@ class MultiProcessRunner(object): thread.start() self._reading_threads.append(thread) + if self._watchdog_thread is None or not self._watchdog_thread.is_alive(): + self._watchdog_thread = threading.Thread(target=self._process_watchdog) + self._watchdog_thread.start() + def start(self): """Starts processes, one for each task in `cluster_spec`. Note that this is best effort by the applicable multiprocessing library, and it may take up to seconds for a subprocess to be successfully started. """ - if self._processes: - raise ValueError('MultiProcessRunner already started.') - for task_type, addresses in self._cluster_spec.items(): - for task_id, _ in enumerate(addresses): - self._start_subprocess_and_reading_thread(task_type, task_id) + with self._process_lock: + if self._processes: + raise ValueError('MultiProcessRunner already started.') + if self._joined: + raise ValueError('cannot start new processes after' + 'MultiProcessRunner.join() is called') + + for task_type, addresses in self._cluster_spec.items(): + for task_id, _ in enumerate(addresses): + self._start_subprocess_and_reading_thread(task_type, task_id) # TODO(rchao): Remove the need of using SIGALRM if possible. At this time, # without this the tests become very flaky. @@ -353,10 +377,14 @@ class MultiProcessRunner(object): """ if self._processes: raise ValueError('MultiProcessRunner already started.') - for task_type, addresses in self._cluster_spec.items(): - for task_id, _ in enumerate(addresses): - if not (task_type == as_task_type and task_id == as_task_id): - self._start_subprocess_and_reading_thread(task_type, task_id) + with self._process_lock: + if self._joined: + raise ValueError('cannot start new processes after' + 'MultiProcessRunner.join() is called') + for task_type, addresses in self._cluster_spec.items(): + for task_id, _ in enumerate(addresses): + if not (task_type == as_task_type and task_id == as_task_id): + self._start_subprocess_and_reading_thread(task_type, task_id) _set_tf_config(as_task_type, as_task_id, self._cluster_spec, self._rpc_layer) @@ -392,13 +420,17 @@ class MultiProcessRunner(object): args: Optional positional arguments to be supplied in `proc_func`. kwargs: Optional keyword arguments to be supplied in `proc_func`. """ - self._start_subprocess_and_reading_thread( - task_type, - task_id, - cluster_spec=cluster_spec, - proc_func=proc_func, - args=args or (), - kwargs=kwargs or {}) + with self._process_lock: + if self._joined: + raise ValueError('cannot start new processes after' + 'MultiProcessRunner.join() is called') + self._start_subprocess_and_reading_thread( + task_type, + task_id, + cluster_spec=cluster_spec, + proc_func=proc_func, + args=args or (), + kwargs=kwargs or {}) def _queue_to_list(self, queue_to_convert): """Convert `queue.Queue` to `list`.""" @@ -411,9 +443,17 @@ class MultiProcessRunner(object): break return list_to_return + def _get_process_statuses(self): + # One worker may have multiple statuses. We only keep the last one. + statuses = {} + for status in self._queue_to_list(self._process_status_queue): + statuses[(status.task_type, status.task_id)] = status + return statuses + def get_process_id(self, task_type, task_id): """Returns the subprocess id given the task type and task id.""" - p = self._processes.get((task_type, task_id), None) + with self._process_lock: + p = self._processes.get((task_type, task_id), None) return p.pid if p else None def get_process_exit_code(self, task_type, task_id): @@ -430,22 +470,54 @@ class MultiProcessRunner(object): KeyError: If the corresponding subprocess is not found with `task_type` and `task_id`. """ - p = self._processes[(task_type, task_id)] + with self._process_lock: + p = self._processes[(task_type, task_id)] return p.exitcode if p else None - def _join_or_terminate(self, task_type, task_id, process, timeout): - """Joins a process. If it times out, terminate all procsses.""" - logging.info('joining %s-%d', task_type, task_id) - process.join(timeout) - # If exitcode is None, the process aren't terminated and this is a - # timeout. - if process.exitcode is None: - # Force termination to dump worker processes stack trace. - self.terminate_all(sig=signal.SIGTERM) - process_statuses = self._queue_to_list(self._process_status_queue) - raise SubprocessTimeoutError( - '%s-%d and possibly more subprocesses timed out.' % - (task_type, task_id), self._get_mpr_result(process_statuses)) + def _process_watchdog(self): + """Simulates a cluster management system. + + - If auto_restart is True, it restarts processes that exit with a non-zero + exit code. Note that when join() times out it overrides auto_restart to + False. + - If dependence_on_chief is True, it terminates all processes once the chief + exits. If auto_restart is also True, it only terminates all processes if + the chief exit with a zero exit code, otherwise it restarts the chief. + + This runs in self._watchdog_thread. + """ + while True: + time.sleep(1) + with self._process_lock: + chief = self._processes.get(('chief', 0), None) + # Terminate the cluster when _dependence_on_chief is True if either: + # - chief has exited with zero exit code. + # - chief has exited with non-zero exit code and self._auto_restart is + # False. + if chief and self._dependence_on_chief and chief.exitcode is not None: + if chief.exitcode == 0 or (not self._auto_restart): + for p in self._processes.values(): + # Give other processes a chance to exit on their own. + p.join(timeout=3) + self._terminate_all() + for p in self._processes.values(): + p.join() + return + + # Auto restart failed processes if self._auto_restart is True. + if self._auto_restart: + has_failure = False + for (task_type, task_id), p in self._processes.items(): + if p.exitcode is not None and p.exitcode != 0: + has_failure = True + logging.info('Restarting failed %s-%d', task_type, task_id) + self._start_subprocess_and_reading_thread(task_type, task_id) + if has_failure: + continue + + # Exit the thread if all processes have exited at this point. + if all(p.exitcode is not None for p in self._processes.values()): + return def join(self, timeout=_DEFAULT_TIMEOUT_SEC): """Joins all the processes with timeout. @@ -489,41 +561,40 @@ class MultiProcessRunner(object): cases. Exception: if there is an Exception propagated from any subprocess. """ - if self._joined: - raise ValueError("MultiProcessRunner can't be joined twice.") - self._joined = True + with self._process_lock: + if self._joined: + raise ValueError("MultiProcessRunner can't be joined twice.") + self._joined = True - chief = self._processes.get(('chief', 0), None) - if self._dependence_on_chief and chief: - self._join_or_terminate('chief', 0, chief, timeout) - # Give other processes a chance to exit on their own. - for p in self._processes.values(): - p.join(timeout=3) - self.terminate_all() - else: - for (task_type, task_id), p in self._processes.items(): - self._join_or_terminate(task_type, task_id, p, timeout) + self._watchdog_thread.join(timeout) + if self._watchdog_thread.is_alive(): + # Timeout. Force termination to dump worker processes stack trace. + with self._process_lock: + self._auto_restart = False + self.terminate_all(sig=signal.SIGTERM) + self._watchdog_thread.join() + process_statuses = self._get_process_statuses() + raise SubprocessTimeoutError('one or more subprocesses timed out.', + self._get_mpr_result(process_statuses)) for (task_type, task_id), p in self._processes.items(): logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode) - process_statuses = self._queue_to_list(self._process_status_queue) - for process_status in process_statuses: + process_statuses = self._get_process_statuses() + for process_status in process_statuses.values(): assert isinstance(process_status, _ProcessStatusInfo) if not process_status.is_successful: six.reraise(*process_status.exc_info) # Checking all the processes that are expected to exit properly. for (task_type, task_id), p in self._processes.items(): - if self._dependence_on_chief and chief and task_type != 'chief': - # If _dependence_on_chief, other processes may have been - # forced-terminated, which is expected. - continue - # Successfully exiting process has exit code 0. - if p.exitcode is None or p.exitcode > 0: + # Successfully exiting process has exit code 0. We ignore processes that + # are terminated. + assert p.exitcode is not None + if (p.exitcode > 0 and (task_type, task_id) not in self._terminated): raise UnexpectedSubprocessExitError( - 'Subprocess %s-%d exited with exit code %d. See logs for details.' % - (task_type, task_id, p.exitcode), + 'Subprocess %s-%d exited with exit code %s. See logs for details.' + % (task_type, task_id, p.exitcode), self._get_mpr_result(process_statuses)) logging.info('Joining log reading threads.') @@ -539,34 +610,60 @@ class MultiProcessRunner(object): def _get_mpr_result(self, process_statuses): stdout = self._queue_to_list(self._streaming_queue) return_values = [] - for process_status in process_statuses: + for process_status in process_statuses.values(): if process_status.return_value is not None: return_values.append(process_status.return_value) return MultiProcessRunnerResult(stdout=stdout, return_value=return_values) def terminate(self, task_type, task_id): - """Terminates the process with `task_type` and `task_id`.""" - p = self._processes.get((task_type, task_id), None) - if p is None: - raise ValueError('{}-{} does not exist'.format(task_type, task_id)) - # TODO(crccw): change to use Process.terminate() as well. - self._parent_to_sub_queue.put('terminate {} {}'.format(task_type, task_id)) - p.join() + """Terminates the process with `task_type` and `task_id`. + + If auto_retart=True, the terminated task will be restarted unless the chief + has already exited with zero exit code. + + Args: + task_type: the task type. + task_id: the task id. + + """ + with self._process_lock: + p = self._processes.get((task_type, task_id), None) + if p is None: + raise ValueError('{}-{} does not exist'.format(task_type, task_id)) + self._terminated.add((task_type, task_id)) + # TODO(crccw): change to use Process.terminate() as well. + self._parent_to_sub_queue.put('terminate {} {}'.format( + task_type, task_id)) + p.join() + + def _terminate_all(self, sig=None): + """Terminates all subprocesses. + + The caller is required to hold self._process_lock. + + Args: + sig: the signal used to terminate the process. The default is SIGKILL. + """ - def terminate_all(self, sig=None): - """Terminates all subprocesses.""" # Use SIGKILL as default. In systems where that's unavailable such as # windows, use SIGTERM. sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM) for (task_type, task_id), p in self._processes.items(): + if p.exitcode is not None: + continue try: os.kill(p.pid, sig) + self._terminated.add((task_type, task_id)) logging.info('%s-%d terminated with signal %r.', task_type, task_id, sig) except ProcessLookupError: logging.info('Attempting to kill %s-%d but it does not exist.', task_type, task_id) - self._all_forced_terminated = True + + def terminate_all(self, sig=None): + """Terminates all subprocesses.""" + with self._process_lock: + self._terminate_all(sig) class _Process(multi_process_lib.Process): @@ -625,11 +722,13 @@ class _ProcFunc(object): time.sleep(0.1) self._resources.process_status_queue.put( _ProcessStatusInfo( + task_type=task_type, + task_id=task_id, is_successful=True, exc_info=None, return_value=None)) - # `os._exit(0)` is used to more reliably terminate a subprocess. - os._exit(0) # pylint: disable=protected-access + # `os._exit(1)` is used to more reliably terminate a subprocess. + os._exit(1) # pylint: disable=protected-access def _close_streaming(self): """Close stdout, stderr and streaming pipe. @@ -685,7 +784,8 @@ class _ProcFunc(object): v2_compat.enable_v2_behavior() with self._runtime_mode(test_env.executing_eagerly): - info = _run_contained(proc_func, args, kwargs) + info = _run_contained(test_env.task_type, test_env.task_id, proc_func, + args, kwargs) self._resources.process_status_queue.put(info) # Re-raise the exception in addition to reporting it to the parent @@ -774,7 +874,7 @@ class MultiProcessPoolRunner(object): task_type, task_id, proc_func=_pool_runner_worker, - args=(initializer, conn2)) + args=(task_type, task_id, initializer, conn2)) def run(self, proc_func, args=None, kwargs=None): """Runs `proc_func` with `args` and `kwargs` on all jobs. @@ -819,7 +919,7 @@ class MultiProcessPoolRunner(object): return return_values -def _pool_runner_worker(initializer, conn): +def _pool_runner_worker(task_type, task_id, initializer, conn): """Function that runs on the workers in a pool. It listens for callables to run and returns the result until `conn` is closed. @@ -827,8 +927,10 @@ def _pool_runner_worker(initializer, conn): `conn`. Args: - initializer: A callable to execute during startup. - conn: A multiprocessing.Connection object to listen for tasks and send + task_type: the task type. + task_id: the task index. + initializer: a callable to execute during startup. + conn: a multiprocessing.Connection object to listen for tasks and send results. """ if initializer: @@ -840,22 +942,24 @@ def _pool_runner_worker(initializer, conn): except EOFError: break proc_func = dill.loads(proc_func) - info = _run_contained(proc_func, args, kwargs) + info = _run_contained(task_type, task_id, proc_func, args, kwargs) sys.stdout.flush() sys.stderr.flush() conn.send(info) -def _run_contained(proc_func, args, kwargs): +def _run_contained(task_type, task_id, proc_func, args, kwargs): """Runs `proc_func` with `args` and `kwargs`. The function returns _ProcessStatusInfo which captures the return value and the exception. Args: - proc_func: The function to be run. - args: Optional positional arguments to be supplied in `proc_func`. - kwargs: Optional keyword arguments to be supplied in `proc_func`. + task_type: the task type. + task_id: the task index. + proc_func: the function to be run. + args: optional positional arguments to be supplied in `proc_func`. + kwargs: optional keyword arguments to be supplied in `proc_func`. Returns: a _ProcessStatusInfo. @@ -868,6 +972,8 @@ def _run_contained(proc_func, args, kwargs): return_value = proc_func(*args, **kwargs) is_successful = True return _ProcessStatusInfo( + task_type=task_type, + task_id=task_id, is_successful=is_successful, exc_info=exc_info, return_value=return_value) @@ -877,6 +983,8 @@ def _run_contained(proc_func, args, kwargs): except Exception: # pylint: disable=broad-except exc_info = sys.exc_info() return _ProcessStatusInfo( + task_type=task_type, + task_id=task_id, is_successful=is_successful, exc_info=exc_info, return_value=return_value) diff --git a/tensorflow/python/distribute/multi_process_runner_test.py b/tensorflow/python/distribute/multi_process_runner_test.py index c6266a5be26..0aa214d3ca4 100644 --- a/tensorflow/python/distribute/multi_process_runner_test.py +++ b/tensorflow/python/distribute/multi_process_runner_test.py @@ -156,11 +156,8 @@ class MultiProcessRunnerTest(test.TestCase): mpr.start() time.sleep(5) mpr.terminate('worker', 0) - with self.assertRaises( - multi_process_runner.UnexpectedSubprocessExitError) as cm: - mpr.join() - std_stream_results = cm.exception.mpr_result.stdout + std_stream_results = mpr.join().stdout # Worker 0 is terminated in the middle, so it should not have iteration 9 # printed. @@ -388,6 +385,99 @@ class MultiProcessRunnerTest(test.TestCase): 'Subprocess worker-0 exited with exit code 10'): mpr.join() + def test_auto_restart(self): + + def proc_func(counter): + counter.value += 1 + if counter.value == 1: + raise ValueError + + manager = multi_process_runner.manager() + counter = manager.Value(int, 0) + mpr = multi_process_runner.MultiProcessRunner( + proc_func, + multi_worker_test_base.create_cluster_spec(num_workers=1), + args=(counter,), + auto_restart=True) + mpr.start() + mpr.join() + self.assertEqual(counter.value, 2) + + def test_auto_restart_and_timeout(self): + + def proc_func(): + time.sleep(1) + raise ValueError + + mpr = multi_process_runner.MultiProcessRunner( + proc_func, + multi_worker_test_base.create_cluster_spec(num_workers=1), + auto_restart=True) + mpr.start() + with self.assertRaises(multi_process_runner.SubprocessTimeoutError): + mpr.join(timeout=10) + + def test_auto_restart_and_chief(self): + # If the chief has exited with zero exit code, auto restart should stop + # restarting other tasks even if they fail. + + def proc_func(): + time.sleep(1) + if multi_worker_test_base.get_task_type() != 'chief': + raise ValueError + + manager = multi_process_runner.manager() + mpr = multi_process_runner.MultiProcessRunner( + proc_func, + multi_worker_test_base.create_cluster_spec( + has_chief=True, num_workers=1), + auto_restart=True) + mpr.start() + with self.assertRaises(ValueError): + mpr.join(timeout=10) + + def test_auto_restart_failure_immediate_after_restart(self): + # Test the case when worker-0 fails immediately after worker-1 restarts. + + def proc_func(): + time.sleep(5) + + mpr = multi_process_runner.MultiProcessRunner( + proc_func, + multi_worker_test_base.create_cluster_spec( + has_chief=False, num_workers=2), + auto_restart=True) + mpr.start() + pid = mpr.get_process_id('worker', 1) + mpr.terminate('worker', 1) + while mpr.get_process_id('worker', 1) == pid: + time.sleep(0.1) + mpr.terminate('worker', 0) + mpr.join(timeout=20) + + def test_auto_restart_terminate(self): + # Tasks terminated by the user should also be restarted. + + def proc_func(counter): + counter.value += 1 + if counter.value == 1: + time.sleep(100) + + manager = multi_process_runner.manager() + counter = manager.Value(int, 0) + + mpr = multi_process_runner.MultiProcessRunner( + proc_func, + multi_worker_test_base.create_cluster_spec( + has_chief=False, num_workers=1), + args=(counter,), + auto_restart=True) + mpr.start() + time.sleep(3) + mpr.terminate('worker', 0) + mpr.join(timeout=20) + self.assertEqual(counter.value, 2) + class MultiProcessPoolRunnerTest(test.TestCase):