diff --git a/tensorflow/python/distribute/coordinator/BUILD b/tensorflow/python/distribute/coordinator/BUILD index 1da6ca6f338..4ab1d5973d5 100644 --- a/tensorflow/python/distribute/coordinator/BUILD +++ b/tensorflow/python/distribute/coordinator/BUILD @@ -94,7 +94,7 @@ tf_py_test( name = "fault_tolerance_test", srcs = ["fault_tolerance_test.py"], python_version = "PY3", - shard_count = 10, + shard_count = 21, tags = [ "noasan", # Multi-process runner does not work with test sanitizers "notsan", # Multi-process runner does not work with test sanitizers diff --git a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py index 00891845f07..6f5fc053b4b 100644 --- a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py +++ b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py @@ -96,13 +96,10 @@ class Model(object): self.cluster_coord.join() -class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring +class BaseFaultToleranceTest(object): # pylint: disable=missing-docstring - NUM_WORKERS = 2 - NUM_PS = 2 - - def setUp(self): - super(FaultToleranceTest, self).setUp() + def setUp(self, num_workers, num_ps): + super(BaseFaultToleranceTest, self).setUp() # Set the environment variable to prevent hanging upon job failure and # restart. Note that it defaults to 'use_caller' at Google, but defaults @@ -110,9 +107,7 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring os.environ["GRPC_FAIL_FAST"] = "use_caller" self._cluster = multi_worker_test_base.create_multi_process_cluster( - num_workers=FaultToleranceTest.NUM_WORKERS, - num_ps=FaultToleranceTest.NUM_PS, - rpc_layer="grpc") + num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict() self._cluster_def["chief"] = [ "localhost:%d" % multi_worker_test_base.pick_unused_port() @@ -127,9 +122,10 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring self.thread_coord = thread_coordinator.Coordinator( clean_stop_exception_types=[]) + self.num_workers = num_workers def tearDown(self): - super(FaultToleranceTest, self).tearDown() + super(BaseFaultToleranceTest, self).tearDown() self._cluster.stop() def _restart(self, downtime_secs, job): @@ -173,10 +169,11 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring model.do_infinite_step.assign(True) model.schedule_training_functions(4) - # Model does infinite training step, so at this moment, we expect to have 2 - # infinite closures inflight, and 2 closures in the queue. + # Model does infinite training step, so at this moment, we expect to have + # `self.num_workers` infinite closures inflight, and `4-self.num_workers` + # closures in the queue. while (self.cluster_coord._cluster._closure_queue._inflight_closure_count - < 2): + < self.num_workers): time.sleep(0.1) self.assertFalse(self.cluster_coord.done()) self._restart(downtime_secs=2, job="worker") @@ -351,6 +348,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring str(e)) def testTwoWorkersPreempted(self): + if self.num_workers < 2: + self.skipTest("Worker number is less than 2.") model = Model(self.cluster_coord) model.do_infinite_step.assign(True) model.schedule_training_functions(10) @@ -380,10 +379,11 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring model.do_infinite_step.assign(True) model.schedule_training_functions(10) - # Model does infinite training step, so at this moment, we expect to have 2 - # infinite closures inflight, and 8 closures in the queue. + # Model does infinite training step, so at this moment, we expect to have + # `self.num_workers` infinite closures inflight, and `10-self.num_workers` + # closures in the queue. while (self.cluster_coord._cluster._closure_queue._inflight_closure_count - < 2): + < self.num_workers): time.sleep(0.1) self.assertFalse(self.cluster_coord.done()) self._cluster.kill_task("worker", 0) @@ -440,6 +440,31 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring # self.testWorkerContinuousFailure() +class MultiWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase): + """Multi worker fault tolerance tests. + + This covers the ordinary cases where multiple workers and PS are used. + """ + + def setUp(self): + super(MultiWorkerFaultToleranceTest, self).setUp(2, 2) + + +class SingleWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase): + """Single worker fault tolerance tests. + + This covers the cases that ensure training can continue in a single-worker + cluster, even if the only worker can become unavailable at some point and + recovered (if there are multiple workers, it is possible that the training + succeeds with the workers that did not fail). Realistically single worker + is very rarely used, but the tests are important to ensure the correct + behaviors. + """ + + def setUp(self): + super(SingleWorkerFaultToleranceTest, self).setUp(1, 1) + + if __name__ == "__main__": v2_compat.enable_v2_behavior() multi_process_runner.test_main()