From a9064fdddc9df688db1e9f3c8be9c25fc0fd8c14 Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Fri, 1 Mar 2019 19:21:20 -0800 Subject: [PATCH] Add a method which makes a best effort attempt to remove any remnants of the pool created when multiprocessing=True is set. PiperOrigin-RevId: 236417625 --- tensorflow/python/keras/BUILD | 1 - .../keras/engine/training_generator_test.py | 42 ++++- tensorflow/python/keras/utils/data_utils.py | 159 +++++++++++++++++- .../v1/tensorflow.keras.experimental.pbtxt | 4 + .../v2/tensorflow.keras.experimental.pbtxt | 4 + 5 files changed, 199 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 37d1f5ccf6e..57bd3c54495 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -1171,7 +1171,6 @@ tf_py_test( shard_count = 6, tags = [ "no_oss", - "notap", #TODO(b/123544294): Re-enable this test. "notsan", ], ) diff --git a/tensorflow/python/keras/engine/training_generator_test.py b/tensorflow/python/keras/engine/training_generator_test.py index 6b754c18b3d..393de931cf0 100644 --- a/tensorflow/python/keras/engine/training_generator_test.py +++ b/tensorflow/python/keras/engine/training_generator_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os +import time import unittest from absl.testing import parameterized @@ -34,6 +35,7 @@ from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import training_generator from tensorflow.python.keras.optimizer_v2 import rmsprop +from tensorflow.python.keras.utils import data_utils from tensorflow.python.platform import test from tensorflow.python.util import nest @@ -61,7 +63,40 @@ def custom_generator(mode=2): yield x, y, w -class TestGeneratorMethods(keras_parameterized.TestCase): +class ForkRobustTestCase(keras_parameterized.TestCase): + _sleep_at_end = False + + def setUp(self): + # When setting up a test simply make a best effort to start from a clean + # state. + self._starting_remnants = data_utils.terminate_keras_multiprocessing_pools( + use_sigkill=False) + + self._sleep_at_end = False + super(ForkRobustTestCase, self).setUp() + + def tearDown(self): + # Give multiprocessing pools some time to finish on their own before + # cleanup_all_keras_forkpools yanks the rug out from under them. This is + # particularly important because calling .close() on a pool that is already + # in the process of spinning down can cause an uncatchable segmentation + # fault at which point the tearDown will hang. + if self._sleep_at_end: + time.sleep(1) + + # If a test finishes and leaves behind uncleanable artifacts then that is a + # failure condition. However, if the state was not clean to begin with the + # test should not fail on that account. + new_remnants = set(data_utils.terminate_keras_multiprocessing_pools( + use_sigkill=True)).difference(self._starting_remnants) + + if new_remnants: + raise ValueError('Test left behind stubborn orphans:\n {}'.format( + '\n '.join(new_remnants))) + super(ForkRobustTestCase, self).tearDown() + + +class TestGeneratorMethods(ForkRobustTestCase): @unittest.skipIf( os.name == 'nt', @@ -76,6 +111,7 @@ class TestGeneratorMethods(keras_parameterized.TestCase): optimizer=rmsprop.RMSprop(1e-3), metrics=['mae', metrics_module.CategoricalAccuracy()]) + self._sleep_at_end = True model.fit_generator(custom_generator(), steps_per_epoch=5, epochs=1, @@ -117,6 +153,7 @@ class TestGeneratorMethods(keras_parameterized.TestCase): metrics=['mae', metrics_module.CategoricalAccuracy()], run_eagerly=testing_utils.should_run_eagerly()) + self._sleep_at_end = True model.evaluate_generator(custom_generator(), steps=5, max_queue_size=10, @@ -143,6 +180,7 @@ class TestGeneratorMethods(keras_parameterized.TestCase): num_hidden=3, num_classes=4, input_dim=2) model.run_eagerly = testing_utils.should_run_eagerly() + self._sleep_at_end = True model.predict_generator(custom_generator(), steps=5, max_queue_size=10, @@ -268,7 +306,7 @@ class TestGeneratorMethods(keras_parameterized.TestCase): model.predict(ones_generator(), steps=2) -class TestGeneratorMethodsWithSequences(keras_parameterized.TestCase): +class TestGeneratorMethodsWithSequences(ForkRobustTestCase): @keras_parameterized.run_with_all_model_types @keras_parameterized.run_all_keras_modes diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py index 0f6e89b4d27..3b02e0342f7 100644 --- a/tensorflow/python/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -20,16 +20,19 @@ from __future__ import print_function from abc import abstractmethod from contextlib import closing +import gc import hashlib import multiprocessing from multiprocessing.pool import ThreadPool import os import random import shutil +import signal import sys import tarfile import threading import time +import weakref import zipfile import numpy as np @@ -423,11 +426,126 @@ _SHARED_SEQUENCES = {} _SEQUENCE_COUNTER = None +# Because multiprocessing pools are inherently unsafe, starting from a clean +# state can be essential to avoiding deadlocks. In order to accomplish this, we +# need to be able to check on the status of Pools that we create. +_DATA_POOLS = weakref.WeakSet() +_WORKER_ID_QUEUE = multiprocessing.Queue() +_WORKER_IDS = set() + + def init_pool(seqs): global _SHARED_SEQUENCES _SHARED_SEQUENCES = seqs +@keras_export('keras.experimental.terminate_keras_multiprocessing_pools') +def terminate_keras_multiprocessing_pools(grace_period=0.1, use_sigkill=False): + """Destroy Keras' multiprocessing pools to prevent deadlocks. + + In general multiprocessing.Pool can interact quite badly with other, seemingly + unrelated, parts of a codebase due to Pool's reliance on fork. This method + cleans up all pools which are known to belong to Keras (and thus can be safely + terminated). + + Args: + grace_period: Time (in seconds) to wait for process cleanup to propagate. + use_sigkill: Boolean of whether or not to perform a cleanup pass using + SIGKILL. + + Returns: + A list of human readable strings describing all issues encountered. It is up + to the caller to decide whether to treat this as an error condition. + """ + errors = [] + + # First cleanup the pools spawned by Keras. If we start killing workers and + # a parent pool is still alive it will just spawn replacements which we don't + # want. + gc.collect() + for pool in _DATA_POOLS: + pool.close() + pool.terminate() + # We do not join the pool, because that would wait forever if a worker + # refused to exit. + + # Finally, delete our reference to the pool so that we do not block garbage + # collection. + del pool + + # If there were any pools, sleep for a small grace period to allow everything + # to finalize. + if _DATA_POOLS: + time.sleep(grace_period) + + # Now we kill any workers which are still alive. However we must compare + # the worker identifier to the set of identifiers which are known to have been + # spawned by pools belonging to Keras to avoid deleting unrelated workers. + # First we call the .terminate() method of a worker, and then if it still + # persists we directly send a signal to the process. Certain worker tasks may + # be able to gracefully handle shutdown, so we send a SIGTERM and then + # optionally follow up with a SIGKILL. + visited_workers = set() + cleanup_passes = ['.terminate', 'SIGTERM'] + if use_sigkill: + cleanup_passes.append('SIGKILL') + cleanup_passes.append('log') + + for cleanup_pass in cleanup_passes: + while True: + # In rare cases, queue.qsize() overestimates the number of elements. This + # loop is designed to be more robust. + try: + _WORKER_IDS.add(_WORKER_ID_QUEUE.get_nowait()) + except queue.Empty: + break + + gc.collect() + workers_terminated_this_pass = False + for worker in multiprocessing.active_children(): + ident = worker.ident + if ident in _WORKER_IDS and worker.is_alive(): + try: + if cleanup_pass == '.terminate': + # First we ask nicely. + worker.terminate() + worker.join(timeout=grace_period) + visited_workers.add(ident) + workers_terminated_this_pass = True + elif cleanup_pass in ('SIGTERM', 'SIGKILL'): + # Then we ask increasingly tersely. + os.kill(worker.pid, signal.SIGKILL if cleanup_pass == 'SIGKILL' + else signal.SIGTERM) + workers_terminated_this_pass = True + + elif cleanup_pass == 'log': + # And finally we give up and log the failure. + errors.append('worker still alive: {}, pid={}, hash={}' + .format(worker.name, worker.pid, hash(worker))) + + except OSError: + # Worker exited since the start of this loop. + pass + + if workers_terminated_this_pass: + # There can be a small propagation delay between worker destruction and + # workers reporting False for is_alive and no longer appearing in the + # list of active children. Once again, we sleep for a small grace period. + # This prevents false positives from workers which are simply still in the + # process of spinning down. + time.sleep(grace_period) + + # Finally we remove the visited worker ids to handle the edge case that a + # pid is reused. + _WORKER_IDS.difference_update(visited_workers) + + gc.collect() + for pool in _DATA_POOLS: + errors.append('pool still exists: {}, hash={}'.format(pool, hash(pool))) + + return errors + + def get_index(uid, i): """Get the value from the Sequence `uid` at index `i`. @@ -596,8 +714,11 @@ class OrderedEnqueuer(SequenceEnqueuer): Function, a Function to initialize the pool """ def pool_fn(seqs): - return multiprocessing.Pool( - workers, initializer=init_pool_generator, initargs=(seqs, None)) + pool = multiprocessing.Pool( + workers, initializer=init_pool_generator, + initargs=(seqs, None, _WORKER_ID_QUEUE)) + _DATA_POOLS.add(pool) + return pool return pool_fn @@ -620,6 +741,7 @@ class OrderedEnqueuer(SequenceEnqueuer): for i in sequence: if self.stop_signal.is_set(): return + self.queue.put( executor.apply_async(get_index, (self.uid, i)), block=True) @@ -655,13 +777,31 @@ class OrderedEnqueuer(SequenceEnqueuer): six.reraise(*sys.exc_info()) -def init_pool_generator(gens, random_seed=None): +def init_pool_generator(gens, random_seed=None, id_queue=None): + """Initializer function for pool workers. + + Args: + gens: State which should be made available to worker processes. + random_seed: An optional value with which to seed child processes. + id_queue: A multiprocessing Queue of worker ids. This is used to indicate + that a worker process was created by Keras and can be terminated using + the cleanup_all_keras_forkpools utility. + """ global _SHARED_SEQUENCES _SHARED_SEQUENCES = gens + worker_proc = multiprocessing.current_process() + + # name isn't used for anything, but setting a more descriptive name is helpful + # when diagnosing orphaned processes. + worker_proc.name = 'Keras_worker_{}'.format(worker_proc.name) + if random_seed is not None: - ident = multiprocessing.current_process().ident - np.random.seed(random_seed + ident) + np.random.seed(random_seed + worker_proc.ident) + + if id_queue is not None: + # If a worker dies during init, the pool will just create a replacement. + id_queue.put(worker_proc.ident, block=True, timeout=0.1) def next_sample(uid): @@ -713,9 +853,11 @@ class GeneratorEnqueuer(SequenceEnqueuer): A Function to initialize the pool """ def pool_fn(seqs): - return multiprocessing.Pool(workers, - initializer=init_pool_generator, - initargs=(seqs, self.random_seed)) + pool = multiprocessing.Pool( + workers, initializer=init_pool_generator, + initargs=(seqs, self.random_seed, _WORKER_ID_QUEUE)) + _DATA_POOLS.add(pool) + return pool return pool_fn def _run(self): @@ -725,6 +867,7 @@ class GeneratorEnqueuer(SequenceEnqueuer): while True: if self.stop_signal.is_set(): return + self.queue.put( executor.apply_async(next_sample, (self.uid,)), block=True) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt index 65b82a3f322..bfd169a9b35 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt @@ -32,4 +32,8 @@ tf_module { name: "load_from_saved_model" argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "terminate_keras_multiprocessing_pools" + argspec: "args=[\'grace_period\', \'use_sigkill\'], varargs=None, keywords=None, defaults=[\'0.1\', \'False\'], " + } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.pbtxt index 65b82a3f322..bfd169a9b35 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.pbtxt @@ -32,4 +32,8 @@ tf_module { name: "load_from_saved_model" argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "terminate_keras_multiprocessing_pools" + argspec: "args=[\'grace_period\', \'use_sigkill\'], varargs=None, keywords=None, defaults=[\'0.1\', \'False\'], " + } }