Use MPR for fault tolerance test
PiperOrigin-RevId: 327766188 Change-Id: I247539f5561940a29fef658818b1e815dd194c1d
This commit is contained in:
parent
e918c5c7ea
commit
f74cc7a696
tensorflow/python/distribute
@ -870,6 +870,7 @@ py_library(
|
||||
srcs = ["multi_worker_test_base.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":multi_process_runner",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:distributed_framework_test_lib",
|
||||
@ -879,12 +880,22 @@ py_library(
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:training_lib",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:remote",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "multi_worker_test_base_test",
|
||||
srcs = ["multi_worker_test_base_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":multi_worker_test_base",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "checkpoint_utils_test",
|
||||
size = "medium",
|
||||
|
@ -41,6 +41,9 @@ from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import remote
|
||||
from tensorflow.python.framework import errors
|
||||
@ -200,6 +203,156 @@ def create_in_process_cluster(num_workers,
|
||||
return cluster
|
||||
|
||||
|
||||
class MultiProcessCluster(object):
|
||||
"""A cluster of TensorFlow servers in separate processes.
|
||||
|
||||
This class is not thread-safe.
|
||||
"""
|
||||
|
||||
def __init__(self, cluster_resolver):
|
||||
self._cluster_resolver = cluster_resolver
|
||||
self._cluster_spec = cluster_resolver.cluster_spec().as_dict()
|
||||
self._rpc_layer = cluster_resolver.rpc_layer
|
||||
self._start_events = {}
|
||||
self._finish_events = {}
|
||||
self._mpr_manager = multi_process_runner.manager()
|
||||
|
||||
def task_function(start_events, finish_events):
|
||||
cluster_resolver = TFConfigClusterResolver()
|
||||
cluster_spec = cluster_resolver.cluster_spec()
|
||||
task_type = cluster_resolver.task_type
|
||||
task_id = cluster_resolver.task_id
|
||||
rpc_layer = cluster_resolver.rpc_layer
|
||||
|
||||
logging.info(
|
||||
'Starting server with cluster_spec = %r, task_type = %r, '
|
||||
'task_id = %r, rpc_layer = %r', cluster_spec, task_type, task_id,
|
||||
rpc_layer)
|
||||
|
||||
# TODO(yuefengz): support GPU clusters.
|
||||
server_config = config_pb2.ConfigProto()
|
||||
server_config.device_count['GPU'] = 0
|
||||
|
||||
server_lib.Server(
|
||||
cluster_spec,
|
||||
job_name=task_type,
|
||||
protocol=rpc_layer,
|
||||
task_index=task_id,
|
||||
config=server_config,
|
||||
start=True)
|
||||
|
||||
start_event = start_events[task_type][task_id]
|
||||
start_event.set()
|
||||
|
||||
finish_event = finish_events[task_type][task_id]
|
||||
finish_event.wait()
|
||||
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
|
||||
self._task_function = task_function
|
||||
self._mpr = None
|
||||
|
||||
def start(self):
|
||||
"""Starts one TensorFlow server for each task in the cluster_resolver.
|
||||
|
||||
It will wait until all the servers are up before returns.
|
||||
"""
|
||||
if self._mpr:
|
||||
raise ValueError('The cluster has already been started.')
|
||||
for task_type, task_addresses in self._cluster_spec.items():
|
||||
self._start_events[task_type] = []
|
||||
self._finish_events[task_type] = []
|
||||
for _ in task_addresses:
|
||||
self._start_events[task_type].append(self._mpr_manager.Event())
|
||||
self._finish_events[task_type].append(self._mpr_manager.Event())
|
||||
|
||||
self._mpr = multi_process_runner.MultiProcessRunner(
|
||||
self._task_function,
|
||||
self._cluster_spec,
|
||||
args=(self._start_events, self._finish_events),
|
||||
rpc_layer=self._rpc_layer,
|
||||
stream_stdout=False,
|
||||
list_stdout=False,
|
||||
use_dill_for_args=False)
|
||||
self._mpr.start()
|
||||
for task_type, task_addresses in self._cluster_spec.items():
|
||||
for i in range(len(task_addresses)):
|
||||
self._start_events[task_type][i].wait()
|
||||
|
||||
def stop(self):
|
||||
"""Stops all the servers."""
|
||||
for task_type, task_addresses in self._cluster_spec.items():
|
||||
for i in range(len(task_addresses)):
|
||||
self._finish_events[task_type][i].set()
|
||||
try:
|
||||
self._mpr.join()
|
||||
except multi_process_runner.UnexpectedSubprocessExitError:
|
||||
# TODO(yuefengz): investigate why processes exit with 255.
|
||||
pass
|
||||
self._mpr = None
|
||||
self._start_events = {}
|
||||
self._finish_events = {}
|
||||
|
||||
def kill_task(self, task_type, task_id):
|
||||
"""Kill a server given task_type and task_id.
|
||||
|
||||
Args:
|
||||
task_type: the type of the task such as "worker".
|
||||
task_id: the id the task such as 1.
|
||||
"""
|
||||
assert self._mpr
|
||||
if (not self._start_events[task_type][task_id].is_set() or
|
||||
self._finish_events[task_type][task_id].is_set()):
|
||||
raise ValueError("The task %s:%d doesn't exist." % (task_type, task_id))
|
||||
|
||||
self._finish_events[task_type][task_id].set()
|
||||
self._mpr._processes[(task_type, task_id)].join()
|
||||
|
||||
def start_task(self, task_type, task_id):
|
||||
"""Starts a server given task_type and task_id.
|
||||
|
||||
Args:
|
||||
task_type: the type of the task such as "worker".
|
||||
task_id: the id the task such as 1.
|
||||
|
||||
Raises:
|
||||
ValueError: if the server alreay exists.
|
||||
"""
|
||||
assert self._mpr
|
||||
|
||||
if (not self._start_events[task_type][task_id].is_set() or
|
||||
not self._finish_events[task_type][task_id].is_set()):
|
||||
raise ValueError(
|
||||
'The task %s:%d is still alive. You cannot start another one.' %
|
||||
(task_type, task_id))
|
||||
self._start_events[task_type][task_id] = self._mpr_manager.Event()
|
||||
self._finish_events[task_type][task_id] = self._mpr_manager.Event()
|
||||
self._mpr.start_single_process(task_type=task_type, task_id=task_id)
|
||||
self._start_events[task_type][task_id].wait()
|
||||
|
||||
@property
|
||||
def cluster_resolver(self):
|
||||
return copy.deepcopy(self._cluster_resolver)
|
||||
|
||||
|
||||
def create_multi_process_cluster(num_workers,
|
||||
num_ps,
|
||||
has_chief=False,
|
||||
has_eval=False,
|
||||
rpc_layer='grpc'):
|
||||
cluster_spec = create_cluster_spec(
|
||||
has_chief=has_chief,
|
||||
num_workers=num_workers,
|
||||
num_ps=num_ps,
|
||||
has_eval=has_eval)
|
||||
|
||||
cluster = MultiProcessCluster(
|
||||
SimpleClusterResolver(
|
||||
server_lib.ClusterSpec(cluster_spec), rpc_layer=rpc_layer))
|
||||
cluster.start()
|
||||
return cluster
|
||||
|
||||
|
||||
# TODO(rchao): Remove `test_obj` once estimator repo picks up the updated
|
||||
# nightly TF.
|
||||
def create_cluster_spec(has_chief=False,
|
||||
|
82
tensorflow/python/distribute/multi_worker_test_base_test.py
Normal file
82
tensorflow/python/distribute/multi_worker_test_base_test.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for multi-process clusters."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import remote
|
||||
from tensorflow.python.eager import test
|
||||
|
||||
|
||||
class MultiProcessClusterTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(MultiProcessClusterTest, self).setUp()
|
||||
self._cluster = multi_worker_test_base.create_multi_process_cluster(
|
||||
num_workers=2, num_ps=1, has_chief=True, rpc_layer="grpc")
|
||||
remote.connect_to_cluster(
|
||||
self._cluster.cluster_resolver.cluster_spec(), protocol="grpc")
|
||||
context.ensure_initialized()
|
||||
|
||||
def testClusterIsAlive(self):
|
||||
self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
|
||||
self.assertTrue(context.check_alive("/job:worker/replica:0/task:1"))
|
||||
self.assertTrue(context.check_alive("/job:ps/replica:0/task:0"))
|
||||
self.assertTrue(context.check_alive("/job:chief/replica:0/task:0"))
|
||||
|
||||
def testKillAndStartTask(self):
|
||||
self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
|
||||
|
||||
# It is not allowed to start a task before killing it.
|
||||
with self.assertRaises(ValueError):
|
||||
self._cluster.start_task("worker", 0)
|
||||
|
||||
self._cluster.kill_task("worker", 0)
|
||||
self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
|
||||
|
||||
# The task is already killed.
|
||||
with self.assertRaises(ValueError):
|
||||
self._cluster.kill_task("worker", 0)
|
||||
|
||||
self._cluster.start_task("worker", 0)
|
||||
|
||||
# Without a call to update_server_def, the next check_alive will return
|
||||
# False. Alternatively sleeping for 2 seconds here also works.
|
||||
context.context().update_server_def(context.get_server_def())
|
||||
|
||||
self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
|
||||
|
||||
def testStop(self):
|
||||
self._cluster.stop()
|
||||
self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
|
||||
self.assertFalse(context.check_alive("/job:worker/replica:0/task:1"))
|
||||
self.assertFalse(context.check_alive("/job:ps/replica:0/task:0"))
|
||||
self.assertFalse(context.check_alive("/job:chief/replica:0/task:0"))
|
||||
|
||||
def testClusterResolverProperty(self):
|
||||
cluster_spec = self._cluster.cluster_resolver.cluster_spec().as_dict()
|
||||
|
||||
self.assertEqual(len(cluster_spec["worker"]), 2)
|
||||
self.assertEqual(len(cluster_spec["ps"]), 1)
|
||||
self.assertEqual(len(cluster_spec["chief"]), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
multi_process_runner.test_main()
|
Loading…
Reference in New Issue
Block a user