From 3666f6bc5345de415f2482c7045b44c67d8348f3 Mon Sep 17 00:00:00 2001 From: Chenkai Kuang Date: Mon, 30 Nov 2020 18:10:40 -0800 Subject: [PATCH] Deflake fault_tolerance_test and enable it in OSS. This CL introduces a reliable way to create test scenarios where workers are preempted in the middle of running remote functions. Previously a time-based mechanism is used, and turns out to be flaky in OSS test. PiperOrigin-RevId: 344927040 Change-Id: I90472c94e3410d368383bf63ee537dfb7c1a073c --- .../python/distribute/coordinator/BUILD | 3 +- .../coordinator/fault_tolerance_test.py | 74 ++++++++++++------- 2 files changed, 50 insertions(+), 27 deletions(-) diff --git a/tensorflow/python/distribute/coordinator/BUILD b/tensorflow/python/distribute/coordinator/BUILD index 10c7c1dae84..1da6ca6f338 100644 --- a/tensorflow/python/distribute/coordinator/BUILD +++ b/tensorflow/python/distribute/coordinator/BUILD @@ -94,9 +94,8 @@ tf_py_test( name = "fault_tolerance_test", srcs = ["fault_tolerance_test.py"], python_version = "PY3", - shard_count = 9, + shard_count = 10, tags = [ - "no_oss", # TODO(b/168772720) "noasan", # Multi-process runner does not work with test sanitizers "notsan", # Multi-process runner does not work with test sanitizers ], diff --git a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py index 099d3eed46c..17472e0c6d2 100644 --- a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py +++ b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py @@ -59,24 +59,31 @@ class Model(object): def build(self): self.w = variables.Variable( - initial_value=random_ops.random_uniform((1000, 1000)), - dtype=dtypes.float32) + initial_value=random_ops.random_uniform((10, 10)), dtype=dtypes.float32) self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32) + # Allow external control to make the model run its train_fn in an infinite + # loop. This allows us to reliably test worker preemption in the middle of + # function execution. + self.do_infinite_step = variables.Variable(False) def dataset_fn(): - data = random_ops.random_uniform((1000, 1000)) + data = random_ops.random_uniform((10, 10)) dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat() return dataset self.iterator = iter( self.cluster_coord.create_per_worker_dataset(dataset_fn)) + def _train_fn_internal(self, iterator): + x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w) + x = math_ops.matmul(random_ops.random_uniform((10, 10)), x) + self.w.assign_add(x) + @def_function.function def train_fn(self, iterator): - for _ in math_ops.range(5): - x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w) - x = math_ops.matmul(random_ops.random_uniform((1000, 1000)), x) - self.w.assign_add(x) + self._train_fn_internal(iterator) + while self.do_infinite_step: + self._train_fn_internal(iterator) self.iterations.assign_add(1) def schedule_training_functions(self, num_steps): @@ -85,6 +92,7 @@ class Model(object): self.cluster_coord.schedule(self.train_fn, args=(self.iterator,)) def join_training_functions(self): + self.do_infinite_step.assign(False) self.cluster_coord.join() @@ -148,23 +156,31 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring restart_thread.start() return restart_thread - def testOneWorkerPreemption(self): - # A blackbox test to make sure the model can still train when there is - # worker preemption. + def testWorkerPreemptionBetweenFunctions(self): model = Model(self.cluster_coord) - model.schedule_training_functions(10) - - time.sleep(1) # Let it run a couple steps. - self.assertFalse( - self.cluster_coord.done(), "cluster finishes work before restart, this" - " is most likely due to the test runs in more powerful machine" - " compared to the one it previously runs. This setup is brittle but" - " there are no easy better alternatives. To fix the failure, consider" - " adding more work to the cluster, e.g, scheduling more functions.") - self._restart(5, "worker") - + model.schedule_training_functions(2) model.join_training_functions() - self.assertGreaterEqual(model.iterations.numpy(), 10) + self.assertEqual(model.iterations.numpy(), 2) + + self._restart(downtime_secs=2, job="worker") + + model.schedule_training_functions(2) + model.join_training_functions() + self.assertEqual(model.iterations.numpy(), 4) + + def testWorkerPreemptionMidstFunction(self): + model = Model(self.cluster_coord) + model.do_infinite_step.assign(True) + + model.schedule_training_functions(4) + # Model does infinite training step, so at this moment, we expect to have 2 + # infinite closures inflight, and 2 closures in the queue. + while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2: + time.sleep(0.1) + self.assertFalse(self.cluster_coord.done()) + self._restart(downtime_secs=2, job="worker") + model.join_training_functions() + self.assertGreaterEqual(model.iterations.numpy(), 4) def testOneWorkerPreemptionWithCancellation(self): @@ -335,9 +351,13 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring def testTwoWorkersPreempted(self): model = Model(self.cluster_coord) + model.do_infinite_step.assign(True) model.schedule_training_functions(10) - time.sleep(1) + # Model does infinite training step, so at this moment, we expect to have 2 + # infinite closures inflight, and 8 closures in the queue. + while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2: + time.sleep(0.1) self.assertFalse(self.cluster_coord.done()) self._cluster.kill_task("worker", 0) self._cluster.kill_task("worker", 1) @@ -355,9 +375,13 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring def testWorkerContinuousFailure(self): model = Model(self.cluster_coord) + model.do_infinite_step.assign(True) model.schedule_training_functions(10) - time.sleep(1) + # Model does infinite training step, so at this moment, we expect to have 2 + # infinite closures inflight, and 8 closures in the queue. + while self.cluster_coord.cluster._closure_queue._inflight_closure_count < 2: + time.sleep(0.1) self.assertFalse(self.cluster_coord.done()) self._cluster.kill_task("worker", 0) time.sleep(2) @@ -399,7 +423,7 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring self.thread_coord = thread_coordinator.Coordinator( clean_stop_exception_types=[]) - self.testOneWorkerPreemption() + self.testWorkerPreemptionMidstFunction() self.thread_coord = thread_coordinator.Coordinator( clean_stop_exception_types=[])