From 40345bd2c3cfdcb095f8cdd7595f4a1eb9698f8f Mon Sep 17 00:00:00 2001 From: Yuefeng Zhou Date: Tue, 11 Dec 2018 14:18:26 -0800 Subject: [PATCH] Re-submit the coordinator change. PiperOrigin-RevId: 225069740 --- .../python/estimator_training_test.py | 7 ++-- .../python/multi_worker_test_base.py | 10 ++++-- .../distribute/distribute_coordinator.py | 35 +++++++++++++------ 3 files changed, 36 insertions(+), 16 deletions(-) diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index b369a7fefe6..3f55a8a1c8b 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -375,11 +375,13 @@ class DistributeCoordinatorIntegrationTest( threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, cluster_spec, train_distribute, eval_distribute) + threads_to_join = [] for task_type, ts in threads.items(): if task_type == PS: continue for t in ts: - t.join() + threads_to_join.append(t) + self.join_independent_workers(threads_to_join) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator) @@ -413,8 +415,7 @@ class DistributeCoordinatorIntegrationTest( threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, cluster_spec, train_distribute, eval_distribute) - threads[WORKER][0].join() - threads[EVALUATOR][0].join() + self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]]) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator) diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index 147c9b83f86..b05aac431f6 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -40,6 +40,7 @@ from tensorflow.python.client import session from tensorflow.python.estimator import run_config from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import coordinator from tensorflow.python.training import server_lib ASSIGNED_PORTS = set() @@ -360,6 +361,7 @@ class IndependentWorkerTestBase(test.TestCase): self._mock_os_env = MockOsEnv() self._mock_context = test.mock.patch.object(os, 'environ', self._mock_os_env) + self._coord = coordinator.Coordinator() super(IndependentWorkerTestBase, self).setUp() self._mock_context.__enter__() @@ -368,8 +370,9 @@ class IndependentWorkerTestBase(test.TestCase): super(IndependentWorkerTestBase, self).tearDown() def _task_thread(self, task_fn, tf_config, *args, **kwargs): - os.environ['TF_CONFIG'] = json.dumps(tf_config) - task_fn(*args, **kwargs) + with self._coord.stop_on_exception(): + os.environ['TF_CONFIG'] = json.dumps(tf_config) + task_fn(*args, **kwargs) def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id, *args, **kwargs): @@ -403,3 +406,6 @@ class IndependentWorkerTestBase(test.TestCase): *args, **kwargs) threads[task_type].append(t) return threads + + def join_independent_workers(self, worker_threads): + self._coord.join(worker_threads) diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py index c0f9b8a1fdf..78c995a5782 100644 --- a/tensorflow/python/distribute/distribute_coordinator.py +++ b/tensorflow/python/distribute/distribute_coordinator.py @@ -29,6 +29,7 @@ from tensorflow.python.client import session from tensorflow.python.distribute import distribute_coordinator_context from tensorflow.python.distribute import multi_worker_util from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import coordinator from tensorflow.python.training import monitored_session from tensorflow.python.training import server_lib @@ -328,7 +329,8 @@ def _run_single_worker(worker_fn, task_id, session_config, rpc_layer="", - worker_barrier=None): + worker_barrier=None, + coord=None): """Runs a single worker by calling `worker_fn` under context.""" session_config = copy.deepcopy(session_config) strategy = copy.deepcopy(strategy) @@ -350,7 +352,11 @@ def _run_single_worker(worker_fn, rpc_layer=rpc_layer, worker_barrier=worker_barrier) with context: - return worker_fn(strategy) + if coord: + with coord.stop_on_exception(): + return worker_fn(strategy) + else: + return worker_fn(strategy) def _split_cluster_for_evaluator(cluster_spec, task_type): @@ -423,6 +429,7 @@ def _run_std_server(cluster_spec=None, def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy, cluster_spec, session_config, rpc_layer): """Runs a standalone client for between-graph replication.""" + coord = coordinator.Coordinator() eval_thread = None if _TaskType.EVALUATOR in cluster_spec.jobs: eval_thread = threading.Thread( @@ -431,6 +438,7 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy, session_config), kwargs={ "rpc_layer": rpc_layer, + "coord": coord, }) eval_thread.start() @@ -444,18 +452,18 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy, session_config), kwargs={ "rpc_layer": rpc_layer, - "worker_barrier": worker_barrier + "worker_barrier": worker_barrier, + "coord": coord, }) t.start() threads.append(t) - # TODO(yuefengz): wrap threads into thread coordinator? - for t in threads: - t.join() - - # TODO(yuefengz): is it necessary to join eval thread? if eval_thread: - eval_thread.join() + # TODO(yuefengz): is it necessary to join eval thread? + threads_to_join = threads + [eval_thread] + else: + threads_to_join = threads + coord.join(threads_to_join) # TODO(yuefengz): we probably want to return results from all workers? return None @@ -464,6 +472,7 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy, def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy, cluster_spec, session_config, rpc_layer): """Runs a standalone client for in-graph replication.""" + coord = coordinator.Coordinator() eval_thread = None if _TaskType.EVALUATOR in cluster_spec.jobs: eval_thread = threading.Thread( @@ -472,6 +481,7 @@ def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy, session_config), kwargs={ "rpc_layer": rpc_layer, + "coord": coord, }) eval_thread.start() @@ -482,9 +492,12 @@ def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy, None, None, session_config, - rpc_layer=rpc_layer) + rpc_layer=rpc_layer, + coord=coord) + if eval_thread: - eval_thread.join() + coord.join([eval_thread]) + return worker_result