542 lines
22 KiB
Python
542 lines
22 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Classes implementing a multi-worker ps DistributionStrategy."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import copy
|
|
|
|
|
|
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
|
from tensorflow.python.distribute import device_util
|
|
from tensorflow.python.distribute import distribute_lib
|
|
from tensorflow.python.distribute import input_lib
|
|
from tensorflow.python.distribute import mirrored_strategy
|
|
from tensorflow.python.distribute import multi_worker_util
|
|
from tensorflow.python.distribute import numpy_dataset
|
|
from tensorflow.python.distribute import values
|
|
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
|
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import device as tf_device
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.ops import variable_scope as vs
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.training import device_setter
|
|
from tensorflow.python.util import nest
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
_LOCAL_CPU = "/device:CPU:0"
|
|
_LOCAL_GPU_0 = "/device:GPU:0"
|
|
|
|
|
|
# TODO(yuefengz): maybe cache variables on local CPU.
|
|
@tf_export("distribute.experimental.ParameterServerStrategy")
|
|
class ParameterServerStrategy(distribute_lib.DistributionStrategy):
|
|
"""A parameter server DistributionStrategy.
|
|
|
|
This strategy class works for both local training and between-graph replicated
|
|
training for multiple workers. It uses `TFConfigClusterResolver` to detect
|
|
configurations for multi-worker training. In multi-worker training mode, i.e.
|
|
`TFConfigClusterResolver` has detected 'TF_CONFIG' environment variable and
|
|
'TF_CONFIG' has a cluster spec, variables and updates to those variables are
|
|
assigned to parameter servers and other operations are assigned to workers.
|
|
In local training mode, variables are assigned to local CPU or the only GPU.
|
|
When each worker has more than one GPU, operations will be replicated on these
|
|
GPUs. In both cases, operations are replicated but variables are not and these
|
|
workers share a common view for which paramater server a variable is assigned
|
|
to.
|
|
|
|
This class assumes between-graph replication will be used and works on a graph
|
|
for a particular worker. Note that each graph and worker is independent.
|
|
This means that while each worker will synchronously compute a single gradient
|
|
update across all GPUs, updates between workers proceed asynchronously.
|
|
Operations that occur only on the first replica (such as incrementing the
|
|
global step), will occur on the first replica *of every worker*.
|
|
|
|
It is expected to call `call_for_each_replica(fn, ...)` for any
|
|
operations which potentially can be replicated across replicas (i.e. multiple
|
|
GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra
|
|
caution needs to be taken:
|
|
|
|
1) It is generally not recommended to open a device scope under the strategy's
|
|
scope. A device scope (i.e. calling `tf.device`) will be merged with or
|
|
override the device for operations but will not change the device for
|
|
variables.
|
|
|
|
2) It is also not recommended to open a colocation scope (i.e. calling
|
|
`tf.colocate_with`) under the strategy's scope. For colocating variables, use
|
|
`strategy.extended.colocate_vars_with` instead. Colocation of ops will
|
|
possibly create conflicts of device assignment.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initializes this strategy with default TFConfigClusterResolver."""
|
|
super(ParameterServerStrategy, self).__init__(
|
|
ParameterServerStrategyExtended(self))
|
|
|
|
|
|
class ParameterServerStrategyExtended(
|
|
distribute_lib.DistributionStrategyExtended):
|
|
"""Implementation of ParameterServerStrategy."""
|
|
|
|
def __init__(self,
|
|
container_strategy,
|
|
cluster_resolver=TFConfigClusterResolver()):
|
|
super(ParameterServerStrategyExtended, self).__init__(container_strategy)
|
|
self._initialize_strategy(cluster_resolver)
|
|
|
|
# We typically don't need to do all-reduce in this strategy.
|
|
self._cross_device_ops = (
|
|
cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU))
|
|
|
|
def _initialize_strategy(self, cluster_resolver):
|
|
if cluster_resolver.cluster_spec().as_dict():
|
|
self._initialize_multi_worker(cluster_resolver)
|
|
else:
|
|
self._initialize_local(cluster_resolver)
|
|
# Save the num_gpus_per_worker for configure method.
|
|
self._num_gpus_per_worker = cluster_resolver.num_accelerators()
|
|
|
|
def _initialize_multi_worker(self, cluster_resolver):
|
|
"""Initialize devices for multiple workers.
|
|
|
|
It creates variable devices and compute devices. Variables and operations
|
|
will be assigned to them respectively. We have one compute device per
|
|
replica. The variable device is a device function or device string. The
|
|
default variable device assigns variables to parameter servers in a
|
|
round-robin fashion.
|
|
|
|
Args:
|
|
cluster_resolver: a descendant of `ClusterResolver` object.
|
|
|
|
Raises:
|
|
ValueError: if the cluster doesn't have ps jobs.
|
|
"""
|
|
num_gpus = cluster_resolver.num_accelerators()
|
|
cluster_spec = cluster_resolver.cluster_spec()
|
|
task_type = cluster_resolver.task_type
|
|
task_id = cluster_resolver.task_id
|
|
if not task_type or task_id is None:
|
|
raise ValueError("When `cluster_spec` is given, you must also specify "
|
|
"`task_type` and `task_id`")
|
|
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
|
|
assert cluster_spec.as_dict()
|
|
|
|
worker_device = "/job:%s/task:%d" % (task_type, task_id)
|
|
self._input_host_device = numpy_dataset.SingleDevice(worker_device)
|
|
|
|
# Define compute devices which is a list of device strings and one for each
|
|
# replica. When there are GPUs, replicate operations on these GPUs.
|
|
# Otherwise, place operations on CPU.
|
|
if num_gpus > 0:
|
|
compute_devices = tuple(
|
|
"%s/device:GPU:%d" % (worker_device, i) for i in range(num_gpus))
|
|
else:
|
|
compute_devices = (worker_device,)
|
|
|
|
self._device_map = values.ReplicaDeviceMap(compute_devices)
|
|
self._input_workers = input_lib.InputWorkers(
|
|
self._device_map, [(worker_device, compute_devices)])
|
|
|
|
# In distributed mode, place variables on ps jobs in a round-robin fashion.
|
|
# Note that devices returned from `replica_device_setter` are not
|
|
# canonical and therefore we don't canonicalize all variable devices to
|
|
# make them consistent.
|
|
# TODO(yuefengz): support passing a strategy object to control variable
|
|
# assignment.
|
|
# TODO(yuefengz): merge the logic of replica_device_setter into this
|
|
# class.
|
|
num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
|
|
if num_ps_replicas == 0:
|
|
raise ValueError("The cluster spec needs to have `ps` jobs.")
|
|
self._variable_device = device_setter.replica_device_setter(
|
|
ps_tasks=num_ps_replicas,
|
|
worker_device=worker_device,
|
|
merge_devices=True,
|
|
cluster=cluster_spec)
|
|
|
|
# The `_parameter_devices` is needed for the `parameter_devices` property
|
|
# and is a list of all variable devices. Here parameter devices are all
|
|
# tasks of the "ps" job.
|
|
self._parameter_devices = tuple(map("/job:ps/task:{}".format,
|
|
range(num_ps_replicas)))
|
|
|
|
# Add a default device so that ops without specified devices will not end up
|
|
# on other workers.
|
|
self._default_device = worker_device
|
|
|
|
self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
|
|
task_id)
|
|
self._cluster_spec = cluster_spec
|
|
self._task_type = task_type
|
|
self._task_id = task_id
|
|
|
|
logging.info(
|
|
"Multi-worker ParameterServerStrategy with "
|
|
"cluster_spec = %r, task_type = %r, task_id = %r, "
|
|
"num_ps_replicas = %r, is_chief = %r, device_map = %r, "
|
|
"variable_device = %r", cluster_spec.as_dict(), task_type, task_id,
|
|
num_ps_replicas, self._is_chief, self._device_map,
|
|
self._variable_device)
|
|
|
|
def _initialize_local(self, cluster_resolver):
|
|
"""Initialize internal devices for local training."""
|
|
worker_device = device_util.canonicalize("/device:CPU:0")
|
|
self._input_host_device = numpy_dataset.SingleDevice(worker_device)
|
|
num_gpus = cluster_resolver.num_accelerators()
|
|
# Define compute devices which is a list of device strings and one for each
|
|
# replica. When there are GPUs, replicate operations on these GPUs.
|
|
# Otherwise, place operations on CPU.
|
|
if num_gpus > 0:
|
|
compute_devices = tuple(map("/device:GPU:{}".format, range(num_gpus)))
|
|
else:
|
|
compute_devices = (_LOCAL_CPU,)
|
|
|
|
self._device_map = values.ReplicaDeviceMap(compute_devices)
|
|
self._input_workers = input_lib.InputWorkers(
|
|
self._device_map, [(worker_device, compute_devices)])
|
|
|
|
# If there is only one GPU, put everything on that GPU. Otherwise, place
|
|
# variables on CPU.
|
|
if num_gpus == 1:
|
|
assert len(compute_devices) == 1
|
|
self._variable_device = _LOCAL_GPU_0
|
|
self._parameter_devices = (_LOCAL_GPU_0,)
|
|
else:
|
|
self._variable_device = _LOCAL_CPU
|
|
self._parameter_devices = (_LOCAL_CPU,)
|
|
|
|
self._is_chief = True
|
|
self._cluster_spec = None
|
|
self._task_type = None
|
|
self._task_id = None
|
|
|
|
logging.info(
|
|
"ParameterServerStrategy with compute_devices = %r, "
|
|
"variable_device = %r", compute_devices, self._variable_device)
|
|
|
|
def _validate_colocate_with_variable(self, colocate_with_variable):
|
|
values.validate_colocate(colocate_with_variable, self)
|
|
|
|
def _make_dataset_iterator(self, dataset):
|
|
return input_lib.DatasetIterator(dataset, self._input_workers,
|
|
self._num_replicas_in_sync)
|
|
|
|
def _make_input_fn_iterator(
|
|
self,
|
|
input_fn,
|
|
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
|
|
"""Distributes the dataset to each local GPU."""
|
|
if self._cluster_spec:
|
|
input_pipeline_id = multi_worker_util.id_in_cluster(
|
|
self._cluster_spec, self._task_type, self._task_id)
|
|
num_input_pipelines = multi_worker_util.worker_count(
|
|
self._cluster_spec, self._task_type)
|
|
else:
|
|
input_pipeline_id = 0
|
|
num_input_pipelines = 1
|
|
input_context = distribute_lib.InputContext(
|
|
num_input_pipelines=num_input_pipelines,
|
|
input_pipeline_id=input_pipeline_id,
|
|
num_replicas_in_sync=self._num_replicas_in_sync)
|
|
return input_lib.InputFunctionIterator(input_fn, self._input_workers,
|
|
[input_context])
|
|
|
|
def _experimental_make_numpy_dataset(self, numpy_input, session):
|
|
return numpy_dataset.one_host_numpy_dataset(
|
|
numpy_input, self._input_host_device, session)
|
|
|
|
def _broadcast_to(self, tensor, destinations):
|
|
# This is both a fast path for Python constants, and a way to delay
|
|
# converting Python values to a tensor until we know what type it
|
|
# should be converted to. Otherwise we have trouble with:
|
|
# global_step.assign_add(1)
|
|
# since the `1` gets broadcast as an int32 but global_step is int64.
|
|
if isinstance(tensor, (float, int)):
|
|
return tensor
|
|
if not cross_device_ops_lib.check_destinations(destinations):
|
|
# TODO(josh11b): Use current logical device instead of 0 here.
|
|
destinations = values.LogicalDeviceSpec(
|
|
device_map=self._device_map, logical_device=0)
|
|
return self._cross_device_ops.broadcast(tensor, destinations)
|
|
|
|
def _allow_variable_partition(self):
|
|
return not context.executing_eagerly()
|
|
|
|
# TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through
|
|
# this creator, such as "MutableHashTable".
|
|
def _create_variable(self, next_creator, *args, **kwargs):
|
|
if self._num_replicas_in_sync > 1:
|
|
aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
|
|
if aggregation not in (
|
|
vs.VariableAggregation.NONE,
|
|
vs.VariableAggregation.SUM,
|
|
vs.VariableAggregation.MEAN,
|
|
vs.VariableAggregation.ONLY_FIRST_REPLICA
|
|
):
|
|
raise ValueError("Invalid variable aggregation mode: " + aggregation +
|
|
" for variable: " + kwargs["name"])
|
|
|
|
def var_creator(*args, **kwargs):
|
|
"""Create an AggregatingVariable and fix up collections."""
|
|
# Record what collections this variable should be added to.
|
|
collections = kwargs.pop("collections", None)
|
|
if collections is None:
|
|
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
|
|
kwargs["collections"] = []
|
|
|
|
# Create and wrap the variable.
|
|
v = next_creator(*args, **kwargs)
|
|
wrapped = values.AggregatingVariable(
|
|
self._container_strategy(), v, aggregation)
|
|
|
|
# Add the wrapped variable to the requested collections.
|
|
# The handling of eager mode and the global step matches
|
|
# ResourceVariable._init_from_args().
|
|
if not context.executing_eagerly():
|
|
g = ops.get_default_graph()
|
|
# If "trainable" is True, next_creator() will add the contained
|
|
# variable to the TRAINABLE_VARIABLES collection, so we manually
|
|
# remove it and replace with the wrapper. We can't set "trainable"
|
|
# to False for next_creator() since that causes functions like
|
|
# implicit_gradients to skip those variables.
|
|
if kwargs.get("trainable", True):
|
|
collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
|
|
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
|
|
if v in l:
|
|
l.remove(v)
|
|
g.add_to_collections(collections, wrapped)
|
|
elif ops.GraphKeys.GLOBAL_STEP in collections:
|
|
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)
|
|
|
|
return wrapped
|
|
else:
|
|
var_creator = next_creator
|
|
|
|
if "colocate_with" in kwargs:
|
|
colocate_with = kwargs["colocate_with"]
|
|
if isinstance(colocate_with, numpy_dataset.SingleDevice):
|
|
with ops.device(colocate_with.device):
|
|
return var_creator(*args, **kwargs)
|
|
with ops.device(None):
|
|
with ops.colocate_with(colocate_with):
|
|
return var_creator(*args, **kwargs)
|
|
|
|
with ops.colocate_with(None, ignore_existing=True):
|
|
with ops.device(self._variable_device):
|
|
return var_creator(*args, **kwargs)
|
|
|
|
def _call_for_each_replica(self, fn, args, kwargs):
|
|
# pylint: disable=protected-access
|
|
return mirrored_strategy._call_for_each_replica(
|
|
self._container_strategy(), self._device_map, fn, args, kwargs)
|
|
|
|
def _verify_destinations_not_different_worker(self, destinations):
|
|
if not self._cluster_spec:
|
|
return
|
|
if destinations is None:
|
|
return
|
|
for d in cross_device_ops_lib.get_devices_from(destinations):
|
|
d_spec = tf_device.DeviceSpec.from_string(d)
|
|
if d_spec.job == self._task_type and d_spec.task != self._task_id:
|
|
raise ValueError(
|
|
"Cannot reduce to another worker: %r, current worker is %r" %
|
|
(d, self._input_workers.worker_devices[0]))
|
|
|
|
def _reduce_to(self, reduce_op, value, destinations):
|
|
self._verify_destinations_not_different_worker(destinations)
|
|
if not isinstance(value, values.DistributedValues):
|
|
# pylint: disable=protected-access
|
|
return cross_device_ops_lib.reduce_non_distributed_value(
|
|
reduce_op, self._device_map, value, destinations)
|
|
return self._cross_device_ops.reduce(
|
|
reduce_op, value, destinations=destinations)
|
|
|
|
def _batch_reduce_to(self, reduce_op, value_destination_pairs):
|
|
for _, destinations in value_destination_pairs:
|
|
self._verify_destinations_not_different_worker(destinations)
|
|
return self._cross_device_ops.batch_reduce(reduce_op,
|
|
value_destination_pairs)
|
|
|
|
def _select_single_value(self, structured):
|
|
"""Select any single values in `structured`."""
|
|
|
|
def _select_fn(x): # pylint: disable=g-missing-docstring
|
|
if isinstance(x, values.Mirrored):
|
|
if len(x.devices) == 1:
|
|
return x.primary
|
|
else:
|
|
raise ValueError(
|
|
"You cannot update variable with a Mirrored object with multiple "
|
|
"components %r when using ParameterServerStrategy. You must "
|
|
"specify a single value or a Mirrored with a single value." % x)
|
|
elif isinstance(x, values.PerReplica):
|
|
raise ValueError(
|
|
"You cannot update variable with a PerReplica object %r when using "
|
|
"ParameterServerStrategy. You must specify a single value or a "
|
|
"Mirrored with a single value" % x)
|
|
else:
|
|
return x
|
|
|
|
return nest.map_structure(_select_fn, structured)
|
|
|
|
def _update(self, var, fn, args, kwargs, group):
|
|
if isinstance(var, values.AggregatingVariable):
|
|
var = var.get()
|
|
if not isinstance(var, resource_variable_ops.ResourceVariable):
|
|
raise ValueError(
|
|
"You can not update `var` %r. It must be a Variable." % var)
|
|
with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
|
|
result = fn(var, *self._select_single_value(args),
|
|
**self._select_single_value(kwargs))
|
|
if group:
|
|
return result
|
|
else:
|
|
return nest.map_structure(self._unwrap, result)
|
|
|
|
# TODO(yuefengz): does it need to call _select_single_value?
|
|
def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
|
|
with ops.device(
|
|
colocate_with.device), distribute_lib.UpdateContext(colocate_with):
|
|
result = fn(*args, **kwargs)
|
|
if group:
|
|
return result
|
|
else:
|
|
return nest.map_structure(self._unwrap, result)
|
|
|
|
def _unwrap(self, val):
|
|
if isinstance(val, values.DistributedValues):
|
|
return val.values
|
|
return (val,)
|
|
|
|
def value_container(self, val):
|
|
if (hasattr(val, "_aggregating_container") and
|
|
not isinstance(val, values.AggregatingVariable)):
|
|
wrapper = val._aggregating_container() # pylint: disable=protected-access
|
|
if wrapper is not None:
|
|
return wrapper
|
|
return val
|
|
|
|
def read_var(self, var):
|
|
# No need to distinguish between normal variables and replica-local
|
|
# variables.
|
|
return array_ops.identity(var)
|
|
|
|
def _configure(self,
|
|
session_config=None,
|
|
cluster_spec=None,
|
|
task_type=None,
|
|
task_id=None):
|
|
"""Configures the strategy class.
|
|
|
|
The strategy object will be re-initialized if `cluster_spec` is given but
|
|
was not passed in the constructor.
|
|
|
|
Args:
|
|
session_config: not used currently.
|
|
cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
|
|
cluster configurations.
|
|
task_type: the current task type.
|
|
task_id: the current task id.
|
|
|
|
Raises:
|
|
ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
|
|
not.
|
|
"""
|
|
if cluster_spec:
|
|
# Use the num_gpus_per_worker recorded in constructor since _configure
|
|
# doesn't take num_gpus.
|
|
cluster_resolver = SimpleClusterResolver(
|
|
cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
|
|
task_type=task_type,
|
|
task_id=task_id,
|
|
num_accelerators=self._num_gpus_per_worker)
|
|
self._initialize_multi_worker(cluster_resolver)
|
|
|
|
if session_config:
|
|
session_config.CopyFrom(self._update_config_proto(session_config))
|
|
|
|
def _update_config_proto(self, config_proto):
|
|
updated_config = copy.deepcopy(config_proto)
|
|
if not self._cluster_spec:
|
|
updated_config.isolate_session_state = True
|
|
return updated_config
|
|
|
|
updated_config.isolate_session_state = False
|
|
|
|
assert self._task_type
|
|
assert self._task_id is not None
|
|
|
|
# The device filters prevent communication between workers.
|
|
if self._task_type not in ["chief", "worker"]:
|
|
return updated_config
|
|
del updated_config.device_filters[:]
|
|
updated_config.device_filters.extend(
|
|
["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"])
|
|
return updated_config
|
|
|
|
@property
|
|
def _num_replicas_in_sync(self):
|
|
return self._device_map.num_replicas_in_graph
|
|
|
|
@property
|
|
def worker_devices(self):
|
|
return self._device_map.all_devices
|
|
|
|
@property
|
|
def worker_devices_by_replica(self):
|
|
return self._device_map.devices_by_replica
|
|
|
|
@property
|
|
def parameter_devices(self):
|
|
return self._parameter_devices
|
|
|
|
def non_slot_devices(self, var_list):
|
|
return min(var_list, key=lambda x: x.name)
|
|
|
|
@property
|
|
def experimental_between_graph(self):
|
|
# TODO(yuefengz): Should this return False in the local case?
|
|
return True
|
|
|
|
@property
|
|
def experimental_should_init(self):
|
|
return self._is_chief
|
|
|
|
@property
|
|
def should_checkpoint(self):
|
|
return self._is_chief
|
|
|
|
@property
|
|
def should_save_summary(self):
|
|
return self._is_chief
|
|
|
|
# TODO(priyag): Delete this once all strategies use global batch size.
|
|
@property
|
|
def _global_batch_size(self):
|
|
"""`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
|
|
|
|
`make_input_fn_iterator` assumes per-replica batching.
|
|
|
|
Returns:
|
|
Boolean.
|
|
"""
|
|
return True
|