Support negative values in the reduction_indices argument of reduce_*
functions. Fixes #2426 Change: 122735328
This commit is contained in:
parent
144855b385
commit
cce41d3fa3
@ -61,12 +61,13 @@ Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
|
|||||||
gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
|
gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
|
||||||
auto axis_vec = axis.flat<int32>();
|
auto axis_vec = axis.flat<int32>();
|
||||||
for (int64 i = 0; i < axis.NumElements(); ++i) {
|
for (int64 i = 0; i < axis.NumElements(); ++i) {
|
||||||
const int32 index = axis_vec(i);
|
int32 index = axis_vec(i);
|
||||||
if (index < 0 || index >= data.dims()) {
|
if (index < -data.dims() || index >= data.dims()) {
|
||||||
return errors::InvalidArgument("Invalid reduction dimension (", index,
|
return errors::InvalidArgument("Invalid reduction dimension (", index,
|
||||||
" for input with ", data.dims(),
|
" for input with ", data.dims(),
|
||||||
" dimension(s)");
|
" dimension(s)");
|
||||||
}
|
}
|
||||||
|
index = (index + data.dims()) % data.dims();
|
||||||
bitmap[index] = true;
|
bitmap[index] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,34 +27,40 @@ from tensorflow.python.ops import math_ops
|
|||||||
|
|
||||||
class ReducedShapeTest(tf.test.TestCase):
|
class ReducedShapeTest(tf.test.TestCase):
|
||||||
|
|
||||||
def testSimple(self):
|
def _check(self, shape, axes, result):
|
||||||
with self.test_session():
|
|
||||||
def check(shape, axes, result):
|
|
||||||
output = math_ops.reduced_shape(shape, axes=axes)
|
output = math_ops.reduced_shape(shape, axes=axes)
|
||||||
self.assertAllEqual(output.eval(), result)
|
self.assertAllEqual(output.eval(), result)
|
||||||
check([3], [], [3])
|
|
||||||
check([3], [0], [1])
|
def testSimple(self):
|
||||||
check([5, 3], [], [5, 3])
|
with self.test_session():
|
||||||
check([5, 3], [0], [1, 3])
|
self._check([3], [], [3])
|
||||||
check([5, 3], [1], [5, 1])
|
self._check([3], [0], [1])
|
||||||
check([5, 3], [0, 1], [1, 1])
|
self._check([5, 3], [], [5, 3])
|
||||||
|
self._check([5, 3], [0], [1, 3])
|
||||||
|
self._check([5, 3], [1], [5, 1])
|
||||||
|
self._check([5, 3], [0, 1], [1, 1])
|
||||||
|
|
||||||
def testZeros(self):
|
def testZeros(self):
|
||||||
"""Check that reduced_shape does the right thing with zero dimensions."""
|
"""Check that reduced_shape does the right thing with zero dimensions."""
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
def check(shape, axes, result):
|
self._check([0], [], [0])
|
||||||
output = math_ops.reduced_shape(shape, axes=axes)
|
self._check([0], [0], [1])
|
||||||
self.assertAllEqual(output.eval(), result)
|
self._check([0, 3], [], [0, 3])
|
||||||
check([0], [], [0])
|
self._check([0, 3], [0], [1, 3])
|
||||||
check([0], [0], [1])
|
self._check([0, 3], [1], [0, 1])
|
||||||
check([0, 3], [], [0, 3])
|
self._check([0, 3], [0, 1], [1, 1])
|
||||||
check([0, 3], [0], [1, 3])
|
self._check([3, 0], [], [3, 0])
|
||||||
check([0, 3], [1], [0, 1])
|
self._check([3, 0], [0], [1, 0])
|
||||||
check([0, 3], [0, 1], [1, 1])
|
self._check([3, 0], [1], [3, 1])
|
||||||
check([3, 0], [], [3, 0])
|
self._check([3, 0], [0, 1], [1, 1])
|
||||||
check([3, 0], [0], [1, 0])
|
|
||||||
check([3, 0], [1], [3, 1])
|
def testNegAxes(self):
|
||||||
check([3, 0], [0, 1], [1, 1])
|
with self.test_session():
|
||||||
|
self._check([10, 10, 10], [-1], [10, 10, 1])
|
||||||
|
self._check([10, 10, 10], [-1, 2], [10, 10, 1])
|
||||||
|
self._check([10, 10, 10], [-1, -1], [10, 10, 1])
|
||||||
|
self._check([10, 10, 10], [-1, 0], [1, 10, 1])
|
||||||
|
self._check([10, 10, 10], [-3], [1, 10, 10])
|
||||||
|
|
||||||
|
|
||||||
class SumReductionTest(tf.test.TestCase):
|
class SumReductionTest(tf.test.TestCase):
|
||||||
@ -110,6 +116,9 @@ class SumReductionTest(tf.test.TestCase):
|
|||||||
self._compareAll(np_arr, [1, 2])
|
self._compareAll(np_arr, [1, 2])
|
||||||
self._compareAll(np_arr, [0, 2])
|
self._compareAll(np_arr, [0, 2])
|
||||||
self._compareAll(np_arr, [0, 1, 2])
|
self._compareAll(np_arr, [0, 1, 2])
|
||||||
|
self._compareAll(np_arr, [-1])
|
||||||
|
self._compareAll(np_arr, [-1, -3])
|
||||||
|
self._compareAll(np_arr, [-1, 1])
|
||||||
|
|
||||||
def testFloatReduce4D(self):
|
def testFloatReduce4D(self):
|
||||||
# Create a 4D array of floats and reduce across some
|
# Create a 4D array of floats and reduce across some
|
||||||
@ -167,7 +176,7 @@ class SumReductionTest(tf.test.TestCase):
|
|||||||
input_tensor = tf.convert_to_tensor(np_arr)
|
input_tensor = tf.convert_to_tensor(np_arr)
|
||||||
with self.assertRaisesWithPredicateMatch(
|
with self.assertRaisesWithPredicateMatch(
|
||||||
ValueError, lambda e: "Invalid reduction dimension" in str(e)):
|
ValueError, lambda e: "Invalid reduction dimension" in str(e)):
|
||||||
tf.reduce_sum(input_tensor, [-1])
|
tf.reduce_sum(input_tensor, [-3])
|
||||||
with self.assertRaisesWithPredicateMatch(
|
with self.assertRaisesWithPredicateMatch(
|
||||||
ValueError, lambda e: "Invalid reduction dimension" in str(e)):
|
ValueError, lambda e: "Invalid reduction dimension" in str(e)):
|
||||||
tf.reduce_sum(input_tensor, [2])
|
tf.reduce_sum(input_tensor, [2])
|
||||||
|
@ -1527,10 +1527,14 @@ def _ReductionShape(op):
|
|||||||
reduction_indices = np.ravel(reduction_indices)
|
reduction_indices = np.ravel(reduction_indices)
|
||||||
|
|
||||||
for reduction_index in reduction_indices:
|
for reduction_index in reduction_indices:
|
||||||
if reduction_index < 0 or reduction_index >= input_shape.ndims:
|
if (reduction_index < -input_shape.ndims or
|
||||||
|
reduction_index >= input_shape.ndims):
|
||||||
raise ValueError("Invalid reduction dimension %d for input with %d "
|
raise ValueError("Invalid reduction dimension %d for input with %d "
|
||||||
"dimensions" % (reduction_index, input_shape.ndims))
|
"dimensions" % (reduction_index, input_shape.ndims))
|
||||||
|
|
||||||
|
reduction_indices = set([(x + input_shape.ndims) % input_shape.ndims
|
||||||
|
for x in reduction_indices])
|
||||||
|
|
||||||
returned_dims = []
|
returned_dims = []
|
||||||
if keep_dims:
|
if keep_dims:
|
||||||
for i, dim in enumerate(input_shape.dims):
|
for i, dim in enumerate(input_shape.dims):
|
||||||
@ -1624,6 +1628,7 @@ def reduced_shape(input_shape, axes):
|
|||||||
axes = to_int32(axes) # [1, 2]
|
axes = to_int32(axes) # [1, 2]
|
||||||
|
|
||||||
input_rank = array_ops.size(input_shape) # 4
|
input_rank = array_ops.size(input_shape) # 4
|
||||||
|
axes = (axes + input_rank) % input_rank
|
||||||
axes_shape = array_ops.shape(axes) # [2]
|
axes_shape = array_ops.shape(axes) # [2]
|
||||||
return gen_data_flow_ops.dynamic_stitch( # [2, 1, 1, 7]
|
return gen_data_flow_ops.dynamic_stitch( # [2, 1, 1, 7]
|
||||||
[range(input_rank), # [0, 1, 2, 3]
|
[range(input_rank), # [0, 1, 2, 3]
|
||||||
|
Loading…
Reference in New Issue
Block a user