From 4af87127490470cb6464546113a3c59d8167f8ae Mon Sep 17 00:00:00 2001 From: Rick Chao Date: Fri, 22 Feb 2019 15:35:52 -0800 Subject: [PATCH] Internal refactoring. PiperOrigin-RevId: 235271388 --- .../python/multi_worker_test_base.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index 7dca13a5b41..d951a8b783e 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -37,12 +37,16 @@ except ImportError as _error: # pylint: disable=invalid-name from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session +from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.estimator import run_config from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import coordinator from tensorflow.python.training import server_lib + +original_run_std_server = dc._run_std_server # pylint: disable=protected-access + ASSIGNED_PORTS = set() lock = threading.Lock() @@ -357,6 +361,22 @@ class MockOsEnv(collections.Mapping): class IndependentWorkerTestBase(test.TestCase): """Testing infra for independent workers.""" + def _make_mock_run_std_server(self): + thread_local = threading.local() + + def _mock_run_std_server(*args, **kwargs): + ret = original_run_std_server(*args, **kwargs) + # Wait for all std servers to be brought up in order to reduce the chance + # of remote sessions taking local ports that have been assigned to std + # servers. Only call this barrier the first time this function is run for + # each thread. + if not getattr(thread_local, 'server_started', False): + self._barrier.wait() + thread_local.server_started = True + return ret + + return _mock_run_std_server + def setUp(self): self._mock_os_env = MockOsEnv() self._mock_context = test.mock.patch.object(os, 'environ',