From 48627d22a1d85a37b6a92c134e55901e0346d4f9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 1 Jul 2020 15:42:26 -0700 Subject: [PATCH] Require that axis arguments to reduction ops are unique. PiperOrigin-RevId: 319310499 Change-Id: I6c1e4b875d57d3c9cb67bd529fe6fb2a499f90f7 --- .../core/kernels/reduction_ops_common.cc | 5 ++ .../python/kernel_tests/reduction_ops_test.py | 9 ++ tensorflow/python/ops/math_ops.py | 90 +++++++++---------- tensorflow/python/ops/math_ops_test.py | 7 +- 4 files changed, 63 insertions(+), 48 deletions(-) diff --git a/tensorflow/core/kernels/reduction_ops_common.cc b/tensorflow/core/kernels/reduction_ops_common.cc index c341e330178..2e21094cc49 100644 --- a/tensorflow/core/kernels/reduction_ops_common.cc +++ b/tensorflow/core/kernels/reduction_ops_common.cc @@ -69,6 +69,11 @@ Status SimplifyHelper(const Tensor& data, const Tensor& axis, " dimension(s)"); } index = (index + data.dims()) % data.dims(); + if (bitmap[index]) { + return errors::InvalidArgument( + "Invalid reduction arguments: Axes contains duplicate dimension: ", + index); + } bitmap[index] = true; } return Status::OK(); diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py index c0ad81f3055..8bf5a08a358 100644 --- a/tensorflow/python/kernel_tests/reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/reduction_ops_test.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util @@ -340,6 +341,14 @@ class SumReductionTest(BaseReductionTest): ".*must be at most rank 1.*"): math_ops.reduce_sum(c_unknown, reduction_axes) + def testInvalidRepeatedReductionIndices(self): + reduction_axes = constant_op.constant([0, 0]) + c = constant_op.constant([1.0, 2.0]) + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + ".*Axes contains duplicate dimension: 0.*"): + self.evaluate(math_ops.reduce_sum(c, reduction_axes)) + # Int64?? @test_util.run_deprecated_v1 diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 7810dae2688..d587285a36e 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1828,18 +1828,18 @@ def _ReductionDims(x, axis): # pylint: disable=invalid-name if axis is not None: return axis else: - # Fast path: avoid creating Rank and Range ops if ndims is known. + x_rank = None if isinstance(x, ops.Tensor): - rank = x.shape.rank - if rank is not None: - return constant_op.constant(np.arange(rank, dtype=np.int32)) + x_rank = x.shape.rank elif (isinstance(x, sparse_tensor.SparseTensor) and x.dense_shape.shape.is_fully_defined()): - rank = x.dense_shape.shape.dims[0].value # sparse.dense_shape is 1-D. - return constant_op.constant(np.arange(rank, dtype=np.int32)) - - # Otherwise, we rely on Range and Rank to do the right thing at run-time. - return range(0, array_ops.rank(x)) + x_rank = x.dense_shape.shape.dims[0].value # sparse.dense_shape is 1-D. + # Fast path: avoid creating Rank and Range ops if ndims is known. + if x_rank: + return constant_op.constant(np.arange(x_rank, dtype=np.int32)) + else: + # Otherwise, we rely on Range and Rank to do the right thing at run-time. + return range(0, array_ops.rank(x)) def _has_fully_defined_shape(tensor): @@ -1870,8 +1870,8 @@ def reduce_sum_v1(input_tensor, Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -1920,8 +1920,8 @@ def reduce_sum(input_tensor, axis=None, keepdims=False, name=None): Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -1997,8 +1997,8 @@ def reduce_euclidean_norm(input_tensor, axis=None, keepdims=False, name=None): Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -2193,8 +2193,8 @@ def reduce_mean_v1(input_tensor, Reduces `input_tensor` along the dimensions given in `axis` by computing the mean of elements across the dimensions in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -2255,8 +2255,8 @@ def reduce_mean(input_tensor, axis=None, keepdims=False, name=None): Reduces `input_tensor` along the dimensions given in `axis` by computing the mean of elements across the dimensions in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions are retained - with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -2314,8 +2314,8 @@ def reduce_variance(input_tensor, axis=None, keepdims=False, name=None): Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -2375,8 +2375,8 @@ def reduce_std(input_tensor, axis=None, keepdims=False, name=None): Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -2469,8 +2469,8 @@ def reduce_prod_v1(input_tensor, Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -2515,8 +2515,8 @@ def reduce_min_v1(input_tensor, Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -2553,8 +2553,8 @@ def reduce_min(input_tensor, axis=None, keepdims=False, name=None): Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -2602,8 +2602,8 @@ def reduce_max_v1(input_tensor, Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -2640,8 +2640,8 @@ def reduce_max(input_tensor, axis=None, keepdims=False, name=None): Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -2707,8 +2707,8 @@ def reduce_all_v1(input_tensor, Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -2754,8 +2754,8 @@ def reduce_all(input_tensor, axis=None, keepdims=False, name=None): Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -2807,8 +2807,8 @@ def reduce_any_v1(input_tensor, Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -2854,8 +2854,8 @@ def reduce_any(input_tensor, axis=None, keepdims=False, name=None): Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. @@ -2907,8 +2907,8 @@ def reduce_logsumexp_v1(input_tensor, Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` has no entries, all dimensions are reduced, and a tensor with a single element is returned. @@ -2956,8 +2956,8 @@ def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None): Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. If `axis` has no entries, all dimensions are reduced, and a tensor with a single element is returned. diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index c5448a39be4..bf15bf86ee2 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -57,13 +57,14 @@ class ReduceTest(test_util.TensorFlowTestCase): def testReduceExplicitAxes(self): x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) with test_util.device(use_gpu=True): - for axis in (0, -2, (0, 0), (0, -2)): + for axis in (0, -2): self.assertAllEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)), [5, 7, 9]) - for axis in (1, -1, (1, 1), (1, -1)): + for axis in (1, -1): self.assertAllEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)), [6, 15]) - for axis in (None, (0, 1), (-1, -2), (-2, -1, 0, 1)): + for axis in (None, (0, 1), (1, 0), (-1, 0), (0, -1), (-2, 1), (1, -2), + (-1, -2), (-2, -1)): self.assertEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)), 21) def testReduceInvalidAxis(self):