cross_device_ops reduce() should always return Tensors
Currently some bypass code path return inputs as is. When inputs are variables we should still return Tensors for simplicity. PiperOrigin-RevId: 306283948 Change-Id: I295a9ff0769b26da07347611d8ef027fafd5b93d
This commit is contained in:
parent
81594ed391
commit
19e0a5d35f
@ -196,9 +196,6 @@ def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
|
||||
raise ValueError("`per_replica_value` must be non-empty")
|
||||
count = len(all_values)
|
||||
|
||||
if (count == 1 and all_values[0].device == reduce_to_device):
|
||||
return all_values[0]
|
||||
|
||||
with ops.device(reduce_to_device):
|
||||
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
|
||||
reduced = cross_device_utils.aggregate_tensors_or_indexed_slices(
|
||||
@ -236,7 +233,8 @@ class CrossDeviceOps(object):
|
||||
Args:
|
||||
reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
|
||||
per_replica_value will be reduced.
|
||||
per_replica_value: A PerReplica object or a tensor with device set.
|
||||
per_replica_value: A `tf.distribute.DistributedValues` object or a tensor
|
||||
with device set.
|
||||
destinations: the reduction destinations.
|
||||
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
|
||||
to perform collective operations.
|
||||
@ -257,9 +255,9 @@ class CrossDeviceOps(object):
|
||||
if self._num_between_graph_workers == 1 and len(
|
||||
per_replica_value.values) == 1 and _devices_match(
|
||||
per_replica_value, destinations):
|
||||
return value_lib.regroup(
|
||||
per_replica_value.values,
|
||||
wrap_class=value_lib.Mirrored)
|
||||
with ops.device(per_replica_value.values[0].device):
|
||||
v = array_ops.identity(per_replica_value.values[0])
|
||||
return value_lib.regroup((v,), wrap_class=value_lib.Mirrored)
|
||||
|
||||
if experimental_hints is None:
|
||||
experimental_hints = collective_util.Hints()
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
@ -39,6 +40,7 @@ from tensorflow.python.framework import kernels
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
def _get_devices(devices):
|
||||
@ -386,6 +388,33 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
|
||||
self._testIndexedSlicesAllReduce(devices, cross_device_ops_instance,
|
||||
reduce_op, batch_reduce)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
cross_device_ops_instance=[
|
||||
combinations.NamedObject(
|
||||
"ReductionToOneDevice",
|
||||
cross_device_ops_lib.ReductionToOneDevice()),
|
||||
combinations.NamedObject(
|
||||
"AllReduceCrossDeviceOps",
|
||||
cross_device_ops_lib.AllReduceCrossDeviceOps("ring"))
|
||||
],
|
||||
batch_reduce=[True, False],
|
||||
mode=["graph", "eager"]))
|
||||
def testReduceDistributedVariable(self, distribution,
|
||||
cross_device_ops_instance, batch_reduce):
|
||||
with distribution.scope():
|
||||
v = variables.Variable(1.)
|
||||
if batch_reduce:
|
||||
result = cross_device_ops_instance.batch_reduce(reduce_util.ReduceOp.MEAN,
|
||||
[(v, v)])[0]
|
||||
else:
|
||||
result = cross_device_ops_instance.reduce(reduce_util.ReduceOp.MEAN, v, v)
|
||||
for v in result.values:
|
||||
self.assertIsInstance(v, ops.Tensor)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual(self.evaluate(result.values), [1.0, 1.0])
|
||||
|
||||
|
||||
class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
CrossDeviceOpsTestBase):
|
||||
|
Loading…
Reference in New Issue
Block a user