Require that axis arguments to reduction ops are unique.

PiperOrigin-RevId: 319310499
Change-Id: I6c1e4b875d57d3c9cb67bd529fe6fb2a499f90f7
This commit is contained in:
A. Unique TensorFlower 2020-07-01 15:42:26 -07:00 committed by TensorFlower Gardener
parent 05bd6dcff4
commit 48627d22a1
4 changed files with 63 additions and 48 deletions

View File

@ -69,6 +69,11 @@ Status SimplifyHelper(const Tensor& data, const Tensor& axis,
" dimension(s)"); " dimension(s)");
} }
index = (index + data.dims()) % data.dims(); index = (index + data.dims()) % data.dims();
if (bitmap[index]) {
return errors::InvalidArgument(
"Invalid reduction arguments: Axes contains duplicate dimension: ",
index);
}
bitmap[index] = true; bitmap[index] = true;
} }
return Status::OK(); return Status::OK();

View File

@ -25,6 +25,7 @@ import numpy as np
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -340,6 +341,14 @@ class SumReductionTest(BaseReductionTest):
".*must be at most rank 1.*"): ".*must be at most rank 1.*"):
math_ops.reduce_sum(c_unknown, reduction_axes) math_ops.reduce_sum(c_unknown, reduction_axes)
def testInvalidRepeatedReductionIndices(self):
reduction_axes = constant_op.constant([0, 0])
c = constant_op.constant([1.0, 2.0])
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
".*Axes contains duplicate dimension: 0.*"):
self.evaluate(math_ops.reduce_sum(c, reduction_axes))
# Int64?? # Int64??
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1

View File

@ -1828,18 +1828,18 @@ def _ReductionDims(x, axis): # pylint: disable=invalid-name
if axis is not None: if axis is not None:
return axis return axis
else: else:
# Fast path: avoid creating Rank and Range ops if ndims is known. x_rank = None
if isinstance(x, ops.Tensor): if isinstance(x, ops.Tensor):
rank = x.shape.rank x_rank = x.shape.rank
if rank is not None:
return constant_op.constant(np.arange(rank, dtype=np.int32))
elif (isinstance(x, sparse_tensor.SparseTensor) and elif (isinstance(x, sparse_tensor.SparseTensor) and
x.dense_shape.shape.is_fully_defined()): x.dense_shape.shape.is_fully_defined()):
rank = x.dense_shape.shape.dims[0].value # sparse.dense_shape is 1-D. x_rank = x.dense_shape.shape.dims[0].value # sparse.dense_shape is 1-D.
return constant_op.constant(np.arange(rank, dtype=np.int32)) # Fast path: avoid creating Rank and Range ops if ndims is known.
if x_rank:
# Otherwise, we rely on Range and Rank to do the right thing at run-time. return constant_op.constant(np.arange(x_rank, dtype=np.int32))
return range(0, array_ops.rank(x)) else:
# Otherwise, we rely on Range and Rank to do the right thing at run-time.
return range(0, array_ops.rank(x))
def _has_fully_defined_shape(tensor): def _has_fully_defined_shape(tensor):
@ -1870,8 +1870,8 @@ def reduce_sum_v1(input_tensor,
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.
@ -1920,8 +1920,8 @@ def reduce_sum(input_tensor, axis=None, keepdims=False, name=None):
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.
@ -1997,8 +1997,8 @@ def reduce_euclidean_norm(input_tensor, axis=None, keepdims=False, name=None):
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.
@ -2193,8 +2193,8 @@ def reduce_mean_v1(input_tensor,
Reduces `input_tensor` along the dimensions given in `axis` by computing the Reduces `input_tensor` along the dimensions given in `axis` by computing the
mean of elements across the dimensions in `axis`. mean of elements across the dimensions in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a tensor with a single If `axis` is None, all dimensions are reduced, and a tensor with a single
element is returned. element is returned.
@ -2255,8 +2255,8 @@ def reduce_mean(input_tensor, axis=None, keepdims=False, name=None):
Reduces `input_tensor` along the dimensions given in `axis` by computing the Reduces `input_tensor` along the dimensions given in `axis` by computing the
mean of elements across the dimensions in `axis`. mean of elements across the dimensions in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions are retained of the entries in `axis`, which must be unique. If `keepdims` is true, the
with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a tensor with a single If `axis` is None, all dimensions are reduced, and a tensor with a single
element is returned. element is returned.
@ -2314,8 +2314,8 @@ def reduce_variance(input_tensor, axis=None, keepdims=False, name=None):
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.
@ -2375,8 +2375,8 @@ def reduce_std(input_tensor, axis=None, keepdims=False, name=None):
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.
@ -2469,8 +2469,8 @@ def reduce_prod_v1(input_tensor,
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.
@ -2515,8 +2515,8 @@ def reduce_min_v1(input_tensor,
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.
@ -2553,8 +2553,8 @@ def reduce_min(input_tensor, axis=None, keepdims=False, name=None):
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.
@ -2602,8 +2602,8 @@ def reduce_max_v1(input_tensor,
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.
@ -2640,8 +2640,8 @@ def reduce_max(input_tensor, axis=None, keepdims=False, name=None):
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.
@ -2707,8 +2707,8 @@ def reduce_all_v1(input_tensor,
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.
@ -2754,8 +2754,8 @@ def reduce_all(input_tensor, axis=None, keepdims=False, name=None):
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.
@ -2807,8 +2807,8 @@ def reduce_any_v1(input_tensor,
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.
@ -2854,8 +2854,8 @@ def reduce_any(input_tensor, axis=None, keepdims=False, name=None):
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` is None, all dimensions are reduced, and a If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.
@ -2907,8 +2907,8 @@ def reduce_logsumexp_v1(input_tensor,
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a If `axis` has no entries, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.
@ -2956,8 +2956,8 @@ def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None):
Reduces `input_tensor` along the dimensions given in `axis`. Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions of the entries in `axis`, which must be unique. If `keepdims` is true, the
are retained with length 1. reduced dimensions are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a If `axis` has no entries, all dimensions are reduced, and a
tensor with a single element is returned. tensor with a single element is returned.

View File

@ -57,13 +57,14 @@ class ReduceTest(test_util.TensorFlowTestCase):
def testReduceExplicitAxes(self): def testReduceExplicitAxes(self):
x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
with test_util.device(use_gpu=True): with test_util.device(use_gpu=True):
for axis in (0, -2, (0, 0), (0, -2)): for axis in (0, -2):
self.assertAllEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)), self.assertAllEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)),
[5, 7, 9]) [5, 7, 9])
for axis in (1, -1, (1, 1), (1, -1)): for axis in (1, -1):
self.assertAllEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)), self.assertAllEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)),
[6, 15]) [6, 15])
for axis in (None, (0, 1), (-1, -2), (-2, -1, 0, 1)): for axis in (None, (0, 1), (1, 0), (-1, 0), (0, -1), (-2, 1), (1, -2),
(-1, -2), (-2, -1)):
self.assertEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)), 21) self.assertEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)), 21)
def testReduceInvalidAxis(self): def testReduceInvalidAxis(self):