parent
837e5feed2
commit
386165b089
@ -1171,6 +1171,7 @@ tf_py_test(
|
|||||||
shard_count = 6,
|
shard_count = 6,
|
||||||
tags = [
|
tags = [
|
||||||
"no_oss",
|
"no_oss",
|
||||||
|
"notap", #TODO(b/123544294): Re-enable this test.
|
||||||
"notsan",
|
"notsan",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
@ -35,7 +34,6 @@ from tensorflow.python.keras import metrics as metrics_module
|
|||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.keras.engine import training_generator
|
from tensorflow.python.keras.engine import training_generator
|
||||||
from tensorflow.python.keras.optimizer_v2 import rmsprop
|
from tensorflow.python.keras.optimizer_v2 import rmsprop
|
||||||
from tensorflow.python.keras.utils import data_utils
|
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
@ -63,40 +61,7 @@ def custom_generator(mode=2):
|
|||||||
yield x, y, w
|
yield x, y, w
|
||||||
|
|
||||||
|
|
||||||
class ForkRobustTestCase(keras_parameterized.TestCase):
|
class TestGeneratorMethods(keras_parameterized.TestCase):
|
||||||
_sleep_at_end = False
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
# When setting up a test simply make a best effort to start from a clean
|
|
||||||
# state.
|
|
||||||
self._starting_remnants = data_utils.terminate_keras_multiprocessing_pools(
|
|
||||||
use_sigkill=False)
|
|
||||||
|
|
||||||
self._sleep_at_end = False
|
|
||||||
super(ForkRobustTestCase, self).setUp()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
# Give multiprocessing pools some time to finish on their own before
|
|
||||||
# cleanup_all_keras_forkpools yanks the rug out from under them. This is
|
|
||||||
# particularly important because calling .close() on a pool that is already
|
|
||||||
# in the process of spinning down can cause an uncatchable segmentation
|
|
||||||
# fault at which point the tearDown will hang.
|
|
||||||
if self._sleep_at_end:
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
# If a test finishes and leaves behind uncleanable artifacts then that is a
|
|
||||||
# failure condition. However, if the state was not clean to begin with the
|
|
||||||
# test should not fail on that account.
|
|
||||||
new_remnants = set(data_utils.terminate_keras_multiprocessing_pools(
|
|
||||||
use_sigkill=True)).difference(self._starting_remnants)
|
|
||||||
|
|
||||||
if new_remnants:
|
|
||||||
raise ValueError('Test left behind stubborn orphans:\n {}'.format(
|
|
||||||
'\n '.join(new_remnants)))
|
|
||||||
super(ForkRobustTestCase, self).tearDown()
|
|
||||||
|
|
||||||
|
|
||||||
class TestGeneratorMethods(ForkRobustTestCase):
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
os.name == 'nt',
|
os.name == 'nt',
|
||||||
@ -111,7 +76,6 @@ class TestGeneratorMethods(ForkRobustTestCase):
|
|||||||
optimizer=rmsprop.RMSprop(1e-3),
|
optimizer=rmsprop.RMSprop(1e-3),
|
||||||
metrics=['mae', metrics_module.CategoricalAccuracy()])
|
metrics=['mae', metrics_module.CategoricalAccuracy()])
|
||||||
|
|
||||||
self._sleep_at_end = True
|
|
||||||
model.fit_generator(custom_generator(),
|
model.fit_generator(custom_generator(),
|
||||||
steps_per_epoch=5,
|
steps_per_epoch=5,
|
||||||
epochs=1,
|
epochs=1,
|
||||||
@ -153,7 +117,6 @@ class TestGeneratorMethods(ForkRobustTestCase):
|
|||||||
metrics=['mae', metrics_module.CategoricalAccuracy()],
|
metrics=['mae', metrics_module.CategoricalAccuracy()],
|
||||||
run_eagerly=testing_utils.should_run_eagerly())
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
|
||||||
self._sleep_at_end = True
|
|
||||||
model.evaluate_generator(custom_generator(),
|
model.evaluate_generator(custom_generator(),
|
||||||
steps=5,
|
steps=5,
|
||||||
max_queue_size=10,
|
max_queue_size=10,
|
||||||
@ -180,7 +143,6 @@ class TestGeneratorMethods(ForkRobustTestCase):
|
|||||||
num_hidden=3, num_classes=4, input_dim=2)
|
num_hidden=3, num_classes=4, input_dim=2)
|
||||||
model.run_eagerly = testing_utils.should_run_eagerly()
|
model.run_eagerly = testing_utils.should_run_eagerly()
|
||||||
|
|
||||||
self._sleep_at_end = True
|
|
||||||
model.predict_generator(custom_generator(),
|
model.predict_generator(custom_generator(),
|
||||||
steps=5,
|
steps=5,
|
||||||
max_queue_size=10,
|
max_queue_size=10,
|
||||||
@ -306,7 +268,7 @@ class TestGeneratorMethods(ForkRobustTestCase):
|
|||||||
model.predict(ones_generator(), steps=2)
|
model.predict(ones_generator(), steps=2)
|
||||||
|
|
||||||
|
|
||||||
class TestGeneratorMethodsWithSequences(ForkRobustTestCase):
|
class TestGeneratorMethodsWithSequences(keras_parameterized.TestCase):
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
@ -20,19 +20,16 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
import gc
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from multiprocessing.pool import ThreadPool
|
from multiprocessing.pool import ThreadPool
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
import signal
|
|
||||||
import sys
|
import sys
|
||||||
import tarfile
|
import tarfile
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import weakref
|
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -426,126 +423,11 @@ _SHARED_SEQUENCES = {}
|
|||||||
_SEQUENCE_COUNTER = None
|
_SEQUENCE_COUNTER = None
|
||||||
|
|
||||||
|
|
||||||
# Because multiprocessing pools are inherently unsafe, starting from a clean
|
|
||||||
# state can be essential to avoiding deadlocks. In order to accomplish this, we
|
|
||||||
# need to be able to check on the status of Pools that we create.
|
|
||||||
_DATA_POOLS = weakref.WeakSet()
|
|
||||||
_WORKER_ID_QUEUE = multiprocessing.Queue()
|
|
||||||
_WORKER_IDS = set()
|
|
||||||
|
|
||||||
|
|
||||||
def init_pool(seqs):
|
def init_pool(seqs):
|
||||||
global _SHARED_SEQUENCES
|
global _SHARED_SEQUENCES
|
||||||
_SHARED_SEQUENCES = seqs
|
_SHARED_SEQUENCES = seqs
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.experimental.terminate_keras_multiprocessing_pools')
|
|
||||||
def terminate_keras_multiprocessing_pools(grace_period=0.1, use_sigkill=False):
|
|
||||||
"""Destroy Keras' multiprocessing pools to prevent deadlocks.
|
|
||||||
|
|
||||||
In general multiprocessing.Pool can interact quite badly with other, seemingly
|
|
||||||
unrelated, parts of a codebase due to Pool's reliance on fork. This method
|
|
||||||
cleans up all pools which are known to belong to Keras (and thus can be safely
|
|
||||||
terminated).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
grace_period: Time (in seconds) to wait for process cleanup to propagate.
|
|
||||||
use_sigkill: Boolean of whether or not to perform a cleanup pass using
|
|
||||||
SIGKILL.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of human readable strings describing all issues encountered. It is up
|
|
||||||
to the caller to decide whether to treat this as an error condition.
|
|
||||||
"""
|
|
||||||
errors = []
|
|
||||||
|
|
||||||
# First cleanup the pools spawned by Keras. If we start killing workers and
|
|
||||||
# a parent pool is still alive it will just spawn replacements which we don't
|
|
||||||
# want.
|
|
||||||
gc.collect()
|
|
||||||
for pool in _DATA_POOLS:
|
|
||||||
pool.close()
|
|
||||||
pool.terminate()
|
|
||||||
# We do not join the pool, because that would wait forever if a worker
|
|
||||||
# refused to exit.
|
|
||||||
|
|
||||||
# Finally, delete our reference to the pool so that we do not block garbage
|
|
||||||
# collection.
|
|
||||||
del pool
|
|
||||||
|
|
||||||
# If there were any pools, sleep for a small grace period to allow everything
|
|
||||||
# to finalize.
|
|
||||||
if _DATA_POOLS:
|
|
||||||
time.sleep(grace_period)
|
|
||||||
|
|
||||||
# Now we kill any workers which are still alive. However we must compare
|
|
||||||
# the worker identifier to the set of identifiers which are known to have been
|
|
||||||
# spawned by pools belonging to Keras to avoid deleting unrelated workers.
|
|
||||||
# First we call the .terminate() method of a worker, and then if it still
|
|
||||||
# persists we directly send a signal to the process. Certain worker tasks may
|
|
||||||
# be able to gracefully handle shutdown, so we send a SIGTERM and then
|
|
||||||
# optionally follow up with a SIGKILL.
|
|
||||||
visited_workers = set()
|
|
||||||
cleanup_passes = ['.terminate', 'SIGTERM']
|
|
||||||
if use_sigkill:
|
|
||||||
cleanup_passes.append('SIGKILL')
|
|
||||||
cleanup_passes.append('log')
|
|
||||||
|
|
||||||
for cleanup_pass in cleanup_passes:
|
|
||||||
while True:
|
|
||||||
# In rare cases, queue.qsize() overestimates the number of elements. This
|
|
||||||
# loop is designed to be more robust.
|
|
||||||
try:
|
|
||||||
_WORKER_IDS.add(_WORKER_ID_QUEUE.get_nowait())
|
|
||||||
except queue.Empty:
|
|
||||||
break
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
workers_terminated_this_pass = False
|
|
||||||
for worker in multiprocessing.active_children():
|
|
||||||
ident = worker.ident
|
|
||||||
if ident in _WORKER_IDS and worker.is_alive():
|
|
||||||
try:
|
|
||||||
if cleanup_pass == '.terminate':
|
|
||||||
# First we ask nicely.
|
|
||||||
worker.terminate()
|
|
||||||
worker.join(timeout=grace_period)
|
|
||||||
visited_workers.add(ident)
|
|
||||||
workers_terminated_this_pass = True
|
|
||||||
elif cleanup_pass in ('SIGTERM', 'SIGKILL'):
|
|
||||||
# Then we ask increasingly tersely.
|
|
||||||
os.kill(worker.pid, signal.SIGKILL if cleanup_pass == 'SIGKILL'
|
|
||||||
else signal.SIGTERM)
|
|
||||||
workers_terminated_this_pass = True
|
|
||||||
|
|
||||||
elif cleanup_pass == 'log':
|
|
||||||
# And finally we give up and log the failure.
|
|
||||||
errors.append('worker still alive: {}, pid={}, hash={}'
|
|
||||||
.format(worker.name, worker.pid, hash(worker)))
|
|
||||||
|
|
||||||
except OSError:
|
|
||||||
# Worker exited since the start of this loop.
|
|
||||||
pass
|
|
||||||
|
|
||||||
if workers_terminated_this_pass:
|
|
||||||
# There can be a small propagation delay between worker destruction and
|
|
||||||
# workers reporting False for is_alive and no longer appearing in the
|
|
||||||
# list of active children. Once again, we sleep for a small grace period.
|
|
||||||
# This prevents false positives from workers which are simply still in the
|
|
||||||
# process of spinning down.
|
|
||||||
time.sleep(grace_period)
|
|
||||||
|
|
||||||
# Finally we remove the visited worker ids to handle the edge case that a
|
|
||||||
# pid is reused.
|
|
||||||
_WORKER_IDS.difference_update(visited_workers)
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
for pool in _DATA_POOLS:
|
|
||||||
errors.append('pool still exists: {}, hash={}'.format(pool, hash(pool)))
|
|
||||||
|
|
||||||
return errors
|
|
||||||
|
|
||||||
|
|
||||||
def get_index(uid, i):
|
def get_index(uid, i):
|
||||||
"""Get the value from the Sequence `uid` at index `i`.
|
"""Get the value from the Sequence `uid` at index `i`.
|
||||||
|
|
||||||
@ -714,11 +596,8 @@ class OrderedEnqueuer(SequenceEnqueuer):
|
|||||||
Function, a Function to initialize the pool
|
Function, a Function to initialize the pool
|
||||||
"""
|
"""
|
||||||
def pool_fn(seqs):
|
def pool_fn(seqs):
|
||||||
pool = multiprocessing.Pool(
|
return multiprocessing.Pool(
|
||||||
workers, initializer=init_pool_generator,
|
workers, initializer=init_pool_generator, initargs=(seqs, None))
|
||||||
initargs=(seqs, None, _WORKER_ID_QUEUE))
|
|
||||||
_DATA_POOLS.add(pool)
|
|
||||||
return pool
|
|
||||||
|
|
||||||
return pool_fn
|
return pool_fn
|
||||||
|
|
||||||
@ -741,7 +620,6 @@ class OrderedEnqueuer(SequenceEnqueuer):
|
|||||||
for i in sequence:
|
for i in sequence:
|
||||||
if self.stop_signal.is_set():
|
if self.stop_signal.is_set():
|
||||||
return
|
return
|
||||||
|
|
||||||
self.queue.put(
|
self.queue.put(
|
||||||
executor.apply_async(get_index, (self.uid, i)), block=True)
|
executor.apply_async(get_index, (self.uid, i)), block=True)
|
||||||
|
|
||||||
@ -777,31 +655,13 @@ class OrderedEnqueuer(SequenceEnqueuer):
|
|||||||
six.reraise(*sys.exc_info())
|
six.reraise(*sys.exc_info())
|
||||||
|
|
||||||
|
|
||||||
def init_pool_generator(gens, random_seed=None, id_queue=None):
|
def init_pool_generator(gens, random_seed=None):
|
||||||
"""Initializer function for pool workers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gens: State which should be made available to worker processes.
|
|
||||||
random_seed: An optional value with which to seed child processes.
|
|
||||||
id_queue: A multiprocessing Queue of worker ids. This is used to indicate
|
|
||||||
that a worker process was created by Keras and can be terminated using
|
|
||||||
the cleanup_all_keras_forkpools utility.
|
|
||||||
"""
|
|
||||||
global _SHARED_SEQUENCES
|
global _SHARED_SEQUENCES
|
||||||
_SHARED_SEQUENCES = gens
|
_SHARED_SEQUENCES = gens
|
||||||
|
|
||||||
worker_proc = multiprocessing.current_process()
|
|
||||||
|
|
||||||
# name isn't used for anything, but setting a more descriptive name is helpful
|
|
||||||
# when diagnosing orphaned processes.
|
|
||||||
worker_proc.name = 'Keras_worker_{}'.format(worker_proc.name)
|
|
||||||
|
|
||||||
if random_seed is not None:
|
if random_seed is not None:
|
||||||
np.random.seed(random_seed + worker_proc.ident)
|
ident = multiprocessing.current_process().ident
|
||||||
|
np.random.seed(random_seed + ident)
|
||||||
if id_queue is not None:
|
|
||||||
# If a worker dies during init, the pool will just create a replacement.
|
|
||||||
id_queue.put(worker_proc.ident, block=True, timeout=0.1)
|
|
||||||
|
|
||||||
|
|
||||||
def next_sample(uid):
|
def next_sample(uid):
|
||||||
@ -853,11 +713,9 @@ class GeneratorEnqueuer(SequenceEnqueuer):
|
|||||||
A Function to initialize the pool
|
A Function to initialize the pool
|
||||||
"""
|
"""
|
||||||
def pool_fn(seqs):
|
def pool_fn(seqs):
|
||||||
pool = multiprocessing.Pool(
|
return multiprocessing.Pool(workers,
|
||||||
workers, initializer=init_pool_generator,
|
initializer=init_pool_generator,
|
||||||
initargs=(seqs, self.random_seed, _WORKER_ID_QUEUE))
|
initargs=(seqs, self.random_seed))
|
||||||
_DATA_POOLS.add(pool)
|
|
||||||
return pool
|
|
||||||
return pool_fn
|
return pool_fn
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
@ -867,7 +725,6 @@ class GeneratorEnqueuer(SequenceEnqueuer):
|
|||||||
while True:
|
while True:
|
||||||
if self.stop_signal.is_set():
|
if self.stop_signal.is_set():
|
||||||
return
|
return
|
||||||
|
|
||||||
self.queue.put(
|
self.queue.put(
|
||||||
executor.apply_async(next_sample, (self.uid,)), block=True)
|
executor.apply_async(next_sample, (self.uid,)), block=True)
|
||||||
|
|
||||||
|
@ -32,8 +32,4 @@ tf_module {
|
|||||||
name: "load_from_saved_model"
|
name: "load_from_saved_model"
|
||||||
argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "terminate_keras_multiprocessing_pools"
|
|
||||||
argspec: "args=[\'grace_period\', \'use_sigkill\'], varargs=None, keywords=None, defaults=[\'0.1\', \'False\'], "
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -32,8 +32,4 @@ tf_module {
|
|||||||
name: "load_from_saved_model"
|
name: "load_from_saved_model"
|
||||||
argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "terminate_keras_multiprocessing_pools"
|
|
||||||
argspec: "args=[\'grace_period\', \'use_sigkill\'], varargs=None, keywords=None, defaults=[\'0.1\', \'False\'], "
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user