Internal refactoring.

PiperOrigin-RevId: 235271388
This commit is contained in:
Rick Chao 2019-02-22 15:35:52 -08:00 committed by TensorFlower Gardener
parent 135fee1685
commit 4af8712749

View File

@ -37,12 +37,16 @@ except ImportError as _error: # pylint: disable=invalid-name
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.estimator import run_config
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import coordinator
from tensorflow.python.training import server_lib
original_run_std_server = dc._run_std_server # pylint: disable=protected-access
ASSIGNED_PORTS = set()
lock = threading.Lock()
@ -357,6 +361,22 @@ class MockOsEnv(collections.Mapping):
class IndependentWorkerTestBase(test.TestCase):
"""Testing infra for independent workers."""
def _make_mock_run_std_server(self):
thread_local = threading.local()
def _mock_run_std_server(*args, **kwargs):
ret = original_run_std_server(*args, **kwargs)
# Wait for all std servers to be brought up in order to reduce the chance
# of remote sessions taking local ports that have been assigned to std
# servers. Only call this barrier the first time this function is run for
# each thread.
if not getattr(thread_local, 'server_started', False):
self._barrier.wait()
thread_local.server_started = True
return ret
return _mock_run_std_server
def setUp(self):
self._mock_os_env = MockOsEnv()
self._mock_context = test.mock.patch.object(os, 'environ',