Improve error messages from helper function array_ops.get_positive_axis
.
PiperOrigin-RevId: 297615399 Change-Id: I5c3817e1e9dfefe0acec8abb87bce15d3cfb967b
This commit is contained in:
parent
456f260a3a
commit
a16ca400dc
@ -5396,7 +5396,7 @@ def convert_to_int_tensor(tensor, name, dtype=dtypes.int32):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def get_positive_axis(axis, ndims):
|
def get_positive_axis(axis, ndims, axis_name="axis", ndims_name="ndims"):
|
||||||
"""Validate an `axis` parameter, and normalize it to be positive.
|
"""Validate an `axis` parameter, and normalize it to be positive.
|
||||||
|
|
||||||
If `ndims` is known (i.e., not `None`), then check that `axis` is in the
|
If `ndims` is known (i.e., not `None`), then check that `axis` is in the
|
||||||
@ -5408,6 +5408,8 @@ def get_positive_axis(axis, ndims):
|
|||||||
Args:
|
Args:
|
||||||
axis: An integer constant
|
axis: An integer constant
|
||||||
ndims: An integer constant, or `None`
|
ndims: An integer constant, or `None`
|
||||||
|
axis_name: The name of `axis` (for error messages).
|
||||||
|
ndims_name: The name of `ndims` (for error messages).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The normalized `axis` value.
|
The normalized `axis` value.
|
||||||
@ -5417,17 +5419,19 @@ def get_positive_axis(axis, ndims):
|
|||||||
`ndims is None`.
|
`ndims is None`.
|
||||||
"""
|
"""
|
||||||
if not isinstance(axis, int):
|
if not isinstance(axis, int):
|
||||||
raise TypeError("axis must be an int; got %s" % type(axis).__name__)
|
raise TypeError("%s must be an int; got %s" %
|
||||||
|
(axis_name, type(axis).__name__))
|
||||||
if ndims is not None:
|
if ndims is not None:
|
||||||
if 0 <= axis < ndims:
|
if 0 <= axis < ndims:
|
||||||
return axis
|
return axis
|
||||||
elif -ndims <= axis < 0:
|
elif -ndims <= axis < 0:
|
||||||
return axis + ndims
|
return axis + ndims
|
||||||
else:
|
else:
|
||||||
raise ValueError("axis=%s out of bounds: expected %s<=axis<%s" %
|
raise ValueError("%s=%s out of bounds: expected %s<=%s<%s" %
|
||||||
(axis, -ndims, ndims))
|
(axis_name, axis, -ndims, axis_name, ndims))
|
||||||
elif axis < 0:
|
elif axis < 0:
|
||||||
raise ValueError("axis may only be negative if ndims is statically known.")
|
raise ValueError("%s may only be negative if %s is statically known." %
|
||||||
|
(axis_name, ndims_name))
|
||||||
return axis
|
return axis
|
||||||
|
|
||||||
|
|
||||||
@ -5485,7 +5489,7 @@ def repeat_with_axis(data, repeats, axis, name=None):
|
|||||||
data_shape = shape(data)
|
data_shape = shape(data)
|
||||||
|
|
||||||
# If `axis` is negative, then convert it to a positive value.
|
# If `axis` is negative, then convert it to a positive value.
|
||||||
axis = get_positive_axis(axis, data.shape.ndims)
|
axis = get_positive_axis(axis, data.shape.rank, ndims_name="rank(data)")
|
||||||
|
|
||||||
# Check data Tensor shapes.
|
# Check data Tensor shapes.
|
||||||
if repeats.shape.ndims == 1:
|
if repeats.shape.ndims == 1:
|
||||||
|
@ -436,7 +436,7 @@ def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
|
|||||||
return array_ops.expand_dims(input, axis)
|
return array_ops.expand_dims(input, axis)
|
||||||
|
|
||||||
ndims = None if input.shape.ndims is None else input.shape.ndims + 1
|
ndims = None if input.shape.ndims is None else input.shape.ndims + 1
|
||||||
axis = ragged_util.get_positive_axis(axis, ndims)
|
axis = array_ops.get_positive_axis(axis, ndims, ndims_name='rank(input)')
|
||||||
|
|
||||||
if axis == 0:
|
if axis == 0:
|
||||||
return ragged_tensor.RaggedTensor.from_uniform_row_length(
|
return ragged_tensor.RaggedTensor.from_uniform_row_length(
|
||||||
@ -521,7 +521,8 @@ def ragged_one_hot(indices,
|
|||||||
indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||||
indices, name='indices')
|
indices, name='indices')
|
||||||
if axis is not None:
|
if axis is not None:
|
||||||
axis = ragged_util.get_positive_axis(axis, indices.shape.ndims)
|
axis = array_ops.get_positive_axis(
|
||||||
|
axis, indices.shape.ndims, ndims_name='rank(indices)')
|
||||||
if axis < indices.ragged_rank:
|
if axis < indices.ragged_rank:
|
||||||
raise ValueError('axis may not be less than indices.ragged_rank.')
|
raise ValueError('axis may not be less than indices.ragged_rank.')
|
||||||
return indices.with_flat_values(
|
return indices.with_flat_values(
|
||||||
@ -672,8 +673,11 @@ def reverse(tensor, axis, name=None):
|
|||||||
tensor, name='tensor')
|
tensor, name='tensor')
|
||||||
|
|
||||||
# Allow usage of negative values to specify innermost axes.
|
# Allow usage of negative values to specify innermost axes.
|
||||||
axis = [ragged_util.get_positive_axis(dim, tensor.shape.rank)
|
axis = [
|
||||||
for dim in axis]
|
array_ops.get_positive_axis(dim, tensor.shape.rank, 'axis[%d]' % i,
|
||||||
|
'rank(tensor)')
|
||||||
|
for i, dim in enumerate(axis)
|
||||||
|
]
|
||||||
|
|
||||||
# We only need to slice up to the max axis. If the axis list
|
# We only need to slice up to the max axis. If the axis list
|
||||||
# is empty, it should be 0.
|
# is empty, it should be 0.
|
||||||
|
@ -161,7 +161,7 @@ def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
|
|||||||
rt.shape.assert_has_rank(ndims)
|
rt.shape.assert_has_rank(ndims)
|
||||||
|
|
||||||
out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1
|
out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1
|
||||||
axis = ragged_util.get_positive_axis(axis, out_ndims)
|
axis = array_ops.get_positive_axis(axis, out_ndims)
|
||||||
|
|
||||||
if stack_values and ndims == 1 and axis == 0:
|
if stack_values and ndims == 1 and axis == 0:
|
||||||
return ragged_tensor.RaggedTensor.from_row_lengths(
|
return ragged_tensor.RaggedTensor.from_row_lengths(
|
||||||
|
@ -29,7 +29,6 @@ from tensorflow.python.ops import gen_ragged_math_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_functional_ops
|
from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.ops.ragged import ragged_util
|
|
||||||
from tensorflow.python.ops.ragged import segment_id_ops
|
from tensorflow.python.ops.ragged import segment_id_ops
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
@ -501,7 +500,9 @@ def ragged_reduce_aggregate(reduce_op,
|
|||||||
# as the sort with negative axis will have different orders.
|
# as the sort with negative axis will have different orders.
|
||||||
# See GitHub issue 27497.
|
# See GitHub issue 27497.
|
||||||
axis = [
|
axis = [
|
||||||
ragged_util.get_positive_axis(a, rt_input.shape.ndims) for a in axis
|
array_ops.get_positive_axis(a, rt_input.shape.ndims, 'axis[%s]' % i,
|
||||||
|
'rank(input_tensor)')
|
||||||
|
for i, a in enumerate(axis)
|
||||||
]
|
]
|
||||||
# When reducing multiple axes, just reduce one at a time. This is less
|
# When reducing multiple axes, just reduce one at a time. This is less
|
||||||
# efficient, and only works for associative ops. (In particular, it
|
# efficient, and only works for associative ops. (In particular, it
|
||||||
@ -518,7 +519,8 @@ def ragged_reduce_aggregate(reduce_op,
|
|||||||
rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||||
rt_input, name='rt_input')
|
rt_input, name='rt_input')
|
||||||
|
|
||||||
axis = ragged_util.get_positive_axis(axis, rt_input.shape.ndims)
|
axis = array_ops.get_positive_axis(
|
||||||
|
axis, rt_input.shape.ndims, ndims_name='rank(input_tensor)')
|
||||||
|
|
||||||
if axis == 0:
|
if axis == 0:
|
||||||
# out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N]
|
# out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N]
|
||||||
|
@ -204,28 +204,28 @@ class RaggedMergeDimsOpTest(test_util.TensorFlowTestCase,
|
|||||||
'outer_axis': {},
|
'outer_axis': {},
|
||||||
'inner_axis': 1,
|
'inner_axis': 1,
|
||||||
'exception': TypeError,
|
'exception': TypeError,
|
||||||
'message': 'axis must be an int',
|
'message': 'outer_axis must be an int',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'rt': [[1]],
|
'rt': [[1]],
|
||||||
'outer_axis': 1,
|
'outer_axis': 1,
|
||||||
'inner_axis': {},
|
'inner_axis': {},
|
||||||
'exception': TypeError,
|
'exception': TypeError,
|
||||||
'message': 'axis must be an int',
|
'message': 'inner_axis must be an int',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'rt': [[1]],
|
'rt': [[1]],
|
||||||
'outer_axis': 1,
|
'outer_axis': 1,
|
||||||
'inner_axis': 3,
|
'inner_axis': 3,
|
||||||
'exception': ValueError,
|
'exception': ValueError,
|
||||||
'message': 'axis=3 out of bounds: expected -2<=axis<2',
|
'message': 'inner_axis=3 out of bounds: expected -2<=inner_axis<2',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'rt': [[1]],
|
'rt': [[1]],
|
||||||
'outer_axis': 1,
|
'outer_axis': 1,
|
||||||
'inner_axis': -3,
|
'inner_axis': -3,
|
||||||
'exception': ValueError,
|
'exception': ValueError,
|
||||||
'message': 'axis=-3 out of bounds: expected -2<=axis<2',
|
'message': 'inner_axis=-3 out of bounds: expected -2<=inner_axis<2',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'rt': [[1]],
|
'rt': [[1]],
|
||||||
|
@ -25,7 +25,6 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.ops.ragged import ragged_util
|
|
||||||
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
|
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
|
||||||
|
|
||||||
|
|
||||||
@ -66,7 +65,10 @@ def squeeze(input, axis=None, name=None): # pylint: disable=redefined-builtin
|
|||||||
dense_dims = []
|
dense_dims = []
|
||||||
ragged_dims = []
|
ragged_dims = []
|
||||||
# Normalize all the dims in axis to be positive
|
# Normalize all the dims in axis to be positive
|
||||||
axis = [ragged_util.get_positive_axis(d, input.shape.ndims) for d in axis]
|
axis = [
|
||||||
|
array_ops.get_positive_axis(d, input.shape.ndims, 'axis[%d]' % i,
|
||||||
|
'rank(input)') for i, d in enumerate(axis)
|
||||||
|
]
|
||||||
for dim in axis:
|
for dim in axis:
|
||||||
if dim > input.ragged_rank:
|
if dim > input.ragged_rank:
|
||||||
dense_dims.append(dim - input.ragged_rank)
|
dense_dims.append(dim - input.ragged_rank)
|
||||||
|
@ -1337,7 +1337,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
|||||||
return self._cached_row_lengths
|
return self._cached_row_lengths
|
||||||
|
|
||||||
with ops.name_scope(name, "RaggedRowLengths", [self]):
|
with ops.name_scope(name, "RaggedRowLengths", [self]):
|
||||||
axis = ragged_util.get_positive_axis(axis, self.shape.ndims)
|
axis = array_ops.get_positive_axis(
|
||||||
|
axis, self.shape.rank, ndims_name="rank(self)")
|
||||||
if axis == 0:
|
if axis == 0:
|
||||||
return self.nrows()
|
return self.nrows()
|
||||||
elif axis == 1:
|
elif axis == 1:
|
||||||
@ -1557,8 +1558,16 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
|||||||
`self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
|
`self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
|
||||||
is the total number of slices in the merged dimensions.
|
is the total number of slices in the merged dimensions.
|
||||||
"""
|
"""
|
||||||
outer_axis = ragged_util.get_positive_axis(outer_axis, self.shape.ndims)
|
outer_axis = array_ops.get_positive_axis(
|
||||||
inner_axis = ragged_util.get_positive_axis(inner_axis, self.shape.ndims)
|
outer_axis,
|
||||||
|
self.shape.rank,
|
||||||
|
axis_name="outer_axis",
|
||||||
|
ndims_name="rank(self)")
|
||||||
|
inner_axis = array_ops.get_positive_axis(
|
||||||
|
inner_axis,
|
||||||
|
self.shape.rank,
|
||||||
|
axis_name="inner_axis",
|
||||||
|
ndims_name="rank(self)")
|
||||||
if not outer_axis < inner_axis:
|
if not outer_axis < inner_axis:
|
||||||
raise ValueError("Expected outer_axis (%d) to be less than "
|
raise ValueError("Expected outer_axis (%d) to be less than "
|
||||||
"inner_axis (%d)" % (outer_axis, inner_axis))
|
"inner_axis (%d)" % (outer_axis, inner_axis))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user