Deflake fault_tolerance_test and enable it in OSS.
This CL introduces a reliable way to create test scenarios where workers are preempted in the middle of running remote functions. Previously a time-based mechanism is used, and turns out to be flaky in OSS test. PiperOrigin-RevId: 344927040 Change-Id: I90472c94e3410d368383bf63ee537dfb7c1a073c
This commit is contained in:
parent
d844ce6ae5
commit
3666f6bc53
@ -94,9 +94,8 @@ tf_py_test(
|
||||
name = "fault_tolerance_test",
|
||||
srcs = ["fault_tolerance_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 9,
|
||||
shard_count = 10,
|
||||
tags = [
|
||||
"no_oss", # TODO(b/168772720)
|
||||
"noasan", # Multi-process runner does not work with test sanitizers
|
||||
"notsan", # Multi-process runner does not work with test sanitizers
|
||||
],
|
||||
|
||||
@ -59,24 +59,31 @@ class Model(object):
|
||||
|
||||
def build(self):
|
||||
self.w = variables.Variable(
|
||||
initial_value=random_ops.random_uniform((1000, 1000)),
|
||||
dtype=dtypes.float32)
|
||||
initial_value=random_ops.random_uniform((10, 10)), dtype=dtypes.float32)
|
||||
self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32)
|
||||
# Allow external control to make the model run its train_fn in an infinite
|
||||
# loop. This allows us to reliably test worker preemption in the middle of
|
||||
# function execution.
|
||||
self.do_infinite_step = variables.Variable(False)
|
||||
|
||||
def dataset_fn():
|
||||
data = random_ops.random_uniform((1000, 1000))
|
||||
data = random_ops.random_uniform((10, 10))
|
||||
dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat()
|
||||
return dataset
|
||||
|
||||
self.iterator = iter(
|
||||
self.cluster_coord.create_per_worker_dataset(dataset_fn))
|
||||
|
||||
def _train_fn_internal(self, iterator):
|
||||
x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w)
|
||||
x = math_ops.matmul(random_ops.random_uniform((10, 10)), x)
|
||||
self.w.assign_add(x)
|
||||
|
||||
@def_function.function
|
||||
def train_fn(self, iterator):
|
||||
for _ in math_ops.range(5):
|
||||
x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w)
|
||||
x = math_ops.matmul(random_ops.random_uniform((1000, 1000)), x)
|
||||
self.w.assign_add(x)
|
||||
self._train_fn_internal(iterator)
|
||||
while self.do_infinite_step:
|
||||
self._train_fn_internal(iterator)
|
||||
self.iterations.assign_add(1)
|
||||
|
||||
def schedule_training_functions(self, num_steps):
|
||||
@ -85,6 +92,7 @@ class Model(object):
|
||||
self.cluster_coord.schedule(self.train_fn, args=(self.iterator,))
|
||||
|
||||
def join_training_functions(self):
|
||||
self.do_infinite_step.assign(False)
|
||||
self.cluster_coord.join()
|
||||
|
||||
|
||||
@ -148,23 +156,31 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
restart_thread.start()
|
||||
return restart_thread
|
||||
|
||||
def testOneWorkerPreemption(self):
|
||||
# A blackbox test to make sure the model can still train when there is
|
||||
# worker preemption.
|
||||
def testWorkerPreemptionBetweenFunctions(self):
|
||||
model = Model(self.cluster_coord)
|
||||
model.schedule_training_functions(10)
|
||||
|
||||
time.sleep(1) # Let it run a couple steps.
|
||||
self.assertFalse(
|
||||
self.cluster_coord.done(), "cluster finishes work before restart, this"
|
||||
" is most likely due to the test runs in more powerful machine"
|
||||
" compared to the one it previously runs. This setup is brittle but"
|
||||
" there are no easy better alternatives. To fix the failure, consider"
|
||||
" adding more work to the cluster, e.g, scheduling more functions.")
|
||||
self._restart(5, "worker")
|
||||
|
||||
model.schedule_training_functions(2)
|
||||
model.join_training_functions()
|
||||
self.assertGreaterEqual(model.iterations.numpy(), 10)
|
||||
self.assertEqual(model.iterations.numpy(), 2)
|
||||
|
||||
self._restart(downtime_secs=2, job="worker")
|
||||
|
||||
model.schedule_training_functions(2)
|
||||
model.join_training_functions()
|
||||
self.assertEqual(model.iterations.numpy(), 4)
|
||||
|
||||
def testWorkerPreemptionMidstFunction(self):
|
||||
model = Model(self.cluster_coord)
|
||||
model.do_infinite_step.assign(True)
|
||||
|
||||
model.schedule_training_functions(4)
|
||||
# Model does infinite training step, so at this moment, we expect to have 2
|
||||
# infinite closures inflight, and 2 closures in the queue.
|
||||
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
|
||||
time.sleep(0.1)
|
||||
self.assertFalse(self.cluster_coord.done())
|
||||
self._restart(downtime_secs=2, job="worker")
|
||||
model.join_training_functions()
|
||||
self.assertGreaterEqual(model.iterations.numpy(), 4)
|
||||
|
||||
def testOneWorkerPreemptionWithCancellation(self):
|
||||
|
||||
@ -335,9 +351,13 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
|
||||
def testTwoWorkersPreempted(self):
|
||||
model = Model(self.cluster_coord)
|
||||
model.do_infinite_step.assign(True)
|
||||
model.schedule_training_functions(10)
|
||||
|
||||
time.sleep(1)
|
||||
# Model does infinite training step, so at this moment, we expect to have 2
|
||||
# infinite closures inflight, and 8 closures in the queue.
|
||||
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
|
||||
time.sleep(0.1)
|
||||
self.assertFalse(self.cluster_coord.done())
|
||||
self._cluster.kill_task("worker", 0)
|
||||
self._cluster.kill_task("worker", 1)
|
||||
@ -355,9 +375,13 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
|
||||
def testWorkerContinuousFailure(self):
|
||||
model = Model(self.cluster_coord)
|
||||
model.do_infinite_step.assign(True)
|
||||
model.schedule_training_functions(10)
|
||||
|
||||
time.sleep(1)
|
||||
# Model does infinite training step, so at this moment, we expect to have 2
|
||||
# infinite closures inflight, and 8 closures in the queue.
|
||||
while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2:
|
||||
time.sleep(0.1)
|
||||
self.assertFalse(self.cluster_coord.done())
|
||||
self._cluster.kill_task("worker", 0)
|
||||
time.sleep(2)
|
||||
@ -399,7 +423,7 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
|
||||
self.thread_coord = thread_coordinator.Coordinator(
|
||||
clean_stop_exception_types=[])
|
||||
self.testOneWorkerPreemption()
|
||||
self.testWorkerPreemptionMidstFunction()
|
||||
|
||||
self.thread_coord = thread_coordinator.Coordinator(
|
||||
clean_stop_exception_types=[])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user