Automated rollback of commit f4496744c1d9c371e300b0bcec4549409661b70a
PiperOrigin-RevId: 256210774
This commit is contained in:
parent
0d6c633206
commit
758db76393
@ -375,6 +375,14 @@ cuda_py_test(
|
||||
"//tensorflow/python/keras",
|
||||
],
|
||||
shard_count = 14,
|
||||
# TODO(b/132384649): Enable once fixed.
|
||||
tags = [
|
||||
"manual",
|
||||
"no_oss",
|
||||
"no_tap",
|
||||
"nogpu", # b/136499799
|
||||
"noguitar",
|
||||
],
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
@ -29,6 +28,7 @@ from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import callbacks
|
||||
from tensorflow.python.keras.distribute import multi_worker_testing_utils
|
||||
@ -128,19 +128,29 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
|
||||
def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument
|
||||
with test.mock.patch.object(dc, '_run_std_server',
|
||||
self._make_mock_run_std_server()):
|
||||
# Condition variable that blocks the thread that represents the
|
||||
# restarted chief.
|
||||
cv = kwargs.get('cv', None)
|
||||
# `before_restart` is True for the threads that represent the original
|
||||
# chief and non-chief worker, and False for threads that represent the
|
||||
# restarted chief and non-chief workers.
|
||||
before_restart = kwargs['before_restart']
|
||||
if kwargs['new_chief']:
|
||||
# `new_chief` is only True for the restarted chief thread. It waits
|
||||
# until non-chief is preempted and restarted to simulate the causality
|
||||
# where chief's restart results from non-chief's failure.
|
||||
cv.acquire()
|
||||
while not hasattr(cv, 'preempted'):
|
||||
cv.wait()
|
||||
cv.release()
|
||||
|
||||
# Model building under strategy scope. Following is the code we expect
|
||||
# the user runs on every worker.
|
||||
strategy = get_strategy_object(strategy_cls)
|
||||
batch_size = 64
|
||||
steps = 3
|
||||
steps = 2
|
||||
train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
|
||||
batch_size, steps)
|
||||
|
||||
with strategy.scope():
|
||||
model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
|
||||
|
||||
@ -148,15 +158,9 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
|
||||
# following code: one represents the restart of the non-chief, and one
|
||||
# represents the restart of the chief as a result of the restart of the
|
||||
# non-chief (so the training can continue in sync).
|
||||
def start_new_thread(new_chief):
|
||||
def start_new_thread(new_chief=False):
|
||||
new_thread_tf_config = json.loads(os.environ['TF_CONFIG'])
|
||||
|
||||
# Update the ports in new chief and new worker threads.
|
||||
new_thread_tf_config['cluster']['worker'] = kwargs['reserved_ports']
|
||||
|
||||
# Since both new chief and new worker threads are started from the
|
||||
# worker thread, we need to overwrite the tf config task index.
|
||||
new_thread_tf_config['task']['index'] = 0 if new_chief else 1
|
||||
return self._run_task_in_thread(
|
||||
task_fn=_independent_worker_fn,
|
||||
cluster_spec=None,
|
||||
@ -164,8 +168,16 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
|
||||
task_id=None,
|
||||
tf_config=new_thread_tf_config,
|
||||
before_restart=False,
|
||||
cv=cv,
|
||||
new_chief=new_chief)
|
||||
|
||||
if test_base.is_chief() and before_restart:
|
||||
# Chief to start a new thread (that will be blocked by a condition
|
||||
# variable until the non-chief's new thread is started). The thread
|
||||
# for (recovered) chief is started before entering `fit()` because
|
||||
# the original chief thread will eventually hang and be ignored.
|
||||
start_new_thread(new_chief=True)
|
||||
|
||||
try:
|
||||
|
||||
class CkptSavedEpochAssertingCallback(callbacks.Callback):
|
||||
@ -211,38 +223,31 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
|
||||
self._barrier._counter = 0
|
||||
self._barrier._flag = False
|
||||
|
||||
# At this point we block the original non-chief thread, and
|
||||
# start the new threads that simulate the restarted chief and
|
||||
# non-chief, joining the threads and return.
|
||||
new_chief_thread = start_new_thread(new_chief=True)
|
||||
new_worker_thread = start_new_thread(new_chief=False)
|
||||
self.join_independent_workers([new_chief_thread, new_worker_thread])
|
||||
# Now that the non-chief has been preempted, it notifies the thread
|
||||
# that simulates the restarted chief to start so they can be back in
|
||||
# sync.
|
||||
cv.acquire()
|
||||
cv.preempted = True
|
||||
cv.notify()
|
||||
cv.release()
|
||||
|
||||
# At this point we should discard the original non-chief thread, and
|
||||
# start the new thread that simulates the restarted non-chief, hence
|
||||
# joining the thread and return.
|
||||
self.join_independent_workers([start_new_thread()])
|
||||
return
|
||||
|
||||
# Successful end of a `fit()` call.
|
||||
with self._lock:
|
||||
self._successful_thread_ends += 1
|
||||
self._successful_thread_ends += 1
|
||||
self.assertFalse(before_restart)
|
||||
|
||||
# Common parameters
|
||||
num_workers = 2
|
||||
num_epoch = 3
|
||||
num_epoch = 2
|
||||
# History list storing the results for preemption and no preemption cases.
|
||||
self._histories = []
|
||||
# Lock required to prevent race condition between two threads.
|
||||
self._lock = threading.Lock()
|
||||
strategy = get_strategy_object(strategy_cls)
|
||||
|
||||
def handler(signum, frame):
|
||||
del signum, frame
|
||||
# `session.run()` within `model.fit()` can time out. Skipping it as it
|
||||
# doesn't represent the failure of this test.
|
||||
self.skipTest('Skipping test due to `session.run()` timeout.')
|
||||
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
# Alarming within 5 min before the test timeouts and fails.
|
||||
signal.alarm(240)
|
||||
|
||||
def get_saving_dir_and_filepath():
|
||||
saving_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
|
||||
saving_filepath = os.path.join(saving_dir, 'checkpoint.' + file_format)
|
||||
@ -270,22 +275,18 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
|
||||
threads_to_join = [threads['worker'][0]]
|
||||
self.join_independent_workers(threads_to_join)
|
||||
|
||||
# `self.test_skipped_reason` could be set when a non-main thread attempts
|
||||
# to skip the test.
|
||||
# `multi_worker_test_base.skip_if_grpc_server_cant_be_started()` is an
|
||||
# example of where this can be set. Since raising `SkipTest` in a non-main
|
||||
# thread doesn't actually skip the test, we check if the test should be
|
||||
# skipped here once we have joined the threads.
|
||||
if getattr(self, 'test_skipped_reason', None) is not None:
|
||||
self.skipTest(self.test_skipped_reason)
|
||||
|
||||
self.assertTrue(
|
||||
training_state.remove_checkpoint_if_exists(saving_dir, saving_filepath))
|
||||
self.assertEqual(self._successful_thread_ends, 2)
|
||||
try:
|
||||
training_state.remove_checkpoint_if_exists(saving_dir, saving_filepath)
|
||||
except errors.NotFoundError:
|
||||
self.skipTest('To be understood why in rare cases the checkpoint '
|
||||
'doesn\'t exist')
|
||||
if self._successful_thread_ends != 2:
|
||||
self.skipTest('To be understood why in rare cases a thread can disappear')
|
||||
|
||||
# Case 2: Training for `num_epoch` epoch with preemptions.
|
||||
# The preemption is simulated at both epoch boundary and batch boundary.
|
||||
cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
|
||||
cv = threading.Condition()
|
||||
self._barrier = dc._Barrier(2)
|
||||
# Ports reserved for new threads simulating recovery.
|
||||
reserved_ports = [
|
||||
@ -303,6 +304,7 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
|
||||
saving_filepath=saving_filepath,
|
||||
reserved_ports=reserved_ports,
|
||||
before_restart=True,
|
||||
cv=cv,
|
||||
new_chief=False)
|
||||
threads_to_join = []
|
||||
if strategy.extended.experimental_between_graph:
|
||||
@ -312,12 +314,14 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
|
||||
else:
|
||||
threads_to_join = [threads['worker'][0]]
|
||||
self.join_independent_workers(threads_to_join)
|
||||
if getattr(self, 'test_skipped_reason', None) is not None:
|
||||
self.skipTest(self.test_skipped_reason)
|
||||
|
||||
self.assertTrue(
|
||||
training_state.remove_checkpoint_if_exists(saving_dir, saving_filepath))
|
||||
self.assertEqual(self._successful_thread_ends, 2)
|
||||
try:
|
||||
training_state.remove_checkpoint_if_exists(saving_dir, saving_filepath)
|
||||
except errors.NotFoundError:
|
||||
self.skipTest('To be understood why in rare cases the checkpoint '
|
||||
'doesn\'t exist')
|
||||
if self._successful_thread_ends != 2:
|
||||
self.skipTest('To be understood why in rare cases a thread can disappear')
|
||||
|
||||
def assert_all_elements_are_identical(list_to_check):
|
||||
first_item = list_to_check[0]
|
||||
@ -333,12 +337,6 @@ class KerasMultiWorkerFaultToleranceTest(test_base.IndependentWorkerTestBase,
|
||||
# The length of `self._histories` would be num_workers * num_runs (3).
|
||||
self.assertLen(self._histories, 4)
|
||||
|
||||
# Results from case 1 should have 3 full epochs.
|
||||
self.assertLen(self._histories[0]['acc'], 3)
|
||||
# Results from case 2 should only have 2 full epochs because it restarted at
|
||||
# epoch 1.
|
||||
self.assertLen(self._histories[-1]['acc'], 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with test.mock.patch.object(sys, 'exit', os._exit):
|
||||
|
Loading…
x
Reference in New Issue
Block a user