diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index ecc8905d141..d381086acbb 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -1106,3 +1106,17 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "ragged_reverse_op_test", + srcs = ["ragged_reverse_op_test.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + deps = [ + ":ragged_array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/python/ops/ragged/ragged_array_ops.py b/tensorflow/python/ops/ragged/ragged_array_ops.py index d5fcabf410a..013c49b8e22 100644 --- a/tensorflow/python/ops/ragged/ragged_array_ops.py +++ b/tensorflow/python/ops/ragged/ragged_array_ops.py @@ -644,3 +644,54 @@ def stack_dynamic_partitions(data, partitions, num_partitions, name=None): with ops.control_dependencies([check]): return stack_dynamic_partitions(data.values, partitions.values, num_partitions) + + +#=============================================================================== +# Reverse +#=============================================================================== +def reverse(tensor, axis, name=None): + """Reverses a RaggedTensor along the specified axes. + + #### Example: + + >>> data = tf.ragged.constant([ + ... [[1, 2], [3, 4]], [[5, 6]], [[7, 8], [9, 10], [11, 12]]]) + >>> tf.reverse(data, axis=[0, 2]) + + + Args: + tensor: A 'RaggedTensor' to reverse. + axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices + of the axes to reverse. + name: A name prefix for the returned tensor (optional). + + Returns: + A 'RaggedTensor'. + """ + type_error_msg = ('`axis` must be a list of int or a constant tensor' + 'when reversing axes in a ragged tensor') + + with ops.name_scope(name, 'Reverse', [tensor, axis]): + if isinstance(axis, ops.Tensor): + axis = tensor_util.constant_value(axis) + if axis is None: + raise TypeError(type_error_msg) + elif not (isinstance(axis, (list, tuple)) and + all(isinstance(dim, int) for dim in axis)): + raise TypeError(type_error_msg) + + tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor( + tensor, name='tensor') + + # Allow usage of negative values to specify innermost axes. + axis = [ragged_util.get_positive_axis(dim, tensor.shape.rank) + for dim in axis] + + # We only need to slice up to the max axis. If the axis list + # is empty, it should be 0. + slices = [slice(None)] * (max(axis) + 1 if axis else 0) + + for dim in axis: + slices[dim] = slice(None, None, -1) + + return tensor[tuple(slices)] diff --git a/tensorflow/python/ops/ragged/ragged_dispatch.py b/tensorflow/python/ops/ragged/ragged_dispatch.py index 871c7ee9c71..8a32e407fd5 100644 --- a/tensorflow/python/ops/ragged/ragged_dispatch.py +++ b/tensorflow/python/ops/ragged/ragged_dispatch.py @@ -461,6 +461,7 @@ _RAGGED_DISPATCH_OPS = [ 'indices']), (array_ops.one_hot, ragged_array_ops.ragged_one_hot, ['indices']), (array_ops.rank, ragged_array_ops.rank, ['input']), + (array_ops.reverse, ragged_array_ops.reverse, ['tensor']), (array_ops.size, _ragged_size_v1, ['input']), (array_ops.size_v2, ragged_array_ops.size, ['input']), (array_ops.squeeze, _ragged_squeeze_v1, ['input']), diff --git a/tensorflow/python/ops/ragged/ragged_dispatch_test.py b/tensorflow/python/ops/ragged/ragged_dispatch_test.py index 9b75109db6e..eafd2cde32d 100644 --- a/tensorflow/python/ops/ragged/ragged_dispatch_test.py +++ b/tensorflow/python/ops/ragged/ragged_dispatch_test.py @@ -756,6 +756,13 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, ragged_factory_ops.constant_value([[1]]) ], result_is_list=True), + dict( + op=array_ops.reverse, + kwargs={ + 'tensor': ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]]), + 'axis': [0, -1] + }, + expected=ragged_factory_ops.constant_value([[5, 4], [3, 2, 1]])) ]) def testRaggedDispatch(self, op, expected, args=(), result_is_list=False, kwargs=None): @@ -802,7 +809,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, 'strings.substr', 'strings.to_hash_bucket_fast', 'strings.to_hash_bucket_strong', 'strings.to_hash_bucket', 'strings.to_number', 'strings.unicode_script', 'tile', 'truncatediv', - 'truncatemod', 'zeros_like', 'dynamic_partition' + 'truncatemod', 'zeros_like', 'dynamic_partition', 'reverse' ] # Ops that should be listed as supported in v1 only. diff --git a/tensorflow/python/ops/ragged/ragged_reverse_op_test.py b/tensorflow/python/ops/ragged/ragged_reverse_op_test.py new file mode 100644 index 00000000000..c0bd40941ab --- /dev/null +++ b/tensorflow/python/ops/ragged/ragged_reverse_op_test.py @@ -0,0 +1,90 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ragged_array_ops.reverse.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.framework import test_util +from tensorflow.python.ops.ragged import ragged_array_ops +from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.platform import googletest + + +@test_util.run_all_in_graph_and_eager_modes +class RaggedReverseOpTest(test_util.TensorFlowTestCase, + parameterized.TestCase): + + @parameterized.parameters([ + dict( + descr='Docstring example 1', + data=[[[1, 2], [3, 4]], + [[5, 6]], + [[7, 8], [9, 10], [11, 12]]], + axis=[0, 2], + expected=[[[8, 7], [10, 9], [12, 11]], + [[6, 5]], + [[2, 1], [4, 3]]]), + dict( + descr='data.shape=[5, (D2)]; axis=[0]', + data=[[1, 2], [3, 4, 5, 6], [7, 8, 9], [], [1, 2, 3]], + axis=[0], + expected=[[1, 2, 3], [], [7, 8, 9], [3, 4, 5, 6], [1, 2]]), + dict( + descr='data.shape=[5, (D2)]; axis=[1]', + data=[[1, 2], [3, 4, 5, 6], [7, 8, 9], [], [1, 2, 3]], + axis=[1], + expected=[[2, 1], [6, 5, 4, 3], [9, 8, 7], [], [3, 2, 1]]), + dict( + descr='data.shape=[5, (D2), (D3)]; axis=[0, -1]', + data=[[[1], [2, 3]], [[4, 5], [6, 7]], [[8]]], + axis=[0, -1], + expected=[[[8]], [[5, 4], [7, 6]], [[1], [3, 2]]]), + dict( + descr='data.shape=[2, (D2), 2]; axis=[2]', + data=[[[1, 2], [3, 4]], [[5, 6]]], + axis=[2], + expected=[[[2, 1], [4, 3]], [[6, 5]]], + ragged_rank=1), + dict( + descr='data.shape=[2, (D2), (D3)]; axis=[-1]', + data=[[[1, 2], [3, 4]], [[5, 6]]], + axis=[-1], + expected=[[[2, 1], [4, 3]], [[6, 5]]]), + dict( + descr='data.shape=[2, (D2), (D3)]; axis=[]', + data=[[[1, 2], [3, 4]], [[5, 6]]], + axis=[], + expected=[[[1, 2], [3, 4]], [[5, 6]]]) + ]) # pyformat: disable + def testReverse(self, descr, data, axis, expected, ragged_rank=None): + data = ragged_factory_ops.constant(data, ragged_rank=ragged_rank) + result = ragged_array_ops.reverse(data, axis) + expected = ragged_factory_ops.constant(expected, ragged_rank=ragged_rank) + self.assertAllClose(result, expected) + + def testErrors(self): + self.assertRaisesRegexp( + TypeError, '`axis` must be a list of int or a constant tensor *', + ragged_array_ops.reverse, + ragged_factory_ops.constant([[1], [2, 3]], ragged_rank=1), + [0, None]) + + +if __name__ == '__main__': + googletest.main()