parent
837e5feed2
commit
386165b089
@ -1171,6 +1171,7 @@ tf_py_test(
|
||||
shard_count = 6,
|
||||
tags = [
|
||||
"no_oss",
|
||||
"notap", #TODO(b/123544294): Re-enable this test.
|
||||
"notsan",
|
||||
],
|
||||
)
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
@ -35,7 +34,6 @@ from tensorflow.python.keras import metrics as metrics_module
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import training_generator
|
||||
from tensorflow.python.keras.optimizer_v2 import rmsprop
|
||||
from tensorflow.python.keras.utils import data_utils
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
@ -63,40 +61,7 @@ def custom_generator(mode=2):
|
||||
yield x, y, w
|
||||
|
||||
|
||||
class ForkRobustTestCase(keras_parameterized.TestCase):
|
||||
_sleep_at_end = False
|
||||
|
||||
def setUp(self):
|
||||
# When setting up a test simply make a best effort to start from a clean
|
||||
# state.
|
||||
self._starting_remnants = data_utils.terminate_keras_multiprocessing_pools(
|
||||
use_sigkill=False)
|
||||
|
||||
self._sleep_at_end = False
|
||||
super(ForkRobustTestCase, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
# Give multiprocessing pools some time to finish on their own before
|
||||
# cleanup_all_keras_forkpools yanks the rug out from under them. This is
|
||||
# particularly important because calling .close() on a pool that is already
|
||||
# in the process of spinning down can cause an uncatchable segmentation
|
||||
# fault at which point the tearDown will hang.
|
||||
if self._sleep_at_end:
|
||||
time.sleep(1)
|
||||
|
||||
# If a test finishes and leaves behind uncleanable artifacts then that is a
|
||||
# failure condition. However, if the state was not clean to begin with the
|
||||
# test should not fail on that account.
|
||||
new_remnants = set(data_utils.terminate_keras_multiprocessing_pools(
|
||||
use_sigkill=True)).difference(self._starting_remnants)
|
||||
|
||||
if new_remnants:
|
||||
raise ValueError('Test left behind stubborn orphans:\n {}'.format(
|
||||
'\n '.join(new_remnants)))
|
||||
super(ForkRobustTestCase, self).tearDown()
|
||||
|
||||
|
||||
class TestGeneratorMethods(ForkRobustTestCase):
|
||||
class TestGeneratorMethods(keras_parameterized.TestCase):
|
||||
|
||||
@unittest.skipIf(
|
||||
os.name == 'nt',
|
||||
@ -111,7 +76,6 @@ class TestGeneratorMethods(ForkRobustTestCase):
|
||||
optimizer=rmsprop.RMSprop(1e-3),
|
||||
metrics=['mae', metrics_module.CategoricalAccuracy()])
|
||||
|
||||
self._sleep_at_end = True
|
||||
model.fit_generator(custom_generator(),
|
||||
steps_per_epoch=5,
|
||||
epochs=1,
|
||||
@ -153,7 +117,6 @@ class TestGeneratorMethods(ForkRobustTestCase):
|
||||
metrics=['mae', metrics_module.CategoricalAccuracy()],
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
|
||||
self._sleep_at_end = True
|
||||
model.evaluate_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
@ -180,7 +143,6 @@ class TestGeneratorMethods(ForkRobustTestCase):
|
||||
num_hidden=3, num_classes=4, input_dim=2)
|
||||
model.run_eagerly = testing_utils.should_run_eagerly()
|
||||
|
||||
self._sleep_at_end = True
|
||||
model.predict_generator(custom_generator(),
|
||||
steps=5,
|
||||
max_queue_size=10,
|
||||
@ -306,7 +268,7 @@ class TestGeneratorMethods(ForkRobustTestCase):
|
||||
model.predict(ones_generator(), steps=2)
|
||||
|
||||
|
||||
class TestGeneratorMethodsWithSequences(ForkRobustTestCase):
|
||||
class TestGeneratorMethodsWithSequences(keras_parameterized.TestCase):
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
|
@ -20,19 +20,16 @@ from __future__ import print_function
|
||||
|
||||
from abc import abstractmethod
|
||||
from contextlib import closing
|
||||
import gc
|
||||
import hashlib
|
||||
import multiprocessing
|
||||
from multiprocessing.pool import ThreadPool
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import tarfile
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
import zipfile
|
||||
|
||||
import numpy as np
|
||||
@ -426,126 +423,11 @@ _SHARED_SEQUENCES = {}
|
||||
_SEQUENCE_COUNTER = None
|
||||
|
||||
|
||||
# Because multiprocessing pools are inherently unsafe, starting from a clean
|
||||
# state can be essential to avoiding deadlocks. In order to accomplish this, we
|
||||
# need to be able to check on the status of Pools that we create.
|
||||
_DATA_POOLS = weakref.WeakSet()
|
||||
_WORKER_ID_QUEUE = multiprocessing.Queue()
|
||||
_WORKER_IDS = set()
|
||||
|
||||
|
||||
def init_pool(seqs):
|
||||
global _SHARED_SEQUENCES
|
||||
_SHARED_SEQUENCES = seqs
|
||||
|
||||
|
||||
@keras_export('keras.experimental.terminate_keras_multiprocessing_pools')
|
||||
def terminate_keras_multiprocessing_pools(grace_period=0.1, use_sigkill=False):
|
||||
"""Destroy Keras' multiprocessing pools to prevent deadlocks.
|
||||
|
||||
In general multiprocessing.Pool can interact quite badly with other, seemingly
|
||||
unrelated, parts of a codebase due to Pool's reliance on fork. This method
|
||||
cleans up all pools which are known to belong to Keras (and thus can be safely
|
||||
terminated).
|
||||
|
||||
Args:
|
||||
grace_period: Time (in seconds) to wait for process cleanup to propagate.
|
||||
use_sigkill: Boolean of whether or not to perform a cleanup pass using
|
||||
SIGKILL.
|
||||
|
||||
Returns:
|
||||
A list of human readable strings describing all issues encountered. It is up
|
||||
to the caller to decide whether to treat this as an error condition.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# First cleanup the pools spawned by Keras. If we start killing workers and
|
||||
# a parent pool is still alive it will just spawn replacements which we don't
|
||||
# want.
|
||||
gc.collect()
|
||||
for pool in _DATA_POOLS:
|
||||
pool.close()
|
||||
pool.terminate()
|
||||
# We do not join the pool, because that would wait forever if a worker
|
||||
# refused to exit.
|
||||
|
||||
# Finally, delete our reference to the pool so that we do not block garbage
|
||||
# collection.
|
||||
del pool
|
||||
|
||||
# If there were any pools, sleep for a small grace period to allow everything
|
||||
# to finalize.
|
||||
if _DATA_POOLS:
|
||||
time.sleep(grace_period)
|
||||
|
||||
# Now we kill any workers which are still alive. However we must compare
|
||||
# the worker identifier to the set of identifiers which are known to have been
|
||||
# spawned by pools belonging to Keras to avoid deleting unrelated workers.
|
||||
# First we call the .terminate() method of a worker, and then if it still
|
||||
# persists we directly send a signal to the process. Certain worker tasks may
|
||||
# be able to gracefully handle shutdown, so we send a SIGTERM and then
|
||||
# optionally follow up with a SIGKILL.
|
||||
visited_workers = set()
|
||||
cleanup_passes = ['.terminate', 'SIGTERM']
|
||||
if use_sigkill:
|
||||
cleanup_passes.append('SIGKILL')
|
||||
cleanup_passes.append('log')
|
||||
|
||||
for cleanup_pass in cleanup_passes:
|
||||
while True:
|
||||
# In rare cases, queue.qsize() overestimates the number of elements. This
|
||||
# loop is designed to be more robust.
|
||||
try:
|
||||
_WORKER_IDS.add(_WORKER_ID_QUEUE.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
gc.collect()
|
||||
workers_terminated_this_pass = False
|
||||
for worker in multiprocessing.active_children():
|
||||
ident = worker.ident
|
||||
if ident in _WORKER_IDS and worker.is_alive():
|
||||
try:
|
||||
if cleanup_pass == '.terminate':
|
||||
# First we ask nicely.
|
||||
worker.terminate()
|
||||
worker.join(timeout=grace_period)
|
||||
visited_workers.add(ident)
|
||||
workers_terminated_this_pass = True
|
||||
elif cleanup_pass in ('SIGTERM', 'SIGKILL'):
|
||||
# Then we ask increasingly tersely.
|
||||
os.kill(worker.pid, signal.SIGKILL if cleanup_pass == 'SIGKILL'
|
||||
else signal.SIGTERM)
|
||||
workers_terminated_this_pass = True
|
||||
|
||||
elif cleanup_pass == 'log':
|
||||
# And finally we give up and log the failure.
|
||||
errors.append('worker still alive: {}, pid={}, hash={}'
|
||||
.format(worker.name, worker.pid, hash(worker)))
|
||||
|
||||
except OSError:
|
||||
# Worker exited since the start of this loop.
|
||||
pass
|
||||
|
||||
if workers_terminated_this_pass:
|
||||
# There can be a small propagation delay between worker destruction and
|
||||
# workers reporting False for is_alive and no longer appearing in the
|
||||
# list of active children. Once again, we sleep for a small grace period.
|
||||
# This prevents false positives from workers which are simply still in the
|
||||
# process of spinning down.
|
||||
time.sleep(grace_period)
|
||||
|
||||
# Finally we remove the visited worker ids to handle the edge case that a
|
||||
# pid is reused.
|
||||
_WORKER_IDS.difference_update(visited_workers)
|
||||
|
||||
gc.collect()
|
||||
for pool in _DATA_POOLS:
|
||||
errors.append('pool still exists: {}, hash={}'.format(pool, hash(pool)))
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def get_index(uid, i):
|
||||
"""Get the value from the Sequence `uid` at index `i`.
|
||||
|
||||
@ -714,11 +596,8 @@ class OrderedEnqueuer(SequenceEnqueuer):
|
||||
Function, a Function to initialize the pool
|
||||
"""
|
||||
def pool_fn(seqs):
|
||||
pool = multiprocessing.Pool(
|
||||
workers, initializer=init_pool_generator,
|
||||
initargs=(seqs, None, _WORKER_ID_QUEUE))
|
||||
_DATA_POOLS.add(pool)
|
||||
return pool
|
||||
return multiprocessing.Pool(
|
||||
workers, initializer=init_pool_generator, initargs=(seqs, None))
|
||||
|
||||
return pool_fn
|
||||
|
||||
@ -741,7 +620,6 @@ class OrderedEnqueuer(SequenceEnqueuer):
|
||||
for i in sequence:
|
||||
if self.stop_signal.is_set():
|
||||
return
|
||||
|
||||
self.queue.put(
|
||||
executor.apply_async(get_index, (self.uid, i)), block=True)
|
||||
|
||||
@ -777,31 +655,13 @@ class OrderedEnqueuer(SequenceEnqueuer):
|
||||
six.reraise(*sys.exc_info())
|
||||
|
||||
|
||||
def init_pool_generator(gens, random_seed=None, id_queue=None):
|
||||
"""Initializer function for pool workers.
|
||||
|
||||
Args:
|
||||
gens: State which should be made available to worker processes.
|
||||
random_seed: An optional value with which to seed child processes.
|
||||
id_queue: A multiprocessing Queue of worker ids. This is used to indicate
|
||||
that a worker process was created by Keras and can be terminated using
|
||||
the cleanup_all_keras_forkpools utility.
|
||||
"""
|
||||
def init_pool_generator(gens, random_seed=None):
|
||||
global _SHARED_SEQUENCES
|
||||
_SHARED_SEQUENCES = gens
|
||||
|
||||
worker_proc = multiprocessing.current_process()
|
||||
|
||||
# name isn't used for anything, but setting a more descriptive name is helpful
|
||||
# when diagnosing orphaned processes.
|
||||
worker_proc.name = 'Keras_worker_{}'.format(worker_proc.name)
|
||||
|
||||
if random_seed is not None:
|
||||
np.random.seed(random_seed + worker_proc.ident)
|
||||
|
||||
if id_queue is not None:
|
||||
# If a worker dies during init, the pool will just create a replacement.
|
||||
id_queue.put(worker_proc.ident, block=True, timeout=0.1)
|
||||
ident = multiprocessing.current_process().ident
|
||||
np.random.seed(random_seed + ident)
|
||||
|
||||
|
||||
def next_sample(uid):
|
||||
@ -853,11 +713,9 @@ class GeneratorEnqueuer(SequenceEnqueuer):
|
||||
A Function to initialize the pool
|
||||
"""
|
||||
def pool_fn(seqs):
|
||||
pool = multiprocessing.Pool(
|
||||
workers, initializer=init_pool_generator,
|
||||
initargs=(seqs, self.random_seed, _WORKER_ID_QUEUE))
|
||||
_DATA_POOLS.add(pool)
|
||||
return pool
|
||||
return multiprocessing.Pool(workers,
|
||||
initializer=init_pool_generator,
|
||||
initargs=(seqs, self.random_seed))
|
||||
return pool_fn
|
||||
|
||||
def _run(self):
|
||||
@ -867,7 +725,6 @@ class GeneratorEnqueuer(SequenceEnqueuer):
|
||||
while True:
|
||||
if self.stop_signal.is_set():
|
||||
return
|
||||
|
||||
self.queue.put(
|
||||
executor.apply_async(next_sample, (self.uid,)), block=True)
|
||||
|
||||
|
@ -32,8 +32,4 @@ tf_module {
|
||||
name: "load_from_saved_model"
|
||||
argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "terminate_keras_multiprocessing_pools"
|
||||
argspec: "args=[\'grace_period\', \'use_sigkill\'], varargs=None, keywords=None, defaults=[\'0.1\', \'False\'], "
|
||||
}
|
||||
}
|
||||
|
@ -32,8 +32,4 @@ tf_module {
|
||||
name: "load_from_saved_model"
|
||||
argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "terminate_keras_multiprocessing_pools"
|
||||
argspec: "args=[\'grace_period\', \'use_sigkill\'], varargs=None, keywords=None, defaults=[\'0.1\', \'False\'], "
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user