Add environment and rpc_layer to the TF_CONFIG environment variable in distribute coordinator.
PiperOrigin-RevId: 210197404
This commit is contained in:
parent
ca94990804
commit
04ffe2f349
@ -22,9 +22,12 @@ import copy
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
from tensorflow.core.protobuf import cluster_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.distribute import distribute_coordinator_context
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
@ -332,16 +335,38 @@ def _run_std_server(cluster_spec=None,
|
||||
task_type=None,
|
||||
task_id=None,
|
||||
session_config=None,
|
||||
rpc_layer=None):
|
||||
rpc_layer=None,
|
||||
environment=None):
|
||||
"""Runs a standard server."""
|
||||
server = server_lib.Server(
|
||||
cluster_spec,
|
||||
job_name=task_type,
|
||||
task_index=task_id,
|
||||
config=session_config,
|
||||
protocol=rpc_layer)
|
||||
server.start()
|
||||
return server
|
||||
|
||||
class _FakeServer(object):
|
||||
"""A fake server that runs a master session."""
|
||||
|
||||
def start(self):
|
||||
assert cluster_spec
|
||||
target = cluster_spec.task_address(task_type, task_id)
|
||||
if rpc_layer:
|
||||
target = rpc_layer + "://" + target
|
||||
# A tensorflow server starts when a remote session is created.
|
||||
session.Session(target=target, config=session_config)
|
||||
|
||||
def join(self):
|
||||
while True:
|
||||
time.sleep(5)
|
||||
|
||||
if environment == "google":
|
||||
server = _FakeServer()
|
||||
server.start()
|
||||
return server
|
||||
else:
|
||||
server = server_lib.Server(
|
||||
cluster_spec,
|
||||
job_name=task_type,
|
||||
task_index=task_id,
|
||||
config=session_config,
|
||||
protocol=rpc_layer)
|
||||
server.start()
|
||||
return server
|
||||
|
||||
|
||||
def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
||||
@ -541,8 +566,18 @@ def run_distribute_coordinator(worker_fn,
|
||||
"`tf.train.ClusterDef` object")
|
||||
# TODO(yuefengz): validate cluster_spec.
|
||||
|
||||
rpc_layer = tf_config.get("rpc_layer", rpc_layer)
|
||||
environment = tf_config.get("environment", None)
|
||||
|
||||
if cluster_spec:
|
||||
logging.info(
|
||||
"Running Distribute Coordinator with mode = %r, cluster_spec = %r, "
|
||||
"task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode,
|
||||
cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer)
|
||||
|
||||
if not cluster_spec:
|
||||
# `mode` is ignored in the local case.
|
||||
logging.info("Running local Distribute Coordinator.")
|
||||
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
|
||||
rpc_layer)
|
||||
if eval_fn:
|
||||
@ -564,7 +599,11 @@ def run_distribute_coordinator(worker_fn,
|
||||
else:
|
||||
# If not a client job, run the standard server.
|
||||
server = _run_std_server(
|
||||
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
|
||||
cluster_spec=cluster_spec,
|
||||
task_type=task_type,
|
||||
task_id=task_id,
|
||||
rpc_layer=rpc_layer,
|
||||
environment=environment)
|
||||
server.join()
|
||||
else:
|
||||
if mode != CoordinatorMode.INDEPENDENT_WORKER:
|
||||
@ -575,7 +614,11 @@ def run_distribute_coordinator(worker_fn,
|
||||
|
||||
# Every one starts a standard server.
|
||||
server = _run_std_server(
|
||||
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
|
||||
cluster_spec=cluster_spec,
|
||||
task_type=task_type,
|
||||
task_id=task_id,
|
||||
rpc_layer=rpc_layer,
|
||||
environment=environment)
|
||||
|
||||
if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
|
||||
if strategy.between_graph:
|
||||
|
@ -20,8 +20,10 @@ from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import six
|
||||
|
||||
@ -59,6 +61,8 @@ INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER
|
||||
NUM_WORKERS = 3
|
||||
NUM_PS = 2
|
||||
|
||||
original_sys_exit = sys.exit
|
||||
|
||||
|
||||
def _bytes_to_str(maybe_bytes):
|
||||
if isinstance(maybe_bytes, six.string_types):
|
||||
@ -369,7 +373,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
|
||||
cluster_spec=None,
|
||||
task_type=None,
|
||||
task_id=None,
|
||||
rpc_layer=None):
|
||||
rpc_layer=None,
|
||||
environment=None):
|
||||
task_type = str(task_type)
|
||||
task_id = task_id or 0
|
||||
with self._lock:
|
||||
@ -730,6 +735,63 @@ class DistributeCoordinatorTestInpendentWorkerMode(
|
||||
self.assertTrue(self._std_servers[WORKER][2].joined)
|
||||
self.assertFalse(self._std_servers[EVALUATOR][0].joined)
|
||||
|
||||
def testRunStdServerInGoogleEnvironment(self):
|
||||
cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}
|
||||
tf_config = {"cluster": cluster_spec, "environment": "google"}
|
||||
|
||||
joined = [False]
|
||||
|
||||
def _fake_sleep(_):
|
||||
joined[0] = True
|
||||
original_sys_exit(0)
|
||||
|
||||
def _thread_fn(cluster_spec):
|
||||
distribute_coordinator.run_distribute_coordinator(
|
||||
None,
|
||||
None,
|
||||
mode=INDEPENDENT_WORKER,
|
||||
cluster_spec=cluster_spec,
|
||||
task_type="ps",
|
||||
task_id=0)
|
||||
|
||||
with test.mock.patch.dict(
|
||||
"os.environ",
|
||||
{"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
|
||||
time, "sleep", _fake_sleep):
|
||||
t = threading.Thread(target=_thread_fn, args=(cluster_spec,))
|
||||
t.start()
|
||||
t.join()
|
||||
self.assertTrue(joined[0])
|
||||
|
||||
def testRpcLayerEnvironmentVariable(self):
|
||||
cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
|
||||
tf_config = {"cluster": cluster_spec, "rpc_layer": "cake"}
|
||||
|
||||
rpc_layer_from_coordinator = [None]
|
||||
|
||||
def _run_mock_server(cluster_spec=None,
|
||||
task_type=None,
|
||||
task_id=None,
|
||||
session_config=None,
|
||||
rpc_layer=None,
|
||||
environment=None):
|
||||
del cluster_spec, task_type, task_id, session_config, environment
|
||||
rpc_layer_from_coordinator[0] = rpc_layer
|
||||
return MockServer()
|
||||
|
||||
with test.mock.patch.dict(
|
||||
"os.environ",
|
||||
{"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
|
||||
distribute_coordinator, "_run_std_server", _run_mock_server):
|
||||
distribute_coordinator.run_distribute_coordinator(
|
||||
None,
|
||||
None,
|
||||
mode=INDEPENDENT_WORKER,
|
||||
cluster_spec=cluster_spec,
|
||||
task_type="ps",
|
||||
task_id=0)
|
||||
self.assertEqual(rpc_layer_from_coordinator[0], "cake")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO(yuefengz): find a smart way to terminite std server threads.
|
||||
|
Loading…
Reference in New Issue
Block a user