Explicitly take the set of devices in CollectiveAllReduce
We used to infer the devices from the inputs, but sometimes the inputs don't have device placement. E.g. when passing into or returning from tf.function, the device placement may be lost. Instead of inferring from the inputs we should just be explicit about the collective devices. PiperOrigin-RevId: 316743112 Change-Id: I2f6995f2f4cc86864723e203deb7562363cdbc38
This commit is contained in:
parent
89df3ddcd5
commit
1a342fb760
@ -1029,14 +1029,21 @@ cuda_py_test(
|
||||
":collective_util",
|
||||
":combinations",
|
||||
":cross_device_ops",
|
||||
":mirrored_strategy",
|
||||
":cross_device_utils",
|
||||
":device_util",
|
||||
":multi_worker_test_base",
|
||||
":multi_worker_util",
|
||||
":reduce_util",
|
||||
":strategy_combinations",
|
||||
":values",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:collective_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:kernels",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
|
@ -175,7 +175,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
self._communication = communication
|
||||
self._initialize_strategy(self._cluster_resolver)
|
||||
self._cfer_fn_cache = weakref.WeakKeyDictionary()
|
||||
assert isinstance(self._get_cross_device_ops(),
|
||||
assert isinstance(self._cross_device_ops,
|
||||
cross_device_ops_lib.CollectiveAllReduce)
|
||||
|
||||
def _initialize_strategy(self, cluster_resolver):
|
||||
@ -217,12 +217,18 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
|
||||
|
||||
self._collective_keys = cross_device_utils.CollectiveKeys()
|
||||
# TODO(yuefengz): remove num_gpus_per_worker from CollectiveAllReduce.
|
||||
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
num_workers=self._num_workers,
|
||||
num_gpus_per_worker=num_gpus,
|
||||
devices=local_devices,
|
||||
group_size=len(local_devices),
|
||||
collective_keys=self._collective_keys,
|
||||
communication=self._communication)
|
||||
# CrossDeviceOps for per host tensors.
|
||||
self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
devices=[self._worker_device],
|
||||
group_size=self._num_workers,
|
||||
collective_keys=self._collective_keys,
|
||||
communication=cross_device_ops_lib.CollectiveCommunication.RING,
|
||||
)
|
||||
super(CollectiveAllReduceExtended, self)._initialize_single_worker(
|
||||
local_devices)
|
||||
|
||||
@ -324,10 +330,17 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
|
||||
self._collective_keys = cross_device_utils.CollectiveKeys()
|
||||
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
num_workers=self._num_workers,
|
||||
num_gpus_per_worker=num_gpus,
|
||||
devices=local_devices,
|
||||
group_size=len(local_devices) * self._num_workers,
|
||||
collective_keys=self._collective_keys,
|
||||
communication=self._communication)
|
||||
# CrossDeviceOps for per host tensors.
|
||||
self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
devices=[self._worker_device],
|
||||
group_size=self._num_workers,
|
||||
collective_keys=self._collective_keys,
|
||||
communication=cross_device_ops_lib.CollectiveCommunication.RING,
|
||||
)
|
||||
super(CollectiveAllReduceExtended, self)._initialize_single_worker(
|
||||
local_devices)
|
||||
host_device = device_util.get_host_for_device(self._worker_device)
|
||||
@ -474,7 +487,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
num_accelerators={"GPU": self._num_gpus_per_worker},
|
||||
rpc_layer=self._rpc_layer)
|
||||
self._initialize_multi_worker(cluster_resolver)
|
||||
assert isinstance(self._get_cross_device_ops(),
|
||||
assert isinstance(self._cross_device_ops,
|
||||
cross_device_ops_lib.CollectiveAllReduce)
|
||||
|
||||
if session_config:
|
||||
@ -518,6 +531,22 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
|
||||
return updated_config
|
||||
|
||||
def _get_cross_device_ops(self, value):
|
||||
# CollectiveAllReduce works on a predefined set of devices. In most cases
|
||||
# they should be the compute devices, but certain use cases may reduce host
|
||||
# tensors as well (e.g. early stopping). We infer the cross_device_ops to
|
||||
# use based on the number of devices, since inputs don't always have device
|
||||
# annotations. The compute devices one is preferred since we can potentially
|
||||
# leverage NCCL.
|
||||
if isinstance(value, values.DistributedValues):
|
||||
num_devices = len(value._values) # pylint: disable=protected-access
|
||||
else:
|
||||
num_devices = 1
|
||||
if num_devices == len(self.worker_devices):
|
||||
return self._cross_device_ops
|
||||
else:
|
||||
return self._host_cross_device_ops
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
if (isinstance(value, values.Mirrored) and
|
||||
reduce_op == reduce_util.ReduceOp.MEAN):
|
||||
@ -538,7 +567,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
# be 0.
|
||||
return cross_device_ops_lib.reduce_non_distributed_value(
|
||||
reduce_op, value, destinations, len(self.worker_devices))
|
||||
return self._get_cross_device_ops().reduce(
|
||||
return self._get_cross_device_ops(value).reduce(
|
||||
reduce_op,
|
||||
value,
|
||||
destinations=destinations,
|
||||
|
@ -116,7 +116,8 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
variable_instance_key_start=10000 +
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base)
|
||||
strategy.extended._collective_keys = collective_keys
|
||||
strategy.extended._cross_device_ops._collective_keys = (collective_keys)
|
||||
strategy.extended._cross_device_ops._collective_keys = collective_keys
|
||||
strategy.extended._host_cross_device_ops._collective_keys = collective_keys
|
||||
|
||||
return strategy, target, session_config
|
||||
|
||||
|
@ -165,7 +165,8 @@ def get_devices_from(destinations):
|
||||
|
||||
|
||||
def _devices_match(left, right):
|
||||
return set(get_devices_from(left)) == set(get_devices_from(right))
|
||||
return left is right or set(get_devices_from(left)) == set(
|
||||
get_devices_from(right))
|
||||
|
||||
|
||||
def _all_devices_match(value_destination_pairs):
|
||||
@ -936,20 +937,24 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_workers=1,
|
||||
num_gpus_per_worker=0,
|
||||
devices,
|
||||
group_size,
|
||||
collective_keys=None,
|
||||
communication=CollectiveCommunication.AUTO):
|
||||
"""Initializes the object.
|
||||
|
||||
Args:
|
||||
num_workers: number of workers in the between-graph replicated training.
|
||||
num_gpus_per_worker: number of GPUs per worker.
|
||||
devices: a list of device strings to run collectives on.
|
||||
group_size: the global group size. For between-graph replicated training
|
||||
it's the total number of devices across all workers.
|
||||
collective_keys: an optional CollectiveKey object.
|
||||
communication: indicates which collective communication to use.
|
||||
"""
|
||||
self._num_workers = num_workers
|
||||
self._num_gpus_per_worker = num_gpus_per_worker
|
||||
if group_size % len(devices) > 0:
|
||||
raise ValueError("group_size must be divisible by the number of devices.")
|
||||
|
||||
self._devices = tuple(device_util.canonicalize(d) for d in devices)
|
||||
self._group_size = group_size
|
||||
self._collective_keys = (collective_keys or
|
||||
cross_device_utils.CollectiveKeys())
|
||||
self._communication = communication
|
||||
@ -963,15 +968,15 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
# async executor operations are still executed sequentially. In graph or
|
||||
# function building, the executors are not used.
|
||||
self._executors = []
|
||||
for _ in range(self._num_gpus_per_worker or 1):
|
||||
# If num_gpus_per_worker is zero, we assume there's only one device (CPU).
|
||||
for _ in range(len(devices)):
|
||||
self._executors.append(executor.new_executor(enable_async=True))
|
||||
|
||||
super(CollectiveAllReduce, self).__init__()
|
||||
|
||||
@property
|
||||
def _num_between_graph_workers(self):
|
||||
return self._num_workers
|
||||
# Currently we only support equal number of devices on each worker.
|
||||
return self._group_size / len(self._devices)
|
||||
|
||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
||||
experimental_hints):
|
||||
@ -979,8 +984,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
experimental_hints)[0]
|
||||
devices = get_devices_from(destinations)
|
||||
|
||||
if (isinstance(all_reduced, value_lib.Mirrored) and
|
||||
(all_reduced._devices == devices)): # pylint: disable=protected-access
|
||||
if _devices_match(per_replica_value, destinations):
|
||||
return all_reduced
|
||||
|
||||
# Convert `all_reduced` to a `Mirrored` object, as a simple and uniform
|
||||
@ -1069,14 +1073,16 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
|
||||
if batch_size > 1:
|
||||
logging.info(
|
||||
"Collective batch_all_reduce: %d all-reduces, num_workers = %d, "
|
||||
"communication_hint = %s, num_packs = %d", batch_size,
|
||||
self._num_workers, communication, len(packs))
|
||||
"Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
|
||||
"group_size = %d, communication_hint = %s, num_packs = %d",
|
||||
batch_size, len(self._devices), self._group_size, communication,
|
||||
len(packs))
|
||||
else:
|
||||
logging.log_first_n(
|
||||
logging.INFO, "Collective batch_all_reduce: %d all-reduces, "
|
||||
"num_workers = %d, communication_hint = %s, num_packs = %d" %
|
||||
(batch_size, self._num_workers, communication, len(packs)), 10)
|
||||
"num_devices = %d, group_size = %d, communication_hint = %s, "
|
||||
"num_packs = %d" % (batch_size, len(
|
||||
self._devices), self._group_size, communication, len(packs)), 10)
|
||||
|
||||
reduced_values = []
|
||||
for pack in packs:
|
||||
@ -1094,21 +1100,25 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
control_inputs = None
|
||||
reduced_values.append(
|
||||
cross_device_utils.build_collective_reduce(
|
||||
per_replica.values, self._num_workers,
|
||||
self._collective_keys, "Add", "Id", communication,
|
||||
control_inputs, executors=self._executors))
|
||||
per_replica.values,
|
||||
self._devices,
|
||||
self._group_size,
|
||||
self._collective_keys,
|
||||
"Add",
|
||||
"Id",
|
||||
communication,
|
||||
control_inputs,
|
||||
executors=self._executors))
|
||||
|
||||
mirrored = []
|
||||
# Reverse the order of reduced value to recover the order in the input.
|
||||
for value in reversed(reduced_values):
|
||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||
# Assume each worker has the same number of replicas.
|
||||
num_replicas = len(value) * self._num_workers
|
||||
for i, v in enumerate(value):
|
||||
with ops.device(v.device):
|
||||
value[i] = v / num_replicas
|
||||
mirrored.append(distribute_utils.regroup(value,
|
||||
wrap_class=value_lib.Mirrored))
|
||||
value[i] = v / self._group_size
|
||||
mirrored.append(
|
||||
distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
|
||||
return mirrored
|
||||
|
||||
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values):
|
||||
@ -1116,8 +1126,8 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
|
||||
logging.log_first_n(
|
||||
logging.INFO, "Collective batch_all_reduce for IndexedSlices: "
|
||||
"%d all-reduces, num_workers = %d" %
|
||||
(len(per_replica_values), self._num_workers), 10)
|
||||
"%d all-reduces, group_size = %d" %
|
||||
(len(per_replica_values), self._group_size), 10)
|
||||
|
||||
# Pass self._communication to the runtime as a communication hint.
|
||||
communication_hint = self._communication.value
|
||||
@ -1133,25 +1143,24 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
for per_replica in per_replica_values:
|
||||
gathered_values.append(
|
||||
cross_device_utils.build_collective_gather_indexed_slices(
|
||||
per_replica.values, self._num_workers, self._collective_keys,
|
||||
communication_hint))
|
||||
per_replica.values, self._devices, self._group_size,
|
||||
self._collective_keys, communication_hint))
|
||||
|
||||
mirrored = []
|
||||
for value in gathered_values:
|
||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||
# Assume each worker has the same number of replicas.
|
||||
num_replicas = len(value) * self._num_workers
|
||||
for i, v in enumerate(value):
|
||||
with ops.device(v.device):
|
||||
value[i].values = value[i].values / num_replicas
|
||||
mirrored.append(distribute_utils.regroup(value,
|
||||
wrap_class=value_lib.Mirrored))
|
||||
value[i].values = value[i].values / self._group_size
|
||||
mirrored.append(
|
||||
distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
|
||||
return mirrored
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
# distribute_coordinator deep-copies the strategy object, so
|
||||
# CollectiveAllReduce needs to support deep copy as well.
|
||||
return CollectiveAllReduce(self._num_workers, self._num_gpus_per_worker,
|
||||
return CollectiveAllReduce(self._devices, self._group_size,
|
||||
self._collective_keys, self._communication)
|
||||
|
||||
|
||||
|
@ -26,6 +26,7 @@ import time
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.distribute import cluster_resolver
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import collective_util
|
||||
from tensorflow.python.distribute import combinations
|
||||
@ -34,10 +35,12 @@ from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import kernels
|
||||
@ -125,7 +128,10 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
|
||||
self.evaluate(ops.convert_to_tensor(left)),
|
||||
self.evaluate(ops.convert_to_tensor(right)))
|
||||
|
||||
def _assert_mirrored_equal(self, left_list, right_list, sess,
|
||||
def _assert_mirrored_equal(self,
|
||||
left_list,
|
||||
right_list,
|
||||
sess=None,
|
||||
run_options=None):
|
||||
if not isinstance(left_list, list):
|
||||
left_list, right_list = [left_list], [right_list]
|
||||
@ -142,17 +148,14 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
|
||||
left, right = [left], [right]
|
||||
|
||||
for left_value, right_value in zip(left, right):
|
||||
self.assertEqual(left_value.device, right_value.device)
|
||||
self.assertEqual(
|
||||
device_util.resolve(left_value.device),
|
||||
device_util.resolve(right_value.device))
|
||||
|
||||
# Densify IndexedSlices.
|
||||
left = [ops.convert_to_tensor(v) for v in left]
|
||||
right = [ops.convert_to_tensor(v) for v in right]
|
||||
if context.executing_eagerly():
|
||||
# Optional args in session run are not supported when eager execution
|
||||
# is enabled.
|
||||
assert run_options is None
|
||||
left, right = sess.run((left, right))
|
||||
else:
|
||||
if not context.executing_eagerly():
|
||||
left, right = sess.run((left, right), options=run_options)
|
||||
for left_value, right_value in zip(left, right):
|
||||
self.assertAllEqual(left_value, right_value)
|
||||
@ -525,8 +528,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
return strategy, devices, ""
|
||||
else:
|
||||
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
1,
|
||||
num_gpus,
|
||||
devices=devices,
|
||||
group_size=len(devices),
|
||||
collective_keys=collective_keys,
|
||||
communication=communication)
|
||||
return collective_all_reduce_ops, devices, ""
|
||||
@ -545,26 +548,28 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
]
|
||||
|
||||
if use_strategy_object:
|
||||
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
|
||||
communication=communication)
|
||||
strategy.configure(
|
||||
cluster_spec=self._cluster_spec,
|
||||
resolver = cluster_resolver.SimpleClusterResolver(
|
||||
cluster_spec=multi_worker_util.normalize_cluster_spec(
|
||||
self._cluster_spec),
|
||||
task_type=task_type,
|
||||
task_id=task_id)
|
||||
task_id=task_id,
|
||||
num_accelerators={"GPU": num_gpus})
|
||||
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
|
||||
cluster_resolver=resolver, communication=communication)
|
||||
strategy.extended._collective_keys = collective_keys
|
||||
strategy.extended._cross_device_ops._collective_keys = collective_keys
|
||||
return (strategy, devices,
|
||||
"grpc://" + self._cluster_spec[task_type][task_id])
|
||||
else:
|
||||
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
NUM_WORKERS,
|
||||
num_gpus,
|
||||
devices=devices,
|
||||
group_size=len(devices) * NUM_WORKERS,
|
||||
collective_keys=collective_keys,
|
||||
communication=communication)
|
||||
return (collective_all_reduce_ops, devices,
|
||||
"grpc://" + self._cluster_spec[task_type][task_id])
|
||||
|
||||
def _assert_mirrored_equal(self, left_list, right_list, sess):
|
||||
def _assert_mirrored_equal(self, left_list, right_list, sess=None):
|
||||
if context.executing_eagerly():
|
||||
run_options = None
|
||||
else:
|
||||
@ -895,6 +900,45 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
self.assertAllEqual(reduced[1].values, [4.0, 4.0])
|
||||
t.join()
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
required_gpus=2,
|
||||
mode="eager",
|
||||
communication=[
|
||||
CollectiveCommunication.NCCL, CollectiveCommunication.RING
|
||||
]))
|
||||
def testInputsAreFunctionArgs(self, communication):
|
||||
# Function inputs don't have device placement.
|
||||
hints = collective_util.Hints(bytes_per_pack=1)
|
||||
collective, devices, _ = self._get_test_objects(
|
||||
None,
|
||||
None,
|
||||
num_gpus=2,
|
||||
communication=communication,
|
||||
use_strategy_object=False,
|
||||
local_mode=True)
|
||||
devices = [device_util.canonicalize(d) for d in devices]
|
||||
|
||||
@def_function.function
|
||||
def reduce_fn(v):
|
||||
self.assertEqual(v.values[0].device, "")
|
||||
self.assertEqual(v.values[1].device, "")
|
||||
# We only use NCCL for batch reduce with two or more values, so we use two
|
||||
# values here.
|
||||
reduced = collective.batch_reduce(
|
||||
reduce_util.ReduceOp.SUM, [(v, v), (v, v)], experimental_hints=hints)
|
||||
self.assertEqual(reduced[0].values[0].device, devices[0])
|
||||
self.assertEqual(reduced[0].values[1].device, devices[1])
|
||||
self.assertEqual(reduced[1].values[0].device, devices[0])
|
||||
self.assertEqual(reduced[1].values[1].device, devices[1])
|
||||
# Returning Mirrored only evaluates the primary value, which causes
|
||||
# hanging,
|
||||
return [reduced[0].values, reduced[1].values]
|
||||
|
||||
v = _make_per_replica([1.0, 2.0], devices)
|
||||
reduced = reduce_fn(v)
|
||||
self.assertAllEqual(self.evaluate(reduced), [[3.0, 3.0], [3.0, 3.0]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set default inter op thread pool size to one to ensure we don't exhaust the
|
||||
|
@ -305,19 +305,6 @@ class CollectiveKeys(object):
|
||||
self._group_key_table[key_id] = new_key
|
||||
return self._group_key_table[key_id]
|
||||
|
||||
def get_group_key_of_tensors(self, tensors):
|
||||
"""Returns a group key for set of tensors.
|
||||
|
||||
Args:
|
||||
tensors: list of `Tensor`s in a collective group. Each tensor must be on a
|
||||
different device.
|
||||
|
||||
Returns:
|
||||
int key uniquely identifying the set of devices of these tensors.
|
||||
"""
|
||||
devices = [t.device for t in tensors]
|
||||
return self.get_group_key(devices)
|
||||
|
||||
def get_op_instance_key(self):
|
||||
"""Returns a new instance key for use in defining a collective op."""
|
||||
v = self._get_thread_local_object().op_instance_key
|
||||
@ -332,7 +319,8 @@ class CollectiveKeys(object):
|
||||
|
||||
|
||||
def build_collective_reduce(input_tensors,
|
||||
num_workers,
|
||||
devices,
|
||||
group_size,
|
||||
collective_keys,
|
||||
reduction_op='Add',
|
||||
unary_op='Id',
|
||||
@ -347,9 +335,10 @@ def build_collective_reduce(input_tensors,
|
||||
Args:
|
||||
input_tensors: tensors within a single worker graph that are to be reduced
|
||||
together; must be one per device.
|
||||
num_workers: total number of workers with identical independent graphs that
|
||||
will be doing this same reduction. The reduction will actually include
|
||||
the corresponding tensors at all these workers.
|
||||
devices: a list of device strings to run the collective on.
|
||||
group_size: total number of devices globally that will be doing this same
|
||||
reduction. The reduction will actually include the corresponding tensors
|
||||
at all these workers.
|
||||
collective_keys: a CollectiveKeys object.
|
||||
reduction_op: string naming the reduction op.
|
||||
unary_op: string naming the unary final op.
|
||||
@ -370,11 +359,14 @@ def build_collective_reduce(input_tensors,
|
||||
not all(e.is_async() for e in executors)):
|
||||
raise ValueError(
|
||||
'collectives requires async executors for each device in eager mode')
|
||||
if len(input_tensors) != len(devices):
|
||||
raise ValueError('collective requires one input tensor for each device, '
|
||||
'len(input_tensors) = %d, len(devices) = %d' %
|
||||
(len(input_tensors), len(devices)))
|
||||
|
||||
group_size = len(input_tensors) * num_workers
|
||||
if group_size < 2:
|
||||
return input_tensors
|
||||
group_key = collective_keys.get_group_key_of_tensors(input_tensors)
|
||||
group_key = collective_keys.get_group_key(devices)
|
||||
instance_key = collective_keys.get_op_instance_key()
|
||||
subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec
|
||||
|
||||
@ -385,9 +377,9 @@ def build_collective_reduce(input_tensors,
|
||||
else:
|
||||
executor_scope = ops.NullContextmanager()
|
||||
with executor_scope, \
|
||||
ops.device(input_tensor.device), \
|
||||
ops.device(devices[idx]), \
|
||||
ops.control_dependencies(
|
||||
_control_input(input_tensors, control_inputs, idx)):
|
||||
_control_input(devices, control_inputs, idx)):
|
||||
out_tensor = collective_ops.all_reduce(input_tensor, group_size,
|
||||
group_key, instance_key,
|
||||
reduction_op, unary_op,
|
||||
@ -397,7 +389,8 @@ def build_collective_reduce(input_tensors,
|
||||
|
||||
|
||||
def build_collective_gather(input_tensors,
|
||||
num_workers,
|
||||
devices,
|
||||
group_size,
|
||||
collective_keys,
|
||||
communication_hint='AUTO',
|
||||
control_inputs=None):
|
||||
@ -408,9 +401,10 @@ def build_collective_gather(input_tensors,
|
||||
Args:
|
||||
input_tensors: tensors within a single worker graph that are to be gathered
|
||||
together; must be one per device.
|
||||
num_workers: total number of workers with identical independent graphs that
|
||||
will be doing this same reduction. The reduction will actually include
|
||||
the corresponding tensors at all these workers.
|
||||
devices: a list of device strings to run the collective on.
|
||||
group_size: total number of devices globally that will be doing this same
|
||||
gathering. The gathering will actually include the corresponding tensors
|
||||
at all these workers.
|
||||
collective_keys: a CollectiveKeys object.
|
||||
communication_hint: string providing hint to runtime for choosing collective
|
||||
implementation.
|
||||
@ -423,18 +417,21 @@ def build_collective_gather(input_tensors,
|
||||
assert not context.executing_eagerly(), (
|
||||
'build_collective_gather can only be called in graph mode or inside '
|
||||
'tf.function')
|
||||
if len(input_tensors) != len(devices):
|
||||
raise ValueError(
|
||||
'collective requires one input tensor for each device, %d != %d' %
|
||||
(len(input_tensors), len(devices)))
|
||||
|
||||
group_size = len(input_tensors) * num_workers
|
||||
if group_size < 2:
|
||||
return input_tensors
|
||||
group_key = collective_keys.get_group_key_of_tensors(input_tensors)
|
||||
group_key = collective_keys.get_group_key(devices)
|
||||
instance_key = collective_keys.get_op_instance_key()
|
||||
|
||||
out_tensors = []
|
||||
for idx, input_tensor in enumerate(input_tensors):
|
||||
with ops.device(input_tensor.device):
|
||||
with ops.device(devices[idx]):
|
||||
with ops.control_dependencies(
|
||||
_control_input(input_tensors, control_inputs, idx)):
|
||||
_control_input(devices, control_inputs, idx)):
|
||||
out_tensor = collective_ops.all_gather(input_tensor, group_size,
|
||||
group_key, instance_key,
|
||||
communication_hint)
|
||||
@ -443,7 +440,8 @@ def build_collective_gather(input_tensors,
|
||||
|
||||
|
||||
def build_collective_gather_indexed_slices(input_slices_list,
|
||||
num_workers,
|
||||
devices,
|
||||
group_size,
|
||||
collective_keys,
|
||||
communication_hint='AUTO',
|
||||
control_inputs=None):
|
||||
@ -454,9 +452,10 @@ def build_collective_gather_indexed_slices(input_slices_list,
|
||||
Args:
|
||||
input_slices_list: a list of IndexedSlices within a single worker graph that
|
||||
are to be gathered together; must be one per device.
|
||||
num_workers: total number of workers with identical independent graphs that
|
||||
will be doing this same reduction. The reduction will actually include
|
||||
the corresponding tensors at all these workers.
|
||||
devices: a list of device strings to run the collective on.
|
||||
group_size: total number of devices globally that will be doing this same
|
||||
gathering. The gathering will actually include the corresponding tensors
|
||||
at all these workers.
|
||||
collective_keys: a CollectiveKeys object.
|
||||
communication_hint: string providing hint to runtime for choosing collective
|
||||
implementation.
|
||||
@ -474,12 +473,15 @@ def build_collective_gather_indexed_slices(input_slices_list,
|
||||
assert not context.executing_eagerly(), (
|
||||
'build_collective_gather_indexed_slices can only be called in graph mode'
|
||||
' or inside tf.function')
|
||||
if len(input_slices_list) != len(devices):
|
||||
raise ValueError(
|
||||
'collective requires one input IndexedSlice for each device, %d != %d' %
|
||||
(len(input_slices_list), len(devices)))
|
||||
|
||||
group_size = len(input_slices_list) * num_workers
|
||||
if group_size < 2:
|
||||
return input_slices_list
|
||||
|
||||
group_key = collective_keys.get_group_key_of_tensors(input_slices_list)
|
||||
group_key = collective_keys.get_group_key(devices)
|
||||
gather_length_key = collective_keys.get_op_instance_key()
|
||||
gather_indices_key = collective_keys.get_op_instance_key()
|
||||
gather_values_key = collective_keys.get_op_instance_key()
|
||||
@ -495,7 +497,7 @@ def build_collective_gather_indexed_slices(input_slices_list,
|
||||
out_slices_list = []
|
||||
for idx, input_slices in enumerate(input_slices_list):
|
||||
# pylint: disable = cell-var-from-loop
|
||||
with ops.device(input_slices.device):
|
||||
with ops.device(devices[idx]):
|
||||
|
||||
def all_gather():
|
||||
"""Use all_gather to aggregate `IndexedSlices`."""
|
||||
@ -967,14 +969,13 @@ def pack_by_size(per_replica_list, bytes_per_pack):
|
||||
return packs
|
||||
|
||||
|
||||
def _control_input(inputs, control_inputs, idx):
|
||||
def _control_input(devices, control_inputs, idx):
|
||||
"""Returns the `idx`-th item in control_inputs to be used in ops.control_dependencies.
|
||||
|
||||
This is a helper function for building collective ops. The function checks
|
||||
that the devices of control_inputs and inputs match.
|
||||
This is a helper function for building collective ops.
|
||||
|
||||
Args:
|
||||
inputs: a list of `Tensor`s
|
||||
devices: a list of device strings the collective run on.
|
||||
control_inputs: a list or None.
|
||||
idx: the index into `inputs` and `control_inputs`.
|
||||
|
||||
@ -984,12 +985,8 @@ def _control_input(inputs, control_inputs, idx):
|
||||
"""
|
||||
if control_inputs is None:
|
||||
return []
|
||||
if len(control_inputs) != len(inputs):
|
||||
if len(control_inputs) != len(devices):
|
||||
raise ValueError(
|
||||
'control_inputs must match the length of the inputs, %s != %s' %
|
||||
(len(control_inputs), len(inputs)))
|
||||
if control_inputs[idx].device != inputs[idx].device:
|
||||
raise ValueError(
|
||||
'control_inputs must match the device of the inputs, %s != %s' %
|
||||
(control_inputs[idx].device, inputs[idx].device))
|
||||
'control_inputs must match the length of the devices, %s != %s' %
|
||||
(len(control_inputs), len(devices)))
|
||||
return [control_inputs[idx]]
|
||||
|
@ -579,7 +579,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
if not destinations:
|
||||
# TODO(josh11b): Use current logical device instead of 0 here.
|
||||
destinations = self._devices
|
||||
return self._get_cross_device_ops().broadcast(tensor, destinations)
|
||||
return self._get_cross_device_ops(tensor).broadcast(tensor, destinations)
|
||||
|
||||
def _call_for_each_replica(self, fn, args, kwargs):
|
||||
return mirrored_run.call_for_each_replica(
|
||||
@ -608,7 +608,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
updated_config.isolate_session_state = True
|
||||
return updated_config
|
||||
|
||||
def _get_cross_device_ops(self):
|
||||
def _get_cross_device_ops(self, value):
|
||||
del value # Unused.
|
||||
return self._cross_device_ops or self._inferred_cross_device_ops
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
@ -623,7 +624,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
# be 0.
|
||||
return cross_device_ops_lib.reduce_non_distributed_value(
|
||||
reduce_op, value, destinations, self._num_replicas_in_sync)
|
||||
return self._get_cross_device_ops().reduce(
|
||||
return self._get_cross_device_ops(value).reduce(
|
||||
reduce_op,
|
||||
value,
|
||||
destinations=destinations,
|
||||
@ -631,9 +632,15 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
def _batch_reduce_to(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
return self._get_cross_device_ops().batch_reduce(reduce_op,
|
||||
value_destination_pairs,
|
||||
experimental_hints)
|
||||
cross_device_ops = None
|
||||
for value, _ in value_destination_pairs:
|
||||
if cross_device_ops is None:
|
||||
cross_device_ops = self._get_cross_device_ops(value)
|
||||
elif cross_device_ops is not self._get_cross_device_ops(value):
|
||||
raise ValueError("inputs to batch_reduce_to must be either all on the "
|
||||
"the host or all on the compute devices")
|
||||
return cross_device_ops.batch_reduce(reduce_op, value_destination_pairs,
|
||||
experimental_hints)
|
||||
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
# TODO(josh11b): In eager mode, use one thread per device.
|
||||
|
@ -65,14 +65,14 @@ class StrategyReduceTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync)
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
strategy=[strategy_combinations.multi_worker_mirrored_two_workers],
|
||||
mode=['eager']))
|
||||
class DistributedCollectiveAllReduceStrategyTest(
|
||||
strategy_test_lib.DistributionTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
strategy=[strategy_combinations.multi_worker_mirrored_two_workers],
|
||||
mode=['eager']))
|
||||
def testDatasetFromFunction(self, strategy):
|
||||
def dataset_fn(input_context):
|
||||
global_batch_size = 10
|
||||
@ -95,6 +95,45 @@ class DistributedCollectiveAllReduceStrategyTest(
|
||||
sum_value.numpy(),
|
||||
expected_sum_on_workers[multi_worker_test_base.get_task_index()])
|
||||
|
||||
def testReduceHostTensor(self, strategy):
|
||||
reduced = strategy.reduce(
|
||||
reduce_util.ReduceOp.SUM, array_ops.identity(1.), axis=None)
|
||||
self.assertEqual(reduced.numpy(), 2.)
|
||||
|
||||
def testReduceToHostTensor(self, strategy):
|
||||
value = array_ops.identity(1.)
|
||||
reduced = strategy.extended.reduce_to(reduce_util.ReduceOp.SUM, value,
|
||||
value)
|
||||
self.assertEqual(reduced.numpy(), 2.)
|
||||
|
||||
def testBatchReduceToHostTensor(self, strategy):
|
||||
value = array_ops.identity(1.)
|
||||
reduced = strategy.extended.batch_reduce_to(reduce_util.ReduceOp.SUM,
|
||||
[(value, value),
|
||||
(value, value)])
|
||||
self.assertAllEqual(reduced, [2., 2.])
|
||||
|
||||
def testReduceDeviceTensors(self, strategy):
|
||||
value = strategy.run(lambda: array_ops.identity(1.))
|
||||
reduced = strategy.reduce(reduce_util.ReduceOp.SUM, value, axis=None)
|
||||
self.assertEqual(reduced.numpy(), 2.)
|
||||
|
||||
def testReduceToDeviceTensors(self, strategy):
|
||||
value = strategy.run(lambda: array_ops.identity(1.))
|
||||
reduced = strategy.extended.reduce_to(reduce_util.ReduceOp.SUM, value,
|
||||
value)
|
||||
self.assertEqual(reduced.numpy(), 2.)
|
||||
|
||||
def testBatchReduceToDeviceTensors(self, strategy):
|
||||
value = strategy.run(lambda: array_ops.identity(1.))
|
||||
reduced = strategy.extended.batch_reduce_to(reduce_util.ReduceOp.SUM,
|
||||
[(value, value),
|
||||
(value, value)])
|
||||
self.assertAllEqual(reduced, [2., 2.])
|
||||
|
||||
# TODO(crccw): add a test that mixes device and host tensors after multi
|
||||
# worker strategy combinations can run on a fixed number of GPUs.
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
combinations.main()
|
||||
|
Loading…
Reference in New Issue
Block a user