PSv2: Merge cluster_coordinator_mpr_test into fault_tolerance_test: step 1, adding SingleWorkerFaultToleranceTest for single worker fault tolerance cases.
PiperOrigin-RevId: 350623998 Change-Id: Ia9a072a27f95c3d1931fd838f9939a28846d9c53
This commit is contained in:
parent
f1f4b915cd
commit
49e3675367
@ -94,7 +94,7 @@ tf_py_test(
|
||||
name = "fault_tolerance_test",
|
||||
srcs = ["fault_tolerance_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 10,
|
||||
shard_count = 21,
|
||||
tags = [
|
||||
"noasan", # Multi-process runner does not work with test sanitizers
|
||||
"notsan", # Multi-process runner does not work with test sanitizers
|
||||
|
@ -96,13 +96,10 @@ class Model(object):
|
||||
self.cluster_coord.join()
|
||||
|
||||
|
||||
class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
class BaseFaultToleranceTest(object): # pylint: disable=missing-docstring
|
||||
|
||||
NUM_WORKERS = 2
|
||||
NUM_PS = 2
|
||||
|
||||
def setUp(self):
|
||||
super(FaultToleranceTest, self).setUp()
|
||||
def setUp(self, num_workers, num_ps):
|
||||
super(BaseFaultToleranceTest, self).setUp()
|
||||
|
||||
# Set the environment variable to prevent hanging upon job failure and
|
||||
# restart. Note that it defaults to 'use_caller' at Google, but defaults
|
||||
@ -110,9 +107,7 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
os.environ["GRPC_FAIL_FAST"] = "use_caller"
|
||||
|
||||
self._cluster = multi_worker_test_base.create_multi_process_cluster(
|
||||
num_workers=FaultToleranceTest.NUM_WORKERS,
|
||||
num_ps=FaultToleranceTest.NUM_PS,
|
||||
rpc_layer="grpc")
|
||||
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
|
||||
self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict()
|
||||
self._cluster_def["chief"] = [
|
||||
"localhost:%d" % multi_worker_test_base.pick_unused_port()
|
||||
@ -127,9 +122,10 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
|
||||
self.thread_coord = thread_coordinator.Coordinator(
|
||||
clean_stop_exception_types=[])
|
||||
self.num_workers = num_workers
|
||||
|
||||
def tearDown(self):
|
||||
super(FaultToleranceTest, self).tearDown()
|
||||
super(BaseFaultToleranceTest, self).tearDown()
|
||||
self._cluster.stop()
|
||||
|
||||
def _restart(self, downtime_secs, job):
|
||||
@ -173,10 +169,11 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
model.do_infinite_step.assign(True)
|
||||
|
||||
model.schedule_training_functions(4)
|
||||
# Model does infinite training step, so at this moment, we expect to have 2
|
||||
# infinite closures inflight, and 2 closures in the queue.
|
||||
# Model does infinite training step, so at this moment, we expect to have
|
||||
# `self.num_workers` infinite closures inflight, and `4-self.num_workers`
|
||||
# closures in the queue.
|
||||
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
|
||||
< 2):
|
||||
< self.num_workers):
|
||||
time.sleep(0.1)
|
||||
self.assertFalse(self.cluster_coord.done())
|
||||
self._restart(downtime_secs=2, job="worker")
|
||||
@ -351,6 +348,8 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
str(e))
|
||||
|
||||
def testTwoWorkersPreempted(self):
|
||||
if self.num_workers < 2:
|
||||
self.skipTest("Worker number is less than 2.")
|
||||
model = Model(self.cluster_coord)
|
||||
model.do_infinite_step.assign(True)
|
||||
model.schedule_training_functions(10)
|
||||
@ -380,10 +379,11 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
model.do_infinite_step.assign(True)
|
||||
model.schedule_training_functions(10)
|
||||
|
||||
# Model does infinite training step, so at this moment, we expect to have 2
|
||||
# infinite closures inflight, and 8 closures in the queue.
|
||||
# Model does infinite training step, so at this moment, we expect to have
|
||||
# `self.num_workers` infinite closures inflight, and `10-self.num_workers`
|
||||
# closures in the queue.
|
||||
while (self.cluster_coord._cluster._closure_queue._inflight_closure_count
|
||||
< 2):
|
||||
< self.num_workers):
|
||||
time.sleep(0.1)
|
||||
self.assertFalse(self.cluster_coord.done())
|
||||
self._cluster.kill_task("worker", 0)
|
||||
@ -440,6 +440,31 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
# self.testWorkerContinuousFailure()
|
||||
|
||||
|
||||
class MultiWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase):
|
||||
"""Multi worker fault tolerance tests.
|
||||
|
||||
This covers the ordinary cases where multiple workers and PS are used.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super(MultiWorkerFaultToleranceTest, self).setUp(2, 2)
|
||||
|
||||
|
||||
class SingleWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase):
|
||||
"""Single worker fault tolerance tests.
|
||||
|
||||
This covers the cases that ensure training can continue in a single-worker
|
||||
cluster, even if the only worker can become unavailable at some point and
|
||||
recovered (if there are multiple workers, it is possible that the training
|
||||
succeeds with the workers that did not fail). Realistically single worker
|
||||
is very rarely used, but the tests are important to ensure the correct
|
||||
behaviors.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super(SingleWorkerFaultToleranceTest, self).setUp(1, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
v2_compat.enable_v2_behavior()
|
||||
multi_process_runner.test_main()
|
||||
|
Loading…
Reference in New Issue
Block a user