876 lines
34 KiB
Python
876 lines
34 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""A component for running distributed TensorFlow."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import copy
|
|
import json
|
|
import os
|
|
import threading
|
|
import time
|
|
|
|
from tensorflow.core.protobuf import config_pb2
|
|
from tensorflow.python.client import session
|
|
from tensorflow.python.distribute import distribute_coordinator_context
|
|
from tensorflow.python.distribute import multi_worker_util
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.training import coordinator
|
|
from tensorflow.python.training import monitored_session
|
|
from tensorflow.python.training import server_lib
|
|
|
|
|
|
_thread_local = threading.local()
|
|
|
|
|
|
class _TaskType(object):
|
|
PS = "ps"
|
|
WORKER = "worker"
|
|
CHIEF = "chief"
|
|
EVALUATOR = "evaluator"
|
|
CLIENT = "client"
|
|
|
|
|
|
# TODO(yuefengz): support another mode where the client colocates with one
|
|
# worker.
|
|
class CoordinatorMode(object):
|
|
"""Specify how distribute coordinator runs."""
|
|
# The default mode where distribute coordinator will run as a standalone
|
|
# client and connects to remote servers for training. Each remote server can
|
|
# use the distribute coordinator binary with task_type set correctly which
|
|
# will then turn into standard servers.
|
|
STANDALONE_CLIENT = "standalone_client"
|
|
|
|
# The distribute coordinator runs on each worker. It will run a standard
|
|
# server on each worker and optionally run the `worker_fn` that is configured
|
|
# to talk to its standard server.
|
|
INDEPENDENT_WORKER = "independent_worker"
|
|
|
|
|
|
class _Barrier(object):
|
|
"""A reusable barrier class for worker synchronization."""
|
|
|
|
def __init__(self, num_participants):
|
|
"""Initializes the barrier object.
|
|
|
|
Args:
|
|
num_participants: an integer which is the expected number of calls of
|
|
`wait` pass to through this barrier.
|
|
"""
|
|
self._num_participants = num_participants
|
|
self._counter = 0
|
|
self._flag = False
|
|
self._local_sense = threading.local()
|
|
self._lock = threading.Lock()
|
|
self._condition = threading.Condition()
|
|
|
|
def wait(self):
|
|
"""Waits until all other callers reach the same wait call."""
|
|
self._local_sense.value = not self._flag
|
|
with self._lock:
|
|
self._counter += 1
|
|
if self._counter == self._num_participants:
|
|
self._counter = 0
|
|
self._flag = self._local_sense.value
|
|
with self._condition:
|
|
while self._flag != self._local_sense.value:
|
|
self._condition.wait()
|
|
self._condition.notify_all()
|
|
|
|
|
|
def _get_num_workers(cluster_spec):
|
|
"""Gets number of workers including chief."""
|
|
if not cluster_spec:
|
|
return 0
|
|
return len(cluster_spec.as_dict().get(_TaskType.WORKER, [])) + len(
|
|
cluster_spec.as_dict().get(_TaskType.CHIEF, []))
|
|
|
|
|
|
class _WorkerContext(object):
|
|
"""The worker context class.
|
|
|
|
This context object provides configuration information for each task. One
|
|
context manager with a worker context object will be created per
|
|
invocation to the `worker_fn` where `get_current_worker_context` can be called
|
|
to access the worker context object.
|
|
"""
|
|
|
|
def __init__(self,
|
|
strategy,
|
|
cluster_spec,
|
|
task_type,
|
|
task_id,
|
|
session_config=None,
|
|
rpc_layer="grpc",
|
|
worker_barrier=None):
|
|
"""Initialize the worker context object.
|
|
|
|
Args:
|
|
strategy: a `DistributionStrategy` object.
|
|
cluster_spec: a ClusterSpec object. It can be empty or None in the local
|
|
training case.
|
|
task_type: a string indicating the role of the corresponding task, such as
|
|
"worker" or "ps". It can be None if it is local training or in-graph
|
|
replicated training.
|
|
task_id: an integer indicating id of the corresponding task. It can be
|
|
None if it is local training or in-graph replicated training.
|
|
session_config: an optional `tf.compat.v1.ConfigProto` object.
|
|
rpc_layer: optional string specifying the RPC protocol for communication
|
|
with worker masters. If None or empty, hosts in the `cluster_spec` will
|
|
be used directly.
|
|
worker_barrier: optional, the barrier object for worker synchronization.
|
|
"""
|
|
self._strategy = strategy
|
|
self._cluster_spec = cluster_spec
|
|
self._task_type = task_type
|
|
self._task_id = task_id
|
|
self._session_config = session_config
|
|
self._worker_barrier = worker_barrier
|
|
self._rpc_layer = rpc_layer
|
|
self._master_target = self._get_master_target()
|
|
self._num_workers = _get_num_workers(cluster_spec)
|
|
self._is_chief_node = self._is_chief()
|
|
|
|
def _debug_message(self):
|
|
if self._cluster_spec:
|
|
return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
|
|
self._cluster_spec, self.task_type, self.task_id)
|
|
else:
|
|
return "[local]"
|
|
|
|
def __enter__(self):
|
|
old_context = distribute_coordinator_context.get_current_worker_context()
|
|
if old_context:
|
|
raise ValueError(
|
|
"You cannot run distribute coordinator in a `worker_fn`.\t" +
|
|
self._debug_message())
|
|
# pylint: disable=protected-access
|
|
distribute_coordinator_context._worker_context.current = self
|
|
|
|
def __exit__(self, unused_exception_type, unused_exception_value,
|
|
unused_traceback):
|
|
# pylint: disable=protected-access
|
|
distribute_coordinator_context._worker_context.current = None
|
|
|
|
def _get_master_target(self):
|
|
"""Return the master target for a task."""
|
|
# If cluster_spec is None or empty, we use local master.
|
|
if not self._cluster_spec:
|
|
return ""
|
|
|
|
# If task_type is None, then it is in-graph replicated training. In this
|
|
# case we use the chief or first worker's master target.
|
|
if not self._task_type:
|
|
if _TaskType.CHIEF in self._cluster_spec.jobs:
|
|
task_type = _TaskType.CHIEF
|
|
task_id = 0
|
|
else:
|
|
assert _TaskType.WORKER in self._cluster_spec.jobs
|
|
task_type = _TaskType.WORKER
|
|
task_id = 0
|
|
else:
|
|
task_type = self._task_type
|
|
task_id = self._task_id
|
|
|
|
prefix = ""
|
|
if self._rpc_layer:
|
|
prefix = self._rpc_layer + "://"
|
|
return prefix + self._cluster_spec.job_tasks(task_type)[task_id or 0]
|
|
|
|
def _is_chief(self):
|
|
"""Return whether the task is the chief worker."""
|
|
if (not self._cluster_spec or
|
|
self._task_type in [_TaskType.CHIEF, _TaskType.EVALUATOR, None]):
|
|
return True
|
|
|
|
# If not local and chief not in the cluster_spec, use the first worker as
|
|
# chief.
|
|
if (_TaskType.CHIEF not in self._cluster_spec.jobs and
|
|
self._task_type == _TaskType.WORKER and self._task_id == 0):
|
|
return True
|
|
return False
|
|
|
|
def wait_for_other_workers(self):
|
|
"""Waits for other workers to reach the same call to this method.
|
|
|
|
Raises:
|
|
ValueError: if `worker_barrier` is not passed to the __init__ method.
|
|
"""
|
|
if not self._worker_barrier:
|
|
# TODO(yuefengz): we should throw an error in independent worker mode.
|
|
return
|
|
self._worker_barrier.wait()
|
|
|
|
def session_creator(self,
|
|
scaffold=None,
|
|
config=None,
|
|
checkpoint_dir=None,
|
|
checkpoint_filename_with_path=None,
|
|
max_wait_secs=7200):
|
|
"""Returns a session creator.
|
|
|
|
The returned session creator will be configured with the correct master
|
|
target and session configs. It will also run either init ops or ready ops
|
|
by querying the `strategy` object when `create_session` is called on it.
|
|
|
|
Args:
|
|
scaffold: A `Scaffold` used for gathering or building supportive ops. If
|
|
not specified a default one is created. It's used to finalize the graph.
|
|
config: `ConfigProto` proto used to configure the session.
|
|
checkpoint_dir: A string. Optional path to a directory where to restore
|
|
variables.
|
|
checkpoint_filename_with_path: Full file name path to the checkpoint file.
|
|
Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be
|
|
specified.
|
|
max_wait_secs: Maximum time to wait for the session to become available.
|
|
|
|
Returns:
|
|
a descendant of SessionCreator.
|
|
"""
|
|
if config:
|
|
session_config = copy.deepcopy(config)
|
|
session_config.MergeFrom(self._session_config)
|
|
else:
|
|
session_config = self._session_config
|
|
|
|
if not self._strategy or self._strategy.extended.experimental_should_init:
|
|
logging.info("Creating chief session creator with config: %r", config)
|
|
return monitored_session.ChiefSessionCreator(
|
|
scaffold,
|
|
master=self.master_target,
|
|
config=session_config,
|
|
checkpoint_dir=checkpoint_dir,
|
|
checkpoint_filename_with_path=checkpoint_filename_with_path)
|
|
else:
|
|
logging.info("Creating worker session creator with config: %r", config)
|
|
return monitored_session.WorkerSessionCreator(
|
|
scaffold,
|
|
master=self.master_target,
|
|
config=session_config,
|
|
max_wait_secs=max_wait_secs)
|
|
|
|
@property
|
|
def session_config(self):
|
|
return copy.deepcopy(self._session_config)
|
|
|
|
@property
|
|
def has_barrier(self):
|
|
"""Whether the barrier is set or not."""
|
|
return self._worker_barrier is not None
|
|
|
|
@property
|
|
def distributed_mode(self):
|
|
"""Whether it is distributed training or not."""
|
|
return bool(self._cluster_spec) and self._task_type != _TaskType.EVALUATOR
|
|
|
|
@property
|
|
def cluster_spec(self):
|
|
"""Returns a copy of the cluster_spec object."""
|
|
return copy.deepcopy(self._cluster_spec)
|
|
|
|
@property
|
|
def task_type(self):
|
|
"""Returns the role of the corresponding task."""
|
|
return self._task_type
|
|
|
|
@property
|
|
def task_id(self):
|
|
"""Returns the id or index of the corresponding task."""
|
|
return self._task_id
|
|
|
|
@property
|
|
def master_target(self):
|
|
"""Returns the session master for the corresponding task to connect to."""
|
|
return self._master_target
|
|
|
|
@property
|
|
def is_chief(self):
|
|
"""Returns whether the task is a chief node."""
|
|
return self._is_chief_node
|
|
|
|
@property
|
|
def num_workers(self):
|
|
"""Returns number of workers in the cluster, including chief."""
|
|
return self._num_workers
|
|
|
|
@property
|
|
def experimental_should_init(self):
|
|
"""Whether to run init ops."""
|
|
return self._strategy.extended.experimental_should_init
|
|
|
|
@property
|
|
def should_checkpoint(self):
|
|
"""Whether to save checkpoint."""
|
|
return self._strategy.extended.should_checkpoint
|
|
|
|
@property
|
|
def should_save_summary(self):
|
|
"""Whether to save summaries."""
|
|
return self._strategy.extended.should_save_summary
|
|
|
|
|
|
def _run_single_worker(worker_fn,
|
|
strategy,
|
|
cluster_spec,
|
|
task_type,
|
|
task_id,
|
|
session_config,
|
|
rpc_layer="",
|
|
worker_barrier=None,
|
|
coord=None):
|
|
"""Runs a single worker by calling `worker_fn` under context."""
|
|
session_config = copy.deepcopy(session_config)
|
|
strategy = copy.deepcopy(strategy)
|
|
# If there is an EVALUATOR task, we run single-machine eval on that task.
|
|
if task_type == _TaskType.EVALUATOR:
|
|
# It is possible to not have a strategy object for EVALUATOR task.
|
|
if strategy:
|
|
strategy.configure(session_config)
|
|
else:
|
|
assert strategy
|
|
strategy.configure(session_config, cluster_spec, task_type, task_id)
|
|
|
|
context = _WorkerContext(
|
|
strategy,
|
|
cluster_spec,
|
|
task_type,
|
|
task_id,
|
|
session_config=session_config,
|
|
rpc_layer=rpc_layer,
|
|
worker_barrier=worker_barrier)
|
|
with context:
|
|
if coord:
|
|
with coord.stop_on_exception():
|
|
return worker_fn(strategy)
|
|
else:
|
|
return worker_fn(strategy)
|
|
|
|
|
|
def _split_cluster_for_evaluator(cluster_spec, task_type):
|
|
"""Split the cluster for evaluator since it needn't talk to other tasks."""
|
|
# Splitting the cluster is important to prevent the evaluator from talking to
|
|
# other tasks in the cluster. Since we allow evaluator not to use
|
|
# distribution strategies and as a result ops in the evaluator task may have
|
|
# unspecified devices. Those ops may end up on other tasks if we don't split
|
|
# the cluster.
|
|
# Note: if you bypass distribute coordinator and bring the cluster yourself,
|
|
# you can equivalently set device filters to split clusters. This is already
|
|
# done by distribution strategy's `update_config_proto` method.
|
|
new_cluster_spec = multi_worker_util.normalize_cluster_spec(
|
|
cluster_spec).as_dict()
|
|
if task_type == _TaskType.EVALUATOR:
|
|
assert _TaskType.EVALUATOR in new_cluster_spec
|
|
new_cluster_spec = {
|
|
_TaskType.EVALUATOR: new_cluster_spec[_TaskType.EVALUATOR]
|
|
}
|
|
else:
|
|
new_cluster_spec.pop(_TaskType.EVALUATOR, None)
|
|
return multi_worker_util.normalize_cluster_spec(new_cluster_spec)
|
|
|
|
|
|
def _run_std_server(cluster_spec=None,
|
|
task_type=None,
|
|
task_id=None,
|
|
session_config=None,
|
|
rpc_layer=None,
|
|
environment=None):
|
|
"""Runs a standard server."""
|
|
# Check if the Server is already running. If so, assert that no configuration
|
|
# options have changed, and return the existing Server. This allows us to
|
|
# call `run_distribute_coordinator` multiple times.
|
|
if getattr(_thread_local, "server", None) is not None:
|
|
assert _thread_local.cluster_spec == cluster_spec
|
|
assert _thread_local.task_type == task_type
|
|
assert _thread_local.task_id == task_id
|
|
assert _thread_local.session_config_str == repr(session_config)
|
|
assert _thread_local.rpc_layer == rpc_layer
|
|
assert _thread_local.environment == environment
|
|
return _thread_local.server
|
|
else:
|
|
# This method is not thread-safe.
|
|
_thread_local.server_started = True
|
|
_thread_local.cluster_spec = cluster_spec
|
|
_thread_local.task_type = task_type
|
|
_thread_local.task_id = task_id
|
|
_thread_local.session_config_str = repr(session_config)
|
|
_thread_local.rpc_layer = rpc_layer
|
|
_thread_local.environment = environment
|
|
|
|
assert cluster_spec
|
|
target = cluster_spec.task_address(task_type, task_id)
|
|
if rpc_layer:
|
|
target = rpc_layer + "://" + target
|
|
|
|
class _FakeServer(object):
|
|
"""A fake server that runs a master session."""
|
|
|
|
def start(self):
|
|
# A tensorflow server starts when a remote session is created.
|
|
logging.info(
|
|
"Creating a remote session to start a TensorFlow server, "
|
|
"target = %r, session_config=%r", target, session_config)
|
|
session.Session(target=target, config=session_config)
|
|
|
|
def join(self):
|
|
while True:
|
|
time.sleep(5)
|
|
|
|
if environment == "google":
|
|
server = _FakeServer()
|
|
else:
|
|
if session_config:
|
|
logging.info(
|
|
"Starting standard TensorFlow server, target = %r, session_config= "
|
|
"%r", target, session_config)
|
|
else:
|
|
logging.info("Starting standard TensorFlow server, target = %r", target)
|
|
cluster_spec = _split_cluster_for_evaluator(cluster_spec, task_type)
|
|
server = server_lib.Server(
|
|
cluster_spec,
|
|
job_name=task_type,
|
|
task_index=task_id,
|
|
config=session_config,
|
|
protocol=rpc_layer)
|
|
|
|
server.start()
|
|
_thread_local.server = server
|
|
return server
|
|
|
|
|
|
def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
|
cluster_spec, session_config, rpc_layer):
|
|
"""Runs a standalone client for between-graph replication."""
|
|
coord = coordinator.Coordinator()
|
|
eval_thread = None
|
|
if _TaskType.EVALUATOR in cluster_spec.jobs:
|
|
eval_thread = threading.Thread(
|
|
target=_run_single_worker,
|
|
args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
|
|
session_config),
|
|
kwargs={
|
|
"rpc_layer": rpc_layer,
|
|
"coord": coord,
|
|
})
|
|
eval_thread.start()
|
|
|
|
threads = []
|
|
worker_barrier = _Barrier(_get_num_workers(cluster_spec))
|
|
for task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
|
|
for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
|
|
t = threading.Thread(
|
|
target=_run_single_worker,
|
|
args=(worker_fn, strategy, cluster_spec, task_type, task_id,
|
|
session_config),
|
|
kwargs={
|
|
"rpc_layer": rpc_layer,
|
|
"worker_barrier": worker_barrier,
|
|
"coord": coord,
|
|
})
|
|
t.start()
|
|
threads.append(t)
|
|
|
|
if eval_thread:
|
|
# TODO(yuefengz): is it necessary to join eval thread?
|
|
threads_to_join = threads + [eval_thread]
|
|
else:
|
|
threads_to_join = threads
|
|
coord.join(threads_to_join)
|
|
|
|
# TODO(yuefengz): we probably want to return results from all workers?
|
|
return None
|
|
|
|
|
|
def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
|
cluster_spec, session_config, rpc_layer):
|
|
"""Runs a standalone client for in-graph replication."""
|
|
coord = coordinator.Coordinator()
|
|
eval_thread = None
|
|
if _TaskType.EVALUATOR in cluster_spec.jobs:
|
|
eval_thread = threading.Thread(
|
|
target=_run_single_worker,
|
|
args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
|
|
session_config),
|
|
kwargs={
|
|
"rpc_layer": rpc_layer,
|
|
"coord": coord,
|
|
})
|
|
eval_thread.start()
|
|
|
|
worker_result = _run_single_worker(
|
|
worker_fn,
|
|
strategy,
|
|
cluster_spec,
|
|
None,
|
|
None,
|
|
session_config,
|
|
rpc_layer=rpc_layer,
|
|
coord=coord)
|
|
|
|
if eval_thread:
|
|
coord.join([eval_thread])
|
|
|
|
return worker_result
|
|
|
|
|
|
def _configure_session_config_for_std_servers(
|
|
strategy, eval_strategy, session_config, cluster_spec, task_type, task_id):
|
|
# pylint: disable=g-doc-args
|
|
"""Call strategy's `configure` to mutate the session_config.
|
|
|
|
The session_config is currently needed as default config for a TensorFlow
|
|
server. In the future, we should be able to remove this method and only pass
|
|
the session config to a client session.
|
|
"""
|
|
if task_type == _TaskType.EVALUATOR:
|
|
if eval_strategy:
|
|
eval_strategy.configure(session_config=session_config)
|
|
else:
|
|
# The strategy may be shared in standalone client mode.
|
|
strategy = copy.deepcopy(strategy)
|
|
strategy.configure(
|
|
session_config=session_config,
|
|
cluster_spec=cluster_spec,
|
|
task_type=task_type,
|
|
task_id=task_id)
|
|
# Remove the device filters specific to the strategy, so that the
|
|
# TensorFlow server brought up with one strategy can be used by other
|
|
# strategies. The device filters can be set in the client side as well.
|
|
del session_config.device_filters[:]
|
|
|
|
|
|
def run_standard_tensorflow_server(session_config=None):
|
|
"""Starts a standard TensorFlow server.
|
|
|
|
This method parses configurations from "TF_CONFIG" environment variable and
|
|
starts a TensorFlow server. The "TF_CONFIG" is typically a json string and
|
|
must have information of the cluster and the role of the server in the
|
|
cluster. One example is:
|
|
|
|
TF_CONFIG='{
|
|
"cluster": {
|
|
"worker": ["host1:2222", "host2:2222", "host3:2222"],
|
|
"ps": ["host4:2222", "host5:2222"]
|
|
},
|
|
"task": {"type": "worker", "index": 1}
|
|
}'
|
|
|
|
This "TF_CONFIG" specifies there are 3 workers and 2 ps tasks in the cluster
|
|
and the current role is worker 1.
|
|
|
|
Valid task types are "chief", "worker", "ps" and "evaluator" and you can have
|
|
at most one "chief" and at most one "evaluator".
|
|
|
|
An optional key-value can be specified is "rpc_layer". The default value is
|
|
"grpc".
|
|
|
|
Args:
|
|
session_config: an optional `tf.compat.v1.ConfigProto` object. Users can
|
|
pass in the session config object to configure server-local devices.
|
|
|
|
Returns:
|
|
a `tf.distribute.Server` object which has already been started.
|
|
|
|
Raises:
|
|
ValueError: if the "TF_CONFIG" environment is not complete.
|
|
"""
|
|
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
|
|
if "cluster" not in tf_config:
|
|
raise ValueError("\"cluster\" is not found in TF_CONFIG.")
|
|
cluster_spec = multi_worker_util.normalize_cluster_spec(tf_config["cluster"])
|
|
if "task" not in tf_config:
|
|
raise ValueError("\"task\" is not found in TF_CONFIG.")
|
|
task_env = tf_config["task"]
|
|
if "type" not in task_env:
|
|
raise ValueError(
|
|
"\"task_type\" is not found in the `task` part of TF_CONFIG.")
|
|
task_type = task_env["type"]
|
|
task_id = int(task_env.get("index", 0))
|
|
|
|
rpc_layer = tf_config.get("rpc_layer", "grpc")
|
|
|
|
session_config = session_config or config_pb2.ConfigProto()
|
|
# Set the collective group leader for collective ops to initialize collective
|
|
# ops when server starts.
|
|
if "chief" in cluster_spec.jobs:
|
|
session_config.experimental.collective_group_leader = (
|
|
"/job:chief/replica:0/task:0")
|
|
else:
|
|
if "worker" not in cluster_spec.jobs:
|
|
raise ValueError(
|
|
"You must have `chief` or `worker` jobs in the `cluster_spec`.")
|
|
session_config.experimental.collective_group_leader = (
|
|
"/job:worker/replica:0/task:0")
|
|
|
|
server = _run_std_server(
|
|
cluster_spec=cluster_spec,
|
|
task_type=task_type,
|
|
task_id=task_id,
|
|
session_config=session_config,
|
|
rpc_layer=rpc_layer)
|
|
server.start()
|
|
return server
|
|
|
|
|
|
# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
|
|
# TODO(yuefengz): we may need a smart way to figure out whether the current task
|
|
# is the special task when we support cluster_spec propagation.
|
|
def run_distribute_coordinator(worker_fn,
|
|
strategy,
|
|
eval_fn=None,
|
|
eval_strategy=None,
|
|
mode=CoordinatorMode.STANDALONE_CLIENT,
|
|
cluster_spec=None,
|
|
task_type=None,
|
|
task_id=None,
|
|
session_config=None,
|
|
rpc_layer="grpc"):
|
|
"""Runs the coordinator for distributed TensorFlow.
|
|
|
|
This function runs a split coordinator for distributed TensorFlow in its
|
|
default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec`
|
|
specifying server addresses and their roles in a cluster, this coordinator
|
|
will figure out how to set them up, give the underlying function the right
|
|
targets for master sessions via a scope object and coordinate their training.
|
|
The cluster consisting of standard servers needs to be brought up either with
|
|
the standard server binary or with a binary running distribute coordinator
|
|
with `task_type` set to non-client type which will then turn into standard
|
|
servers.
|
|
|
|
In addition to be the distribute coordinator, this is also the source of
|
|
configurations for each job in the distributed training. As there are multiple
|
|
ways to configure a distributed TensorFlow cluster, its context object
|
|
provides these configurations so that users or higher-level APIs don't have to
|
|
figure out the configuration for each job by themselves.
|
|
|
|
In the between-graph replicated training, this coordinator will create
|
|
multiple threads and each calls the `worker_fn` which is supposed to create
|
|
its own graph and connect to one worker master given by its context object. In
|
|
the in-graph replicated training, it has only one thread calling this
|
|
`worker_fn`.
|
|
|
|
Another mode is the INDEPENDENT_WORKER mode where each server runs a
|
|
distribute coordinator which will start a standard server and optionally runs
|
|
`worker_fn` depending whether it is between-graph training or in-graph
|
|
replicated training.
|
|
|
|
The `strategy` object is expected to be a DistributionStrategy object which
|
|
has implemented methods needed by distributed coordinator such as
|
|
`configure(session_config, cluster_spec, task_type, task_id)` which configures
|
|
the strategy object for a specific task and `experimental_should_init`
|
|
property which instructs the distribute coordinator whether to run init ops
|
|
for a task. The distribute coordinator will make a copy of the `strategy`
|
|
object, call its `configure` method and pass it to `worker_fn` as an argument.
|
|
|
|
The `worker_fn` defines the training logic and is called under its own
|
|
worker context which can be accessed to via `get_current_worker_context`. A
|
|
worker context provides access to configurations for each task, e.g. the
|
|
task_type, task_id, master target and so on. Since `worker_fn` will be called
|
|
in a thread and possibly multiple times, caller should be careful when it
|
|
accesses global data. For example, it is unsafe to define flags in a
|
|
`worker_fn` or to define different environment variables for different
|
|
`worker_fn`s.
|
|
|
|
The `worker_fn` for the between-graph replication is defined as if there is
|
|
only one worker corresponding to the `worker_fn` and possibly ps jobs. For
|
|
example, when training with parameter servers, it assigns variables to
|
|
parameter servers and all other operations to that worker. In the in-graph
|
|
replication case, the `worker_fn` has to define operations for all worker
|
|
jobs. Using a distribution strategy can simplify the `worker_fn` by not having
|
|
to worry about the replication and device assignment of variables and
|
|
operations.
|
|
|
|
This method is intended to be invoked by high-level APIs so that users don't
|
|
have to explicitly call it to run this coordinator. For those who don't use
|
|
high-level APIs, to change a program to use this coordinator, wrap everything
|
|
in a the program after global data definitions such as commandline flag
|
|
definition into the `worker_fn` and get task-specific configurations from
|
|
the worker context.
|
|
|
|
The `cluster_spec` can be either passed by the argument or parsed from the
|
|
"TF_CONFIG" environment variable. Example of a TF_CONFIG:
|
|
```
|
|
cluster = {'chief': ['host0:2222'],
|
|
'ps': ['host1:2222', 'host2:2222'],
|
|
'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
|
|
os.environ['TF_CONFIG'] = json.dumps({'cluster': cluster})
|
|
```
|
|
|
|
If `cluster_spec` is not given in any format, it becomes local training and
|
|
this coordinator will connect to a local session.
|
|
|
|
For evaluation, if "evaluator" exists in the cluster_spec, a separate thread
|
|
will be created to call `eval_fn` with its `task_type` set to "evaluator". If
|
|
`eval_fn` is not defined, fall back to `worker_fn`. This implies that
|
|
evaluation will be done on a single machine if there is an "evaluator" task.
|
|
If "evaluator" doesn't exist in the cluster_spec, it entirely depends on the
|
|
`worker_fn` for how to do evaluation.
|
|
|
|
Args:
|
|
worker_fn: the function to be called. The function should accept a
|
|
`strategy` object and will be given access to a context object via a
|
|
context manager scope.
|
|
strategy: a DistributionStrategy object specifying whether it should
|
|
run between-graph replicated training or not, whether to run init ops,
|
|
etc. This object will also be configured given `session_config`,
|
|
`cluster_spec`, `task_type` and `task_id`.
|
|
eval_fn: optional function for "evaluator" task. If `eval_fn` is not passed
|
|
in but a "evaluator" task is found in the `cluster_spec`, the `worker_fn`
|
|
will be used for this task.
|
|
eval_strategy: optional DistributionStrategy object for "evaluator" task.
|
|
mode: in which mode this distribute coordinator runs.
|
|
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
|
|
in a cluster. If not set or empty, fall back to local training.
|
|
task_type: the current task type, optional if this is a client.
|
|
task_id: the current task id, optional if this is a client.
|
|
session_config: an optional `tf.compat.v1.ConfigProto` object which will be
|
|
passed to `strategy`'s `configure` method and used to create a session.
|
|
rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
|
|
|
|
Raises:
|
|
ValueError: if `cluster_spec` is supplied but not a dict or a ClusterDef or
|
|
a ClusterSpec.
|
|
|
|
Returns:
|
|
In the client job, return the value returned by `worker_fn` if
|
|
it is in-graph replication or INDEPENDENT_WORKER mode; return None
|
|
otherwise.
|
|
"""
|
|
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
|
|
rpc_layer = tf_config.get("rpc_layer", rpc_layer)
|
|
environment = tf_config.get("environment", None)
|
|
|
|
if not cluster_spec:
|
|
cluster_spec = tf_config.get("cluster", {})
|
|
task_env = tf_config.get("task", {})
|
|
if task_env:
|
|
task_type = task_env.get("type", task_type)
|
|
task_id = int(task_env.get("index", task_id))
|
|
|
|
if cluster_spec:
|
|
# TODO(yuefengz): validate cluster_spec.
|
|
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
|
|
elif hasattr(strategy.extended, "_cluster_resolver"):
|
|
cluster_resolver = strategy.extended._cluster_resolver # pylint: disable=protected-access
|
|
task_type = cluster_resolver.task_type
|
|
task_id = cluster_resolver.task_id
|
|
rpc_layer = cluster_resolver.rpc_layer or rpc_layer
|
|
environment = cluster_resolver.environment
|
|
cluster_spec = cluster_resolver.cluster_spec()
|
|
|
|
# Setting the session config is necessary for some strategies such as
|
|
# CollectiveAllReduceStrategy.
|
|
session_config = session_config or config_pb2.ConfigProto(
|
|
allow_soft_placement=True)
|
|
|
|
if cluster_spec:
|
|
logging.info(
|
|
"Running Distribute Coordinator with mode = %r, cluster_spec = %r, "
|
|
"task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode,
|
|
cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer)
|
|
|
|
if not cluster_spec:
|
|
# `mode` is ignored in the local case.
|
|
logging.info("Running local Distribute Coordinator.")
|
|
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
|
|
rpc_layer)
|
|
if eval_fn:
|
|
_run_single_worker(eval_fn, eval_strategy, None, None, None,
|
|
session_config, rpc_layer)
|
|
else:
|
|
logging.warning("Skipped evaluation since `eval_fn` is not passed in.")
|
|
elif mode == CoordinatorMode.STANDALONE_CLIENT:
|
|
if not eval_fn:
|
|
logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
|
|
"used if an \"evaluator\" task exists in the cluster.")
|
|
eval_fn = eval_fn or worker_fn
|
|
if not eval_strategy:
|
|
logging.warning("`eval_strategy` is not passed in. No distribution "
|
|
"strategy will be used for evaluation.")
|
|
|
|
# The client must know the cluster but servers in the cluster don't have to
|
|
# know the client.
|
|
if task_type in [_TaskType.CLIENT, None]:
|
|
if strategy.extended.experimental_between_graph:
|
|
return _run_between_graph_client(worker_fn, strategy, eval_fn,
|
|
eval_strategy, cluster_spec,
|
|
session_config, rpc_layer)
|
|
else:
|
|
return _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
|
cluster_spec, session_config, rpc_layer)
|
|
else:
|
|
# If not a client job, run the standard server.
|
|
_configure_session_config_for_std_servers(strategy, eval_strategy,
|
|
session_config, cluster_spec,
|
|
task_type, task_id)
|
|
server = _run_std_server(
|
|
cluster_spec=cluster_spec,
|
|
task_type=task_type,
|
|
task_id=task_id,
|
|
session_config=session_config,
|
|
rpc_layer=rpc_layer,
|
|
environment=environment)
|
|
server.join()
|
|
else:
|
|
if mode != CoordinatorMode.INDEPENDENT_WORKER:
|
|
raise ValueError("Unexpected coordinator mode: %r" % mode)
|
|
|
|
if not eval_fn:
|
|
logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
|
|
"used if an \"evaluator\" task exists in the cluster.")
|
|
eval_fn = eval_fn or worker_fn
|
|
if not eval_strategy:
|
|
logging.warning("`eval_strategy` is not passed in. No distribution "
|
|
"strategy will be used for evaluation.")
|
|
|
|
# Every one starts a standard server, get session config from `configure`
|
|
# method.
|
|
_configure_session_config_for_std_servers(strategy, eval_strategy,
|
|
session_config, cluster_spec,
|
|
task_type, task_id)
|
|
|
|
if not getattr(strategy.extended, "_std_server_started", False):
|
|
# Right now, with eager mode, context is configured with a std server at
|
|
# the very beginning while with graph mode the std server is started when
|
|
# distribute coordinator is called. We should consolidate these two paths.
|
|
server = _run_std_server(
|
|
cluster_spec=cluster_spec,
|
|
task_type=task_type,
|
|
task_id=task_id,
|
|
session_config=session_config,
|
|
rpc_layer=rpc_layer,
|
|
environment=environment)
|
|
if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
|
|
if strategy.extended.experimental_between_graph:
|
|
# All jobs run `worker_fn` if between-graph.
|
|
return _run_single_worker(worker_fn, strategy, cluster_spec, task_type,
|
|
task_id, session_config, rpc_layer)
|
|
else:
|
|
# Only one node runs `worker_fn` if in-graph.
|
|
context = _WorkerContext(strategy, cluster_spec, task_type, task_id)
|
|
if context.is_chief:
|
|
return _run_single_worker(worker_fn, strategy, cluster_spec, None,
|
|
None, session_config, rpc_layer)
|
|
else:
|
|
server.join()
|
|
elif task_type == _TaskType.EVALUATOR:
|
|
return _run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type,
|
|
task_id, session_config, rpc_layer)
|
|
else:
|
|
if task_type != _TaskType.PS:
|
|
raise ValueError("Unexpected task_type: %r" % task_type)
|
|
server.join()
|