Improved fix for tf.distribute.Strategy.reduce()
when in TF v1
graph mode, axis is not None, and the tensor's shape is not fully known, addressing TODO. Also add a test for this case. PiperOrigin-RevId: 244924874
This commit is contained in:
parent
2f9b121791
commit
3ee29f7b04
@ -877,10 +877,10 @@ cuda_py_test(
|
||||
name = "mirrored_strategy_test",
|
||||
srcs = ["mirrored_strategy_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
":combinations",
|
||||
":strategy_combinations",
|
||||
":mirrored_strategy",
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
":multi_worker_test_base",
|
||||
":strategy_test_lib",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -888,6 +888,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:layers",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.eager import context as eager_context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
@ -561,14 +562,15 @@ class Strategy(object):
|
||||
raise ValueError(
|
||||
"`axis` = %r out of range for `value` with rank %d" %
|
||||
(axis, v.shape.rank))
|
||||
# TODO(anjalisridhar): Added a second condition to handle the case of
|
||||
# dynamic shapes when using tf.functions. We might want to remove this
|
||||
# static shape case and always calculate the shape of v.
|
||||
if (v.shape[axis] is not None and
|
||||
[x for x in v.get_shape().as_list() if x]):
|
||||
# TF v2 returns `None` for unknown dimensions and an integer for
|
||||
# known dimension, whereas TF v1 returns tensor_shape.Dimension(None)
|
||||
# or tensor_shape.Dimension(integer). `dimension_value` hides this
|
||||
# difference, always returning `None` or an integer.
|
||||
dim = tensor_shape.dimension_value(v.shape[axis])
|
||||
if dim is not None:
|
||||
# By returning a python value in the static shape case, we can
|
||||
# maybe get a fast path for reducing the denominator.
|
||||
return numer, v.shape[axis]
|
||||
return numer, dim
|
||||
elif axis < 0:
|
||||
axis = axis + array_ops.rank(v)
|
||||
denom = array_ops.shape_v2(v, out_type=dtypes.int64)[axis]
|
||||
|
@ -44,6 +44,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras.engine import training as keras_training
|
||||
from tensorflow.python.keras.layers import core as keras_core
|
||||
@ -107,6 +108,20 @@ class MirroredTwoDeviceDistributionTest(
|
||||
expected = sum(range(distribution.num_replicas_in_sync))
|
||||
self.assertEqual(expected, self.evaluate(reduced))
|
||||
|
||||
def reduce_axis_helper(self, distribution, replica_squared_fn):
|
||||
with distribution.scope():
|
||||
num_replicas = distribution.num_replicas_in_sync
|
||||
result = distribution.extended.call_for_each_replica(replica_squared_fn)
|
||||
# sum
|
||||
reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result, axis=0)
|
||||
expected = sum(x * (x + 1) for x in range(num_replicas))
|
||||
self.assertNear(expected, self.evaluate(reduced), 0.00001)
|
||||
|
||||
# mean
|
||||
reduced = distribution.reduce(reduce_util.ReduceOp.MEAN, result, axis=0)
|
||||
expected /= sum(x + 1 for x in range(num_replicas))
|
||||
self.assertNear(expected, self.evaluate(reduced), 0.00001)
|
||||
|
||||
def testReduceAxisToCpu(self, distribution):
|
||||
for dtype in (dtypes.float32, dtypes.int32):
|
||||
def replica_squared_fn(dtype=dtype):
|
||||
@ -114,18 +129,31 @@ class MirroredTwoDeviceDistributionTest(
|
||||
replica_id = _replica_id_as_int()
|
||||
return math_ops.cast([replica_id] * (replica_id + 1), dtype)
|
||||
|
||||
with distribution.scope():
|
||||
num_replicas = distribution.num_replicas_in_sync
|
||||
result = distribution.extended.call_for_each_replica(replica_squared_fn)
|
||||
# sum
|
||||
reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result, axis=0)
|
||||
expected = sum(x * (x + 1) for x in range(num_replicas))
|
||||
self.assertNear(expected, self.evaluate(reduced), 0.00001)
|
||||
self.reduce_axis_helper(distribution, replica_squared_fn)
|
||||
|
||||
# mean
|
||||
reduced = distribution.reduce(reduce_util.ReduceOp.MEAN, result, axis=0)
|
||||
expected /= sum(x + 1 for x in range(num_replicas))
|
||||
self.assertNear(expected, self.evaluate(reduced), 0.00001)
|
||||
def set_v2_tensorshape(self, v2):
|
||||
if v2:
|
||||
tensor_shape.enable_v2_tensorshape()
|
||||
else:
|
||||
tensor_shape.disable_v2_tensorshape()
|
||||
|
||||
def testReduceAxisToCpuUnknownShape(self, distribution):
|
||||
original_v2 = tensor_shape._TENSORSHAPE_V2_OVERRIDE # pylint: disable=protected-access
|
||||
try:
|
||||
for v2 in (False, True):
|
||||
self.set_v2_tensorshape(v2)
|
||||
for dtype in (dtypes.float32, dtypes.int32):
|
||||
for shape in ((None,), None): # Test both unknown size and rank.
|
||||
def replica_squared_fn(dtype=dtype, shape=shape):
|
||||
# Lists with different lengths on different replicas.
|
||||
replica_id = _replica_id_as_int()
|
||||
tensor = math_ops.cast([replica_id] * (replica_id + 1), dtype)
|
||||
# Erase shape information
|
||||
return array_ops.placeholder_with_default(tensor, shape=shape)
|
||||
|
||||
self.reduce_axis_helper(distribution, replica_squared_fn)
|
||||
finally:
|
||||
self.set_v2_tensorshape(original_v2)
|
||||
|
||||
def testMakeInputFnIteratorWithDataset(self, distribution):
|
||||
dataset_fn = lambda: dataset_ops.Dataset.range(10)
|
||||
|
Loading…
Reference in New Issue
Block a user