Internal refactoring.
PiperOrigin-RevId: 235271388
This commit is contained in:
parent
135fee1685
commit
4af8712749
@ -37,12 +37,16 @@ except ImportError as _error: # pylint: disable=invalid-name
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.estimator import run_config
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
|
||||
original_run_std_server = dc._run_std_server # pylint: disable=protected-access
|
||||
|
||||
ASSIGNED_PORTS = set()
|
||||
lock = threading.Lock()
|
||||
|
||||
@ -357,6 +361,22 @@ class MockOsEnv(collections.Mapping):
|
||||
class IndependentWorkerTestBase(test.TestCase):
|
||||
"""Testing infra for independent workers."""
|
||||
|
||||
def _make_mock_run_std_server(self):
|
||||
thread_local = threading.local()
|
||||
|
||||
def _mock_run_std_server(*args, **kwargs):
|
||||
ret = original_run_std_server(*args, **kwargs)
|
||||
# Wait for all std servers to be brought up in order to reduce the chance
|
||||
# of remote sessions taking local ports that have been assigned to std
|
||||
# servers. Only call this barrier the first time this function is run for
|
||||
# each thread.
|
||||
if not getattr(thread_local, 'server_started', False):
|
||||
self._barrier.wait()
|
||||
thread_local.server_started = True
|
||||
return ret
|
||||
|
||||
return _mock_run_std_server
|
||||
|
||||
def setUp(self):
|
||||
self._mock_os_env = MockOsEnv()
|
||||
self._mock_context = test.mock.patch.object(os, 'environ',
|
||||
|
Loading…
Reference in New Issue
Block a user