cross_device_ops reduce() should always return Tensors

Currently some bypass code path return inputs as is. When inputs are variables
we should still return Tensors for simplicity.

PiperOrigin-RevId: 306283948
Change-Id: I295a9ff0769b26da07347611d8ef027fafd5b93d
This commit is contained in:
Ran Chen 2020-04-13 12:19:00 -07:00 committed by TensorFlower Gardener
parent 81594ed391
commit 19e0a5d35f
2 changed files with 34 additions and 7 deletions

View File

@ -196,9 +196,6 @@ def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
raise ValueError("`per_replica_value` must be non-empty")
count = len(all_values)
if (count == 1 and all_values[0].device == reduce_to_device):
return all_values[0]
with ops.device(reduce_to_device):
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
reduced = cross_device_utils.aggregate_tensors_or_indexed_slices(
@ -236,7 +233,8 @@ class CrossDeviceOps(object):
Args:
reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
per_replica_value will be reduced.
per_replica_value: A PerReplica object or a tensor with device set.
per_replica_value: A `tf.distribute.DistributedValues` object or a tensor
with device set.
destinations: the reduction destinations.
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
to perform collective operations.
@ -257,9 +255,9 @@ class CrossDeviceOps(object):
if self._num_between_graph_workers == 1 and len(
per_replica_value.values) == 1 and _devices_match(
per_replica_value, destinations):
return value_lib.regroup(
per_replica_value.values,
wrap_class=value_lib.Mirrored)
with ops.device(per_replica_value.values[0].device):
v = array_ops.identity(per_replica_value.values[0])
return value_lib.regroup((v,), wrap_class=value_lib.Mirrored)
if experimental_hints is None:
experimental_hints = collective_util.Hints()

View File

@ -31,6 +31,7 @@ from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import test
@ -39,6 +40,7 @@ from tensorflow.python.framework import kernels
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
def _get_devices(devices):
@ -386,6 +388,33 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
self._testIndexedSlicesAllReduce(devices, cross_device_ops_instance,
reduce_op, batch_reduce)
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
cross_device_ops_instance=[
combinations.NamedObject(
"ReductionToOneDevice",
cross_device_ops_lib.ReductionToOneDevice()),
combinations.NamedObject(
"AllReduceCrossDeviceOps",
cross_device_ops_lib.AllReduceCrossDeviceOps("ring"))
],
batch_reduce=[True, False],
mode=["graph", "eager"]))
def testReduceDistributedVariable(self, distribution,
cross_device_ops_instance, batch_reduce):
with distribution.scope():
v = variables.Variable(1.)
if batch_reduce:
result = cross_device_ops_instance.batch_reduce(reduce_util.ReduceOp.MEAN,
[(v, v)])[0]
else:
result = cross_device_ops_instance.reduce(reduce_util.ReduceOp.MEAN, v, v)
for v in result.values:
self.assertIsInstance(v, ops.Tensor)
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(self.evaluate(result.values), [1.0, 1.0])
class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase,
CrossDeviceOpsTestBase):