Skip creating std server for evaluator.

PiperOrigin-RevId: 350173166
Change-Id: I78cc40331574f4fb7cd960e5ea30a77878e0a22d
This commit is contained in:
Yuefeng Zhou 2021-01-05 10:45:29 -08:00 committed by TensorFlower Gardener
parent dff50dd3ad
commit 36b4ef7a81
2 changed files with 6 additions and 10 deletions

View File

@ -169,7 +169,7 @@ class _WorkerContext(object):
def _get_master_target(self): def _get_master_target(self):
"""Return the master target for a task.""" """Return the master target for a task."""
# If cluster_spec is None or empty, we use local master. # If cluster_spec is None or empty, we use local master.
if not self._cluster_spec: if not self._cluster_spec or self._task_type == _TaskType.EVALUATOR:
return "" return ""
# If task_type is None, then it is in-graph replicated training. In this # If task_type is None, then it is in-graph replicated training. In this
@ -842,7 +842,8 @@ def run_distribute_coordinator(worker_fn,
session_config, cluster_spec, session_config, cluster_spec,
task_type, task_id) task_type, task_id)
if not getattr(strategy.extended, "_std_server_started", False): if (task_type != _TaskType.EVALUATOR and
not getattr(strategy.extended, "_std_server_started", False)):
# Right now, with eager mode, context is configured with a std server at # Right now, with eager mode, context is configured with a std server at
# the very beginning while with graph mode the std server is started when # the very beginning while with graph mode the std server is started when
# distribute coordinator is called. We should consolidate these two paths. # distribute coordinator is called. We should consolidate these two paths.

View File

@ -589,8 +589,7 @@ class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
# and distributed_mode. # and distributed_mode.
self.assertEqual(self._worker_context["None"][0], (_strip_protocol( self.assertEqual(self._worker_context["None"][0], (_strip_protocol(
_bytes_to_str(self._workers[0].target)), 3, True, True)) _bytes_to_str(self._workers[0].target)), 3, True, True))
self.assertEqual(self._worker_context[EVALUATOR][0], self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
("fake_evaluator", 3, True, False))
class DistributeCoordinatorTestIndependentWorkerMode( class DistributeCoordinatorTestIndependentWorkerMode(
@ -755,19 +754,15 @@ class DistributeCoordinatorTestIndependentWorkerMode(
# and distributed_mode. # and distributed_mode.
self.assertEqual(self._worker_context["None"][0], self.assertEqual(self._worker_context["None"][0],
(_bytes_to_str(cluster_spec[WORKER][0]), 3, True, True)) (_bytes_to_str(cluster_spec[WORKER][0]), 3, True, True))
self.assertEqual(self._worker_context[EVALUATOR][0], self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
(cluster_spec[EVALUATOR][0], 3, True, False))
# Make sure each worker runs a std server. # Make sure each worker runs a std server.
self.assertEqual(len(self._std_servers), 2) self.assertEqual(len(self._std_servers), 1)
self.assertTrue(WORKER in self._std_servers) self.assertTrue(WORKER in self._std_servers)
self.assertTrue(EVALUATOR in self._std_servers)
self.assertEqual(len(self._std_servers[WORKER]), 3) self.assertEqual(len(self._std_servers[WORKER]), 3)
self.assertEqual(len(self._std_servers[EVALUATOR]), 1)
self.assertFalse(self._std_servers[WORKER][0].joined) self.assertFalse(self._std_servers[WORKER][0].joined)
self.assertTrue(self._std_servers[WORKER][1].joined) self.assertTrue(self._std_servers[WORKER][1].joined)
self.assertTrue(self._std_servers[WORKER][2].joined) self.assertTrue(self._std_servers[WORKER][2].joined)
self.assertFalse(self._std_servers[EVALUATOR][0].joined)
def testRunStdServerInGoogleEnvironment(self): def testRunStdServerInGoogleEnvironment(self):
cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]} cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}