Automated rollback of commit f4496744c1d9c371e300b0bcec4549409661b70a

PiperOrigin-RevId: 256210774
This commit is contained in:
Mihai Maruseac 2019-07-02 11:58:52 -07:00 committed by TensorFlower Gardener
parent 0d6c633206
commit 758db76393
2 changed files with 60 additions and 54 deletions

View File

@ -375,6 +375,14 @@ cuda_py_test(
"//tensorflow/python/keras",
],
shard_count = 14,
# TODO(b/132384649): Enable once fixed.
tags = [
"manual",
"no_oss",
"no_tap",
"nogpu", # b/136499799
"noguitar",
],
xla_enable_strict_auto_jit = True,
)

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import json
import os
import signal
import sys
import tempfile
import threading
@ -29,6 +28,7 @@ from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import multi_worker_test_base as test_base
from tensorflow.python.framework import errors
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks
from tensorflow.python.keras.distribute import multi_worker_testing_utils
@ -128,19 +128,29 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument
with test.mock.patch.object(dc, '_run_std_server',
self._make_mock_run_std_server()):
# Condition variable that blocks the thread that represents the
# restarted chief.
cv = kwargs.get('cv', None)
# `before_restart` is True for the threads that represent the original
# chief and non-chief worker, and False for threads that represent the
# restarted chief and non-chief workers.
before_restart = kwargs['before_restart']
if kwargs['new_chief']:
# `new_chief` is only True for the restarted chief thread. It waits
# until non-chief is preempted and restarted to simulate the causality
# where chief's restart results from non-chief's failure.
cv.acquire()
while not hasattr(cv, 'preempted'):
cv.wait()
cv.release()
# Model building under strategy scope. Following is the code we expect
# the user runs on every worker.
strategy = get_strategy_object(strategy_cls)
batch_size = 64
steps = 3
steps = 2
train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
batch_size, steps)
with strategy.scope():
model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
@ -148,15 +158,9 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
# following code: one represents the restart of the non-chief, and one
# represents the restart of the chief as a result of the restart of the
# non-chief (so the training can continue in sync).
def start_new_thread(new_chief):
def start_new_thread(new_chief=False):
new_thread_tf_config = json.loads(os.environ['TF_CONFIG'])
# Update the ports in new chief and new worker threads.
new_thread_tf_config['cluster']['worker'] = kwargs['reserved_ports']
# Since both new chief and new worker threads are started from the
# worker thread, we need to overwrite the tf config task index.
new_thread_tf_config['task']['index'] = 0 if new_chief else 1
return self._run_task_in_thread(
task_fn=_independent_worker_fn,
cluster_spec=None,
@ -164,8 +168,16 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
task_id=None,
tf_config=new_thread_tf_config,
before_restart=False,
cv=cv,
new_chief=new_chief)
if test_base.is_chief() and before_restart:
# Chief to start a new thread (that will be blocked by a condition
# variable until the non-chief's new thread is started). The thread
# for (recovered) chief is started before entering `fit()` because
# the original chief thread will eventually hang and be ignored.
start_new_thread(new_chief=True)
try:
class CkptSavedEpochAssertingCallback(callbacks.Callback):
@ -211,38 +223,31 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
self._barrier._counter = 0
self._barrier._flag = False
# At this point we block the original non-chief thread, and
# start the new threads that simulate the restarted chief and
# non-chief, joining the threads and return.
new_chief_thread = start_new_thread(new_chief=True)
new_worker_thread = start_new_thread(new_chief=False)
self.join_independent_workers([new_chief_thread, new_worker_thread])
# Now that the non-chief has been preempted, it notifies the thread
# that simulates the restarted chief to start so they can be back in
# sync.
cv.acquire()
cv.preempted = True
cv.notify()
cv.release()
# At this point we should discard the original non-chief thread, and
# start the new thread that simulates the restarted non-chief, hence
# joining the thread and return.
self.join_independent_workers([start_new_thread()])
return
# Successful end of a `fit()` call.
with self._lock:
self._successful_thread_ends += 1
self._successful_thread_ends += 1
self.assertFalse(before_restart)
# Common parameters
num_workers = 2
num_epoch = 3
num_epoch = 2
# History list storing the results for preemption and no preemption cases.
self._histories = []
# Lock required to prevent race condition between two threads.
self._lock = threading.Lock()
strategy = get_strategy_object(strategy_cls)
def handler(signum, frame):
del signum, frame
# `session.run()` within `model.fit()` can time out. Skipping it as it
# doesn't represent the failure of this test.
self.skipTest('Skipping test due to `session.run()` timeout.')
signal.signal(signal.SIGALRM, handler)
# Alarming within 5 min before the test timeouts and fails.
signal.alarm(240)
def get_saving_dir_and_filepath():
saving_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
saving_filepath = os.path.join(saving_dir, 'checkpoint.' + file_format)
@ -270,22 +275,18 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
threads_to_join = [threads['worker'][0]]
self.join_independent_workers(threads_to_join)
# `self.test_skipped_reason` could be set when a non-main thread attempts
# to skip the test.
# `multi_worker_test_base.skip_if_grpc_server_cant_be_started()` is an
# example of where this can be set. Since raising `SkipTest` in a non-main
# thread doesn't actually skip the test, we check if the test should be
# skipped here once we have joined the threads.
if getattr(self, 'test_skipped_reason', None) is not None:
self.skipTest(self.test_skipped_reason)
self.assertTrue(
training_state.remove_checkpoint_if_exists(saving_dir, saving_filepath))
self.assertEqual(self._successful_thread_ends, 2)
try:
training_state.remove_checkpoint_if_exists(saving_dir, saving_filepath)
except errors.NotFoundError:
self.skipTest('To be understood why in rare cases the checkpoint '
'doesn\'t exist')
if self._successful_thread_ends != 2:
self.skipTest('To be understood why in rare cases a thread can disappear')
# Case 2: Training for `num_epoch` epoch with preemptions.
# The preemption is simulated at both epoch boundary and batch boundary.
cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
cv = threading.Condition()
self._barrier = dc._Barrier(2)
# Ports reserved for new threads simulating recovery.
reserved_ports = [
@ -303,6 +304,7 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
saving_filepath=saving_filepath,
reserved_ports=reserved_ports,
before_restart=True,
cv=cv,
new_chief=False)
threads_to_join = []
if strategy.extended.experimental_between_graph:
@ -312,12 +314,14 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
else:
threads_to_join = [threads['worker'][0]]
self.join_independent_workers(threads_to_join)
if getattr(self, 'test_skipped_reason', None) is not None:
self.skipTest(self.test_skipped_reason)
self.assertTrue(
training_state.remove_checkpoint_if_exists(saving_dir, saving_filepath))
self.assertEqual(self._successful_thread_ends, 2)
try:
training_state.remove_checkpoint_if_exists(saving_dir, saving_filepath)
except errors.NotFoundError:
self.skipTest('To be understood why in rare cases the checkpoint '
'doesn\'t exist')
if self._successful_thread_ends != 2:
self.skipTest('To be understood why in rare cases a thread can disappear')
def assert_all_elements_are_identical(list_to_check):
first_item = list_to_check[0]
@ -333,12 +337,6 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
# The length of `self._histories` would be num_workers * num_runs (3).
self.assertLen(self._histories, 4)
# Results from case 1 should have 3 full epochs.
self.assertLen(self._histories[0]['acc'], 3)
# Results from case 2 should only have 2 full epochs because it restarted at
# epoch 1.
self.assertLen(self._histories[-1]['acc'], 2)
if __name__ == '__main__':
with test.mock.patch.object(sys, 'exit', os._exit):