Explicitly take the set of devices in CollectiveAllReduce

We used to infer the devices from the inputs, but sometimes the inputs don't
have device placement. E.g. when passing into or returning from tf.function, the
device placement may be lost.

Instead of inferring from the inputs we should just be explicit about the
collective devices.

PiperOrigin-RevId: 316743112
Change-Id: I2f6995f2f4cc86864723e203deb7562363cdbc38
This commit is contained in:
Ran Chen 2020-06-16 13:11:14 -07:00 committed by TensorFlower Gardener
parent 89df3ddcd5
commit 1a342fb760
8 changed files with 252 additions and 119 deletions

View File

@ -1029,14 +1029,21 @@ cuda_py_test(
":collective_util",
":combinations",
":cross_device_ops",
":mirrored_strategy",
":cross_device_utils",
":device_util",
":multi_worker_test_base",
":multi_worker_util",
":reduce_util",
":strategy_combinations",
":values",
"//tensorflow/python:array_ops",
"//tensorflow/python:collective_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
"//tensorflow/python:kernels",
"//tensorflow/python:math_ops",
"//tensorflow/python:variables",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
"@absl_py//absl/testing:parameterized",

View File

@ -175,7 +175,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
self._communication = communication
self._initialize_strategy(self._cluster_resolver)
self._cfer_fn_cache = weakref.WeakKeyDictionary()
assert isinstance(self._get_cross_device_ops(),
assert isinstance(self._cross_device_ops,
cross_device_ops_lib.CollectiveAllReduce)
def _initialize_strategy(self, cluster_resolver):
@ -217,12 +217,18 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
self._collective_keys = cross_device_utils.CollectiveKeys()
# TODO(yuefengz): remove num_gpus_per_worker from CollectiveAllReduce.
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
num_workers=self._num_workers,
num_gpus_per_worker=num_gpus,
devices=local_devices,
group_size=len(local_devices),
collective_keys=self._collective_keys,
communication=self._communication)
# CrossDeviceOps for per host tensors.
self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=[self._worker_device],
group_size=self._num_workers,
collective_keys=self._collective_keys,
communication=cross_device_ops_lib.CollectiveCommunication.RING,
)
super(CollectiveAllReduceExtended, self)._initialize_single_worker(
local_devices)
@ -324,10 +330,17 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
self._collective_keys = cross_device_utils.CollectiveKeys()
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
num_workers=self._num_workers,
num_gpus_per_worker=num_gpus,
devices=local_devices,
group_size=len(local_devices) * self._num_workers,
collective_keys=self._collective_keys,
communication=self._communication)
# CrossDeviceOps for per host tensors.
self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=[self._worker_device],
group_size=self._num_workers,
collective_keys=self._collective_keys,
communication=cross_device_ops_lib.CollectiveCommunication.RING,
)
super(CollectiveAllReduceExtended, self)._initialize_single_worker(
local_devices)
host_device = device_util.get_host_for_device(self._worker_device)
@ -474,7 +487,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
num_accelerators={"GPU": self._num_gpus_per_worker},
rpc_layer=self._rpc_layer)
self._initialize_multi_worker(cluster_resolver)
assert isinstance(self._get_cross_device_ops(),
assert isinstance(self._cross_device_ops,
cross_device_ops_lib.CollectiveAllReduce)
if session_config:
@ -518,6 +531,22 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
return updated_config
def _get_cross_device_ops(self, value):
# CollectiveAllReduce works on a predefined set of devices. In most cases
# they should be the compute devices, but certain use cases may reduce host
# tensors as well (e.g. early stopping). We infer the cross_device_ops to
# use based on the number of devices, since inputs don't always have device
# annotations. The compute devices one is preferred since we can potentially
# leverage NCCL.
if isinstance(value, values.DistributedValues):
num_devices = len(value._values) # pylint: disable=protected-access
else:
num_devices = 1
if num_devices == len(self.worker_devices):
return self._cross_device_ops
else:
return self._host_cross_device_ops
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
if (isinstance(value, values.Mirrored) and
reduce_op == reduce_util.ReduceOp.MEAN):
@ -538,7 +567,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
# be 0.
return cross_device_ops_lib.reduce_non_distributed_value(
reduce_op, value, destinations, len(self.worker_devices))
return self._get_cross_device_ops().reduce(
return self._get_cross_device_ops(value).reduce(
reduce_op,
value,
destinations=destinations,

View File

@ -116,7 +116,8 @@ class CollectiveAllReduceStrategyTestBase(
variable_instance_key_start=10000 +
CollectiveAllReduceStrategyTestBase.collective_key_base)
strategy.extended._collective_keys = collective_keys
strategy.extended._cross_device_ops._collective_keys = (collective_keys)
strategy.extended._cross_device_ops._collective_keys = collective_keys
strategy.extended._host_cross_device_ops._collective_keys = collective_keys
return strategy, target, session_config

View File

@ -165,7 +165,8 @@ def get_devices_from(destinations):
def _devices_match(left, right):
return set(get_devices_from(left)) == set(get_devices_from(right))
return left is right or set(get_devices_from(left)) == set(
get_devices_from(right))
def _all_devices_match(value_destination_pairs):
@ -936,20 +937,24 @@ class CollectiveAllReduce(CrossDeviceOps):
"""
def __init__(self,
num_workers=1,
num_gpus_per_worker=0,
devices,
group_size,
collective_keys=None,
communication=CollectiveCommunication.AUTO):
"""Initializes the object.
Args:
num_workers: number of workers in the between-graph replicated training.
num_gpus_per_worker: number of GPUs per worker.
devices: a list of device strings to run collectives on.
group_size: the global group size. For between-graph replicated training
it's the total number of devices across all workers.
collective_keys: an optional CollectiveKey object.
communication: indicates which collective communication to use.
"""
self._num_workers = num_workers
self._num_gpus_per_worker = num_gpus_per_worker
if group_size % len(devices) > 0:
raise ValueError("group_size must be divisible by the number of devices.")
self._devices = tuple(device_util.canonicalize(d) for d in devices)
self._group_size = group_size
self._collective_keys = (collective_keys or
cross_device_utils.CollectiveKeys())
self._communication = communication
@ -963,15 +968,15 @@ class CollectiveAllReduce(CrossDeviceOps):
# async executor operations are still executed sequentially. In graph or
# function building, the executors are not used.
self._executors = []
for _ in range(self._num_gpus_per_worker or 1):
# If num_gpus_per_worker is zero, we assume there's only one device (CPU).
for _ in range(len(devices)):
self._executors.append(executor.new_executor(enable_async=True))
super(CollectiveAllReduce, self).__init__()
@property
def _num_between_graph_workers(self):
return self._num_workers
# Currently we only support equal number of devices on each worker.
return self._group_size / len(self._devices)
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
experimental_hints):
@ -979,8 +984,7 @@ class CollectiveAllReduce(CrossDeviceOps):
experimental_hints)[0]
devices = get_devices_from(destinations)
if (isinstance(all_reduced, value_lib.Mirrored) and
(all_reduced._devices == devices)): # pylint: disable=protected-access
if _devices_match(per_replica_value, destinations):
return all_reduced
# Convert `all_reduced` to a `Mirrored` object, as a simple and uniform
@ -1069,14 +1073,16 @@ class CollectiveAllReduce(CrossDeviceOps):
if batch_size > 1:
logging.info(
"Collective batch_all_reduce: %d all-reduces, num_workers = %d, "
"communication_hint = %s, num_packs = %d", batch_size,
self._num_workers, communication, len(packs))
"Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
"group_size = %d, communication_hint = %s, num_packs = %d",
batch_size, len(self._devices), self._group_size, communication,
len(packs))
else:
logging.log_first_n(
logging.INFO, "Collective batch_all_reduce: %d all-reduces, "
"num_workers = %d, communication_hint = %s, num_packs = %d" %
(batch_size, self._num_workers, communication, len(packs)), 10)
"num_devices = %d, group_size = %d, communication_hint = %s, "
"num_packs = %d" % (batch_size, len(
self._devices), self._group_size, communication, len(packs)), 10)
reduced_values = []
for pack in packs:
@ -1094,21 +1100,25 @@ class CollectiveAllReduce(CrossDeviceOps):
control_inputs = None
reduced_values.append(
cross_device_utils.build_collective_reduce(
per_replica.values, self._num_workers,
self._collective_keys, "Add", "Id", communication,
control_inputs, executors=self._executors))
per_replica.values,
self._devices,
self._group_size,
self._collective_keys,
"Add",
"Id",
communication,
control_inputs,
executors=self._executors))
mirrored = []
# Reverse the order of reduced value to recover the order in the input.
for value in reversed(reduced_values):
if reduce_op == reduce_util.ReduceOp.MEAN:
# Assume each worker has the same number of replicas.
num_replicas = len(value) * self._num_workers
for i, v in enumerate(value):
with ops.device(v.device):
value[i] = v / num_replicas
mirrored.append(distribute_utils.regroup(value,
wrap_class=value_lib.Mirrored))
value[i] = v / self._group_size
mirrored.append(
distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
return mirrored
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values):
@ -1116,8 +1126,8 @@ class CollectiveAllReduce(CrossDeviceOps):
logging.log_first_n(
logging.INFO, "Collective batch_all_reduce for IndexedSlices: "
"%d all-reduces, num_workers = %d" %
(len(per_replica_values), self._num_workers), 10)
"%d all-reduces, group_size = %d" %
(len(per_replica_values), self._group_size), 10)
# Pass self._communication to the runtime as a communication hint.
communication_hint = self._communication.value
@ -1133,25 +1143,24 @@ class CollectiveAllReduce(CrossDeviceOps):
for per_replica in per_replica_values:
gathered_values.append(
cross_device_utils.build_collective_gather_indexed_slices(
per_replica.values, self._num_workers, self._collective_keys,
communication_hint))
per_replica.values, self._devices, self._group_size,
self._collective_keys, communication_hint))
mirrored = []
for value in gathered_values:
if reduce_op == reduce_util.ReduceOp.MEAN:
# Assume each worker has the same number of replicas.
num_replicas = len(value) * self._num_workers
for i, v in enumerate(value):
with ops.device(v.device):
value[i].values = value[i].values / num_replicas
mirrored.append(distribute_utils.regroup(value,
wrap_class=value_lib.Mirrored))
value[i].values = value[i].values / self._group_size
mirrored.append(
distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
return mirrored
def __deepcopy__(self, memo):
# distribute_coordinator deep-copies the strategy object, so
# CollectiveAllReduce needs to support deep copy as well.
return CollectiveAllReduce(self._num_workers, self._num_gpus_per_worker,
return CollectiveAllReduce(self._devices, self._group_size,
self._collective_keys, self._communication)

View File

@ -26,6 +26,7 @@ import time
from absl.testing import parameterized
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import cluster_resolver
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import combinations
@ -34,10 +35,12 @@ from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import kernels
@ -125,7 +128,10 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
self.evaluate(ops.convert_to_tensor(left)),
self.evaluate(ops.convert_to_tensor(right)))
def _assert_mirrored_equal(self, left_list, right_list, sess,
def _assert_mirrored_equal(self,
left_list,
right_list,
sess=None,
run_options=None):
if not isinstance(left_list, list):
left_list, right_list = [left_list], [right_list]
@ -142,17 +148,14 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
left, right = [left], [right]
for left_value, right_value in zip(left, right):
self.assertEqual(left_value.device, right_value.device)
self.assertEqual(
device_util.resolve(left_value.device),
device_util.resolve(right_value.device))
# Densify IndexedSlices.
left = [ops.convert_to_tensor(v) for v in left]
right = [ops.convert_to_tensor(v) for v in right]
if context.executing_eagerly():
# Optional args in session run are not supported when eager execution
# is enabled.
assert run_options is None
left, right = sess.run((left, right))
else:
if not context.executing_eagerly():
left, right = sess.run((left, right), options=run_options)
for left_value, right_value in zip(left, right):
self.assertAllEqual(left_value, right_value)
@ -525,8 +528,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
return strategy, devices, ""
else:
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
1,
num_gpus,
devices=devices,
group_size=len(devices),
collective_keys=collective_keys,
communication=communication)
return collective_all_reduce_ops, devices, ""
@ -545,26 +548,28 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
]
if use_strategy_object:
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
communication=communication)
strategy.configure(
cluster_spec=self._cluster_spec,
resolver = cluster_resolver.SimpleClusterResolver(
cluster_spec=multi_worker_util.normalize_cluster_spec(
self._cluster_spec),
task_type=task_type,
task_id=task_id)
task_id=task_id,
num_accelerators={"GPU": num_gpus})
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
cluster_resolver=resolver, communication=communication)
strategy.extended._collective_keys = collective_keys
strategy.extended._cross_device_ops._collective_keys = collective_keys
return (strategy, devices,
"grpc://" + self._cluster_spec[task_type][task_id])
else:
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
NUM_WORKERS,
num_gpus,
devices=devices,
group_size=len(devices) * NUM_WORKERS,
collective_keys=collective_keys,
communication=communication)
return (collective_all_reduce_ops, devices,
"grpc://" + self._cluster_spec[task_type][task_id])
def _assert_mirrored_equal(self, left_list, right_list, sess):
def _assert_mirrored_equal(self, left_list, right_list, sess=None):
if context.executing_eagerly():
run_options = None
else:
@ -895,6 +900,45 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
self.assertAllEqual(reduced[1].values, [4.0, 4.0])
t.join()
@combinations.generate(
combinations.combine(
required_gpus=2,
mode="eager",
communication=[
CollectiveCommunication.NCCL, CollectiveCommunication.RING
]))
def testInputsAreFunctionArgs(self, communication):
# Function inputs don't have device placement.
hints = collective_util.Hints(bytes_per_pack=1)
collective, devices, _ = self._get_test_objects(
None,
None,
num_gpus=2,
communication=communication,
use_strategy_object=False,
local_mode=True)
devices = [device_util.canonicalize(d) for d in devices]
@def_function.function
def reduce_fn(v):
self.assertEqual(v.values[0].device, "")
self.assertEqual(v.values[1].device, "")
# We only use NCCL for batch reduce with two or more values, so we use two
# values here.
reduced = collective.batch_reduce(
reduce_util.ReduceOp.SUM, [(v, v), (v, v)], experimental_hints=hints)
self.assertEqual(reduced[0].values[0].device, devices[0])
self.assertEqual(reduced[0].values[1].device, devices[1])
self.assertEqual(reduced[1].values[0].device, devices[0])
self.assertEqual(reduced[1].values[1].device, devices[1])
# Returning Mirrored only evaluates the primary value, which causes
# hanging,
return [reduced[0].values, reduced[1].values]
v = _make_per_replica([1.0, 2.0], devices)
reduced = reduce_fn(v)
self.assertAllEqual(self.evaluate(reduced), [[3.0, 3.0], [3.0, 3.0]])
if __name__ == "__main__":
# Set default inter op thread pool size to one to ensure we don't exhaust the

View File

@ -305,19 +305,6 @@ class CollectiveKeys(object):
self._group_key_table[key_id] = new_key
return self._group_key_table[key_id]
def get_group_key_of_tensors(self, tensors):
"""Returns a group key for set of tensors.
Args:
tensors: list of `Tensor`s in a collective group. Each tensor must be on a
different device.
Returns:
int key uniquely identifying the set of devices of these tensors.
"""
devices = [t.device for t in tensors]
return self.get_group_key(devices)
def get_op_instance_key(self):
"""Returns a new instance key for use in defining a collective op."""
v = self._get_thread_local_object().op_instance_key
@ -332,7 +319,8 @@ class CollectiveKeys(object):
def build_collective_reduce(input_tensors,
num_workers,
devices,
group_size,
collective_keys,
reduction_op='Add',
unary_op='Id',
@ -347,9 +335,10 @@ def build_collective_reduce(input_tensors,
Args:
input_tensors: tensors within a single worker graph that are to be reduced
together; must be one per device.
num_workers: total number of workers with identical independent graphs that
will be doing this same reduction. The reduction will actually include
the corresponding tensors at all these workers.
devices: a list of device strings to run the collective on.
group_size: total number of devices globally that will be doing this same
reduction. The reduction will actually include the corresponding tensors
at all these workers.
collective_keys: a CollectiveKeys object.
reduction_op: string naming the reduction op.
unary_op: string naming the unary final op.
@ -370,11 +359,14 @@ def build_collective_reduce(input_tensors,
not all(e.is_async() for e in executors)):
raise ValueError(
'collectives requires async executors for each device in eager mode')
if len(input_tensors) != len(devices):
raise ValueError('collective requires one input tensor for each device, '
'len(input_tensors) = %d, len(devices) = %d' %
(len(input_tensors), len(devices)))
group_size = len(input_tensors) * num_workers
if group_size < 2:
return input_tensors
group_key = collective_keys.get_group_key_of_tensors(input_tensors)
group_key = collective_keys.get_group_key(devices)
instance_key = collective_keys.get_op_instance_key()
subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec
@ -385,9 +377,9 @@ def build_collective_reduce(input_tensors,
else:
executor_scope = ops.NullContextmanager()
with executor_scope, \
ops.device(input_tensor.device), \
ops.device(devices[idx]), \
ops.control_dependencies(
_control_input(input_tensors, control_inputs, idx)):
_control_input(devices, control_inputs, idx)):
out_tensor = collective_ops.all_reduce(input_tensor, group_size,
group_key, instance_key,
reduction_op, unary_op,
@ -397,7 +389,8 @@ def build_collective_reduce(input_tensors,
def build_collective_gather(input_tensors,
num_workers,
devices,
group_size,
collective_keys,
communication_hint='AUTO',
control_inputs=None):
@ -408,9 +401,10 @@ def build_collective_gather(input_tensors,
Args:
input_tensors: tensors within a single worker graph that are to be gathered
together; must be one per device.
num_workers: total number of workers with identical independent graphs that
will be doing this same reduction. The reduction will actually include
the corresponding tensors at all these workers.
devices: a list of device strings to run the collective on.
group_size: total number of devices globally that will be doing this same
gathering. The gathering will actually include the corresponding tensors
at all these workers.
collective_keys: a CollectiveKeys object.
communication_hint: string providing hint to runtime for choosing collective
implementation.
@ -423,18 +417,21 @@ def build_collective_gather(input_tensors,
assert not context.executing_eagerly(), (
'build_collective_gather can only be called in graph mode or inside '
'tf.function')
if len(input_tensors) != len(devices):
raise ValueError(
'collective requires one input tensor for each device, %d != %d' %
(len(input_tensors), len(devices)))
group_size = len(input_tensors) * num_workers
if group_size < 2:
return input_tensors
group_key = collective_keys.get_group_key_of_tensors(input_tensors)
group_key = collective_keys.get_group_key(devices)
instance_key = collective_keys.get_op_instance_key()
out_tensors = []
for idx, input_tensor in enumerate(input_tensors):
with ops.device(input_tensor.device):
with ops.device(devices[idx]):
with ops.control_dependencies(
_control_input(input_tensors, control_inputs, idx)):
_control_input(devices, control_inputs, idx)):
out_tensor = collective_ops.all_gather(input_tensor, group_size,
group_key, instance_key,
communication_hint)
@ -443,7 +440,8 @@ def build_collective_gather(input_tensors,
def build_collective_gather_indexed_slices(input_slices_list,
num_workers,
devices,
group_size,
collective_keys,
communication_hint='AUTO',
control_inputs=None):
@ -454,9 +452,10 @@ def build_collective_gather_indexed_slices(input_slices_list,
Args:
input_slices_list: a list of IndexedSlices within a single worker graph that
are to be gathered together; must be one per device.
num_workers: total number of workers with identical independent graphs that
will be doing this same reduction. The reduction will actually include
the corresponding tensors at all these workers.
devices: a list of device strings to run the collective on.
group_size: total number of devices globally that will be doing this same
gathering. The gathering will actually include the corresponding tensors
at all these workers.
collective_keys: a CollectiveKeys object.
communication_hint: string providing hint to runtime for choosing collective
implementation.
@ -474,12 +473,15 @@ def build_collective_gather_indexed_slices(input_slices_list,
assert not context.executing_eagerly(), (
'build_collective_gather_indexed_slices can only be called in graph mode'
' or inside tf.function')
if len(input_slices_list) != len(devices):
raise ValueError(
'collective requires one input IndexedSlice for each device, %d != %d' %
(len(input_slices_list), len(devices)))
group_size = len(input_slices_list) * num_workers
if group_size < 2:
return input_slices_list
group_key = collective_keys.get_group_key_of_tensors(input_slices_list)
group_key = collective_keys.get_group_key(devices)
gather_length_key = collective_keys.get_op_instance_key()
gather_indices_key = collective_keys.get_op_instance_key()
gather_values_key = collective_keys.get_op_instance_key()
@ -495,7 +497,7 @@ def build_collective_gather_indexed_slices(input_slices_list,
out_slices_list = []
for idx, input_slices in enumerate(input_slices_list):
# pylint: disable = cell-var-from-loop
with ops.device(input_slices.device):
with ops.device(devices[idx]):
def all_gather():
"""Use all_gather to aggregate `IndexedSlices`."""
@ -967,14 +969,13 @@ def pack_by_size(per_replica_list, bytes_per_pack):
return packs
def _control_input(inputs, control_inputs, idx):
def _control_input(devices, control_inputs, idx):
"""Returns the `idx`-th item in control_inputs to be used in ops.control_dependencies.
This is a helper function for building collective ops. The function checks
that the devices of control_inputs and inputs match.
This is a helper function for building collective ops.
Args:
inputs: a list of `Tensor`s
devices: a list of device strings the collective run on.
control_inputs: a list or None.
idx: the index into `inputs` and `control_inputs`.
@ -984,12 +985,8 @@ def _control_input(inputs, control_inputs, idx):
"""
if control_inputs is None:
return []
if len(control_inputs) != len(inputs):
if len(control_inputs) != len(devices):
raise ValueError(
'control_inputs must match the length of the inputs, %s != %s' %
(len(control_inputs), len(inputs)))
if control_inputs[idx].device != inputs[idx].device:
raise ValueError(
'control_inputs must match the device of the inputs, %s != %s' %
(control_inputs[idx].device, inputs[idx].device))
'control_inputs must match the length of the devices, %s != %s' %
(len(control_inputs), len(devices)))
return [control_inputs[idx]]

View File

@ -579,7 +579,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
if not destinations:
# TODO(josh11b): Use current logical device instead of 0 here.
destinations = self._devices
return self._get_cross_device_ops().broadcast(tensor, destinations)
return self._get_cross_device_ops(tensor).broadcast(tensor, destinations)
def _call_for_each_replica(self, fn, args, kwargs):
return mirrored_run.call_for_each_replica(
@ -608,7 +608,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
updated_config.isolate_session_state = True
return updated_config
def _get_cross_device_ops(self):
def _get_cross_device_ops(self, value):
del value # Unused.
return self._cross_device_ops or self._inferred_cross_device_ops
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
@ -623,7 +624,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
# be 0.
return cross_device_ops_lib.reduce_non_distributed_value(
reduce_op, value, destinations, self._num_replicas_in_sync)
return self._get_cross_device_ops().reduce(
return self._get_cross_device_ops(value).reduce(
reduce_op,
value,
destinations=destinations,
@ -631,9 +632,15 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
def _batch_reduce_to(self, reduce_op, value_destination_pairs,
experimental_hints):
return self._get_cross_device_ops().batch_reduce(reduce_op,
value_destination_pairs,
experimental_hints)
cross_device_ops = None
for value, _ in value_destination_pairs:
if cross_device_ops is None:
cross_device_ops = self._get_cross_device_ops(value)
elif cross_device_ops is not self._get_cross_device_ops(value):
raise ValueError("inputs to batch_reduce_to must be either all on the "
"the host or all on the compute devices")
return cross_device_ops.batch_reduce(reduce_op, value_destination_pairs,
experimental_hints)
def _update(self, var, fn, args, kwargs, group):
# TODO(josh11b): In eager mode, use one thread per device.

View File

@ -65,14 +65,14 @@ class StrategyReduceTest(test.TestCase, parameterized.TestCase):
self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync)
@combinations.generate(
combinations.combine(
strategy=[strategy_combinations.multi_worker_mirrored_two_workers],
mode=['eager']))
class DistributedCollectiveAllReduceStrategyTest(
strategy_test_lib.DistributionTestBase,
parameterized.TestCase):
@combinations.generate(
combinations.combine(
strategy=[strategy_combinations.multi_worker_mirrored_two_workers],
mode=['eager']))
def testDatasetFromFunction(self, strategy):
def dataset_fn(input_context):
global_batch_size = 10
@ -95,6 +95,45 @@ class DistributedCollectiveAllReduceStrategyTest(
sum_value.numpy(),
expected_sum_on_workers[multi_worker_test_base.get_task_index()])
def testReduceHostTensor(self, strategy):
reduced = strategy.reduce(
reduce_util.ReduceOp.SUM, array_ops.identity(1.), axis=None)
self.assertEqual(reduced.numpy(), 2.)
def testReduceToHostTensor(self, strategy):
value = array_ops.identity(1.)
reduced = strategy.extended.reduce_to(reduce_util.ReduceOp.SUM, value,
value)
self.assertEqual(reduced.numpy(), 2.)
def testBatchReduceToHostTensor(self, strategy):
value = array_ops.identity(1.)
reduced = strategy.extended.batch_reduce_to(reduce_util.ReduceOp.SUM,
[(value, value),
(value, value)])
self.assertAllEqual(reduced, [2., 2.])
def testReduceDeviceTensors(self, strategy):
value = strategy.run(lambda: array_ops.identity(1.))
reduced = strategy.reduce(reduce_util.ReduceOp.SUM, value, axis=None)
self.assertEqual(reduced.numpy(), 2.)
def testReduceToDeviceTensors(self, strategy):
value = strategy.run(lambda: array_ops.identity(1.))
reduced = strategy.extended.reduce_to(reduce_util.ReduceOp.SUM, value,
value)
self.assertEqual(reduced.numpy(), 2.)
def testBatchReduceToDeviceTensors(self, strategy):
value = strategy.run(lambda: array_ops.identity(1.))
reduced = strategy.extended.batch_reduce_to(reduce_util.ReduceOp.SUM,
[(value, value),
(value, value)])
self.assertAllEqual(reduced, [2., 2.])
# TODO(crccw): add a test that mixes device and host tensors after multi
# worker strategy combinations can run on a fixed number of GPUs.
if __name__ == '__main__':
combinations.main()