PSv2: Merge cluster_coordinator_mpr_test into fault_tolerance_test: step 1, adding SingleWorkerFaultToleranceTest for single worker fault tolerance cases.

PiperOrigin-RevId: 350623998
Change-Id: Ia9a072a27f95c3d1931fd838f9939a28846d9c53
This commit is contained in:
Rick Chao 2021-01-07 13:11:41 -08:00 committed by TensorFlower Gardener
parent f1f4b915cd
commit 49e3675367
2 changed files with 42 additions and 17 deletions

View File

@ -94,7 +94,7 @@ tf_py_test(
name = "fault_tolerance_test", name = "fault_tolerance_test",
srcs = ["fault_tolerance_test.py"], srcs = ["fault_tolerance_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 10, shard_count = 21,
tags = [ tags = [
"noasan", # Multi-process runner does not work with test sanitizers "noasan", # Multi-process runner does not work with test sanitizers
"notsan", # Multi-process runner does not work with test sanitizers "notsan", # Multi-process runner does not work with test sanitizers

View File

@ -96,13 +96,10 @@ class Model(object):
self.cluster_coord.join() self.cluster_coord.join()
class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring class BaseFaultToleranceTest(object): # pylint: disable=missing-docstring
NUM_WORKERS = 2 def setUp(self, num_workers, num_ps):
NUM_PS = 2 super(BaseFaultToleranceTest, self).setUp()
def setUp(self):
super(FaultToleranceTest, self).setUp()
# Set the environment variable to prevent hanging upon job failure and # Set the environment variable to prevent hanging upon job failure and
# restart. Note that it defaults to 'use_caller' at Google, but defaults # restart. Note that it defaults to 'use_caller' at Google, but defaults
@ -110,9 +107,7 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
os.environ["GRPC_FAIL_FAST"] = "use_caller" os.environ["GRPC_FAIL_FAST"] = "use_caller"
self._cluster = multi_worker_test_base.create_multi_process_cluster( self._cluster = multi_worker_test_base.create_multi_process_cluster(
num_workers=FaultToleranceTest.NUM_WORKERS, num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
num_ps=FaultToleranceTest.NUM_PS,
rpc_layer="grpc")
self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict() self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict()
self._cluster_def["chief"] = [ self._cluster_def["chief"] = [
"localhost:%d" % multi_worker_test_base.pick_unused_port() "localhost:%d" % multi_worker_test_base.pick_unused_port()
@ -127,9 +122,10 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
self.thread_coord = thread_coordinator.Coordinator( self.thread_coord = thread_coordinator.Coordinator(
clean_stop_exception_types=[]) clean_stop_exception_types=[])
self.num_workers = num_workers
def tearDown(self): def tearDown(self):
super(FaultToleranceTest, self).tearDown() super(BaseFaultToleranceTest, self).tearDown()
self._cluster.stop() self._cluster.stop()
def _restart(self, downtime_secs, job): def _restart(self, downtime_secs, job):
@ -173,10 +169,11 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
model.do_infinite_step.assign(True) model.do_infinite_step.assign(True)
model.schedule_training_functions(4) model.schedule_training_functions(4)
# Model does infinite training step, so at this moment, we expect to have 2 # Model does infinite training step, so at this moment, we expect to have
# infinite closures inflight, and 2 closures in the queue. # `self.num_workers` infinite closures inflight, and `4-self.num_workers`
# closures in the queue.
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
< 2): < self.num_workers):
time.sleep(0.1) time.sleep(0.1)
self.assertFalse(self.cluster_coord.done()) self.assertFalse(self.cluster_coord.done())
self._restart(downtime_secs=2, job="worker") self._restart(downtime_secs=2, job="worker")
@ -351,6 +348,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
str(e)) str(e))
def testTwoWorkersPreempted(self): def testTwoWorkersPreempted(self):
if self.num_workers < 2:
self.skipTest("Worker number is less than 2.")
model = Model(self.cluster_coord) model = Model(self.cluster_coord)
model.do_infinite_step.assign(True) model.do_infinite_step.assign(True)
model.schedule_training_functions(10) model.schedule_training_functions(10)
@ -380,10 +379,11 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
model.do_infinite_step.assign(True) model.do_infinite_step.assign(True)
model.schedule_training_functions(10) model.schedule_training_functions(10)
# Model does infinite training step, so at this moment, we expect to have 2 # Model does infinite training step, so at this moment, we expect to have
# infinite closures inflight, and 8 closures in the queue. # `self.num_workers` infinite closures inflight, and `10-self.num_workers`
# closures in the queue.
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
< 2): < self.num_workers):
time.sleep(0.1) time.sleep(0.1)
self.assertFalse(self.cluster_coord.done()) self.assertFalse(self.cluster_coord.done())
self._cluster.kill_task("worker", 0) self._cluster.kill_task("worker", 0)
@ -440,6 +440,31 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
# self.testWorkerContinuousFailure() # self.testWorkerContinuousFailure()
class MultiWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase):
"""Multi worker fault tolerance tests.
This covers the ordinary cases where multiple workers and PS are used.
"""
def setUp(self):
super(MultiWorkerFaultToleranceTest, self).setUp(2, 2)
class SingleWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase):
"""Single worker fault tolerance tests.
This covers the cases that ensure training can continue in a single-worker
cluster, even if the only worker can become unavailable at some point and
recovered (if there are multiple workers, it is possible that the training
succeeds with the workers that did not fail). Realistically single worker
is very rarely used, but the tests are important to ensure the correct
behaviors.
"""
def setUp(self):
super(SingleWorkerFaultToleranceTest, self).setUp(1, 1)
if __name__ == "__main__": if __name__ == "__main__":
v2_compat.enable_v2_behavior() v2_compat.enable_v2_behavior()
multi_process_runner.test_main() multi_process_runner.test_main()