Add ragged dispatch for array_ops.reverse
PiperOrigin-RevId: 272229377
This commit is contained in:
parent
97757f34d2
commit
70e85fa79a
@ -1106,3 +1106,17 @@ py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "ragged_reverse_op_test",
|
||||||
|
srcs = ["ragged_reverse_op_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":ragged_array_ops",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -644,3 +644,54 @@ def stack_dynamic_partitions(data, partitions, num_partitions, name=None):
|
|||||||
with ops.control_dependencies([check]):
|
with ops.control_dependencies([check]):
|
||||||
return stack_dynamic_partitions(data.values, partitions.values,
|
return stack_dynamic_partitions(data.values, partitions.values,
|
||||||
num_partitions)
|
num_partitions)
|
||||||
|
|
||||||
|
|
||||||
|
#===============================================================================
|
||||||
|
# Reverse
|
||||||
|
#===============================================================================
|
||||||
|
def reverse(tensor, axis, name=None):
|
||||||
|
"""Reverses a RaggedTensor along the specified axes.
|
||||||
|
|
||||||
|
#### Example:
|
||||||
|
|
||||||
|
>>> data = tf.ragged.constant([
|
||||||
|
... [[1, 2], [3, 4]], [[5, 6]], [[7, 8], [9, 10], [11, 12]]])
|
||||||
|
>>> tf.reverse(data, axis=[0, 2])
|
||||||
|
<tf.RaggedTensor [[[8, 7], [10, 9], [12, 11]], [[6, 5]], [[2, 1], [4, 3]]]>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: A 'RaggedTensor' to reverse.
|
||||||
|
axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices
|
||||||
|
of the axes to reverse.
|
||||||
|
name: A name prefix for the returned tensor (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 'RaggedTensor'.
|
||||||
|
"""
|
||||||
|
type_error_msg = ('`axis` must be a list of int or a constant tensor'
|
||||||
|
'when reversing axes in a ragged tensor')
|
||||||
|
|
||||||
|
with ops.name_scope(name, 'Reverse', [tensor, axis]):
|
||||||
|
if isinstance(axis, ops.Tensor):
|
||||||
|
axis = tensor_util.constant_value(axis)
|
||||||
|
if axis is None:
|
||||||
|
raise TypeError(type_error_msg)
|
||||||
|
elif not (isinstance(axis, (list, tuple)) and
|
||||||
|
all(isinstance(dim, int) for dim in axis)):
|
||||||
|
raise TypeError(type_error_msg)
|
||||||
|
|
||||||
|
tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||||
|
tensor, name='tensor')
|
||||||
|
|
||||||
|
# Allow usage of negative values to specify innermost axes.
|
||||||
|
axis = [ragged_util.get_positive_axis(dim, tensor.shape.rank)
|
||||||
|
for dim in axis]
|
||||||
|
|
||||||
|
# We only need to slice up to the max axis. If the axis list
|
||||||
|
# is empty, it should be 0.
|
||||||
|
slices = [slice(None)] * (max(axis) + 1 if axis else 0)
|
||||||
|
|
||||||
|
for dim in axis:
|
||||||
|
slices[dim] = slice(None, None, -1)
|
||||||
|
|
||||||
|
return tensor[tuple(slices)]
|
||||||
|
@ -461,6 +461,7 @@ _RAGGED_DISPATCH_OPS = [
|
|||||||
'indices']),
|
'indices']),
|
||||||
(array_ops.one_hot, ragged_array_ops.ragged_one_hot, ['indices']),
|
(array_ops.one_hot, ragged_array_ops.ragged_one_hot, ['indices']),
|
||||||
(array_ops.rank, ragged_array_ops.rank, ['input']),
|
(array_ops.rank, ragged_array_ops.rank, ['input']),
|
||||||
|
(array_ops.reverse, ragged_array_ops.reverse, ['tensor']),
|
||||||
(array_ops.size, _ragged_size_v1, ['input']),
|
(array_ops.size, _ragged_size_v1, ['input']),
|
||||||
(array_ops.size_v2, ragged_array_ops.size, ['input']),
|
(array_ops.size_v2, ragged_array_ops.size, ['input']),
|
||||||
(array_ops.squeeze, _ragged_squeeze_v1, ['input']),
|
(array_ops.squeeze, _ragged_squeeze_v1, ['input']),
|
||||||
|
@ -756,6 +756,13 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
ragged_factory_ops.constant_value([[1]])
|
ragged_factory_ops.constant_value([[1]])
|
||||||
],
|
],
|
||||||
result_is_list=True),
|
result_is_list=True),
|
||||||
|
dict(
|
||||||
|
op=array_ops.reverse,
|
||||||
|
kwargs={
|
||||||
|
'tensor': ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]]),
|
||||||
|
'axis': [0, -1]
|
||||||
|
},
|
||||||
|
expected=ragged_factory_ops.constant_value([[5, 4], [3, 2, 1]]))
|
||||||
])
|
])
|
||||||
def testRaggedDispatch(self, op, expected, args=(), result_is_list=False,
|
def testRaggedDispatch(self, op, expected, args=(), result_is_list=False,
|
||||||
kwargs=None):
|
kwargs=None):
|
||||||
@ -802,7 +809,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
'strings.substr', 'strings.to_hash_bucket_fast',
|
'strings.substr', 'strings.to_hash_bucket_fast',
|
||||||
'strings.to_hash_bucket_strong', 'strings.to_hash_bucket',
|
'strings.to_hash_bucket_strong', 'strings.to_hash_bucket',
|
||||||
'strings.to_number', 'strings.unicode_script', 'tile', 'truncatediv',
|
'strings.to_number', 'strings.unicode_script', 'tile', 'truncatediv',
|
||||||
'truncatemod', 'zeros_like', 'dynamic_partition'
|
'truncatemod', 'zeros_like', 'dynamic_partition', 'reverse'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Ops that should be listed as supported in v1 only.
|
# Ops that should be listed as supported in v1 only.
|
||||||
|
90
tensorflow/python/ops/ragged/ragged_reverse_op_test.py
Normal file
90
tensorflow/python/ops/ragged/ragged_reverse_op_test.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for ragged_array_ops.reverse."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
class RaggedReverseOpTest(test_util.TensorFlowTestCase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.parameters([
|
||||||
|
dict(
|
||||||
|
descr='Docstring example 1',
|
||||||
|
data=[[[1, 2], [3, 4]],
|
||||||
|
[[5, 6]],
|
||||||
|
[[7, 8], [9, 10], [11, 12]]],
|
||||||
|
axis=[0, 2],
|
||||||
|
expected=[[[8, 7], [10, 9], [12, 11]],
|
||||||
|
[[6, 5]],
|
||||||
|
[[2, 1], [4, 3]]]),
|
||||||
|
dict(
|
||||||
|
descr='data.shape=[5, (D2)]; axis=[0]',
|
||||||
|
data=[[1, 2], [3, 4, 5, 6], [7, 8, 9], [], [1, 2, 3]],
|
||||||
|
axis=[0],
|
||||||
|
expected=[[1, 2, 3], [], [7, 8, 9], [3, 4, 5, 6], [1, 2]]),
|
||||||
|
dict(
|
||||||
|
descr='data.shape=[5, (D2)]; axis=[1]',
|
||||||
|
data=[[1, 2], [3, 4, 5, 6], [7, 8, 9], [], [1, 2, 3]],
|
||||||
|
axis=[1],
|
||||||
|
expected=[[2, 1], [6, 5, 4, 3], [9, 8, 7], [], [3, 2, 1]]),
|
||||||
|
dict(
|
||||||
|
descr='data.shape=[5, (D2), (D3)]; axis=[0, -1]',
|
||||||
|
data=[[[1], [2, 3]], [[4, 5], [6, 7]], [[8]]],
|
||||||
|
axis=[0, -1],
|
||||||
|
expected=[[[8]], [[5, 4], [7, 6]], [[1], [3, 2]]]),
|
||||||
|
dict(
|
||||||
|
descr='data.shape=[2, (D2), 2]; axis=[2]',
|
||||||
|
data=[[[1, 2], [3, 4]], [[5, 6]]],
|
||||||
|
axis=[2],
|
||||||
|
expected=[[[2, 1], [4, 3]], [[6, 5]]],
|
||||||
|
ragged_rank=1),
|
||||||
|
dict(
|
||||||
|
descr='data.shape=[2, (D2), (D3)]; axis=[-1]',
|
||||||
|
data=[[[1, 2], [3, 4]], [[5, 6]]],
|
||||||
|
axis=[-1],
|
||||||
|
expected=[[[2, 1], [4, 3]], [[6, 5]]]),
|
||||||
|
dict(
|
||||||
|
descr='data.shape=[2, (D2), (D3)]; axis=[]',
|
||||||
|
data=[[[1, 2], [3, 4]], [[5, 6]]],
|
||||||
|
axis=[],
|
||||||
|
expected=[[[1, 2], [3, 4]], [[5, 6]]])
|
||||||
|
]) # pyformat: disable
|
||||||
|
def testReverse(self, descr, data, axis, expected, ragged_rank=None):
|
||||||
|
data = ragged_factory_ops.constant(data, ragged_rank=ragged_rank)
|
||||||
|
result = ragged_array_ops.reverse(data, axis)
|
||||||
|
expected = ragged_factory_ops.constant(expected, ragged_rank=ragged_rank)
|
||||||
|
self.assertAllClose(result, expected)
|
||||||
|
|
||||||
|
def testErrors(self):
|
||||||
|
self.assertRaisesRegexp(
|
||||||
|
TypeError, '`axis` must be a list of int or a constant tensor *',
|
||||||
|
ragged_array_ops.reverse,
|
||||||
|
ragged_factory_ops.constant([[1], [2, 3]], ragged_rank=1),
|
||||||
|
[0, None])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
googletest.main()
|
Loading…
Reference in New Issue
Block a user