Re-submit the coordinator change.

PiperOrigin-RevId: 225069740
This commit is contained in:
Yuefeng Zhou 2018-12-11 14:18:26 -08:00 committed by TensorFlower Gardener
parent 5dd912f2d7
commit 40345bd2c3
3 changed files with 36 additions and 16 deletions

View File

@ -375,11 +375,13 @@ class DistributeCoordinatorIntegrationTest(
threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn,
cluster_spec, train_distribute,
eval_distribute)
threads_to_join = []
for task_type, ts in threads.items():
if task_type == PS:
continue
for t in ts:
t.join()
threads_to_join.append(t)
self.join_independent_workers(threads_to_join)
estimator = self._get_estimator(train_distribute, eval_distribute)
self._inspect_train_and_eval_events(estimator)
@ -413,8 +415,7 @@ class DistributeCoordinatorIntegrationTest(
threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn,
cluster_spec, train_distribute,
eval_distribute)
threads[WORKER][0].join()
threads[EVALUATOR][0].join()
self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]])
estimator = self._get_estimator(train_distribute, eval_distribute)
self._inspect_train_and_eval_events(estimator)

View File

@ -40,6 +40,7 @@ from tensorflow.python.client import session
from tensorflow.python.estimator import run_config
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import coordinator
from tensorflow.python.training import server_lib
ASSIGNED_PORTS = set()
@ -360,6 +361,7 @@ class IndependentWorkerTestBase(test.TestCase):
self._mock_os_env = MockOsEnv()
self._mock_context = test.mock.patch.object(os, 'environ',
self._mock_os_env)
self._coord = coordinator.Coordinator()
super(IndependentWorkerTestBase, self).setUp()
self._mock_context.__enter__()
@ -368,6 +370,7 @@ class IndependentWorkerTestBase(test.TestCase):
super(IndependentWorkerTestBase, self).tearDown()
def _task_thread(self, task_fn, tf_config, *args, **kwargs):
with self._coord.stop_on_exception():
os.environ['TF_CONFIG'] = json.dumps(tf_config)
task_fn(*args, **kwargs)
@ -403,3 +406,6 @@ class IndependentWorkerTestBase(test.TestCase):
*args, **kwargs)
threads[task_type].append(t)
return threads
def join_independent_workers(self, worker_threads):
self._coord.join(worker_threads)

View File

@ -29,6 +29,7 @@ from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import coordinator
from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib
@ -328,7 +329,8 @@ def _run_single_worker(worker_fn,
task_id,
session_config,
rpc_layer="",
worker_barrier=None):
worker_barrier=None,
coord=None):
"""Runs a single worker by calling `worker_fn` under context."""
session_config = copy.deepcopy(session_config)
strategy = copy.deepcopy(strategy)
@ -350,6 +352,10 @@ def _run_single_worker(worker_fn,
rpc_layer=rpc_layer,
worker_barrier=worker_barrier)
with context:
if coord:
with coord.stop_on_exception():
return worker_fn(strategy)
else:
return worker_fn(strategy)
@ -423,6 +429,7 @@ def _run_std_server(cluster_spec=None,
def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
cluster_spec, session_config, rpc_layer):
"""Runs a standalone client for between-graph replication."""
coord = coordinator.Coordinator()
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
@ -431,6 +438,7 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
session_config),
kwargs={
"rpc_layer": rpc_layer,
"coord": coord,
})
eval_thread.start()
@ -444,18 +452,18 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
session_config),
kwargs={
"rpc_layer": rpc_layer,
"worker_barrier": worker_barrier
"worker_barrier": worker_barrier,
"coord": coord,
})
t.start()
threads.append(t)
# TODO(yuefengz): wrap threads into thread coordinator?
for t in threads:
t.join()
# TODO(yuefengz): is it necessary to join eval thread?
if eval_thread:
eval_thread.join()
# TODO(yuefengz): is it necessary to join eval thread?
threads_to_join = threads + [eval_thread]
else:
threads_to_join = threads
coord.join(threads_to_join)
# TODO(yuefengz): we probably want to return results from all workers?
return None
@ -464,6 +472,7 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
cluster_spec, session_config, rpc_layer):
"""Runs a standalone client for in-graph replication."""
coord = coordinator.Coordinator()
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
@ -472,6 +481,7 @@ def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
session_config),
kwargs={
"rpc_layer": rpc_layer,
"coord": coord,
})
eval_thread.start()
@ -482,9 +492,12 @@ def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
None,
None,
session_config,
rpc_layer=rpc_layer)
rpc_layer=rpc_layer,
coord=coord)
if eval_thread:
eval_thread.join()
coord.join([eval_thread])
return worker_result