Refactor collective utils to be of one replica
We used to have build_collective_* that launches collective on all replicas. However each replica is independent and it can simplify the utils if it only launches on one replica. This is mainly to prepare for switching to Collective V2 ops and reduce the use of merge_call in MirroredStrategy/MultiWorkerMirroredStrategy. For the former, it will be easier after this change. For the latter, we can directly call CollectiveReplicaLauncher.all_reduce from ReplicaContext.all_reduce without the need of merge_call in MS/MWMS. Note that the original CollecticeKeys is outdated. Now instance key is scoped under each group key, and is no longer global. PiperOrigin-RevId: 337007588 Change-Id: I111efad0ecbe829260c8e4d2f555c6b2dc1d4c05
This commit is contained in:
parent
c451bbccd9
commit
746af76e4f
@ -96,13 +96,14 @@ py_library(
|
||||
":values",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:collective_ops",
|
||||
"//tensorflow/python:device",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nccl_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -101,6 +101,9 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
||||
# TODO(anjalisridhar): Update our guides with examples showing how we can use
|
||||
# the cluster_resolver argument.
|
||||
|
||||
# The starting number for collective keys. This should only be set in tests.
|
||||
_collective_key_base = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
communication=cross_device_ops_lib.CollectiveCommunication.AUTO,
|
||||
@ -362,7 +365,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
else:
|
||||
local_devices = (self._worker_device,)
|
||||
|
||||
self._collective_keys = cross_device_utils.CollectiveKeys()
|
||||
self._collective_keys = cross_device_utils.CollectiveKeys(
|
||||
group_key_start=1 + CollectiveAllReduceStrategy._collective_key_base) # pylint: disable=protected-access
|
||||
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
devices=local_devices,
|
||||
group_size=len(local_devices) * self._num_workers,
|
||||
@ -428,7 +432,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
group_key = self._collective_keys.get_group_key([device])
|
||||
group_size = self._num_workers
|
||||
collective_instance_key = (
|
||||
self._collective_keys.get_variable_instance_key())
|
||||
self._collective_keys.get_instance_key(group_key, device))
|
||||
|
||||
with ops.device(device):
|
||||
initial_value = kwargs["initial_value"]
|
||||
|
@ -29,7 +29,6 @@ from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
@ -59,6 +58,8 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
|
||||
|
||||
CollectiveAllReduceStrategy = (
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
|
||||
CollectiveAllReduceExtended = (
|
||||
collective_all_reduce_strategy.CollectiveAllReduceExtended)
|
||||
|
||||
@ -93,14 +94,10 @@ def create_test_objects(cluster_spec=None,
|
||||
class CollectiveAllReduceStrategyTestBase(
|
||||
multi_worker_test_base.MultiWorkerTestBase):
|
||||
|
||||
collective_key_base = 0
|
||||
|
||||
def setUp(self):
|
||||
# We use a different key_base for each test so that collective keys won't be
|
||||
# reused.
|
||||
# TODO(yuefengz, ayushd): enable it to reuse collective keys in different
|
||||
# tests.
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base += 100000
|
||||
CollectiveAllReduceStrategy._collective_key_base += 100000
|
||||
super(CollectiveAllReduceStrategyTestBase, self).setUp()
|
||||
|
||||
def _get_test_object(self, task_type, task_id, num_gpus=0):
|
||||
@ -109,18 +106,6 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
task_type=task_type,
|
||||
task_id=task_id,
|
||||
num_gpus=num_gpus)
|
||||
|
||||
collective_keys = cross_device_utils.CollectiveKeys(
|
||||
group_key_start=10 +
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base,
|
||||
op_instance_key_start=100 +
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base,
|
||||
variable_instance_key_start=10000 +
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base)
|
||||
strategy.extended._collective_keys = collective_keys
|
||||
strategy.extended._cross_device_ops._collective_keys = collective_keys
|
||||
strategy.extended._host_cross_device_ops._collective_keys = collective_keys
|
||||
|
||||
return strategy, target, session_config
|
||||
|
||||
def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
|
||||
|
@ -36,7 +36,7 @@ from tensorflow.python.distribute import tpu_values
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import executor
|
||||
from tensorflow.python.eager import executor as executor_lib
|
||||
from tensorflow.python.framework import kernels
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
@ -144,8 +144,8 @@ def _normalize_value_destination_pairs(value_destination_pairs):
|
||||
|
||||
|
||||
def _validate_value_destination_pairs(value_destination_pairs):
|
||||
"""Validates value_destination_pairs are valid."""
|
||||
# TODO(yuefengz): raise exceptions instead of returning False.
|
||||
# pylint: disable=g-missing-docstring
|
||||
if not value_destination_pairs: return False
|
||||
if not isinstance(value_destination_pairs, (list, tuple)): return False
|
||||
if not all(isinstance(pair, tuple) for pair in value_destination_pairs):
|
||||
@ -197,7 +197,7 @@ def simple_broadcast(value, destinations, always_mirrored=False):
|
||||
|
||||
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
|
||||
reduce_op):
|
||||
# pylint: disable=g-missing-docstring
|
||||
"""Reduces the value by accumulation_fn and reduce_op."""
|
||||
all_values = per_replica_value.values
|
||||
if not all_values:
|
||||
raise ValueError("`per_replica_value` must be non-empty")
|
||||
@ -1019,7 +1019,6 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
if group_size % len(devices) > 0:
|
||||
raise ValueError("group_size must be divisible by the number of devices.")
|
||||
|
||||
self._devices = tuple(device_util.canonicalize(d) for d in devices)
|
||||
self._group_size = group_size
|
||||
self._collective_keys = (collective_keys or
|
||||
cross_device_utils.CollectiveKeys())
|
||||
@ -1039,14 +1038,21 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
# This deadlocks since neither collective is able to finish.
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._devices = tuple(device_util.canonicalize(d) for d in devices)
|
||||
group_key = self._collective_keys.get_group_key(self._devices)
|
||||
# Collective ops requires all devices to participate and is blocking. In
|
||||
# eager, we need one async executor for each device to be able to launch
|
||||
# them altogether. Note that async doesn't imply concurrency. Within an
|
||||
# async executor operations are still executed sequentially. In graph or
|
||||
# function building, the executors are not used.
|
||||
self._executors = []
|
||||
for _ in range(len(devices)):
|
||||
self._executors.append(executor.new_executor(enable_async=True))
|
||||
self._launchers = []
|
||||
for device in self._devices:
|
||||
executor = executor_lib.new_executor(enable_async=True)
|
||||
self._executors.append(executor)
|
||||
launcher = cross_device_utils.CollectiveReplicaLauncher(
|
||||
group_key, group_size, self._collective_keys, device, executor)
|
||||
self._launchers.append(launcher)
|
||||
|
||||
super(CollectiveAllReduce, self).__init__()
|
||||
|
||||
@ -1148,63 +1154,39 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
# queuing time due to concurrent intense computation.
|
||||
#
|
||||
# TODO(b/147393503): explore solutions for optimal ordering.
|
||||
packs = cross_device_utils.pack_by_size(
|
||||
list(reversed(per_replica_values)), experimental_hints.bytes_per_pack)
|
||||
values_by_device = [[] for _ in range(len(self._devices))]
|
||||
for per_replica in reversed(per_replica_values):
|
||||
for i in range(len(self._devices)):
|
||||
values_by_device[i].append(per_replica.values[i])
|
||||
|
||||
if batch_size > 1:
|
||||
logging.info(
|
||||
"Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
|
||||
"group_size = %d, communication_hint = %s, num_packs = %d",
|
||||
batch_size, len(self._devices), self._group_size, communication,
|
||||
len(packs))
|
||||
else:
|
||||
logging.log_first_n(
|
||||
logging.INFO, "Collective batch_all_reduce: %d all-reduces, "
|
||||
"num_devices = %d, group_size = %d, communication_hint = %s, "
|
||||
"num_packs = %d" % (batch_size, len(
|
||||
self._devices), self._group_size, communication, len(packs)), 10)
|
||||
|
||||
reduced_values = []
|
||||
outputs_by_device = []
|
||||
with self._lock:
|
||||
for pack in packs:
|
||||
# By placing all CollectiveReduce ops in a pack under single name scope,
|
||||
# we ensure they will be picked up by the `ScopedAllocator` grappler
|
||||
# optimizer and packed into a single all-reduce.
|
||||
with ops.name_scope("allreduce"):
|
||||
for per_replica in pack:
|
||||
# Add control dependencies per device from the last gradients to the
|
||||
# current set, in order to serialize NCCL launches.
|
||||
if (communication == CollectiveCommunication.NCCL.value and
|
||||
reduced_values):
|
||||
control_inputs = list(reduced_values[-1])
|
||||
else:
|
||||
control_inputs = None
|
||||
reduced_values.append(
|
||||
cross_device_utils.build_collective_reduce(
|
||||
per_replica.values,
|
||||
self._devices,
|
||||
self._group_size,
|
||||
self._collective_keys,
|
||||
"Add",
|
||||
"Id",
|
||||
communication,
|
||||
control_inputs,
|
||||
executors=self._executors,
|
||||
timeout=experimental_hints.timeout_seconds))
|
||||
for i in range(len(self._devices)):
|
||||
packs = cross_device_utils.group_by_size(
|
||||
values_by_device[i], experimental_hints.bytes_per_pack)
|
||||
if not context.executing_eagerly() and i == 0:
|
||||
logging.info(
|
||||
"Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
|
||||
"group_size = %d, communication_hint = %s, num_packs = %d",
|
||||
batch_size, len(self._launchers), self._group_size, communication,
|
||||
len(packs))
|
||||
outputs_by_device.append(self._launchers[i].batch_all_reduce(
|
||||
packs, communication, experimental_hints.timeout_seconds))
|
||||
|
||||
for e in self._executors:
|
||||
e.wait()
|
||||
|
||||
mirrored = []
|
||||
# Reverse the order of reduced value to recover the order in the input.
|
||||
for value in reversed(reduced_values):
|
||||
for values in zip(*outputs_by_device):
|
||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||
for i, v in enumerate(value):
|
||||
values = list(values)
|
||||
for i, v in enumerate(values):
|
||||
with ops.device(v.device):
|
||||
value[i] = v / self._group_size
|
||||
values[i] = v / self._group_size
|
||||
mirrored.append(
|
||||
distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
|
||||
return mirrored
|
||||
distribute_utils.regroup(values, wrap_class=value_lib.Mirrored))
|
||||
# Reverse the order of reduced value to recover the order in the input.
|
||||
return list(reversed(mirrored))
|
||||
|
||||
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values,
|
||||
experimental_hints):
|
||||
@ -1217,24 +1199,16 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
|
||||
# Pass self._communication to the runtime as a communication hint.
|
||||
communication_hint = self._communication.value
|
||||
# For now, we use NCCL only when batch_size > 1.
|
||||
# TODO(b/132575814): switch to NCCL for all collectives when communication
|
||||
# is NCCL.
|
||||
if self._communication == CollectiveCommunication.NCCL and len(
|
||||
per_replica_values) == 1:
|
||||
communication_hint = CollectiveCommunication.AUTO.value
|
||||
|
||||
gathered_values = []
|
||||
with self._lock, ops.name_scope("allreduce"):
|
||||
with self._lock:
|
||||
for per_replica in per_replica_values:
|
||||
gathered_values.append(
|
||||
cross_device_utils.build_collective_gather_indexed_slices(
|
||||
per_replica.values,
|
||||
self._devices,
|
||||
self._group_size,
|
||||
self._collective_keys,
|
||||
communication_hint,
|
||||
timeout=experimental_hints.timeout_seconds))
|
||||
outputs = []
|
||||
for i in range(len(self._devices)):
|
||||
outputs.append(self._launchers[i].all_reduce_indexed_slices(
|
||||
per_replica.values[i], communication_hint,
|
||||
experimental_hints.timeout_seconds))
|
||||
gathered_values.append(outputs)
|
||||
|
||||
mirrored = []
|
||||
for value in gathered_values:
|
||||
@ -1281,11 +1255,6 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
batch_size = len(per_replica_values)
|
||||
# Pass self._communication to the runtime as a communication hint.
|
||||
communication = self._communication.value
|
||||
# For now, we use NCCL only when batch_size > 1.
|
||||
# TODO(b/132575814): switch to NCCL for all collectives when communication
|
||||
# is NCCL.
|
||||
if self._communication == CollectiveCommunication.NCCL and batch_size == 1:
|
||||
communication = CollectiveCommunication.AUTO.value
|
||||
|
||||
logging.log_first_n(
|
||||
logging.INFO, "Collective batch_all_gather: %d all-gathers, "
|
||||
@ -1296,21 +1265,12 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
gathered_values = []
|
||||
with self._lock, ops.name_scope("allgather"):
|
||||
for per_replica in per_replica_values:
|
||||
if (communication == CollectiveCommunication.NCCL.value and
|
||||
gathered_values):
|
||||
control_inputs = list(gathered_values[-1])
|
||||
else:
|
||||
control_inputs = None
|
||||
gathered_values.append(
|
||||
cross_device_utils.build_collective_gather(
|
||||
per_replica.values,
|
||||
self._devices,
|
||||
self._group_size,
|
||||
self._collective_keys,
|
||||
axis,
|
||||
communication,
|
||||
control_inputs,
|
||||
timeout=experimental_hints.timeout_seconds))
|
||||
outputs = []
|
||||
for i in range(len(self._devices)):
|
||||
outputs.append(self._launchers[i].all_gather(
|
||||
per_replica.values[i], axis, communication,
|
||||
experimental_hints.timeout_seconds))
|
||||
gathered_values.append(outputs)
|
||||
return gathered_values
|
||||
|
||||
if context.executing_eagerly():
|
||||
@ -1319,8 +1279,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
gathered_values = compute_gathered_values()
|
||||
|
||||
mirrored = []
|
||||
# Reverse the order of gathered value to recover the order in the input.
|
||||
for value in reversed(gathered_values):
|
||||
for value in gathered_values:
|
||||
mirrored.append(
|
||||
distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
|
||||
return mirrored
|
||||
|
@ -24,7 +24,6 @@ import threading
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -34,7 +33,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nccl_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
OP_INSTANCE_KEY_START_NUMBER = 100
|
||||
INSTANCE_KEY_START_NUMBER = 100
|
||||
|
||||
|
||||
def aggregate_gradients_using_nccl(replica_grads):
|
||||
@ -181,69 +180,66 @@ class CollectiveKeys(object):
|
||||
*Instance key*: an integer key to identify the set of same counterpart of
|
||||
tensors on different devices in a device group that need to be all-reduced.
|
||||
|
||||
"Graph key": an integer key that is unique key graph. This is used to support
|
||||
multiple graphs per client session. It must be non-zero and set in the
|
||||
`config` argument of each call to `session.run`.
|
||||
|
||||
This class is thread safe.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
group_key_start=1,
|
||||
op_instance_key_start=OP_INSTANCE_KEY_START_NUMBER,
|
||||
variable_instance_key_start=1000000):
|
||||
def __init__(self, group_key_start=1):
|
||||
"""Initializes the object.
|
||||
|
||||
Args:
|
||||
group_key_start: the starting integer of group key.
|
||||
op_instance_key_start: the starting integer of instance key for ops.
|
||||
variable_instance_key_start: the starting integer of instance key for
|
||||
variables.
|
||||
"""
|
||||
self._group_key = group_key_start
|
||||
self._group_key_table = {}
|
||||
|
||||
assert op_instance_key_start != variable_instance_key_start
|
||||
self._op_instance_key = op_instance_key_start
|
||||
self._variable_instance_key = variable_instance_key_start
|
||||
self._instance_key_table = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_group_key(self, devices):
|
||||
"""Returns a group key for the set of devices.
|
||||
|
||||
Args:
|
||||
devices: list of strings naming devices in a collective group.
|
||||
devices: a list of canonical device strings in a collective group.
|
||||
|
||||
Returns:
|
||||
int key uniquely identifying the set of device names.
|
||||
"""
|
||||
parsed = [pydev.DeviceSpec.from_string(d) for d in devices]
|
||||
# In the between-graph replicated training, different workers need to get
|
||||
# the same device key. So we remove the task_type and task_id from the
|
||||
# devices.
|
||||
# TODO(yuefengz): in the in-graph replicated training, we need to include
|
||||
# task_type and task_id.
|
||||
names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed])
|
||||
key_id = ','.join(names)
|
||||
key_id = hash(tuple(sorted(devices)))
|
||||
with self._lock:
|
||||
if key_id not in self._group_key_table:
|
||||
new_key = self._group_key
|
||||
self._group_key += 1
|
||||
self._group_key_table[key_id] = new_key
|
||||
self._instance_key_table[new_key] = {}
|
||||
for device in devices:
|
||||
self._instance_key_table[new_key][device] = INSTANCE_KEY_START_NUMBER
|
||||
return self._group_key_table[key_id]
|
||||
|
||||
def get_op_instance_key(self):
|
||||
"""Returns a new instance key for use in defining a collective op."""
|
||||
with self._lock:
|
||||
v = self._op_instance_key
|
||||
self._op_instance_key += 1
|
||||
return v
|
||||
def get_instance_key(self, group_key, device):
|
||||
"""Returns a new instance key for use in defining a collective op.
|
||||
|
||||
def get_variable_instance_key(self):
|
||||
"""Returns a new instance key for use in creating a Variable."""
|
||||
You should call this once per each collective op of a collective instance.
|
||||
|
||||
Args:
|
||||
group_key: the group key returned by get_group_key(). You should not
|
||||
assign the group key yourself.
|
||||
device: a canonical device string. It should be the device this collective
|
||||
op is on.
|
||||
|
||||
Returns:
|
||||
a new instance key.
|
||||
|
||||
Raises:
|
||||
ValueError: when the group key is invalid or the device is not in the
|
||||
group.
|
||||
"""
|
||||
with self._lock:
|
||||
v = self._variable_instance_key
|
||||
self._variable_instance_key += 1
|
||||
group = self._instance_key_table.get(group_key, None)
|
||||
if group is None:
|
||||
raise ValueError('group {} not found'.format(group_key))
|
||||
if device not in group:
|
||||
raise ValueError('{} not in group {}'.format(device, group_key))
|
||||
v = group[device]
|
||||
group[device] += 1
|
||||
return v
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
@ -252,135 +248,146 @@ class CollectiveKeys(object):
|
||||
copied = CollectiveKeys()
|
||||
copied._group_key = self._group_key
|
||||
copied._group_key_table = copy.deepcopy(self._group_key_table, memo)
|
||||
copied._op_instance_key = self._op_instance_key
|
||||
copied._variable_instance_key = self._variable_instance_key
|
||||
copied._instance_key_table = copy.deepcopy(self._instance_key_table, memo)
|
||||
return copied
|
||||
|
||||
|
||||
def build_collective_reduce(input_tensors,
|
||||
devices,
|
||||
group_size,
|
||||
collective_keys,
|
||||
reduction_op='Add',
|
||||
unary_op='Id',
|
||||
communication_hint='AUTO',
|
||||
control_inputs=None,
|
||||
executors=None,
|
||||
timeout=None):
|
||||
"""Build a subgraph that does one full all-reduce, using the collective Op.
|
||||
class CollectiveReplicaLauncher(object):
|
||||
"""Launch collectives on one replica."""
|
||||
|
||||
If called in eager mode, it's required to supply a list of async executors for
|
||||
each input Tensor.
|
||||
def __init__(self,
|
||||
group_key,
|
||||
group_size,
|
||||
collective_keys,
|
||||
device,
|
||||
executor=None):
|
||||
if executor and not executor.is_async():
|
||||
raise ValueError('executor must be async')
|
||||
self._group_key = group_key
|
||||
self._group_size = group_size
|
||||
self._collective_keys = collective_keys
|
||||
self._device = device
|
||||
self._executor = executor
|
||||
|
||||
Args:
|
||||
input_tensors: tensors within a single worker graph that are to be reduced
|
||||
together; must be one per device.
|
||||
devices: a list of device strings to run the collective on.
|
||||
group_size: total number of devices globally that will be doing this same
|
||||
reduction. The reduction will actually include the corresponding tensors
|
||||
at all these workers.
|
||||
collective_keys: a CollectiveKeys object.
|
||||
reduction_op: string naming the reduction op.
|
||||
unary_op: string naming the unary final op.
|
||||
communication_hint: string providing hint to runtime for choosing collective
|
||||
implementation.
|
||||
control_inputs: if not None, add control edges between control_inputs and
|
||||
(index-wise) corresponding collective_reduce tensors
|
||||
executors: a list of async executor. Required for eager execution.
|
||||
timeout: a float or None. The timeout in seconds.
|
||||
|
||||
Returns:
|
||||
An array of final tensors, one per device, computed by the full reduction.
|
||||
|
||||
Raises:
|
||||
ValueError: There must be at least two tensors over all the workers.
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
if (not executors or len(executors) != len(input_tensors) or
|
||||
not all(e.is_async() for e in executors)):
|
||||
raise ValueError(
|
||||
'collectives requires async executors for each device in eager mode')
|
||||
if len(input_tensors) != len(devices):
|
||||
raise ValueError('collective requires one input tensor for each device, '
|
||||
'len(input_tensors) = %d, len(devices) = %d' %
|
||||
(len(input_tensors), len(devices)))
|
||||
|
||||
if group_size < 2:
|
||||
return input_tensors
|
||||
group_key = collective_keys.get_group_key(devices)
|
||||
instance_key = collective_keys.get_op_instance_key()
|
||||
subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec
|
||||
|
||||
out_tensors = []
|
||||
for idx, input_tensor in enumerate(input_tensors):
|
||||
def _executor_scope(self):
|
||||
if context.executing_eagerly() and not self._executor:
|
||||
raise ValueError('collectives requires a async executor in eager mode')
|
||||
if context.executing_eagerly():
|
||||
executor_scope = context.executor_scope(executors[idx])
|
||||
else:
|
||||
executor_scope = ops.NullContextmanager()
|
||||
with executor_scope, \
|
||||
ops.device(devices[idx]), \
|
||||
ops.control_dependencies(
|
||||
_control_input(devices, control_inputs, idx)):
|
||||
out_tensor = collective_ops.all_reduce(
|
||||
return context.executor_scope(self._executor)
|
||||
return ops.NullContextmanager()
|
||||
|
||||
def _control_input(self, control_input):
|
||||
if control_input is not None:
|
||||
return ops.control_dependencies([control_input])
|
||||
return ops.NullContextmanager()
|
||||
|
||||
def all_reduce(self,
|
||||
input_tensor,
|
||||
control_input=None,
|
||||
communication_hint='AUTO',
|
||||
timeout=0):
|
||||
"""All-reduce a dense tensor.
|
||||
|
||||
This can be called in eager mode if a async executor is supplied when
|
||||
creating the launcher.
|
||||
|
||||
Args:
|
||||
input_tensor: a dense tensor. It must have the same shape on all replicas.
|
||||
control_input: if not None, add control edges between control_input and
|
||||
the all-reduce.
|
||||
communication_hint: string providing hint to runtime for choosing
|
||||
collective implementation.
|
||||
timeout: a float. The timeout in seconds.
|
||||
|
||||
Returns:
|
||||
The reduced tensor.
|
||||
"""
|
||||
instance_key = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
with self._executor_scope(), \
|
||||
ops.device(self._device), \
|
||||
self._control_input(control_input):
|
||||
return collective_ops.all_reduce(
|
||||
input_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key,
|
||||
reduction_op,
|
||||
unary_op,
|
||||
subdiv_offsets,
|
||||
communication_hint,
|
||||
communication_hint=communication_hint,
|
||||
timeout=timeout)
|
||||
out_tensors.append(out_tensor)
|
||||
return out_tensors
|
||||
|
||||
def batch_all_reduce(self,
|
||||
input_tensor_packs,
|
||||
communication_hint='AUTO',
|
||||
timeout=0):
|
||||
"""Batch all-reduce dense tensors.
|
||||
|
||||
def build_collective_gather(input_tensors,
|
||||
devices,
|
||||
group_size,
|
||||
collective_keys,
|
||||
axis,
|
||||
communication_hint='AUTO',
|
||||
control_inputs=None,
|
||||
timeout=None):
|
||||
"""Build a subgraph that does one full all-gather, using the collective Op.
|
||||
This takes a list of batches of tensors. Using multiple batches have the
|
||||
benefit that it doesn't need to wait for all inputs to be ready to start the
|
||||
all-reduce.
|
||||
|
||||
This method must be called in graph mode or inside a tf.function.
|
||||
This can be called in eager mode if a async executor is supplied when
|
||||
creating the launcher.
|
||||
|
||||
Args:
|
||||
input_tensors: tensors within a single worker graph that are to be gathered
|
||||
together; must be one per device. Input tensors cannot have rank 0.
|
||||
devices: a list of device strings to run the collective on.
|
||||
group_size: total number of devices globally that will be doing this same
|
||||
gathering. The gathering will actually include the corresponding tensors
|
||||
at all these workers.
|
||||
collective_keys: a CollectiveKeys object.
|
||||
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
|
||||
range [0, rank(value)).
|
||||
communication_hint: string providing hint to runtime for choosing collective
|
||||
implementation. Available options are `AUTO`, `NCCL`, and `RING`.
|
||||
control_inputs: if not None, add control edges between control_inputs and
|
||||
(index-wise) corresponding collective_gather tensors
|
||||
timeout: a float or None. The timeout in seconds.
|
||||
Args:
|
||||
input_tensor_packs: a list of lists of dense tensors.
|
||||
communication_hint: string providing hint to runtime for choosing
|
||||
collective implementation.
|
||||
timeout: a float. The timeout in seconds.
|
||||
|
||||
Returns:
|
||||
An array of final tensors, one per device, computed by the full gather.
|
||||
"""
|
||||
if len(input_tensors) != len(devices):
|
||||
raise ValueError(
|
||||
'collective requires one input tensor for each device, %d != %d' %
|
||||
(len(input_tensors), len(devices)))
|
||||
Returns:
|
||||
A flat list of reduced tensors.
|
||||
"""
|
||||
outputs = []
|
||||
for pack in input_tensor_packs:
|
||||
# By placing all CollectiveReduce ops in a batch under single name scope,
|
||||
# we ensure they will be picked up by the `ScopedAllocator` grappler
|
||||
# optimizer and packed into a single all-reduce.
|
||||
with ops.name_scope('allreduce'):
|
||||
# TODO(b/169168846): inserts a parallel all_gather to verify packings
|
||||
# are the same on each replica.
|
||||
for input_tensor in pack:
|
||||
if communication_hint == 'NCCL' and outputs:
|
||||
control_input = outputs[-1]
|
||||
else:
|
||||
control_input = None
|
||||
outputs.append(
|
||||
self.all_reduce(input_tensor, control_input, communication_hint,
|
||||
timeout))
|
||||
return outputs
|
||||
|
||||
if group_size < 2:
|
||||
return input_tensors
|
||||
group_key = collective_keys.get_group_key(devices)
|
||||
instance_key_tensor = collective_keys.get_op_instance_key()
|
||||
instance_key_shape = collective_keys.get_op_instance_key()
|
||||
def all_gather(self,
|
||||
input_tensor,
|
||||
axis,
|
||||
communication_hint='AUTO',
|
||||
timeout=0):
|
||||
"""All-gather a dense tensor.
|
||||
|
||||
out_tensors = []
|
||||
for idx, input_tensor in enumerate(input_tensors):
|
||||
with ops.device(devices[idx]), ops.control_dependencies(
|
||||
_control_input(devices, control_inputs, idx)):
|
||||
This method must be called inside a tf.function.
|
||||
|
||||
Args:
|
||||
input_tensor: a dense tensor. It must have the same rank on all replicas,
|
||||
and dimensions other than `axis` need to be the same as well.
|
||||
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
|
||||
range [0, rank(value)).
|
||||
communication_hint: string providing hint to runtime for choosing
|
||||
collective implementation. Available options are `AUTO`, `NCCL`, and
|
||||
`RING`.
|
||||
timeout: a float. The timeout in seconds.
|
||||
|
||||
Returns:
|
||||
The gathered Tensor.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if called in eager mode.
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('all_gather in eager mode is not supported')
|
||||
|
||||
instance_key_tensor = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
instance_key_shape = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
with ops.device(self._device):
|
||||
# 1. Transpose
|
||||
# E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
|
||||
# we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which
|
||||
@ -394,8 +401,8 @@ def build_collective_gather(input_tensors,
|
||||
# 2. Pad
|
||||
gathered_shape = collective_ops.all_gather(
|
||||
array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0),
|
||||
group_size,
|
||||
group_key,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key_shape,
|
||||
communication_hint,
|
||||
timeout=timeout)
|
||||
@ -406,8 +413,8 @@ def build_collective_gather(input_tensors,
|
||||
# 3. Gather
|
||||
gather_padded_out_tensor = collective_ops.all_gather(
|
||||
padded_input_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key_tensor,
|
||||
communication_hint,
|
||||
timeout=timeout)
|
||||
@ -424,93 +431,56 @@ def build_collective_gather(input_tensors,
|
||||
(math_ops.range(1, axis + 1), [0],
|
||||
math_ops.range(axis + 1, array_ops.rank(input_tensor_t))),
|
||||
axis=0)
|
||||
out_tensor = array_ops.transpose(out_tensor_t, perm=perm_after)
|
||||
out_tensors.append(out_tensor)
|
||||
return out_tensors
|
||||
return array_ops.transpose(out_tensor_t, perm=perm_after)
|
||||
|
||||
def all_reduce_indexed_slices(self,
|
||||
input_slices,
|
||||
communication_hint='AUTO',
|
||||
timeout=0):
|
||||
"""All-reduce an IndexedSlices.
|
||||
|
||||
def _pad_util(input_tensor, full_axis_dim):
|
||||
"""Pad the `input_tensor`'s first dimension to be `full_axis_dim`."""
|
||||
missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0]
|
||||
tensor_rank = array_ops.rank(input_tensor)
|
||||
paddings_axis = [[0, missing_axis_dim]]
|
||||
paddings = array_ops.concat([
|
||||
paddings_axis,
|
||||
array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32)
|
||||
],
|
||||
axis=0)
|
||||
padded_input_tensor = array_ops.pad(input_tensor, paddings)
|
||||
return padded_input_tensor
|
||||
This method must be called inside a tf.function.
|
||||
|
||||
Args:
|
||||
input_slices: an IndexedSlices.
|
||||
communication_hint: string providing hint to runtime for choosing
|
||||
collective implementation.
|
||||
timeout: a float. The timeout in seconds.
|
||||
|
||||
def build_collective_gather_indexed_slices(input_slices_list,
|
||||
devices,
|
||||
group_size,
|
||||
collective_keys,
|
||||
communication_hint='AUTO',
|
||||
control_inputs=None,
|
||||
timeout=None):
|
||||
"""Build a subgraph that all-gathers IndexedSlices using the collective Op.
|
||||
Returns:
|
||||
The reduced IndexedSlices.
|
||||
|
||||
This method must be called in graph mode or inside a tf.function.
|
||||
Raises:
|
||||
RuntimeError: if called in eager mode.
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
'all_reduce_indexed_slices in eager mode is not supported')
|
||||
|
||||
Args:
|
||||
input_slices_list: a list of IndexedSlices within a single worker graph that
|
||||
are to be gathered together; must be one per device.
|
||||
devices: a list of device strings to run the collective on.
|
||||
group_size: total number of devices globally that will be doing this same
|
||||
gathering. The gathering will actually include the corresponding tensors
|
||||
at all these workers.
|
||||
collective_keys: a CollectiveKeys object.
|
||||
communication_hint: string providing hint to runtime for choosing collective
|
||||
implementation.
|
||||
control_inputs: if not None, add control edges between control_inputs and
|
||||
(index-wise) corresponding collective_reduce tensors
|
||||
timeout: a float or None. The timeout in seconds.
|
||||
gather_length_key = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
gather_indices_key = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
gather_values_key = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
reduce_densified_key = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
|
||||
Returns:
|
||||
An array of final IndexedSlices, one per device, computed by the full
|
||||
gather.
|
||||
|
||||
Raises:
|
||||
ValueError: if control_inputs is not None and doesn't match the length and
|
||||
devices of inputs.
|
||||
"""
|
||||
assert not context.executing_eagerly(), (
|
||||
'build_collective_gather_indexed_slices can only be called in graph mode'
|
||||
' or inside tf.function')
|
||||
if len(input_slices_list) != len(devices):
|
||||
raise ValueError(
|
||||
'collective requires one input IndexedSlice for each device, %d != %d' %
|
||||
(len(input_slices_list), len(devices)))
|
||||
|
||||
if group_size < 2:
|
||||
return input_slices_list
|
||||
|
||||
group_key = collective_keys.get_group_key(devices)
|
||||
gather_length_key = collective_keys.get_op_instance_key()
|
||||
gather_indices_key = collective_keys.get_op_instance_key()
|
||||
gather_values_key = collective_keys.get_op_instance_key()
|
||||
reduce_densified_key = collective_keys.get_op_instance_key()
|
||||
|
||||
# Current CollectiveAllGather implementations require input IndexedSlices to
|
||||
# have consistent length across the board, we handle the reduction of
|
||||
# IndexedSlices as follows:
|
||||
# 1. Gather the lengths of IndexedSlices from all participants.
|
||||
# 2. If they have consistent length, apply all_gather.
|
||||
# 3. Otherwise convert IndexedSlices to dense tensors and apply
|
||||
# all_reduce.
|
||||
out_slices_list = []
|
||||
for idx, input_slices in enumerate(input_slices_list):
|
||||
# pylint: disable = cell-var-from-loop
|
||||
with ops.device(devices[idx]):
|
||||
# Current CollectiveAllGather implementations require input IndexedSlices to
|
||||
# have consistent length across the board, we handle the reduction of
|
||||
# IndexedSlices as follows:
|
||||
# 1. Gather the lengths of IndexedSlices from all participants.
|
||||
# 2. If they have consistent length, apply all_gather.
|
||||
# 3. Otherwise convert IndexedSlices to dense tensors and apply
|
||||
# all_reduce.
|
||||
with ops.device(self._device):
|
||||
|
||||
def all_gather():
|
||||
"""Use all_gather to aggregate `IndexedSlices`."""
|
||||
all_values = collective_ops.all_gather(
|
||||
input_slices.values,
|
||||
group_size,
|
||||
group_key,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
gather_values_key,
|
||||
communication_hint,
|
||||
timeout=timeout)
|
||||
@ -519,8 +489,8 @@ def build_collective_gather_indexed_slices(input_slices_list,
|
||||
with ops.control_dependencies(control):
|
||||
all_indices = collective_ops.all_gather(
|
||||
input_slices.indices,
|
||||
group_size,
|
||||
group_key,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
gather_indices_key,
|
||||
communication_hint,
|
||||
timeout=timeout)
|
||||
@ -534,8 +504,8 @@ def build_collective_gather_indexed_slices(input_slices_list,
|
||||
densified = ops.convert_to_tensor(input_slices)
|
||||
reduced = collective_ops.all_reduce(
|
||||
densified,
|
||||
group_size,
|
||||
group_key,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
reduce_densified_key,
|
||||
'Add',
|
||||
'Id', [0],
|
||||
@ -550,23 +520,18 @@ def build_collective_gather_indexed_slices(input_slices_list,
|
||||
dense_shape=input_slices.dense_shape)
|
||||
|
||||
length = array_ops.shape(input_slices.indices)
|
||||
with ops.control_dependencies(
|
||||
_control_input(input_slices, control_inputs, idx)):
|
||||
all_lengths = collective_ops.all_gather(
|
||||
length,
|
||||
group_size,
|
||||
group_key,
|
||||
gather_length_key,
|
||||
communication_hint,
|
||||
timeout=timeout)
|
||||
out_slices = control_flow_ops.cond(
|
||||
all_lengths = collective_ops.all_gather(
|
||||
length,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
gather_length_key,
|
||||
communication_hint,
|
||||
timeout=timeout)
|
||||
return control_flow_ops.cond(
|
||||
math_ops.equal(
|
||||
math_ops.reduce_max(all_lengths),
|
||||
math_ops.reduce_min(all_lengths)), all_gather,
|
||||
densify_and_all_reduce)
|
||||
out_slices_list.append(out_slices)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
return out_slices_list
|
||||
|
||||
|
||||
def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n):
|
||||
@ -653,56 +618,35 @@ def stitch_values(values_and_indices_list):
|
||||
return result
|
||||
|
||||
|
||||
def per_replica_num_elements(per_replica):
|
||||
"""Returns the static number of elements of one replica.
|
||||
def group_by_size(input_tensors, bytes_per_pack):
|
||||
"""Groups `input_tensors` into chunks of `bytes_per_pack`.
|
||||
|
||||
Args:
|
||||
per_replica: A PerReplica of Tensor or IndexedSlices.
|
||||
|
||||
Returns:
|
||||
Number of elements. None if some replica has a different or unknown shape.
|
||||
"""
|
||||
|
||||
values = per_replica._values # pylint: disable=protected-access
|
||||
s0 = values[0].shape
|
||||
for v in values:
|
||||
assert not isinstance(v, ops.IndexedSlices)
|
||||
if v.shape != s0:
|
||||
return None
|
||||
return s0.num_elements()
|
||||
|
||||
|
||||
def pack_by_size(per_replica_list, bytes_per_pack):
|
||||
"""Packs `per_replica_list` into chunks of `bytes_per_pack`.
|
||||
|
||||
The method preserves the original order of `per_replica_list`. The packing is
|
||||
The method preserves the original order of `input_tensors`. The grouping is
|
||||
best effort, each pack could have more or less bytes than `bytes_per_pack`.
|
||||
It only packs values with known shape. Note that, the usage is different from
|
||||
`cross_device_ops._pack_tensors`, this function is intended to work with the
|
||||
ScopeAllocator style batching used in `CollectiveAllReduce`.
|
||||
It only groups values with known shape.
|
||||
|
||||
Args:
|
||||
per_replica_list: A list of PerReplica.
|
||||
bytes_per_pack: Bytes per pack.
|
||||
input_tensors: a list of Tensor.
|
||||
bytes_per_pack: an integer.
|
||||
|
||||
Returns:
|
||||
A list of packs of PerReplica. All values are packed into one pack if
|
||||
`bytes_per_pack` is zero or any of the value has unknown shape.
|
||||
A list of packs of Tensor. All values are grouped into one pack if
|
||||
`bytes_per_pack` is zero or any of the value has unknown shape.
|
||||
"""
|
||||
|
||||
if bytes_per_pack == 0:
|
||||
return [per_replica_list]
|
||||
return [input_tensors]
|
||||
packs = []
|
||||
last_pack_size = 0
|
||||
for value in per_replica_list:
|
||||
num_elements = per_replica_num_elements(value)
|
||||
for value in input_tensors:
|
||||
num_elements = value.shape.num_elements()
|
||||
if num_elements is None:
|
||||
# Can't pack values with unknown shape.
|
||||
logging.warning(
|
||||
'not packing values due to the unknown or inconsistent shape of %s',
|
||||
value)
|
||||
return [per_replica_list]
|
||||
size = num_elements * value._primary.dtype.size # pylint: disable=protected-access
|
||||
return [input_tensors]
|
||||
size = num_elements * value.dtype.size
|
||||
# Try to keep each pack as close to bytes_per_pack as possible, while each
|
||||
# pack is at least bytes_per_pack large. I.E. we err on the side of having
|
||||
# few but large packs.
|
||||
@ -714,24 +658,15 @@ def pack_by_size(per_replica_list, bytes_per_pack):
|
||||
return packs
|
||||
|
||||
|
||||
def _control_input(devices, control_inputs, idx):
|
||||
"""Returns the `idx`-th item in control_inputs to be used in ops.control_dependencies.
|
||||
|
||||
This is a helper function for building collective ops.
|
||||
|
||||
Args:
|
||||
devices: a list of device strings the collective run on.
|
||||
control_inputs: a list or None.
|
||||
idx: the index into `inputs` and `control_inputs`.
|
||||
|
||||
Returns:
|
||||
A one item list of the `idx`-th element of `control_inputs`, or an empty
|
||||
list if `control_inputs` is None.
|
||||
"""
|
||||
if control_inputs is None:
|
||||
return []
|
||||
if len(control_inputs) != len(devices):
|
||||
raise ValueError(
|
||||
'control_inputs must match the length of the devices, %s != %s' %
|
||||
(len(control_inputs), len(devices)))
|
||||
return [control_inputs[idx]]
|
||||
def _pad_util(input_tensor, full_axis_dim):
|
||||
"""Pad the `input_tensor`'s first dimension to be `full_axis_dim`."""
|
||||
missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0]
|
||||
tensor_rank = array_ops.rank(input_tensor)
|
||||
paddings_axis = [[0, missing_axis_dim]]
|
||||
paddings = array_ops.concat([
|
||||
paddings_axis,
|
||||
array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32)
|
||||
],
|
||||
axis=0)
|
||||
padded_input_tensor = array_ops.pad(input_tensor, paddings)
|
||||
return padded_input_tensor
|
||||
|
@ -23,7 +23,6 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -114,11 +113,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
|
||||
device_util.resolve(destination), device_util.resolve(result.device))
|
||||
|
||||
|
||||
class PackBySizeTest(test.TestCase):
|
||||
|
||||
def assertShape(self, per_replica, shape):
|
||||
for v in per_replica._values: # pylint: disable=protected-access
|
||||
self.assertEqual(v.shape, shape)
|
||||
class GroupBySizeTest(test.TestCase):
|
||||
|
||||
def testPreferLargerPack(self):
|
||||
# Each packs except the last one should be equal or larger than
|
||||
@ -133,49 +128,38 @@ class PackBySizeTest(test.TestCase):
|
||||
# size = 1 * 4 = 4
|
||||
array_ops.ones([1], dtype=dtypes.int32),
|
||||
]
|
||||
per_replica_values = [value_lib.PerReplica([v, v]) for v in values]
|
||||
packs = cross_device_utils.pack_by_size(
|
||||
per_replica_values, bytes_per_pack=200)
|
||||
packs = cross_device_utils.group_by_size(values, bytes_per_pack=200)
|
||||
self.assertLen(packs, 2)
|
||||
self.assertLen(packs[0], 3)
|
||||
self.assertShape(packs[0][0], [2, 4, 4])
|
||||
self.assertShape(packs[0][1], [8])
|
||||
self.assertShape(packs[0][2], [10, 10])
|
||||
self.assertEqual(packs[0][0].shape, [2, 4, 4])
|
||||
self.assertEqual(packs[0][1].shape, [8])
|
||||
self.assertEqual(packs[0][2].shape, [10, 10])
|
||||
self.assertLen(packs[1], 1)
|
||||
self.assertShape(packs[1][0], [1])
|
||||
self.assertEqual(packs[1][0].shape, [1])
|
||||
|
||||
def testZeroBytesPerPack(self):
|
||||
values = [
|
||||
array_ops.ones([1], dtype=dtypes.float32),
|
||||
array_ops.ones([2], dtype=dtypes.float32),
|
||||
]
|
||||
per_replica_values = [value_lib.PerReplica([v, v]) for v in values]
|
||||
packs = cross_device_utils.pack_by_size(
|
||||
per_replica_values, bytes_per_pack=0)
|
||||
packs = cross_device_utils.group_by_size(values, bytes_per_pack=0)
|
||||
self.assertLen(packs, 1)
|
||||
self.assertLen(packs[0], 2)
|
||||
self.assertShape(packs[0][0], [1])
|
||||
self.assertShape(packs[0][1], [2])
|
||||
self.assertEqual(packs[0][0].shape, [1])
|
||||
self.assertEqual(packs[0][1].shape, [2])
|
||||
|
||||
def testUnknownShape(self):
|
||||
def create_placeholder(shape, dtype):
|
||||
with ops.Graph().as_default():
|
||||
return array_ops.placeholder(dtype=dtype, shape=shape)
|
||||
|
||||
per_replica_values = [
|
||||
value_lib.PerReplica([
|
||||
array_ops.ones([10, 10], dtype=dtypes.float32),
|
||||
array_ops.ones([10, 10], dtype=dtypes.float32),
|
||||
]),
|
||||
value_lib.PerReplica([
|
||||
array_ops.ones([10, 10], dtype=dtypes.float32),
|
||||
create_placeholder([None, 10], dtype=dtypes.float32),
|
||||
]),
|
||||
values = [
|
||||
array_ops.ones([10, 10], dtype=dtypes.float32),
|
||||
create_placeholder([None, 10], dtype=dtypes.float32),
|
||||
]
|
||||
packs = cross_device_utils.pack_by_size(
|
||||
per_replica_values, bytes_per_pack=1)
|
||||
packs = cross_device_utils.group_by_size(values, bytes_per_pack=1)
|
||||
self.assertLen(packs, 1)
|
||||
self.assertEqual(packs[0], per_replica_values)
|
||||
self.assertEqual(packs[0], values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -27,7 +27,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.distribute import cluster_resolver
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib
|
||||
from tensorflow.python.distribute import collective_util
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
@ -451,6 +451,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
# Reusing keys is not supported well. So we have to give a different
|
||||
# collective key base for different tests.
|
||||
CollectiveAllReduceTest.collective_key_base += 100000
|
||||
mwms_lib.CollectiveAllReduceStrategy._collective_key_base = (
|
||||
CollectiveAllReduceTest.collective_key_base)
|
||||
|
||||
def _get_test_objects(self,
|
||||
task_type,
|
||||
@ -460,10 +462,7 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
use_strategy_object=False,
|
||||
local_mode=False):
|
||||
collective_keys = cross_device_utils.CollectiveKeys(
|
||||
group_key_start=10 + CollectiveAllReduceTest.collective_key_base,
|
||||
op_instance_key_start=100 + CollectiveAllReduceTest.collective_key_base,
|
||||
variable_instance_key_start=10000 +
|
||||
CollectiveAllReduceTest.collective_key_base)
|
||||
group_key_start=10 + CollectiveAllReduceTest.collective_key_base)
|
||||
if local_mode:
|
||||
if num_gpus:
|
||||
devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
|
||||
@ -471,13 +470,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
devices = ["/device:CPU:0"]
|
||||
|
||||
if use_strategy_object:
|
||||
strategy = (
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy
|
||||
._from_local_devices(devices, communication=communication)) # pylint: disable=protected-access
|
||||
strategy.extended._collective_keys = collective_keys
|
||||
strategy.extended._cross_device_ops._collective_keys = collective_keys
|
||||
strategy.extended._host_cross_device_ops._collective_keys = (
|
||||
collective_keys)
|
||||
strategy = (mwms_lib.CollectiveAllReduceStrategy
|
||||
._from_local_devices(devices, communication=communication)) # pylint: disable=protected-access
|
||||
return strategy, devices, ""
|
||||
else:
|
||||
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
@ -507,10 +501,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
task_type=task_type,
|
||||
task_id=task_id,
|
||||
num_accelerators={"GPU": num_gpus})
|
||||
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
|
||||
strategy = mwms_lib.CollectiveAllReduceStrategy(
|
||||
cluster_resolver=resolver, communication=communication)
|
||||
strategy.extended._collective_keys = collective_keys
|
||||
strategy.extended._cross_device_ops._collective_keys = collective_keys
|
||||
return (strategy, devices,
|
||||
"grpc://" + self._cluster_spec[task_type][task_id])
|
||||
else:
|
||||
|
@ -689,7 +689,7 @@ cuda_py_test(
|
||||
name = "multi_worker_test",
|
||||
srcs = ["multi_worker_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 32,
|
||||
shard_count = 2,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_oss", # TODO(b/130369494): Investigate why it times out on OSS.
|
||||
|
@ -24,9 +24,8 @@ import numpy as np
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib
|
||||
from tensorflow.python.distribute import combinations as ds_combinations
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
@ -80,7 +79,7 @@ def create_test_objects(cluster_spec=None,
|
||||
ClusterSpec({}), num_accelerators={'GPU': num_gpus})
|
||||
target = ''
|
||||
|
||||
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
|
||||
strategy = mwms_lib.CollectiveAllReduceStrategy(
|
||||
cluster_resolver=cluster_resolver)
|
||||
sess_config = strategy.update_config_proto(sess_config)
|
||||
|
||||
@ -95,9 +94,7 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
def setUp(self):
|
||||
# We use a different key_base for each test so that collective keys won't be
|
||||
# reused.
|
||||
# TODO(yuefengz, ayushd): enable it to reuse collective keys in different
|
||||
# tests.
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base += 100000
|
||||
mwms_lib.CollectiveAllReduceStrategy._collective_key_base += 100000
|
||||
super(CollectiveAllReduceStrategyTestBase, self).setUp()
|
||||
|
||||
def _get_test_object(self, task_type, task_id, num_gpus=0):
|
||||
@ -106,18 +103,6 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
task_type=task_type,
|
||||
task_id=task_id,
|
||||
num_gpus=num_gpus)
|
||||
|
||||
collective_keys = cross_device_utils.CollectiveKeys(
|
||||
group_key_start=10 +
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base,
|
||||
op_instance_key_start=100 +
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base,
|
||||
variable_instance_key_start=10000 +
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base)
|
||||
strategy.extended._collective_keys = collective_keys
|
||||
strategy.extended._cross_device_ops._collective_keys = collective_keys
|
||||
strategy.extended._host_cross_device_ops._collective_keys = collective_keys
|
||||
|
||||
return strategy, target, session_config
|
||||
|
||||
def _test_complex_model(self, task_type, task_id, num_gpus):
|
||||
|
Loading…
Reference in New Issue
Block a user