Add a method which makes a best effort attempt to remove any remnants of the pool created when multiprocessing=True is set.
PiperOrigin-RevId: 236417625
This commit is contained in:
parent
4de282cf8a
commit
a9064fdddc
@ -1171,7 +1171,6 @@ tf_py_test(
|
|||||||
shard_count = 6,
|
shard_count = 6,
|
||||||
tags = [
|
tags = [
|
||||||
"no_oss",
|
"no_oss",
|
||||||
"notap", #TODO(b/123544294): Re-enable this test.
|
|
||||||
"notsan",
|
"notsan",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
@ -34,6 +35,7 @@ from tensorflow.python.keras import metrics as metrics_module
|
|||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.keras.engine import training_generator
|
from tensorflow.python.keras.engine import training_generator
|
||||||
from tensorflow.python.keras.optimizer_v2 import rmsprop
|
from tensorflow.python.keras.optimizer_v2 import rmsprop
|
||||||
|
from tensorflow.python.keras.utils import data_utils
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
@ -61,7 +63,40 @@ def custom_generator(mode=2):
|
|||||||
yield x, y, w
|
yield x, y, w
|
||||||
|
|
||||||
|
|
||||||
class TestGeneratorMethods(keras_parameterized.TestCase):
|
class ForkRobustTestCase(keras_parameterized.TestCase):
|
||||||
|
_sleep_at_end = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# When setting up a test simply make a best effort to start from a clean
|
||||||
|
# state.
|
||||||
|
self._starting_remnants = data_utils.terminate_keras_multiprocessing_pools(
|
||||||
|
use_sigkill=False)
|
||||||
|
|
||||||
|
self._sleep_at_end = False
|
||||||
|
super(ForkRobustTestCase, self).setUp()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
# Give multiprocessing pools some time to finish on their own before
|
||||||
|
# cleanup_all_keras_forkpools yanks the rug out from under them. This is
|
||||||
|
# particularly important because calling .close() on a pool that is already
|
||||||
|
# in the process of spinning down can cause an uncatchable segmentation
|
||||||
|
# fault at which point the tearDown will hang.
|
||||||
|
if self._sleep_at_end:
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# If a test finishes and leaves behind uncleanable artifacts then that is a
|
||||||
|
# failure condition. However, if the state was not clean to begin with the
|
||||||
|
# test should not fail on that account.
|
||||||
|
new_remnants = set(data_utils.terminate_keras_multiprocessing_pools(
|
||||||
|
use_sigkill=True)).difference(self._starting_remnants)
|
||||||
|
|
||||||
|
if new_remnants:
|
||||||
|
raise ValueError('Test left behind stubborn orphans:\n {}'.format(
|
||||||
|
'\n '.join(new_remnants)))
|
||||||
|
super(ForkRobustTestCase, self).tearDown()
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeneratorMethods(ForkRobustTestCase):
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
os.name == 'nt',
|
os.name == 'nt',
|
||||||
@ -76,6 +111,7 @@ class TestGeneratorMethods(keras_parameterized.TestCase):
|
|||||||
optimizer=rmsprop.RMSprop(1e-3),
|
optimizer=rmsprop.RMSprop(1e-3),
|
||||||
metrics=['mae', metrics_module.CategoricalAccuracy()])
|
metrics=['mae', metrics_module.CategoricalAccuracy()])
|
||||||
|
|
||||||
|
self._sleep_at_end = True
|
||||||
model.fit_generator(custom_generator(),
|
model.fit_generator(custom_generator(),
|
||||||
steps_per_epoch=5,
|
steps_per_epoch=5,
|
||||||
epochs=1,
|
epochs=1,
|
||||||
@ -117,6 +153,7 @@ class TestGeneratorMethods(keras_parameterized.TestCase):
|
|||||||
metrics=['mae', metrics_module.CategoricalAccuracy()],
|
metrics=['mae', metrics_module.CategoricalAccuracy()],
|
||||||
run_eagerly=testing_utils.should_run_eagerly())
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
|
||||||
|
self._sleep_at_end = True
|
||||||
model.evaluate_generator(custom_generator(),
|
model.evaluate_generator(custom_generator(),
|
||||||
steps=5,
|
steps=5,
|
||||||
max_queue_size=10,
|
max_queue_size=10,
|
||||||
@ -143,6 +180,7 @@ class TestGeneratorMethods(keras_parameterized.TestCase):
|
|||||||
num_hidden=3, num_classes=4, input_dim=2)
|
num_hidden=3, num_classes=4, input_dim=2)
|
||||||
model.run_eagerly = testing_utils.should_run_eagerly()
|
model.run_eagerly = testing_utils.should_run_eagerly()
|
||||||
|
|
||||||
|
self._sleep_at_end = True
|
||||||
model.predict_generator(custom_generator(),
|
model.predict_generator(custom_generator(),
|
||||||
steps=5,
|
steps=5,
|
||||||
max_queue_size=10,
|
max_queue_size=10,
|
||||||
@ -268,7 +306,7 @@ class TestGeneratorMethods(keras_parameterized.TestCase):
|
|||||||
model.predict(ones_generator(), steps=2)
|
model.predict(ones_generator(), steps=2)
|
||||||
|
|
||||||
|
|
||||||
class TestGeneratorMethodsWithSequences(keras_parameterized.TestCase):
|
class TestGeneratorMethodsWithSequences(ForkRobustTestCase):
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
@ -20,16 +20,19 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
|
import gc
|
||||||
import hashlib
|
import hashlib
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from multiprocessing.pool import ThreadPool
|
from multiprocessing.pool import ThreadPool
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import tarfile
|
import tarfile
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import weakref
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -423,11 +426,126 @@ _SHARED_SEQUENCES = {}
|
|||||||
_SEQUENCE_COUNTER = None
|
_SEQUENCE_COUNTER = None
|
||||||
|
|
||||||
|
|
||||||
|
# Because multiprocessing pools are inherently unsafe, starting from a clean
|
||||||
|
# state can be essential to avoiding deadlocks. In order to accomplish this, we
|
||||||
|
# need to be able to check on the status of Pools that we create.
|
||||||
|
_DATA_POOLS = weakref.WeakSet()
|
||||||
|
_WORKER_ID_QUEUE = multiprocessing.Queue()
|
||||||
|
_WORKER_IDS = set()
|
||||||
|
|
||||||
|
|
||||||
def init_pool(seqs):
|
def init_pool(seqs):
|
||||||
global _SHARED_SEQUENCES
|
global _SHARED_SEQUENCES
|
||||||
_SHARED_SEQUENCES = seqs
|
_SHARED_SEQUENCES = seqs
|
||||||
|
|
||||||
|
|
||||||
|
@keras_export('keras.experimental.terminate_keras_multiprocessing_pools')
|
||||||
|
def terminate_keras_multiprocessing_pools(grace_period=0.1, use_sigkill=False):
|
||||||
|
"""Destroy Keras' multiprocessing pools to prevent deadlocks.
|
||||||
|
|
||||||
|
In general multiprocessing.Pool can interact quite badly with other, seemingly
|
||||||
|
unrelated, parts of a codebase due to Pool's reliance on fork. This method
|
||||||
|
cleans up all pools which are known to belong to Keras (and thus can be safely
|
||||||
|
terminated).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grace_period: Time (in seconds) to wait for process cleanup to propagate.
|
||||||
|
use_sigkill: Boolean of whether or not to perform a cleanup pass using
|
||||||
|
SIGKILL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of human readable strings describing all issues encountered. It is up
|
||||||
|
to the caller to decide whether to treat this as an error condition.
|
||||||
|
"""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
# First cleanup the pools spawned by Keras. If we start killing workers and
|
||||||
|
# a parent pool is still alive it will just spawn replacements which we don't
|
||||||
|
# want.
|
||||||
|
gc.collect()
|
||||||
|
for pool in _DATA_POOLS:
|
||||||
|
pool.close()
|
||||||
|
pool.terminate()
|
||||||
|
# We do not join the pool, because that would wait forever if a worker
|
||||||
|
# refused to exit.
|
||||||
|
|
||||||
|
# Finally, delete our reference to the pool so that we do not block garbage
|
||||||
|
# collection.
|
||||||
|
del pool
|
||||||
|
|
||||||
|
# If there were any pools, sleep for a small grace period to allow everything
|
||||||
|
# to finalize.
|
||||||
|
if _DATA_POOLS:
|
||||||
|
time.sleep(grace_period)
|
||||||
|
|
||||||
|
# Now we kill any workers which are still alive. However we must compare
|
||||||
|
# the worker identifier to the set of identifiers which are known to have been
|
||||||
|
# spawned by pools belonging to Keras to avoid deleting unrelated workers.
|
||||||
|
# First we call the .terminate() method of a worker, and then if it still
|
||||||
|
# persists we directly send a signal to the process. Certain worker tasks may
|
||||||
|
# be able to gracefully handle shutdown, so we send a SIGTERM and then
|
||||||
|
# optionally follow up with a SIGKILL.
|
||||||
|
visited_workers = set()
|
||||||
|
cleanup_passes = ['.terminate', 'SIGTERM']
|
||||||
|
if use_sigkill:
|
||||||
|
cleanup_passes.append('SIGKILL')
|
||||||
|
cleanup_passes.append('log')
|
||||||
|
|
||||||
|
for cleanup_pass in cleanup_passes:
|
||||||
|
while True:
|
||||||
|
# In rare cases, queue.qsize() overestimates the number of elements. This
|
||||||
|
# loop is designed to be more robust.
|
||||||
|
try:
|
||||||
|
_WORKER_IDS.add(_WORKER_ID_QUEUE.get_nowait())
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
workers_terminated_this_pass = False
|
||||||
|
for worker in multiprocessing.active_children():
|
||||||
|
ident = worker.ident
|
||||||
|
if ident in _WORKER_IDS and worker.is_alive():
|
||||||
|
try:
|
||||||
|
if cleanup_pass == '.terminate':
|
||||||
|
# First we ask nicely.
|
||||||
|
worker.terminate()
|
||||||
|
worker.join(timeout=grace_period)
|
||||||
|
visited_workers.add(ident)
|
||||||
|
workers_terminated_this_pass = True
|
||||||
|
elif cleanup_pass in ('SIGTERM', 'SIGKILL'):
|
||||||
|
# Then we ask increasingly tersely.
|
||||||
|
os.kill(worker.pid, signal.SIGKILL if cleanup_pass == 'SIGKILL'
|
||||||
|
else signal.SIGTERM)
|
||||||
|
workers_terminated_this_pass = True
|
||||||
|
|
||||||
|
elif cleanup_pass == 'log':
|
||||||
|
# And finally we give up and log the failure.
|
||||||
|
errors.append('worker still alive: {}, pid={}, hash={}'
|
||||||
|
.format(worker.name, worker.pid, hash(worker)))
|
||||||
|
|
||||||
|
except OSError:
|
||||||
|
# Worker exited since the start of this loop.
|
||||||
|
pass
|
||||||
|
|
||||||
|
if workers_terminated_this_pass:
|
||||||
|
# There can be a small propagation delay between worker destruction and
|
||||||
|
# workers reporting False for is_alive and no longer appearing in the
|
||||||
|
# list of active children. Once again, we sleep for a small grace period.
|
||||||
|
# This prevents false positives from workers which are simply still in the
|
||||||
|
# process of spinning down.
|
||||||
|
time.sleep(grace_period)
|
||||||
|
|
||||||
|
# Finally we remove the visited worker ids to handle the edge case that a
|
||||||
|
# pid is reused.
|
||||||
|
_WORKER_IDS.difference_update(visited_workers)
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
for pool in _DATA_POOLS:
|
||||||
|
errors.append('pool still exists: {}, hash={}'.format(pool, hash(pool)))
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
def get_index(uid, i):
|
def get_index(uid, i):
|
||||||
"""Get the value from the Sequence `uid` at index `i`.
|
"""Get the value from the Sequence `uid` at index `i`.
|
||||||
|
|
||||||
@ -596,8 +714,11 @@ class OrderedEnqueuer(SequenceEnqueuer):
|
|||||||
Function, a Function to initialize the pool
|
Function, a Function to initialize the pool
|
||||||
"""
|
"""
|
||||||
def pool_fn(seqs):
|
def pool_fn(seqs):
|
||||||
return multiprocessing.Pool(
|
pool = multiprocessing.Pool(
|
||||||
workers, initializer=init_pool_generator, initargs=(seqs, None))
|
workers, initializer=init_pool_generator,
|
||||||
|
initargs=(seqs, None, _WORKER_ID_QUEUE))
|
||||||
|
_DATA_POOLS.add(pool)
|
||||||
|
return pool
|
||||||
|
|
||||||
return pool_fn
|
return pool_fn
|
||||||
|
|
||||||
@ -620,6 +741,7 @@ class OrderedEnqueuer(SequenceEnqueuer):
|
|||||||
for i in sequence:
|
for i in sequence:
|
||||||
if self.stop_signal.is_set():
|
if self.stop_signal.is_set():
|
||||||
return
|
return
|
||||||
|
|
||||||
self.queue.put(
|
self.queue.put(
|
||||||
executor.apply_async(get_index, (self.uid, i)), block=True)
|
executor.apply_async(get_index, (self.uid, i)), block=True)
|
||||||
|
|
||||||
@ -655,13 +777,31 @@ class OrderedEnqueuer(SequenceEnqueuer):
|
|||||||
six.reraise(*sys.exc_info())
|
six.reraise(*sys.exc_info())
|
||||||
|
|
||||||
|
|
||||||
def init_pool_generator(gens, random_seed=None):
|
def init_pool_generator(gens, random_seed=None, id_queue=None):
|
||||||
|
"""Initializer function for pool workers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gens: State which should be made available to worker processes.
|
||||||
|
random_seed: An optional value with which to seed child processes.
|
||||||
|
id_queue: A multiprocessing Queue of worker ids. This is used to indicate
|
||||||
|
that a worker process was created by Keras and can be terminated using
|
||||||
|
the cleanup_all_keras_forkpools utility.
|
||||||
|
"""
|
||||||
global _SHARED_SEQUENCES
|
global _SHARED_SEQUENCES
|
||||||
_SHARED_SEQUENCES = gens
|
_SHARED_SEQUENCES = gens
|
||||||
|
|
||||||
|
worker_proc = multiprocessing.current_process()
|
||||||
|
|
||||||
|
# name isn't used for anything, but setting a more descriptive name is helpful
|
||||||
|
# when diagnosing orphaned processes.
|
||||||
|
worker_proc.name = 'Keras_worker_{}'.format(worker_proc.name)
|
||||||
|
|
||||||
if random_seed is not None:
|
if random_seed is not None:
|
||||||
ident = multiprocessing.current_process().ident
|
np.random.seed(random_seed + worker_proc.ident)
|
||||||
np.random.seed(random_seed + ident)
|
|
||||||
|
if id_queue is not None:
|
||||||
|
# If a worker dies during init, the pool will just create a replacement.
|
||||||
|
id_queue.put(worker_proc.ident, block=True, timeout=0.1)
|
||||||
|
|
||||||
|
|
||||||
def next_sample(uid):
|
def next_sample(uid):
|
||||||
@ -713,9 +853,11 @@ class GeneratorEnqueuer(SequenceEnqueuer):
|
|||||||
A Function to initialize the pool
|
A Function to initialize the pool
|
||||||
"""
|
"""
|
||||||
def pool_fn(seqs):
|
def pool_fn(seqs):
|
||||||
return multiprocessing.Pool(workers,
|
pool = multiprocessing.Pool(
|
||||||
initializer=init_pool_generator,
|
workers, initializer=init_pool_generator,
|
||||||
initargs=(seqs, self.random_seed))
|
initargs=(seqs, self.random_seed, _WORKER_ID_QUEUE))
|
||||||
|
_DATA_POOLS.add(pool)
|
||||||
|
return pool
|
||||||
return pool_fn
|
return pool_fn
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
@ -725,6 +867,7 @@ class GeneratorEnqueuer(SequenceEnqueuer):
|
|||||||
while True:
|
while True:
|
||||||
if self.stop_signal.is_set():
|
if self.stop_signal.is_set():
|
||||||
return
|
return
|
||||||
|
|
||||||
self.queue.put(
|
self.queue.put(
|
||||||
executor.apply_async(next_sample, (self.uid,)), block=True)
|
executor.apply_async(next_sample, (self.uid,)), block=True)
|
||||||
|
|
||||||
|
@ -32,4 +32,8 @@ tf_module {
|
|||||||
name: "load_from_saved_model"
|
name: "load_from_saved_model"
|
||||||
argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "terminate_keras_multiprocessing_pools"
|
||||||
|
argspec: "args=[\'grace_period\', \'use_sigkill\'], varargs=None, keywords=None, defaults=[\'0.1\', \'False\'], "
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -32,4 +32,8 @@ tf_module {
|
|||||||
name: "load_from_saved_model"
|
name: "load_from_saved_model"
|
||||||
argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "terminate_keras_multiprocessing_pools"
|
||||||
|
argspec: "args=[\'grace_period\', \'use_sigkill\'], varargs=None, keywords=None, defaults=[\'0.1\', \'False\'], "
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user