Deflake fault_tolerance_test and enable it in OSS.

This CL introduces a reliable way to create test scenarios where workers are preempted in the middle of running remote functions. Previously a time-based mechanism is used, and turns out to be flaky in OSS test.

PiperOrigin-RevId: 344927040
Change-Id: I90472c94e3410d368383bf63ee537dfb7c1a073c
This commit is contained in:
Chenkai Kuang 2020-11-30 18:10:40 -08:00 committed by TensorFlower Gardener
parent d844ce6ae5
commit 3666f6bc53
2 changed files with 50 additions and 27 deletions

View File

@ -94,9 +94,8 @@ tf_py_test(
name = "fault_tolerance_test",
srcs = ["fault_tolerance_test.py"],
python_version = "PY3",
shard_count = 9,
shard_count = 10,
tags = [
"no_oss", # TODO(b/168772720)
"noasan", # Multi-process runner does not work with test sanitizers
"notsan", # Multi-process runner does not work with test sanitizers
],

View File

@ -59,24 +59,31 @@ class Model(object):
def build(self):
self.w = variables.Variable(
initial_value=random_ops.random_uniform((1000, 1000)),
dtype=dtypes.float32)
initial_value=random_ops.random_uniform((10, 10)), dtype=dtypes.float32)
self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32)
# Allow external control to make the model run its train_fn in an infinite
# loop. This allows us to reliably test worker preemption in the middle of
# function execution.
self.do_infinite_step = variables.Variable(False)
def dataset_fn():
data = random_ops.random_uniform((1000, 1000))
data = random_ops.random_uniform((10, 10))
dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat()
return dataset
self.iterator = iter(
self.cluster_coord.create_per_worker_dataset(dataset_fn))
def _train_fn_internal(self, iterator):
x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w)
x = math_ops.matmul(random_ops.random_uniform((10, 10)), x)
self.w.assign_add(x)
@def_function.function
def train_fn(self, iterator):
for _ in math_ops.range(5):
x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w)
x = math_ops.matmul(random_ops.random_uniform((1000, 1000)), x)
self.w.assign_add(x)
self._train_fn_internal(iterator)
while self.do_infinite_step:
self._train_fn_internal(iterator)
self.iterations.assign_add(1)
def schedule_training_functions(self, num_steps):
@ -85,6 +92,7 @@ class Model(object):
self.cluster_coord.schedule(self.train_fn, args=(self.iterator,))
def join_training_functions(self):
self.do_infinite_step.assign(False)
self.cluster_coord.join()
@ -148,23 +156,31 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
restart_thread.start()
return restart_thread
def testOneWorkerPreemption(self):
# A blackbox test to make sure the model can still train when there is
# worker preemption.
def testWorkerPreemptionBetweenFunctions(self):
model = Model(self.cluster_coord)
model.schedule_training_functions(10)
time.sleep(1) # Let it run a couple steps.
self.assertFalse(
self.cluster_coord.done(), "cluster finishes work before restart, this"
" is most likely due to the test runs in more powerful machine"
" compared to the one it previously runs. This setup is brittle but"
" there are no easy better alternatives. To fix the failure, consider"
" adding more work to the cluster, e.g, scheduling more functions.")
self._restart(5, "worker")
model.schedule_training_functions(2)
model.join_training_functions()
self.assertGreaterEqual(model.iterations.numpy(), 10)
self.assertEqual(model.iterations.numpy(), 2)
self._restart(downtime_secs=2, job="worker")
model.schedule_training_functions(2)
model.join_training_functions()
self.assertEqual(model.iterations.numpy(), 4)
def testWorkerPreemptionMidstFunction(self):
model = Model(self.cluster_coord)
model.do_infinite_step.assign(True)
model.schedule_training_functions(4)
# Model does infinite training step, so at this moment, we expect to have 2
# infinite closures inflight, and 2 closures in the queue.
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
time.sleep(0.1)
self.assertFalse(self.cluster_coord.done())
self._restart(downtime_secs=2, job="worker")
model.join_training_functions()
self.assertGreaterEqual(model.iterations.numpy(), 4)
def testOneWorkerPreemptionWithCancellation(self):
@ -335,9 +351,13 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
def testTwoWorkersPreempted(self):
model = Model(self.cluster_coord)
model.do_infinite_step.assign(True)
model.schedule_training_functions(10)
time.sleep(1)
# Model does infinite training step, so at this moment, we expect to have 2
# infinite closures inflight, and 8 closures in the queue.
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
time.sleep(0.1)
self.assertFalse(self.cluster_coord.done())
self._cluster.kill_task("worker", 0)
self._cluster.kill_task("worker", 1)
@ -355,9 +375,13 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
def testWorkerContinuousFailure(self):
model = Model(self.cluster_coord)
model.do_infinite_step.assign(True)
model.schedule_training_functions(10)
time.sleep(1)
# Model does infinite training step, so at this moment, we expect to have 2
# infinite closures inflight, and 8 closures in the queue.
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
time.sleep(0.1)
self.assertFalse(self.cluster_coord.done())
self._cluster.kill_task("worker", 0)
time.sleep(2)
@ -399,7 +423,7 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
self.thread_coord = thread_coordinator.Coordinator(
clean_stop_exception_types=[])
self.testOneWorkerPreemption()
self.testWorkerPreemptionMidstFunction()
self.thread_coord = thread_coordinator.Coordinator(
clean_stop_exception_types=[])