Re-submit the coordinator change.
PiperOrigin-RevId: 225069740
This commit is contained in:
parent
5dd912f2d7
commit
40345bd2c3
@ -375,11 +375,13 @@ class DistributeCoordinatorIntegrationTest(
|
||||
threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn,
|
||||
cluster_spec, train_distribute,
|
||||
eval_distribute)
|
||||
threads_to_join = []
|
||||
for task_type, ts in threads.items():
|
||||
if task_type == PS:
|
||||
continue
|
||||
for t in ts:
|
||||
t.join()
|
||||
threads_to_join.append(t)
|
||||
self.join_independent_workers(threads_to_join)
|
||||
|
||||
estimator = self._get_estimator(train_distribute, eval_distribute)
|
||||
self._inspect_train_and_eval_events(estimator)
|
||||
@ -413,8 +415,7 @@ class DistributeCoordinatorIntegrationTest(
|
||||
threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn,
|
||||
cluster_spec, train_distribute,
|
||||
eval_distribute)
|
||||
threads[WORKER][0].join()
|
||||
threads[EVALUATOR][0].join()
|
||||
self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]])
|
||||
|
||||
estimator = self._get_estimator(train_distribute, eval_distribute)
|
||||
self._inspect_train_and_eval_events(estimator)
|
||||
|
@ -40,6 +40,7 @@ from tensorflow.python.client import session
|
||||
from tensorflow.python.estimator import run_config
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
ASSIGNED_PORTS = set()
|
||||
@ -360,6 +361,7 @@ class IndependentWorkerTestBase(test.TestCase):
|
||||
self._mock_os_env = MockOsEnv()
|
||||
self._mock_context = test.mock.patch.object(os, 'environ',
|
||||
self._mock_os_env)
|
||||
self._coord = coordinator.Coordinator()
|
||||
super(IndependentWorkerTestBase, self).setUp()
|
||||
self._mock_context.__enter__()
|
||||
|
||||
@ -368,8 +370,9 @@ class IndependentWorkerTestBase(test.TestCase):
|
||||
super(IndependentWorkerTestBase, self).tearDown()
|
||||
|
||||
def _task_thread(self, task_fn, tf_config, *args, **kwargs):
|
||||
os.environ['TF_CONFIG'] = json.dumps(tf_config)
|
||||
task_fn(*args, **kwargs)
|
||||
with self._coord.stop_on_exception():
|
||||
os.environ['TF_CONFIG'] = json.dumps(tf_config)
|
||||
task_fn(*args, **kwargs)
|
||||
|
||||
def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id,
|
||||
*args, **kwargs):
|
||||
@ -403,3 +406,6 @@ class IndependentWorkerTestBase(test.TestCase):
|
||||
*args, **kwargs)
|
||||
threads[task_type].append(t)
|
||||
return threads
|
||||
|
||||
def join_independent_workers(self, worker_threads):
|
||||
self._coord.join(worker_threads)
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.client import session
|
||||
from tensorflow.python.distribute import distribute_coordinator_context
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
@ -328,7 +329,8 @@ def _run_single_worker(worker_fn,
|
||||
task_id,
|
||||
session_config,
|
||||
rpc_layer="",
|
||||
worker_barrier=None):
|
||||
worker_barrier=None,
|
||||
coord=None):
|
||||
"""Runs a single worker by calling `worker_fn` under context."""
|
||||
session_config = copy.deepcopy(session_config)
|
||||
strategy = copy.deepcopy(strategy)
|
||||
@ -350,7 +352,11 @@ def _run_single_worker(worker_fn,
|
||||
rpc_layer=rpc_layer,
|
||||
worker_barrier=worker_barrier)
|
||||
with context:
|
||||
return worker_fn(strategy)
|
||||
if coord:
|
||||
with coord.stop_on_exception():
|
||||
return worker_fn(strategy)
|
||||
else:
|
||||
return worker_fn(strategy)
|
||||
|
||||
|
||||
def _split_cluster_for_evaluator(cluster_spec, task_type):
|
||||
@ -423,6 +429,7 @@ def _run_std_server(cluster_spec=None,
|
||||
def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
||||
cluster_spec, session_config, rpc_layer):
|
||||
"""Runs a standalone client for between-graph replication."""
|
||||
coord = coordinator.Coordinator()
|
||||
eval_thread = None
|
||||
if _TaskType.EVALUATOR in cluster_spec.jobs:
|
||||
eval_thread = threading.Thread(
|
||||
@ -431,6 +438,7 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
||||
session_config),
|
||||
kwargs={
|
||||
"rpc_layer": rpc_layer,
|
||||
"coord": coord,
|
||||
})
|
||||
eval_thread.start()
|
||||
|
||||
@ -444,18 +452,18 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
||||
session_config),
|
||||
kwargs={
|
||||
"rpc_layer": rpc_layer,
|
||||
"worker_barrier": worker_barrier
|
||||
"worker_barrier": worker_barrier,
|
||||
"coord": coord,
|
||||
})
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
# TODO(yuefengz): wrap threads into thread coordinator?
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# TODO(yuefengz): is it necessary to join eval thread?
|
||||
if eval_thread:
|
||||
eval_thread.join()
|
||||
# TODO(yuefengz): is it necessary to join eval thread?
|
||||
threads_to_join = threads + [eval_thread]
|
||||
else:
|
||||
threads_to_join = threads
|
||||
coord.join(threads_to_join)
|
||||
|
||||
# TODO(yuefengz): we probably want to return results from all workers?
|
||||
return None
|
||||
@ -464,6 +472,7 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
||||
def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
||||
cluster_spec, session_config, rpc_layer):
|
||||
"""Runs a standalone client for in-graph replication."""
|
||||
coord = coordinator.Coordinator()
|
||||
eval_thread = None
|
||||
if _TaskType.EVALUATOR in cluster_spec.jobs:
|
||||
eval_thread = threading.Thread(
|
||||
@ -472,6 +481,7 @@ def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
||||
session_config),
|
||||
kwargs={
|
||||
"rpc_layer": rpc_layer,
|
||||
"coord": coord,
|
||||
})
|
||||
eval_thread.start()
|
||||
|
||||
@ -482,9 +492,12 @@ def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
||||
None,
|
||||
None,
|
||||
session_config,
|
||||
rpc_layer=rpc_layer)
|
||||
rpc_layer=rpc_layer,
|
||||
coord=coord)
|
||||
|
||||
if eval_thread:
|
||||
eval_thread.join()
|
||||
coord.join([eval_thread])
|
||||
|
||||
return worker_result
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user