Support reducing IndexedSlices in CollectiveAllReduce._all_reduce.
PiperOrigin-RevId: 358860726 Change-Id: I1f2f4333fe87a22e4902dbea8d787c9fdf50fb9b
This commit is contained in:
parent
6b350fcde9
commit
85c597aaa2
tensorflow/python/distribute
@ -70,8 +70,9 @@ def validate_destinations(destinations):
|
||||
"""Validates the `destination` is one of expected types."""
|
||||
if not isinstance(
|
||||
destinations,
|
||||
(value_lib.DistributedValues, ops.Tensor, ps_values.AggregatingVariable,
|
||||
six.string_types, tpu_values.TPUMirroredVariable
|
||||
(value_lib.DistributedValues, ops.Tensor, ops.IndexedSlices,
|
||||
ps_values.AggregatingVariable, six.string_types,
|
||||
tpu_values.TPUMirroredVariable
|
||||
)) and not resource_variable_ops.is_resource_variable(destinations):
|
||||
raise ValueError("destinations must be one of a `DistributedValues` object,"
|
||||
" a tf.Variable object, or a device string.")
|
||||
@ -532,25 +533,25 @@ class CrossDeviceOps(object):
|
||||
def _all_reduce(self, reduce_op, value, replica_id, options):
|
||||
"""All-reduce the `value` across all replicas so that all get the result.
|
||||
|
||||
`value` can be a nested structure of tensors. The implementation should
|
||||
generally batch the all-reduces when possible. `options` can be set to
|
||||
hint the batching behavior.
|
||||
`value` can be a nested structure of tensors or `IndexedSlices`. The
|
||||
implementation should generally batch the all-reduces when possible.
|
||||
`options` can be set to hint the batching behavior.
|
||||
|
||||
This API must be called in a replica context.
|
||||
|
||||
Args:
|
||||
reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
|
||||
be combined. Allows using string representation of the enum such as
|
||||
"SUM", "MEAN".
|
||||
value: Value to be reduced. A tensor or a nested structure of tensors.
|
||||
be combined.
|
||||
value: Value to be reduced. A tensor or a nested structure of tensors or
|
||||
`IndexedSlices`.
|
||||
replica_id: An interger indicating the id of the replica where this
|
||||
all_reduce is called under. This is the local replica id that ranges
|
||||
from 0 to len(local_devices) - 1.
|
||||
options: A `tf.distribute.experimental.CommunicationOptions`.
|
||||
|
||||
Returns:
|
||||
A tensor or a nested strucutre of tensors with the reduced values. The
|
||||
structure is the same as `value`.
|
||||
A tensor/IndexedSlices or a nested strucutre of tensors/IndexedSlices with
|
||||
the reduced values. The structure is the same as `value`.
|
||||
"""
|
||||
raise NotImplementedError("_all_reduce must be implemented in descendants.")
|
||||
|
||||
@ -1083,48 +1084,74 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
# TODO(b/122840926): reuse this method in _batch_all_reduce.
|
||||
flat_values = nest.flatten(value)
|
||||
|
||||
if isinstance(flat_values[0], ops.IndexedSlices):
|
||||
raise NotImplementedError("all_reduce doesn't support IndexedSlices.")
|
||||
|
||||
batch_size = len(flat_values)
|
||||
|
||||
implementation = options.implementation.value
|
||||
# If NCCL launches can't be ordered (self._limited_nccl == True), we only
|
||||
# use NCCL only when batch_size > 1, hoping that there's only one batched
|
||||
# all-reduce, which is the gradients.
|
||||
# use NCCL when batch_size > 1, hoping that there's only one batched
|
||||
# all-reduce, which is the gradient aggregation in optimizer. For TF 2.x,
|
||||
# NCCL launches are always ordered.
|
||||
if (self._limited_nccl and
|
||||
options.implementation == CommunicationImplementation.NCCL and
|
||||
batch_size == 1):
|
||||
len(flat_values) == 1):
|
||||
implementation = CommunicationImplementation.AUTO.value
|
||||
|
||||
# Reverse the lists so that there's better chance that values follows
|
||||
# the order in which they are calculated (e.g. when they're gradients), so
|
||||
# as to overlap calculation with communication. However, this may not be
|
||||
# optimal for cases like gradients of complicated non-sequential models.
|
||||
#
|
||||
# Note that we reverse the list before packing so that the first pack won't
|
||||
# be too small, since it's more likely for first few packs to have long
|
||||
# queuing time due to concurrent intense computation.
|
||||
#
|
||||
# TODO(b/147393503): explore solutions for optimal ordering.
|
||||
flat_values.reverse()
|
||||
packs = cross_device_utils.group_by_size(flat_values,
|
||||
options.bytes_per_pack)
|
||||
|
||||
launcher = self._launchers[replica_id]
|
||||
if not context.executing_eagerly() and replica_id == 0:
|
||||
logging.info(
|
||||
"Collective all_reduce: %d all-reduces, num_devices = %d, "
|
||||
"group_size = %d, implementation = %s, num_packs = %d", batch_size,
|
||||
len(self._launchers), self._group_size, implementation, len(packs))
|
||||
flat_results = launcher.batch_all_reduce(packs, implementation,
|
||||
options.timeout_seconds)
|
||||
dense_values, dense_indices, sparse_values, sparse_indices = (
|
||||
cross_device_utils.split_by_sparsity(flat_values))
|
||||
dense_results = []
|
||||
sparse_results = []
|
||||
|
||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||
for i, v in enumerate(flat_results):
|
||||
flat_results[i] = v / self._group_size
|
||||
flat_results.reverse()
|
||||
if dense_values:
|
||||
# Reverse the lists so that there's better chance that values follows
|
||||
# the order in which they are calculated (e.g. when they're gradients), so
|
||||
# as to overlap calculation with communication. However, this may not be
|
||||
# optimal for cases like gradients of complicated non-sequential models.
|
||||
#
|
||||
# Note that we reverse the list before packing so that the first pack
|
||||
# won't be too small, since it's more likely for first few packs to have
|
||||
# long queuing time due to concurrent intense computation.
|
||||
#
|
||||
# TODO(b/147393503): explore solutions for optimal ordering.
|
||||
dense_values.reverse()
|
||||
packs = cross_device_utils.group_by_size(dense_values,
|
||||
options.bytes_per_pack)
|
||||
|
||||
if not context.executing_eagerly() and replica_id == 0:
|
||||
logging.info(
|
||||
"Collective all_reduce tensors: %d all_reduces, num_devices = %d, "
|
||||
"group_size = %d, implementation = %s, num_packs = %d",
|
||||
len(dense_values), len(self._launchers), self._group_size,
|
||||
implementation, len(packs))
|
||||
|
||||
dense_results = launcher.batch_all_reduce(packs, implementation,
|
||||
options.timeout_seconds)
|
||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||
for i, v in enumerate(dense_results):
|
||||
with ops.device(self._devices[replica_id]):
|
||||
dense_results[i] = v / self._group_size
|
||||
dense_results.reverse()
|
||||
|
||||
if sparse_values:
|
||||
if not context.executing_eagerly() and replica_id == 0:
|
||||
logging.info(
|
||||
"Collective all_reduce IndexedSlices: %d all_reduces, num_devices ="
|
||||
"%d, group_size = %d, implementation = %s", len(dense_values),
|
||||
len(self._launchers), self._group_size, implementation)
|
||||
|
||||
for indexed_slice in sparse_values:
|
||||
sparse_results.append(
|
||||
launcher.all_reduce_indexed_slices(indexed_slice, implementation,
|
||||
options.timeout_seconds))
|
||||
|
||||
if reduce_op == reduce_util.ReduceOp.MEAN:
|
||||
for i, v in enumerate(sparse_results):
|
||||
with ops.device(self._devices[replica_id]):
|
||||
sparse_results[i] = ops.IndexedSlices(
|
||||
values=sparse_results[i].values / self._group_size,
|
||||
indices=sparse_results[i].indices,
|
||||
dense_shape=sparse_results[i].dense_shape)
|
||||
|
||||
flat_results = cross_device_utils.stitch_values(
|
||||
((dense_results, dense_indices), (sparse_results, sparse_indices)))
|
||||
return nest.pack_sequence_as(value, flat_results)
|
||||
|
||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
||||
|
@ -162,10 +162,10 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
gpu_per_process: number of GPUs (0 if no GPUs) used by each process.
|
||||
|
||||
Returns:
|
||||
A tuple of (collective, devices, group_size) where collective is a instance
|
||||
A tuple of (collective, devices, pid) where collective is a instance
|
||||
of `CollectiveAllReduce`, devices are a list of local devices (str)
|
||||
attached to the current process, and group_size is the group_size of
|
||||
collective.
|
||||
attached to the current process, and pid is the id of this process among
|
||||
all participant processes.
|
||||
"""
|
||||
|
||||
cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
|
||||
@ -319,8 +319,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
],
|
||||
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
|
||||
prefer_unique_instance_key=[True, False]))
|
||||
def testAllReduceDense(self, num_processes, required_gpus, implementation,
|
||||
reduce_op, prefer_unique_instance_key):
|
||||
def testReduceDense(self, num_processes, required_gpus, implementation,
|
||||
reduce_op, prefer_unique_instance_key):
|
||||
if (required_gpus == 0 and
|
||||
implementation == CommunicationImplementation.NCCL):
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
@ -361,8 +361,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
# TODO(b/166682130): add MEAN reduce once the bug is fixed.
|
||||
reduce_op=ReduceOp.SUM,
|
||||
prefer_unique_instance_key=[True, False]))
|
||||
def testAllReduceSparse(self, num_processes, required_gpus, implementation,
|
||||
reduce_op, prefer_unique_instance_key):
|
||||
def testReduceSparse(self, num_processes, required_gpus, implementation,
|
||||
reduce_op, prefer_unique_instance_key):
|
||||
if (required_gpus == 0 and
|
||||
implementation == CommunicationImplementation.NCCL):
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
@ -410,7 +410,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(prefer_unique_instance_key=[True, False]))
|
||||
def testAllReduceSparseVariableLength(self, prefer_unique_instance_key):
|
||||
def testReduceSparseVariableLength(self, prefer_unique_instance_key):
|
||||
# One device per process, 2 processes, 2 replicas in total.
|
||||
inputs = [
|
||||
IndexedSlicesValue(values=[[1.]], indices=[0], dense_shape=[10, 1]),
|
||||
@ -441,9 +441,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
],
|
||||
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
|
||||
prefer_unique_instance_key=[True, False]))
|
||||
def testBatchAllReduceDense(self, num_processes, required_gpus,
|
||||
implementation, reduce_op,
|
||||
prefer_unique_instance_key):
|
||||
def testBatchReduceDense(self, num_processes, required_gpus, implementation,
|
||||
reduce_op, prefer_unique_instance_key):
|
||||
if (required_gpus == 0 and
|
||||
implementation == CommunicationImplementation.NCCL):
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
@ -485,9 +484,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
# TODO(b/166682130): add MEAN reduce once the bug is fixed.
|
||||
reduce_op=ReduceOp.SUM,
|
||||
prefer_unique_instance_key=[True, False]))
|
||||
def testBatchAllReduceSparse(self, num_processes, required_gpus,
|
||||
implementation, reduce_op,
|
||||
prefer_unique_instance_key):
|
||||
def testBatchReduceSparse(self, num_processes, required_gpus, implementation,
|
||||
reduce_op, prefer_unique_instance_key):
|
||||
if (required_gpus == 0 and
|
||||
implementation == CommunicationImplementation.NCCL):
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
@ -571,8 +569,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
],
|
||||
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
|
||||
))
|
||||
def testCollectiveAllReduce(self, num_processes, required_gpus,
|
||||
implementation, reduce_op):
|
||||
def testAllReduceDense(self, num_processes, required_gpus, implementation,
|
||||
reduce_op):
|
||||
if (required_gpus == 0 and
|
||||
implementation == CommunicationImplementation.NCCL):
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
@ -623,6 +621,126 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
get_global_mpr(num_processes).run(replica_fn)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
num_processes=[1, 2],
|
||||
required_gpus=[0, 1, 2],
|
||||
implementation=[
|
||||
CommunicationImplementation.AUTO,
|
||||
CommunicationImplementation.RING,
|
||||
CommunicationImplementation.NCCL,
|
||||
],
|
||||
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
|
||||
))
|
||||
def testAllReduceSparse(self, num_processes, required_gpus, implementation,
|
||||
reduce_op):
|
||||
if (required_gpus == 0 and
|
||||
implementation == CommunicationImplementation.NCCL):
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
if (num_processes == 2 and
|
||||
implementation == CommunicationImplementation.NCCL):
|
||||
self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
|
||||
"physical GPUs for every process.")
|
||||
|
||||
def replica_fn():
|
||||
collective, devices, _ = self.make_collective(num_processes,
|
||||
required_gpus)
|
||||
options = collective_util.Options(implementation=implementation)
|
||||
group_size = num_processes * (required_gpus or 1)
|
||||
|
||||
@def_function.function
|
||||
def collective_all_reduce():
|
||||
results = []
|
||||
for replica_id, device in enumerate(devices):
|
||||
with ops.device(device):
|
||||
value = IndexedSlices(
|
||||
values=array_ops.identity([[1.]]),
|
||||
indices=array_ops.identity([0]),
|
||||
dense_shape=array_ops.identity([5, 1]))
|
||||
results.append(
|
||||
collective._all_reduce(reduce_op, value, replica_id, options))
|
||||
return results
|
||||
|
||||
got = collective_all_reduce()
|
||||
if reduce_op == ReduceOp.SUM:
|
||||
expect = [IndexedSlices([[1. * group_size]], [0], [5, 1])
|
||||
] * len(devices)
|
||||
elif reduce_op == ReduceOp.MEAN:
|
||||
expect = [IndexedSlices([[1.]], [0], [5, 1])] * len(devices)
|
||||
self.assertAllClose(
|
||||
nest.map_structure(ops.convert_to_tensor, got),
|
||||
nest.map_structure(ops.convert_to_tensor, expect))
|
||||
|
||||
@def_function.function
|
||||
def collective_batch_all_reduce():
|
||||
results = []
|
||||
for replica_id, device in enumerate(devices):
|
||||
with ops.device(device):
|
||||
value = (IndexedSlices(
|
||||
array_ops.identity([[1.]]), array_ops.identity([0]),
|
||||
array_ops.identity([5, 1])),
|
||||
IndexedSlices(
|
||||
array_ops.identity([[3.]]), array_ops.identity([2]),
|
||||
array_ops.identity([5, 1])))
|
||||
results.append(
|
||||
collective._all_reduce(reduce_op, value, replica_id, options))
|
||||
return results
|
||||
|
||||
got = collective_batch_all_reduce()
|
||||
if reduce_op == ReduceOp.SUM:
|
||||
expect = [(IndexedSlices([[1. * group_size]], [0], [5, 1]),
|
||||
IndexedSlices([[3. * group_size]], [2], [5, 1]))
|
||||
] * len(devices)
|
||||
elif reduce_op == ReduceOp.MEAN:
|
||||
expect = [(IndexedSlices([[1.]], [0], [5, 1]),
|
||||
IndexedSlices([[3.]], [2], [5, 1]))] * len(devices)
|
||||
self.assertAllClose(
|
||||
nest.map_structure(ops.convert_to_tensor, got),
|
||||
nest.map_structure(ops.convert_to_tensor, expect))
|
||||
|
||||
get_global_mpr(num_processes).run(replica_fn)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
num_processes=2,
|
||||
required_gpus=0,
|
||||
implementation=CommunicationImplementation.AUTO,
|
||||
reduce_op=ReduceOp.SUM))
|
||||
def testAllReduceMixedDenseAndSparse(self, num_processes, required_gpus,
|
||||
implementation, reduce_op):
|
||||
|
||||
def replica_fn():
|
||||
collective, devices, _ = self.make_collective(num_processes,
|
||||
required_gpus)
|
||||
options = collective_util.Options(implementation=implementation)
|
||||
group_size = num_processes * (required_gpus or 1)
|
||||
|
||||
@def_function.function
|
||||
def collective_batch_all_reduce():
|
||||
results = []
|
||||
for replica_id, device in enumerate(devices):
|
||||
with ops.device(device):
|
||||
value = (IndexedSlices(
|
||||
array_ops.identity([[1.]]), array_ops.identity([0]),
|
||||
array_ops.identity([5, 1])), array_ops.identity(1.0),
|
||||
IndexedSlices(
|
||||
array_ops.identity([[3.]]), array_ops.identity([2]),
|
||||
array_ops.identity([5, 1])), array_ops.identity(2.0))
|
||||
results.append(
|
||||
collective._all_reduce(reduce_op, value, replica_id, options))
|
||||
return results
|
||||
|
||||
got = collective_batch_all_reduce()
|
||||
expect = [
|
||||
(IndexedSlices([[1. * group_size]], [0], [5, 1]), 1.0 * group_size,
|
||||
IndexedSlices([[3. * group_size]], [2], [5, 1]), 2.0 * group_size)
|
||||
] * len(devices)
|
||||
self.assertAllClose(
|
||||
nest.map_structure(ops.convert_to_tensor, got),
|
||||
nest.map_structure(ops.convert_to_tensor, expect))
|
||||
|
||||
get_global_mpr(num_processes).run(replica_fn)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
num_processes=[1, 2],
|
||||
|
@ -603,8 +603,9 @@ def copy_tensor_or_indexed_slices_to_device(value, device):
|
||||
def is_indexed_slices(value):
|
||||
if isinstance(value, ops.IndexedSlices):
|
||||
return True
|
||||
assert isinstance(value, value_lib.DistributedValues)
|
||||
return all(isinstance(v, ops.IndexedSlices) for v in value.values)
|
||||
if isinstance(value, value_lib.DistributedValues):
|
||||
return all(isinstance(v, ops.IndexedSlices) for v in value.values)
|
||||
return False
|
||||
|
||||
|
||||
def split_by_sparsity(values):
|
||||
|
@ -2409,8 +2409,7 @@ class StrategyExtendedV2(object):
|
||||
|
||||
Args:
|
||||
reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
|
||||
be combined. Allows using string representation of the enum such as
|
||||
"SUM", "MEAN".
|
||||
be combined.
|
||||
value: Value to be reduced. A tensor or a nested structure of tensors.
|
||||
options: A `tf.distribute.experimental.CommunicationOptions`. Options to
|
||||
perform collective operations. This overrides the default options if the
|
||||
|
@ -116,13 +116,14 @@ class ReduceTest(test.TestCase, parameterized.TestCase):
|
||||
strategy=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
] + strategy_combinations.strategies_minus_tpu,
|
||||
tf_function=[combinations.tf_function, combinations.no_tf_function],
|
||||
mode=['eager']))
|
||||
class ReplicaCtxAllReduceTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testBasic(self, strategy, tf_function):
|
||||
def testDense(self, strategy, tf_function):
|
||||
if (isinstance(strategy, tpu_strategy.TPUStrategy) and
|
||||
tf_function is combinations.no_tf_function):
|
||||
self.skipTest('Skip TPUStrategy + eager combination.')
|
||||
@ -132,7 +133,8 @@ class ReplicaCtxAllReduceTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def replica_fn():
|
||||
value = array_ops.identity(1.0)
|
||||
reduced = strategy.extended._replica_ctx_all_reduce('SUM', value)
|
||||
reduced = strategy.extended._replica_ctx_all_reduce(
|
||||
reduce_util.ReduceOp.SUM, value)
|
||||
return reduced
|
||||
|
||||
return strategy.experimental_local_results(strategy.run(replica_fn))
|
||||
@ -140,27 +142,75 @@ class ReplicaCtxAllReduceTest(test.TestCase, parameterized.TestCase):
|
||||
got = fn()[0]
|
||||
self.assertEqual(got, 1.0 * strategy.num_replicas_in_sync)
|
||||
|
||||
def testNestedInput(self, strategy, tf_function):
|
||||
# TODO(b/122840926): enable this test once cl/122840926 is submitted.
|
||||
self.skipTest('Enable after cl/353109164 is submitted.')
|
||||
|
||||
if (isinstance(strategy, tpu_strategy.TPUStrategy) and
|
||||
tf_function is combinations.no_tf_function):
|
||||
self.skipTest('Skip TPUStrategy + eager combination.')
|
||||
def testSparse(self, strategy, tf_function):
|
||||
if tf_function is combinations.no_tf_function:
|
||||
self.skipTest('Skip IndexedSlices + eager combination.')
|
||||
|
||||
@tf_function
|
||||
def fn():
|
||||
|
||||
def replica_fn():
|
||||
value = (array_ops.identity(1.0), array_ops.identity(1.0))
|
||||
reduced = strategy.extended._replica_ctx_all_reduce('SUM', value)
|
||||
value = ops.IndexedSlices(
|
||||
values=array_ops.identity([[1.0]]),
|
||||
indices=array_ops.identity([0]),
|
||||
dense_shape=array_ops.identity([5, 1]))
|
||||
reduced = strategy.extended._replica_ctx_all_reduce(
|
||||
reduce_util.ReduceOp.SUM, value)
|
||||
return reduced
|
||||
|
||||
return strategy.experimental_local_results(strategy.run(replica_fn))
|
||||
|
||||
got = fn()[0]
|
||||
self.assertEqual(got, (1.0 * strategy.num_replicas_in_sync,
|
||||
2.0 * strategy.num_replicas_in_sync))
|
||||
expect = ops.IndexedSlices(
|
||||
values=array_ops.identity([[1.0 * strategy.num_replicas_in_sync]]),
|
||||
indices=array_ops.identity([0]),
|
||||
dense_shape=array_ops.identity([5, 1]))
|
||||
self.assertAllEqual(
|
||||
ops.convert_to_tensor(got), ops.convert_to_tensor(expect))
|
||||
|
||||
def testNestedInput(self, strategy, tf_function):
|
||||
# TODO(b/122840926): enable this test once cl/356664949 is submitted.
|
||||
self.skipTest('Enable after cl/353109164 is submitted.')
|
||||
|
||||
if tf_function is combinations.no_tf_function:
|
||||
self.skipTest('Skip IndexedSlices + eager combination.')
|
||||
|
||||
@tf_function
|
||||
def fn():
|
||||
|
||||
def replica_fn():
|
||||
value = (array_ops.identity(1.0),
|
||||
ops.IndexedSlices(
|
||||
values=array_ops.identity([[1.0]]),
|
||||
indices=array_ops.identity([0]),
|
||||
dense_shape=array_ops.identity([5, 1])),
|
||||
array_ops.identity(2.0),
|
||||
ops.IndexedSlices(
|
||||
values=array_ops.identity([[2.0]]),
|
||||
indices=array_ops.identity([1]),
|
||||
dense_shape=array_ops.identity([5, 1])))
|
||||
reduced = strategy.extended._replica_ctx_all_reduce(
|
||||
reduce_util.ReduceOp.SUM, value)
|
||||
return reduced
|
||||
|
||||
return strategy.experimental_local_results(strategy.run(replica_fn))
|
||||
|
||||
got = fn()[0]
|
||||
expect = (1.0 * strategy.num_replicas_in_sync,
|
||||
ops.IndexedSlices(
|
||||
values=array_ops.identity(
|
||||
[[1.0 * strategy.num_replicas_in_sync]]),
|
||||
indices=array_ops.identity([0]),
|
||||
dense_shape=array_ops.identity([5, 1])),
|
||||
2.0 * strategy.num_replicas_in_sync,
|
||||
ops.IndexedSlices(
|
||||
values=array_ops.identity(
|
||||
[[2.0 * strategy.num_replicas_in_sync]]),
|
||||
indices=array_ops.identity([1]),
|
||||
dense_shape=array_ops.identity([5, 1])))
|
||||
self.assertAllEqual(
|
||||
nest.map_structure(ops.convert_to_tensor, got),
|
||||
nest.map_structure(ops.convert_to_tensor, expect))
|
||||
|
||||
|
||||
def _make_indexed_slices(values, indices, dense_shape):
|
||||
|
Loading…
Reference in New Issue
Block a user