Re-submit the coordinator change.

PiperOrigin-RevId: 225069740
This commit is contained in:
Yuefeng Zhou 2018-12-11 14:18:26 -08:00 committed by TensorFlower Gardener
parent 5dd912f2d7
commit 40345bd2c3
3 changed files with 36 additions and 16 deletions

View File

@ -375,11 +375,13 @@ class DistributeCoordinatorIntegrationTest(
threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn,
cluster_spec, train_distribute, cluster_spec, train_distribute,
eval_distribute) eval_distribute)
threads_to_join = []
for task_type, ts in threads.items(): for task_type, ts in threads.items():
if task_type == PS: if task_type == PS:
continue continue
for t in ts: for t in ts:
t.join() threads_to_join.append(t)
self.join_independent_workers(threads_to_join)
estimator = self._get_estimator(train_distribute, eval_distribute) estimator = self._get_estimator(train_distribute, eval_distribute)
self._inspect_train_and_eval_events(estimator) self._inspect_train_and_eval_events(estimator)
@ -413,8 +415,7 @@ class DistributeCoordinatorIntegrationTest(
threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn,
cluster_spec, train_distribute, cluster_spec, train_distribute,
eval_distribute) eval_distribute)
threads[WORKER][0].join() self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]])
threads[EVALUATOR][0].join()
estimator = self._get_estimator(train_distribute, eval_distribute) estimator = self._get_estimator(train_distribute, eval_distribute)
self._inspect_train_and_eval_events(estimator) self._inspect_train_and_eval_events(estimator)

View File

@ -40,6 +40,7 @@ from tensorflow.python.client import session
from tensorflow.python.estimator import run_config from tensorflow.python.estimator import run_config
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import coordinator
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
ASSIGNED_PORTS = set() ASSIGNED_PORTS = set()
@ -360,6 +361,7 @@ class IndependentWorkerTestBase(test.TestCase):
self._mock_os_env = MockOsEnv() self._mock_os_env = MockOsEnv()
self._mock_context = test.mock.patch.object(os, 'environ', self._mock_context = test.mock.patch.object(os, 'environ',
self._mock_os_env) self._mock_os_env)
self._coord = coordinator.Coordinator()
super(IndependentWorkerTestBase, self).setUp() super(IndependentWorkerTestBase, self).setUp()
self._mock_context.__enter__() self._mock_context.__enter__()
@ -368,8 +370,9 @@ class IndependentWorkerTestBase(test.TestCase):
super(IndependentWorkerTestBase, self).tearDown() super(IndependentWorkerTestBase, self).tearDown()
def _task_thread(self, task_fn, tf_config, *args, **kwargs): def _task_thread(self, task_fn, tf_config, *args, **kwargs):
os.environ['TF_CONFIG'] = json.dumps(tf_config) with self._coord.stop_on_exception():
task_fn(*args, **kwargs) os.environ['TF_CONFIG'] = json.dumps(tf_config)
task_fn(*args, **kwargs)
def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id, def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id,
*args, **kwargs): *args, **kwargs):
@ -403,3 +406,6 @@ class IndependentWorkerTestBase(test.TestCase):
*args, **kwargs) *args, **kwargs)
threads[task_type].append(t) threads[task_type].append(t)
return threads return threads
def join_independent_workers(self, worker_threads):
self._coord.join(worker_threads)

View File

@ -29,6 +29,7 @@ from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator_context from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import coordinator
from tensorflow.python.training import monitored_session from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
@ -328,7 +329,8 @@ def _run_single_worker(worker_fn,
task_id, task_id,
session_config, session_config,
rpc_layer="", rpc_layer="",
worker_barrier=None): worker_barrier=None,
coord=None):
"""Runs a single worker by calling `worker_fn` under context.""" """Runs a single worker by calling `worker_fn` under context."""
session_config = copy.deepcopy(session_config) session_config = copy.deepcopy(session_config)
strategy = copy.deepcopy(strategy) strategy = copy.deepcopy(strategy)
@ -350,7 +352,11 @@ def _run_single_worker(worker_fn,
rpc_layer=rpc_layer, rpc_layer=rpc_layer,
worker_barrier=worker_barrier) worker_barrier=worker_barrier)
with context: with context:
return worker_fn(strategy) if coord:
with coord.stop_on_exception():
return worker_fn(strategy)
else:
return worker_fn(strategy)
def _split_cluster_for_evaluator(cluster_spec, task_type): def _split_cluster_for_evaluator(cluster_spec, task_type):
@ -423,6 +429,7 @@ def _run_std_server(cluster_spec=None,
def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy, def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
cluster_spec, session_config, rpc_layer): cluster_spec, session_config, rpc_layer):
"""Runs a standalone client for between-graph replication.""" """Runs a standalone client for between-graph replication."""
coord = coordinator.Coordinator()
eval_thread = None eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs: if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread( eval_thread = threading.Thread(
@ -431,6 +438,7 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
session_config), session_config),
kwargs={ kwargs={
"rpc_layer": rpc_layer, "rpc_layer": rpc_layer,
"coord": coord,
}) })
eval_thread.start() eval_thread.start()
@ -444,18 +452,18 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
session_config), session_config),
kwargs={ kwargs={
"rpc_layer": rpc_layer, "rpc_layer": rpc_layer,
"worker_barrier": worker_barrier "worker_barrier": worker_barrier,
"coord": coord,
}) })
t.start() t.start()
threads.append(t) threads.append(t)
# TODO(yuefengz): wrap threads into thread coordinator?
for t in threads:
t.join()
# TODO(yuefengz): is it necessary to join eval thread?
if eval_thread: if eval_thread:
eval_thread.join() # TODO(yuefengz): is it necessary to join eval thread?
threads_to_join = threads + [eval_thread]
else:
threads_to_join = threads
coord.join(threads_to_join)
# TODO(yuefengz): we probably want to return results from all workers? # TODO(yuefengz): we probably want to return results from all workers?
return None return None
@ -464,6 +472,7 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy, def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
cluster_spec, session_config, rpc_layer): cluster_spec, session_config, rpc_layer):
"""Runs a standalone client for in-graph replication.""" """Runs a standalone client for in-graph replication."""
coord = coordinator.Coordinator()
eval_thread = None eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs: if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread( eval_thread = threading.Thread(
@ -472,6 +481,7 @@ def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
session_config), session_config),
kwargs={ kwargs={
"rpc_layer": rpc_layer, "rpc_layer": rpc_layer,
"coord": coord,
}) })
eval_thread.start() eval_thread.start()
@ -482,9 +492,12 @@ def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
None, None,
None, None,
session_config, session_config,
rpc_layer=rpc_layer) rpc_layer=rpc_layer,
coord=coord)
if eval_thread: if eval_thread:
eval_thread.join() coord.join([eval_thread])
return worker_result return worker_result