diff --git a/tensorflow/core/kernels/reduction_ops_common.cc b/tensorflow/core/kernels/reduction_ops_common.cc index 8818f7befb6..5511b8f43cc 100644 --- a/tensorflow/core/kernels/reduction_ops_common.cc +++ b/tensorflow/core/kernels/reduction_ops_common.cc @@ -61,12 +61,13 @@ Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis, gtl::InlinedVector bitmap(data.dims(), false); auto axis_vec = axis.flat(); for (int64 i = 0; i < axis.NumElements(); ++i) { - const int32 index = axis_vec(i); - if (index < 0 || index >= data.dims()) { + int32 index = axis_vec(i); + if (index < -data.dims() || index >= data.dims()) { return errors::InvalidArgument("Invalid reduction dimension (", index, " for input with ", data.dims(), " dimension(s)"); } + index = (index + data.dims()) % data.dims(); bitmap[index] = true; } diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py index d330040db4d..e9d21fe1a40 100644 --- a/tensorflow/python/kernel_tests/reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/reduction_ops_test.py @@ -27,34 +27,40 @@ from tensorflow.python.ops import math_ops class ReducedShapeTest(tf.test.TestCase): + def _check(self, shape, axes, result): + output = math_ops.reduced_shape(shape, axes=axes) + self.assertAllEqual(output.eval(), result) + def testSimple(self): with self.test_session(): - def check(shape, axes, result): - output = math_ops.reduced_shape(shape, axes=axes) - self.assertAllEqual(output.eval(), result) - check([3], [], [3]) - check([3], [0], [1]) - check([5, 3], [], [5, 3]) - check([5, 3], [0], [1, 3]) - check([5, 3], [1], [5, 1]) - check([5, 3], [0, 1], [1, 1]) + self._check([3], [], [3]) + self._check([3], [0], [1]) + self._check([5, 3], [], [5, 3]) + self._check([5, 3], [0], [1, 3]) + self._check([5, 3], [1], [5, 1]) + self._check([5, 3], [0, 1], [1, 1]) def testZeros(self): """Check that reduced_shape does the right thing with zero dimensions.""" with self.test_session(): - def check(shape, axes, result): - output = math_ops.reduced_shape(shape, axes=axes) - self.assertAllEqual(output.eval(), result) - check([0], [], [0]) - check([0], [0], [1]) - check([0, 3], [], [0, 3]) - check([0, 3], [0], [1, 3]) - check([0, 3], [1], [0, 1]) - check([0, 3], [0, 1], [1, 1]) - check([3, 0], [], [3, 0]) - check([3, 0], [0], [1, 0]) - check([3, 0], [1], [3, 1]) - check([3, 0], [0, 1], [1, 1]) + self._check([0], [], [0]) + self._check([0], [0], [1]) + self._check([0, 3], [], [0, 3]) + self._check([0, 3], [0], [1, 3]) + self._check([0, 3], [1], [0, 1]) + self._check([0, 3], [0, 1], [1, 1]) + self._check([3, 0], [], [3, 0]) + self._check([3, 0], [0], [1, 0]) + self._check([3, 0], [1], [3, 1]) + self._check([3, 0], [0, 1], [1, 1]) + + def testNegAxes(self): + with self.test_session(): + self._check([10, 10, 10], [-1], [10, 10, 1]) + self._check([10, 10, 10], [-1, 2], [10, 10, 1]) + self._check([10, 10, 10], [-1, -1], [10, 10, 1]) + self._check([10, 10, 10], [-1, 0], [1, 10, 1]) + self._check([10, 10, 10], [-3], [1, 10, 10]) class SumReductionTest(tf.test.TestCase): @@ -110,6 +116,9 @@ class SumReductionTest(tf.test.TestCase): self._compareAll(np_arr, [1, 2]) self._compareAll(np_arr, [0, 2]) self._compareAll(np_arr, [0, 1, 2]) + self._compareAll(np_arr, [-1]) + self._compareAll(np_arr, [-1, -3]) + self._compareAll(np_arr, [-1, 1]) def testFloatReduce4D(self): # Create a 4D array of floats and reduce across some @@ -167,7 +176,7 @@ class SumReductionTest(tf.test.TestCase): input_tensor = tf.convert_to_tensor(np_arr) with self.assertRaisesWithPredicateMatch( ValueError, lambda e: "Invalid reduction dimension" in str(e)): - tf.reduce_sum(input_tensor, [-1]) + tf.reduce_sum(input_tensor, [-3]) with self.assertRaisesWithPredicateMatch( ValueError, lambda e: "Invalid reduction dimension" in str(e)): tf.reduce_sum(input_tensor, [2]) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index adb673f8510..c622d834906 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1527,10 +1527,14 @@ def _ReductionShape(op): reduction_indices = np.ravel(reduction_indices) for reduction_index in reduction_indices: - if reduction_index < 0 or reduction_index >= input_shape.ndims: + if (reduction_index < -input_shape.ndims or + reduction_index >= input_shape.ndims): raise ValueError("Invalid reduction dimension %d for input with %d " "dimensions" % (reduction_index, input_shape.ndims)) + reduction_indices = set([(x + input_shape.ndims) % input_shape.ndims + for x in reduction_indices]) + returned_dims = [] if keep_dims: for i, dim in enumerate(input_shape.dims): @@ -1624,6 +1628,7 @@ def reduced_shape(input_shape, axes): axes = to_int32(axes) # [1, 2] input_rank = array_ops.size(input_shape) # 4 + axes = (axes + input_rank) % input_rank axes_shape = array_ops.shape(axes) # [2] return gen_data_flow_ops.dynamic_stitch( # [2, 1, 1, 7] [range(input_rank), # [0, 1, 2, 3]