Add ragged dispatch for array_ops.reverse

PiperOrigin-RevId: 272229377
This commit is contained in:
Irina Bejan 2019-10-01 09:37:30 -07:00 committed by TensorFlower Gardener
parent 97757f34d2
commit 70e85fa79a
5 changed files with 164 additions and 1 deletions

View File

@ -1106,3 +1106,17 @@ py_test(
"//tensorflow/python:client_testlib",
],
)
py_test(
name = "ragged_reverse_op_test",
srcs = ["ragged_reverse_op_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -644,3 +644,54 @@ def stack_dynamic_partitions(data, partitions, num_partitions, name=None):
with ops.control_dependencies([check]):
return stack_dynamic_partitions(data.values, partitions.values,
num_partitions)
#===============================================================================
# Reverse
#===============================================================================
def reverse(tensor, axis, name=None):
"""Reverses a RaggedTensor along the specified axes.
#### Example:
>>> data = tf.ragged.constant([
... [[1, 2], [3, 4]], [[5, 6]], [[7, 8], [9, 10], [11, 12]]])
>>> tf.reverse(data, axis=[0, 2])
<tf.RaggedTensor [[[8, 7], [10, 9], [12, 11]], [[6, 5]], [[2, 1], [4, 3]]]>
Args:
tensor: A 'RaggedTensor' to reverse.
axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices
of the axes to reverse.
name: A name prefix for the returned tensor (optional).
Returns:
A 'RaggedTensor'.
"""
type_error_msg = ('`axis` must be a list of int or a constant tensor'
'when reversing axes in a ragged tensor')
with ops.name_scope(name, 'Reverse', [tensor, axis]):
if isinstance(axis, ops.Tensor):
axis = tensor_util.constant_value(axis)
if axis is None:
raise TypeError(type_error_msg)
elif not (isinstance(axis, (list, tuple)) and
all(isinstance(dim, int) for dim in axis)):
raise TypeError(type_error_msg)
tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(
tensor, name='tensor')
# Allow usage of negative values to specify innermost axes.
axis = [ragged_util.get_positive_axis(dim, tensor.shape.rank)
for dim in axis]
# We only need to slice up to the max axis. If the axis list
# is empty, it should be 0.
slices = [slice(None)] * (max(axis) + 1 if axis else 0)
for dim in axis:
slices[dim] = slice(None, None, -1)
return tensor[tuple(slices)]

View File

@ -461,6 +461,7 @@ _RAGGED_DISPATCH_OPS = [
'indices']),
(array_ops.one_hot, ragged_array_ops.ragged_one_hot, ['indices']),
(array_ops.rank, ragged_array_ops.rank, ['input']),
(array_ops.reverse, ragged_array_ops.reverse, ['tensor']),
(array_ops.size, _ragged_size_v1, ['input']),
(array_ops.size_v2, ragged_array_ops.size, ['input']),
(array_ops.squeeze, _ragged_squeeze_v1, ['input']),

View File

@ -756,6 +756,13 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
ragged_factory_ops.constant_value([[1]])
],
result_is_list=True),
dict(
op=array_ops.reverse,
kwargs={
'tensor': ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]]),
'axis': [0, -1]
},
expected=ragged_factory_ops.constant_value([[5, 4], [3, 2, 1]]))
])
def testRaggedDispatch(self, op, expected, args=(), result_is_list=False,
kwargs=None):
@ -802,7 +809,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
'strings.substr', 'strings.to_hash_bucket_fast',
'strings.to_hash_bucket_strong', 'strings.to_hash_bucket',
'strings.to_number', 'strings.unicode_script', 'tile', 'truncatediv',
'truncatemod', 'zeros_like', 'dynamic_partition'
'truncatemod', 'zeros_like', 'dynamic_partition', 'reverse'
]
# Ops that should be listed as supported in v1 only.

View File

@ -0,0 +1,90 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ragged_array_ops.reverse."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedReverseOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.parameters([
dict(
descr='Docstring example 1',
data=[[[1, 2], [3, 4]],
[[5, 6]],
[[7, 8], [9, 10], [11, 12]]],
axis=[0, 2],
expected=[[[8, 7], [10, 9], [12, 11]],
[[6, 5]],
[[2, 1], [4, 3]]]),
dict(
descr='data.shape=[5, (D2)]; axis=[0]',
data=[[1, 2], [3, 4, 5, 6], [7, 8, 9], [], [1, 2, 3]],
axis=[0],
expected=[[1, 2, 3], [], [7, 8, 9], [3, 4, 5, 6], [1, 2]]),
dict(
descr='data.shape=[5, (D2)]; axis=[1]',
data=[[1, 2], [3, 4, 5, 6], [7, 8, 9], [], [1, 2, 3]],
axis=[1],
expected=[[2, 1], [6, 5, 4, 3], [9, 8, 7], [], [3, 2, 1]]),
dict(
descr='data.shape=[5, (D2), (D3)]; axis=[0, -1]',
data=[[[1], [2, 3]], [[4, 5], [6, 7]], [[8]]],
axis=[0, -1],
expected=[[[8]], [[5, 4], [7, 6]], [[1], [3, 2]]]),
dict(
descr='data.shape=[2, (D2), 2]; axis=[2]',
data=[[[1, 2], [3, 4]], [[5, 6]]],
axis=[2],
expected=[[[2, 1], [4, 3]], [[6, 5]]],
ragged_rank=1),
dict(
descr='data.shape=[2, (D2), (D3)]; axis=[-1]',
data=[[[1, 2], [3, 4]], [[5, 6]]],
axis=[-1],
expected=[[[2, 1], [4, 3]], [[6, 5]]]),
dict(
descr='data.shape=[2, (D2), (D3)]; axis=[]',
data=[[[1, 2], [3, 4]], [[5, 6]]],
axis=[],
expected=[[[1, 2], [3, 4]], [[5, 6]]])
]) # pyformat: disable
def testReverse(self, descr, data, axis, expected, ragged_rank=None):
data = ragged_factory_ops.constant(data, ragged_rank=ragged_rank)
result = ragged_array_ops.reverse(data, axis)
expected = ragged_factory_ops.constant(expected, ragged_rank=ragged_rank)
self.assertAllClose(result, expected)
def testErrors(self):
self.assertRaisesRegexp(
TypeError, '`axis` must be a list of int or a constant tensor *',
ragged_array_ops.reverse,
ragged_factory_ops.constant([[1], [2, 3]], ragged_rank=1),
[0, None])
if __name__ == '__main__':
googletest.main()