Re-submit the coordinator change.
PiperOrigin-RevId: 225069740
This commit is contained in:
parent
5dd912f2d7
commit
40345bd2c3
@ -375,11 +375,13 @@ class DistributeCoordinatorIntegrationTest(
|
|||||||
threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn,
|
threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn,
|
||||||
cluster_spec, train_distribute,
|
cluster_spec, train_distribute,
|
||||||
eval_distribute)
|
eval_distribute)
|
||||||
|
threads_to_join = []
|
||||||
for task_type, ts in threads.items():
|
for task_type, ts in threads.items():
|
||||||
if task_type == PS:
|
if task_type == PS:
|
||||||
continue
|
continue
|
||||||
for t in ts:
|
for t in ts:
|
||||||
t.join()
|
threads_to_join.append(t)
|
||||||
|
self.join_independent_workers(threads_to_join)
|
||||||
|
|
||||||
estimator = self._get_estimator(train_distribute, eval_distribute)
|
estimator = self._get_estimator(train_distribute, eval_distribute)
|
||||||
self._inspect_train_and_eval_events(estimator)
|
self._inspect_train_and_eval_events(estimator)
|
||||||
@ -413,8 +415,7 @@ class DistributeCoordinatorIntegrationTest(
|
|||||||
threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn,
|
threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn,
|
||||||
cluster_spec, train_distribute,
|
cluster_spec, train_distribute,
|
||||||
eval_distribute)
|
eval_distribute)
|
||||||
threads[WORKER][0].join()
|
self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]])
|
||||||
threads[EVALUATOR][0].join()
|
|
||||||
|
|
||||||
estimator = self._get_estimator(train_distribute, eval_distribute)
|
estimator = self._get_estimator(train_distribute, eval_distribute)
|
||||||
self._inspect_train_and_eval_events(estimator)
|
self._inspect_train_and_eval_events(estimator)
|
||||||
|
@ -40,6 +40,7 @@ from tensorflow.python.client import session
|
|||||||
from tensorflow.python.estimator import run_config
|
from tensorflow.python.estimator import run_config
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.training import coordinator
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
|
|
||||||
ASSIGNED_PORTS = set()
|
ASSIGNED_PORTS = set()
|
||||||
@ -360,6 +361,7 @@ class IndependentWorkerTestBase(test.TestCase):
|
|||||||
self._mock_os_env = MockOsEnv()
|
self._mock_os_env = MockOsEnv()
|
||||||
self._mock_context = test.mock.patch.object(os, 'environ',
|
self._mock_context = test.mock.patch.object(os, 'environ',
|
||||||
self._mock_os_env)
|
self._mock_os_env)
|
||||||
|
self._coord = coordinator.Coordinator()
|
||||||
super(IndependentWorkerTestBase, self).setUp()
|
super(IndependentWorkerTestBase, self).setUp()
|
||||||
self._mock_context.__enter__()
|
self._mock_context.__enter__()
|
||||||
|
|
||||||
@ -368,8 +370,9 @@ class IndependentWorkerTestBase(test.TestCase):
|
|||||||
super(IndependentWorkerTestBase, self).tearDown()
|
super(IndependentWorkerTestBase, self).tearDown()
|
||||||
|
|
||||||
def _task_thread(self, task_fn, tf_config, *args, **kwargs):
|
def _task_thread(self, task_fn, tf_config, *args, **kwargs):
|
||||||
os.environ['TF_CONFIG'] = json.dumps(tf_config)
|
with self._coord.stop_on_exception():
|
||||||
task_fn(*args, **kwargs)
|
os.environ['TF_CONFIG'] = json.dumps(tf_config)
|
||||||
|
task_fn(*args, **kwargs)
|
||||||
|
|
||||||
def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id,
|
def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id,
|
||||||
*args, **kwargs):
|
*args, **kwargs):
|
||||||
@ -403,3 +406,6 @@ class IndependentWorkerTestBase(test.TestCase):
|
|||||||
*args, **kwargs)
|
*args, **kwargs)
|
||||||
threads[task_type].append(t)
|
threads[task_type].append(t)
|
||||||
return threads
|
return threads
|
||||||
|
|
||||||
|
def join_independent_workers(self, worker_threads):
|
||||||
|
self._coord.join(worker_threads)
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.client import session
|
|||||||
from tensorflow.python.distribute import distribute_coordinator_context
|
from tensorflow.python.distribute import distribute_coordinator_context
|
||||||
from tensorflow.python.distribute import multi_worker_util
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.training import coordinator
|
||||||
from tensorflow.python.training import monitored_session
|
from tensorflow.python.training import monitored_session
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
|
|
||||||
@ -328,7 +329,8 @@ def _run_single_worker(worker_fn,
|
|||||||
task_id,
|
task_id,
|
||||||
session_config,
|
session_config,
|
||||||
rpc_layer="",
|
rpc_layer="",
|
||||||
worker_barrier=None):
|
worker_barrier=None,
|
||||||
|
coord=None):
|
||||||
"""Runs a single worker by calling `worker_fn` under context."""
|
"""Runs a single worker by calling `worker_fn` under context."""
|
||||||
session_config = copy.deepcopy(session_config)
|
session_config = copy.deepcopy(session_config)
|
||||||
strategy = copy.deepcopy(strategy)
|
strategy = copy.deepcopy(strategy)
|
||||||
@ -350,7 +352,11 @@ def _run_single_worker(worker_fn,
|
|||||||
rpc_layer=rpc_layer,
|
rpc_layer=rpc_layer,
|
||||||
worker_barrier=worker_barrier)
|
worker_barrier=worker_barrier)
|
||||||
with context:
|
with context:
|
||||||
return worker_fn(strategy)
|
if coord:
|
||||||
|
with coord.stop_on_exception():
|
||||||
|
return worker_fn(strategy)
|
||||||
|
else:
|
||||||
|
return worker_fn(strategy)
|
||||||
|
|
||||||
|
|
||||||
def _split_cluster_for_evaluator(cluster_spec, task_type):
|
def _split_cluster_for_evaluator(cluster_spec, task_type):
|
||||||
@ -423,6 +429,7 @@ def _run_std_server(cluster_spec=None,
|
|||||||
def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
||||||
cluster_spec, session_config, rpc_layer):
|
cluster_spec, session_config, rpc_layer):
|
||||||
"""Runs a standalone client for between-graph replication."""
|
"""Runs a standalone client for between-graph replication."""
|
||||||
|
coord = coordinator.Coordinator()
|
||||||
eval_thread = None
|
eval_thread = None
|
||||||
if _TaskType.EVALUATOR in cluster_spec.jobs:
|
if _TaskType.EVALUATOR in cluster_spec.jobs:
|
||||||
eval_thread = threading.Thread(
|
eval_thread = threading.Thread(
|
||||||
@ -431,6 +438,7 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
|||||||
session_config),
|
session_config),
|
||||||
kwargs={
|
kwargs={
|
||||||
"rpc_layer": rpc_layer,
|
"rpc_layer": rpc_layer,
|
||||||
|
"coord": coord,
|
||||||
})
|
})
|
||||||
eval_thread.start()
|
eval_thread.start()
|
||||||
|
|
||||||
@ -444,18 +452,18 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
|||||||
session_config),
|
session_config),
|
||||||
kwargs={
|
kwargs={
|
||||||
"rpc_layer": rpc_layer,
|
"rpc_layer": rpc_layer,
|
||||||
"worker_barrier": worker_barrier
|
"worker_barrier": worker_barrier,
|
||||||
|
"coord": coord,
|
||||||
})
|
})
|
||||||
t.start()
|
t.start()
|
||||||
threads.append(t)
|
threads.append(t)
|
||||||
|
|
||||||
# TODO(yuefengz): wrap threads into thread coordinator?
|
|
||||||
for t in threads:
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
# TODO(yuefengz): is it necessary to join eval thread?
|
|
||||||
if eval_thread:
|
if eval_thread:
|
||||||
eval_thread.join()
|
# TODO(yuefengz): is it necessary to join eval thread?
|
||||||
|
threads_to_join = threads + [eval_thread]
|
||||||
|
else:
|
||||||
|
threads_to_join = threads
|
||||||
|
coord.join(threads_to_join)
|
||||||
|
|
||||||
# TODO(yuefengz): we probably want to return results from all workers?
|
# TODO(yuefengz): we probably want to return results from all workers?
|
||||||
return None
|
return None
|
||||||
@ -464,6 +472,7 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
|||||||
def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
||||||
cluster_spec, session_config, rpc_layer):
|
cluster_spec, session_config, rpc_layer):
|
||||||
"""Runs a standalone client for in-graph replication."""
|
"""Runs a standalone client for in-graph replication."""
|
||||||
|
coord = coordinator.Coordinator()
|
||||||
eval_thread = None
|
eval_thread = None
|
||||||
if _TaskType.EVALUATOR in cluster_spec.jobs:
|
if _TaskType.EVALUATOR in cluster_spec.jobs:
|
||||||
eval_thread = threading.Thread(
|
eval_thread = threading.Thread(
|
||||||
@ -472,6 +481,7 @@ def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
|||||||
session_config),
|
session_config),
|
||||||
kwargs={
|
kwargs={
|
||||||
"rpc_layer": rpc_layer,
|
"rpc_layer": rpc_layer,
|
||||||
|
"coord": coord,
|
||||||
})
|
})
|
||||||
eval_thread.start()
|
eval_thread.start()
|
||||||
|
|
||||||
@ -482,9 +492,12 @@ def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
session_config,
|
session_config,
|
||||||
rpc_layer=rpc_layer)
|
rpc_layer=rpc_layer,
|
||||||
|
coord=coord)
|
||||||
|
|
||||||
if eval_thread:
|
if eval_thread:
|
||||||
eval_thread.join()
|
coord.join([eval_thread])
|
||||||
|
|
||||||
return worker_result
|
return worker_result
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user