Improve error messages from helper function array_ops.get_positive_axis
.
PiperOrigin-RevId: 297615399 Change-Id: I5c3817e1e9dfefe0acec8abb87bce15d3cfb967b
This commit is contained in:
parent
456f260a3a
commit
a16ca400dc
@ -5396,7 +5396,7 @@ def convert_to_int_tensor(tensor, name, dtype=dtypes.int32):
|
||||
return tensor
|
||||
|
||||
|
||||
def get_positive_axis(axis, ndims):
|
||||
def get_positive_axis(axis, ndims, axis_name="axis", ndims_name="ndims"):
|
||||
"""Validate an `axis` parameter, and normalize it to be positive.
|
||||
|
||||
If `ndims` is known (i.e., not `None`), then check that `axis` is in the
|
||||
@ -5408,6 +5408,8 @@ def get_positive_axis(axis, ndims):
|
||||
Args:
|
||||
axis: An integer constant
|
||||
ndims: An integer constant, or `None`
|
||||
axis_name: The name of `axis` (for error messages).
|
||||
ndims_name: The name of `ndims` (for error messages).
|
||||
|
||||
Returns:
|
||||
The normalized `axis` value.
|
||||
@ -5417,17 +5419,19 @@ def get_positive_axis(axis, ndims):
|
||||
`ndims is None`.
|
||||
"""
|
||||
if not isinstance(axis, int):
|
||||
raise TypeError("axis must be an int; got %s" % type(axis).__name__)
|
||||
raise TypeError("%s must be an int; got %s" %
|
||||
(axis_name, type(axis).__name__))
|
||||
if ndims is not None:
|
||||
if 0 <= axis < ndims:
|
||||
return axis
|
||||
elif -ndims <= axis < 0:
|
||||
return axis + ndims
|
||||
else:
|
||||
raise ValueError("axis=%s out of bounds: expected %s<=axis<%s" %
|
||||
(axis, -ndims, ndims))
|
||||
raise ValueError("%s=%s out of bounds: expected %s<=%s<%s" %
|
||||
(axis_name, axis, -ndims, axis_name, ndims))
|
||||
elif axis < 0:
|
||||
raise ValueError("axis may only be negative if ndims is statically known.")
|
||||
raise ValueError("%s may only be negative if %s is statically known." %
|
||||
(axis_name, ndims_name))
|
||||
return axis
|
||||
|
||||
|
||||
@ -5485,7 +5489,7 @@ def repeat_with_axis(data, repeats, axis, name=None):
|
||||
data_shape = shape(data)
|
||||
|
||||
# If `axis` is negative, then convert it to a positive value.
|
||||
axis = get_positive_axis(axis, data.shape.ndims)
|
||||
axis = get_positive_axis(axis, data.shape.rank, ndims_name="rank(data)")
|
||||
|
||||
# Check data Tensor shapes.
|
||||
if repeats.shape.ndims == 1:
|
||||
|
@ -436,7 +436,7 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
|
||||
return array_ops.expand_dims(input, axis)
|
||||
|
||||
ndims = None if input.shape.ndims is None else input.shape.ndims + 1
|
||||
axis = ragged_util.get_positive_axis(axis, ndims)
|
||||
axis = array_ops.get_positive_axis(axis, ndims, ndims_name='rank(input)')
|
||||
|
||||
if axis == 0:
|
||||
return ragged_tensor.RaggedTensor.from_uniform_row_length(
|
||||
@ -521,7 +521,8 @@ def ragged_one_hot(indices,
|
||||
indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
indices, name='indices')
|
||||
if axis is not None:
|
||||
axis = ragged_util.get_positive_axis(axis, indices.shape.ndims)
|
||||
axis = array_ops.get_positive_axis(
|
||||
axis, indices.shape.ndims, ndims_name='rank(indices)')
|
||||
if axis < indices.ragged_rank:
|
||||
raise ValueError('axis may not be less than indices.ragged_rank.')
|
||||
return indices.with_flat_values(
|
||||
@ -672,8 +673,11 @@ def reverse(tensor, axis, name=None):
|
||||
tensor, name='tensor')
|
||||
|
||||
# Allow usage of negative values to specify innermost axes.
|
||||
axis = [ragged_util.get_positive_axis(dim, tensor.shape.rank)
|
||||
for dim in axis]
|
||||
axis = [
|
||||
array_ops.get_positive_axis(dim, tensor.shape.rank, 'axis[%d]' % i,
|
||||
'rank(tensor)')
|
||||
for i, dim in enumerate(axis)
|
||||
]
|
||||
|
||||
# We only need to slice up to the max axis. If the axis list
|
||||
# is empty, it should be 0.
|
||||
|
@ -161,7 +161,7 @@ def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
|
||||
rt.shape.assert_has_rank(ndims)
|
||||
|
||||
out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1
|
||||
axis = ragged_util.get_positive_axis(axis, out_ndims)
|
||||
axis = array_ops.get_positive_axis(axis, out_ndims)
|
||||
|
||||
if stack_values and ndims == 1 and axis == 0:
|
||||
return ragged_tensor.RaggedTensor.from_row_lengths(
|
||||
|
@ -29,7 +29,6 @@ from tensorflow.python.ops import gen_ragged_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_util
|
||||
from tensorflow.python.ops.ragged import segment_id_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -501,7 +500,9 @@ def ragged_reduce_aggregate(reduce_op,
|
||||
# as the sort with negative axis will have different orders.
|
||||
# See GitHub issue 27497.
|
||||
axis = [
|
||||
ragged_util.get_positive_axis(a, rt_input.shape.ndims) for a in axis
|
||||
array_ops.get_positive_axis(a, rt_input.shape.ndims, 'axis[%s]' % i,
|
||||
'rank(input_tensor)')
|
||||
for i, a in enumerate(axis)
|
||||
]
|
||||
# When reducing multiple axes, just reduce one at a time. This is less
|
||||
# efficient, and only works for associative ops. (In particular, it
|
||||
@ -518,7 +519,8 @@ def ragged_reduce_aggregate(reduce_op,
|
||||
rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
rt_input, name='rt_input')
|
||||
|
||||
axis = ragged_util.get_positive_axis(axis, rt_input.shape.ndims)
|
||||
axis = array_ops.get_positive_axis(
|
||||
axis, rt_input.shape.ndims, ndims_name='rank(input_tensor)')
|
||||
|
||||
if axis == 0:
|
||||
# out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N]
|
||||
|
@ -204,28 +204,28 @@ class RaggedMergeDimsOpTest(test_util.TensorFlowTestCase,
|
||||
'outer_axis': {},
|
||||
'inner_axis': 1,
|
||||
'exception': TypeError,
|
||||
'message': 'axis must be an int',
|
||||
'message': 'outer_axis must be an int',
|
||||
},
|
||||
{
|
||||
'rt': [[1]],
|
||||
'outer_axis': 1,
|
||||
'inner_axis': {},
|
||||
'exception': TypeError,
|
||||
'message': 'axis must be an int',
|
||||
'message': 'inner_axis must be an int',
|
||||
},
|
||||
{
|
||||
'rt': [[1]],
|
||||
'outer_axis': 1,
|
||||
'inner_axis': 3,
|
||||
'exception': ValueError,
|
||||
'message': 'axis=3 out of bounds: expected -2<=axis<2',
|
||||
'message': 'inner_axis=3 out of bounds: expected -2<=inner_axis<2',
|
||||
},
|
||||
{
|
||||
'rt': [[1]],
|
||||
'outer_axis': 1,
|
||||
'inner_axis': -3,
|
||||
'exception': ValueError,
|
||||
'message': 'axis=-3 out of bounds: expected -2<=axis<2',
|
||||
'message': 'inner_axis=-3 out of bounds: expected -2<=inner_axis<2',
|
||||
},
|
||||
{
|
||||
'rt': [[1]],
|
||||
|
@ -25,7 +25,6 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_util
|
||||
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
|
||||
|
||||
|
||||
@ -66,7 +65,10 @@ def squeeze(input, axis=None, name=None): # pylint: disable=redefined-builtin
|
||||
dense_dims = []
|
||||
ragged_dims = []
|
||||
# Normalize all the dims in axis to be positive
|
||||
axis = [ragged_util.get_positive_axis(d, input.shape.ndims) for d in axis]
|
||||
axis = [
|
||||
array_ops.get_positive_axis(d, input.shape.ndims, 'axis[%d]' % i,
|
||||
'rank(input)') for i, d in enumerate(axis)
|
||||
]
|
||||
for dim in axis:
|
||||
if dim > input.ragged_rank:
|
||||
dense_dims.append(dim - input.ragged_rank)
|
||||
|
@ -1337,7 +1337,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
return self._cached_row_lengths
|
||||
|
||||
with ops.name_scope(name, "RaggedRowLengths", [self]):
|
||||
axis = ragged_util.get_positive_axis(axis, self.shape.ndims)
|
||||
axis = array_ops.get_positive_axis(
|
||||
axis, self.shape.rank, ndims_name="rank(self)")
|
||||
if axis == 0:
|
||||
return self.nrows()
|
||||
elif axis == 1:
|
||||
@ -1557,8 +1558,16 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
`self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
|
||||
is the total number of slices in the merged dimensions.
|
||||
"""
|
||||
outer_axis = ragged_util.get_positive_axis(outer_axis, self.shape.ndims)
|
||||
inner_axis = ragged_util.get_positive_axis(inner_axis, self.shape.ndims)
|
||||
outer_axis = array_ops.get_positive_axis(
|
||||
outer_axis,
|
||||
self.shape.rank,
|
||||
axis_name="outer_axis",
|
||||
ndims_name="rank(self)")
|
||||
inner_axis = array_ops.get_positive_axis(
|
||||
inner_axis,
|
||||
self.shape.rank,
|
||||
axis_name="inner_axis",
|
||||
ndims_name="rank(self)")
|
||||
if not outer_axis < inner_axis:
|
||||
raise ValueError("Expected outer_axis (%d) to be less than "
|
||||
"inner_axis (%d)" % (outer_axis, inner_axis))
|
||||
|
Loading…
x
Reference in New Issue
Block a user