Automated rollback of commit a9064fdddc

PiperOrigin-RevId: 236430122
This commit is contained in:
RJ Ryan 2019-03-01 23:01:26 -08:00 committed by TensorFlower Gardener
parent 837e5feed2
commit 386165b089
5 changed files with 11 additions and 199 deletions

View File

@ -1171,6 +1171,7 @@ tf_py_test(
shard_count = 6,
tags = [
"no_oss",
"notap", #TODO(b/123544294): Re-enable this test.
"notsan",
],
)

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import os
import time
import unittest
from absl.testing import parameterized
@ -35,7 +34,6 @@ from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.optimizer_v2 import rmsprop
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.platform import test
from tensorflow.python.util import nest
@ -63,40 +61,7 @@ def custom_generator(mode=2):
yield x, y, w
class ForkRobustTestCase(keras_parameterized.TestCase):
_sleep_at_end = False
def setUp(self):
# When setting up a test simply make a best effort to start from a clean
# state.
self._starting_remnants = data_utils.terminate_keras_multiprocessing_pools(
use_sigkill=False)
self._sleep_at_end = False
super(ForkRobustTestCase, self).setUp()
def tearDown(self):
# Give multiprocessing pools some time to finish on their own before
# cleanup_all_keras_forkpools yanks the rug out from under them. This is
# particularly important because calling .close() on a pool that is already
# in the process of spinning down can cause an uncatchable segmentation
# fault at which point the tearDown will hang.
if self._sleep_at_end:
time.sleep(1)
# If a test finishes and leaves behind uncleanable artifacts then that is a
# failure condition. However, if the state was not clean to begin with the
# test should not fail on that account.
new_remnants = set(data_utils.terminate_keras_multiprocessing_pools(
use_sigkill=True)).difference(self._starting_remnants)
if new_remnants:
raise ValueError('Test left behind stubborn orphans:\n {}'.format(
'\n '.join(new_remnants)))
super(ForkRobustTestCase, self).tearDown()
class TestGeneratorMethods(ForkRobustTestCase):
class TestGeneratorMethods(keras_parameterized.TestCase):
@unittest.skipIf(
os.name == 'nt',
@ -111,7 +76,6 @@ class TestGeneratorMethods(ForkRobustTestCase):
optimizer=rmsprop.RMSprop(1e-3),
metrics=['mae', metrics_module.CategoricalAccuracy()])
self._sleep_at_end = True
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
@ -153,7 +117,6 @@ class TestGeneratorMethods(ForkRobustTestCase):
metrics=['mae', metrics_module.CategoricalAccuracy()],
run_eagerly=testing_utils.should_run_eagerly())
self._sleep_at_end = True
model.evaluate_generator(custom_generator(),
steps=5,
max_queue_size=10,
@ -180,7 +143,6 @@ class TestGeneratorMethods(ForkRobustTestCase):
num_hidden=3, num_classes=4, input_dim=2)
model.run_eagerly = testing_utils.should_run_eagerly()
self._sleep_at_end = True
model.predict_generator(custom_generator(),
steps=5,
max_queue_size=10,
@ -306,7 +268,7 @@ class TestGeneratorMethods(ForkRobustTestCase):
model.predict(ones_generator(), steps=2)
class TestGeneratorMethodsWithSequences(ForkRobustTestCase):
class TestGeneratorMethodsWithSequences(keras_parameterized.TestCase):
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes

View File

@ -20,19 +20,16 @@ from __future__ import print_function
from abc import abstractmethod
from contextlib import closing
import gc
import hashlib
import multiprocessing
from multiprocessing.pool import ThreadPool
import os
import random
import shutil
import signal
import sys
import tarfile
import threading
import time
import weakref
import zipfile
import numpy as np
@ -426,126 +423,11 @@ _SHARED_SEQUENCES = {}
_SEQUENCE_COUNTER = None
# Because multiprocessing pools are inherently unsafe, starting from a clean
# state can be essential to avoiding deadlocks. In order to accomplish this, we
# need to be able to check on the status of Pools that we create.
_DATA_POOLS = weakref.WeakSet()
_WORKER_ID_QUEUE = multiprocessing.Queue()
_WORKER_IDS = set()
def init_pool(seqs):
global _SHARED_SEQUENCES
_SHARED_SEQUENCES = seqs
@keras_export('keras.experimental.terminate_keras_multiprocessing_pools')
def terminate_keras_multiprocessing_pools(grace_period=0.1, use_sigkill=False):
"""Destroy Keras' multiprocessing pools to prevent deadlocks.
In general multiprocessing.Pool can interact quite badly with other, seemingly
unrelated, parts of a codebase due to Pool's reliance on fork. This method
cleans up all pools which are known to belong to Keras (and thus can be safely
terminated).
Args:
grace_period: Time (in seconds) to wait for process cleanup to propagate.
use_sigkill: Boolean of whether or not to perform a cleanup pass using
SIGKILL.
Returns:
A list of human readable strings describing all issues encountered. It is up
to the caller to decide whether to treat this as an error condition.
"""
errors = []
# First cleanup the pools spawned by Keras. If we start killing workers and
# a parent pool is still alive it will just spawn replacements which we don't
# want.
gc.collect()
for pool in _DATA_POOLS:
pool.close()
pool.terminate()
# We do not join the pool, because that would wait forever if a worker
# refused to exit.
# Finally, delete our reference to the pool so that we do not block garbage
# collection.
del pool
# If there were any pools, sleep for a small grace period to allow everything
# to finalize.
if _DATA_POOLS:
time.sleep(grace_period)
# Now we kill any workers which are still alive. However we must compare
# the worker identifier to the set of identifiers which are known to have been
# spawned by pools belonging to Keras to avoid deleting unrelated workers.
# First we call the .terminate() method of a worker, and then if it still
# persists we directly send a signal to the process. Certain worker tasks may
# be able to gracefully handle shutdown, so we send a SIGTERM and then
# optionally follow up with a SIGKILL.
visited_workers = set()
cleanup_passes = ['.terminate', 'SIGTERM']
if use_sigkill:
cleanup_passes.append('SIGKILL')
cleanup_passes.append('log')
for cleanup_pass in cleanup_passes:
while True:
# In rare cases, queue.qsize() overestimates the number of elements. This
# loop is designed to be more robust.
try:
_WORKER_IDS.add(_WORKER_ID_QUEUE.get_nowait())
except queue.Empty:
break
gc.collect()
workers_terminated_this_pass = False
for worker in multiprocessing.active_children():
ident = worker.ident
if ident in _WORKER_IDS and worker.is_alive():
try:
if cleanup_pass == '.terminate':
# First we ask nicely.
worker.terminate()
worker.join(timeout=grace_period)
visited_workers.add(ident)
workers_terminated_this_pass = True
elif cleanup_pass in ('SIGTERM', 'SIGKILL'):
# Then we ask increasingly tersely.
os.kill(worker.pid, signal.SIGKILL if cleanup_pass == 'SIGKILL'
else signal.SIGTERM)
workers_terminated_this_pass = True
elif cleanup_pass == 'log':
# And finally we give up and log the failure.
errors.append('worker still alive: {}, pid={}, hash={}'
.format(worker.name, worker.pid, hash(worker)))
except OSError:
# Worker exited since the start of this loop.
pass
if workers_terminated_this_pass:
# There can be a small propagation delay between worker destruction and
# workers reporting False for is_alive and no longer appearing in the
# list of active children. Once again, we sleep for a small grace period.
# This prevents false positives from workers which are simply still in the
# process of spinning down.
time.sleep(grace_period)
# Finally we remove the visited worker ids to handle the edge case that a
# pid is reused.
_WORKER_IDS.difference_update(visited_workers)
gc.collect()
for pool in _DATA_POOLS:
errors.append('pool still exists: {}, hash={}'.format(pool, hash(pool)))
return errors
def get_index(uid, i):
"""Get the value from the Sequence `uid` at index `i`.
@ -714,11 +596,8 @@ class OrderedEnqueuer(SequenceEnqueuer):
Function, a Function to initialize the pool
"""
def pool_fn(seqs):
pool = multiprocessing.Pool(
workers, initializer=init_pool_generator,
initargs=(seqs, None, _WORKER_ID_QUEUE))
_DATA_POOLS.add(pool)
return pool
return multiprocessing.Pool(
workers, initializer=init_pool_generator, initargs=(seqs, None))
return pool_fn
@ -741,7 +620,6 @@ class OrderedEnqueuer(SequenceEnqueuer):
for i in sequence:
if self.stop_signal.is_set():
return
self.queue.put(
executor.apply_async(get_index, (self.uid, i)), block=True)
@ -777,31 +655,13 @@ class OrderedEnqueuer(SequenceEnqueuer):
six.reraise(*sys.exc_info())
def init_pool_generator(gens, random_seed=None, id_queue=None):
"""Initializer function for pool workers.
Args:
gens: State which should be made available to worker processes.
random_seed: An optional value with which to seed child processes.
id_queue: A multiprocessing Queue of worker ids. This is used to indicate
that a worker process was created by Keras and can be terminated using
the cleanup_all_keras_forkpools utility.
"""
def init_pool_generator(gens, random_seed=None):
global _SHARED_SEQUENCES
_SHARED_SEQUENCES = gens
worker_proc = multiprocessing.current_process()
# name isn't used for anything, but setting a more descriptive name is helpful
# when diagnosing orphaned processes.
worker_proc.name = 'Keras_worker_{}'.format(worker_proc.name)
if random_seed is not None:
np.random.seed(random_seed + worker_proc.ident)
if id_queue is not None:
# If a worker dies during init, the pool will just create a replacement.
id_queue.put(worker_proc.ident, block=True, timeout=0.1)
ident = multiprocessing.current_process().ident
np.random.seed(random_seed + ident)
def next_sample(uid):
@ -853,11 +713,9 @@ class GeneratorEnqueuer(SequenceEnqueuer):
A Function to initialize the pool
"""
def pool_fn(seqs):
pool = multiprocessing.Pool(
workers, initializer=init_pool_generator,
initargs=(seqs, self.random_seed, _WORKER_ID_QUEUE))
_DATA_POOLS.add(pool)
return pool
return multiprocessing.Pool(workers,
initializer=init_pool_generator,
initargs=(seqs, self.random_seed))
return pool_fn
def _run(self):
@ -867,7 +725,6 @@ class GeneratorEnqueuer(SequenceEnqueuer):
while True:
if self.stop_signal.is_set():
return
self.queue.put(
executor.apply_async(next_sample, (self.uid,)), block=True)

View File

@ -32,8 +32,4 @@ tf_module {
name: "load_from_saved_model"
argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "terminate_keras_multiprocessing_pools"
argspec: "args=[\'grace_period\', \'use_sigkill\'], varargs=None, keywords=None, defaults=[\'0.1\', \'False\'], "
}
}

View File

@ -32,8 +32,4 @@ tf_module {
name: "load_from_saved_model"
argspec: "args=[\'saved_model_path\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "terminate_keras_multiprocessing_pools"
argspec: "args=[\'grace_period\', \'use_sigkill\'], varargs=None, keywords=None, defaults=[\'0.1\', \'False\'], "
}
}