Refactor collective utils to be of one replica

We used to have build_collective_* that launches collective on all replicas. However each replica is independent and it can simplify the utils if it only launches on one replica.

This is mainly to prepare for switching to Collective V2 ops and reduce the use of merge_call in MirroredStrategy/MultiWorkerMirroredStrategy. For the former, it will be easier after this change. For the latter, we can directly call CollectiveReplicaLauncher.all_reduce from ReplicaContext.all_reduce without the need of merge_call in MS/MWMS.

Note that the original CollecticeKeys is outdated. Now instance key is scoped under each group key, and is no longer global.

PiperOrigin-RevId: 337007588
Change-Id: I111efad0ecbe829260c8e4d2f555c6b2dc1d4c05
This commit is contained in:
Ran Chen 2020-10-13 19:57:19 -07:00 committed by TensorFlower Gardener
parent c451bbccd9
commit 746af76e4f
9 changed files with 321 additions and 476 deletions

View File

@ -96,13 +96,14 @@ py_library(
":values",
"//tensorflow/python:array_ops",
"//tensorflow/python:collective_ops",
"//tensorflow/python:device",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:nccl_ops",
"//tensorflow/python:platform",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
],
)

View File

@ -101,6 +101,9 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
# TODO(anjalisridhar): Update our guides with examples showing how we can use
# the cluster_resolver argument.
# The starting number for collective keys. This should only be set in tests.
_collective_key_base = 0
def __init__(
self,
communication=cross_device_ops_lib.CollectiveCommunication.AUTO,
@ -362,7 +365,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
else:
local_devices = (self._worker_device,)
self._collective_keys = cross_device_utils.CollectiveKeys()
self._collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=1 + CollectiveAllReduceStrategy._collective_key_base) # pylint: disable=protected-access
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=local_devices,
group_size=len(local_devices) * self._num_workers,
@ -428,7 +432,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
group_key = self._collective_keys.get_group_key([device])
group_size = self._num_workers
collective_instance_key = (
self._collective_keys.get_variable_instance_key())
self._collective_keys.get_instance_key(group_key, device))
with ops.device(device):
initial_value = kwargs["initial_value"]

View File

@ -29,7 +29,6 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context
@ -59,6 +58,8 @@ from tensorflow.python.platform import test
from tensorflow.python.training.server_lib import ClusterSpec
CollectiveAllReduceStrategy = (
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
CollectiveAllReduceExtended = (
collective_all_reduce_strategy.CollectiveAllReduceExtended)
@ -93,14 +94,10 @@ def create_test_objects(cluster_spec=None,
class CollectiveAllReduceStrategyTestBase(
multi_worker_test_base.MultiWorkerTestBase):
collective_key_base = 0
def setUp(self):
# We use a different key_base for each test so that collective keys won't be
# reused.
# TODO(yuefengz, ayushd): enable it to reuse collective keys in different
# tests.
CollectiveAllReduceStrategyTestBase.collective_key_base += 100000
CollectiveAllReduceStrategy._collective_key_base += 100000
super(CollectiveAllReduceStrategyTestBase, self).setUp()
def _get_test_object(self, task_type, task_id, num_gpus=0):
@ -109,18 +106,6 @@ class CollectiveAllReduceStrategyTestBase(
task_type=task_type,
task_id=task_id,
num_gpus=num_gpus)
collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=10 +
CollectiveAllReduceStrategyTestBase.collective_key_base,
op_instance_key_start=100 +
CollectiveAllReduceStrategyTestBase.collective_key_base,
variable_instance_key_start=10000 +
CollectiveAllReduceStrategyTestBase.collective_key_base)
strategy.extended._collective_keys = collective_keys
strategy.extended._cross_device_ops._collective_keys = collective_keys
strategy.extended._host_cross_device_ops._collective_keys = collective_keys
return strategy, target, session_config
def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):

View File

@ -36,7 +36,7 @@ from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import executor
from tensorflow.python.eager import executor as executor_lib
from tensorflow.python.framework import kernels
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
@ -144,8 +144,8 @@ def _normalize_value_destination_pairs(value_destination_pairs):
def _validate_value_destination_pairs(value_destination_pairs):
"""Validates value_destination_pairs are valid."""
# TODO(yuefengz): raise exceptions instead of returning False.
# pylint: disable=g-missing-docstring
if not value_destination_pairs: return False
if not isinstance(value_destination_pairs, (list, tuple)): return False
if not all(isinstance(pair, tuple) for pair in value_destination_pairs):
@ -197,7 +197,7 @@ def simple_broadcast(value, destinations, always_mirrored=False):
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
reduce_op):
# pylint: disable=g-missing-docstring
"""Reduces the value by accumulation_fn and reduce_op."""
all_values = per_replica_value.values
if not all_values:
raise ValueError("`per_replica_value` must be non-empty")
@ -1019,7 +1019,6 @@ class CollectiveAllReduce(CrossDeviceOps):
if group_size % len(devices) > 0:
raise ValueError("group_size must be divisible by the number of devices.")
self._devices = tuple(device_util.canonicalize(d) for d in devices)
self._group_size = group_size
self._collective_keys = (collective_keys or
cross_device_utils.CollectiveKeys())
@ -1039,14 +1038,21 @@ class CollectiveAllReduce(CrossDeviceOps):
# This deadlocks since neither collective is able to finish.
self._lock = threading.Lock()
self._devices = tuple(device_util.canonicalize(d) for d in devices)
group_key = self._collective_keys.get_group_key(self._devices)
# Collective ops requires all devices to participate and is blocking. In
# eager, we need one async executor for each device to be able to launch
# them altogether. Note that async doesn't imply concurrency. Within an
# async executor operations are still executed sequentially. In graph or
# function building, the executors are not used.
self._executors = []
for _ in range(len(devices)):
self._executors.append(executor.new_executor(enable_async=True))
self._launchers = []
for device in self._devices:
executor = executor_lib.new_executor(enable_async=True)
self._executors.append(executor)
launcher = cross_device_utils.CollectiveReplicaLauncher(
group_key, group_size, self._collective_keys, device, executor)
self._launchers.append(launcher)
super(CollectiveAllReduce, self).__init__()
@ -1148,63 +1154,39 @@ class CollectiveAllReduce(CrossDeviceOps):
# queuing time due to concurrent intense computation.
#
# TODO(b/147393503): explore solutions for optimal ordering.
packs = cross_device_utils.pack_by_size(
list(reversed(per_replica_values)), experimental_hints.bytes_per_pack)
values_by_device = [[] for _ in range(len(self._devices))]
for per_replica in reversed(per_replica_values):
for i in range(len(self._devices)):
values_by_device[i].append(per_replica.values[i])
if batch_size > 1:
logging.info(
"Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
"group_size = %d, communication_hint = %s, num_packs = %d",
batch_size, len(self._devices), self._group_size, communication,
len(packs))
else:
logging.log_first_n(
logging.INFO, "Collective batch_all_reduce: %d all-reduces, "
"num_devices = %d, group_size = %d, communication_hint = %s, "
"num_packs = %d" % (batch_size, len(
self._devices), self._group_size, communication, len(packs)), 10)
reduced_values = []
outputs_by_device = []
with self._lock:
for pack in packs:
# By placing all CollectiveReduce ops in a pack under single name scope,
# we ensure they will be picked up by the `ScopedAllocator` grappler
# optimizer and packed into a single all-reduce.
with ops.name_scope("allreduce"):
for per_replica in pack:
# Add control dependencies per device from the last gradients to the
# current set, in order to serialize NCCL launches.
if (communication == CollectiveCommunication.NCCL.value and
reduced_values):
control_inputs = list(reduced_values[-1])
else:
control_inputs = None
reduced_values.append(
cross_device_utils.build_collective_reduce(
per_replica.values,
self._devices,
self._group_size,
self._collective_keys,
"Add",
"Id",
communication,
control_inputs,
executors=self._executors,
timeout=experimental_hints.timeout_seconds))
for i in range(len(self._devices)):
packs = cross_device_utils.group_by_size(
values_by_device[i], experimental_hints.bytes_per_pack)
if not context.executing_eagerly() and i == 0:
logging.info(
"Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
"group_size = %d, communication_hint = %s, num_packs = %d",
batch_size, len(self._launchers), self._group_size, communication,
len(packs))
outputs_by_device.append(self._launchers[i].batch_all_reduce(
packs, communication, experimental_hints.timeout_seconds))
for e in self._executors:
e.wait()
mirrored = []
# Reverse the order of reduced value to recover the order in the input.
for value in reversed(reduced_values):
for values in zip(*outputs_by_device):
if reduce_op == reduce_util.ReduceOp.MEAN:
for i, v in enumerate(value):
values = list(values)
for i, v in enumerate(values):
with ops.device(v.device):
value[i] = v / self._group_size
values[i] = v / self._group_size
mirrored.append(
distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
return mirrored
distribute_utils.regroup(values, wrap_class=value_lib.Mirrored))
# Reverse the order of reduced value to recover the order in the input.
return list(reversed(mirrored))
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values,
experimental_hints):
@ -1217,24 +1199,16 @@ class CollectiveAllReduce(CrossDeviceOps):
# Pass self._communication to the runtime as a communication hint.
communication_hint = self._communication.value
# For now, we use NCCL only when batch_size > 1.
# TODO(b/132575814): switch to NCCL for all collectives when communication
# is NCCL.
if self._communication == CollectiveCommunication.NCCL and len(
per_replica_values) == 1:
communication_hint = CollectiveCommunication.AUTO.value
gathered_values = []
with self._lock, ops.name_scope("allreduce"):
with self._lock:
for per_replica in per_replica_values:
gathered_values.append(
cross_device_utils.build_collective_gather_indexed_slices(
per_replica.values,
self._devices,
self._group_size,
self._collective_keys,
communication_hint,
timeout=experimental_hints.timeout_seconds))
outputs = []
for i in range(len(self._devices)):
outputs.append(self._launchers[i].all_reduce_indexed_slices(
per_replica.values[i], communication_hint,
experimental_hints.timeout_seconds))
gathered_values.append(outputs)
mirrored = []
for value in gathered_values:
@ -1281,11 +1255,6 @@ class CollectiveAllReduce(CrossDeviceOps):
batch_size = len(per_replica_values)
# Pass self._communication to the runtime as a communication hint.
communication = self._communication.value
# For now, we use NCCL only when batch_size > 1.
# TODO(b/132575814): switch to NCCL for all collectives when communication
# is NCCL.
if self._communication == CollectiveCommunication.NCCL and batch_size == 1:
communication = CollectiveCommunication.AUTO.value
logging.log_first_n(
logging.INFO, "Collective batch_all_gather: %d all-gathers, "
@ -1296,21 +1265,12 @@ class CollectiveAllReduce(CrossDeviceOps):
gathered_values = []
with self._lock, ops.name_scope("allgather"):
for per_replica in per_replica_values:
if (communication == CollectiveCommunication.NCCL.value and
gathered_values):
control_inputs = list(gathered_values[-1])
else:
control_inputs = None
gathered_values.append(
cross_device_utils.build_collective_gather(
per_replica.values,
self._devices,
self._group_size,
self._collective_keys,
axis,
communication,
control_inputs,
timeout=experimental_hints.timeout_seconds))
outputs = []
for i in range(len(self._devices)):
outputs.append(self._launchers[i].all_gather(
per_replica.values[i], axis, communication,
experimental_hints.timeout_seconds))
gathered_values.append(outputs)
return gathered_values
if context.executing_eagerly():
@ -1319,8 +1279,7 @@ class CollectiveAllReduce(CrossDeviceOps):
gathered_values = compute_gathered_values()
mirrored = []
# Reverse the order of gathered value to recover the order in the input.
for value in reversed(gathered_values):
for value in gathered_values:
mirrored.append(
distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
return mirrored

View File

@ -24,7 +24,6 @@ import threading
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -34,7 +33,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nccl_ops
from tensorflow.python.platform import tf_logging as logging
OP_INSTANCE_KEY_START_NUMBER = 100
INSTANCE_KEY_START_NUMBER = 100
def aggregate_gradients_using_nccl(replica_grads):
@ -181,69 +180,66 @@ class CollectiveKeys(object):
*Instance key*: an integer key to identify the set of same counterpart of
tensors on different devices in a device group that need to be all-reduced.
"Graph key": an integer key that is unique key graph. This is used to support
multiple graphs per client session. It must be non-zero and set in the
`config` argument of each call to `session.run`.
This class is thread safe.
"""
def __init__(self,
group_key_start=1,
op_instance_key_start=OP_INSTANCE_KEY_START_NUMBER,
variable_instance_key_start=1000000):
def __init__(self, group_key_start=1):
"""Initializes the object.
Args:
group_key_start: the starting integer of group key.
op_instance_key_start: the starting integer of instance key for ops.
variable_instance_key_start: the starting integer of instance key for
variables.
"""
self._group_key = group_key_start
self._group_key_table = {}
assert op_instance_key_start != variable_instance_key_start
self._op_instance_key = op_instance_key_start
self._variable_instance_key = variable_instance_key_start
self._instance_key_table = {}
self._lock = threading.Lock()
def get_group_key(self, devices):
"""Returns a group key for the set of devices.
Args:
devices: list of strings naming devices in a collective group.
devices: a list of canonical device strings in a collective group.
Returns:
int key uniquely identifying the set of device names.
"""
parsed = [pydev.DeviceSpec.from_string(d) for d in devices]
# In the between-graph replicated training, different workers need to get
# the same device key. So we remove the task_type and task_id from the
# devices.
# TODO(yuefengz): in the in-graph replicated training, we need to include
# task_type and task_id.
names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed])
key_id = ','.join(names)
key_id = hash(tuple(sorted(devices)))
with self._lock:
if key_id not in self._group_key_table:
new_key = self._group_key
self._group_key += 1
self._group_key_table[key_id] = new_key
self._instance_key_table[new_key] = {}
for device in devices:
self._instance_key_table[new_key][device] = INSTANCE_KEY_START_NUMBER
return self._group_key_table[key_id]
def get_op_instance_key(self):
"""Returns a new instance key for use in defining a collective op."""
with self._lock:
v = self._op_instance_key
self._op_instance_key += 1
return v
def get_instance_key(self, group_key, device):
"""Returns a new instance key for use in defining a collective op.
def get_variable_instance_key(self):
"""Returns a new instance key for use in creating a Variable."""
You should call this once per each collective op of a collective instance.
Args:
group_key: the group key returned by get_group_key(). You should not
assign the group key yourself.
device: a canonical device string. It should be the device this collective
op is on.
Returns:
a new instance key.
Raises:
ValueError: when the group key is invalid or the device is not in the
group.
"""
with self._lock:
v = self._variable_instance_key
self._variable_instance_key += 1
group = self._instance_key_table.get(group_key, None)
if group is None:
raise ValueError('group {} not found'.format(group_key))
if device not in group:
raise ValueError('{} not in group {}'.format(device, group_key))
v = group[device]
group[device] += 1
return v
def __deepcopy__(self, memo):
@ -252,135 +248,146 @@ class CollectiveKeys(object):
copied = CollectiveKeys()
copied._group_key = self._group_key
copied._group_key_table = copy.deepcopy(self._group_key_table, memo)
copied._op_instance_key = self._op_instance_key
copied._variable_instance_key = self._variable_instance_key
copied._instance_key_table = copy.deepcopy(self._instance_key_table, memo)
return copied
def build_collective_reduce(input_tensors,
devices,
group_size,
collective_keys,
reduction_op='Add',
unary_op='Id',
communication_hint='AUTO',
control_inputs=None,
executors=None,
timeout=None):
"""Build a subgraph that does one full all-reduce, using the collective Op.
class CollectiveReplicaLauncher(object):
"""Launch collectives on one replica."""
If called in eager mode, it's required to supply a list of async executors for
each input Tensor.
def __init__(self,
group_key,
group_size,
collective_keys,
device,
executor=None):
if executor and not executor.is_async():
raise ValueError('executor must be async')
self._group_key = group_key
self._group_size = group_size
self._collective_keys = collective_keys
self._device = device
self._executor = executor
Args:
input_tensors: tensors within a single worker graph that are to be reduced
together; must be one per device.
devices: a list of device strings to run the collective on.
group_size: total number of devices globally that will be doing this same
reduction. The reduction will actually include the corresponding tensors
at all these workers.
collective_keys: a CollectiveKeys object.
reduction_op: string naming the reduction op.
unary_op: string naming the unary final op.
communication_hint: string providing hint to runtime for choosing collective
implementation.
control_inputs: if not None, add control edges between control_inputs and
(index-wise) corresponding collective_reduce tensors
executors: a list of async executor. Required for eager execution.
timeout: a float or None. The timeout in seconds.
Returns:
An array of final tensors, one per device, computed by the full reduction.
Raises:
ValueError: There must be at least two tensors over all the workers.
"""
if context.executing_eagerly():
if (not executors or len(executors) != len(input_tensors) or
not all(e.is_async() for e in executors)):
raise ValueError(
'collectives requires async executors for each device in eager mode')
if len(input_tensors) != len(devices):
raise ValueError('collective requires one input tensor for each device, '
'len(input_tensors) = %d, len(devices) = %d' %
(len(input_tensors), len(devices)))
if group_size < 2:
return input_tensors
group_key = collective_keys.get_group_key(devices)
instance_key = collective_keys.get_op_instance_key()
subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec
out_tensors = []
for idx, input_tensor in enumerate(input_tensors):
def _executor_scope(self):
if context.executing_eagerly() and not self._executor:
raise ValueError('collectives requires a async executor in eager mode')
if context.executing_eagerly():
executor_scope = context.executor_scope(executors[idx])
else:
executor_scope = ops.NullContextmanager()
with executor_scope, \
ops.device(devices[idx]), \
ops.control_dependencies(
_control_input(devices, control_inputs, idx)):
out_tensor = collective_ops.all_reduce(
return context.executor_scope(self._executor)
return ops.NullContextmanager()
def _control_input(self, control_input):
if control_input is not None:
return ops.control_dependencies([control_input])
return ops.NullContextmanager()
def all_reduce(self,
input_tensor,
control_input=None,
communication_hint='AUTO',
timeout=0):
"""All-reduce a dense tensor.
This can be called in eager mode if a async executor is supplied when
creating the launcher.
Args:
input_tensor: a dense tensor. It must have the same shape on all replicas.
control_input: if not None, add control edges between control_input and
the all-reduce.
communication_hint: string providing hint to runtime for choosing
collective implementation.
timeout: a float. The timeout in seconds.
Returns:
The reduced tensor.
"""
instance_key = self._collective_keys.get_instance_key(
self._group_key, self._device)
with self._executor_scope(), \
ops.device(self._device), \
self._control_input(control_input):
return collective_ops.all_reduce(
input_tensor,
group_size,
group_key,
self._group_size,
self._group_key,
instance_key,
reduction_op,
unary_op,
subdiv_offsets,
communication_hint,
communication_hint=communication_hint,
timeout=timeout)
out_tensors.append(out_tensor)
return out_tensors
def batch_all_reduce(self,
input_tensor_packs,
communication_hint='AUTO',
timeout=0):
"""Batch all-reduce dense tensors.
def build_collective_gather(input_tensors,
devices,
group_size,
collective_keys,
axis,
communication_hint='AUTO',
control_inputs=None,
timeout=None):
"""Build a subgraph that does one full all-gather, using the collective Op.
This takes a list of batches of tensors. Using multiple batches have the
benefit that it doesn't need to wait for all inputs to be ready to start the
all-reduce.
This method must be called in graph mode or inside a tf.function.
This can be called in eager mode if a async executor is supplied when
creating the launcher.
Args:
input_tensors: tensors within a single worker graph that are to be gathered
together; must be one per device. Input tensors cannot have rank 0.
devices: a list of device strings to run the collective on.
group_size: total number of devices globally that will be doing this same
gathering. The gathering will actually include the corresponding tensors
at all these workers.
collective_keys: a CollectiveKeys object.
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
range [0, rank(value)).
communication_hint: string providing hint to runtime for choosing collective
implementation. Available options are `AUTO`, `NCCL`, and `RING`.
control_inputs: if not None, add control edges between control_inputs and
(index-wise) corresponding collective_gather tensors
timeout: a float or None. The timeout in seconds.
Args:
input_tensor_packs: a list of lists of dense tensors.
communication_hint: string providing hint to runtime for choosing
collective implementation.
timeout: a float. The timeout in seconds.
Returns:
An array of final tensors, one per device, computed by the full gather.
"""
if len(input_tensors) != len(devices):
raise ValueError(
'collective requires one input tensor for each device, %d != %d' %
(len(input_tensors), len(devices)))
Returns:
A flat list of reduced tensors.
"""
outputs = []
for pack in input_tensor_packs:
# By placing all CollectiveReduce ops in a batch under single name scope,
# we ensure they will be picked up by the `ScopedAllocator` grappler
# optimizer and packed into a single all-reduce.
with ops.name_scope('allreduce'):
# TODO(b/169168846): inserts a parallel all_gather to verify packings
# are the same on each replica.
for input_tensor in pack:
if communication_hint == 'NCCL' and outputs:
control_input = outputs[-1]
else:
control_input = None
outputs.append(
self.all_reduce(input_tensor, control_input, communication_hint,
timeout))
return outputs
if group_size < 2:
return input_tensors
group_key = collective_keys.get_group_key(devices)
instance_key_tensor = collective_keys.get_op_instance_key()
instance_key_shape = collective_keys.get_op_instance_key()
def all_gather(self,
input_tensor,
axis,
communication_hint='AUTO',
timeout=0):
"""All-gather a dense tensor.
out_tensors = []
for idx, input_tensor in enumerate(input_tensors):
with ops.device(devices[idx]), ops.control_dependencies(
_control_input(devices, control_inputs, idx)):
This method must be called inside a tf.function.
Args:
input_tensor: a dense tensor. It must have the same rank on all replicas,
and dimensions other than `axis` need to be the same as well.
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
range [0, rank(value)).
communication_hint: string providing hint to runtime for choosing
collective implementation. Available options are `AUTO`, `NCCL`, and
`RING`.
timeout: a float. The timeout in seconds.
Returns:
The gathered Tensor.
Raises:
RuntimeError: if called in eager mode.
"""
if context.executing_eagerly():
raise RuntimeError('all_gather in eager mode is not supported')
instance_key_tensor = self._collective_keys.get_instance_key(
self._group_key, self._device)
instance_key_shape = self._collective_keys.get_instance_key(
self._group_key, self._device)
with ops.device(self._device):
# 1. Transpose
# E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
# we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which
@ -394,8 +401,8 @@ def build_collective_gather(input_tensors,
# 2. Pad
gathered_shape = collective_ops.all_gather(
array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0),
group_size,
group_key,
self._group_size,
self._group_key,
instance_key_shape,
communication_hint,
timeout=timeout)
@ -406,8 +413,8 @@ def build_collective_gather(input_tensors,
# 3. Gather
gather_padded_out_tensor = collective_ops.all_gather(
padded_input_tensor,
group_size,
group_key,
self._group_size,
self._group_key,
instance_key_tensor,
communication_hint,
timeout=timeout)
@ -424,93 +431,56 @@ def build_collective_gather(input_tensors,
(math_ops.range(1, axis + 1), [0],
math_ops.range(axis + 1, array_ops.rank(input_tensor_t))),
axis=0)
out_tensor = array_ops.transpose(out_tensor_t, perm=perm_after)
out_tensors.append(out_tensor)
return out_tensors
return array_ops.transpose(out_tensor_t, perm=perm_after)
def all_reduce_indexed_slices(self,
input_slices,
communication_hint='AUTO',
timeout=0):
"""All-reduce an IndexedSlices.
def _pad_util(input_tensor, full_axis_dim):
"""Pad the `input_tensor`'s first dimension to be `full_axis_dim`."""
missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0]
tensor_rank = array_ops.rank(input_tensor)
paddings_axis = [[0, missing_axis_dim]]
paddings = array_ops.concat([
paddings_axis,
array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32)
],
axis=0)
padded_input_tensor = array_ops.pad(input_tensor, paddings)
return padded_input_tensor
This method must be called inside a tf.function.
Args:
input_slices: an IndexedSlices.
communication_hint: string providing hint to runtime for choosing
collective implementation.
timeout: a float. The timeout in seconds.
def build_collective_gather_indexed_slices(input_slices_list,
devices,
group_size,
collective_keys,
communication_hint='AUTO',
control_inputs=None,
timeout=None):
"""Build a subgraph that all-gathers IndexedSlices using the collective Op.
Returns:
The reduced IndexedSlices.
This method must be called in graph mode or inside a tf.function.
Raises:
RuntimeError: if called in eager mode.
"""
if context.executing_eagerly():
raise RuntimeError(
'all_reduce_indexed_slices in eager mode is not supported')
Args:
input_slices_list: a list of IndexedSlices within a single worker graph that
are to be gathered together; must be one per device.
devices: a list of device strings to run the collective on.
group_size: total number of devices globally that will be doing this same
gathering. The gathering will actually include the corresponding tensors
at all these workers.
collective_keys: a CollectiveKeys object.
communication_hint: string providing hint to runtime for choosing collective
implementation.
control_inputs: if not None, add control edges between control_inputs and
(index-wise) corresponding collective_reduce tensors
timeout: a float or None. The timeout in seconds.
gather_length_key = self._collective_keys.get_instance_key(
self._group_key, self._device)
gather_indices_key = self._collective_keys.get_instance_key(
self._group_key, self._device)
gather_values_key = self._collective_keys.get_instance_key(
self._group_key, self._device)
reduce_densified_key = self._collective_keys.get_instance_key(
self._group_key, self._device)
Returns:
An array of final IndexedSlices, one per device, computed by the full
gather.
Raises:
ValueError: if control_inputs is not None and doesn't match the length and
devices of inputs.
"""
assert not context.executing_eagerly(), (
'build_collective_gather_indexed_slices can only be called in graph mode'
' or inside tf.function')
if len(input_slices_list) != len(devices):
raise ValueError(
'collective requires one input IndexedSlice for each device, %d != %d' %
(len(input_slices_list), len(devices)))
if group_size < 2:
return input_slices_list
group_key = collective_keys.get_group_key(devices)
gather_length_key = collective_keys.get_op_instance_key()
gather_indices_key = collective_keys.get_op_instance_key()
gather_values_key = collective_keys.get_op_instance_key()
reduce_densified_key = collective_keys.get_op_instance_key()
# Current CollectiveAllGather implementations require input IndexedSlices to
# have consistent length across the board, we handle the reduction of
# IndexedSlices as follows:
# 1. Gather the lengths of IndexedSlices from all participants.
# 2. If they have consistent length, apply all_gather.
# 3. Otherwise convert IndexedSlices to dense tensors and apply
# all_reduce.
out_slices_list = []
for idx, input_slices in enumerate(input_slices_list):
# pylint: disable = cell-var-from-loop
with ops.device(devices[idx]):
# Current CollectiveAllGather implementations require input IndexedSlices to
# have consistent length across the board, we handle the reduction of
# IndexedSlices as follows:
# 1. Gather the lengths of IndexedSlices from all participants.
# 2. If they have consistent length, apply all_gather.
# 3. Otherwise convert IndexedSlices to dense tensors and apply
# all_reduce.
with ops.device(self._device):
def all_gather():
"""Use all_gather to aggregate `IndexedSlices`."""
all_values = collective_ops.all_gather(
input_slices.values,
group_size,
group_key,
self._group_size,
self._group_key,
gather_values_key,
communication_hint,
timeout=timeout)
@ -519,8 +489,8 @@ def build_collective_gather_indexed_slices(input_slices_list,
with ops.control_dependencies(control):
all_indices = collective_ops.all_gather(
input_slices.indices,
group_size,
group_key,
self._group_size,
self._group_key,
gather_indices_key,
communication_hint,
timeout=timeout)
@ -534,8 +504,8 @@ def build_collective_gather_indexed_slices(input_slices_list,
densified = ops.convert_to_tensor(input_slices)
reduced = collective_ops.all_reduce(
densified,
group_size,
group_key,
self._group_size,
self._group_key,
reduce_densified_key,
'Add',
'Id', [0],
@ -550,23 +520,18 @@ def build_collective_gather_indexed_slices(input_slices_list,
dense_shape=input_slices.dense_shape)
length = array_ops.shape(input_slices.indices)
with ops.control_dependencies(
_control_input(input_slices, control_inputs, idx)):
all_lengths = collective_ops.all_gather(
length,
group_size,
group_key,
gather_length_key,
communication_hint,
timeout=timeout)
out_slices = control_flow_ops.cond(
all_lengths = collective_ops.all_gather(
length,
self._group_size,
self._group_key,
gather_length_key,
communication_hint,
timeout=timeout)
return control_flow_ops.cond(
math_ops.equal(
math_ops.reduce_max(all_lengths),
math_ops.reduce_min(all_lengths)), all_gather,
densify_and_all_reduce)
out_slices_list.append(out_slices)
# pylint: enable=cell-var-from-loop
return out_slices_list
def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n):
@ -653,56 +618,35 @@ def stitch_values(values_and_indices_list):
return result
def per_replica_num_elements(per_replica):
"""Returns the static number of elements of one replica.
def group_by_size(input_tensors, bytes_per_pack):
"""Groups `input_tensors` into chunks of `bytes_per_pack`.
Args:
per_replica: A PerReplica of Tensor or IndexedSlices.
Returns:
Number of elements. None if some replica has a different or unknown shape.
"""
values = per_replica._values # pylint: disable=protected-access
s0 = values[0].shape
for v in values:
assert not isinstance(v, ops.IndexedSlices)
if v.shape != s0:
return None
return s0.num_elements()
def pack_by_size(per_replica_list, bytes_per_pack):
"""Packs `per_replica_list` into chunks of `bytes_per_pack`.
The method preserves the original order of `per_replica_list`. The packing is
The method preserves the original order of `input_tensors`. The grouping is
best effort, each pack could have more or less bytes than `bytes_per_pack`.
It only packs values with known shape. Note that, the usage is different from
`cross_device_ops._pack_tensors`, this function is intended to work with the
ScopeAllocator style batching used in `CollectiveAllReduce`.
It only groups values with known shape.
Args:
per_replica_list: A list of PerReplica.
bytes_per_pack: Bytes per pack.
input_tensors: a list of Tensor.
bytes_per_pack: an integer.
Returns:
A list of packs of PerReplica. All values are packed into one pack if
`bytes_per_pack` is zero or any of the value has unknown shape.
A list of packs of Tensor. All values are grouped into one pack if
`bytes_per_pack` is zero or any of the value has unknown shape.
"""
if bytes_per_pack == 0:
return [per_replica_list]
return [input_tensors]
packs = []
last_pack_size = 0
for value in per_replica_list:
num_elements = per_replica_num_elements(value)
for value in input_tensors:
num_elements = value.shape.num_elements()
if num_elements is None:
# Can't pack values with unknown shape.
logging.warning(
'not packing values due to the unknown or inconsistent shape of %s',
value)
return [per_replica_list]
size = num_elements * value._primary.dtype.size # pylint: disable=protected-access
return [input_tensors]
size = num_elements * value.dtype.size
# Try to keep each pack as close to bytes_per_pack as possible, while each
# pack is at least bytes_per_pack large. I.E. we err on the side of having
# few but large packs.
@ -714,24 +658,15 @@ def pack_by_size(per_replica_list, bytes_per_pack):
return packs
def _control_input(devices, control_inputs, idx):
"""Returns the `idx`-th item in control_inputs to be used in ops.control_dependencies.
This is a helper function for building collective ops.
Args:
devices: a list of device strings the collective run on.
control_inputs: a list or None.
idx: the index into `inputs` and `control_inputs`.
Returns:
A one item list of the `idx`-th element of `control_inputs`, or an empty
list if `control_inputs` is None.
"""
if control_inputs is None:
return []
if len(control_inputs) != len(devices):
raise ValueError(
'control_inputs must match the length of the devices, %s != %s' %
(len(control_inputs), len(devices)))
return [control_inputs[idx]]
def _pad_util(input_tensor, full_axis_dim):
"""Pad the `input_tensor`'s first dimension to be `full_axis_dim`."""
missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0]
tensor_rank = array_ops.rank(input_tensor)
paddings_axis = [[0, missing_axis_dim]]
paddings = array_ops.concat([
paddings_axis,
array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32)
],
axis=0)
padded_input_tensor = array_ops.pad(input_tensor, paddings)
return padded_input_tensor

View File

@ -23,7 +23,6 @@ from absl.testing import parameterized
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -114,11 +113,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
device_util.resolve(destination), device_util.resolve(result.device))
class PackBySizeTest(test.TestCase):
def assertShape(self, per_replica, shape):
for v in per_replica._values: # pylint: disable=protected-access
self.assertEqual(v.shape, shape)
class GroupBySizeTest(test.TestCase):
def testPreferLargerPack(self):
# Each packs except the last one should be equal or larger than
@ -133,49 +128,38 @@ class PackBySizeTest(test.TestCase):
# size = 1 * 4 = 4
array_ops.ones([1], dtype=dtypes.int32),
]
per_replica_values = [value_lib.PerReplica([v, v]) for v in values]
packs = cross_device_utils.pack_by_size(
per_replica_values, bytes_per_pack=200)
packs = cross_device_utils.group_by_size(values, bytes_per_pack=200)
self.assertLen(packs, 2)
self.assertLen(packs[0], 3)
self.assertShape(packs[0][0], [2, 4, 4])
self.assertShape(packs[0][1], [8])
self.assertShape(packs[0][2], [10, 10])
self.assertEqual(packs[0][0].shape, [2, 4, 4])
self.assertEqual(packs[0][1].shape, [8])
self.assertEqual(packs[0][2].shape, [10, 10])
self.assertLen(packs[1], 1)
self.assertShape(packs[1][0], [1])
self.assertEqual(packs[1][0].shape, [1])
def testZeroBytesPerPack(self):
values = [
array_ops.ones([1], dtype=dtypes.float32),
array_ops.ones([2], dtype=dtypes.float32),
]
per_replica_values = [value_lib.PerReplica([v, v]) for v in values]
packs = cross_device_utils.pack_by_size(
per_replica_values, bytes_per_pack=0)
packs = cross_device_utils.group_by_size(values, bytes_per_pack=0)
self.assertLen(packs, 1)
self.assertLen(packs[0], 2)
self.assertShape(packs[0][0], [1])
self.assertShape(packs[0][1], [2])
self.assertEqual(packs[0][0].shape, [1])
self.assertEqual(packs[0][1].shape, [2])
def testUnknownShape(self):
def create_placeholder(shape, dtype):
with ops.Graph().as_default():
return array_ops.placeholder(dtype=dtype, shape=shape)
per_replica_values = [
value_lib.PerReplica([
array_ops.ones([10, 10], dtype=dtypes.float32),
array_ops.ones([10, 10], dtype=dtypes.float32),
]),
value_lib.PerReplica([
array_ops.ones([10, 10], dtype=dtypes.float32),
create_placeholder([None, 10], dtype=dtypes.float32),
]),
values = [
array_ops.ones([10, 10], dtype=dtypes.float32),
create_placeholder([None, 10], dtype=dtypes.float32),
]
packs = cross_device_utils.pack_by_size(
per_replica_values, bytes_per_pack=1)
packs = cross_device_utils.group_by_size(values, bytes_per_pack=1)
self.assertLen(packs, 1)
self.assertEqual(packs[0], per_replica_values)
self.assertEqual(packs[0], values)
if __name__ == "__main__":

View File

@ -27,7 +27,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import cluster_resolver
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib
from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
@ -451,6 +451,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
# Reusing keys is not supported well. So we have to give a different
# collective key base for different tests.
CollectiveAllReduceTest.collective_key_base += 100000
mwms_lib.CollectiveAllReduceStrategy._collective_key_base = (
CollectiveAllReduceTest.collective_key_base)
def _get_test_objects(self,
task_type,
@ -460,10 +462,7 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
use_strategy_object=False,
local_mode=False):
collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=10 + CollectiveAllReduceTest.collective_key_base,
op_instance_key_start=100 + CollectiveAllReduceTest.collective_key_base,
variable_instance_key_start=10000 +
CollectiveAllReduceTest.collective_key_base)
group_key_start=10 + CollectiveAllReduceTest.collective_key_base)
if local_mode:
if num_gpus:
devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
@ -471,13 +470,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
devices = ["/device:CPU:0"]
if use_strategy_object:
strategy = (
collective_all_reduce_strategy.CollectiveAllReduceStrategy
._from_local_devices(devices, communication=communication)) # pylint: disable=protected-access
strategy.extended._collective_keys = collective_keys
strategy.extended._cross_device_ops._collective_keys = collective_keys
strategy.extended._host_cross_device_ops._collective_keys = (
collective_keys)
strategy = (mwms_lib.CollectiveAllReduceStrategy
._from_local_devices(devices, communication=communication)) # pylint: disable=protected-access
return strategy, devices, ""
else:
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
@ -507,10 +501,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
task_type=task_type,
task_id=task_id,
num_accelerators={"GPU": num_gpus})
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
strategy = mwms_lib.CollectiveAllReduceStrategy(
cluster_resolver=resolver, communication=communication)
strategy.extended._collective_keys = collective_keys
strategy.extended._cross_device_ops._collective_keys = collective_keys
return (strategy, devices,
"grpc://" + self._cluster_spec[task_type][task_id])
else:

View File

@ -689,7 +689,7 @@ cuda_py_test(
name = "multi_worker_test",
srcs = ["multi_worker_test.py"],
python_version = "PY3",
shard_count = 32,
shard_count = 2,
tags = [
"multi_and_single_gpu",
"no_oss", # TODO(b/130369494): Investigate why it times out on OSS.

View File

@ -24,9 +24,8 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import multi_worker_util
@ -80,7 +79,7 @@ def create_test_objects(cluster_spec=None,
ClusterSpec({}), num_accelerators={'GPU': num_gpus})
target = ''
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
strategy = mwms_lib.CollectiveAllReduceStrategy(
cluster_resolver=cluster_resolver)
sess_config = strategy.update_config_proto(sess_config)
@ -95,9 +94,7 @@ class CollectiveAllReduceStrategyTestBase(
def setUp(self):
# We use a different key_base for each test so that collective keys won't be
# reused.
# TODO(yuefengz, ayushd): enable it to reuse collective keys in different
# tests.
CollectiveAllReduceStrategyTestBase.collective_key_base += 100000
mwms_lib.CollectiveAllReduceStrategy._collective_key_base += 100000
super(CollectiveAllReduceStrategyTestBase, self).setUp()
def _get_test_object(self, task_type, task_id, num_gpus=0):
@ -106,18 +103,6 @@ class CollectiveAllReduceStrategyTestBase(
task_type=task_type,
task_id=task_id,
num_gpus=num_gpus)
collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=10 +
CollectiveAllReduceStrategyTestBase.collective_key_base,
op_instance_key_start=100 +
CollectiveAllReduceStrategyTestBase.collective_key_base,
variable_instance_key_start=10000 +
CollectiveAllReduceStrategyTestBase.collective_key_base)
strategy.extended._collective_keys = collective_keys
strategy.extended._cross_device_ops._collective_keys = collective_keys
strategy.extended._host_cross_device_ops._collective_keys = collective_keys
return strategy, target, session_config
def _test_complex_model(self, task_type, task_id, num_gpus):