Improve error messages from helper function array_ops.get_positive_axis.

PiperOrigin-RevId: 297615399
Change-Id: I5c3817e1e9dfefe0acec8abb87bce15d3cfb967b
This commit is contained in:
Edward Loper 2020-02-27 09:27:31 -08:00 committed by TensorFlower Gardener
parent 456f260a3a
commit a16ca400dc
7 changed files with 44 additions and 23 deletions

View File

@ -5396,7 +5396,7 @@ def convert_to_int_tensor(tensor, name, dtype=dtypes.int32):
return tensor return tensor
def get_positive_axis(axis, ndims): def get_positive_axis(axis, ndims, axis_name="axis", ndims_name="ndims"):
"""Validate an `axis` parameter, and normalize it to be positive. """Validate an `axis` parameter, and normalize it to be positive.
If `ndims` is known (i.e., not `None`), then check that `axis` is in the If `ndims` is known (i.e., not `None`), then check that `axis` is in the
@ -5408,6 +5408,8 @@ def get_positive_axis(axis, ndims):
Args: Args:
axis: An integer constant axis: An integer constant
ndims: An integer constant, or `None` ndims: An integer constant, or `None`
axis_name: The name of `axis` (for error messages).
ndims_name: The name of `ndims` (for error messages).
Returns: Returns:
The normalized `axis` value. The normalized `axis` value.
@ -5417,17 +5419,19 @@ def get_positive_axis(axis, ndims):
`ndims is None`. `ndims is None`.
""" """
if not isinstance(axis, int): if not isinstance(axis, int):
raise TypeError("axis must be an int; got %s" % type(axis).__name__) raise TypeError("%s must be an int; got %s" %
(axis_name, type(axis).__name__))
if ndims is not None: if ndims is not None:
if 0 <= axis < ndims: if 0 <= axis < ndims:
return axis return axis
elif -ndims <= axis < 0: elif -ndims <= axis < 0:
return axis + ndims return axis + ndims
else: else:
raise ValueError("axis=%s out of bounds: expected %s<=axis<%s" % raise ValueError("%s=%s out of bounds: expected %s<=%s<%s" %
(axis, -ndims, ndims)) (axis_name, axis, -ndims, axis_name, ndims))
elif axis < 0: elif axis < 0:
raise ValueError("axis may only be negative if ndims is statically known.") raise ValueError("%s may only be negative if %s is statically known." %
(axis_name, ndims_name))
return axis return axis
@ -5485,7 +5489,7 @@ def repeat_with_axis(data, repeats, axis, name=None):
data_shape = shape(data) data_shape = shape(data)
# If `axis` is negative, then convert it to a positive value. # If `axis` is negative, then convert it to a positive value.
axis = get_positive_axis(axis, data.shape.ndims) axis = get_positive_axis(axis, data.shape.rank, ndims_name="rank(data)")
# Check data Tensor shapes. # Check data Tensor shapes.
if repeats.shape.ndims == 1: if repeats.shape.ndims == 1:

View File

@ -436,7 +436,7 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
return array_ops.expand_dims(input, axis) return array_ops.expand_dims(input, axis)
ndims = None if input.shape.ndims is None else input.shape.ndims + 1 ndims = None if input.shape.ndims is None else input.shape.ndims + 1
axis = ragged_util.get_positive_axis(axis, ndims) axis = array_ops.get_positive_axis(axis, ndims, ndims_name='rank(input)')
if axis == 0: if axis == 0:
return ragged_tensor.RaggedTensor.from_uniform_row_length( return ragged_tensor.RaggedTensor.from_uniform_row_length(
@ -521,7 +521,8 @@ def ragged_one_hot(indices,
indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
indices, name='indices') indices, name='indices')
if axis is not None: if axis is not None:
axis = ragged_util.get_positive_axis(axis, indices.shape.ndims) axis = array_ops.get_positive_axis(
axis, indices.shape.ndims, ndims_name='rank(indices)')
if axis < indices.ragged_rank: if axis < indices.ragged_rank:
raise ValueError('axis may not be less than indices.ragged_rank.') raise ValueError('axis may not be less than indices.ragged_rank.')
return indices.with_flat_values( return indices.with_flat_values(
@ -672,8 +673,11 @@ def reverse(tensor, axis, name=None):
tensor, name='tensor') tensor, name='tensor')
# Allow usage of negative values to specify innermost axes. # Allow usage of negative values to specify innermost axes.
axis = [ragged_util.get_positive_axis(dim, tensor.shape.rank) axis = [
for dim in axis] array_ops.get_positive_axis(dim, tensor.shape.rank, 'axis[%d]' % i,
'rank(tensor)')
for i, dim in enumerate(axis)
]
# We only need to slice up to the max axis. If the axis list # We only need to slice up to the max axis. If the axis list
# is empty, it should be 0. # is empty, it should be 0.

View File

@ -161,7 +161,7 @@ def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
rt.shape.assert_has_rank(ndims) rt.shape.assert_has_rank(ndims)
out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1 out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1
axis = ragged_util.get_positive_axis(axis, out_ndims) axis = array_ops.get_positive_axis(axis, out_ndims)
if stack_values and ndims == 1 and axis == 0: if stack_values and ndims == 1 and axis == 0:
return ragged_tensor.RaggedTensor.from_row_lengths( return ragged_tensor.RaggedTensor.from_row_lengths(

View File

@ -29,7 +29,6 @@ from tensorflow.python.ops import gen_ragged_math_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_functional_ops from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.ops.ragged import segment_id_ops from tensorflow.python.ops.ragged import segment_id_ops
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -501,7 +500,9 @@ def ragged_reduce_aggregate(reduce_op,
# as the sort with negative axis will have different orders. # as the sort with negative axis will have different orders.
# See GitHub issue 27497. # See GitHub issue 27497.
axis = [ axis = [
ragged_util.get_positive_axis(a, rt_input.shape.ndims) for a in axis array_ops.get_positive_axis(a, rt_input.shape.ndims, 'axis[%s]' % i,
'rank(input_tensor)')
for i, a in enumerate(axis)
] ]
# When reducing multiple axes, just reduce one at a time. This is less # When reducing multiple axes, just reduce one at a time. This is less
# efficient, and only works for associative ops. (In particular, it # efficient, and only works for associative ops. (In particular, it
@ -518,7 +519,8 @@ def ragged_reduce_aggregate(reduce_op,
rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor( rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
rt_input, name='rt_input') rt_input, name='rt_input')
axis = ragged_util.get_positive_axis(axis, rt_input.shape.ndims) axis = array_ops.get_positive_axis(
axis, rt_input.shape.ndims, ndims_name='rank(input_tensor)')
if axis == 0: if axis == 0:
# out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N] # out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N]

View File

@ -204,28 +204,28 @@ class RaggedMergeDimsOpTest(test_util.TensorFlowTestCase,
'outer_axis': {}, 'outer_axis': {},
'inner_axis': 1, 'inner_axis': 1,
'exception': TypeError, 'exception': TypeError,
'message': 'axis must be an int', 'message': 'outer_axis must be an int',
}, },
{ {
'rt': [[1]], 'rt': [[1]],
'outer_axis': 1, 'outer_axis': 1,
'inner_axis': {}, 'inner_axis': {},
'exception': TypeError, 'exception': TypeError,
'message': 'axis must be an int', 'message': 'inner_axis must be an int',
}, },
{ {
'rt': [[1]], 'rt': [[1]],
'outer_axis': 1, 'outer_axis': 1,
'inner_axis': 3, 'inner_axis': 3,
'exception': ValueError, 'exception': ValueError,
'message': 'axis=3 out of bounds: expected -2<=axis<2', 'message': 'inner_axis=3 out of bounds: expected -2<=inner_axis<2',
}, },
{ {
'rt': [[1]], 'rt': [[1]],
'outer_axis': 1, 'outer_axis': 1,
'inner_axis': -3, 'inner_axis': -3,
'exception': ValueError, 'exception': ValueError,
'message': 'axis=-3 out of bounds: expected -2<=axis<2', 'message': 'inner_axis=-3 out of bounds: expected -2<=inner_axis<2',
}, },
{ {
'rt': [[1]], 'rt': [[1]],

View File

@ -25,7 +25,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
@ -66,7 +65,10 @@ def squeeze(input, axis=None, name=None): # pylint: disable=redefined-builtin
dense_dims = [] dense_dims = []
ragged_dims = [] ragged_dims = []
# Normalize all the dims in axis to be positive # Normalize all the dims in axis to be positive
axis = [ragged_util.get_positive_axis(d, input.shape.ndims) for d in axis] axis = [
array_ops.get_positive_axis(d, input.shape.ndims, 'axis[%d]' % i,
'rank(input)') for i, d in enumerate(axis)
]
for dim in axis: for dim in axis:
if dim > input.ragged_rank: if dim > input.ragged_rank:
dense_dims.append(dim - input.ragged_rank) dense_dims.append(dim - input.ragged_rank)

View File

@ -1337,7 +1337,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
return self._cached_row_lengths return self._cached_row_lengths
with ops.name_scope(name, "RaggedRowLengths", [self]): with ops.name_scope(name, "RaggedRowLengths", [self]):
axis = ragged_util.get_positive_axis(axis, self.shape.ndims) axis = array_ops.get_positive_axis(
axis, self.shape.rank, ndims_name="rank(self)")
if axis == 0: if axis == 0:
return self.nrows() return self.nrows()
elif axis == 1: elif axis == 1:
@ -1557,8 +1558,16 @@ class RaggedTensor(composite_tensor.CompositeTensor):
`self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N` `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
is the total number of slices in the merged dimensions. is the total number of slices in the merged dimensions.
""" """
outer_axis = ragged_util.get_positive_axis(outer_axis, self.shape.ndims) outer_axis = array_ops.get_positive_axis(
inner_axis = ragged_util.get_positive_axis(inner_axis, self.shape.ndims) outer_axis,
self.shape.rank,
axis_name="outer_axis",
ndims_name="rank(self)")
inner_axis = array_ops.get_positive_axis(
inner_axis,
self.shape.rank,
axis_name="inner_axis",
ndims_name="rank(self)")
if not outer_axis < inner_axis: if not outer_axis < inner_axis:
raise ValueError("Expected outer_axis (%d) to be less than " raise ValueError("Expected outer_axis (%d) to be less than "
"inner_axis (%d)" % (outer_axis, inner_axis)) "inner_axis (%d)" % (outer_axis, inner_axis))