Add a method which makes a best effort attempt to remove any remnants of the pool created when multiprocessing=True is set.

PiperOrigin-RevId: 236417625
This commit is contained in:
Taylor Robie 2019-03-01 19:21:20 -08:00 committed by TensorFlower Gardener
parent 4de282cf8a
commit a9064fdddc
5 changed files with 199 additions and 11 deletions

View File

@ -1171,7 +1171,6 @@ tf_py_test(
shard_count = 6, shard_count = 6,
tags = [ tags = [
"no_oss", "no_oss",
"notap", #TODO(b/123544294): Re-enable this test.
"notsan", "notsan",
], ],
) )

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import time
import unittest import unittest
from absl.testing import parameterized from absl.testing import parameterized
@ -34,6 +35,7 @@ from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import training_generator from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.optimizer_v2 import rmsprop from tensorflow.python.keras.optimizer_v2 import rmsprop
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.util import nest from tensorflow.python.util import nest
@ -61,7 +63,40 @@ def custom_generator(mode=2):
yield x, y, w yield x, y, w
class TestGeneratorMethods(keras_parameterized.TestCase): class ForkRobustTestCase(keras_parameterized.TestCase):
_sleep_at_end = False
def setUp(self):
# When setting up a test simply make a best effort to start from a clean
# state.
self._starting_remnants = data_utils.terminate_keras_multiprocessing_pools(
use_sigkill=False)
self._sleep_at_end = False
super(ForkRobustTestCase, self).setUp()
def tearDown(self):
# Give multiprocessing pools some time to finish on their own before
# cleanup_all_keras_forkpools yanks the rug out from under them. This is
# particularly important because calling .close() on a pool that is already
# in the process of spinning down can cause an uncatchable segmentation
# fault at which point the tearDown will hang.
if self._sleep_at_end:
time.sleep(1)
# If a test finishes and leaves behind uncleanable artifacts then that is a
# failure condition. However, if the state was not clean to begin with the
# test should not fail on that account.
new_remnants = set(data_utils.terminate_keras_multiprocessing_pools(
use_sigkill=True)).difference(self._starting_remnants)
if new_remnants:
raise ValueError('Test left behind stubborn orphans:\n {}'.format(
'\n '.join(new_remnants)))
super(ForkRobustTestCase, self).tearDown()
class TestGeneratorMethods(ForkRobustTestCase):
@unittest.skipIf( @unittest.skipIf(
os.name == 'nt', os.name == 'nt',
@ -76,6 +111,7 @@ class TestGeneratorMethods(keras_parameterized.TestCase):
optimizer=rmsprop.RMSprop(1e-3), optimizer=rmsprop.RMSprop(1e-3),
metrics=['mae', metrics_module.CategoricalAccuracy()]) metrics=['mae', metrics_module.CategoricalAccuracy()])
self._sleep_at_end = True
model.fit_generator(custom_generator(), model.fit_generator(custom_generator(),
steps_per_epoch=5, steps_per_epoch=5,
epochs=1, epochs=1,
@ -117,6 +153,7 @@ class TestGeneratorMethods(keras_parameterized.TestCase):
metrics=['mae', metrics_module.CategoricalAccuracy()], metrics=['mae', metrics_module.CategoricalAccuracy()],
run_eagerly=testing_utils.should_run_eagerly()) run_eagerly=testing_utils.should_run_eagerly())
self._sleep_at_end = True
model.evaluate_generator(custom_generator(), model.evaluate_generator(custom_generator(),
steps=5, steps=5,
max_queue_size=10, max_queue_size=10,
@ -143,6 +180,7 @@ class TestGeneratorMethods(keras_parameterized.TestCase):
num_hidden=3, num_classes=4, input_dim=2) num_hidden=3, num_classes=4, input_dim=2)
model.run_eagerly = testing_utils.should_run_eagerly() model.run_eagerly = testing_utils.should_run_eagerly()
self._sleep_at_end = True
model.predict_generator(custom_generator(), model.predict_generator(custom_generator(),
steps=5, steps=5,
max_queue_size=10, max_queue_size=10,
@ -268,7 +306,7 @@ class TestGeneratorMethods(keras_parameterized.TestCase):
model.predict(ones_generator(), steps=2) model.predict(ones_generator(), steps=2)
class TestGeneratorMethodsWithSequences(keras_parameterized.TestCase): class TestGeneratorMethodsWithSequences(ForkRobustTestCase):
@keras_parameterized.run_with_all_model_types @keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes

View File

@ -20,16 +20,19 @@ from __future__ import print_function
from abc import abstractmethod from abc import abstractmethod
from contextlib import closing from contextlib import closing
import gc
import hashlib import hashlib
import multiprocessing import multiprocessing
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
import os import os
import random import random
import shutil import shutil
import signal
import sys import sys
import tarfile import tarfile
import threading import threading
import time import time
import weakref
import zipfile import zipfile
import numpy as np import numpy as np
@ -423,11 +426,126 @@ _SHARED_SEQUENCES = {}
_SEQUENCE_COUNTER = None _SEQUENCE_COUNTER = None
# Because multiprocessing pools are inherently unsafe, starting from a clean
# state can be essential to avoiding deadlocks. In order to accomplish this, we
# need to be able to check on the status of Pools that we create.
_DATA_POOLS = weakref.WeakSet()
_WORKER_ID_QUEUE = multiprocessing.Queue()
_WORKER_IDS = set()
def init_pool(seqs): def init_pool(seqs):
global _SHARED_SEQUENCES global _SHARED_SEQUENCES
_SHARED_SEQUENCES = seqs _SHARED_SEQUENCES = seqs
@keras_export('keras.experimental.terminate_keras_multiprocessing_pools')
def terminate_keras_multiprocessing_pools(grace_period=0.1, use_sigkill=False):
"""Destroy Keras' multiprocessing pools to prevent deadlocks.
In general multiprocessing.Pool can interact quite badly with other, seemingly
unrelated, parts of a codebase due to Pool's reliance on fork. This method
cleans up all pools which are known to belong to Keras (and thus can be safely
terminated).
Args:
grace_period: Time (in seconds) to wait for process cleanup to propagate.
use_sigkill: Boolean of whether or not to perform a cleanup pass using
SIGKILL.
Returns:
A list of human readable strings describing all issues encountered. It is up
to the caller to decide whether to treat this as an error condition.
"""
errors = []
# First cleanup the pools spawned by Keras. If we start killing workers and
# a parent pool is still alive it will just spawn replacements which we don't
# want.
gc.collect()
for pool in _DATA_POOLS:
pool.close()
pool.terminate()
# We do not join the pool, because that would wait forever if a worker
# refused to exit.
# Finally, delete our reference to the pool so that we do not block garbage
# collection.
del pool
# If there were any pools, sleep for a small grace period to allow everything
# to finalize.
if _DATA_POOLS:
time.sleep(grace_period)
# Now we kill any workers which are still alive. However we must compare
# the worker identifier to the set of identifiers which are known to have been
# spawned by pools belonging to Keras to avoid deleting unrelated workers.
# First we call the .terminate() method of a worker, and then if it still
# persists we directly send a signal to the process. Certain worker tasks may
# be able to gracefully handle shutdown, so we send a SIGTERM and then
# optionally follow up with a SIGKILL.
visited_workers = set()
cleanup_passes = ['.terminate', 'SIGTERM']
if use_sigkill:
cleanup_passes.append('SIGKILL')
cleanup_passes.append('log')
for cleanup_pass in cleanup_passes:
while True:
# In rare cases, queue.qsize() overestimates the number of elements. This
# loop is designed to be more robust.
try:
_WORKER_IDS.add(_WORKER_ID_QUEUE.get_nowait())
except queue.Empty:
break
gc.collect()
workers_terminated_this_pass = False
for worker in multiprocessing.active_children():
ident = worker.ident
if ident in _WORKER_IDS and worker.is_alive():
try:
if cleanup_pass == '.terminate':
# First we ask nicely.
worker.terminate()
worker.join(timeout=grace_period)
visited_workers.add(ident)
workers_terminated_this_pass = True
elif cleanup_pass in ('SIGTERM', 'SIGKILL'):
# Then we ask increasingly tersely.
os.kill(worker.pid, signal.SIGKILL if cleanup_pass == 'SIGKILL'
else signal.SIGTERM)
workers_terminated_this_pass = True
elif cleanup_pass == 'log':
# And finally we give up and log the failure.
errors.append('worker still alive: {}, pid={}, hash={}'
.format(worker.name, worker.pid, hash(worker)))
except OSError:
# Worker exited since the start of this loop.
pass
if workers_terminated_this_pass:
# There can be a small propagation delay between worker destruction and
# workers reporting False for is_alive and no longer appearing in the
# list of active children. Once again, we sleep for a small grace period.
# This prevents false positives from workers which are simply still in the
# process of spinning down.
time.sleep(grace_period)
# Finally we remove the visited worker ids to handle the edge case that a
# pid is reused.
_WORKER_IDS.difference_update(visited_workers)
gc.collect()
for pool in _DATA_POOLS:
errors.append('pool still exists: {}, hash={}'.format(pool, hash(pool)))
return errors
def get_index(uid, i): def get_index(uid, i):
"""Get the value from the Sequence `uid` at index `i`. """Get the value from the Sequence `uid` at index `i`.
@ -596,8 +714,11 @@ class OrderedEnqueuer(SequenceEnqueuer):
Function, a Function to initialize the pool Function, a Function to initialize the pool
""" """
def pool_fn(seqs): def pool_fn(seqs):
return multiprocessing.Pool( pool = multiprocessing.Pool(
workers, initializer=init_pool_generator, initargs=(seqs, None)) workers, initializer=init_pool_generator,
initargs=(seqs, None, _WORKER_ID_QUEUE))
_DATA_POOLS.add(pool)
return pool
return pool_fn return pool_fn
@ -620,6 +741,7 @@ class OrderedEnqueuer(SequenceEnqueuer):
for i in sequence: for i in sequence:
if self.stop_signal.is_set(): if self.stop_signal.is_set():
return return
self.queue.put( self.queue.put(
executor.apply_async(get_index, (self.uid, i)), block=True) executor.apply_async(get_index, (self.uid, i)), block=True)
@ -655,13 +777,31 @@ class OrderedEnqueuer(SequenceEnqueuer):
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
def init_pool_generator(gens, random_seed=None): def init_pool_generator(gens, random_seed=None, id_queue=None):
"""Initializer function for pool workers.
Args:
gens: State which should be made available to worker processes.
random_seed: An optional value with which to seed child processes.
id_queue: A multiprocessing Queue of worker ids. This is used to indicate
that a worker process was created by Keras and can be terminated using
the cleanup_all_keras_forkpools utility.
"""
global _SHARED_SEQUENCES global _SHARED_SEQUENCES
_SHARED_SEQUENCES = gens _SHARED_SEQUENCES = gens
worker_proc = multiprocessing.current_process()
# name isn't used for anything, but setting a more descriptive name is helpful
# when diagnosing orphaned processes.
worker_proc.name = 'Keras_worker_{}'.format(worker_proc.name)
if random_seed is not None: if random_seed is not None:
ident = multiprocessing.current_process().ident np.random.seed(random_seed + worker_proc.ident)
np.random.seed(random_seed + ident)
if id_queue is not None:
# If a worker dies during init, the pool will just create a replacement.
id_queue.put(worker_proc.ident, block=True, timeout=0.1)
def next_sample(uid): def next_sample(uid):
@ -713,9 +853,11 @@ class GeneratorEnqueuer(SequenceEnqueuer):
A Function to initialize the pool A Function to initialize the pool
""" """
def pool_fn(seqs): def pool_fn(seqs):
return multiprocessing.Pool(workers, pool = multiprocessing.Pool(
initializer=init_pool_generator, workers, initializer=init_pool_generator,
initargs=(seqs, self.random_seed)) initargs=(seqs, self.random_seed, _WORKER_ID_QUEUE))
_DATA_POOLS.add(pool)
return pool
return pool_fn return pool_fn
def _run(self): def _run(self):
@ -725,6 +867,7 @@ class GeneratorEnqueuer(SequenceEnqueuer):
while True: while True:
if self.stop_signal.is_set(): if self.stop_signal.is_set():
return return
self.queue.put( self.queue.put(
executor.apply_async(next_sample, (self.uid,)), block=True) executor.apply_async(next_sample, (self.uid,)), block=True)

View File

@ -32,4 +32,8 @@ tf_module {
name: "load_from_saved_model" name: "load_from_saved_model"
argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "terminate_keras_multiprocessing_pools"
argspec: "args=[\'grace_period\', \'use_sigkill\'], varargs=None, keywords=None, defaults=[\'0.1\', \'False\'], "
}
} }

View File

@ -32,4 +32,8 @@ tf_module {
name: "load_from_saved_model" name: "load_from_saved_model"
argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "terminate_keras_multiprocessing_pools"
argspec: "args=[\'grace_period\', \'use_sigkill\'], varargs=None, keywords=None, defaults=[\'0.1\', \'False\'], "
}
} }