Add ragged dispatch for array_ops.reverse
PiperOrigin-RevId: 272229377
This commit is contained in:
parent
97757f34d2
commit
70e85fa79a
@ -1106,3 +1106,17 @@ py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ragged_reverse_op_test",
|
||||
srcs = ["ragged_reverse_op_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -644,3 +644,54 @@ def stack_dynamic_partitions(data, partitions, num_partitions, name=None):
|
||||
with ops.control_dependencies([check]):
|
||||
return stack_dynamic_partitions(data.values, partitions.values,
|
||||
num_partitions)
|
||||
|
||||
|
||||
#===============================================================================
|
||||
# Reverse
|
||||
#===============================================================================
|
||||
def reverse(tensor, axis, name=None):
|
||||
"""Reverses a RaggedTensor along the specified axes.
|
||||
|
||||
#### Example:
|
||||
|
||||
>>> data = tf.ragged.constant([
|
||||
... [[1, 2], [3, 4]], [[5, 6]], [[7, 8], [9, 10], [11, 12]]])
|
||||
>>> tf.reverse(data, axis=[0, 2])
|
||||
<tf.RaggedTensor [[[8, 7], [10, 9], [12, 11]], [[6, 5]], [[2, 1], [4, 3]]]>
|
||||
|
||||
Args:
|
||||
tensor: A 'RaggedTensor' to reverse.
|
||||
axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices
|
||||
of the axes to reverse.
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
A 'RaggedTensor'.
|
||||
"""
|
||||
type_error_msg = ('`axis` must be a list of int or a constant tensor'
|
||||
'when reversing axes in a ragged tensor')
|
||||
|
||||
with ops.name_scope(name, 'Reverse', [tensor, axis]):
|
||||
if isinstance(axis, ops.Tensor):
|
||||
axis = tensor_util.constant_value(axis)
|
||||
if axis is None:
|
||||
raise TypeError(type_error_msg)
|
||||
elif not (isinstance(axis, (list, tuple)) and
|
||||
all(isinstance(dim, int) for dim in axis)):
|
||||
raise TypeError(type_error_msg)
|
||||
|
||||
tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
tensor, name='tensor')
|
||||
|
||||
# Allow usage of negative values to specify innermost axes.
|
||||
axis = [ragged_util.get_positive_axis(dim, tensor.shape.rank)
|
||||
for dim in axis]
|
||||
|
||||
# We only need to slice up to the max axis. If the axis list
|
||||
# is empty, it should be 0.
|
||||
slices = [slice(None)] * (max(axis) + 1 if axis else 0)
|
||||
|
||||
for dim in axis:
|
||||
slices[dim] = slice(None, None, -1)
|
||||
|
||||
return tensor[tuple(slices)]
|
||||
|
@ -461,6 +461,7 @@ _RAGGED_DISPATCH_OPS = [
|
||||
'indices']),
|
||||
(array_ops.one_hot, ragged_array_ops.ragged_one_hot, ['indices']),
|
||||
(array_ops.rank, ragged_array_ops.rank, ['input']),
|
||||
(array_ops.reverse, ragged_array_ops.reverse, ['tensor']),
|
||||
(array_ops.size, _ragged_size_v1, ['input']),
|
||||
(array_ops.size_v2, ragged_array_ops.size, ['input']),
|
||||
(array_ops.squeeze, _ragged_squeeze_v1, ['input']),
|
||||
|
@ -756,6 +756,13 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
ragged_factory_ops.constant_value([[1]])
|
||||
],
|
||||
result_is_list=True),
|
||||
dict(
|
||||
op=array_ops.reverse,
|
||||
kwargs={
|
||||
'tensor': ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]]),
|
||||
'axis': [0, -1]
|
||||
},
|
||||
expected=ragged_factory_ops.constant_value([[5, 4], [3, 2, 1]]))
|
||||
])
|
||||
def testRaggedDispatch(self, op, expected, args=(), result_is_list=False,
|
||||
kwargs=None):
|
||||
@ -802,7 +809,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
'strings.substr', 'strings.to_hash_bucket_fast',
|
||||
'strings.to_hash_bucket_strong', 'strings.to_hash_bucket',
|
||||
'strings.to_number', 'strings.unicode_script', 'tile', 'truncatediv',
|
||||
'truncatemod', 'zeros_like', 'dynamic_partition'
|
||||
'truncatemod', 'zeros_like', 'dynamic_partition', 'reverse'
|
||||
]
|
||||
|
||||
# Ops that should be listed as supported in v1 only.
|
||||
|
90
tensorflow/python/ops/ragged/ragged_reverse_op_test.py
Normal file
90
tensorflow/python/ops/ragged/ragged_reverse_op_test.py
Normal file
@ -0,0 +1,90 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for ragged_array_ops.reverse."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedReverseOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
dict(
|
||||
descr='Docstring example 1',
|
||||
data=[[[1, 2], [3, 4]],
|
||||
[[5, 6]],
|
||||
[[7, 8], [9, 10], [11, 12]]],
|
||||
axis=[0, 2],
|
||||
expected=[[[8, 7], [10, 9], [12, 11]],
|
||||
[[6, 5]],
|
||||
[[2, 1], [4, 3]]]),
|
||||
dict(
|
||||
descr='data.shape=[5, (D2)]; axis=[0]',
|
||||
data=[[1, 2], [3, 4, 5, 6], [7, 8, 9], [], [1, 2, 3]],
|
||||
axis=[0],
|
||||
expected=[[1, 2, 3], [], [7, 8, 9], [3, 4, 5, 6], [1, 2]]),
|
||||
dict(
|
||||
descr='data.shape=[5, (D2)]; axis=[1]',
|
||||
data=[[1, 2], [3, 4, 5, 6], [7, 8, 9], [], [1, 2, 3]],
|
||||
axis=[1],
|
||||
expected=[[2, 1], [6, 5, 4, 3], [9, 8, 7], [], [3, 2, 1]]),
|
||||
dict(
|
||||
descr='data.shape=[5, (D2), (D3)]; axis=[0, -1]',
|
||||
data=[[[1], [2, 3]], [[4, 5], [6, 7]], [[8]]],
|
||||
axis=[0, -1],
|
||||
expected=[[[8]], [[5, 4], [7, 6]], [[1], [3, 2]]]),
|
||||
dict(
|
||||
descr='data.shape=[2, (D2), 2]; axis=[2]',
|
||||
data=[[[1, 2], [3, 4]], [[5, 6]]],
|
||||
axis=[2],
|
||||
expected=[[[2, 1], [4, 3]], [[6, 5]]],
|
||||
ragged_rank=1),
|
||||
dict(
|
||||
descr='data.shape=[2, (D2), (D3)]; axis=[-1]',
|
||||
data=[[[1, 2], [3, 4]], [[5, 6]]],
|
||||
axis=[-1],
|
||||
expected=[[[2, 1], [4, 3]], [[6, 5]]]),
|
||||
dict(
|
||||
descr='data.shape=[2, (D2), (D3)]; axis=[]',
|
||||
data=[[[1, 2], [3, 4]], [[5, 6]]],
|
||||
axis=[],
|
||||
expected=[[[1, 2], [3, 4]], [[5, 6]]])
|
||||
]) # pyformat: disable
|
||||
def testReverse(self, descr, data, axis, expected, ragged_rank=None):
|
||||
data = ragged_factory_ops.constant(data, ragged_rank=ragged_rank)
|
||||
result = ragged_array_ops.reverse(data, axis)
|
||||
expected = ragged_factory_ops.constant(expected, ragged_rank=ragged_rank)
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def testErrors(self):
|
||||
self.assertRaisesRegexp(
|
||||
TypeError, '`axis` must be a list of int or a constant tensor *',
|
||||
ragged_array_ops.reverse,
|
||||
ragged_factory_ops.constant([[1], [2, 3]], ragged_rank=1),
|
||||
[0, None])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
Loading…
Reference in New Issue
Block a user