Add environment and rpc_layer to the TF_CONFIG environment variable in distribute coordinator.

PiperOrigin-RevId: 210197404
This commit is contained in:
Yuefeng Zhou 2018-08-24 20:49:01 -07:00 committed by TensorFlower Gardener
parent ca94990804
commit 04ffe2f349
2 changed files with 117 additions and 12 deletions

View File

@ -22,9 +22,12 @@ import copy
import json
import os
import threading
import time
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib
@ -332,16 +335,38 @@ def _run_std_server(cluster_spec=None,
task_type=None,
task_id=None,
session_config=None,
rpc_layer=None):
rpc_layer=None,
environment=None):
"""Runs a standard server."""
server = server_lib.Server(
cluster_spec,
job_name=task_type,
task_index=task_id,
config=session_config,
protocol=rpc_layer)
server.start()
return server
class _FakeServer(object):
"""A fake server that runs a master session."""
def start(self):
assert cluster_spec
target = cluster_spec.task_address(task_type, task_id)
if rpc_layer:
target = rpc_layer + "://" + target
# A tensorflow server starts when a remote session is created.
session.Session(target=target, config=session_config)
def join(self):
while True:
time.sleep(5)
if environment == "google":
server = _FakeServer()
server.start()
return server
else:
server = server_lib.Server(
cluster_spec,
job_name=task_type,
task_index=task_id,
config=session_config,
protocol=rpc_layer)
server.start()
return server
def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
@ -541,8 +566,18 @@ def run_distribute_coordinator(worker_fn,
"`tf.train.ClusterDef` object")
# TODO(yuefengz): validate cluster_spec.
rpc_layer = tf_config.get("rpc_layer", rpc_layer)
environment = tf_config.get("environment", None)
if cluster_spec:
logging.info(
"Running Distribute Coordinator with mode = %r, cluster_spec = %r, "
"task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode,
cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer)
if not cluster_spec:
# `mode` is ignored in the local case.
logging.info("Running local Distribute Coordinator.")
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
rpc_layer)
if eval_fn:
@ -564,7 +599,11 @@ def run_distribute_coordinator(worker_fn,
else:
# If not a client job, run the standard server.
server = _run_std_server(
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
rpc_layer=rpc_layer,
environment=environment)
server.join()
else:
if mode != CoordinatorMode.INDEPENDENT_WORKER:
@ -575,7 +614,11 @@ def run_distribute_coordinator(worker_fn,
# Every one starts a standard server.
server = _run_std_server(
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
rpc_layer=rpc_layer,
environment=environment)
if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
if strategy.between_graph:

View File

@ -20,8 +20,10 @@ from __future__ import print_function
import contextlib
import copy
import json
import os
import sys
import time
import threading
import six
@ -59,6 +61,8 @@ INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER
NUM_WORKERS = 3
NUM_PS = 2
original_sys_exit = sys.exit
def _bytes_to_str(maybe_bytes):
if isinstance(maybe_bytes, six.string_types):
@ -369,7 +373,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
cluster_spec=None,
task_type=None,
task_id=None,
rpc_layer=None):
rpc_layer=None,
environment=None):
task_type = str(task_type)
task_id = task_id or 0
with self._lock:
@ -730,6 +735,63 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self.assertTrue(self._std_servers[WORKER][2].joined)
self.assertFalse(self._std_servers[EVALUATOR][0].joined)
def testRunStdServerInGoogleEnvironment(self):
cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}
tf_config = {"cluster": cluster_spec, "environment": "google"}
joined = [False]
def _fake_sleep(_):
joined[0] = True
original_sys_exit(0)
def _thread_fn(cluster_spec):
distribute_coordinator.run_distribute_coordinator(
None,
None,
mode=INDEPENDENT_WORKER,
cluster_spec=cluster_spec,
task_type="ps",
task_id=0)
with test.mock.patch.dict(
"os.environ",
{"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
time, "sleep", _fake_sleep):
t = threading.Thread(target=_thread_fn, args=(cluster_spec,))
t.start()
t.join()
self.assertTrue(joined[0])
def testRpcLayerEnvironmentVariable(self):
cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
tf_config = {"cluster": cluster_spec, "rpc_layer": "cake"}
rpc_layer_from_coordinator = [None]
def _run_mock_server(cluster_spec=None,
task_type=None,
task_id=None,
session_config=None,
rpc_layer=None,
environment=None):
del cluster_spec, task_type, task_id, session_config, environment
rpc_layer_from_coordinator[0] = rpc_layer
return MockServer()
with test.mock.patch.dict(
"os.environ",
{"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
distribute_coordinator, "_run_std_server", _run_mock_server):
distribute_coordinator.run_distribute_coordinator(
None,
None,
mode=INDEPENDENT_WORKER,
cluster_spec=cluster_spec,
task_type="ps",
task_id=0)
self.assertEqual(rpc_layer_from_coordinator[0], "cake")
if __name__ == "__main__":
# TODO(yuefengz): find a smart way to terminite std server threads.