Add ragged tensor support to tf.squeeze

PiperOrigin-RevId: 244385979
This commit is contained in:
A. Unique TensorFlower 2019-04-19 10:58:51 -07:00 committed by TensorFlower Gardener
parent 6aca010fce
commit 58f67785f6
6 changed files with 475 additions and 2 deletions

View File

@ -3100,6 +3100,7 @@ def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None):
@tf_export(v1=["squeeze"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, "Use the `axis` argument instead",
"squeeze_dims")
def squeeze(input, axis=None, name=None, squeeze_dims=None):
@ -3125,12 +3126,18 @@ def squeeze(input, axis=None, name=None, squeeze_dims=None):
tf.shape(tf.squeeze(t, [2, 4])) # [1, 2, 3, 1]
```
Note: When it comes to squeezing ragged tensors, it has O(number of elements).
Note: if `input` is a `tf.RaggedTensor`, then this operation takes `O(N)`
time, where `N` is the number of elements in the squeezed dimensions.
Args:
input: A `Tensor`. The `input` to squeeze.
axis: An optional list of `ints`. Defaults to `[]`. If specified, only
squeezes the dimensions listed. The dimension index starts at 0. It is an
error to squeeze a dimension that is not 1. Must be in the range
`[-rank(input), rank(input))`.
Must be specified if `input` is a `RaggedTensor`.
name: A name for the operation (optional).
squeeze_dims: Deprecated keyword argument that is now axis.
@ -3150,6 +3157,7 @@ def squeeze(input, axis=None, name=None, squeeze_dims=None):
@tf_export("squeeze", v1=[])
@dispatch.add_dispatch_support
def squeeze_v2(input, axis=None, name=None):
# pylint: disable=redefined-builtin
return squeeze(input, axis, name)

View File

@ -269,6 +269,19 @@ py_library(
],
)
py_library(
name = "ragged_squeeze_op",
srcs = ["ragged_squeeze_op.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_tensor",
":ragged_tensor_shape",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:ops",
],
)
py_library(
name = "ragged_tensor",
srcs = ["ragged_tensor.py"],
@ -391,6 +404,7 @@ py_library(
":ragged_array_ops",
":ragged_batch_gather_ops",
":ragged_math_ops",
":ragged_squeeze_op",
":ragged_tensor",
":ragged_tensor_shape",
":ragged_util",
@ -1012,3 +1026,19 @@ py_test(
"//tensorflow/python:platform_test",
],
)
py_test(
name = "ragged_squeeze_op_test",
srcs = ["ragged_squeeze_op_test.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_conversion_ops",
":ragged_factory_ops",
":ragged_squeeze_op",
":ragged_test_util",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -36,10 +36,12 @@ from tensorflow.python.ops.ragged import ragged_batch_gather_ops
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_gather_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_squeeze_op
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_shape
from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.ops.ragged import ragged_where_op
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_export
@ -432,6 +434,11 @@ def _ragged_size_v1(input, name=None, out_type=dtypes.int32): # pylint: disable
return ragged_array_ops.size(input=input, out_type=out_type, name=name)
def _ragged_squeeze_v1(input, axis=None, name=None, squeeze_dims=None): # pylint: disable=redefined-builtin
axis = deprecation.deprecated_argument_lookup('axis', axis, 'squeeze_dims',
squeeze_dims)
return ragged_squeeze_op.squeeze(input, axis, name)
# (original_op, ragged_op, ragged_args)
_RAGGED_DISPATCH_OPS = [
(array_ops.batch_gather, ragged_batch_gather_ops.batch_gather,
@ -442,11 +449,13 @@ _RAGGED_DISPATCH_OPS = [
(array_ops.gather, _ragged_gather_v1, ['params', 'indices']),
(array_ops.gather_v2, ragged_gather_ops.gather, ['params', 'indices']),
(array_ops.gather_nd, _ragged_gather_nd_v1, ['params', 'indices']),
(array_ops.gather_nd_v2, ragged_gather_ops.gather_nd,
['params', 'indices']),
(array_ops.gather_nd_v2, ragged_gather_ops.gather_nd, ['params',
'indices']),
(array_ops.rank, ragged_array_ops.rank, ['input']),
(array_ops.size, _ragged_size_v1, ['input']),
(array_ops.size_v2, ragged_array_ops.size, ['input']),
(array_ops.squeeze, _ragged_squeeze_v1, ['input']),
(array_ops.squeeze_v2, ragged_squeeze_op.squeeze, ['input']),
(array_ops.stack, ragged_concat_ops.stack, ['[values]']),
(array_ops.tile, ragged_array_ops.tile, ['input']),
(array_ops.where, ragged_where_op.where, ['condition', 'x', 'y']),

View File

@ -695,6 +695,20 @@ class RaggedElementwiseOpsTest(ragged_test_util.RaggedTensorTestCase,
op=array_ops.size_v2,
kwargs={'input': ragged_factory_ops.constant_value([[8, 3], [5]])},
expected=3),
dict(
op=array_ops.squeeze,
kwargs={
'input': ragged_factory_ops.constant_value([[[1, 2, 3], [4, 5]]]),
'axis': [0]
},
expected=ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]])),
dict(
op=array_ops.squeeze_v2,
kwargs={
'input': ragged_factory_ops.constant_value([[[1, 2, 3], [4, 5]]]),
'axis': [0]
},
expected=ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]])),
])
def testRaggedDispatch(self, op, expected, args=(), kwargs=None):
if kwargs is None: kwargs = {}

View File

@ -0,0 +1,120 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operator Squeeze for RaggedTensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
def squeeze(input, axis=None, name=None): # pylint: disable=redefined-builtin
"""Ragged compatible squeeze.
If `input` is a `tf.Tensor`, then this calls `tf.squeeze`.
If `input` is a `tf.RaggedTensor`, then this operation takes `O(N)` time,
where `N` is the number of elements in the squeezed dimensions.
Args:
input: A potentially ragged tensor. The input to squeeze.
axis: An optional list of ints. Defaults to `None`. If the `input` is
ragged, it only squeezes the dimensions listed. It fails if `input` is
ragged and axis is []. If `input` is not ragged it calls tf.squeeze. Note
that it is an error to squeeze a dimension that is not 1. It must be in
the range of [-rank(input), rank(input)).
name: A name for the operation (optional).
Returns:
A potentially ragged tensor. Contains the same data as input,
but has one or more dimensions of size 1 removed.
"""
with ops.name_scope(name, 'RaggedSqueeze', [input]):
input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input)
if isinstance(input, ops.Tensor):
return array_ops.squeeze(input, axis, name)
if axis is None:
raise ValueError('Ragged.squeeze must have an axis argument.')
if isinstance(axis, int):
axis = [axis]
elif ((not isinstance(axis, (list, tuple))) or
(not all(isinstance(d, int) for d in axis))):
raise TypeError('Axis must be a list or tuple of integers.')
dense_dims = []
ragged_dims = []
# Normalize all the dims in axis to be positive
axis = [ragged_util.get_positive_axis(d, input.shape.ndims) for d in axis]
for dim in axis:
if dim > input.ragged_rank:
dense_dims.append(dim - input.ragged_rank)
else:
ragged_dims.append(dim)
# Make sure the specified ragged dimensions are squeezable.
assertion_list = []
scalar_tensor_one = constant_op.constant(1, dtype=dtypes.int64)
for i, r in enumerate(input.nested_row_lengths()):
if i + 1 in ragged_dims:
assertion_list.append(
control_flow_ops.Assert(
math_ops.reduce_all(math_ops.equal(r, scalar_tensor_one)),
['the given axis (axis = %d) is not squeezable!' % (i + 1)]))
if 0 in ragged_dims:
scalar_tensor_two = constant_op.constant(2, dtype=dtypes.int32)
assertion_list.append(
control_flow_ops.Assert(
math_ops.equal(
array_ops.size(input.row_splits), scalar_tensor_two),
['the given axis (axis = 0) is not squeezable!']))
# Till now, we are sure that the ragged dimensions are squeezable.
squeezed_rt = None
squeezed_rt = control_flow_ops.with_dependencies(assertion_list,
input.flat_values)
if dense_dims:
# Gives error if the dense dimension is not squeezable.
squeezed_rt = array_ops.squeeze(squeezed_rt, dense_dims)
remaining_row_splits = []
remaining_row_splits = list()
for i, row_split in enumerate(input.nested_row_splits):
# each row_splits tensor is for dimension #(i+1) .
if (i + 1) not in ragged_dims:
remaining_row_splits.append(row_split)
# Take care of the first row if it is to be squeezed.
if remaining_row_splits and 0 in ragged_dims:
remaining_row_splits.pop(0)
squeezed_rt = RaggedTensor.from_nested_row_splits(squeezed_rt,
remaining_row_splits)
# Corner case: when removing all the ragged dimensions and the output is
# a scalar tensor e.g. ragged.squeeze(ragged.constant([[[1]]])).
if set(range(0, input.ragged_rank + 1)).issubset(set(ragged_dims)):
squeezed_rt = array_ops.squeeze(squeezed_rt, [0], name)
return squeezed_rt

View File

@ -0,0 +1,292 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ragged.size."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_conversion_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_squeeze_op
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedSqueezeTest(ragged_test_util.RaggedTensorTestCase,
parameterized.TestCase):
@parameterized.parameters([
{
'input_list': []
},
{
'input_list': [[]],
'squeeze_ranks': [0]
},
{
'input_list': [[[[], []], [[], []]]],
'squeeze_ranks': [0]
},
])
def test_passing_empty(self, input_list, squeeze_ranks=None):
rt = ragged_squeeze_op.squeeze(
ragged_factory_ops.constant(input_list), squeeze_ranks)
dt = array_ops.squeeze(constant_op.constant(input_list), squeeze_ranks)
self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt), dt)
@parameterized.parameters([
{
'input_list': [[1]],
'squeeze_ranks': [0]
},
{
'input_list': [[1]],
'squeeze_ranks': [0, 1]
},
{
'input_list': [[1, 2]],
'squeeze_ranks': [0]
},
{
'input_list': [[1], [2]],
'squeeze_ranks': [1]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [0]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [1]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [3]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [0, 3]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [0, 1]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [1, 3]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [0, 1, 3]
},
{
'input_list': [[[1], [2]], [[3], [4]]],
'squeeze_ranks': [2]
},
{
'input_list': [[1], [2]],
'squeeze_ranks': [-1]
},
])
def test_passing_simple(self, input_list, squeeze_ranks=None):
rt = ragged_squeeze_op.squeeze(
ragged_factory_ops.constant(input_list), squeeze_ranks)
dt = array_ops.squeeze(constant_op.constant(input_list), squeeze_ranks)
self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt), dt)
@parameterized.parameters([
# ragged_conversion_ops.from_tensor does not work for this
# {'input_list': [1]},
{
'input_list': [[1]],
'squeeze_ranks': [0]
},
{
'input_list': [[1, 2]],
'squeeze_ranks': [0]
},
{
'input_list': [[1], [2]],
'squeeze_ranks': [1]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [0]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [1]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [3]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [0, 3]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [0, 1]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [1, 3]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [0, 1, 3]
},
{
'input_list': [[[1], [2]], [[3], [4]]],
'squeeze_ranks': [2]
},
])
def test_passing_simple_from_dense(self, input_list, squeeze_ranks=None):
dt = constant_op.constant(input_list)
rt = ragged_conversion_ops.from_tensor(dt)
rt_s = ragged_squeeze_op.squeeze(rt, squeeze_ranks)
dt_s = array_ops.squeeze(dt, squeeze_ranks)
self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt_s), dt_s)
@parameterized.parameters([
{
'input_list': [[[[[[1]], [[1, 2]]]], [[[[]], [[]]]]]],
'output_list': [[[1], [1, 2]], [[], []]],
'squeeze_ranks': [0, 2, 4]
},
{
'input_list': [[[[[[1]], [[1, 2]]]], [[[[]], [[]]]]]],
'output_list': [[[[[1]], [[1, 2]]]], [[[[]], [[]]]]],
'squeeze_ranks': [0]
},
])
def test_passing_ragged(self, input_list, output_list, squeeze_ranks=None):
rt = ragged_factory_ops.constant(input_list)
rt_s = ragged_squeeze_op.squeeze(rt, squeeze_ranks)
ref = ragged_factory_ops.constant(output_list)
self.assertRaggedEqual(rt_s, ref)
def test_passing_text(self):
rt = ragged_factory_ops.constant([[[[[[[['H']], [['e']], [['l']], [['l']],
[['o']]],
[[['W']], [['o']], [['r']], [['l']],
[['d']], [['!']]]]],
[[[[['T']], [['h']], [['i']], [['s']]],
[[['i']], [['s']]],
[[['M']], [['e']], [['h']], [['r']],
[['d']], [['a']], [['d']]],
[[['.']]]]]]]])
output_list = [[['H', 'e', 'l', 'l', 'o'], ['W', 'o', 'r', 'l', 'd', '!']],
[['T', 'h', 'i', 's'], ['i', 's'],
['M', 'e', 'h', 'r', 'd', 'a', 'd'], ['.']]]
ref = ragged_factory_ops.constant(output_list)
rt_s = ragged_squeeze_op.squeeze(rt, [0, 1, 3, 6, 7])
self.assertRaggedEqual(rt_s, ref)
@parameterized.parameters([
{
'input_list': [[]],
'squeeze_ranks': [1]
},
{
'input_list': [[1, 2]],
'squeeze_ranks': [1]
},
{
'input_list': [[1], [2]],
'squeeze_ranks': [0]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [0, 2]
},
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [2]
},
{
'input_list': [[[1], [2]], [[3], [4]]],
'squeeze_ranks': [0]
},
{
'input_list': [[[1], [2]], [[3], [4]]],
'squeeze_ranks': [1]
},
{
'input_list': [[], []],
'squeeze_ranks': [1]
},
{
'input_list': [[[], []], [[], []]],
'squeeze_ranks': [1]
},
])
def test_failing_InvalidArgumentError(self, input_list, squeeze_ranks):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(
ragged_squeeze_op.squeeze(
ragged_factory_ops.constant(input_list), squeeze_ranks))
@parameterized.parameters([
{
'input_list': [[]]
},
{
'input_list': [[1]]
},
{
'input_list': [[1, 2]]
},
{
'input_list': [[[1], [2]], [[3], [4]]]
},
{
'input_list': [[1]]
},
{
'input_list': [[[1], [2]], [[3], [4]]]
},
{
'input_list': [[[[12], [11]]]]
},
])
def test_failing_no_squeeze_dim_specified(self, input_list):
with self.assertRaises(ValueError):
ragged_squeeze_op.squeeze(ragged_factory_ops.constant(input_list))
@parameterized.parameters([
{
'input_list': [[[[12], [11]]]],
'squeeze_ranks': [0, 1, 3]
},
])
def test_failing_axis_is_not_a_list(self, input_list, squeeze_ranks):
with self.assertRaises(TypeError):
tensor_ranks = constant_op.constant(squeeze_ranks)
ragged_squeeze_op.squeeze(
ragged_factory_ops.constant(input_list), tensor_ranks)
if __name__ == '__main__':
googletest.main()