Skip creating std server for evaluator.
PiperOrigin-RevId: 350173166 Change-Id: I78cc40331574f4fb7cd960e5ea30a77878e0a22d
This commit is contained in:
parent
dff50dd3ad
commit
36b4ef7a81
@ -169,7 +169,7 @@ class _WorkerContext(object):
|
|||||||
def _get_master_target(self):
|
def _get_master_target(self):
|
||||||
"""Return the master target for a task."""
|
"""Return the master target for a task."""
|
||||||
# If cluster_spec is None or empty, we use local master.
|
# If cluster_spec is None or empty, we use local master.
|
||||||
if not self._cluster_spec:
|
if not self._cluster_spec or self._task_type == _TaskType.EVALUATOR:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# If task_type is None, then it is in-graph replicated training. In this
|
# If task_type is None, then it is in-graph replicated training. In this
|
||||||
@ -842,7 +842,8 @@ def run_distribute_coordinator(worker_fn,
|
|||||||
session_config, cluster_spec,
|
session_config, cluster_spec,
|
||||||
task_type, task_id)
|
task_type, task_id)
|
||||||
|
|
||||||
if not getattr(strategy.extended, "_std_server_started", False):
|
if (task_type != _TaskType.EVALUATOR and
|
||||||
|
not getattr(strategy.extended, "_std_server_started", False)):
|
||||||
# Right now, with eager mode, context is configured with a std server at
|
# Right now, with eager mode, context is configured with a std server at
|
||||||
# the very beginning while with graph mode the std server is started when
|
# the very beginning while with graph mode the std server is started when
|
||||||
# distribute coordinator is called. We should consolidate these two paths.
|
# distribute coordinator is called. We should consolidate these two paths.
|
||||||
|
|||||||
@ -589,8 +589,7 @@ class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
|
|||||||
# and distributed_mode.
|
# and distributed_mode.
|
||||||
self.assertEqual(self._worker_context["None"][0], (_strip_protocol(
|
self.assertEqual(self._worker_context["None"][0], (_strip_protocol(
|
||||||
_bytes_to_str(self._workers[0].target)), 3, True, True))
|
_bytes_to_str(self._workers[0].target)), 3, True, True))
|
||||||
self.assertEqual(self._worker_context[EVALUATOR][0],
|
self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
|
||||||
("fake_evaluator", 3, True, False))
|
|
||||||
|
|
||||||
|
|
||||||
class DistributeCoordinatorTestIndependentWorkerMode(
|
class DistributeCoordinatorTestIndependentWorkerMode(
|
||||||
@ -755,19 +754,15 @@ class DistributeCoordinatorTestIndependentWorkerMode(
|
|||||||
# and distributed_mode.
|
# and distributed_mode.
|
||||||
self.assertEqual(self._worker_context["None"][0],
|
self.assertEqual(self._worker_context["None"][0],
|
||||||
(_bytes_to_str(cluster_spec[WORKER][0]), 3, True, True))
|
(_bytes_to_str(cluster_spec[WORKER][0]), 3, True, True))
|
||||||
self.assertEqual(self._worker_context[EVALUATOR][0],
|
self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
|
||||||
(cluster_spec[EVALUATOR][0], 3, True, False))
|
|
||||||
|
|
||||||
# Make sure each worker runs a std server.
|
# Make sure each worker runs a std server.
|
||||||
self.assertEqual(len(self._std_servers), 2)
|
self.assertEqual(len(self._std_servers), 1)
|
||||||
self.assertTrue(WORKER in self._std_servers)
|
self.assertTrue(WORKER in self._std_servers)
|
||||||
self.assertTrue(EVALUATOR in self._std_servers)
|
|
||||||
self.assertEqual(len(self._std_servers[WORKER]), 3)
|
self.assertEqual(len(self._std_servers[WORKER]), 3)
|
||||||
self.assertEqual(len(self._std_servers[EVALUATOR]), 1)
|
|
||||||
self.assertFalse(self._std_servers[WORKER][0].joined)
|
self.assertFalse(self._std_servers[WORKER][0].joined)
|
||||||
self.assertTrue(self._std_servers[WORKER][1].joined)
|
self.assertTrue(self._std_servers[WORKER][1].joined)
|
||||||
self.assertTrue(self._std_servers[WORKER][2].joined)
|
self.assertTrue(self._std_servers[WORKER][2].joined)
|
||||||
self.assertFalse(self._std_servers[EVALUATOR][0].joined)
|
|
||||||
|
|
||||||
def testRunStdServerInGoogleEnvironment(self):
|
def testRunStdServerInGoogleEnvironment(self):
|
||||||
cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}
|
cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user