PSv2: Merge cluster_coordinator_mpr_test into fault_tolerance_test: step 1, adding SingleWorkerFaultToleranceTest for single worker fault tolerance cases.
PiperOrigin-RevId: 350623998 Change-Id: Ia9a072a27f95c3d1931fd838f9939a28846d9c53
This commit is contained in:
parent
f1f4b915cd
commit
49e3675367
@ -94,7 +94,7 @@ tf_py_test(
|
|||||||
name = "fault_tolerance_test",
|
name = "fault_tolerance_test",
|
||||||
srcs = ["fault_tolerance_test.py"],
|
srcs = ["fault_tolerance_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 10,
|
shard_count = 21,
|
||||||
tags = [
|
tags = [
|
||||||
"noasan", # Multi-process runner does not work with test sanitizers
|
"noasan", # Multi-process runner does not work with test sanitizers
|
||||||
"notsan", # Multi-process runner does not work with test sanitizers
|
"notsan", # Multi-process runner does not work with test sanitizers
|
||||||
|
@ -96,13 +96,10 @@ class Model(object):
|
|||||||
self.cluster_coord.join()
|
self.cluster_coord.join()
|
||||||
|
|
||||||
|
|
||||||
class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
class BaseFaultToleranceTest(object): # pylint: disable=missing-docstring
|
||||||
|
|
||||||
NUM_WORKERS = 2
|
def setUp(self, num_workers, num_ps):
|
||||||
NUM_PS = 2
|
super(BaseFaultToleranceTest, self).setUp()
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
super(FaultToleranceTest, self).setUp()
|
|
||||||
|
|
||||||
# Set the environment variable to prevent hanging upon job failure and
|
# Set the environment variable to prevent hanging upon job failure and
|
||||||
# restart. Note that it defaults to 'use_caller' at Google, but defaults
|
# restart. Note that it defaults to 'use_caller' at Google, but defaults
|
||||||
@ -110,9 +107,7 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
|||||||
os.environ["GRPC_FAIL_FAST"] = "use_caller"
|
os.environ["GRPC_FAIL_FAST"] = "use_caller"
|
||||||
|
|
||||||
self._cluster = multi_worker_test_base.create_multi_process_cluster(
|
self._cluster = multi_worker_test_base.create_multi_process_cluster(
|
||||||
num_workers=FaultToleranceTest.NUM_WORKERS,
|
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
|
||||||
num_ps=FaultToleranceTest.NUM_PS,
|
|
||||||
rpc_layer="grpc")
|
|
||||||
self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict()
|
self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict()
|
||||||
self._cluster_def["chief"] = [
|
self._cluster_def["chief"] = [
|
||||||
"localhost:%d" % multi_worker_test_base.pick_unused_port()
|
"localhost:%d" % multi_worker_test_base.pick_unused_port()
|
||||||
@ -127,9 +122,10 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
|||||||
|
|
||||||
self.thread_coord = thread_coordinator.Coordinator(
|
self.thread_coord = thread_coordinator.Coordinator(
|
||||||
clean_stop_exception_types=[])
|
clean_stop_exception_types=[])
|
||||||
|
self.num_workers = num_workers
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
super(FaultToleranceTest, self).tearDown()
|
super(BaseFaultToleranceTest, self).tearDown()
|
||||||
self._cluster.stop()
|
self._cluster.stop()
|
||||||
|
|
||||||
def _restart(self, downtime_secs, job):
|
def _restart(self, downtime_secs, job):
|
||||||
@ -173,10 +169,11 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
|||||||
model.do_infinite_step.assign(True)
|
model.do_infinite_step.assign(True)
|
||||||
|
|
||||||
model.schedule_training_functions(4)
|
model.schedule_training_functions(4)
|
||||||
# Model does infinite training step, so at this moment, we expect to have 2
|
# Model does infinite training step, so at this moment, we expect to have
|
||||||
# infinite closures inflight, and 2 closures in the queue.
|
# `self.num_workers` infinite closures inflight, and `4-self.num_workers`
|
||||||
|
# closures in the queue.
|
||||||
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
|
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
|
||||||
< 2):
|
< self.num_workers):
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
self.assertFalse(self.cluster_coord.done())
|
self.assertFalse(self.cluster_coord.done())
|
||||||
self._restart(downtime_secs=2, job="worker")
|
self._restart(downtime_secs=2, job="worker")
|
||||||
@ -351,6 +348,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
|||||||
str(e))
|
str(e))
|
||||||
|
|
||||||
def testTwoWorkersPreempted(self):
|
def testTwoWorkersPreempted(self):
|
||||||
|
if self.num_workers < 2:
|
||||||
|
self.skipTest("Worker number is less than 2.")
|
||||||
model = Model(self.cluster_coord)
|
model = Model(self.cluster_coord)
|
||||||
model.do_infinite_step.assign(True)
|
model.do_infinite_step.assign(True)
|
||||||
model.schedule_training_functions(10)
|
model.schedule_training_functions(10)
|
||||||
@ -380,10 +379,11 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
|||||||
model.do_infinite_step.assign(True)
|
model.do_infinite_step.assign(True)
|
||||||
model.schedule_training_functions(10)
|
model.schedule_training_functions(10)
|
||||||
|
|
||||||
# Model does infinite training step, so at this moment, we expect to have 2
|
# Model does infinite training step, so at this moment, we expect to have
|
||||||
# infinite closures inflight, and 8 closures in the queue.
|
# `self.num_workers` infinite closures inflight, and `10-self.num_workers`
|
||||||
|
# closures in the queue.
|
||||||
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
|
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
|
||||||
< 2):
|
< self.num_workers):
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
self.assertFalse(self.cluster_coord.done())
|
self.assertFalse(self.cluster_coord.done())
|
||||||
self._cluster.kill_task("worker", 0)
|
self._cluster.kill_task("worker", 0)
|
||||||
@ -440,6 +440,31 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
|||||||
# self.testWorkerContinuousFailure()
|
# self.testWorkerContinuousFailure()
|
||||||
|
|
||||||
|
|
||||||
|
class MultiWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase):
|
||||||
|
"""Multi worker fault tolerance tests.
|
||||||
|
|
||||||
|
This covers the ordinary cases where multiple workers and PS are used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(MultiWorkerFaultToleranceTest, self).setUp(2, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class SingleWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase):
|
||||||
|
"""Single worker fault tolerance tests.
|
||||||
|
|
||||||
|
This covers the cases that ensure training can continue in a single-worker
|
||||||
|
cluster, even if the only worker can become unavailable at some point and
|
||||||
|
recovered (if there are multiple workers, it is possible that the training
|
||||||
|
succeeds with the workers that did not fail). Realistically single worker
|
||||||
|
is very rarely used, but the tests are important to ensure the correct
|
||||||
|
behaviors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(SingleWorkerFaultToleranceTest, self).setUp(1, 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
v2_compat.enable_v2_behavior()
|
v2_compat.enable_v2_behavior()
|
||||||
multi_process_runner.test_main()
|
multi_process_runner.test_main()
|
||||||
|
Loading…
Reference in New Issue
Block a user